dill-dev commited on
Commit
fecde63
Β·
verified Β·
1 Parent(s): 192968d

Update modeling_momo.py

Browse files
Files changed (1) hide show
  1. modeling_momo.py +12 -44
modeling_momo.py CHANGED
@@ -10,10 +10,6 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
10
  from .configuration_momo import MomoConfig
11
 
12
 
13
- # ════════════════════════════════════════════════════════════════
14
- # COMPONENTS
15
- # ════════════════════════════════════════════════════════════════
16
-
17
  class RMSNorm(nn.Module):
18
  def __init__(self, dim, eps=1e-5):
19
  super().__init__()
@@ -57,10 +53,6 @@ def apply_rope(q, k, cos, sin):
57
  return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)
58
 
59
 
60
- # ════════════════════════════════════════════════════════════════
61
- # ATTENTION β€” Grouped Query Attention (GQA)
62
- # ════════════════════════════════════════════════════════════════
63
-
64
  class MomoAttention(nn.Module):
65
  def __init__(self, cfg: MomoConfig):
66
  super().__init__()
@@ -112,10 +104,6 @@ class MomoAttention(nn.Module):
112
  return self.o(out), pres
113
 
114
 
115
- # ════════════════════════════════════════════════════════════════
116
- # FEED-FORWARD β€” SwiGLU
117
- # ════════════════════════════════════════════════════════════════
118
-
119
  class MomoFFN(nn.Module):
120
  def __init__(self, cfg: MomoConfig):
121
  super().__init__()
@@ -127,10 +115,6 @@ class MomoFFN(nn.Module):
127
  return self.down(F.silu(self.gate(x)) * self.up(x))
128
 
129
 
130
- # ════════════════════════════════════════════════════════════════
131
- # TRANSFORMER BLOCK
132
- # ════════════════════════════════════════════════════════════════
133
-
134
  class MomoBlock(nn.Module):
135
  def __init__(self, cfg: MomoConfig):
136
  super().__init__()
@@ -146,50 +130,37 @@ class MomoBlock(nn.Module):
146
  return x, p
147
 
148
 
149
- # ════════════════════════════════════════════════════════════════
150
- # 🌸 MOMO FOR CAUSAL LM
151
- # ════════════════════════════════════════════════════════════════
152
-
153
  class MomoForCausalLM(PreTrainedModel):
154
  config_class = MomoConfig
155
  _no_split_modules = ["MomoBlock"]
156
  _tied_weights_keys = ["lm_head.weight"]
 
 
157
 
158
  def __init__(self, cfg: MomoConfig):
159
  super().__init__(cfg)
160
- self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
161
- self.layers = nn.ModuleList([MomoBlock(cfg) for _ in range(cfg.num_hidden_layers)])
162
- self.norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
163
- # lm_head weight is tied to embed β€” do NOT pre-tie here,
164
- # HF will call tie_weights() after loading the state dict
165
- self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
166
  self.grad_ckpt = cfg.use_gradient_checkpointing
167
  self.apply(self._init_weights)
168
 
169
- # ── Required by HF 4.40+ ────────────────────────────────────
170
- @property
171
- def all_tied_weights_keys(self):
172
- # Must return a dict: {weight_to_tie: source_weight}
173
- return {"lm_head.weight": "embed.weight"}
174
-
175
- def tie_weights(self, missing_keys=None, recompute_mapping=False, **kwargs):
176
- self.lm_head.weight = self.embed.weight
177
-
178
- # ── Embedding accessors (needed by HF tie_weights logic) ─────
179
  def get_input_embeddings(self):
180
  return self.embed
181
 
182
  def set_input_embeddings(self, value):
183
  self.embed = value
184
- self.tie_weights()
185
 
186
  def get_output_embeddings(self):
187
  return self.lm_head
188
 
189
- def set_output_embeddings(self, new_embeddings):
190
- self.lm_head = new_embeddings
191
 
192
- # ── Weight init ──────────────────────────────────────────────
193
  def _init_weights(self, m):
194
  if isinstance(m, nn.Linear):
195
  nn.init.normal_(m.weight, std=0.02)
@@ -198,7 +169,6 @@ class MomoForCausalLM(PreTrainedModel):
198
  elif isinstance(m, nn.Embedding):
199
  nn.init.normal_(m.weight, std=0.02)
200
 
