Update modeling_gator.py
Browse files- modeling_gator.py +4 -7
modeling_gator.py
CHANGED
|
@@ -100,13 +100,10 @@ class GatorForCausalLM(PreTrainedModel):
|
|
| 100 |
self.gator = GatorModel(config)
|
| 101 |
self.post_init()
|
| 102 |
|
| 103 |
-
def forward(self, input_ids):
|
| 104 |
-
return self.gator(input_ids)["logits"]
|
| 105 |
-
|
| 106 |
@torch.no_grad()
|
| 107 |
-
def
|
| 108 |
-
logits = self.
|
| 109 |
topk = torch.topk(logits, k=top_k, dim=-1)
|
| 110 |
probs = torch.softmax(topk.values, dim=-1)
|
| 111 |
-
next_token = topk.indices
|
| 112 |
-
return next_token.item()
|
|
|
|
| 100 |
self.gator = GatorModel(config)
|
| 101 |
self.post_init()
|
| 102 |
|
|
|
|
|
|
|
|
|
|
| 103 |
@torch.no_grad()
|
| 104 |
+
def forward(self, input_ids, temperature=0.8, top_k=5):
|
| 105 |
+
logits = self.gator(input_ids)["logits"][:, -1, :] / temperature
|
| 106 |
topk = torch.topk(logits, k=top_k, dim=-1)
|
| 107 |
probs = torch.softmax(topk.values, dim=-1)
|
| 108 |
+
next_token = topk.indices.gather(-1, torch.multinomial(probs, 1))
|
| 109 |
+
return next_token.squeeze().item()
|