Adding modeling.py file
Browse files- modeling.py +13 -2
modeling.py
CHANGED
|
@@ -223,8 +223,19 @@ class MoLM(PreTrainedModel):
|
|
| 223 |
# # apply softmax to convert logits to (normalized) probabilities
|
| 224 |
# probs = F.softmax(logits, dim=-1)
|
| 225 |
# # sample from the distribution
|
| 226 |
-
|
| 227 |
-
idx_next = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
# append sampled index to the running sequence and continue
|
| 229 |
idx = torch.cat((idx, idx_next), dim=1)
|
| 230 |
# check if we hit the end of the sequence
|
|
|
|
| 223 |
# # apply softmax to convert logits to (normalized) probabilities
|
| 224 |
# probs = F.softmax(logits, dim=-1)
|
| 225 |
# # sample from the distribution
|
| 226 |
+
log_probs = self(idx_cond, date=date).combined_log_probs[:, -1, :]
|
| 227 |
+
#idx_next = torch.multinomial(probs, num_samples=1)
|
| 228 |
+
# Sample from the log probabilities
|
| 229 |
+
if temperature == 0:
|
| 230 |
+
# If temperature is 0, take the argmax (greedy sampling)
|
| 231 |
+
idx_next = torch.argmax(log_probs, dim=-1, keepdim=True)
|
| 232 |
+
else:
|
| 233 |
+
# Apply temperature scaling
|
| 234 |
+
scaled_log_probs = log_probs / temperature
|
| 235 |
+
# Convert log probabilities to probabilities
|
| 236 |
+
probs = torch.exp(scaled_log_probs)
|
| 237 |
+
# Sample from the distribution
|
| 238 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 239 |
# append sampled index to the running sequence and continue
|
| 240 |
idx = torch.cat((idx, idx_next), dim=1)
|
| 241 |
# check if we hit the end of the sequence
|