dill-dev commited on
Commit
ed64422
Β·
verified Β·
1 Parent(s): 2f08f53

Create modeling_momo.py

Browse files
Files changed (1) hide show
  1. modeling_momo.py +284 -0
modeling_momo.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_momo.py
2
+ # 🌸 Momo-336M β€” HuggingFace compatible model definition
3
+ # Upload this file to your HF repo alongside config.json and configuration_momo.py
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+
11
+ from .configuration_momo import MomoConfig
12
+
13
+
14
+ # ════════════════════════════════════════════════════════════════
15
+ # COMPONENTS
16
+ # ════════════════════════════════════════════════════════════════
17
+
18
+ class RMSNorm(nn.Module):
19
+ def __init__(self, dim, eps=1e-5):
20
+ super().__init__()
21
+ self.eps = eps
22
+ self.weight = nn.Parameter(torch.ones(dim))
23
+
24
+ def forward(self, x):
25
+ rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
26
+ return (x.float() * rms).to(x.dtype) * self.weight
27
+
28
+
29
+ class RotaryEmbedding(nn.Module):
30
+ def __init__(self, dim, max_seq=512, theta=10000.0):
31
+ super().__init__()
32
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
33
+ self.register_buffer('inv_freq', inv_freq)
34
+ self._cache(max_seq)
35
+
36
+ def _cache(self, n):
37
+ t = torch.arange(n, device=self.inv_freq.device).float()
38
+ freq = torch.outer(t, self.inv_freq)
39
+ emb = torch.cat([freq, freq], dim=-1)
40
+ self.register_buffer('cos_c', emb.cos()[None, None])
41
+ self.register_buffer('sin_c', emb.sin()[None, None])
42
+
43
+ def forward(self, x, seq_len):
44
+ if seq_len > self.cos_c.shape[2]:
45
+ self._cache(seq_len)
46
+ return (
47
+ self.cos_c[:, :, :seq_len].to(x.dtype),
48
+ self.sin_c[:, :, :seq_len].to(x.dtype),
49
+ )
50
+
51
+
52
+ def rot_half(x):
53
+ a, b = x.chunk(2, dim=-1)
54
+ return torch.cat([-b, a], dim=-1)
55
+
56
+
57
+ def apply_rope(q, k, cos, sin):
58
+ return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)
59
+
60
+
61
+ # ════════════════════════════════════════════════════════════════
62
+ # ATTENTION β€” Grouped Query Attention (GQA)
63
+ # ════════════════════════════════════════════════════════════════
64
+
65
+ class MomoAttention(nn.Module):
66
+ def __init__(self, cfg: MomoConfig):
67
+ super().__init__()
68
+ self.nh = cfg.num_attention_heads
69
+ self.nkv = cfg.num_key_value_heads
70
+ self.hd = cfg.hidden_size // cfg.num_attention_heads
71
+ self.grp = self.nh // self.nkv
72
+ self.sc = self.hd ** -0.5
73
+ H = cfg.hidden_size
74
+ self.q = nn.Linear(H, self.nh * self.hd, bias=False)
75
+ self.k = nn.Linear(H, self.nkv * self.hd, bias=False)
76
+ self.v = nn.Linear(H, self.nkv * self.hd, bias=False)
77
+ self.o = nn.Linear(self.nh * self.hd, H, bias=False)
78
+ self.rope = RotaryEmbedding(self.hd, cfg.max_position_embeddings, cfg.rope_theta)
79
+
80
+ def forward(self, x, mask=None, past=None, use_cache=False):
81
+ B, T, _ = x.shape
82
+ q = self.q(x).view(B, T, self.nh, self.hd).transpose(1, 2)
83
+ k = self.k(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
84
+ v = self.v(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
85
+
86
+ past_len = past[0].shape[2] if past is not None else 0
87
+ cos, sin = self.rope(q, past_len + T)
88
+ cos = cos[:, :, past_len:past_len + T]
89
+ sin = sin[:, :, past_len:past_len + T]
90
+ q, k = apply_rope(q, k, cos, sin)
91
+
92
+ if self.grp > 1:
93
+ k = k[:, None].expand(-1, self.grp, -1, -1, -1).reshape(B, self.nh, T, self.hd)
94
+ v = v[:, None].expand(-1, self.grp, -1, -1, -1).reshape(B, self.nh, T, self.hd)
95
+
96
+ if past is not None:
97
+ pk, pv = past
98
+ k = torch.cat([pk, k], 2)
99
+ v = torch.cat([pv, v], 2)
100
+
101
+ pres = (k, v) if use_cache else None
102
+ S = k.shape[2]
103
+ a = torch.matmul(q, k.transpose(-2, -1)) * self.sc
104
+ causal = torch.triu(
105
+ torch.full((T, S), float('-inf'), device=x.device),
106
+ diagonal=S - T + 1
107
+ )
108
+ a = a + causal
109
+ if mask is not None:
110
+ a = a + mask
111
+ a = F.softmax(a, dim=-1)
112
+ out = torch.matmul(a, v).transpose(1, 2).reshape(B, T, -1)
113
+ return self.o(out), pres
114
+
115
+
116
+ # ════════════════════════════════════════════════════════════════
117
+ # FEED-FORWARD β€” SwiGLU
118
+ # ════════════════════════════════════════════════════════════════
119
+
120
+ class MomoFFN(nn.Module):
121
+ def __init__(self, cfg: MomoConfig):
122
+ super().__init__()
123
+ self.gate = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
124
+ self.up = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
125
+ self.down = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False)
126
+
127
+ def forward(self, x):
128
+ return self.down(F.silu(self.gate(x)) * self.up(x))
129
+
130
+
131
+ # ════════════════════════════════════════════════════════════════
132
+ # TRANSFORMER BLOCK
133
+ # ════════════════════════════════════════════════════════════════
134
+
135
+ class MomoBlock(nn.Module):
136
+ def __init__(self, cfg: MomoConfig):
137
+ super().__init__()
138
+ self.attn = MomoAttention(cfg)
139
+ self.ffn = MomoFFN(cfg)
140
+ self.norm1 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
141
+ self.norm2 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
142
+
143
+ def forward(self, x, mask=None, past=None, use_cache=False):
144
+ a, p = self.attn(self.norm1(x), mask, past, use_cache)
145
+ x = x + a
146
+ x = x + self.ffn(self.norm2(x))
147
+ return x, p
148
+
149
+
150
+ # ════════════════════════════════════════════════════════════════
151
+ # 🌸 MOMO FOR CAUSAL LM β€” Main Model
152
+ # ════════════════════════════════════════════════════════════════
153
+
154
+ class MomoForCausalLM(PreTrainedModel):
155
+ config_class = MomoConfig
156
+ _no_split_modules = ['MomoBlock']
157
+
158
+ def __init__(self, cfg: MomoConfig):
159
+ super().__init__(cfg)
160
+ self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
161
+ self.layers = nn.ModuleList([MomoBlock(cfg) for _ in range(cfg.num_hidden_layers)])
162
+ self.norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
163
+ self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
164
+ self.lm_head.weight = self.embed.weight # weight tying
165
+ self.grad_ckpt = cfg.use_gradient_checkpointing
166
+ self.apply(self._init_weights)
167
+
168
+ def _init_weights(self, m):
169
+ if isinstance(m, nn.Linear):
170
+ nn.init.normal_(m.weight, std=0.02)
171
+ if m.bias is not None:
172
+ nn.init.zeros_(m.bias)
173
+ elif isinstance(m, nn.Embedding):
174
+ nn.init.normal_(m.weight, std=0.02)
175
+
176
+ def get_input_embeddings(self):
177
+ return self.embed
178
+
179
+ def set_input_embeddings(self, value):
180
+ self.embed = value
181
+
182
+ def get_output_embeddings(self):
183
+ return self.lm_head
184
+
185
+ def set_output_embeddings(self, new_embeddings):
186
+ self.lm_head = new_embeddings
187
+
188
+ def forward(
189
+ self,
190
+ input_ids=None,
191
+ attention_mask=None,
192
+ labels=None,
193
+ past_key_values=None,
194
+ use_cache=False,
195
+ **kwargs,
196
+ ):
197
+ x = self.embed(input_ids)
198
+ pkvs = past_key_values or [None] * len(self.layers)
199
+ cache = []
200
+
201
+ for layer, past in zip(self.layers, pkvs):
202
+ if self.grad_ckpt and self.training:
203
+ def _fn(layer):
204
+ def fn(x):
205
+ out, _ = layer(x, mask=attention_mask, use_cache=False)
206
+ return out
207
+ return fn
208
+ x = torch.utils.checkpoint.checkpoint(
209
+ _fn(layer), x, use_reentrant=False
210
+ )
211
+ cache.append(None)
212
+ else:
213
+ x, p = layer(x, attention_mask, past, use_cache)
214
+ cache.append(p)
215
+
216
+ x = self.norm(x)
217
+ logits = self.lm_head(x)
218
+
219
+ loss = None
220
+ if labels is not None:
221
+ loss = F.cross_entropy(
222
+ logits[..., :-1, :].contiguous().view(-1, logits.size(-1)),
223
+ labels[..., 1:].contiguous().view(-1),
224
+ ignore_index=-100,
225
+ )
226
+
227
+ return CausalLMOutputWithPast(
228
+ loss=loss,
229
+ logits=logits,
230
+ past_key_values=cache if use_cache else None,
231
+ )
232
+
233
+ @torch.no_grad()
234
+ def generate(
235
+ self,
236
+ input_ids,
237
+ max_new_tokens=300,
238
+ temperature=0.75,
239
+ top_k=50,
240
+ top_p=0.92,
241
+ rep_penalty=1.1,
242
+ eos_token_id=None,
243
+ pad_token_id=None,
244
+ **kwargs,
245
+ ):
246
+ self.eval()
247
+ gen = input_ids.clone()
248
+ past = None
249
+
250
+ for _ in range(max_new_tokens):
251
+ inp = gen if past is None else gen[:, -1:]
252
+ out = self(inp, use_cache=True, past_key_values=past)
253
+ past = out.past_key_values
254
+ logits = out.logits[:, -1, :].float()
255
+
256
+ # Repetition penalty
257
+ if rep_penalty != 1.0:
258
+ for tok in set(gen[0].tolist()):
259
+ if logits[0, tok] > 0:
260
+ logits[0, tok] /= rep_penalty
261
+ else:
262
+ logits[0, tok] *= rep_penalty
263
+
264
+ logits = logits / max(temperature, 1e-6)
265
+
266
+ # Top-k
267
+ if top_k > 0:
268
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
269
+ logits[logits < v[:, -1:]] = float('-inf')
270
+
271
+ # Top-p (nucleus)
272
+ if top_p < 1.0:
273
+ sl, si = torch.sort(logits, descending=True)
274
+ cp = torch.cumsum(F.softmax(sl, dim=-1), dim=-1)
275
+ sl[cp - F.softmax(sl, dim=-1) > top_p] = float('-inf')
276
+ logits.scatter_(1, si, sl)
277
+
278
+ next_tok = torch.multinomial(F.softmax(logits, dim=-1), 1)
279
+ gen = torch.cat([gen, next_tok], dim=1)
280
+
281
+ if eos_token_id is not None and (next_tok == eos_token_id).all():
282
+ break
283
+
284
+ return gen