if001 commited on
Commit
178a1fe
·
verified ·
1 Parent(s): 48941ed

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ResidualNetV2ForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_residualnet_v2.py.ResidualNetV2Config",
7
+ "AutoModel": "modeling_residualnet_v2.ResidualNetV2Model",
8
+ "AutoModelForCausalLM": "modeling_residualnet_v2.py.ResidualNetV2ForCausalLM"
9
+ },
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": null,
12
+ "embd_pdrop": 0.0,
13
+ "eos_token_id": 151645,
14
+ "hidden_act": "silu",
15
+ "hidden_size": 128,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 64,
18
+ "max_position_embeddings": 1024,
19
+ "model_type": "ResidualNetV2Config",
20
+ "name": "residual-v2-tiny",
21
+ "num_attention_heads": 4,
22
+ "num_hidden_layers": 4,
23
+ "num_key_value_heads": 4,
24
+ "original_max_position_embeddings": 1024,
25
+ "pad_token_id": 151645,
26
+ "resid_pdrop": 0.0,
27
+ "rms_norm_eps": 1e-05,
28
+ "rope_scaling": null,
29
+ "rope_theta": 10000.0,
30
+ "sliding_window": null,
31
+ "tie_word_embeddings": false,
32
+ "torch_dtype": "float32",
33
+ "transformers_version": "4.48.2",
34
+ "use_cache": true,
35
+ "vocab_size": 151669
36
+ }
configuration_residualnet_v2.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from transformers.models.phi3.configuration_phi3 import Phi3Config
3
+ class ResidualNetV2Config(Phi3Config):
4
+ model_type = "ResidualNetV2Config"
5
+ def __init__(self, **kwargs):
6
+ super().__init__(**kwargs)
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32f3e4a3d58b407bd75a163cb45e0b46d42dc6fb22bf84b905cf35f4b9156c29
3
+ size 81442712
modeling_residualnet_v2.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ from typing import Optional, List, Tuple
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from transformers.models.phi3.configuration_phi3 import Phi3Config
8
+ from transformers.models.phi3.modeling_phi3 import (
9
+ Phi3PreTrainedModel,
10
+ Phi3RMSNorm,
11
+ Phi3MLP,
12
+ Phi3Attention,
13
+ # Phi3SdpaAttention,
14
+ Phi3RotaryEmbedding,
15
+ )
16
+ #from models.phi3_config import Phi3Config
17
+ #from models.phi3 import (
18
+ # Phi3PreTrainedModel,
19
+ # Phi3RMSNorm,
20
+ # Phi3MLP,
21
+ # # Phi3SdpaAttention,
22
+ # Phi3Attention,
23
+ # Phi3RotaryEmbedding,
24
+ #)
25
+
26
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
27
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
28
+ from transformers.generation.utils import GenerationMixin
29
+
30
+
31
+ class ResidualNetV2Config(Phi3Config):
32
+ model_type = "ResidualNetV2Config"
33
+ def __init__(self, **kwargs):
34
+ super().__init__(**kwargs)
35
+
36
+
37
+ # ==============
38
+ # helpers (差分/マスク, 形状ユーティリティ, optional RoPE)
39
+ # ==============
40
+
41
+ def first_order_diff(x: torch.Tensor) -> torch.Tensor:
42
+ # (B,L,H) -> (B,L-1,H)
43
+ return x[:, 1:, :] - x[:, :-1, :]
44
+
45
+ def second_order_diff(x: torch.Tensor) -> torch.Tensor:
46
+ # (B,L,H) -> (B,L-2,H): x_{t+2} - 2 x_{t+1} + x_t
47
+ return x[:, 2:, :] - 2.0 * x[:, 1:-1, :] + x[:, :-2, :]
48
+
49
+ def mask_and(*tensors: torch.Tensor) -> torch.Tensor:
50
+ # AND を丁寧に(float/bool いずれでも可)
51
+ out = tensors[0].bool()
52
+ for t in tensors[1:]:
53
+ out = out & t.bool()
54
+ return out.to(tensors[0].dtype)
55
+
56
+ def build_mask_for_diff(mask2d: Optional[torch.Tensor], order: int) -> Optional[torch.Tensor]:
57
+ if mask2d is None:
58
+ return None
59
+ if order == 0:
60
+ return mask2d
61
+ elif order == 1:
62
+ return mask_and(mask2d[:, 1:], mask2d[:, :-1])
63
+ elif order == 2:
64
+ return mask_and(mask2d[:, 2:], mask2d[:, 1:-1], mask2d[:, :-2])
65
+ else:
66
+ raise ValueError("order must be 0,1,2")
67
+
68
+ def shape_qkv(x: torch.Tensor, num_heads: int) -> torch.Tensor:
69
+ # (B,L,H) -> (B,heads,L,head_dim)
70
+ B, L, H = x.shape
71
+ head_dim = H // num_heads
72
+ x = x.view(B, L, num_heads, head_dim).transpose(1, 2) # (B,heads,L,head_dim)
73
+ return x
74
+
75
+ def unshape_ctx(x: torch.Tensor) -> torch.Tensor:
76
+ # (B,heads,L,head_dim) -> (B,L,H)
77
+ B, nH, L, d = x.shape
78
+ return x.transpose(1, 2).contiguous().view(B, L, nH * d)
79
+
80
+ # RoPE helper(必要な場合のみ使用)
81
+ def rotate_half(x):
82
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
83
+ return torch.cat((-x2, x1), dim=-1)
84
+
85
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
86
+ cos = cos.unsqueeze(unsqueeze_dim)
87
+ sin = sin.unsqueeze(unsqueeze_dim)
88
+ rotary_dim = cos.shape[-1]
89
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
90
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
91
+ q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1)
92
+ k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1)
93
+ return q_embed, k_embed
94
+
95
+
96
+ # ==============
97
+ # Self-Block (Phi-3 部品そのまま)
98
+ # ==============
99
+
100
+ class Phi3SelfBlock(nn.Module):
101
+ """
102
+ PreNorm -> Self-Attn(SDPA+RoPE) -> resid -> PreNorm -> MLP -> resid
103
+ """
104
+ def __init__(self, config: ResidualNetV2Config, layer_idx: int, rotary_emb: Phi3RotaryEmbedding):
105
+ super().__init__()
106
+ self.config = config
107
+ self.layer_idx = layer_idx
108
+ self.in_norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
109
+ # self.attn = Phi3SdpaAttention(config, layer_idx=layer_idx)
110
+ self.attn = Phi3Attention(config, layer_idx=layer_idx)
111
+ self.dropout_attn = nn.Dropout(config.resid_pdrop)
112
+ self.ff_norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
113
+ self.mlp = Phi3MLP(config)
114
+ self.dropout_mlp = nn.Dropout(config.resid_pdrop)
115
+ self.rotary_emb = rotary_emb
116
+
117
+ def _prepare_4d_mask(self, mask2d, bsz, seqlen, hidden_states):
118
+ if mask2d is None:
119
+ return None
120
+ return _prepare_4d_causal_attention_mask(
121
+ mask2d, (bsz, seqlen), hidden_states, past_key_values_length=0, sliding_window=self.config.sliding_window
122
+ )
123
+
124
+ def forward(
125
+ self,
126
+ hidden_states: torch.Tensor, # (B,L,H)
127
+ mask2d: Optional[torch.Tensor], # (B,L)
128
+ position_ids: Optional[torch.Tensor],# unused here; we re-make per len
129
+ output_attentions: bool = False,
130
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
131
+ x = self.in_norm(hidden_states)
132
+ B, L, _ = x.shape
133
+ pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B, -1)
134
+ attn_mask = self._prepare_4d_mask(mask2d, B, L, x)
135
+ position_embeddings = self.rotary_emb(hidden_states, pos)
136
+
137
+ attn_out, attn_weights = self.attn(
138
+ hidden_states=x,
139
+ attention_mask=attn_mask,
140
+ position_ids=pos,
141
+ position_embeddings=position_embeddings,
142
+ past_key_value=None,
143
+ output_attentions=output_attentions,
144
+ use_cache=False,
145
+ )
146
+ x = hidden_states + self.dropout_attn(attn_out)
147
+
148
+ h = self.ff_norm(x)
149
+ h = self.mlp(h)
150
+ x = x + self.dropout_mlp(h)
151
+ return x, (attn_weights if output_attentions else None)
152
+
153
+
154
+ # ==============
155
+ # Cross-Attention (シンプル実装/GQA対応。既定で RoPE 無し)
156
+ # ==============
157
+
158
+ class SimpleCrossAttention(nn.Module):
159
+ """
160
+ Query: x_q (B,Lq,H) Key/Value: x_kv (B,Lk,H)
161
+ - num_heads / num_key_value_heads は Phi-3 と同じ設定に合わせる(GQA対応)
162
+ - 既定: RoPE 適用なし(decoder-encoder cross は相対位置の意味付けが曖昧なため)。
163
+ use_rope_in_cross_attn=True で RoPE を適用可能。
164
+ """
165
+ def __init__(self, config: ResidualNetV2Config, rotary_emb: Phi3RotaryEmbedding, use_rope_in_cross_attn: bool = False):
166
+ super().__init__()
167
+ self.config = config
168
+ self.rotary_emb = rotary_emb
169
+ self.use_rope = use_rope_in_cross_attn
170
+
171
+ H = config.hidden_size
172
+ nH = config.num_attention_heads
173
+ nKV = getattr(config, "num_key_value_heads", nH)
174
+ self.nH = nH
175
+ self.nKV = nKV
176
+ self.groups = nH // nKV
177
+ self.head_dim = H // nH
178
+
179
+ self.q_proj = nn.Linear(H, H, bias=False)
180
+ self.k_proj = nn.Linear(H, nKV * self.head_dim, bias=False)
181
+ self.v_proj = nn.Linear(H, nKV * self.head_dim, bias=False)
182
+ self.o_proj = nn.Linear(H, H, bias=False)
183
+ self.dropout = nn.Dropout(config.attention_dropout)
184
+
185
+ # 追加正規化(安定のため)
186
+ self.q_norm = Phi3RMSNorm(H, eps=config.rms_norm_eps)
187
+ self.kv_norm = Phi3RMSNorm(H, eps=config.rms_norm_eps)
188
+
189
+ def _kv_repeat(self, x: torch.Tensor) -> torch.Tensor:
190
+ # (B, nKV, Lk, d) -> (B, nH, Lk, d) へ繰り返し
191
+ if self.nKV == self.nH:
192
+ return x
193
+ return x.repeat_interleave(self.groups, dim=1)
194
+
195
+ def _make_attn_mask_bool(self, enc_mask2d: Optional[torch.Tensor], Lq: int, Lk: int, B: int) -> Optional[torch.Tensor]:
196
+ # enc_mask2d: (B,Lk) in {0,1} -> broadcastable bool mask of shape (B,1,Lq,Lk)
197
+ if enc_mask2d is None:
198
+ return None
199
+ m = (~enc_mask2d.bool()).unsqueeze(1).unsqueeze(2) # True=mask
200
+ return m.expand(B, 1, Lq, Lk)
201
+
202
+ def forward(
203
+ self,
204
+ x_q: torch.Tensor, # (B,Lq,H)
205
+ x_kv: torch.Tensor, # (B,Lk,H)
206
+ enc_mask2d: Optional[torch.Tensor] = None, # (B,Lk)
207
+ ) -> torch.Tensor:
208
+ B, Lq, H = x_q.shape
209
+ Lk = x_kv.size(1)
210
+
211
+ q = self.q_proj(self.q_norm(x_q)) # (B,Lq,H)
212
+ k = self.k_proj(self.kv_norm(x_kv)) # (B,Lk, nKV*Hd)
213
+ v = self.v_proj(self.kv_norm(x_kv)) # (B,Lk, nKV*Hd)
214
+
215
+ q = shape_qkv(q, self.nH) # (B,nH,Lq,Hd)
216
+ k = k.view(B, Lk, self.nKV, self.head_dim).transpose(1, 2) # (B,nKV,Lk,Hd)
217
+ v = v.view(B, Lk, self.nKV, self.head_dim).transpose(1, 2) # (B,nKV,Lk,Hd)
218
+ k = self._kv_repeat(k) # (B,nH,Lk,Hd)
219
+ v = self._kv_repeat(v) # (B,nH,Lk,Hd)
220
+
221
+ if self.use_rope:
222
+ # 参考実装:各系列長に合わせて cos/sin を取り、q/k に適用
223
+ # Phi3RotaryEmbedding の forward は (x, position_ids) -> (cos, sin) を返す前提
224
+ pos_q = torch.arange(Lq, device=x_q.device).unsqueeze(0).expand(B, -1)
225
+ pos_k = torch.arange(Lk, device=x_q.device).unsqueeze(0).expand(B, -1)
226
+ # ダミーの [B,L,head_dim] を渡して cos/sin を得る(実装に依存するため try/except)
227
+ try:
228
+ dummy_q = torch.zeros(B, Lq, self.head_dim, device=x_q.device, dtype=x_q.dtype)
229
+ dummy_k = torch.zeros(B, Lk, self.head_dim, device=x_q.device, dtype=x_q.dtype)
230
+ cos_q, sin_q = self.rotary_emb(dummy_q, pos_q)
231
+ cos_k, sin_k = self.rotary_emb(dummy_k, pos_k)
232
+ q, k = apply_rotary_pos_emb(q, k, cos_q, sin_k, position_ids=None, unsqueeze_dim=2)
233
+ except Exception:
234
+ # RoPE 未対応環境では静かにスキップ(Self-Attn 側で RoPE が効いていれば全体としては相対位置信号を保持)
235
+ pass
236
+
237
+ # scaled dot-product attention
238
+ attn_mask = self._make_attn_mask_bool(enc_mask2d, Lq, Lk, B) # True=mask
239
+ y = F.scaled_dot_product_attention(
240
+ q, k, v, attn_mask=attn_mask, dropout_p=self.dropout.p if self.training else 0.0, is_causal=False
241
+ ) # (B,nH,Lq,Hd)
242
+
243
+ y = unshape_ctx(y) # (B,Lq,H)
244
+ y = self.o_proj(y)
245
+ return y
246
+
247
+
248
+ # ==============
249
+ # v2: 3本の枝��「各3層」通してから、0階に Cross-Attn(←1階) → Cross-Attn(←2階)
250
+ # ==============
251
+
252
+ class ResidualNetV2Model(Phi3PreTrainedModel):
253
+ """
254
+ 1) embedding -> 0/1/2階差分 3枝
255
+ 2) 各枝: [SelfBlock] x 3 (同一枝内で3層)
256
+ 3) x0 に CrossAttn(x1_final) → residual、続けて CrossAttn(x2_final) → residual
257
+ 出力は x0(原系列長 L)
258
+ """
259
+ def __init__(self, config: ResidualNetV2Config, use_rope_in_cross_attn: bool = False):
260
+ super().__init__(config)
261
+ self.padding_idx = config.pad_token_id
262
+ self.vocab_size = config.vocab_size
263
+
264
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
265
+ self.rotary_emb = Phi3RotaryEmbedding(config=config)
266
+ self.norm_out = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
267
+
268
+ # 各枝 3 層
269
+ self.branch0 = nn.ModuleList([Phi3SelfBlock(config, layer_idx=i, rotary_emb=self.rotary_emb) for i in range(3)])
270
+ self.branch1 = nn.ModuleList([Phi3SelfBlock(config, layer_idx=100+i, rotary_emb=self.rotary_emb) for i in range(3)])
271
+ self.branch2 = nn.ModuleList([Phi3SelfBlock(config, layer_idx=200+i, rotary_emb=self.rotary_emb) for i in range(3)])
272
+
273
+ # Cross-Attn × 2 (0階 <- 1階, 0階 <- 2階)
274
+ self.cross01_norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
275
+ self.cross01 = SimpleCrossAttention(config, self.rotary_emb, use_rope_in_cross_attn)
276
+ self.cross12_norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
277
+ self.cross02 = SimpleCrossAttention(config, self.rotary_emb, use_rope_in_cross_attn)
278
+
279
+ self.dropout = nn.Dropout(config.resid_pdrop)
280
+ self.post_init()
281
+
282
+ def forward(
283
+ self,
284
+ input_ids: Optional[torch.LongTensor] = None,
285
+ attention_mask: Optional[torch.Tensor] = None, # (B,L)
286
+ inputs_embeds: Optional[torch.FloatTensor] = None,
287
+ output_attentions: Optional[bool] = None,
288
+ output_hidden_states: Optional[bool] = None,
289
+ return_dict: Optional[bool] = None,
290
+ ) -> BaseModelOutput:
291
+ output_attentions = False if output_attentions is None else output_attentions
292
+ output_hidden_states = False if output_hidden_states is None else output_hidden_states
293
+ return_dict = True if return_dict is None else return_dict
294
+
295
+ if inputs_embeds is None:
296
+ x0 = self.embed_tokens(input_ids) # (B,L,H) 0階
297
+ else:
298
+ x0 = inputs_embeds
299
+ B, L, H = x0.shape
300
+
301
+ m0 = attention_mask if attention_mask is not None else torch.ones(B, L, device=x0.device, dtype=torch.long)
302
+ # 1階/2階差分
303
+ x1 = first_order_diff(x0) # (B,L-1,H)
304
+ x2 = second_order_diff(x0) # (B,L-2,H)
305
+ m1 = build_mask_for_diff(m0, 1)
306
+ m2 = build_mask_for_diff(m0, 2)
307
+
308
+ # 各枝 3 層
309
+ for blk in self.branch0:
310
+ x0, _ = blk(x0, m0, None, output_attentions=False)
311
+ for blk in self.branch1:
312
+ x1, _ = blk(x1, m1, None, output_attentions=False)
313
+ for blk in self.branch2:
314
+ x2, _ = blk(x2, m2, None, output_attentions=False)
315
+
316
+ # CrossAttn: x0 <- x1
317
+ x0 = x0 + self.dropout(self.cross01(self.cross01_norm(x0), x1, enc_mask2d=m1))
318
+ # CrossAttn: x0 <- x2
319
+ x0 = x0 + self.dropout(self.cross02(self.cross12_norm(x0), x2, enc_mask2d=m2))
320
+
321
+ x_out = self.norm_out(x0)
322
+
323
+ if not return_dict:
324
+ return (x_out,)
325
+
326
+ return BaseModelOutput(
327
+ last_hidden_state=x_out,
328
+ hidden_states=None,
329
+ attentions=None,
330
+ )
331
+
332
+
333
+ class ResidualNetV2ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
334
+ _tied_weights_keys = ["lm_head.weight"]
335
+
336
+ def __init__(self, config: ResidualNetV2Config, use_rope_in_cross_attn: bool = False):
337
+ super().__init__(config)
338
+ self.model = ResidualNetV2Model(config, use_rope_in_cross_attn=use_rope_in_cross_attn)
339
+ self.vocab_size = config.vocab_size
340
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
341
+ # weight tying
342
+ self.lm_head.weight = self.model.embed_tokens.weight
343
+ self.post_init()
344
+
345
+ def forward(
346
+ self,
347
+ input_ids=None,
348
+ attention_mask=None,
349
+ inputs_embeds=None,
350
+ labels: Optional[torch.LongTensor] = None,
351
+ **kwargs,
352
+ ) -> CausalLMOutputWithPast:
353
+ out = self.model(
354
+ input_ids=input_ids,
355
+ attention_mask=attention_mask,
356
+ inputs_embeds=inputs_embeds,
357
+ output_attentions=kwargs.get("output_attentions", False),
358
+ output_hidden_states=kwargs.get("output_hidden_states", False),
359
+ return_dict=True,
360
+ )
361
+ logits = self.lm_head(out.last_hidden_state).float()
362
+
363
+ loss = None
364
+ if labels is not None:
365
+ shift_logits = logits[:, :-1, :].contiguous()
366
+ shift_labels = labels[:, 1:].contiguous()
367
+ loss_fct = nn.CrossEntropyLoss()
368
+ loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
369
+
370
+ return CausalLMOutputWithPast(
371
+ loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None
372
+ )
373
+
374
+ @property
375
+ def base_model(self):
376
+ return self.model
377
+
378
+
379
+ # ==============
380
+ # v3: 「各枝1層 + x0<-x1 Cross + x0<-x2 Cross」を1ブロックとして **3回** 反復
381
+ # ==============
382
+
383
+ class ResidualNetV3Model(Phi3PreTrainedModel):
384
+ """
385
+ 1 block = { 3枝: SelfBlock各1層 → x0<-x1 Cross → x0<-x2 Cross }
386
+ これを 3 回繰り返す(早期融合 + 反復洗練)
387
+ """
388
+ def __init__(self, config: ResidualNetV2Config, use_rope_in_cross_attn: bool = False):
389
+ super().__init__(config)
390
+ self.padding_idx = config.pad_token_id
391
+ self.vocab_size = config.vocab_size
392
+
393
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
394
+ self.rotary_emb = Phi3RotaryEmbedding(config=config)
395
+ self.norm_out = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
396
+ self.dropout = nn.Dropout(config.resid_pdrop)
397
+
398
+ # 3 ブロック分の層を用意(枝それぞれ + CrossAttn×2)
399
+ self.blocks_branch0 = nn.ModuleList([Phi3SelfBlock(config, layer_idx=10+i, rotary_emb=self.rotary_emb) for i in range(3)])
400
+ self.blocks_branch1 = nn.ModuleList([Phi3SelfBlock(config, layer_idx=110+i, rotary_emb=self.rotary_emb) for i in range(3)])
401
+ self.blocks_branch2 = nn.ModuleList([Phi3SelfBlock(config, layer_idx=210+i, rotary_emb=self.rotary_emb) for i in range(3)])
402
+
403
+ self.cross_norm_01 = nn.ModuleList([Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(3)])
404
+ self.cross_01 = nn.ModuleList([SimpleCrossAttention(config, self.rotary_emb, use_rope_in_cross_attn) for _ in range(3)])
405
+
406
+ self.cross_norm_02 = nn.ModuleList([Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(3)])
407
+ self.cross_02 = nn.ModuleList([SimpleCrossAttention(config, self.rotary_emb, use_rope_in_cross_attn) for _ in range(3)])
408
+
409
+ self.post_init()
410
+
411
+ def forward(
412
+ self,
413
+ input_ids: Optional[torch.LongTensor] = None,
414
+ attention_mask: Optional[torch.Tensor] = None,
415
+ inputs_embeds: Optional[torch.FloatTensor] = None,
416
+ output_attentions: Optional[bool] = None,
417
+ output_hidden_states: Optional[bool] = None,
418
+ return_dict: Optional[bool] = None,
419
+ ) -> BaseModelOutput:
420
+ output_attentions = False if output_attentions is None else output_attentions
421
+ output_hidden_states = False if output_hidden_states is None else output_hidden_states
422
+ return_dict = True if return_dict is None else return_dict
423
+
424
+ if inputs_embeds is None:
425
+ x0 = self.embed_tokens(input_ids)
426
+ else:
427
+ x0 = inputs_embeds
428
+ B, L, H = x0.shape
429
+ m0 = attention_mask if attention_mask is not None else torch.ones(B, L, device=x0.device, dtype=torch.long)
430
+
431
+ # 初回の差分(1,2階)は x0 から
432
+ def mk_x1x2(x0, m0):
433
+ return first_order_diff(x0), second_order_diff(x0), build_mask_for_diff(m0, 1), build_mask_for_diff(m0, 2)
434
+
435
+ x1, x2, m1, m2 = mk_x1x2(x0, m0)
436
+
437
+ # 3 ブロック反復
438
+ for i in range(3):
439
+ # 各枝 1 層
440
+ x0, _ = self.blocks_branch0[i](x0, m0, None, output_attentions=False)
441
+ x1, _ = self.blocks_branch1[i](x1, m1, None, output_attentions=False)
442
+ x2, _ = self.blocks_branch2[i](x2, m2, None, output_attentions=False)
443
+
444
+ # Cross: x0 <- x1, ついで x0 <- x2
445
+ x0 = x0 + self.dropout(self.cross_01[i](self.cross_norm_01[i](x0), x1, enc_mask2d=m1))
446
+ x0 = x0 + self.dropout(self.cross_02[i](self.cross_norm_02[i](x0), x2, enc_mask2d=m2))
447
+
448
+ # 次ブロック用に 1/2階差分を「最新の x0」から再計算する手もある
449
+ # (Down(2) 的な早期融合→再分解の設計に合わせたい場合)
450
+ # ここでは「枝連鎖の継続」を優先し、x1/x2 は枝内の連続層として進める。
451
+ # もし再分解を望むなら下記を有効化:
452
+ # x1, x2, m1, m2 = mk_x1x2(x0, m0)
453
+
454
+ x_out = self.norm_out(x0)
455
+
456
+ if not return_dict:
457
+ return (x_out,)
458
+
459
+ return BaseModelOutput(
460
+ last_hidden_state=x_out,
461
+ hidden_states=None,
462
+ attentions=None,
463
+ )
464
+
465
+
466
+ class ResidualNetV3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
467
+ _tied_weights_keys = ["lm_head.weight"]
468
+
469
+ def __init__(self, config: ResidualNetV2Config, use_rope_in_cross_attn: bool = False):
470
+ super().__init__(config)
471
+ self.model = ResidualNetV3Model(config, use_rope_in_cross_attn=use_rope_in_cross_attn)
472
+ self.vocab_size = config.vocab_size
473
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
474
+ self.lm_head.weight = self.model.embed_tokens.weight
475
+ self.post_init()
476
+
477
+ def forward(
478
+ self,
479
+ input_ids=None,
480
+ attention_mask=None,
481
+ inputs_embeds=None,
482
+ labels: Optional[torch.LongTensor] = None,
483
+ **kwargs,
484
+ ) -> CausalLMOutputWithPast:
485
+ out = self.model(
486
+ input_ids=input_ids,
487
+ attention_mask=attention_mask,
488
+ inputs_embeds=inputs_embeds,
489
+ output_attentions=kwargs.get("output_attentions", False),
490
+ output_hidden_states=kwargs.get("output_hidden_states", False),
491
+ return_dict=True,
492
+ )
493
+ logits = self.lm_head(out.last_hidden_state).float()
494
+
495
+ loss = None
496
+ if labels is not None:
497
+ shift_logits = logits[:, :-1, :].contiguous()
498
+ shift_labels = labels[:, 1:].contiguous()
499
+ loss_fct = nn.CrossEntropyLoss()
500
+ loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
501
+
502
+ return CausalLMOutputWithPast(
503
+ loss=loss, logits=logits, past_key_values=None, hidden_states=None, attentions=None
504
+ )
505
+
506
+ @property
507
+ def base_model(self):
508
+ return self.model
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff