Update inference.py
Browse files- inference.py +1 -1
inference.py
CHANGED
|
@@ -198,7 +198,7 @@ def main() -> None:
|
|
| 198 |
with torch.no_grad():
|
| 199 |
for _ in range(int(args.max_new_tokens)):
|
| 200 |
# full forward sur toute la séquence, sans cache
|
| 201 |
-
out = m(input_ids=tokens, use_cache=
|
| 202 |
logits = out.logits[:, -1, :]
|
| 203 |
|
| 204 |
full_seq = tokens[0].tolist()
|
|
|
|
| 198 |
with torch.no_grad():
|
| 199 |
for _ in range(int(args.max_new_tokens)):
|
| 200 |
# full forward sur toute la séquence, sans cache
|
| 201 |
+
out = m(input_ids=tokens, use_cache=True)
|
| 202 |
logits = out.logits[:, -1, :]
|
| 203 |
|
| 204 |
full_seq = tokens[0].tolist()
|