Adding modeling.py file
Browse files- modeling.py +12 -11
modeling.py
CHANGED
|
@@ -212,17 +212,18 @@ class MoLM(PreTrainedModel):
|
|
| 212 |
if idx.size(1) <= self.config.sequence_length
|
| 213 |
else idx[:, -self.config.sequence_length :]
|
| 214 |
)
|
| 215 |
-
# forward the model to get the logits for the index in the sequence
|
| 216 |
-
logits = self(idx_cond, date, get_logits=True).logits
|
| 217 |
-
# pluck the logits at the final step and scale by desired temperature
|
| 218 |
-
logits = logits[:, -1, :] / temperature
|
| 219 |
-
# optionally crop the logits to only the top k options
|
| 220 |
-
if top_k is not None:
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
# apply softmax to convert logits to (normalized) probabilities
|
| 224 |
-
probs = F.softmax(logits, dim=-1)
|
| 225 |
-
# sample from the distribution
|
|
|
|
| 226 |
idx_next = torch.multinomial(probs, num_samples=1)
|
| 227 |
# append sampled index to the running sequence and continue
|
| 228 |
idx = torch.cat((idx, idx_next), dim=1)
|
|
|
|
| 212 |
if idx.size(1) <= self.config.sequence_length
|
| 213 |
else idx[:, -self.config.sequence_length :]
|
| 214 |
)
|
| 215 |
+
# # forward the model to get the logits for the index in the sequence
|
| 216 |
+
# logits = self(idx_cond, date, get_logits=True).logits
|
| 217 |
+
# # pluck the logits at the final step and scale by desired temperature
|
| 218 |
+
# logits = logits[:, -1, :] / temperature
|
| 219 |
+
# # optionally crop the logits to only the top k options
|
| 220 |
+
# if top_k is not None:
|
| 221 |
+
# v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 222 |
+
# logits[logits < v[:, [-1]]] = -float("Inf")
|
| 223 |
+
# # apply softmax to convert logits to (normalized) probabilities
|
| 224 |
+
# probs = F.softmax(logits, dim=-1)
|
| 225 |
+
# # sample from the distribution
|
| 226 |
+
probs = self(idx_cond, date=date).combined_log_probs[:, -1, :]
|
| 227 |
idx_next = torch.multinomial(probs, num_samples=1)
|
| 228 |
# append sampled index to the running sequence and continue
|
| 229 |
idx = torch.cat((idx, idx_next), dim=1)
|