bhsinghgrid commited on
Commit
0a4e1d5
·
verified ·
1 Parent(s): 9124d18

Update step_ablation generate fallback

Browse files
Files changed (1) hide show
  1. analysis/step_ablation.py +6 -2
analysis/step_ablation.py CHANGED
@@ -448,7 +448,11 @@ def evaluate_all_models(models: Dict[int, object],
448
  src = src.unsqueeze(0)
449
 
450
  with torch.no_grad():
451
- out = model.model.generate_cached(src.to(device))
 
 
 
 
452
 
453
  ids = [x for x in out[0].tolist() if x > 4]
454
  pred = tgt_tokenizer.decode(ids).strip()
@@ -579,4 +583,4 @@ def run_task4(models, src_list, ref_list, tgt_tokenizer):
579
  with open("analysis/outputs/task4_report.txt", "w") as f:
580
  f.write(f"Optimal diffusion steps = {knee_T}\n")
581
 
582
- return knee_T
 
448
  src = src.unsqueeze(0)
449
 
450
  with torch.no_grad():
451
+ if hasattr(model, "model") and hasattr(model.model, "generate_cached"):
452
+ out = model.model.generate_cached(src.to(device))
453
+ else:
454
+ # Fallback for wrappers that only expose top-level generate.
455
+ out = model.generate(src.to(device))
456
 
457
  ids = [x for x in out[0].tolist() if x > 4]
458
  pred = tgt_tokenizer.decode(ids).strip()
 
583
  with open("analysis/outputs/task4_report.txt", "w") as f:
584
  f.write(f"Optimal diffusion steps = {knee_T}\n")
585
 
586
+ return knee_T