Commit ·
384a482
1
Parent(s): c96ed72
"."
Browse files
model.py
CHANGED
|
@@ -164,8 +164,8 @@ class GPT(PreTrainedModel):
|
|
| 164 |
return {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 165 |
|
| 166 |
@torch.no_grad()
|
| 167 |
-
def generate(self, input_ids,
|
| 168 |
-
for _ in range(
|
| 169 |
idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
|
| 170 |
out = self(idx_cond)
|
| 171 |
logits = out['logits'][:, -1, :] / temperature
|
|
|
|
| 164 |
return {"input_ids": input_ids, "past_key_values": past_key_values}
|
| 165 |
|
| 166 |
@torch.no_grad()
|
| 167 |
+
def generate(self, input_ids, max_length, temperature=1.0, top_k=None, attention_mask=None):
|
| 168 |
+
for _ in range(max_length - input_ids.size(1)):
|
| 169 |
idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:]
|
| 170 |
out = self(idx_cond)
|
| 171 |
logits = out['logits'][:, -1, :] / temperature
|