klemenk commited on
Commit
192bd1d
·
verified ·
1 Parent(s): abddcc6

Update modeling_auristream.py

Browse files
Files changed (1) hide show
  1. modeling_auristream.py +64 -29
modeling_auristream.py CHANGED
@@ -72,73 +72,108 @@ class AuriStream(PreTrainedModel):
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
- def forward(self, seq, tgt=None, output_logits=False, output_hidden_states=False, return_dict=False, up_until_layer=None):
 
 
 
 
 
 
 
 
 
76
  """
77
- Input: coch: torch.Tensor of shape (b, t)
78
- tgt_coch: torch.Tensor of shape (b, t) or None
 
 
 
 
 
 
79
  """
80
-
81
  # forward the GPT model itself
82
- tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd)
83
-
84
- # if wpe exists in self.transformer apply leanred positional embedding
85
  if hasattr(self.transformer, 'wpe'):
86
  pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
87
- pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
88
  x = self.transformer.drop(tok_emb + pos_emb)
89
  else:
90
  x = self.transformer.drop(tok_emb)
91
-
92
  all_hidden_states = []
93
  for block_idx, block in enumerate(self.transformer.h):
94
- # Forward the block
95
  all_hidden_states.append(x)
96
  if up_until_layer is not None and block_idx == up_until_layer:
97
  break
98
  x = block(x)
99
-
100
- # append the last hidden state if we did not exit early
101
  if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
102
  all_hidden_states.append(x)
103
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if output_hidden_states and not output_logits:
105
  model_output = BaseModelOutput(
106
- last_hidden_state=x,
107
- hidden_states=all_hidden_states,
108
  )
109
  return model_output
110
-
 
111
  x = self.transformer.ln_f(x)
112
  logits = self.coch_head(x)
113
-
114
  if tgt is not None:
115
-
116
  if output_logits:
117
  all_logits = [logits]
118
-
119
  loss = F.cross_entropy(
120
  logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
121
  )
122
-
123
- # If we have more than one future head, compute the loss for each head
124
  if self.future_heads is not None:
125
  for i, head in enumerate(self.future_heads):
126
- future_logits = head(x[:, :-(i+1)])
127
  loss += F.cross_entropy(
128
- future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1),
 
129
  )
130
  if output_logits:
131
  all_logits.append(future_logits)
132
- # divide loss by number of future heads
133
  loss = loss / (len(self.future_heads) + 1)
134
-
135
  if return_dict:
136
  if output_logits:
137
  if output_hidden_states:
138
  model_output = CausalLMOutput(
139
  loss=loss,
140
  logits=all_logits,
141
- hidden_states=all_hidden_states,
142
  )
143
  else:
144
  model_output = CausalLMOutput(
@@ -150,7 +185,7 @@ class AuriStream(PreTrainedModel):
150
  model_output = CausalLMOutput(
151
  loss=loss,
152
  logits=logits,
153
- hidden_states=all_hidden_states,
154
  )
155
  else:
156
  model_output = CausalLMOutput(
@@ -158,9 +193,9 @@ class AuriStream(PreTrainedModel):
158
  logits=logits,
159
  )
160
  return model_output
161
-
162
  return logits, loss
163
-
164
  return logits, None
165
 
166
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,
 
72
  elif isinstance(module, nn.Embedding):
73
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
74
 
75
+ def forward(
76
+ self,
77
+ seq,
78
+ tgt=None,
79
+ output_logits=False,
80
+ output_hidden_states=False,
81
+ return_dict=False,
82
+ up_until_layer=None,
83
+ normalize_embeddings=None,
84
+ ):
85
  """
86
+ Input: seq: torch.Tensor of shape (b, t)
87
+ tgt: torch.Tensor of shape (b, t) or None
88
+
89
+ Behavior (unchanged unless normalize_embeddings is set and output_hidden_states=True):
90
+ - When normalize_embeddings is None: identical to prior behavior.
91
+ - When normalize_embeddings in {'l2','learned'} and output_hidden_states=True:
92
+ the list returned in `hidden_states` is normalized per request.
93
+ (logits/loss computation remains unchanged.)
94
  """
95
+
96
  # forward the GPT model itself
97
+ tok_emb = self.transformer.wte(seq) # (b, t, n_embd)
98
+
99
+ # learned positional embeddings if present
100
  if hasattr(self.transformer, 'wpe'):
101
  pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device)
