Spaces:
Running
Running
| """ | |
| analysis/run_analysis.py | |
| ========================= | |
| Entry point for all 5 tasks. | |
| Tasks: | |
| Task 1 β KV Cache benchmark (no retraining) | |
| Task 2 β Attention viz + drift (no retraining) | |
| Task 3 β Concept vectors + PCA steer (no retraining) | |
| Task 4 β Step ablation (REQUIRES retraining for each T) | |
| Task 5 β Classifier-free guidance (trains small 10k-param classifier) | |
| Usage: | |
| python analysis/run_analysis.py --task 1 | |
| python analysis/run_analysis.py --task 2 --input "dharmo rakαΉ£ati rakαΉ£itaαΈ₯" | |
| python analysis/run_analysis.py --task 3 | |
| python analysis/run_analysis.py --task 4 --phase generate_configs | |
| python analysis/run_analysis.py --task 4 --phase analyze | |
| python analysis/run_analysis.py --task 5 | |
| python analysis/run_analysis.py --task all --input "satyameva jayate" | |
| Output files: analysis/outputs/ | |
| """ | |
| import copy | |
| import torch | |
| import os, sys, argparse, json | |
| import numpy as np | |
| import time | |
| import gc | |
| import tracemalloc | |
| import threading | |
| import resource | |
| from difflib import SequenceMatcher | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| try: | |
| import psutil | |
| except Exception: | |
| psutil = None | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from config import CONFIG | |
| from inference import load_model, _decode_with_cleanup, _iast_to_deva | |
| from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer | |
| OUTPUT_DIR = "analysis/outputs" | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| # Keep caches writable/project-local for laptops and sandboxed runners. | |
| _ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| os.environ.setdefault("HF_HOME", os.path.join(_ROOT, ".hf_cache")) | |
| os.environ.setdefault("HF_DATASETS_CACHE", os.path.join(_ROOT, ".hf_cache", "datasets")) | |
| os.environ.setdefault("HF_HUB_CACHE", os.path.join(_ROOT, ".hf_cache", "hub")) | |
| os.environ.setdefault("MPLCONFIGDIR", os.path.join(_ROOT, ".mplconfig")) | |
| for _p in [ | |
| os.environ["HF_HOME"], | |
| os.environ["HF_DATASETS_CACHE"], | |
| os.environ["HF_HUB_CACHE"], | |
| os.environ["MPLCONFIGDIR"], | |
| ]: | |
| os.makedirs(_p, exist_ok=True) | |
| def _process_mem_mb() -> float: | |
| if psutil is not None: | |
| try: | |
| return float(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)) | |
| except Exception: | |
| pass | |
| # Linux fallback: /proc/self/statm current RSS pages. | |
| try: | |
| with open("/proc/self/statm", "r", encoding="utf-8") as f: | |
| parts = f.read().strip().split() | |
| if len(parts) >= 2: | |
| rss_pages = int(parts[1]) | |
| page_size = os.sysconf("SC_PAGE_SIZE") | |
| return float(rss_pages * page_size / (1024 * 1024)) | |
| except Exception: | |
| pass | |
| # Unix fallback: max RSS from resource (platform-dependent units). | |
| try: | |
| ru = float(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) | |
| # Heuristic: macOS tends to return bytes, Linux tends KB. | |
| if ru > 10_000_000: | |
| return ru / (1024 * 1024) | |
| return ru / 1024.0 | |
| except Exception: | |
| return 0.0 | |
| # ββ Shared loader βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def infer_model_type_from_checkpoint(ckpt_path: str) -> str: | |
| name = ckpt_path.lower() | |
| if "ablation_results/t" in name or "d3pm_cross_attention" in name: | |
| return "d3pm_cross_attention" | |
| if "d3pm_encoder_decoder" in name: | |
| return "d3pm_encoder_decoder" | |
| if "baseline_cross_attention" in name: | |
| return "baseline_cross_attention" | |
| if "baseline_encoder_decoder" in name: | |
| return "baseline_encoder_decoder" | |
| return CONFIG["model_type"] | |
| def infer_include_negative_from_checkpoint(ckpt_path: str) -> bool: | |
| name = ckpt_path.lower() | |
| if "_neg_true" in name: | |
| return True | |
| if "_neg_false" in name: | |
| return False | |
| if "ablation_results/t" in name: | |
| return False | |
| return CONFIG["data"]["include_negative_examples"] | |
| def load_everything(cfg, device, ckpt_override=None): | |
| model_name = cfg['model_type'] | |
| has_neg = cfg['data']['include_negative_examples'] | |
| candidates = [ | |
| f"results7/{model_name}_neg_{has_neg}/best_model.pt", | |
| f"results/{model_name}_neg_{has_neg}/best_model.pt", | |
| f"results7/{model_name}_neg_True/best_model.pt", | |
| f"results/{model_name}_neg_True/best_model.pt", | |
| f"results7/{model_name}_neg_False/best_model.pt", | |
| f"results/{model_name}_neg_False/best_model.pt", | |
| "ablation_results/T4/best_model.pt", | |
| "ablation_results/T8/best_model.pt", | |
| ] | |
| ckpt = ckpt_override if ckpt_override else next((p for p in candidates if os.path.exists(p)), None) | |
| if not os.path.exists(ckpt): | |
| raise FileNotFoundError(f"No checkpoint found. Checked: {candidates}") | |
| model, cfg = load_model(ckpt, cfg, device) | |
| model.eval() | |
| src_tok = SanskritSourceTokenizer( | |
| vocab_size=cfg['model'].get('src_vocab_size', 500), | |
| max_len=cfg['model']['max_seq_len']) | |
| tgt_tok = SanskritTargetTokenizer( | |
| vocab_size=cfg['model'].get('tgt_vocab_size', 500), | |
| max_len=cfg['model']['max_seq_len']) | |
| return model, src_tok, tgt_tok, cfg | |
| def load_val_data(cfg, src_tok, tgt_tok, n=500): | |
| """Load validation set as (src_tensors, ref_strings, input_strings).""" | |
| from data.dataset import OptimizedSanskritDataset | |
| from torch.utils.data import Subset | |
| from sklearn.model_selection import train_test_split | |
| dataset = OptimizedSanskritDataset( | |
| 'train', max_len=cfg['model']['max_seq_len'], | |
| cfg=cfg, src_tokenizer=src_tok, tgt_tokenizer=tgt_tok) | |
| total = min(cfg['data']['dataset_size'], len(dataset)) | |
| _, val_idx = train_test_split(list(range(total)), train_size=0.8, random_state=42) | |
| val_idx = val_idx[:n] | |
| src_list, ref_list, inp_list = [], [], [] | |
| for i in val_idx: | |
| item = dataset[i] | |
| src_list.append(item['input_ids'].unsqueeze(0)) | |
| ref_list.append(item['target_text']) | |
| inp_list.append(item['input_text']) | |
| return src_list, ref_list, inp_list | |
| def _generate_ids_compat(model, src, num_steps=None, temperature=0.8, top_k=40, | |
| repetition_penalty=1.2, diversity_penalty=0.0): | |
| kwargs = dict(temperature=temperature, top_k=top_k) | |
| if num_steps is not None: | |
| kwargs["num_steps"] = int(num_steps) | |
| if repetition_penalty is not None: | |
| kwargs["repetition_penalty"] = float(repetition_penalty) | |
| if diversity_penalty is not None: | |
| kwargs["diversity_penalty"] = float(diversity_penalty) | |
| try: | |
| return model.generate(src, **kwargs) | |
| except TypeError: | |
| # Some model variants expose reduced generate() kwargs. | |
| slim = {k: kwargs[k] for k in ["temperature", "top_k", "num_steps"] if k in kwargs} | |
| try: | |
| return model.generate(src, **slim) | |
| except TypeError: | |
| return model.generate(src) | |
| def _decode_ids(tgt_tok, out_ids, src_text=None, inf_cfg=None): | |
| ids = [] | |
| for x in out_ids[0].tolist(): | |
| # stop at PAD/SEP once decoding started | |
| if x in (1, 4) and ids: | |
| break | |
| if x > 4: | |
| ids.append(x) | |
| if src_text is not None and inf_cfg is not None: | |
| txt = _decode_with_cleanup(tgt_tok, ids, src_text, inf_cfg) | |
| else: | |
| txt = tgt_tok.decode(ids).strip() | |
| return txt, ids | |
| def _cer(a: str, b: str) -> float: | |
| m, n = len(a), len(b) | |
| if m == 0 and n == 0: | |
| return 0.0 | |
| dp = list(range(n + 1)) | |
| for i in range(1, m + 1): | |
| prev, dp[0] = dp[0], i | |
| for j in range(1, n + 1): | |
| tmp = dp[j] | |
| dp[j] = prev if a[i-1] == b[j-1] else 1 + min(prev, dp[j], dp[j-1]) | |
| prev = tmp | |
| return float(dp[n]) / max(1, m, n) | |
| # ββ Task 1 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task1(model, src_tok, device): | |
| print("\n" + "="*65) | |
| print(" TASK 1 β KV Cache Benchmark") | |
| print("="*65) | |
| src_vocab = model.model.src_embed.token_emb.weight.shape[0] | |
| src_lens = [16, 32, 64] | |
| n_runs = 3 | |
| has_cached = hasattr(model, "generate_cached") | |
| if not has_cached: | |
| print(" Compatibility mode: generate_cached() unavailable; running standard benchmark only.") | |
| def _timeit(fn, runs=n_runs): | |
| vals = [] | |
| for _ in range(runs): | |
| t0 = time.perf_counter() | |
| fn() | |
| vals.append(time.perf_counter() - t0) | |
| return float(np.mean(vals)) | |
| def _trace_peak_bytes(fn, repeat=8): | |
| gc.collect() | |
| tracemalloc.start() | |
| for _ in range(max(1, int(repeat))): | |
| fn() | |
| _, peak = tracemalloc.get_traced_memory() | |
| tracemalloc.stop() | |
| return int(peak) | |
| def _torch_cpu_mem_bytes(fn): | |
| try: | |
| from torch.profiler import profile, ProfilerActivity | |
| with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=False) as prof: | |
| fn() | |
| mem = 0 | |
| for ev in prof.key_averages(): | |
| try: | |
| mem += max(0, int(getattr(ev, "self_cpu_memory_usage", 0))) | |
| except Exception: | |
| pass | |
| return int(mem) | |
| except Exception: | |
| return 0 | |
| results = {} | |
| for L in src_lens: | |
| src = torch.randint(5, src_vocab, (1, L), device=device) | |
| t_std = _timeit(lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40)) | |
| if has_cached: | |
| t_cache = _timeit( | |
| lambda: model.generate_cached( | |
| src, num_steps=64, temperature=0.8, top_k=40, | |
| repetition_penalty=1.2, diversity_penalty=0.0 | |
| ) | |
| ) | |
| speedup = t_std / max(t_cache, 1e-9) | |
| else: | |
| t_cache = t_std | |
| speedup = 1.0 | |
| # Encoder cost estimate: one encode_source pass vs one cached step. | |
| if hasattr(model.model, "encode_source") and hasattr(model.model, "forward_cached"): | |
| memory, src_pad = model.model.encode_source(src) | |
| x = torch.full((1, L), model.model.mask_token_id, dtype=torch.long, device=device) | |
| t = torch.full((1,), max(0, model.model.scheduler.num_timesteps - 1), dtype=torch.long, device=device) | |
| t_enc = _timeit(lambda: model.model.encode_source(src)) | |
| t_step = _timeit(lambda: model.model.forward_cached(memory, src_pad, x, t, x0_hint=None, inference_mode=True)) | |
| encoder_pct = (t_enc / max(t_enc + t_step, 1e-9)) * 100.0 | |
| else: | |
| encoder_pct = 0.0 | |
| results[L] = dict( | |
| standard_s=t_std, | |
| cached_s=t_cache, | |
| speedup=speedup, | |
| encoder_pct=encoder_pct, | |
| ) | |
| print(f" src_len={L:>3d} standard={t_std:.3f}s cached={t_cache:.3f}s speedup={speedup:.2f}x encoder%={encoder_pct:.1f}") | |
| # Memory profiling (GPU preferred, CPU/MPS fallback via process RSS delta). | |
| mem_note = "N/A" | |
| mem_red = None | |
| if torch.cuda.is_available() and str(device).startswith("cuda"): | |
| L = 64 | |
| src = torch.randint(5, src_vocab, (1, L), device=device) | |
| torch.cuda.reset_peak_memory_stats(device) | |
| _ = _generate_ids_compat(model, src, temperature=0.8, top_k=40) | |
| m_std = torch.cuda.max_memory_allocated(device) | |
| torch.cuda.reset_peak_memory_stats(device) | |
| _ = model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40, | |
| repetition_penalty=1.2, diversity_penalty=0.0) | |
| m_cache = torch.cuda.max_memory_allocated(device) | |
| mem_red = 100.0 * (m_std - m_cache) / max(m_std, 1) | |
| mem_note = f"GPU peak alloc reduction: {mem_red:.1f}% @ src_len=64" | |
| print(f" Memory reduction: {mem_note}") | |
| elif has_cached and _process_mem_mb() > 0.0: | |
| L = 64 | |
| src = torch.randint(5, src_vocab, (1, L), device=device) | |
| def _peak_rss_while(fn, poll_s=0.01): | |
| done = {"v": False} | |
| peak = {"v": _process_mem_mb()} | |
| def _poll(): | |
| while not done["v"]: | |
| peak["v"] = max(peak["v"], _process_mem_mb()) | |
| time.sleep(poll_s) | |
| th = threading.Thread(target=_poll, daemon=True) | |
| gc.collect() | |
| base = _process_mem_mb() | |
| th.start() | |
| try: | |
| fn() | |
| finally: | |
| done["v"] = True | |
| th.join(timeout=0.1) | |
| gc.collect() | |
| return base, peak["v"], max(0.0, peak["v"] - base) | |
| b_std, p_std, d_std = _peak_rss_while( | |
| lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40) | |
| ) | |
| b_c, p_c, d_c = _peak_rss_while( | |
| lambda: model.generate_cached( | |
| src, num_steps=64, temperature=0.8, top_k=40, | |
| repetition_penalty=1.2, diversity_penalty=0.0 | |
| ) | |
| ) | |
| if d_std > 0.0: | |
| mem_red = 100.0 * (d_std - d_c) / d_std | |
| mem_note = ( | |
| f"RSS peak reduction: {mem_red:.1f}% @ src_len=64 " | |
| f"(std_peak={p_std:.1f}MB, cache_peak={p_c:.1f}MB)" | |
| ) | |
| else: | |
| # Secondary fallback: Python allocator peak (always available). | |
| peak_std = _trace_peak_bytes( | |
| lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40), repeat=10 | |
| ) | |
| peak_cache = _trace_peak_bytes( | |
| lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40, | |
| repetition_penalty=1.2, diversity_penalty=0.0), | |
| repeat=10 | |
| ) | |
| if peak_std >= 256 * 1024: | |
| mem_red = 100.0 * (peak_std - peak_cache) / peak_std | |
| mem_note = ( | |
| f"Py alloc peak reduction: {mem_red:.1f}% @ src_len=64 " | |
| f"(std={peak_std/1024**2:.1f}MB, cache={peak_cache/1024**2:.1f}MB)" | |
| ) | |
| else: | |
| cpu_std = _torch_cpu_mem_bytes( | |
| lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40) | |
| ) | |
| cpu_cache = _torch_cpu_mem_bytes( | |
| lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40, | |
| repetition_penalty=1.2, diversity_penalty=0.0) | |
| ) | |
| if cpu_std > 0: | |
| mem_red = 100.0 * (cpu_std - cpu_cache) / max(cpu_std, 1) | |
| mem_note = ( | |
| f"Torch CPU mem-event reduction: {mem_red:.1f}% @ src_len=64 " | |
| f"(std={cpu_std/1024**2:.1f}MB, cache={cpu_cache/1024**2:.1f}MB)" | |
| ) | |
| else: | |
| mem_note = "Memory estimate unavailable (RSS/tracemalloc/torch-profiler flat)" | |
| print(f" Memory reduction: {mem_note}") | |
| elif has_cached: | |
| # Final fallback (CPU-safe): Python allocation peak via tracemalloc. | |
| # This does not include all native tensor allocator memory, but still | |
| # gives a consistent relative signal when psutil/CUDA stats are absent. | |
| L = 64 | |
| src = torch.randint(5, src_vocab, (1, L), device=device) | |
| peak_std = _trace_peak_bytes( | |
| lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40), repeat=10 | |
| ) | |
| peak_cache = _trace_peak_bytes( | |
| lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40, | |
| repetition_penalty=1.2, diversity_penalty=0.0), | |
| repeat=10 | |
| ) | |
| # Ignore extremely small peaks; they are noise for tensor-heavy paths. | |
| if peak_std >= 256 * 1024: | |
| mem_red = 100.0 * (peak_std - peak_cache) / peak_std | |
| mem_note = ( | |
| f"Py alloc peak reduction: {mem_red:.1f}% @ src_len=64 " | |
| f"(std={peak_std/1024**2:.1f}MB, cache={peak_cache/1024**2:.1f}MB)" | |
| ) | |
| else: | |
| cpu_std = _torch_cpu_mem_bytes( | |
| lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40) | |
| ) | |
| cpu_cache = _torch_cpu_mem_bytes( | |
| lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40, | |
| repetition_penalty=1.2, diversity_penalty=0.0) | |
| ) | |
| if cpu_std > 0: | |
| mem_red = 100.0 * (cpu_std - cpu_cache) / max(cpu_std, 1) | |
| mem_note = ( | |
| f"Torch CPU mem-event reduction: {mem_red:.1f}% @ src_len=64 " | |
| f"(std={cpu_std/1024**2:.1f}MB, cache={cpu_cache/1024**2:.1f}MB)" | |
| ) | |
| else: | |
| mem_note = "Py alloc peak too small/noisy to estimate (no psutil/CUDA profiler)" | |
| print(f" Memory reduction: {mem_note}") | |
| else: | |
| mem_note = "Profiler unavailable (cached path missing)" | |
| # Subtask graphs | |
| lens = sorted(results.keys()) | |
| std_vals = [results[L]["standard_s"] for L in lens] | |
| cache_vals = [results[L]["cached_s"] for L in lens] | |
| speed_vals = [results[L]["speedup"] for L in lens] | |
| enc_vals = [results[L]["encoder_pct"] for L in lens] | |
| plt.figure(figsize=(7, 4)) | |
| plt.plot(lens, std_vals, marker="o", label="standard") | |
| plt.plot(lens, cache_vals, marker="o", label="cached") | |
| plt.xlabel("Source length") | |
| plt.ylabel("Time (s)") | |
| plt.title("Task1: Generation Time (Standard vs Cached)") | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task1_time_comparison.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| plt.figure(figsize=(7, 4)) | |
| plt.plot(lens, speed_vals, marker="o") | |
| plt.xlabel("Source length") | |
| plt.ylabel("Speedup (x)") | |
| plt.title("Task1: KV-Cache Speedup") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task1_speedup.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| plt.figure(figsize=(7, 4)) | |
| plt.plot(lens, enc_vals, marker="o") | |
| plt.xlabel("Source length") | |
| plt.ylabel("Encoder cost (%)") | |
| plt.title("Task1: Encoder Cost Share") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task1_encoder_cost.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| path = os.path.join(OUTPUT_DIR, "task1_kv_cache.txt") | |
| with open(path, "w") as f: | |
| f.write("TASK 1 β KV CACHE BENCHMARK\n" + "="*40 + "\n\n") | |
| f.write(f"has_generate_cached={has_cached}\n") | |
| f.write(f"memory_profile={mem_note}\n\n") | |
| f.write(f"{'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} " | |
| f"{'speedup':>8} {'encoder%':>9}\n") | |
| for src_len, r in results.items(): | |
| f.write(f"{src_len:>8} {r['standard_s']:>12.3f} {r['cached_s']:>10.3f} " | |
| f"{r['speedup']:>7.2f}x {r['encoder_pct']:>8.1f}%\n") | |
| f.write("\nSaved graphs:\n") | |
| f.write(" - task1_time_comparison.png\n") | |
| f.write(" - task1_speedup.png\n") | |
| f.write(" - task1_encoder_cost.png\n") | |
| print(f" Saved: {path}") | |
| # ββ Task 2 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task2(model, src_tok, tgt_tok, device, input_text, cfg, corpus_inputs=None): | |
| print("\n" + "="*65) | |
| print(" TASK 2 β Attention Visualization + Semantic Drift") | |
| print("="*65) | |
| print(f" Input: {input_text}") | |
| if not hasattr(model.model, 'encode_source'): | |
| print(" Compatibility mode: attention hooks unavailable; running semantic-drift-only analysis.") | |
| src_ids = src_tok.encode(input_text) | |
| src = torch.tensor([src_ids], dtype=torch.long, device=device) | |
| # Keep steps <= scheduler horizon for this checkpoint to avoid backend aborts. | |
| t_sched = int(getattr(getattr(model.model, "scheduler", object()), "num_timesteps", 64)) | |
| # Stability guard for some checkpoints/backends: keep sweep moderate. | |
| t_max = min(t_sched, 64) | |
| candidates = [t_max, 48, 32, 24, 16, 8, 4, 1] | |
| step_list = [] | |
| seen = set() | |
| for s in candidates: | |
| s = max(1, min(int(s), t_max)) | |
| if s not in seen: | |
| step_list.append(s) | |
| seen.add(s) | |
| outs = {} | |
| for s in step_list: | |
| out = _generate_ids_compat(model, src, num_steps=s, temperature=0.8, top_k=40) | |
| txt, _ = _decode_ids( | |
| tgt_tok, out, | |
| src_text=input_text, | |
| inf_cfg=cfg.get("inference", {"temperature": 0.8, "top_k": 40}) | |
| ) | |
| outs[s] = txt | |
| final = outs[1] | |
| drift = [(_cer(outs[s], final), s) for s in step_list] | |
| # Plot drift | |
| xs = [s for _, s in drift] | |
| ys = [c for c, _ in drift] | |
| plt.figure(figsize=(8, 4)) | |
| plt.plot(xs, ys, marker='o') | |
| plt.gca().invert_xaxis() | |
| plt.xlabel("Generation steps") | |
| plt.ylabel("CER to 1-step output") | |
| plt.title("Task2 Semantic Drift (Compatibility Mode)") | |
| plt.tight_layout() | |
| plot_path = os.path.join(OUTPUT_DIR, "task2_semantic_drift.png") | |
| plt.savefig(plot_path, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| report = os.path.join(OUTPUT_DIR, "task2_report.txt") | |
| with open(report, "w", encoding="utf-8") as f: | |
| f.write("TASK 2 β COMPATIBILITY REPORT\n") | |
| f.write("="*40 + "\n") | |
| f.write("Cross-attention capture unavailable for this checkpoint.\n") | |
| f.write(f"Input: {input_text}\n") | |
| f.write(f"Reference final (1 step): {final}\n\n") | |
| for cer_v, s in drift: | |
| f.write(f"steps={s:>3d} CER_to_final={cer_v:.4f} output={outs[s][:120]}\n") | |
| print(f" Output(final@1): {final}") | |
| print(f" Report: {report}") | |
| print(f" Saved: {plot_path}") | |
| return | |
| src_ids = src_tok.encode(input_text) | |
| src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device) | |
| from analysis.attention_viz import ( | |
| AttentionCapture, | |
| compute_trajectory_metrics, | |
| analyze_token_stability, | |
| tfidf_attention_correlation, | |
| ) | |
| # Attention capture | |
| print(" Capturing attention weights...") | |
| capturer = AttentionCapture(model) | |
| step_weights, step_outputs_ids = capturer.run(src_tensor) | |
| def _decode_tensor_ids(t): | |
| out = [] | |
| for x in t[0].tolist(): | |
| if x in (1, 4) and out: | |
| break | |
| if x > 4: | |
| out.append(x) | |
| raw_txt = tgt_tok.decode(out).strip() | |
| clean_txt = _decode_with_cleanup( | |
| tgt_tok, out, input_text, cfg.get("inference", {"temperature": 0.8, "top_k": 40}) | |
| ) | |
| return raw_txt, clean_txt, out | |
| decoded = {} | |
| decoded_raw = {} | |
| for t_val, ids_t in step_outputs_ids.items(): | |
| raw_txt, clean_txt, ids = _decode_tensor_ids(ids_t) | |
| decoded_raw[t_val] = (raw_txt, ids) | |
| decoded[t_val] = (clean_txt, ids) | |
| final_step = min(decoded.keys()) | |
| final_out, final_ids = decoded[final_step] | |
| final_out_raw = decoded_raw[final_step][0] | |
| src_labels = [] | |
| for sid in src_ids[:20]: | |
| tok = src_tok.decode([sid]).strip() | |
| src_labels.append(tok if tok else f"id{sid}") | |
| tgt_labels = [f"y{i}" for i in range(min(20, len(final_ids)))] | |
| print(f" Output: {final_out}") | |
| # Heatmap t=max, layer 0 | |
| first_t = max(step_weights.keys()) | |
| w_first = step_weights[first_t][0][0] | |
| w0 = step_weights[0][0][0] | |
| n_src = min(len(src_labels), w_first.shape[1], 20) | |
| n_tgt = min(len(tgt_labels), w_first.shape[0], 20) | |
| plt.figure(figsize=(max(8, n_src * 0.35), max(6, n_tgt * 0.3))) | |
| plt.imshow(w_first[:n_tgt, :n_src], aspect="auto", cmap="YlOrRd") | |
| plt.xticks(range(n_src), src_labels[:n_src], rotation=45, ha="right", fontsize=8) | |
| plt.yticks(range(n_tgt), tgt_labels[:n_tgt], fontsize=8) | |
| plt.title(f"Attention t={first_t} Layer 0") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, f"task2_attn_t{first_t}.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| plt.figure(figsize=(max(8, n_src * 0.35), max(6, n_tgt * 0.3))) | |
| plt.imshow(w0[:n_tgt, :n_src], aspect="auto", cmap="YlOrRd") | |
| plt.xticks(range(n_src), src_labels[:n_src], rotation=45, ha="right", fontsize=8) | |
| plt.yticks(range(n_tgt), tgt_labels[:n_tgt], fontsize=8) | |
| plt.title("Attention t=0 Layer 0") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task2_attn_t0.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # All layers at t=0 | |
| layers = step_weights[0] | |
| n_layers = len(layers) | |
| n_cols = min(4, n_layers) | |
| n_rows = (n_layers + n_cols - 1) // n_cols | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3.2)) | |
| axes = np.array(axes).reshape(-1) | |
| for i, layer_w in enumerate(layers): | |
| ax = axes[i] | |
| w = layer_w[0][:n_tgt, :n_src] | |
| ax.imshow(w, aspect="auto", cmap="YlOrRd") | |
| ax.set_title(f"Layer {i}", fontsize=9) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| for i in range(n_layers, len(axes)): | |
| axes[i].axis("off") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task2_all_layers_t0.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # Attention evolution for src[0] -> tgt[0] | |
| t_vals_desc = sorted(step_weights.keys(), reverse=True) | |
| evo = [] | |
| for t_val in t_vals_desc: | |
| w = step_weights[t_val][0][0] | |
| evo.append(float(w[0, 0]) if w.shape[0] > 0 and w.shape[1] > 0 else 0.0) | |
| plt.figure(figsize=(10, 3.5)) | |
| plt.plot(range(len(t_vals_desc)), evo, marker="o") | |
| plt.xlabel("Captured step index (Tβ0)") | |
| plt.ylabel("Attention weight") | |
| plt.title("Attention Evolution (src0βtgt0)") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task2_attn_evolution.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # Drift (CER to final across steps) on RAW decoded trajectory to expose true diffusion. | |
| t_vals = sorted(decoded.keys(), reverse=True) | |
| cer_vals = [_cer(decoded_raw[t][0], final_out_raw) for t in t_vals] | |
| plt.figure(figsize=(8, 4)) | |
| plt.plot(t_vals, cer_vals, marker="o") | |
| plt.gca().invert_xaxis() | |
| plt.xlabel("Diffusion step") | |
| plt.ylabel("CER to final") | |
| plt.title("Task2 Semantic Drift") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task2_semantic_drift.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # Source alignment proxy (avg attention on source positions at t=0, last layer) | |
| last_layer_t0 = step_weights[0][-1][0] | |
| src_align = last_layer_t0.mean(axis=0)[:n_src] | |
| plt.figure(figsize=(8, 3)) | |
| plt.bar(np.arange(len(src_align)), src_align) | |
| plt.xticks(range(n_src), src_labels[:n_src], rotation=45, ha="right", fontsize=8) | |
| plt.title("Source Alignment Importance (t=0, last layer)") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task2_source_alignment.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| stability = analyze_token_stability(step_weights) | |
| n_locked = sum(1 for v in stability.values() if v == "LOCKED") | |
| n_flex = sum(1 for v in stability.values() if v == "FLEXIBLE") | |
| tfidf_info = tfidf_attention_correlation(input_text, step_weights, corpus_texts=corpus_inputs) | |
| tfidf_corr = tfidf_info.get("corr") | |
| tfidf_status = tfidf_info.get("status", "UNKNOWN") | |
| traj = compute_trajectory_metrics( | |
| step_outputs_ids, | |
| tgt_tok, | |
| reference_text=_iast_to_deva(input_text), | |
| ) | |
| # Keep trajectory semantic scoring on raw decoded text to avoid masking drift. | |
| ref_text = _iast_to_deva(input_text) | |
| for row in traj: | |
| t_cur = row["step"] | |
| raw_txt = decoded_raw.get(t_cur, ("", []))[0] | |
| if raw_txt: | |
| sim = max(0.0, 1.0 - _cer(raw_txt, ref_text)) | |
| row["text"] = raw_txt | |
| row["bert"] = sim | |
| row["drift"] = 1.0 - sim | |
| # TF-IDF vs attention graph (subtask visualization) | |
| tfidf_vec = np.asarray(tfidf_info.get("tfidf_scores", []), dtype=np.float32) | |
| attn_vec = np.asarray(tfidf_info.get("attn_scores", []), dtype=np.float32) | |
| labels = list(tfidf_info.get("tokens", [])) | |
| m = min(len(tfidf_vec), len(attn_vec), len(labels), 20) | |
| if m > 0: | |
| x = np.arange(m) | |
| plt.figure(figsize=(8, 3.5)) | |
| tf_part = tfidf_vec[:m] | |
| at_part = attn_vec[:m] | |
| tf_norm = tf_part / (np.max(np.abs(tf_part)) + 1e-9) | |
| at_norm = at_part / (np.max(np.abs(at_part)) + 1e-9) | |
| w = 0.4 | |
| plt.bar(x - w/2, tf_norm, width=w, label="tfidf(norm)") | |
| plt.bar(x + w/2, at_norm, width=w, label="attn(norm)") | |
| plt.xlabel("Source token") | |
| plt.ylabel("Normalized score") | |
| plt.title("Task2: TF-IDF vs Attention Stability") | |
| plt.xticks(x, labels[:m], rotation=45, ha="right", fontsize=8) | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task2_tfidf_vs_attention.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| lock_in_t = next((t for t, c in zip(t_vals[::-1], cer_vals[::-1]) if c <= 0.05), t_vals[-1]) | |
| if tfidf_corr is not None and abs(float(tfidf_corr)) < 0.10: | |
| tfidf_status = "WEAK" | |
| has_semantic = any(float(r.get("bert", 0.0)) > 0.05 for r in traj) | |
| # Degeneracy score on final output | |
| toks = [t for t in final_out.split() if t] | |
| uniq_ratio = len(set(toks)) / max(1, len(toks)) | |
| degenerate = (len(toks) >= 8 and uniq_ratio < 0.35) | |
| # Small multi-sample stability check (prevents overclaim from one example) | |
| multi_scores = [] | |
| if corpus_inputs: | |
| sample_texts = [s for s in corpus_inputs[:8] if isinstance(s, str) and s.strip()] | |
| for txt in sample_texts: | |
| src_i = torch.tensor([src_tok.encode(txt)], dtype=torch.long, device=device) | |
| out_i = _generate_ids_compat(model, src_i, num_steps=min(16, cfg.get("inference", {}).get("num_steps", 16)), | |
| temperature=0.8, top_k=40) | |
| pred_i, _ = _decode_ids(tgt_tok, out_i) | |
| multi_scores.append(max(0.0, 1.0 - _cer(pred_i, _iast_to_deva(txt)))) | |
| multi_sem = float(np.mean(multi_scores)) if multi_scores else 0.0 | |
| quality_status = ( | |
| "VALID" | |
| if len(final_out.strip()) > 0 and n_flex + n_locked > 0 and has_semantic and not degenerate and multi_sem >= 0.05 | |
| else "WEAK" | |
| ) | |
| report = os.path.join(OUTPUT_DIR, "task2_report.txt") | |
| with open(report, "w", encoding="utf-8") as f: | |
| f.write("TASK 2 β ATTENTION + DRIFT REPORT\n" + "=" * 50 + "\n\n") | |
| f.write(f"Input : {input_text}\n") | |
| f.write(f"Output: {final_out}\n\n") | |
| f.write(f"Captured steps: {len(t_vals)}\n") | |
| f.write(f"Analysis quality: {quality_status}\n") | |
| f.write(f"Final output uniq-ratio: {uniq_ratio:.3f}\n") | |
| f.write(f"Degenerate output: {'YES' if degenerate else 'NO'}\n") | |
| f.write(f"Multi-sample semantic score (n<={len(multi_scores)}): {multi_sem:.4f}\n") | |
| f.write(f"Lock-in step (CER<=0.05): t={lock_in_t}\n") | |
| f.write(f"Locked tokens: {n_locked} Flexible tokens: {n_flex}\n") | |
| corr_txt = f"{tfidf_corr:.4f}" if tfidf_corr is not None else "N/A" | |
| f.write(f"TF-IDF vs attention stability corr: {corr_txt}\n") | |
| f.write(f"TF-IDF status: {tfidf_status}\n\n") | |
| f.write("Saved graphs:\n") | |
| f.write(" - task2_attn_t*.png / task2_all_layers_t0.png\n") | |
| f.write(" - task2_attn_evolution.png\n") | |
| f.write(" - task2_semantic_drift.png\n") | |
| f.write(" - task2_source_alignment.png\n") | |
| f.write(" - task2_tfidf_vs_attention.png\n\n") | |
| f.write("Step trajectory (first 10 rows)\n") | |
| f.write("-" * 60 + "\n") | |
| for row in traj[:10]: | |
| f.write(f"t={row['step']:>3d} bert={row['bert']:.4f} drift={row['drift']:.4f} text={row['text'][:60]}\n") | |
| print(f" Lock-in timestep: t={lock_in_t}") | |
| print(f" Locked/Flexible: {n_locked}/{n_flex}") | |
| corr_txt = f"{tfidf_corr:.4f}" if tfidf_corr is not None else "N/A" | |
| print(f" TF-IDF corr: {corr_txt} ({tfidf_status})") | |
| print(f" Report: {report}") | |
| # ββ Task 3 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task3(model, src_tok, tgt_tok, device, src_list, ref_list, n_samples=500): | |
| print("\n" + "="*65) | |
| print(" TASK 3 β Concept Vectors + PCA Steering") | |
| print("="*65) | |
| if not hasattr(model.model, 'encode_source'): | |
| print(" Compatibility mode: using output-token statistics for PCA concept proxy.") | |
| # Keep compatibility run lightweight/stable on constrained backends. | |
| n = min(60, len(src_list)) | |
| feats, lens = [], [] | |
| for i, src in enumerate(src_list[:n]): | |
| out = _generate_ids_compat(model, src.to(device), num_steps=8, temperature=0.8, top_k=40) | |
| txt, ids = _decode_ids(tgt_tok, out) | |
| arr = np.array(ids[:64] + [0] * max(0, 64 - len(ids[:64])), dtype=np.float32) | |
| feats.append(arr) | |
| lens.append(len(txt)) | |
| from sklearn.decomposition import PCA | |
| X = np.stack(feats) | |
| pca = PCA(n_components=min(10, X.shape[0]-1, X.shape[1])) | |
| Z = pca.fit_transform(X) | |
| plt.figure(figsize=(6, 5)) | |
| sc = plt.scatter(Z[:, 0], Z[:, 1] if Z.shape[1] > 1 else np.zeros_like(Z[:, 0]), | |
| c=lens, cmap="viridis", s=14) | |
| plt.colorbar(sc, label="Output length") | |
| plt.title("Task3 Concept Proxy Space (Compatibility Mode)") | |
| plt.tight_layout() | |
| img = os.path.join(OUTPUT_DIR, "task3_concept_space.png") | |
| plt.savefig(img, dpi=150, bbox_inches="tight") | |
| plt.close() | |
| rep = os.path.join(OUTPUT_DIR, "task3_report.txt") | |
| corr = float(np.corrcoef(Z[:, 0], np.array(lens))[0, 1]) if len(lens) > 2 else 0.0 | |
| with open(rep, "w", encoding="utf-8") as f: | |
| f.write("TASK 3 β COMPATIBILITY REPORT\n") | |
| f.write("="*40 + "\n") | |
| f.write("Hidden-state capture unavailable; used output-token vector proxy.\n") | |
| f.write(f"Samples: {n}\n") | |
| f.write(f"PC1-length correlation: {corr:.4f}\n") | |
| print(f" Saved: {img}") | |
| print(f" Report: {rep}") | |
| return | |
| from analysis.concept_vectors import ( | |
| collect_hidden_states, fit_pca, find_diversity_direction, generate_diversity_spectrum | |
| ) | |
| # Collect hidden states from val set | |
| n = min(max(1, int(n_samples)), len(src_list)) | |
| print(f" Collecting hidden states from {n} examples...") | |
| hidden, texts, lengths = collect_hidden_states( | |
| model, src_list[:n], tgt_tok, t_capture=0, max_samples=n | |
| ) | |
| # Fit PCA + find diversity direction | |
| pca = fit_pca(hidden, n_components=min(50, n-1)) | |
| direction = find_diversity_direction(hidden, lengths, pca) | |
| proj = pca.transform(hidden) | |
| corr = float(np.corrcoef(proj[:, 0], np.array(lengths))[0, 1]) if len(lengths) > 2 else 0.0 | |
| if not np.isfinite(corr): | |
| corr = 0.0 | |
| best_pc = 0 | |
| # Plot concept space | |
| plt.figure(figsize=(8, 6)) | |
| sc = plt.scatter(proj[:, 0], proj[:, 1] if proj.shape[1] > 1 else np.zeros_like(proj[:, 0]), | |
| c=lengths, cmap="viridis", s=14) | |
| plt.colorbar(sc, label="Output diversity proxy") | |
| plt.title("Task3 Concept Space") | |
| plt.xlabel("PC1") | |
| plt.ylabel("PC2") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task3_concept_space.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # Subtask graph: explained variance by PCA components | |
| ev = pca.explained_variance_ratio_ | |
| k = min(20, len(ev)) | |
| plt.figure(figsize=(8, 3.5)) | |
| plt.bar(np.arange(k), ev[:k]) | |
| plt.xlabel("PC index") | |
| plt.ylabel("Explained variance ratio") | |
| plt.title("Task3: PCA Explained Variance (Top Components)") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task3_pca_explained_variance.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # Generate diversity spectrum on multiple seeds for more stable conclusions | |
| seed_k = min(5, len(src_list)) | |
| uniq_list = [] | |
| sem_list = [] | |
| all_spectra = [] | |
| for i in range(seed_k): | |
| src_i = src_list[i] | |
| spec_i = generate_diversity_spectrum( | |
| model, src_i.to(device), direction, tgt_tok, | |
| alphas=[-2.0, -1.0, 0.0, 1.0, 2.0] | |
| ) | |
| all_spectra.append(spec_i) | |
| spec_items = sorted(spec_i.items()) | |
| spec_texts = [t for _, t in spec_items] | |
| uniq_list.append(len(set(spec_texts)) / max(1, len(spec_texts))) | |
| pivot = spec_texts[2] if len(spec_texts) >= 3 else (spec_texts[0] if spec_texts else "") | |
| sims = [SequenceMatcher(None, txt, pivot).ratio() for txt in spec_texts if txt] | |
| sem_list.append(float(np.mean(sims)) if sims else 0.0) | |
| uniq_ratio = float(np.mean(uniq_list)) if uniq_list else 0.0 | |
| semantic_stability = float(np.mean(sem_list)) if sem_list else 0.0 | |
| steering_valid = (abs(corr) >= 0.20) and (uniq_ratio >= 0.55) and (semantic_stability >= 0.40) | |
| # use first seed spectrum for visualization table | |
| spectrum = all_spectra[0] if all_spectra else {} | |
| # Subtask graph: alpha vs decoded length | |
| a_vals = sorted(spectrum.keys()) | |
| l_vals = [len(spectrum[a]) for a in a_vals] if spectrum else [] | |
| plt.figure(figsize=(7, 3.5)) | |
| plt.plot(a_vals, l_vals, marker="o") | |
| plt.xlabel("Steering alpha") | |
| plt.ylabel("Output length") | |
| plt.title("Task3: Diversity Steering Curve") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task3_diversity_curve.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| # Save diversity direction + results | |
| np.save(os.path.join(OUTPUT_DIR, "task3_diversity_direction.npy"), direction) | |
| report = os.path.join(OUTPUT_DIR, "task3_report.txt") | |
| with open(report, "w", encoding="utf-8") as f: | |
| f.write("TASK 3 β CONCEPT VECTORS + PCA STEERING\n" + "="*50 + "\n\n") | |
| f.write(f"PCA: {pca.n_components_} components, " | |
| f"{pca.explained_variance_ratio_.sum()*100:.1f}% variance\n") | |
| f.write(f"Diversity PC: {best_pc} (|r|={corr:.3f} with diversity proxy)\n\n") | |
| f.write(f"Direction validity: {'VALID' if steering_valid else 'WEAK'}\n") | |
| f.write(f"Spectrum unique ratio (mean over {seed_k} seeds): {uniq_ratio:.3f}\n") | |
| f.write(f"Spectrum semantic stability (mean over {seed_k} seeds): {semantic_stability:.3f}\n\n") | |
| f.write("Saved graphs:\n") | |
| f.write(" - task3_concept_space.png\n") | |
| f.write(" - task3_pca_explained_variance.png\n") | |
| f.write(" - task3_diversity_curve.png\n\n") | |
| f.write("Diversity spectrum:\n") | |
| for alpha, text in sorted(spectrum.items()): | |
| f.write(f" alpha={alpha:+.1f} β {text}\n") | |
| print(f" Report: {report}") | |
| # ββ Task 4 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task4(phase, model, src_tok, tgt_tok, device, cfg, | |
| src_list, ref_list, n_samples=200): | |
| print("\n" + "="*65) | |
| print(f" TASK 4 β Step Ablation (phase={phase})") | |
| print("="*65) | |
| import analysis.step_ablation as step_ablation | |
| # Legacy API | |
| has_legacy = all(hasattr(step_ablation, fn) for fn in [ | |
| "generate_ablation_configs", "run_ablation_analysis", "plot_ablation_3d" | |
| ]) | |
| # New API | |
| has_new = hasattr(step_ablation, "run_task4") | |
| if phase == "generate_configs": | |
| if has_legacy: | |
| print(" Generating ablation configs...") | |
| step_ablation.generate_ablation_configs(output_dir="ablation_configs") | |
| print("\n NEXT STEPS:") | |
| print(" 1. bash ablation_configs/train_all.sh") | |
| print(" 2. python analysis/run_analysis.py --task 4 --phase analyze") | |
| return | |
| print(" This step_ablation version does not expose config generation helpers.") | |
| print(" Use your latest ablation training script/config pipeline directly.") | |
| return | |
| if phase == "analyze": | |
| existing = [T for T in [4, 8, 16, 32, 64] | |
| if os.path.exists(f"ablation_results/T{T}/best_model.pt")] | |
| only_t = os.environ.get("TASK4_ONLY_T") | |
| if only_t and only_t.isdigit(): | |
| t_req = int(only_t) | |
| existing = [T for T in existing if T == t_req] | |
| if not existing: | |
| print(" No ablation models found at ablation_results/T*/best_model.pt") | |
| return | |
| print(f" Found models for T={existing}") | |
| if has_legacy: | |
| results = step_ablation.run_ablation_analysis( | |
| ablation_dir="ablation_results", base_cfg=cfg, | |
| src_list=src_list[:200], ref_list=ref_list[:200], | |
| tgt_tokenizer=tgt_tok, device=device, | |
| output_dir=OUTPUT_DIR) | |
| step_ablation.plot_ablation_3d( | |
| results, save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png")) | |
| elif has_new: | |
| from inference import load_model as _load_model | |
| models = {} | |
| for T in existing: | |
| ckpt = f"ablation_results/T{T}/best_model.pt" | |
| cfg_t = copy.deepcopy(cfg) | |
| cfg_t["model"]["diffusion_steps"] = T | |
| cfg_t["inference"]["num_steps"] = T | |
| m_t, _ = _load_model(ckpt, cfg_t, device) | |
| m_t.eval() | |
| models[T] = m_t | |
| knee_t = step_ablation.run_task4( | |
| models, src_list[:n_samples], ref_list[:n_samples], tgt_tok, | |
| output_dir=OUTPUT_DIR, n_samples=n_samples) | |
| print(f" New pipeline suggested optimal T={knee_t}") | |
| else: | |
| print(" Unsupported step_ablation API; please sync analysis/step_ablation.py") | |
| return | |
| # Optional adversarial robustness (legacy helper only) | |
| if hasattr(step_ablation, "run_adversarial_test"): | |
| print("\n Running adversarial robustness test...") | |
| inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4]) | |
| for s in src_list[:50]] | |
| step_ablation.run_adversarial_test( | |
| model, src_tok, tgt_tok, | |
| test_inputs=inp_texts, test_refs=ref_list[:50], | |
| device=device, output_dir=OUTPUT_DIR) | |
| # ββ Task 5 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list, task5_samples=500): | |
| print("\n" + "="*65) | |
| print(" TASK 5 β Classifier-Free Guidance") | |
| print("="*65) | |
| if not hasattr(model.model, 'encode_source'): | |
| print(" Compatibility mode: classifier-guidance unavailable; sweeping decoding controls.") | |
| n = min(100, int(task5_samples), len(src_list), len(ref_list)) | |
| lambdas = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0] | |
| results = [] | |
| for lam in lambdas: | |
| rep_pen = 1.0 + 0.15 * lam | |
| cer_vals, uniq_vals = [], [] | |
| for src, ref in zip(src_list[:n], ref_list[:n]): | |
| out = _generate_ids_compat( | |
| model, src.to(device), num_steps=8, temperature=0.8, top_k=40, | |
| repetition_penalty=rep_pen, diversity_penalty=0.0 | |
| ) | |
| txt, ids = _decode_ids(tgt_tok, out) | |
| cer_vals.append(_cer(txt, ref)) | |
| uniq_vals.append(len(set(ids)) / max(1, len(ids))) | |
| results.append((lam, float(np.mean(cer_vals)), float(np.mean(uniq_vals)))) | |
| print(f" Ξ»={lam:.1f} CER={results[-1][1]:.4f} diversity={results[-1][2]:.3f}") | |
| # Subtask graph: quality-diversity tradeoff | |
| x = [r[1] for r in results] | |
| y = [r[2] for r in results] | |
| labels = [r[0] for r in results] | |
| plt.figure(figsize=(6, 4)) | |
| plt.plot(x, y, marker="o") | |
| for xi, yi, la in zip(x, y, labels): | |
| plt.text(xi, yi, f"Ξ»={la:.1f}", fontsize=8) | |
| plt.xlabel("CER (lower is better)") | |
| plt.ylabel("Diversity") | |
| plt.title("Task5: Quality-Diversity Tradeoff") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task5_quality_diversity_tradeoff.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| rep = os.path.join(OUTPUT_DIR, "task5_report.txt") | |
| with open(rep, "w", encoding="utf-8") as f: | |
| f.write("TASK 5 β COMPATIBILITY REPORT\n") | |
| f.write("="*40 + "\n") | |
| f.write("Guidance classifier path unavailable; Ξ» mapped to repetition penalty.\n\n") | |
| for lam, cer_v, div_v in results: | |
| f.write(f"lambda={lam:.1f} CER={cer_v:.4f} diversity={div_v:.3f}\n") | |
| f.write("\nSaved graphs:\n") | |
| f.write(" - task5_quality_diversity_tradeoff.png\n") | |
| print(f" Report: {rep}") | |
| return | |
| try: | |
| from analysis.quality_classifier import ( | |
| QualityClassifier, collect_quality_data, | |
| train_quality_classifier, sweep_guidance_scales) | |
| except Exception: | |
| print(" Quality-classifier API mismatch; using compatibility sweep.") | |
| n = min(50, int(task5_samples), len(src_list)) | |
| scales = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0] | |
| results = [] | |
| for lam in scales: | |
| rep_pen = 1.0 + 0.2 * lam | |
| cer_vals, uniq_vals = [], [] | |
| for src, ref in zip(src_list[:n], ref_list[:n]): | |
| out = _generate_ids_compat( | |
| model, src.to(device), num_steps=8, temperature=0.8, top_k=40, | |
| repetition_penalty=rep_pen, diversity_penalty=0.0 | |
| ) | |
| txt, ids = _decode_ids(tgt_tok, out) | |
| cer_vals.append(_cer(txt, ref)) | |
| uniq_vals.append(len(set(ids)) / max(1, len(ids))) | |
| results.append((lam, float(np.mean(cer_vals)), float(np.mean(uniq_vals)))) | |
| print(f" Ξ»={lam:.1f} CER={results[-1][1]:.4f} diversity={results[-1][2]:.3f}") | |
| # Subtask graph: quality-diversity tradeoff | |
| x = [r[1] for r in results] | |
| y = [r[2] for r in results] | |
| labels = [r[0] for r in results] | |
| plt.figure(figsize=(6, 4)) | |
| plt.plot(x, y, marker="o") | |
| for xi, yi, la in zip(x, y, labels): | |
| plt.text(xi, yi, f"Ξ»={la:.1f}", fontsize=8) | |
| plt.xlabel("CER (lower is better)") | |
| plt.ylabel("Diversity") | |
| plt.title("Task5: Quality-Diversity Tradeoff") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(OUTPUT_DIR, "task5_quality_diversity_tradeoff.png"), dpi=150, bbox_inches="tight") | |
| plt.close() | |
| rep = os.path.join(OUTPUT_DIR, "task5_report.txt") | |
| with open(rep, "w", encoding="utf-8") as f: | |
| f.write("TASK 5 β COMPATIBILITY REPORT\n") | |
| f.write("="*40 + "\n") | |
| f.write("Guidance classifier path unavailable; Ξ» mapped to repetition penalty.\n\n") | |
| for lam, cer_v, div_v in results: | |
| f.write(f"lambda={lam:.1f} CER={cer_v:.4f} diversity={div_v:.3f}\n") | |
| f.write("\nSaved graphs:\n") | |
| f.write(" - task5_quality_diversity_tradeoff.png\n") | |
| print(f" Report: {rep}") | |
| return | |
| clf_path = os.path.join(OUTPUT_DIR, "task5_quality_classifier.pt") | |
| d_model = cfg['model']['d_model'] | |
| # Step 1: collect or load training data | |
| data_path = os.path.join(OUTPUT_DIR, "task5_quality_data.npz") | |
| if os.path.exists(data_path): | |
| print(" Loading cached quality data...") | |
| data = np.load(data_path) | |
| hidden = data["hidden"] | |
| quality = data["quality"] | |
| else: | |
| print(" Collecting quality data (this takes a few minutes)...") | |
| n = min(int(task5_samples), len(src_list)) | |
| hidden, quality = collect_quality_data( | |
| model, src_list[:n], ref_list[:n], tgt_tok, | |
| t_capture=0, max_samples=n) | |
| np.savez(data_path, hidden=hidden, quality=quality) | |
| print(f" Saved quality data: {data_path}") | |
| # Step 2: train or load classifier | |
| if os.path.exists(clf_path): | |
| print(f" Loading cached classifier: {clf_path}") | |
| clf = QualityClassifier(d_model) | |
| clf.load_state_dict(torch.load(clf_path, map_location='cpu')) | |
| clf.eval() | |
| else: | |
| print(" Training quality classifier...") | |
| clf = train_quality_classifier( | |
| hidden, quality, d_model=d_model, | |
| epochs=30, batch_size=64, lr=1e-3, | |
| save_path=clf_path) | |
| clf.eval() | |
| # Step 3: guidance scale sweep | |
| print("\n Guidance scale sweep (Ξ» β {0.0, 0.5, 1.0, 1.5, 2.0, 3.0})...") | |
| n_sweep = min(80, int(task5_samples), len(src_list)) | |
| results = sweep_guidance_scales( | |
| model, clf, src_list[:n_sweep], ref_list[:n_sweep], | |
| tgt_tok, scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0], | |
| n_samples=n_sweep, device=device, output_dir=OUTPUT_DIR) | |
| # Find optimal scale (quality + anti-collapse diversity) | |
| def _score(s): | |
| r = results[s] | |
| return (r["mean_cer"] - 0.05 * r.get("diversity", 0.0)) | |
| best_scale = min(results, key=_score) | |
| print(f"\n Optimal guidance scale: Ξ»={best_scale:.1f} " | |
| f"CER={results[best_scale]['mean_cer']:.4f}") | |
| report = os.path.join(OUTPUT_DIR, "task5_report.txt") | |
| with open(report, "w") as f: | |
| f.write("TASK 5 β CLASSIFIER-FREE GUIDANCE\n" + "="*50 + "\n\n") | |
| f.write(f"Classifier params: {sum(p.numel() for p in clf.parameters())}\n") | |
| f.write(f"Training samples : {len(hidden)}\n\n") | |
| f.write("Guidance scale sweep:\n") | |
| f.write(f" {'Ξ»':>6} {'CER':>8} {'diversity':>10} {'d2':>6} {'sBLEU':>8}\n") | |
| f.write(" " + "-"*52 + "\n") | |
| for s in sorted(results.keys()): | |
| r = results[s] | |
| marker = " β optimal" if s == best_scale else "" | |
| f.write( | |
| f" {s:>6.1f} {r['mean_cer']:>8.4f} {r['diversity']:>10.3f} " | |
| f"{r.get('distinct2', 0.0):>6.3f} {r.get('self_bleu', 0.0):>8.3f}{marker}\n" | |
| ) | |
| print(f" Report: {report}") | |
| # ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| global OUTPUT_DIR | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--task", | |
| choices=["1","2","3","4","5","all"], default="all") | |
| parser.add_argument("--input", | |
| default="dharmo rakαΉ£ati rakαΉ£itaαΈ₯", | |
| help="IAST input text for Task 2") | |
| parser.add_argument("--phase", | |
| choices=["generate_configs", "analyze"], default="analyze", | |
| help="Task 4 phase: generate_configs (before training) or analyze (after)") | |
| parser.add_argument("--checkpoint", default=None, | |
| help="Optional explicit checkpoint path") | |
| parser.add_argument("--output_dir", default="analysis/outputs", | |
| help="Output directory for reports/figures") | |
| parser.add_argument("--task4_samples", type=int, default=50, | |
| help="Samples for Task 4 dry/full evaluation") | |
| parser.add_argument("--task3_samples", type=int, default=500, | |
| help="Samples for Task 3 hidden-state collection") | |
| parser.add_argument("--task5_samples", type=int, default=500, | |
| help="Samples for Task 5 classifier data + sweep") | |
| args = parser.parse_args() | |
| OUTPUT_DIR = args.output_dir | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| cfg = copy.deepcopy(CONFIG) | |
| if args.checkpoint: | |
| cfg["model_type"] = infer_model_type_from_checkpoint(args.checkpoint) | |
| cfg["data"]["include_negative_examples"] = infer_include_negative_from_checkpoint(args.checkpoint) | |
| ckpt_name = os.path.basename(os.path.dirname(args.checkpoint)) | |
| if ckpt_name.startswith("T") and ckpt_name[1:].isdigit(): | |
| t_val = int(ckpt_name[1:]) | |
| cfg["model"]["diffusion_steps"] = t_val | |
| cfg["inference"]["num_steps"] = t_val | |
| requested = cfg["training"]["device"] | |
| if requested == "mps" and not torch.backends.mps.is_available(): | |
| requested = "cpu" | |
| elif requested == "cuda" and not torch.cuda.is_available(): | |
| requested = "cpu" | |
| cfg["training"]["device"] = requested | |
| device = torch.device(requested) | |
| print("Loading model and tokenizers...") | |
| model, src_tok, tgt_tok, cfg = load_everything(cfg, device, ckpt_override=args.checkpoint) | |
| # Load val data for tasks that need corpus/context (Tasks 2, 3, 4, 5) | |
| needs_data = args.task in ("2", "3", "4", "5", "all") | |
| if needs_data: | |
| print("Loading validation data...") | |
| src_list, ref_list, inp_list = load_val_data(cfg, src_tok, tgt_tok, n=500) | |
| else: | |
| src_list, ref_list, inp_list = [], [], [] | |
| tasks = (["1","2","3","4","5"] if args.task == "all" | |
| else [args.task]) | |
| for task in tasks: | |
| if task == "1": | |
| run_task1(model, src_tok, device) | |
| elif task == "2": | |
| run_task2(model, src_tok, tgt_tok, device, args.input, cfg, corpus_inputs=inp_list) | |
| elif task == "3": | |
| run_task3(model, src_tok, tgt_tok, device, src_list, ref_list, n_samples=args.task3_samples) | |
| elif task == "4": | |
| run_task4(args.phase, model, src_tok, tgt_tok, device, cfg, | |
| src_list, ref_list, n_samples=args.task4_samples) | |
| elif task == "5": | |
| run_task5( | |
| model, src_tok, tgt_tok, device, cfg, src_list, ref_list, | |
| task5_samples=args.task5_samples | |
| ) | |
| print(f"\n{'='*65}") | |
| print(f" All outputs saved to: {OUTPUT_DIR}/") | |
| print("="*65) | |
| if __name__ == "__main__": | |
| main() | |