Spaces:
Sleeping
Sleeping
Update step_ablation generate fallback
Browse files
analysis/step_ablation.py
CHANGED
|
@@ -448,7 +448,11 @@ def evaluate_all_models(models: Dict[int, object],
|
|
| 448 |
src = src.unsqueeze(0)
|
| 449 |
|
| 450 |
with torch.no_grad():
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
ids = [x for x in out[0].tolist() if x > 4]
|
| 454 |
pred = tgt_tokenizer.decode(ids).strip()
|
|
@@ -579,4 +583,4 @@ def run_task4(models, src_list, ref_list, tgt_tokenizer):
|
|
| 579 |
with open("analysis/outputs/task4_report.txt", "w") as f:
|
| 580 |
f.write(f"Optimal diffusion steps = {knee_T}\n")
|
| 581 |
|
| 582 |
-
return knee_T
|
|
|
|
| 448 |
src = src.unsqueeze(0)
|
| 449 |
|
| 450 |
with torch.no_grad():
|
| 451 |
+
if hasattr(model, "model") and hasattr(model.model, "generate_cached"):
|
| 452 |
+
out = model.model.generate_cached(src.to(device))
|
| 453 |
+
else:
|
| 454 |
+
# Fallback for wrappers that only expose top-level generate.
|
| 455 |
+
out = model.generate(src.to(device))
|
| 456 |
|
| 457 |
ids = [x for x in out[0].tolist() if x > 4]
|
| 458 |
pred = tgt_tokenizer.decode(ids).strip()
|
|
|
|
| 583 |
with open("analysis/outputs/task4_report.txt", "w") as f:
|
| 584 |
f.write(f"Optimal diffusion steps = {knee_T}\n")
|
| 585 |
|
| 586 |
+
return knee_T
|