Spaces:
Sleeping
Sleeping
Update run_analysis task4 compatibility
Browse files- analysis/run_analysis.py +57 -29
analysis/run_analysis.py
CHANGED
|
@@ -275,43 +275,71 @@ def run_task4(phase, model, src_tok, tgt_tok, device, cfg,
|
|
| 275 |
print(f" TASK 4 β Step Ablation (phase={phase})")
|
| 276 |
print("="*65)
|
| 277 |
|
| 278 |
-
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
if phase == "generate_configs":
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
existing = [T for T in [4, 8, 16, 32, 64]
|
| 291 |
if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
|
| 292 |
if not existing:
|
| 293 |
print(" No ablation models found at ablation_results/T*/best_model.pt")
|
| 294 |
-
print(" Run: python analysis/run_analysis.py --task 4 --phase generate_configs")
|
| 295 |
-
print(" Then: bash ablation_configs/train_all.sh")
|
| 296 |
return
|
| 297 |
-
|
| 298 |
print(f" Found models for T={existing}")
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
|
| 317 |
# ββ Task 5 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 275 |
print(f" TASK 4 β Step Ablation (phase={phase})")
|
| 276 |
print("="*65)
|
| 277 |
|
| 278 |
+
import analysis.step_ablation as step_ablation
|
| 279 |
+
|
| 280 |
+
# Legacy API
|
| 281 |
+
has_legacy = all(hasattr(step_ablation, fn) for fn in [
|
| 282 |
+
"generate_ablation_configs", "run_ablation_analysis", "plot_ablation_3d"
|
| 283 |
+
])
|
| 284 |
+
|
| 285 |
+
# New API
|
| 286 |
+
has_new = hasattr(step_ablation, "run_task4")
|
| 287 |
|
| 288 |
if phase == "generate_configs":
|
| 289 |
+
if has_legacy:
|
| 290 |
+
print(" Generating ablation configs...")
|
| 291 |
+
step_ablation.generate_ablation_configs(output_dir="ablation_configs")
|
| 292 |
+
print("\n NEXT STEPS:")
|
| 293 |
+
print(" 1. bash ablation_configs/train_all.sh")
|
| 294 |
+
print(" 2. python analysis/run_analysis.py --task 4 --phase analyze")
|
| 295 |
+
return
|
| 296 |
+
print(" This step_ablation version does not expose config generation helpers.")
|
| 297 |
+
print(" Use your latest ablation training script/config pipeline directly.")
|
| 298 |
+
return
|
| 299 |
+
|
| 300 |
+
if phase == "analyze":
|
| 301 |
existing = [T for T in [4, 8, 16, 32, 64]
|
| 302 |
if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
|
| 303 |
if not existing:
|
| 304 |
print(" No ablation models found at ablation_results/T*/best_model.pt")
|
|
|
|
|
|
|
| 305 |
return
|
|
|
|
| 306 |
print(f" Found models for T={existing}")
|
| 307 |
+
|
| 308 |
+
if has_legacy:
|
| 309 |
+
results = step_ablation.run_ablation_analysis(
|
| 310 |
+
ablation_dir="ablation_results", base_cfg=cfg,
|
| 311 |
+
src_list=src_list[:200], ref_list=ref_list[:200],
|
| 312 |
+
tgt_tokenizer=tgt_tok, device=device,
|
| 313 |
+
output_dir=OUTPUT_DIR)
|
| 314 |
+
step_ablation.plot_ablation_3d(
|
| 315 |
+
results, save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png"))
|
| 316 |
+
elif has_new:
|
| 317 |
+
from inference import load_model as _load_model
|
| 318 |
+
models = {}
|
| 319 |
+
for T in existing:
|
| 320 |
+
ckpt = f"ablation_results/T{T}/best_model.pt"
|
| 321 |
+
cfg_t = copy.deepcopy(cfg)
|
| 322 |
+
cfg_t["model"]["diffusion_steps"] = T
|
| 323 |
+
cfg_t["inference"]["num_steps"] = T
|
| 324 |
+
m_t, _ = _load_model(ckpt, cfg_t, device)
|
| 325 |
+
m_t.eval()
|
| 326 |
+
models[T] = m_t
|
| 327 |
+
knee_t = step_ablation.run_task4(
|
| 328 |
+
models, src_list[:200], ref_list[:200], tgt_tok)
|
| 329 |
+
print(f" New pipeline suggested optimal T={knee_t}")
|
| 330 |
+
else:
|
| 331 |
+
print(" Unsupported step_ablation API; please sync analysis/step_ablation.py")
|
| 332 |
+
return
|
| 333 |
+
|
| 334 |
+
# Optional adversarial robustness (legacy helper only)
|
| 335 |
+
if hasattr(step_ablation, "run_adversarial_test"):
|
| 336 |
+
print("\n Running adversarial robustness test...")
|
| 337 |
+
inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4])
|
| 338 |
+
for s in src_list[:50]]
|
| 339 |
+
step_ablation.run_adversarial_test(
|
| 340 |
+
model, src_tok, tgt_tok,
|
| 341 |
+
test_inputs=inp_texts, test_refs=ref_list[:50],
|
| 342 |
+
device=device, output_dir=OUTPUT_DIR)
|
| 343 |
|
| 344 |
|
| 345 |
# ββ Task 5 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|