kunjcr2 commited on
Commit
a616980
·
verified ·
1 Parent(s): fc1be13

Update modeling_gator.py

Browse files
Files changed (1) hide show
  1. modeling_gator.py +4 -7
modeling_gator.py CHANGED
@@ -100,13 +100,10 @@ class GatorForCausalLM(PreTrainedModel):
100
  self.gator = GatorModel(config)
101
  self.post_init()
102
 
103
- def forward(self, input_ids):
104
- return self.gator(input_ids)["logits"]
105
-
106
  @torch.no_grad()
107
- def generate_one(self, input_ids, temperature=0.8, top_k=5):
108
- logits = self.forward(input_ids)[:, -1, :] / temperature
109
  topk = torch.topk(logits, k=top_k, dim=-1)
110
  probs = torch.softmax(topk.values, dim=-1)
111
- next_token = topk.indices[torch.multinomial(probs, num_samples=1)]
112
- return next_token.item()
 
100
  self.gator = GatorModel(config)
101
  self.post_init()
102
 
 
 
 
103
  @torch.no_grad()
104
+ def forward(self, input_ids, temperature=0.8, top_k=5):
105
+ logits = self.gator(input_ids)["logits"][:, -1, :] / temperature
106
  topk = torch.topk(logits, k=top_k, dim=-1)
107
  probs = torch.softmax(topk.values, dim=-1)
108
+ next_token = topk.indices.gather(-1, torch.multinomial(probs, 1))
109
+ return next_token.squeeze().item()