Exquisique commited on
Commit
384a482
·
1 Parent(s): c96ed72
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -164,8 +164,8 @@ class GPT(PreTrainedModel):
164
  return {"input_ids": input_ids, "past_key_values": past_key_values}
165
 
166
  @torch.no_grad()
167
- def generate(self, input_ids, max_new_tokens, temperature=1.0, top_k=None, attention_mask=None):
168
- for _ in range(max_new_tokens):
169
  idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
170
  out = self(idx_cond)
171
  logits = out['logits'][:, -1, :] / temperature
 
164
  return {"input_ids": input_ids, "past_key_values": past_key_values}
165
 
166
  @torch.no_grad()
167
+ def generate(self, input_ids, max_length, temperature=1.0, top_k=None, attention_mask=None):
168
+ for _ in range(max_length - input_ids.size(1)):
169
  idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
170
  out = self(idx_cond)
171
  logits = out['logits'][:, -1, :] / temperature