201
- # ── Forward ──────────────────────────────────────────────────
202
  def forward(
203
  self,
204
  input_ids=None,
@@ -216,7 +186,7 @@ class MomoForCausalLM(PreTrainedModel):
216
  if self.grad_ckpt and self.training:
217
  def _fn(layer):
218
  def fn(x):
219
- out, _ = layer(x, mask=attention_mask, use_cache=False)
220
  return out
221
  return fn
222
  x = torch.utils.checkpoint.checkpoint(
@@ -244,7 +214,6 @@ class MomoForCausalLM(PreTrainedModel):
244
  past_key_values=cache if use_cache else None,
245
  )
246
 
247
- # ── Generate ─────────────────────────────────────────────────
248
  @torch.no_grad()
249
  def generate(
250
  self,
@@ -268,7 +237,6 @@ class MomoForCausalLM(PreTrainedModel):
268
  past = out.past_key_values
269
  logits = out.logits[:, -1, :].float()
270
 
271
- # Repetition penalty
272
  if rep_penalty != 1.0:
273
  for tok in set(gen[0].tolist()):
274
  if logits[0, tok] > 0:
 
10
  from .configuration_momo import MomoConfig
11
 
12
 
 
 
 
 
13
  class RMSNorm(nn.Module):
14
  def __init__(self, dim, eps=1e-5):
15
  super().__init__()
 
53
  return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)
54
 
55
 
 
 
 
 
56
  class MomoAttention(nn.Module):
57
  def __init__(self, cfg: MomoConfig):
58
  super().__init__()
 
104
  return self.o(out), pres
105
 
106
 
 
 
 
 
107
  class MomoFFN(nn.Module):
108
  def __init__(self, cfg: MomoConfig):
109
  super().__init__()
 
115
  return self.down(F.silu(self.gate(x)) * self.up(x))
116
 
117
 
 
 
 
 
118
  class MomoBlock(nn.Module):
119
  def __init__(self, cfg: MomoConfig):
120
  super().__init__()
 
130
  return x, p
131
 
132
 
 
 
 
 
133
  class MomoForCausalLM(PreTrainedModel):
134
  config_class = MomoConfig
135
  _no_split_modules = ["MomoBlock"]
136
  _tied_weights_keys = ["lm_head.weight"]
137
+ # HF 4.40+ calls model.all_tied_weights_keys.keys() β€” must be a dict on the instance
138
+ all_tied_weights_keys = {"lm_head.weight": "embed.weight"}
139
 
140
  def __init__(self, cfg: MomoConfig):
141
  super().__init__(cfg)
142
+ self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
143
+ self.layers = nn.ModuleList([MomoBlock(cfg) for _ in range(cfg.num_hidden_layers)])
144
+ self.norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
145
+ self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
146
+ # Tie weights now β€” HF post-load also calls get_output_embeddings to re-tie
147
+ self.lm_head.weight = self.embed.weight
148
  self.grad_ckpt = cfg.use_gradient_checkpointing
149
  self.apply(self._init_weights)
150
 
151
+ # HF calls these to re-tie after loading β€” must be defined
 
 
 
 
 
 
 
 
 
152
  def get_input_embeddings(self):
153
  return self.embed
154
 
155
  def set_input_embeddings(self, value):
156
  self.embed = value
 
157
 
158
  def get_output_embeddings(self):
159
  return self.lm_head
160
 
161
+ def set_output_embeddings(self, value):
162
+ self.lm_head = value
163
 
 
164
  def _init_weights(self, m):
165
  if isinstance(m, nn.Linear):
166
  nn.init.normal_(m.weight, std=0.02)
 
169
  elif isinstance(m, nn.Embedding):
170
  nn.init.normal_(m.weight, std=0.02)
171
 
 
172
  def forward(
173
  self,
174
  input_ids=None,
 
186
  if self.grad_ckpt and self.training:
187
  def _fn(layer):
188
  def fn(x):
189
+ out, _ = layer(x, mask=None, use_cache=False)
190
  return out
191
  return fn
192
  x = torch.utils.checkpoint.checkpoint(
 
214
  past_key_values=cache if use_cache else None,
215
  )
216
 
 
217
  @torch.no_grad()
218
  def generate(
219
  self,
 
237
  past = out.past_key_values
238
  logits = out.logits[:, -1, :].float()
239
 
 
240
  if rep_penalty != 1.0:
241
  for tok in set(gen[0].tolist()):
242
  if logits[0, tok] > 0: