PhysiQuanty commited on
Commit
ca1b436
·
verified ·
1 Parent(s): 4571215

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +1 -1
inference.py CHANGED
@@ -198,7 +198,7 @@ def main() -> None:
198
  with torch.no_grad():
199
  for _ in range(int(args.max_new_tokens)):
200
  # full forward sur toute la séquence, sans cache
201
- out = m(input_ids=tokens, use_cache=False)
202
  logits = out.logits[:, -1, :]
203
 
204
  full_seq = tokens[0].tolist()
 
198
  with torch.no_grad():
199
  for _ in range(int(args.max_new_tokens)):
200
  # full forward sur toute la séquence, sans cache
201
+ out = m(input_ids=tokens, use_cache=True)
202
  logits = out.logits[:, -1, :]
203
 
204
  full_seq = tokens[0].tolist()