102
+ pos_emb = self.transformer.wpe(pos) # (t, n_embd)
103
  x = self.transformer.drop(tok_emb + pos_emb)
104
  else:
105
  x = self.transformer.drop(tok_emb)
106
+
107
  all_hidden_states = []
108
  for block_idx, block in enumerate(self.transformer.h):
109
+ # capture pre-block hidden state
110
  all_hidden_states.append(x)
111
  if up_until_layer is not None and block_idx == up_until_layer:
112
  break
113
  x = block(x)
114
+
115
+ # append final pre-ln_f state if we did not exit early
116
  if up_until_layer is None or block_idx == len(self.transformer.h) - 1:
117
  all_hidden_states.append(x)
118
+
119
+ # optional normalization of hidden states for returning
120
+ hs_to_return = all_hidden_states
121
+ if output_hidden_states and normalize_embeddings is not None:
122
+ if normalize_embeddings == 'l2':
123
+ hs_to_return = [F.normalize(h, p=2, dim=-1) for h in all_hidden_states]
124
+ elif normalize_embeddings == 'learned':
125
+ hs_to_return = []
126
+ L = len(self.transformer.h)
127
+ for i, h in enumerate(all_hidden_states):
128
+ if i < L:
129
+ # input emb -> block0.norm1, block0 out -> block1.norm1, ...
130
+ hs_to_return.append(self.transformer.h[i].norm1(h))
131
+ else:
132
+ # final layer -> transformer.ln_f
133
+ hs_to_return.append(self.transformer.ln_f(h))
134
+ else:
135
+ # any other value behaves like None (no normalization)
136
+ hs_to_return = all_hidden_states
137
+
138
+ # if only hidden states are requested (and not logits), return here
139
  if output_hidden_states and not output_logits:
140
  model_output = BaseModelOutput(
141
+ last_hidden_state=x, # unchanged (pre-ln_f), to preserve original behavior
142
+ hidden_states=hs_to_return, # possibly normalized per the new option
143
  )
144
  return model_output
145
+
146
+ # standard logits path (unchanged)
147
  x = self.transformer.ln_f(x)
148
  logits = self.coch_head(x)
149
+
150
  if tgt is not None:
 
151
  if output_logits:
152
  all_logits = [logits]
153
+
154
  loss = F.cross_entropy(
155
  logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1),
156
  )
157
+
158
+ # future multi-step heads (unchanged)
159
  if self.future_heads is not None:
160
  for i, head in enumerate(self.future_heads):
161
+ future_logits = head(x[:, :-(i + 1)])
162
  loss += F.cross_entropy(
163
+ future_logits.reshape(-1, self.config.vocab_size),
164
+ tgt[:, (i + 1):].reshape(-1),
165
  )
166
  if output_logits:
167
  all_logits.append(future_logits)
 
168
  loss = loss / (len(self.future_heads) + 1)
169
+
170
  if return_dict:
171
  if output_logits:
172
  if output_hidden_states:
173
  model_output = CausalLMOutput(
174
  loss=loss,
175
  logits=all_logits,
176
+ hidden_states=hs_to_return,
177
  )
178
  else:
179
  model_output = CausalLMOutput(
 
185
  model_output = CausalLMOutput(
186
  loss=loss,
187
  logits=logits,
188
+ hidden_states=hs_to_return,
189
  )
190
  else:
191
  model_output = CausalLMOutput(
 
193
  logits=logits,
194
  )
195
  return model_output
196
+
197
  return logits, loss
198
+
199
  return logits, None
200
 
201
  def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9,