robinfaro commited on
Commit
076a580
·
verified ·
1 Parent(s): bc0146f

Adding modeling.py file

Browse files
Files changed (1) hide show
  1. modeling.py +12 -11
modeling.py CHANGED
@@ -212,17 +212,18 @@ class MoLM(PreTrainedModel):
212
  if idx.size(1) <= self.config.sequence_length
213
  else idx[:, -self.config.sequence_length :]
214
  )
215
- # forward the model to get the logits for the index in the sequence
216
- logits = self(idx_cond, date, get_logits=True).logits
217
- # pluck the logits at the final step and scale by desired temperature
218
- logits = logits[:, -1, :] / temperature
219
- # optionally crop the logits to only the top k options
220
- if top_k is not None:
221
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
222
- logits[logits < v[:, [-1]]] = -float("Inf")
223
- # apply softmax to convert logits to (normalized) probabilities
224
- probs = F.softmax(logits, dim=-1)
225
- # sample from the distribution
 
226
  idx_next = torch.multinomial(probs, num_samples=1)
227
  # append sampled index to the running sequence and continue
228
  idx = torch.cat((idx, idx_next), dim=1)
 
212
  if idx.size(1) <= self.config.sequence_length
213
  else idx[:, -self.config.sequence_length :]
214
  )
215
+ # # forward the model to get the logits for the index in the sequence
216
+ # logits = self(idx_cond, date, get_logits=True).logits
217
+ # # pluck the logits at the final step and scale by desired temperature
218
+ # logits = logits[:, -1, :] / temperature
219
+ # # optionally crop the logits to only the top k options
220
+ # if top_k is not None:
221
+ # v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
222
+ # logits[logits < v[:, [-1]]] = -float("Inf")
223
+ # # apply softmax to convert logits to (normalized) probabilities
224
+ # probs = F.softmax(logits, dim=-1)
225
+ # # sample from the distribution
226
+ probs = self(idx_cond, date=date).combined_log_probs[:, -1, :]
227
  idx_next = torch.multinomial(probs, num_samples=1)
228
  # append sampled index to the running sequence and continue
229
  idx = torch.cat((idx, idx_next), dim=1)