bhsinghgrid commited on
Commit
f52024c
Β·
verified Β·
1 Parent(s): 0a4e1d5

Update run_analysis task4 compatibility

Browse files
Files changed (1) hide show
  1. analysis/run_analysis.py +57 -29
analysis/run_analysis.py CHANGED
@@ -275,43 +275,71 @@ def run_task4(phase, model, src_tok, tgt_tok, device, cfg,
275
  print(f" TASK 4 β€” Step Ablation (phase={phase})")
276
  print("="*65)
277
 
278
- from analysis.step_ablation import (generate_ablation_configs,
279
- run_ablation_analysis, plot_ablation_3d, run_adversarial_test)
 
 
 
 
 
 
 
280
 
281
  if phase == "generate_configs":
282
- print(" Generating ablation configs...")
283
- generate_ablation_configs(output_dir="ablation_configs")
284
- print("\n NEXT STEPS:")
285
- print(" 1. bash ablation_configs/train_all.sh")
286
- print(" 2. python analysis/run_analysis.py --task 4 --phase analyze")
287
-
288
- elif phase == "analyze":
289
- # Check which models exist
 
 
 
 
290
  existing = [T for T in [4, 8, 16, 32, 64]
291
  if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
292
  if not existing:
293
  print(" No ablation models found at ablation_results/T*/best_model.pt")
294
- print(" Run: python analysis/run_analysis.py --task 4 --phase generate_configs")
295
- print(" Then: bash ablation_configs/train_all.sh")
296
  return
297
-
298
  print(f" Found models for T={existing}")
299
- results = run_ablation_analysis(
300
- ablation_dir="ablation_results", base_cfg=cfg,
301
- src_list=src_list[:200], ref_list=ref_list[:200],
302
- tgt_tokenizer=tgt_tok, device=device,
303
- output_dir=OUTPUT_DIR)
304
- plot_ablation_3d(results,
305
- save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png"))
306
-
307
- # Adversarial robustness always runs on existing model (no retraining)
308
- print("\n Running adversarial robustness test...")
309
- inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4])
310
- for s in src_list[:50]]
311
- run_adversarial_test(
312
- model, src_tok, tgt_tok,
313
- test_inputs=inp_texts, test_refs=ref_list[:50],
314
- device=device, output_dir=OUTPUT_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
 
317
  # ── Task 5 ────────────────────────────────────────────────────────────
 
275
  print(f" TASK 4 β€” Step Ablation (phase={phase})")
276
  print("="*65)
277
 
278
+ import analysis.step_ablation as step_ablation
279
+
280
+ # Legacy API
281
+ has_legacy = all(hasattr(step_ablation, fn) for fn in [
282
+ "generate_ablation_configs", "run_ablation_analysis", "plot_ablation_3d"
283
+ ])
284
+
285
+ # New API
286
+ has_new = hasattr(step_ablation, "run_task4")
287
 
288
  if phase == "generate_configs":
289
+ if has_legacy:
290
+ print(" Generating ablation configs...")
291
+ step_ablation.generate_ablation_configs(output_dir="ablation_configs")
292
+ print("\n NEXT STEPS:")
293
+ print(" 1. bash ablation_configs/train_all.sh")
294
+ print(" 2. python analysis/run_analysis.py --task 4 --phase analyze")
295
+ return
296
+ print(" This step_ablation version does not expose config generation helpers.")
297
+ print(" Use your latest ablation training script/config pipeline directly.")
298
+ return
299
+
300
+ if phase == "analyze":
301
  existing = [T for T in [4, 8, 16, 32, 64]
302
  if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
303
  if not existing:
304
  print(" No ablation models found at ablation_results/T*/best_model.pt")
 
 
305
  return
 
306
  print(f" Found models for T={existing}")
307
+
308
+ if has_legacy:
309
+ results = step_ablation.run_ablation_analysis(
310
+ ablation_dir="ablation_results", base_cfg=cfg,
311
+ src_list=src_list[:200], ref_list=ref_list[:200],
312
+ tgt_tokenizer=tgt_tok, device=device,
313
+ output_dir=OUTPUT_DIR)
314
+ step_ablation.plot_ablation_3d(
315
+ results, save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png"))
316
+ elif has_new:
317
+ from inference import load_model as _load_model
318
+ models = {}
319
+ for T in existing:
320
+ ckpt = f"ablation_results/T{T}/best_model.pt"
321
+ cfg_t = copy.deepcopy(cfg)
322
+ cfg_t["model"]["diffusion_steps"] = T
323
+ cfg_t["inference"]["num_steps"] = T
324
+ m_t, _ = _load_model(ckpt, cfg_t, device)
325
+ m_t.eval()
326
+ models[T] = m_t
327
+ knee_t = step_ablation.run_task4(
328
+ models, src_list[:200], ref_list[:200], tgt_tok)
329
+ print(f" New pipeline suggested optimal T={knee_t}")
330
+ else:
331
+ print(" Unsupported step_ablation API; please sync analysis/step_ablation.py")
332
+ return
333
+
334
+ # Optional adversarial robustness (legacy helper only)
335
+ if hasattr(step_ablation, "run_adversarial_test"):
336
+ print("\n Running adversarial robustness test...")
337
+ inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4])
338
+ for s in src_list[:50]]
339
+ step_ablation.run_adversarial_test(
340
+ model, src_tok, tgt_tok,
341
+ test_inputs=inp_texts, test_refs=ref_list[:50],
342
+ device=device, output_dir=OUTPUT_DIR)
343
 
344
 
345
  # ── Task 5 ────────────────────────────────────────────────────────────