robinfaro commited on
Commit
1fc04cf
·
verified ·
1 Parent(s): 30191af

Adding modeling.py file

Browse files
Files changed (1) hide show
  1. modeling.py +13 -2
modeling.py CHANGED
@@ -223,8 +223,19 @@ class MoLM(PreTrainedModel):
223
  # # apply softmax to convert logits to (normalized) probabilities
224
  # probs = F.softmax(logits, dim=-1)
225
  # # sample from the distribution
226
- probs = self(idx_cond, date=date).combined_log_probs[:, -1, :]
227
- idx_next = torch.multinomial(probs, num_samples=1)
 
 
 
 
 
 
 
 
 
 
 
228
  # append sampled index to the running sequence and continue
229
  idx = torch.cat((idx, idx_next), dim=1)
230
  # check if we hit the end of the sequence
 
223
  # # apply softmax to convert logits to (normalized) probabilities
224
  # probs = F.softmax(logits, dim=-1)
225
  # # sample from the distribution
226
+ log_probs = self(idx_cond, date=date).combined_log_probs[:, -1, :]
227
+ #idx_next = torch.multinomial(probs, num_samples=1)
228
+ # Sample from the log probabilities
229
+ if temperature == 0:
230
+ # If temperature is 0, take the argmax (greedy sampling)
231
+ idx_next = torch.argmax(log_probs, dim=-1, keepdim=True)
232
+ else:
233
+ # Apply temperature scaling
234
+ scaled_log_probs = log_probs / temperature
235
+ # Convert log probabilities to probabilities
236
+ probs = torch.exp(scaled_log_probs)
237
+ # Sample from the distribution
238
+ idx_next = torch.multinomial(probs, num_samples=1)
239
  # append sampled index to the running sequence and continue
240
  idx = torch.cat((idx, idx_next), dim=1)
241
  # check if we hit the end of the sequence