| """ |
| analysis/step_ablation.py |
| ========================== |
| Task 4: Semantic Robustness β Ablation of Diffusion Steps vs Meaning Preservation |
| |
| Two-phase workflow (retraining IS required for different T values): |
| |
| PHASE 1 β Generate configs + train (run once per T value): |
| python analysis/step_ablation.py --phase generate_configs |
| # Creates configs: ablation_configs/T4.py, T8.py, T16.py, T32.py, T64.py |
| # Then train each: MODEL_TYPE=d3pm_cross_attention python train.py (for each config) |
| |
| PHASE 2 β Analyze trained models (no retraining needed): |
| python analysis/step_ablation.py --phase analyze |
| # Loads each trained model, generates 200 paraphrases, computes CER |
| # Produces 3D plot: X=steps, Y=generation_speed, Z=CER |
| |
| Why retraining is needed: |
| A model trained with T=128 learns to denoise from x_t~Uniform[0,128]. |
| Running it with T=4 means the model only sees tβ{0,1,2,3} β which it |
| was never trained on at those scales. Outputs are meaningless. |
| You must train a separate model for each T value. |
| |
| Also implements adversarial robustness test (no retraining): |
| Takes your existing T=128 model and tests whether corrupted IAST |
| inputs (typos, character swaps) cause proportional output degradation. |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import os |
| import sys |
| import time |
| import json |
| import copy |
| from typing import List, Dict, Optional |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
| |
|
|
| T_VALUES = [4, 8, 16, 32, 64] |
|
|
| def generate_ablation_configs(base_config_path: str = "config.py", |
| output_dir: str = "ablation_configs"): |
| """ |
| Generate one config file per T value. |
| Each config is a copy of the base config with diffusion_steps changed. |
| |
| After running this, train each model: |
| for T in 4 8 16 32 64; do |
| cp ablation_configs/config_T${T}.py config.py |
| python train.py |
| mv results7/d3pm_cross_attention_neg_False \ |
| ablation_results/T${T} |
| done |
| """ |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| with open(base_config_path, "r") as f: |
| base_src = f.read() |
|
|
| for T in T_VALUES: |
| |
| cfg_src = base_src |
| cfg_src = cfg_src.replace( |
| '"diffusion_steps": 128', |
| f'"diffusion_steps": {T}' |
| ) |
| cfg_src = cfg_src.replace( |
| "'diffusion_steps': 128", |
| f"'diffusion_steps': {T}" |
| ) |
| cfg_src = cfg_src.replace( |
| '"num_steps": 128', |
| f'"num_steps": {T}' |
| ) |
| cfg_src = cfg_src.replace( |
| "'num_steps': 128", |
| f"'num_steps': {T}" |
| ) |
| out_path = os.path.join(output_dir, f"config_T{T}.py") |
| with open(out_path, "w") as f: |
| f.write(f"# Ablation config: T={T} diffusion steps\n") |
| f.write(cfg_src) |
| print(f" Wrote: {out_path}") |
|
|
| |
| shell_script = os.path.join(output_dir, "train_all.sh") |
| with open(shell_script, "w") as f: |
| f.write("#!/bin/bash\n") |
| f.write("# Run this script to train all ablation models\n\n") |
| for T in T_VALUES: |
| f.write(f"echo '=== Training T={T} ==='\n") |
| f.write(f"cp {output_dir}/config_T{T}.py config.py\n") |
| f.write(f"python train.py\n") |
| f.write(f"mkdir -p ablation_results/T{T}\n") |
| f.write(f"cp -r results7/d3pm_cross_attention_neg_False/best_model.pt " |
| f"ablation_results/T{T}/best_model.pt\n") |
| f.write(f"cp -r results7/d3pm_cross_attention_neg_False/train.log " |
| f"ablation_results/T{T}/train.log\n\n") |
| os.chmod(shell_script, 0o755) |
| print(f"\nTraining script: {shell_script}") |
| print(f"Run: bash {shell_script}") |
|
|
|
|
| |
|
|
| def compute_cer(pred: str, ref: str) -> float: |
| if not ref: |
| return 1.0 |
|
|
| def edit_distance(s1, s2): |
| m, n = len(s1), len(s2) |
| dp = list(range(n + 1)) |
| for i in range(1, m + 1): |
| prev, dp[0] = dp[0], i |
| for j in range(1, n + 1): |
| temp = dp[j] |
| dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1]) |
| prev = temp |
| return dp[n] |
|
|
| return edit_distance(pred, ref) / max(len(ref), 1) |
|
|
|
|
| def evaluate_model( |
| model, |
| src_list: List[torch.Tensor], |
| ref_list: List[str], |
| tgt_tokenizer, |
| n_samples: int = 200, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| ) -> Dict: |
| """ |
| Generate n_samples outputs and compute CER + generation speed. |
| |
| Returns dict with: |
| mean_cer : average CER over samples |
| generation_s : total wall-clock seconds for all generations |
| speed_per_sample: seconds per sample |
| cer_list : per-sample CER values |
| """ |
| device = next(model.parameters()).device |
| n = min(n_samples, len(src_list)) |
| cer_list = [] |
|
|
| start = time.perf_counter() |
| for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])): |
| if src.dim() == 1: |
| src = src.unsqueeze(0) |
|
|
| with torch.no_grad(): |
| if hasattr(model.model, 'generate_cached'): |
| out = model.model.generate_cached( |
| src.to(device), temperature=temperature, top_k=top_k |
| ) |
| else: |
| out = model.generate( |
| src.to(device), temperature=temperature, top_k=top_k |
| ) |
|
|
| ids = [x for x in out[0].tolist() if x > 4] |
| pred = tgt_tokenizer.decode(ids).strip() |
| cer = compute_cer(pred, ref) |
| cer_list.append(cer) |
|
|
| elapsed = time.perf_counter() - start |
|
|
| return { |
| "mean_cer": float(np.mean(cer_list)), |
| "std_cer": float(np.std(cer_list)), |
| "generation_s": elapsed, |
| "speed_per_sample": elapsed / max(n, 1), |
| "cer_list": cer_list, |
| "n_samples": n, |
| } |
|
|
|
|
| def run_ablation_analysis( |
| ablation_dir: str = "ablation_results", |
| base_cfg: dict = None, |
| src_list: List[torch.Tensor] = None, |
| ref_list: List[str] = None, |
| tgt_tokenizer = None, |
| device: torch.device = None, |
| output_dir: str = "analysis/outputs", |
| ) -> Dict: |
| """ |
| Load each trained model and evaluate. |
| Produces results dict and 3D plot. |
| |
| Expects ablation_results/T{N}/best_model.pt for each T in T_VALUES. |
| """ |
| from inference import load_model |
|
|
| results = {} |
| for T in T_VALUES: |
| ckpt = os.path.join(ablation_dir, f"T{T}", "best_model.pt") |
| if not os.path.exists(ckpt): |
| print(f" SKIP T={T}: no checkpoint at {ckpt}") |
| continue |
|
|
| print(f"\nEvaluating T={T}...") |
| cfg_T = copy.deepcopy(base_cfg) |
| cfg_T['model']['diffusion_steps'] = T |
| cfg_T['inference']['num_steps'] = T |
|
|
| model, cfg_T = load_model(ckpt, cfg_T, device) |
| model.eval() |
|
|
| metrics = evaluate_model( |
| model, src_list, ref_list, tgt_tokenizer, n_samples=200 |
| ) |
| results[T] = metrics |
| print(f" T={T} CER={metrics['mean_cer']:.4f} " |
| f"speed={metrics['speed_per_sample']:.3f}s/sample") |
|
|
| del model |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
| results_path = os.path.join(output_dir, "ablation_results.json") |
| with open(results_path, "w") as f: |
| json.dump({str(k): {kk: vv for kk, vv in v.items() if kk != 'cer_list'} |
| for k, v in results.items()}, f, indent=2) |
| print(f"\nResults saved: {results_path}") |
|
|
| return results |
|
|
|
|
| def plot_ablation_3d( |
| results: Dict, |
| save_path: Optional[str] = None, |
| ): |
| """ |
| 3D plot: X=diffusion_steps, Y=generation_speed(s/sample), Z=CER. |
| Also produces a 2D summary plot. |
| """ |
| try: |
| import matplotlib.pyplot as plt |
| from mpl_toolkits.mplot3d import Axes3D |
| except ImportError: |
| print("pip install matplotlib.") |
| return |
|
|
| T_list = sorted(results.keys()) |
| cers = [results[T]["mean_cer"] for T in T_list] |
| speeds = [results[T]["speed_per_sample"] for T in T_list] |
|
|
| |
| fig = plt.figure(figsize=(14, 5)) |
|
|
| ax3d = fig.add_subplot(121, projection='3d') |
| ax3d.scatter(T_list, speeds, cers, c=cers, cmap='RdYlGn_r', s=80) |
| for T, s, c in zip(T_list, speeds, cers): |
| ax3d.text(T, s, c, f"T={T}", fontsize=8) |
| ax3d.set_xlabel("Diffusion steps T", fontsize=9) |
| ax3d.set_ylabel("Speed (s/sample)", fontsize=9) |
| ax3d.set_zlabel("CER (β better)", fontsize=9) |
| ax3d.set_title("T vs speed vs CER", fontsize=10) |
|
|
| |
| ax2d = fig.add_subplot(122) |
| ax2d.plot(T_list, cers, 'o-', linewidth=1.8, color='coral', markersize=7) |
| for T, c in zip(T_list, cers): |
| ax2d.annotate(f"{c:.3f}", (T, c), textcoords="offset points", |
| xytext=(0, 8), fontsize=8, ha='center') |
|
|
| |
| if len(T_list) >= 3: |
| drops = [cers[i] - cers[i+1] for i in range(len(cers)-1)] |
| knee_i = int(np.argmax(drops)) |
| knee_T = T_list[knee_i + 1] |
| ax2d.axvline(knee_T, color='steelblue', linestyle='--', linewidth=1.2, |
| label=f"Knee at T={knee_T}") |
| ax2d.legend(fontsize=9) |
|
|
| ax2d.set_xlabel("Diffusion steps T", fontsize=10) |
| ax2d.set_ylabel("CER (lower = better)", fontsize=10) |
| ax2d.set_title("CER vs diffusion steps", fontsize=10) |
| ax2d.set_ylim(0, max(cers) * 1.1) |
|
|
| plt.tight_layout() |
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| print(f"Saved: {save_path}") |
| else: |
| plt.show() |
| plt.close() |
|
|
|
|
| |
|
|
| def corrupt_iast(text: str, corruption_rate: float = 0.05) -> str: |
| """ |
| Introduce random corruption into IAST text: |
| - Character swap (adjacent chars swapped) |
| - Character deletion |
| - Random character insertion |
| |
| Models rate as 5% to 20% corruption to test robustness. |
| """ |
| import random |
| chars = list(text) |
| n_corrupt = max(1, int(len(chars) * corruption_rate)) |
|
|
| for _ in range(n_corrupt): |
| op = random.choice(['swap', 'delete', 'insert']) |
| pos = random.randint(0, len(chars) - 1) |
|
|
| if op == 'swap' and pos < len(chars) - 1: |
| chars[pos], chars[pos+1] = chars[pos+1], chars[pos] |
| elif op == 'delete' and len(chars) > 1: |
| chars.pop(pos) |
| elif op == 'insert': |
| chars.insert(pos, random.choice('abcdeimnostu')) |
|
|
| return "".join(chars) |
|
|
|
|
| @torch.no_grad() |
| def run_adversarial_test( |
| model, |
| src_tokenizer, |
| tgt_tokenizer, |
| test_inputs: List[str], |
| test_refs: List[str], |
| corruption_rates: List[float] = [0.0, 0.05, 0.10, 0.15, 0.20], |
| device: torch.device = None, |
| output_dir: str = "analysis/outputs", |
| ) -> Dict: |
| """ |
| Test if CER degrades proportionally with IAST corruption. |
| Uses existing trained model β no retraining. |
| """ |
| device = device or next(model.parameters()).device |
| results = {} |
|
|
| print("\nAdversarial robustness test...") |
| for rate in corruption_rates: |
| cer_list = [] |
| for text, ref in zip(test_inputs, test_refs): |
| corrupted = corrupt_iast(text, rate) |
| ids = src_tokenizer.encode(corrupted) |
| src = torch.tensor([ids], dtype=torch.long, device=device) |
|
|
| if hasattr(model.model, 'generate_cached'): |
| out = model.model.generate_cached(src) |
| else: |
| out = model.generate(src) |
|
|
| pred_ids = [x for x in out[0].tolist() if x > 4] |
| pred = tgt_tokenizer.decode(pred_ids).strip() |
| cer_list.append(compute_cer(pred, ref)) |
|
|
| mean_cer = float(np.mean(cer_list)) |
| results[rate] = mean_cer |
| print(f" corruption={rate*100:.0f}% β CER={mean_cer:.4f}") |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
| try: |
| import matplotlib.pyplot as plt |
| fig, ax = plt.subplots(figsize=(8, 4)) |
| rates = [r * 100 for r in corruption_rates] |
| cers = [results[r] for r in corruption_rates] |
| ax.plot(rates, cers, 'o-', linewidth=1.8, color='steelblue', markersize=7) |
| ax.set_xlabel("IAST corruption rate (%)", fontsize=11) |
| ax.set_ylabel("CER", fontsize=11) |
| ax.set_title("Model robustness to IAST input corruption", fontsize=11) |
| ax.set_ylim(0, max(cers) * 1.2) |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, "adversarial_robustness.png"), |
| dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f" Saved: {output_dir}/adversarial_robustness.png") |
| except ImportError: |
| pass |
|
|
| with open(os.path.join(output_dir, "adversarial_results.json"), "w") as f: |
| json.dump({str(k): v for k, v in results.items()}, f, indent=2) |
|
|
| return results |
|
|