Upload 27 files
Browse files- __init__.py +0 -0
- app.py +389 -0
- attention_viz.py +379 -0
- best_model.pt +3 -0
- concept_vectors.py +368 -0
- config_T16.py +119 -0
- config_T32.py +119 -0
- config_T4.py +119 -0
- config_T64.py +119 -0
- config_T8.py +119 -0
- d3pm_model_cross_attention.py +271 -0
- d3pm_model_encoder_decoder.py +227 -0
- dataset.py +152 -0
- forward_process.py +21 -0
- inference.py +300 -0
- kv_cache_benchmark.py +233 -0
- quality_classifier.py +514 -0
- reverse_process.py +302 -0
- reverse_process1.py +154 -0
- reverse_process2.py +275 -0
- run_analysis.py +407 -0
- sanskrit_model.py +61 -0
- scheduler.py +34 -0
- semantic_drift.py +279 -0
- step_ablation.py +389 -0
- tokenizer.py +222 -0
- train_all.sh +28 -0
__init__.py
ADDED
|
File without changes
|
app.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from config import CONFIG
|
| 10 |
+
from inference import load_model, run_inference, _build_tokenizers, _resolve_device
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
RESULTS_DIR = "generated_results"
|
| 14 |
+
os.makedirs(RESULTS_DIR, exist_ok=True)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def discover_checkpoints():
|
| 18 |
+
found = []
|
| 19 |
+
for root in ("ablation_results", "results7", "results"):
|
| 20 |
+
if not os.path.isdir(root):
|
| 21 |
+
continue
|
| 22 |
+
for entry in sorted(os.listdir(root)):
|
| 23 |
+
ckpt = os.path.join(root, entry, "best_model.pt")
|
| 24 |
+
if not os.path.exists(ckpt):
|
| 25 |
+
continue
|
| 26 |
+
found.append({
|
| 27 |
+
"label": f"{entry} [{root}]",
|
| 28 |
+
"path": ckpt,
|
| 29 |
+
"experiment": entry,
|
| 30 |
+
"root": root,
|
| 31 |
+
})
|
| 32 |
+
return found
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def default_checkpoint_label():
|
| 36 |
+
checkpoints = discover_checkpoints()
|
| 37 |
+
if not checkpoints:
|
| 38 |
+
return None
|
| 39 |
+
for item in checkpoints:
|
| 40 |
+
if item["path"].endswith("ablation_results/T4/best_model.pt"):
|
| 41 |
+
return item["label"]
|
| 42 |
+
return checkpoints[0]["label"]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def checkpoint_map():
|
| 46 |
+
return {item["label"]: item for item in discover_checkpoints()}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def infer_model_type(experiment_name: str, root: str = "") -> str:
|
| 50 |
+
if root == "ablation_results":
|
| 51 |
+
return "d3pm_cross_attention"
|
| 52 |
+
if experiment_name.startswith("d3pm_cross_attention"):
|
| 53 |
+
return "d3pm_cross_attention"
|
| 54 |
+
if experiment_name.startswith("d3pm_encoder_decoder"):
|
| 55 |
+
return "d3pm_encoder_decoder"
|
| 56 |
+
if experiment_name.startswith("baseline_cross_attention"):
|
| 57 |
+
return "baseline_cross_attention"
|
| 58 |
+
if experiment_name.startswith("baseline_encoder_decoder"):
|
| 59 |
+
return "baseline_encoder_decoder"
|
| 60 |
+
return CONFIG["model_type"]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def infer_include_negative(experiment_name: str, root: str = "") -> bool:
|
| 64 |
+
if root == "ablation_results":
|
| 65 |
+
return False
|
| 66 |
+
if "_neg_True" in experiment_name:
|
| 67 |
+
return True
|
| 68 |
+
if "_neg_False" in experiment_name:
|
| 69 |
+
return False
|
| 70 |
+
return CONFIG["data"]["include_negative_examples"]
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def build_runtime_cfg(ckpt_path: str):
|
| 74 |
+
experiment = os.path.basename(os.path.dirname(ckpt_path))
|
| 75 |
+
root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path)))
|
| 76 |
+
cfg = copy.deepcopy(CONFIG)
|
| 77 |
+
cfg["model_type"] = infer_model_type(experiment, root=root)
|
| 78 |
+
cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root)
|
| 79 |
+
if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit():
|
| 80 |
+
t_val = int(experiment[1:])
|
| 81 |
+
cfg["model"]["diffusion_steps"] = t_val
|
| 82 |
+
cfg["inference"]["num_steps"] = t_val
|
| 83 |
+
device = _resolve_device(cfg)
|
| 84 |
+
return cfg, device, experiment
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def load_selected_model(checkpoint_label):
|
| 88 |
+
mapping = checkpoint_map()
|
| 89 |
+
if checkpoint_label not in mapping:
|
| 90 |
+
raise gr.Error("Selected checkpoint was not found. Refresh the dropdown.")
|
| 91 |
+
|
| 92 |
+
ckpt_path = mapping[checkpoint_label]["path"]
|
| 93 |
+
cfg, device, experiment = build_runtime_cfg(ckpt_path)
|
| 94 |
+
model, cfg = load_model(ckpt_path, cfg, device)
|
| 95 |
+
src_tok, tgt_tok = _build_tokenizers(cfg)
|
| 96 |
+
|
| 97 |
+
bundle = {
|
| 98 |
+
"ckpt_path": ckpt_path,
|
| 99 |
+
"experiment": experiment,
|
| 100 |
+
"device": str(device),
|
| 101 |
+
"cfg": cfg,
|
| 102 |
+
"model": model,
|
| 103 |
+
"src_tok": src_tok,
|
| 104 |
+
"tgt_tok": tgt_tok,
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
model_info = {
|
| 108 |
+
"checkpoint": ckpt_path,
|
| 109 |
+
"experiment": experiment,
|
| 110 |
+
"model_type": cfg["model_type"],
|
| 111 |
+
"include_negatives": cfg["data"]["include_negative_examples"],
|
| 112 |
+
"device": str(device),
|
| 113 |
+
"max_seq_len": cfg["model"]["max_seq_len"],
|
| 114 |
+
"diffusion_steps": cfg["model"]["diffusion_steps"],
|
| 115 |
+
"d_model": cfg["model"]["d_model"],
|
| 116 |
+
"n_layers": cfg["model"]["n_layers"],
|
| 117 |
+
"n_heads": cfg["model"]["n_heads"],
|
| 118 |
+
}
|
| 119 |
+
status = f"Loaded `{experiment}` on `{device}`."
|
| 120 |
+
return bundle, status, model_info, cfg["inference"]["num_steps"]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def apply_preset(preset_name):
|
| 124 |
+
presets = {
|
| 125 |
+
"Manual": (0.70, 40, 1.20, 0.0, 64),
|
| 126 |
+
"Literal": (0.60, 20, 1.25, 0.0, 64),
|
| 127 |
+
"Balanced": (0.70, 40, 1.20, 0.0, 64),
|
| 128 |
+
"Creative": (0.85, 80, 1.20, 0.2, 64),
|
| 129 |
+
}
|
| 130 |
+
return presets.get(preset_name, presets["Balanced"])
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def task_notes_md():
|
| 134 |
+
return """
|
| 135 |
+
### Task Notes
|
| 136 |
+
|
| 137 |
+
**Task 1: KV Cache**
|
| 138 |
+
- Benchmark encoder caching vs standard generation.
|
| 139 |
+
- Best for engineering evaluation, not language quality evaluation.
|
| 140 |
+
|
| 141 |
+
**Task 2: Attention + Drift**
|
| 142 |
+
- Shows internal attention maps and output stabilization over diffusion steps.
|
| 143 |
+
- Useful for diagnostics and mentor discussion of model behavior.
|
| 144 |
+
|
| 145 |
+
**Task 3: Concept Vectors**
|
| 146 |
+
- Experimental PCA steering over decoder hidden states.
|
| 147 |
+
- Current outputs are exploratory, not strong semantic evidence yet.
|
| 148 |
+
|
| 149 |
+
**Task 4: Step Ablation**
|
| 150 |
+
- Requires retraining separate checkpoints for each diffusion step count.
|
| 151 |
+
- Use this UI for generation only; ablation analysis runs from `analysis/run_analysis.py`.
|
| 152 |
+
|
| 153 |
+
**Task 5: Quality Guidance**
|
| 154 |
+
- Advanced experimental feature in the analysis pipeline.
|
| 155 |
+
- Not exposed in this UI because the current evidence is still under validation.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def save_generation(experiment, record):
|
| 160 |
+
ts = datetime.now().strftime("%Y%m%d")
|
| 161 |
+
path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
|
| 162 |
+
existing = []
|
| 163 |
+
if os.path.exists(path):
|
| 164 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 165 |
+
existing = json.load(f)
|
| 166 |
+
existing.append(record)
|
| 167 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 168 |
+
json.dump(existing, f, ensure_ascii=False, indent=2)
|
| 169 |
+
return path
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def clean_generated_text(text: str, max_consecutive: int = 2, max_occurrence_ratio: float = 0.15) -> str:
|
| 173 |
+
"""
|
| 174 |
+
Lightweight cleanup for repetitive diffusion outputs.
|
| 175 |
+
Keeps Sanskrit tokens but trims pathological token loops.
|
| 176 |
+
"""
|
| 177 |
+
text = " ".join(text.split())
|
| 178 |
+
if not text:
|
| 179 |
+
return text
|
| 180 |
+
|
| 181 |
+
tokens = text.split()
|
| 182 |
+
cleaned = []
|
| 183 |
+
|
| 184 |
+
# 1) Limit consecutive token repetitions.
|
| 185 |
+
prev = None
|
| 186 |
+
run = 0
|
| 187 |
+
for tok in tokens:
|
| 188 |
+
if tok == prev:
|
| 189 |
+
run += 1
|
| 190 |
+
else:
|
| 191 |
+
prev = tok
|
| 192 |
+
run = 1
|
| 193 |
+
if run <= max_consecutive:
|
| 194 |
+
cleaned.append(tok)
|
| 195 |
+
|
| 196 |
+
# 2) Limit global over-dominant tokens (common in collapse cases).
|
| 197 |
+
if cleaned:
|
| 198 |
+
max_occ = max(3, int(len(cleaned) * max_occurrence_ratio))
|
| 199 |
+
counts = {}
|
| 200 |
+
filtered = []
|
| 201 |
+
for tok in cleaned:
|
| 202 |
+
c = counts.get(tok, 0) + 1
|
| 203 |
+
counts[tok] = c
|
| 204 |
+
if c <= max_occ:
|
| 205 |
+
filtered.append(tok)
|
| 206 |
+
cleaned = filtered
|
| 207 |
+
|
| 208 |
+
out = " ".join(cleaned)
|
| 209 |
+
out = out.replace(" ।", "।").replace(" ॥", "॥")
|
| 210 |
+
out = " ".join(out.split())
|
| 211 |
+
return out
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def generate_from_ui(
|
| 215 |
+
model_bundle,
|
| 216 |
+
input_text,
|
| 217 |
+
temperature,
|
| 218 |
+
top_k,
|
| 219 |
+
repetition_penalty,
|
| 220 |
+
diversity_penalty,
|
| 221 |
+
num_steps,
|
| 222 |
+
clean_output,
|
| 223 |
+
):
|
| 224 |
+
if not model_bundle:
|
| 225 |
+
raise gr.Error("Load a model first.")
|
| 226 |
+
if not input_text.strip():
|
| 227 |
+
raise gr.Error("Enter input text first.")
|
| 228 |
+
|
| 229 |
+
cfg = copy.deepcopy(model_bundle["cfg"])
|
| 230 |
+
cfg["inference"]["temperature"] = float(temperature)
|
| 231 |
+
cfg["inference"]["top_k"] = int(top_k)
|
| 232 |
+
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
|
| 233 |
+
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
|
| 234 |
+
cfg["inference"]["num_steps"] = int(num_steps)
|
| 235 |
+
|
| 236 |
+
src_tok = model_bundle["src_tok"]
|
| 237 |
+
tgt_tok = model_bundle["tgt_tok"]
|
| 238 |
+
device = torch.device(model_bundle["device"])
|
| 239 |
+
|
| 240 |
+
input_ids = torch.tensor(
|
| 241 |
+
[src_tok.encode(input_text.strip())],
|
| 242 |
+
dtype=torch.long,
|
| 243 |
+
device=device,
|
| 244 |
+
)
|
| 245 |
+
out = run_inference(model_bundle["model"], input_ids, cfg)
|
| 246 |
+
clean = [x for x in out[0].tolist() if x > 4]
|
| 247 |
+
raw_output_text = tgt_tok.decode(clean).strip()
|
| 248 |
+
output_text = clean_generated_text(raw_output_text) if clean_output else raw_output_text
|
| 249 |
+
if not output_text:
|
| 250 |
+
output_text = "(empty output)"
|
| 251 |
+
|
| 252 |
+
record = {
|
| 253 |
+
"timestamp": datetime.now().isoformat(timespec="seconds"),
|
| 254 |
+
"experiment": model_bundle["experiment"],
|
| 255 |
+
"checkpoint": model_bundle["ckpt_path"],
|
| 256 |
+
"input_text": input_text,
|
| 257 |
+
"raw_output_text": raw_output_text,
|
| 258 |
+
"output_text": output_text,
|
| 259 |
+
"clean_output": bool(clean_output),
|
| 260 |
+
"temperature": float(temperature),
|
| 261 |
+
"top_k": int(top_k),
|
| 262 |
+
"repetition_penalty": float(repetition_penalty),
|
| 263 |
+
"diversity_penalty": float(diversity_penalty),
|
| 264 |
+
"num_steps": int(num_steps),
|
| 265 |
+
}
|
| 266 |
+
log_path = save_generation(model_bundle["experiment"], record)
|
| 267 |
+
status = f"Generated with `{model_bundle['experiment']}`. Saved to `{log_path}`."
|
| 268 |
+
return output_text, status, record
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
with gr.Blocks(title="Sanskrit D3PM Studio") as demo:
|
| 272 |
+
model_state = gr.State(None)
|
| 273 |
+
|
| 274 |
+
gr.Markdown(
|
| 275 |
+
"""
|
| 276 |
+
# Sanskrit D3PM Studio
|
| 277 |
+
|
| 278 |
+
Load any available checkpoint, generate Devanagari output from Roman/IAST Sanskrit,
|
| 279 |
+
and inspect the settings used for evaluation or demos.
|
| 280 |
+
"""
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
with gr.Row():
|
| 284 |
+
with gr.Column(scale=2):
|
| 285 |
+
checkpoint_dropdown = gr.Dropdown(
|
| 286 |
+
label="Available Checkpoints",
|
| 287 |
+
choices=list(checkpoint_map().keys()),
|
| 288 |
+
value=default_checkpoint_label(),
|
| 289 |
+
interactive=True,
|
| 290 |
+
)
|
| 291 |
+
with gr.Column(scale=1):
|
| 292 |
+
refresh_btn = gr.Button("Refresh List")
|
| 293 |
+
load_btn = gr.Button("Load Model", variant="primary")
|
| 294 |
+
|
| 295 |
+
load_status = gr.Markdown("Select a checkpoint and load it.")
|
| 296 |
+
model_info = gr.JSON(label="Loaded Model Info")
|
| 297 |
+
|
| 298 |
+
with gr.Row():
|
| 299 |
+
with gr.Column(scale=2):
|
| 300 |
+
input_text = gr.Textbox(
|
| 301 |
+
label="Input Text (Roman / IAST Sanskrit)",
|
| 302 |
+
placeholder="dharmo rakṣati rakṣitaḥ",
|
| 303 |
+
lines=4,
|
| 304 |
+
)
|
| 305 |
+
output_text = gr.Textbox(
|
| 306 |
+
label="Generated Output (Devanagari)",
|
| 307 |
+
lines=6,
|
| 308 |
+
interactive=False,
|
| 309 |
+
)
|
| 310 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
| 311 |
+
with gr.Column(scale=1):
|
| 312 |
+
preset = gr.Radio(
|
| 313 |
+
["Manual", "Literal", "Balanced", "Creative"],
|
| 314 |
+
value="Balanced",
|
| 315 |
+
label="Inference Preset",
|
| 316 |
+
)
|
| 317 |
+
temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
|
| 318 |
+
top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
|
| 319 |
+
repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
|
| 320 |
+
diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
|
| 321 |
+
num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
|
| 322 |
+
clean_output = gr.Checkbox(value=True, label="Clean Output (dedupe loops)")
|
| 323 |
+
|
| 324 |
+
run_status = gr.Markdown("")
|
| 325 |
+
run_record = gr.JSON(label="Last Generation Metadata")
|
| 326 |
+
|
| 327 |
+
with gr.Accordion("Task Details and Evaluation Notes", open=False):
|
| 328 |
+
task_notes = gr.Markdown(task_notes_md())
|
| 329 |
+
|
| 330 |
+
gr.Examples(
|
| 331 |
+
examples=[
|
| 332 |
+
["dharmo rakṣati rakṣitaḥ"],
|
| 333 |
+
["satyameva jayate"],
|
| 334 |
+
["ahaṃ brahmāsmi"],
|
| 335 |
+
["yatra nāryastu pūjyante"],
|
| 336 |
+
],
|
| 337 |
+
inputs=[input_text],
|
| 338 |
+
label="Quick Examples",
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def refresh_checkpoints():
|
| 342 |
+
choices = list(checkpoint_map().keys())
|
| 343 |
+
value = choices[0] if choices else None
|
| 344 |
+
return gr.Dropdown(choices=choices, value=value)
|
| 345 |
+
|
| 346 |
+
refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown])
|
| 347 |
+
load_btn.click(
|
| 348 |
+
fn=load_selected_model,
|
| 349 |
+
inputs=[checkpoint_dropdown],
|
| 350 |
+
outputs=[model_state, load_status, model_info, num_steps],
|
| 351 |
+
)
|
| 352 |
+
preset.change(
|
| 353 |
+
fn=apply_preset,
|
| 354 |
+
inputs=[preset],
|
| 355 |
+
outputs=[temperature, top_k, repetition_penalty, diversity_penalty, num_steps],
|
| 356 |
+
)
|
| 357 |
+
generate_btn.click(
|
| 358 |
+
fn=generate_from_ui,
|
| 359 |
+
inputs=[
|
| 360 |
+
model_state,
|
| 361 |
+
input_text,
|
| 362 |
+
temperature,
|
| 363 |
+
top_k,
|
| 364 |
+
repetition_penalty,
|
| 365 |
+
diversity_penalty,
|
| 366 |
+
num_steps,
|
| 367 |
+
clean_output,
|
| 368 |
+
],
|
| 369 |
+
outputs=[output_text, run_status, run_record],
|
| 370 |
+
)
|
| 371 |
+
input_text.submit(
|
| 372 |
+
fn=generate_from_ui,
|
| 373 |
+
inputs=[
|
| 374 |
+
model_state,
|
| 375 |
+
input_text,
|
| 376 |
+
temperature,
|
| 377 |
+
top_k,
|
| 378 |
+
repetition_penalty,
|
| 379 |
+
diversity_penalty,
|
| 380 |
+
num_steps,
|
| 381 |
+
clean_output,
|
| 382 |
+
],
|
| 383 |
+
outputs=[output_text, run_status, run_record],
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
if __name__ == "__main__":
|
| 388 |
+
port = int(os.environ["GRADIO_SERVER_PORT"]) if "GRADIO_SERVER_PORT" in os.environ else None
|
| 389 |
+
demo.launch(server_name="127.0.0.1", server_port=port, share=False)
|
attention_viz.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis/attention_viz.py
|
| 3 |
+
==========================
|
| 4 |
+
Task 2: Attention weight capture and visualization across diffusion steps.
|
| 5 |
+
|
| 6 |
+
How it works (no retraining needed):
|
| 7 |
+
MultiHeadAttention now has two attributes:
|
| 8 |
+
- capture_weights: bool — set True to start storing weights
|
| 9 |
+
- last_attn_weights: Tensor — [B, n_heads, Lq, Lk], updated each forward call
|
| 10 |
+
|
| 11 |
+
AttentionCapture:
|
| 12 |
+
- Sets capture_weights=True on all cross-attention layers
|
| 13 |
+
- Hooks into generate_cached() to record weights at every diffusion step
|
| 14 |
+
- Returns a dict: {t_val: [layer_0_weights, layer_1_weights, ...]}
|
| 15 |
+
|
| 16 |
+
Visualization:
|
| 17 |
+
- plot_attn_heatmap(): shows src→tgt alignment at a single step
|
| 18 |
+
- plot_attn_evolution(): shows how one src→tgt pair evolves over T steps
|
| 19 |
+
- plot_all_layers(): grid of heatmaps per layer at a given step
|
| 20 |
+
|
| 21 |
+
Usage:
|
| 22 |
+
from analysis.attention_viz import AttentionCapture, plot_attn_heatmap
|
| 23 |
+
|
| 24 |
+
capturer = AttentionCapture(model)
|
| 25 |
+
weights = capturer.capture(src_ids, src_tokens, tgt_tokens)
|
| 26 |
+
plot_attn_heatmap(weights, step=0, layer=0, src_tokens=..., tgt_tokens=...)
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import numpy as np
|
| 31 |
+
import os
|
| 32 |
+
from typing import List, Dict, Optional
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ── Attention capture ─────────────────────────────────────────────────
|
| 36 |
+
|
| 37 |
+
class AttentionCapture:
|
| 38 |
+
"""
|
| 39 |
+
Captures cross-attention weights from all decoder layers at every
|
| 40 |
+
diffusion step during generate_cached().
|
| 41 |
+
|
| 42 |
+
Works by:
|
| 43 |
+
1. Setting capture_weights=True on each DecoderBlock.cross_attn
|
| 44 |
+
2. Running generate_cached() (encoder runs once via KV cache)
|
| 45 |
+
3. After each denoising step, reading last_attn_weights from each layer
|
| 46 |
+
4. Storing as {t_val: list_of_layer_weights}
|
| 47 |
+
|
| 48 |
+
Zero retraining required — uses the flag added to MultiHeadAttention.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, model):
|
| 52 |
+
"""
|
| 53 |
+
Args:
|
| 54 |
+
model : SanskritModel wrapper (must be D3PMCrossAttention)
|
| 55 |
+
"""
|
| 56 |
+
self.model = model
|
| 57 |
+
self.inner = model.model # D3PMCrossAttention
|
| 58 |
+
self._cross_attns = []
|
| 59 |
+
|
| 60 |
+
# Collect all cross-attention modules from decoder blocks
|
| 61 |
+
if hasattr(self.inner, 'decoder_blocks'):
|
| 62 |
+
for block in self.inner.decoder_blocks:
|
| 63 |
+
if hasattr(block, 'cross_attn'):
|
| 64 |
+
self._cross_attns.append(block.cross_attn)
|
| 65 |
+
|
| 66 |
+
if not self._cross_attns:
|
| 67 |
+
raise ValueError(
|
| 68 |
+
"No cross-attention layers found. "
|
| 69 |
+
"AttentionCapture only works with D3PMCrossAttention."
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
print(f"AttentionCapture: found {len(self._cross_attns)} cross-attention layers.")
|
| 73 |
+
|
| 74 |
+
def _enable(self):
|
| 75 |
+
"""Turn on weight capture for all cross-attention layers."""
|
| 76 |
+
for ca in self._cross_attns:
|
| 77 |
+
ca.capture_weights = True
|
| 78 |
+
|
| 79 |
+
def _disable(self):
|
| 80 |
+
"""Turn off weight capture (restores zero overhead)."""
|
| 81 |
+
for ca in self._cross_attns:
|
| 82 |
+
ca.capture_weights = False
|
| 83 |
+
ca.last_attn_weights = None
|
| 84 |
+
|
| 85 |
+
def _read_weights(self) -> List[np.ndarray]:
|
| 86 |
+
"""
|
| 87 |
+
Read current last_attn_weights from all layers.
|
| 88 |
+
Returns list of [B, n_heads, Lq, Lk] arrays — one per layer.
|
| 89 |
+
Averages over heads to produce [B, Lq, Lk].
|
| 90 |
+
"""
|
| 91 |
+
weights = []
|
| 92 |
+
for ca in self._cross_attns:
|
| 93 |
+
if ca.last_attn_weights is not None:
|
| 94 |
+
# Average over attention heads → [B, Lq, Lk]
|
| 95 |
+
w = ca.last_attn_weights.float().mean(dim=1)
|
| 96 |
+
weights.append(w.numpy())
|
| 97 |
+
return weights
|
| 98 |
+
|
| 99 |
+
@torch.no_grad()
|
| 100 |
+
def capture(
|
| 101 |
+
self,
|
| 102 |
+
src: torch.Tensor,
|
| 103 |
+
capture_every: int = 10,
|
| 104 |
+
) -> Dict[int, List[np.ndarray]]:
|
| 105 |
+
"""
|
| 106 |
+
Run full generation while capturing attention at every `capture_every` steps.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
src : [1, src_len] or [B, src_len] IAST token ids
|
| 110 |
+
capture_every : capture weights every N steps (default 10)
|
| 111 |
+
Use 1 to capture every step (slow, high memory).
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
step_weights : dict mapping t_val → list of [B, Lq, Lk] arrays
|
| 115 |
+
one array per decoder layer
|
| 116 |
+
keys are t values: T-1, T-1-N, ..., 0
|
| 117 |
+
|
| 118 |
+
Example:
|
| 119 |
+
weights = capturer.capture(src_ids, capture_every=10)
|
| 120 |
+
# weights[127] = layer weights at t=127 (heavy noise)
|
| 121 |
+
# weights[0] = layer weights at t=0 (clean output)
|
| 122 |
+
"""
|
| 123 |
+
if src.dim() == 1:
|
| 124 |
+
src = src.unsqueeze(0)
|
| 125 |
+
|
| 126 |
+
inner = self.inner
|
| 127 |
+
T = inner.scheduler.num_timesteps
|
| 128 |
+
device = src.device
|
| 129 |
+
|
| 130 |
+
# KV cache: encode source once
|
| 131 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 132 |
+
|
| 133 |
+
B = src.shape[0]
|
| 134 |
+
tgt_len = inner.max_seq_len
|
| 135 |
+
mask_id = inner.mask_token_id
|
| 136 |
+
|
| 137 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 138 |
+
hint = None
|
| 139 |
+
|
| 140 |
+
step_weights: Dict[int, List[np.ndarray]] = {}
|
| 141 |
+
|
| 142 |
+
self._enable()
|
| 143 |
+
try:
|
| 144 |
+
inner.eval()
|
| 145 |
+
for t_val in range(T - 1, -1, -1):
|
| 146 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 147 |
+
is_last = (t_val == 0)
|
| 148 |
+
|
| 149 |
+
logits, _ = inner.forward_cached(
|
| 150 |
+
memory, src_pad_mask, x0_est, t,
|
| 151 |
+
x0_hint=hint, inference_mode=True,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Capture at this step if scheduled or it's the last step
|
| 155 |
+
if (T - 1 - t_val) % capture_every == 0 or is_last:
|
| 156 |
+
step_weights[t_val] = self._read_weights()
|
| 157 |
+
|
| 158 |
+
import torch.nn.functional as F
|
| 159 |
+
probs = F.softmax(logits / 0.8, dim=-1)
|
| 160 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else \
|
| 161 |
+
_multinomial_sample(probs)
|
| 162 |
+
hint = x0_est
|
| 163 |
+
|
| 164 |
+
finally:
|
| 165 |
+
self._disable() # always restore — even if exception raised
|
| 166 |
+
|
| 167 |
+
print(f"Captured attention at {len(step_weights)} steps "
|
| 168 |
+
f"({len(self._cross_attns)} layers each).")
|
| 169 |
+
return step_weights
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _multinomial_sample(probs: torch.Tensor) -> torch.Tensor:
|
| 173 |
+
B, L, V = probs.shape
|
| 174 |
+
flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 175 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 176 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ── Visualization ─────────────────────────────────────────────────────
|
| 180 |
+
|
| 181 |
+
def plot_attn_heatmap(
|
| 182 |
+
step_weights: Dict[int, List[np.ndarray]],
|
| 183 |
+
t_val: int,
|
| 184 |
+
layer: int,
|
| 185 |
+
src_tokens: List[str],
|
| 186 |
+
tgt_tokens: List[str],
|
| 187 |
+
sample_idx: int = 0,
|
| 188 |
+
save_path: Optional[str] = None,
|
| 189 |
+
title: Optional[str] = None,
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Plot cross-attention heatmap for a single step and layer.
|
| 193 |
+
|
| 194 |
+
X-axis = source (IAST) tokens
|
| 195 |
+
Y-axis = target (Devanagari) positions
|
| 196 |
+
Color = attention weight (brighter = stronger attention)
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
step_weights : output of AttentionCapture.capture()
|
| 200 |
+
t_val : which diffusion step to visualize
|
| 201 |
+
layer : which decoder layer (0 = first, -1 = last)
|
| 202 |
+
src_tokens : list of IAST token strings for x-axis labels
|
| 203 |
+
tgt_tokens : list of Devanagari token strings for y-axis labels
|
| 204 |
+
sample_idx : which batch item to visualize (default 0)
|
| 205 |
+
save_path : if given, save figure to this path
|
| 206 |
+
title : custom plot title
|
| 207 |
+
"""
|
| 208 |
+
try:
|
| 209 |
+
import matplotlib.pyplot as plt
|
| 210 |
+
import matplotlib.ticker as ticker
|
| 211 |
+
except ImportError:
|
| 212 |
+
print("pip install matplotlib to use visualization functions.")
|
| 213 |
+
return
|
| 214 |
+
|
| 215 |
+
if t_val not in step_weights:
|
| 216 |
+
available = sorted(step_weights.keys())
|
| 217 |
+
raise ValueError(
|
| 218 |
+
f"t_val={t_val} not in captured steps. "
|
| 219 |
+
f"Available: {available[:5]}...{available[-5:]}"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
layers = step_weights[t_val]
|
| 223 |
+
weights = layers[layer][sample_idx] # [Lq, Lk]
|
| 224 |
+
|
| 225 |
+
# Trim to actual token lengths
|
| 226 |
+
n_src = min(len(src_tokens), weights.shape[1])
|
| 227 |
+
n_tgt = min(len(tgt_tokens), weights.shape[0])
|
| 228 |
+
weights = weights[:n_tgt, :n_src]
|
| 229 |
+
|
| 230 |
+
fig, ax = plt.subplots(figsize=(max(8, n_src * 0.4), max(6, n_tgt * 0.35)))
|
| 231 |
+
im = ax.imshow(weights, aspect='auto', cmap='YlOrRd', interpolation='nearest')
|
| 232 |
+
|
| 233 |
+
ax.set_xticks(range(n_src))
|
| 234 |
+
ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=9)
|
| 235 |
+
ax.set_yticks(range(n_tgt))
|
| 236 |
+
ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=9)
|
| 237 |
+
|
| 238 |
+
ax.set_xlabel("Source (IAST)", fontsize=11)
|
| 239 |
+
ax.set_ylabel("Target position (Devanagari)", fontsize=11)
|
| 240 |
+
|
| 241 |
+
plot_title = title or f"Cross-Attention | t={t_val} | Layer {layer}"
|
| 242 |
+
ax.set_title(plot_title, fontsize=12, pad=10)
|
| 243 |
+
|
| 244 |
+
plt.colorbar(im, ax=ax, label="Attention weight")
|
| 245 |
+
plt.tight_layout()
|
| 246 |
+
|
| 247 |
+
if save_path:
|
| 248 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 249 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 250 |
+
print(f"Saved: {save_path}")
|
| 251 |
+
else:
|
| 252 |
+
plt.show()
|
| 253 |
+
plt.close()
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def plot_attn_evolution(
|
| 257 |
+
step_weights: Dict[int, List[np.ndarray]],
|
| 258 |
+
src_token_idx: int,
|
| 259 |
+
tgt_token_idx: int,
|
| 260 |
+
layer: int = -1,
|
| 261 |
+
sample_idx: int = 0,
|
| 262 |
+
src_token_str: str = "",
|
| 263 |
+
tgt_token_str: str = "",
|
| 264 |
+
save_path: Optional[str] = None,
|
| 265 |
+
):
|
| 266 |
+
"""
|
| 267 |
+
Plot how attention between one specific src↔tgt token pair evolves
|
| 268 |
+
across all captured diffusion steps (T → 0).
|
| 269 |
+
|
| 270 |
+
Reveals whether a token pair is 'locked' (stable from early steps)
|
| 271 |
+
or 'flexible' (weight fluctuates until final steps).
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
step_weights : output of AttentionCapture.capture()
|
| 275 |
+
src_token_idx : index of source token to track
|
| 276 |
+
tgt_token_idx : index of target position to track
|
| 277 |
+
layer : decoder layer index
|
| 278 |
+
sample_idx : batch item
|
| 279 |
+
src_token_str : string label for the source token (for plot title)
|
| 280 |
+
tgt_token_str : string label for the target token (for plot title)
|
| 281 |
+
save_path : if given, save figure to this path
|
| 282 |
+
"""
|
| 283 |
+
try:
|
| 284 |
+
import matplotlib.pyplot as plt
|
| 285 |
+
except ImportError:
|
| 286 |
+
print("pip install matplotlib to use visualization functions.")
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
t_vals = sorted(step_weights.keys(), reverse=True) # T-1 → 0
|
| 290 |
+
weights = []
|
| 291 |
+
|
| 292 |
+
for t_val in t_vals:
|
| 293 |
+
layers = step_weights[t_val]
|
| 294 |
+
w = layers[layer][sample_idx] # [Lq, Lk]
|
| 295 |
+
if tgt_token_idx < w.shape[0] and src_token_idx < w.shape[1]:
|
| 296 |
+
weights.append(w[tgt_token_idx, src_token_idx])
|
| 297 |
+
else:
|
| 298 |
+
weights.append(0.0)
|
| 299 |
+
|
| 300 |
+
fig, ax = plt.subplots(figsize=(12, 4))
|
| 301 |
+
ax.plot(range(len(t_vals)), weights, linewidth=1.5, color='steelblue')
|
| 302 |
+
ax.fill_between(range(len(t_vals)), weights, alpha=0.2, color='steelblue')
|
| 303 |
+
|
| 304 |
+
# Mark every 10th step on x-axis
|
| 305 |
+
step_labels = [str(t) if i % max(1, len(t_vals)//10) == 0 else ""
|
| 306 |
+
for i, t in enumerate(t_vals)]
|
| 307 |
+
ax.set_xticks(range(len(t_vals)))
|
| 308 |
+
ax.set_xticklabels(step_labels, fontsize=8)
|
| 309 |
+
ax.set_xlabel("Diffusion step (T → 0)", fontsize=11)
|
| 310 |
+
ax.set_ylabel("Attention weight", fontsize=11)
|
| 311 |
+
|
| 312 |
+
pair_str = f"src[{src_token_idx}]={src_token_str!r} → tgt[{tgt_token_idx}]={tgt_token_str!r}"
|
| 313 |
+
ax.set_title(f"Attention evolution | {pair_str} | Layer {layer}", fontsize=11)
|
| 314 |
+
ax.set_xlim(0, len(t_vals) - 1)
|
| 315 |
+
ax.set_ylim(0, None)
|
| 316 |
+
plt.tight_layout()
|
| 317 |
+
|
| 318 |
+
if save_path:
|
| 319 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 320 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 321 |
+
print(f"Saved: {save_path}")
|
| 322 |
+
else:
|
| 323 |
+
plt.show()
|
| 324 |
+
plt.close()
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def plot_all_layers(
|
| 328 |
+
step_weights: Dict[int, List[np.ndarray]],
|
| 329 |
+
t_val: int,
|
| 330 |
+
src_tokens: List[str],
|
| 331 |
+
tgt_tokens: List[str],
|
| 332 |
+
sample_idx: int = 0,
|
| 333 |
+
save_path: Optional[str] = None,
|
| 334 |
+
):
|
| 335 |
+
"""
|
| 336 |
+
Plot attention heatmaps for ALL decoder layers at a single diffusion step.
|
| 337 |
+
Shows how different layers specialize their attention patterns.
|
| 338 |
+
"""
|
| 339 |
+
try:
|
| 340 |
+
import matplotlib.pyplot as plt
|
| 341 |
+
except ImportError:
|
| 342 |
+
print("pip install matplotlib to use visualization functions.")
|
| 343 |
+
return
|
| 344 |
+
|
| 345 |
+
layers = step_weights[t_val]
|
| 346 |
+
n_layers = len(layers)
|
| 347 |
+
n_cols = min(4, n_layers)
|
| 348 |
+
n_rows = (n_layers + n_cols - 1) // n_cols
|
| 349 |
+
|
| 350 |
+
fig, axes = plt.subplots(n_rows, n_cols,
|
| 351 |
+
figsize=(n_cols * 5, n_rows * 4))
|
| 352 |
+
axes = np.array(axes).flatten() if n_layers > 1 else [axes]
|
| 353 |
+
|
| 354 |
+
n_src = min(len(src_tokens), layers[0][sample_idx].shape[1])
|
| 355 |
+
n_tgt = min(len(tgt_tokens), layers[0][sample_idx].shape[0])
|
| 356 |
+
|
| 357 |
+
for i, (ax, layer_w) in enumerate(zip(axes, layers)):
|
| 358 |
+
w = layer_w[sample_idx][:n_tgt, :n_src]
|
| 359 |
+
im = ax.imshow(w, aspect='auto', cmap='YlOrRd', interpolation='nearest',
|
| 360 |
+
vmin=0, vmax=w.max())
|
| 361 |
+
ax.set_title(f"Layer {i}", fontsize=10)
|
| 362 |
+
ax.set_xticks(range(n_src))
|
| 363 |
+
ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=7)
|
| 364 |
+
ax.set_yticks(range(n_tgt))
|
| 365 |
+
ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=7)
|
| 366 |
+
|
| 367 |
+
for ax in axes[n_layers:]:
|
| 368 |
+
ax.set_visible(False)
|
| 369 |
+
|
| 370 |
+
fig.suptitle(f"All layers at t={t_val}", fontsize=13, y=1.02)
|
| 371 |
+
plt.tight_layout()
|
| 372 |
+
|
| 373 |
+
if save_path:
|
| 374 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 375 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 376 |
+
print(f"Saved: {save_path}")
|
| 377 |
+
else:
|
| 378 |
+
plt.show()
|
| 379 |
+
plt.close()
|
best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b1baa03568c2bed42621da115f6e6971411b59cc9dec6b58cf8f2ed87ba2e770
|
| 3 |
+
size 1077681643
|
concept_vectors.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis/concept_vectors.py
|
| 3 |
+
============================
|
| 4 |
+
Task 3: Concept Vector Extraction + Controlled Paraphrase Diversity
|
| 5 |
+
|
| 6 |
+
No retraining required. Uses decoder hidden states already computed
|
| 7 |
+
during generate_cached() — stored in model.model._last_hidden after
|
| 8 |
+
each forward_cached() call.
|
| 9 |
+
|
| 10 |
+
Steps:
|
| 11 |
+
1. Collect hidden states from N examples at a fixed diffusion step
|
| 12 |
+
2. Pool sequence dimension → [N, d_model] representation per example
|
| 13 |
+
3. PCA → find principal directions in concept space
|
| 14 |
+
4. Identify "diversity direction" (PC that best separates short/long outputs)
|
| 15 |
+
5. Steer: at inference, shift hidden states along diversity direction
|
| 16 |
+
before the output head projection
|
| 17 |
+
6. Generate at 5 points along the direction, measure output diversity
|
| 18 |
+
|
| 19 |
+
Key insight: the diversity direction is found purely from model outputs
|
| 20 |
+
(no human annotation needed). We use output length as a proxy:
|
| 21 |
+
short output → low diversity (model collapsed to simple token)
|
| 22 |
+
long output → high diversity (model exploring more of the space)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import numpy as np
|
| 29 |
+
from typing import List, Dict, Optional, Tuple
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ── Hidden state collection ───────────────────────────────────────────
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def collect_hidden_states(
|
| 36 |
+
model,
|
| 37 |
+
src_list: List[torch.Tensor],
|
| 38 |
+
t_capture: int = 0,
|
| 39 |
+
temperature: float = 0.8,
|
| 40 |
+
top_k: int = 40,
|
| 41 |
+
max_samples: int = 1000,
|
| 42 |
+
) -> Tuple[np.ndarray, List[str]]:
|
| 43 |
+
"""
|
| 44 |
+
Run generate_cached() on a list of source tensors, collecting the
|
| 45 |
+
decoder hidden state at timestep t_capture for each sample.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
model : SanskritModel (D3PMCrossAttention)
|
| 49 |
+
src_list : list of [1, src_len] tensors, one per sample
|
| 50 |
+
t_capture : which diffusion step to capture hidden states at
|
| 51 |
+
0 = final (clean), T-1 = noisy start
|
| 52 |
+
temperature: sampling temperature
|
| 53 |
+
top_k : top-k filter
|
| 54 |
+
max_samples: cap at this many samples
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
hidden_matrix : np.ndarray [N, d_model] — pooled hidden states
|
| 58 |
+
output_texts : list of N decoded output strings (for diversity analysis)
|
| 59 |
+
"""
|
| 60 |
+
inner = model.model
|
| 61 |
+
T = inner.scheduler.num_timesteps
|
| 62 |
+
device = next(inner.parameters()).device
|
| 63 |
+
|
| 64 |
+
hidden_list = []
|
| 65 |
+
output_list = []
|
| 66 |
+
|
| 67 |
+
n = min(len(src_list), max_samples)
|
| 68 |
+
print(f"Collecting hidden states from {n} examples at t={t_capture}...")
|
| 69 |
+
|
| 70 |
+
for i, src in enumerate(src_list[:n]):
|
| 71 |
+
if i % 100 == 0:
|
| 72 |
+
print(f" {i}/{n}")
|
| 73 |
+
|
| 74 |
+
if src.dim() == 1:
|
| 75 |
+
src = src.unsqueeze(0)
|
| 76 |
+
src = src.to(device)
|
| 77 |
+
|
| 78 |
+
B = src.shape[0]
|
| 79 |
+
tgt_len = inner.max_seq_len
|
| 80 |
+
mask_id = inner.mask_token_id
|
| 81 |
+
|
| 82 |
+
# KV cache
|
| 83 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 84 |
+
|
| 85 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 86 |
+
hint = None
|
| 87 |
+
captured_hidden = None
|
| 88 |
+
|
| 89 |
+
for t_val in range(T - 1, -1, -1):
|
| 90 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 91 |
+
is_last = (t_val == 0)
|
| 92 |
+
|
| 93 |
+
logits, _ = inner.forward_cached(
|
| 94 |
+
memory, src_pad_mask, x0_est, t,
|
| 95 |
+
x0_hint=hint, inference_mode=True,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Capture hidden state at target step
|
| 99 |
+
if t_val == t_capture and hasattr(inner, '_last_hidden'):
|
| 100 |
+
captured_hidden = inner._last_hidden.detach().cpu()
|
| 101 |
+
|
| 102 |
+
logits = logits / max(temperature, 1e-8)
|
| 103 |
+
if top_k > 0:
|
| 104 |
+
V = logits.shape[-1]
|
| 105 |
+
if top_k < V:
|
| 106 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 107 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 108 |
+
|
| 109 |
+
probs = F.softmax(logits, dim=-1)
|
| 110 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 111 |
+
hint = x0_est
|
| 112 |
+
|
| 113 |
+
# Pool hidden state over non-PAD positions → [d_model]
|
| 114 |
+
if captured_hidden is not None:
|
| 115 |
+
non_pad = (x0_est[0] > 1).cpu() # [tgt_len] bool
|
| 116 |
+
if non_pad.sum() > 0:
|
| 117 |
+
h = captured_hidden[0][non_pad].mean(dim=0) # [d_model]
|
| 118 |
+
else:
|
| 119 |
+
h = captured_hidden[0].mean(dim=0)
|
| 120 |
+
hidden_list.append(h.numpy())
|
| 121 |
+
|
| 122 |
+
# Decode output
|
| 123 |
+
ids = [x for x in x0_est[0].tolist() if x > 4]
|
| 124 |
+
|
| 125 |
+
print(f"Collected {len(hidden_list)} hidden states.")
|
| 126 |
+
return np.stack(hidden_list), output_list
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ── PCA on hidden states ──────────────────────────────────────────────
|
| 130 |
+
|
| 131 |
+
def fit_pca(
|
| 132 |
+
hidden_matrix: np.ndarray,
|
| 133 |
+
n_components: int = 50,
|
| 134 |
+
) -> object:
|
| 135 |
+
"""
|
| 136 |
+
Fit PCA on hidden state matrix.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
hidden_matrix : [N, d_model]
|
| 140 |
+
n_components : number of PCA components to retain
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
fitted sklearn PCA object
|
| 144 |
+
"""
|
| 145 |
+
from sklearn.decomposition import PCA
|
| 146 |
+
n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
|
| 147 |
+
pca = PCA(n_components=n_comp)
|
| 148 |
+
pca.fit(hidden_matrix)
|
| 149 |
+
print(f"PCA fit: {n_comp} components explain "
|
| 150 |
+
f"{pca.explained_variance_ratio_.sum()*100:.1f}% of variance.")
|
| 151 |
+
return pca
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def find_diversity_direction(
|
| 155 |
+
hidden_matrix: np.ndarray,
|
| 156 |
+
output_lengths: List[int],
|
| 157 |
+
pca: object,
|
| 158 |
+
) -> np.ndarray:
|
| 159 |
+
"""
|
| 160 |
+
Find the PCA direction that best correlates with output diversity
|
| 161 |
+
(measured by output length as proxy).
|
| 162 |
+
|
| 163 |
+
Projects hidden states into PCA space, then finds the PC whose
|
| 164 |
+
scores have highest Spearman correlation with output lengths.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
direction : np.ndarray [d_model] — diversity direction in original space
|
| 168 |
+
"""
|
| 169 |
+
from scipy.stats import spearmanr
|
| 170 |
+
|
| 171 |
+
projected = pca.transform(hidden_matrix) # [N, n_components]
|
| 172 |
+
lengths = np.array(output_lengths)
|
| 173 |
+
|
| 174 |
+
correlations = []
|
| 175 |
+
for pc_idx in range(projected.shape[1]):
|
| 176 |
+
r, _ = spearmanr(projected[:, pc_idx], lengths)
|
| 177 |
+
correlations.append(abs(r))
|
| 178 |
+
|
| 179 |
+
best_pc = int(np.argmax(correlations))
|
| 180 |
+
print(f"Diversity direction: PC {best_pc} "
|
| 181 |
+
f"(|r|={correlations[best_pc]:.3f} with output length)")
|
| 182 |
+
|
| 183 |
+
# Map back to original d_model space
|
| 184 |
+
direction = pca.components_[best_pc] # [d_model]
|
| 185 |
+
direction = direction / (np.linalg.norm(direction) + 1e-8)
|
| 186 |
+
return direction, best_pc, correlations[best_pc]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ── Steered generation ────────────────────────────────────────────────
|
| 190 |
+
|
| 191 |
+
@torch.no_grad()
|
| 192 |
+
def generate_steered(
|
| 193 |
+
model,
|
| 194 |
+
src: torch.Tensor,
|
| 195 |
+
direction: np.ndarray,
|
| 196 |
+
alpha: float = 0.0,
|
| 197 |
+
temperature: float = 0.8,
|
| 198 |
+
top_k: int = 40,
|
| 199 |
+
) -> torch.Tensor:
|
| 200 |
+
"""
|
| 201 |
+
Generate output while steering hidden states along diversity direction.
|
| 202 |
+
|
| 203 |
+
At each diffusion step, after the decoder runs, we shift the hidden state
|
| 204 |
+
by alpha * direction before projecting to logits.
|
| 205 |
+
|
| 206 |
+
alpha > 0 → push toward high-diversity output
|
| 207 |
+
alpha < 0 → push toward low-diversity output
|
| 208 |
+
alpha = 0 → standard generation (no steering)
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
model : SanskritModel (D3PMCrossAttention)
|
| 212 |
+
src : [1, src_len] IAST token ids
|
| 213 |
+
direction : [d_model] diversity direction from find_diversity_direction()
|
| 214 |
+
alpha : steering strength
|
| 215 |
+
temperature / top_k: sampling params
|
| 216 |
+
|
| 217 |
+
Returns:
|
| 218 |
+
x0_est : [1, tgt_len] generated token ids
|
| 219 |
+
"""
|
| 220 |
+
inner = model.model
|
| 221 |
+
T = inner.scheduler.num_timesteps
|
| 222 |
+
device = next(inner.parameters()).device
|
| 223 |
+
|
| 224 |
+
if src.dim() == 1:
|
| 225 |
+
src = src.unsqueeze(0)
|
| 226 |
+
src = src.to(device)
|
| 227 |
+
|
| 228 |
+
B = src.shape[0]
|
| 229 |
+
tgt_len = inner.max_seq_len
|
| 230 |
+
mask_id = inner.mask_token_id
|
| 231 |
+
|
| 232 |
+
dir_tensor = torch.tensor(direction, dtype=torch.float32, device=device)
|
| 233 |
+
|
| 234 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 235 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 236 |
+
hint = None
|
| 237 |
+
|
| 238 |
+
inner.eval()
|
| 239 |
+
for t_val in range(T - 1, -1, -1):
|
| 240 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 241 |
+
is_last = (t_val == 0)
|
| 242 |
+
|
| 243 |
+
# Standard forward_cached but we intercept hidden states
|
| 244 |
+
PAD = 1
|
| 245 |
+
tgt_pad_mask = None # inference_mode
|
| 246 |
+
|
| 247 |
+
_, x_t_ids = inner.forward_process.q_sample(x0_est, t) if t_val > 0 else \
|
| 248 |
+
(None, x0_est)
|
| 249 |
+
x = inner.tgt_embed(x_t_ids)
|
| 250 |
+
t_norm = t.float() / inner.scheduler.num_timesteps
|
| 251 |
+
t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
|
| 252 |
+
x = x + t_emb.unsqueeze(1)
|
| 253 |
+
|
| 254 |
+
if hint is not None:
|
| 255 |
+
hint_emb = inner.tgt_embed(hint)
|
| 256 |
+
gate = inner.hint_gate(x)
|
| 257 |
+
x = x + gate * hint_emb
|
| 258 |
+
|
| 259 |
+
for block in inner.decoder_blocks:
|
| 260 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
|
| 261 |
+
|
| 262 |
+
# ── STEERING: shift hidden states along diversity direction ───
|
| 263 |
+
if alpha != 0.0:
|
| 264 |
+
x = x + alpha * dir_tensor.unsqueeze(0).unsqueeze(0)
|
| 265 |
+
|
| 266 |
+
# Project to logits using the head
|
| 267 |
+
logits = inner.head(x)
|
| 268 |
+
|
| 269 |
+
logits = logits / max(temperature, 1e-8)
|
| 270 |
+
if top_k > 0:
|
| 271 |
+
V = logits.shape[-1]
|
| 272 |
+
if top_k < V:
|
| 273 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 274 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 275 |
+
|
| 276 |
+
probs = F.softmax(logits, dim=-1)
|
| 277 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 278 |
+
hint = x0_est
|
| 279 |
+
|
| 280 |
+
return x0_est
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def generate_diversity_spectrum(
|
| 284 |
+
model,
|
| 285 |
+
src: torch.Tensor,
|
| 286 |
+
direction: np.ndarray,
|
| 287 |
+
tgt_tokenizer,
|
| 288 |
+
alphas: List[float] = [-2.0, -1.0, 0.0, 1.0, 2.0],
|
| 289 |
+
temperature: float = 0.8,
|
| 290 |
+
top_k: int = 40,
|
| 291 |
+
) -> Dict[float, str]:
|
| 292 |
+
"""
|
| 293 |
+
Generate outputs at 5 points along the diversity direction.
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
alphas : steering strengths (negative = low diversity, positive = high)
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
dict mapping alpha → decoded Devanagari string
|
| 300 |
+
"""
|
| 301 |
+
results = {}
|
| 302 |
+
for alpha in alphas:
|
| 303 |
+
out_ids = generate_steered(model, src, direction, alpha, temperature, top_k)
|
| 304 |
+
ids = [x for x in out_ids[0].tolist() if x > 4]
|
| 305 |
+
text = tgt_tokenizer.decode(ids).strip()
|
| 306 |
+
results[alpha] = text
|
| 307 |
+
print(f" alpha={alpha:+.1f} → {text}")
|
| 308 |
+
return results
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def plot_pca_space(
|
| 312 |
+
hidden_matrix: np.ndarray,
|
| 313 |
+
output_lengths: List[int],
|
| 314 |
+
pca: object,
|
| 315 |
+
diversity_pc: int,
|
| 316 |
+
save_path: Optional[str] = None,
|
| 317 |
+
):
|
| 318 |
+
"""
|
| 319 |
+
Scatter plot of examples in PC1 vs PC2 space, coloured by output length.
|
| 320 |
+
Highlights the diversity direction.
|
| 321 |
+
"""
|
| 322 |
+
try:
|
| 323 |
+
import matplotlib.pyplot as plt
|
| 324 |
+
except ImportError:
|
| 325 |
+
print("pip install matplotlib.")
|
| 326 |
+
return
|
| 327 |
+
|
| 328 |
+
projected = pca.transform(hidden_matrix) # [N, n_pc]
|
| 329 |
+
lengths = np.array(output_lengths)
|
| 330 |
+
|
| 331 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 332 |
+
|
| 333 |
+
# Left: PC0 vs PC1 coloured by length
|
| 334 |
+
ax = axes[0]
|
| 335 |
+
sc = ax.scatter(projected[:, 0], projected[:, 1],
|
| 336 |
+
c=lengths, cmap='viridis', alpha=0.6, s=15)
|
| 337 |
+
plt.colorbar(sc, ax=ax, label="Output length (chars)")
|
| 338 |
+
ax.set_xlabel(f"PC0 ({pca.explained_variance_ratio_[0]*100:.1f}%)", fontsize=10)
|
| 339 |
+
ax.set_ylabel(f"PC1 ({pca.explained_variance_ratio_[1]*100:.1f}%)", fontsize=10)
|
| 340 |
+
ax.set_title("Concept space (PC0 vs PC1)", fontsize=11)
|
| 341 |
+
|
| 342 |
+
# Right: explained variance
|
| 343 |
+
ax2 = axes[1]
|
| 344 |
+
cumvar = np.cumsum(pca.explained_variance_ratio_) * 100
|
| 345 |
+
ax2.plot(range(1, len(cumvar)+1), cumvar, linewidth=1.5, color='steelblue')
|
| 346 |
+
ax2.axvline(diversity_pc, color='coral', linestyle='--', label=f"Diversity PC={diversity_pc}")
|
| 347 |
+
ax2.set_xlabel("Number of PCs", fontsize=10)
|
| 348 |
+
ax2.set_ylabel("Cumulative variance (%)", fontsize=10)
|
| 349 |
+
ax2.set_title("PCA explained variance", fontsize=11)
|
| 350 |
+
ax2.legend()
|
| 351 |
+
ax2.set_ylim(0, 102)
|
| 352 |
+
|
| 353 |
+
plt.tight_layout()
|
| 354 |
+
if save_path:
|
| 355 |
+
import os
|
| 356 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 357 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 358 |
+
print(f"Saved: {save_path}")
|
| 359 |
+
else:
|
| 360 |
+
plt.show()
|
| 361 |
+
plt.close()
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def _sample(probs):
|
| 365 |
+
B, L, V = probs.shape
|
| 366 |
+
flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 367 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 368 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
config_T16.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ablation config: T=16 diffusion steps
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _get_env_int(name, default):
|
| 7 |
+
value = os.environ.get(name)
|
| 8 |
+
return int(value) if value is not None else default
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _get_env_str(name, default):
|
| 12 |
+
return os.environ.get(name, default)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
|
| 16 |
+
MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
|
| 17 |
+
NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
|
| 18 |
+
DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
|
| 19 |
+
INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
|
| 20 |
+
TRAIN_DEVICE = _get_env_str(
|
| 21 |
+
"TRAIN_DEVICE",
|
| 22 |
+
"mps" if torch.backends.mps.is_available() else "cpu",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
CONFIG = {
|
| 26 |
+
"model_type": MODEL,
|
| 27 |
+
|
| 28 |
+
"data": {
|
| 29 |
+
"include_negative_examples": NEGATIVES,
|
| 30 |
+
"dataset_size": 60000,
|
| 31 |
+
},
|
| 32 |
+
|
| 33 |
+
# "model": {
|
| 34 |
+
# "vocab_size": 16000,
|
| 35 |
+
# "max_seq_len": 80,
|
| 36 |
+
# "diffusion_steps": 10,
|
| 37 |
+
# "d_model": 384,
|
| 38 |
+
# "n_layers": 6,
|
| 39 |
+
# "n_heads": 6,
|
| 40 |
+
# "d_ff": 1536,
|
| 41 |
+
# "dropout": 0.15
|
| 42 |
+
# },
|
| 43 |
+
#
|
| 44 |
+
# "diffusion": {
|
| 45 |
+
# "mask_token_id": 0
|
| 46 |
+
# },
|
| 47 |
+
#
|
| 48 |
+
# "training": {
|
| 49 |
+
# "batch_size": 32,
|
| 50 |
+
# "epochs": 10,
|
| 51 |
+
# "lr": 2e-4,
|
| 52 |
+
# "label_smoothing": 0.05,
|
| 53 |
+
# "precision": "float32",
|
| 54 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 55 |
+
# "early_stopping_patience": 3
|
| 56 |
+
# }
|
| 57 |
+
# "model": {
|
| 58 |
+
# "vocab_size": 16000,
|
| 59 |
+
# "max_seq_len": 96, # Optimized for GRETIL slokas
|
| 60 |
+
# "diffusion_steps": 16, # Use 16 steps (better than 8)
|
| 61 |
+
# "d_model": 512, # Wider model learns faster
|
| 62 |
+
# "n_layers": 8,
|
| 63 |
+
# "n_heads": 8,
|
| 64 |
+
# "d_ff": 2048,
|
| 65 |
+
# "dropout": 0.1
|
| 66 |
+
# },
|
| 67 |
+
#
|
| 68 |
+
# "diffusion": {
|
| 69 |
+
# "mask_token_id": 0
|
| 70 |
+
# },
|
| 71 |
+
#
|
| 72 |
+
# "training": {
|
| 73 |
+
# "batch_size": 32,
|
| 74 |
+
# "epochs": 20, # 20 is enough with these tweaks
|
| 75 |
+
# "lr": 4e-4, # Higher LR + Warmup for speed
|
| 76 |
+
# "label_smoothing": 0.15, # Increased for 16k vocab stability
|
| 77 |
+
# "precision": "float32",
|
| 78 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 79 |
+
# "early_stopping_patience": 5
|
| 80 |
+
# }
|
| 81 |
+
'diffusion': {
|
| 82 |
+
'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
|
| 83 |
+
},
|
| 84 |
+
|
| 85 |
+
# ── Model architecture ────────────────────────────────────────────
|
| 86 |
+
'model': {
|
| 87 |
+
# 'vocab_size': 16000,
|
| 88 |
+
'src_vocab_size': 16000, # Roman/IAST BPE vocab
|
| 89 |
+
'tgt_vocab_size': 16000, # Devanagari BPE vocab
|
| 90 |
+
'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
|
| 91 |
+
'n_heads': 8, # 384 / 6 = 64 head_dim
|
| 92 |
+
'd_ff': 4096, #2048, #1536, # 4 × d_model
|
| 93 |
+
'n_layers': 8,#4,
|
| 94 |
+
'dropout': 0.2,
|
| 95 |
+
'max_seq_len': 80,
|
| 96 |
+
'diffusion_steps': DIFFUSION_STEPS,
|
| 97 |
+
},
|
| 98 |
+
|
| 99 |
+
# ── Training ──────────────────────────────────────────────────────
|
| 100 |
+
'training': {
|
| 101 |
+
'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
|
| 102 |
+
'batch_size': 32,
|
| 103 |
+
'accum_steps': 2, # effective batch = 64
|
| 104 |
+
'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
|
| 105 |
+
'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
|
| 106 |
+
'patience': 4, # early stop after 4 non-improving epochs
|
| 107 |
+
'l1_lambda': 1e-7, # very light L1
|
| 108 |
+
'device': TRAIN_DEVICE,
|
| 109 |
+
},
|
| 110 |
+
|
| 111 |
+
# ── Inference (used during val BERTScore and generate()) ──────────
|
| 112 |
+
'inference': {
|
| 113 |
+
'num_steps': INFERENCE_STEPS,
|
| 114 |
+
'temperature': 0.7, # slightly lower = more confident output
|
| 115 |
+
'top_k': 40,
|
| 116 |
+
'repetition_penalty': 1.2,
|
| 117 |
+
'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
|
| 118 |
+
},
|
| 119 |
+
}
|
config_T32.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ablation config: T=32 diffusion steps
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _get_env_int(name, default):
|
| 7 |
+
value = os.environ.get(name)
|
| 8 |
+
return int(value) if value is not None else default
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _get_env_str(name, default):
|
| 12 |
+
return os.environ.get(name, default)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
|
| 16 |
+
MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
|
| 17 |
+
NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
|
| 18 |
+
DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
|
| 19 |
+
INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
|
| 20 |
+
TRAIN_DEVICE = _get_env_str(
|
| 21 |
+
"TRAIN_DEVICE",
|
| 22 |
+
"mps" if torch.backends.mps.is_available() else "cpu",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
CONFIG = {
|
| 26 |
+
"model_type": MODEL,
|
| 27 |
+
|
| 28 |
+
"data": {
|
| 29 |
+
"include_negative_examples": NEGATIVES,
|
| 30 |
+
"dataset_size": 60000,
|
| 31 |
+
},
|
| 32 |
+
|
| 33 |
+
# "model": {
|
| 34 |
+
# "vocab_size": 16000,
|
| 35 |
+
# "max_seq_len": 80,
|
| 36 |
+
# "diffusion_steps": 10,
|
| 37 |
+
# "d_model": 384,
|
| 38 |
+
# "n_layers": 6,
|
| 39 |
+
# "n_heads": 6,
|
| 40 |
+
# "d_ff": 1536,
|
| 41 |
+
# "dropout": 0.15
|
| 42 |
+
# },
|
| 43 |
+
#
|
| 44 |
+
# "diffusion": {
|
| 45 |
+
# "mask_token_id": 0
|
| 46 |
+
# },
|
| 47 |
+
#
|
| 48 |
+
# "training": {
|
| 49 |
+
# "batch_size": 32,
|
| 50 |
+
# "epochs": 10,
|
| 51 |
+
# "lr": 2e-4,
|
| 52 |
+
# "label_smoothing": 0.05,
|
| 53 |
+
# "precision": "float32",
|
| 54 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 55 |
+
# "early_stopping_patience": 3
|
| 56 |
+
# }
|
| 57 |
+
# "model": {
|
| 58 |
+
# "vocab_size": 16000,
|
| 59 |
+
# "max_seq_len": 96, # Optimized for GRETIL slokas
|
| 60 |
+
# "diffusion_steps": 16, # Use 16 steps (better than 8)
|
| 61 |
+
# "d_model": 512, # Wider model learns faster
|
| 62 |
+
# "n_layers": 8,
|
| 63 |
+
# "n_heads": 8,
|
| 64 |
+
# "d_ff": 2048,
|
| 65 |
+
# "dropout": 0.1
|
| 66 |
+
# },
|
| 67 |
+
#
|
| 68 |
+
# "diffusion": {
|
| 69 |
+
# "mask_token_id": 0
|
| 70 |
+
# },
|
| 71 |
+
#
|
| 72 |
+
# "training": {
|
| 73 |
+
# "batch_size": 32,
|
| 74 |
+
# "epochs": 20, # 20 is enough with these tweaks
|
| 75 |
+
# "lr": 4e-4, # Higher LR + Warmup for speed
|
| 76 |
+
# "label_smoothing": 0.15, # Increased for 16k vocab stability
|
| 77 |
+
# "precision": "float32",
|
| 78 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 79 |
+
# "early_stopping_patience": 5
|
| 80 |
+
# }
|
| 81 |
+
'diffusion': {
|
| 82 |
+
'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
|
| 83 |
+
},
|
| 84 |
+
|
| 85 |
+
# ── Model architecture ────────────────────────────────────────────
|
| 86 |
+
'model': {
|
| 87 |
+
# 'vocab_size': 16000,
|
| 88 |
+
'src_vocab_size': 16000, # Roman/IAST BPE vocab
|
| 89 |
+
'tgt_vocab_size': 16000, # Devanagari BPE vocab
|
| 90 |
+
'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
|
| 91 |
+
'n_heads': 8, # 384 / 6 = 64 head_dim
|
| 92 |
+
'd_ff': 4096, #2048, #1536, # 4 × d_model
|
| 93 |
+
'n_layers': 8,#4,
|
| 94 |
+
'dropout': 0.2,
|
| 95 |
+
'max_seq_len': 80,
|
| 96 |
+
'diffusion_steps': DIFFUSION_STEPS,
|
| 97 |
+
},
|
| 98 |
+
|
| 99 |
+
# ── Training ──────────────────────────────────────────────────────
|
| 100 |
+
'training': {
|
| 101 |
+
'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
|
| 102 |
+
'batch_size': 32,
|
| 103 |
+
'accum_steps': 2, # effective batch = 64
|
| 104 |
+
'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
|
| 105 |
+
'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
|
| 106 |
+
'patience': 4, # early stop after 4 non-improving epochs
|
| 107 |
+
'l1_lambda': 1e-7, # very light L1
|
| 108 |
+
'device': TRAIN_DEVICE,
|
| 109 |
+
},
|
| 110 |
+
|
| 111 |
+
# ── Inference (used during val BERTScore and generate()) ──────────
|
| 112 |
+
'inference': {
|
| 113 |
+
'num_steps': INFERENCE_STEPS,
|
| 114 |
+
'temperature': 0.7, # slightly lower = more confident output
|
| 115 |
+
'top_k': 40,
|
| 116 |
+
'repetition_penalty': 1.2,
|
| 117 |
+
'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
|
| 118 |
+
},
|
| 119 |
+
}
|
config_T4.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ablation config: T=4 diffusion steps
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _get_env_int(name, default):
|
| 7 |
+
value = os.environ.get(name)
|
| 8 |
+
return int(value) if value is not None else default
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _get_env_str(name, default):
|
| 12 |
+
return os.environ.get(name, default)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
|
| 16 |
+
MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
|
| 17 |
+
NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
|
| 18 |
+
DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
|
| 19 |
+
INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
|
| 20 |
+
TRAIN_DEVICE = _get_env_str(
|
| 21 |
+
"TRAIN_DEVICE",
|
| 22 |
+
"mps" if torch.backends.mps.is_available() else "cpu",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
CONFIG = {
|
| 26 |
+
"model_type": MODEL,
|
| 27 |
+
|
| 28 |
+
"data": {
|
| 29 |
+
"include_negative_examples": NEGATIVES,
|
| 30 |
+
"dataset_size": 60000,
|
| 31 |
+
},
|
| 32 |
+
|
| 33 |
+
# "model": {
|
| 34 |
+
# "vocab_size": 16000,
|
| 35 |
+
# "max_seq_len": 80,
|
| 36 |
+
# "diffusion_steps": 10,
|
| 37 |
+
# "d_model": 384,
|
| 38 |
+
# "n_layers": 6,
|
| 39 |
+
# "n_heads": 6,
|
| 40 |
+
# "d_ff": 1536,
|
| 41 |
+
# "dropout": 0.15
|
| 42 |
+
# },
|
| 43 |
+
#
|
| 44 |
+
# "diffusion": {
|
| 45 |
+
# "mask_token_id": 0
|
| 46 |
+
# },
|
| 47 |
+
#
|
| 48 |
+
# "training": {
|
| 49 |
+
# "batch_size": 32,
|
| 50 |
+
# "epochs": 10,
|
| 51 |
+
# "lr": 2e-4,
|
| 52 |
+
# "label_smoothing": 0.05,
|
| 53 |
+
# "precision": "float32",
|
| 54 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 55 |
+
# "early_stopping_patience": 3
|
| 56 |
+
# }
|
| 57 |
+
# "model": {
|
| 58 |
+
# "vocab_size": 16000,
|
| 59 |
+
# "max_seq_len": 96, # Optimized for GRETIL slokas
|
| 60 |
+
# "diffusion_steps": 16, # Use 16 steps (better than 8)
|
| 61 |
+
# "d_model": 512, # Wider model learns faster
|
| 62 |
+
# "n_layers": 8,
|
| 63 |
+
# "n_heads": 8,
|
| 64 |
+
# "d_ff": 2048,
|
| 65 |
+
# "dropout": 0.1
|
| 66 |
+
# },
|
| 67 |
+
#
|
| 68 |
+
# "diffusion": {
|
| 69 |
+
# "mask_token_id": 0
|
| 70 |
+
# },
|
| 71 |
+
#
|
| 72 |
+
# "training": {
|
| 73 |
+
# "batch_size": 32,
|
| 74 |
+
# "epochs": 20, # 20 is enough with these tweaks
|
| 75 |
+
# "lr": 4e-4, # Higher LR + Warmup for speed
|
| 76 |
+
# "label_smoothing": 0.15, # Increased for 16k vocab stability
|
| 77 |
+
# "precision": "float32",
|
| 78 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 79 |
+
# "early_stopping_patience": 5
|
| 80 |
+
# }
|
| 81 |
+
'diffusion': {
|
| 82 |
+
'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
|
| 83 |
+
},
|
| 84 |
+
|
| 85 |
+
# ── Model architecture ────────────────────────────────────────────
|
| 86 |
+
'model': {
|
| 87 |
+
# 'vocab_size': 16000,
|
| 88 |
+
'src_vocab_size': 16000, # Roman/IAST BPE vocab
|
| 89 |
+
'tgt_vocab_size': 16000, # Devanagari BPE vocab
|
| 90 |
+
'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
|
| 91 |
+
'n_heads': 8, # 384 / 6 = 64 head_dim
|
| 92 |
+
'd_ff': 4096, #2048, #1536, # 4 × d_model
|
| 93 |
+
'n_layers': 8,#4,
|
| 94 |
+
'dropout': 0.2,
|
| 95 |
+
'max_seq_len': 80,
|
| 96 |
+
'diffusion_steps': DIFFUSION_STEPS,
|
| 97 |
+
},
|
| 98 |
+
|
| 99 |
+
# ── Training ──────────────────────────────────────────────────────
|
| 100 |
+
'training': {
|
| 101 |
+
'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
|
| 102 |
+
'batch_size': 32,
|
| 103 |
+
'accum_steps': 2, # effective batch = 64
|
| 104 |
+
'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
|
| 105 |
+
'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
|
| 106 |
+
'patience': 4, # early stop after 4 non-improving epochs
|
| 107 |
+
'l1_lambda': 1e-7, # very light L1
|
| 108 |
+
'device': TRAIN_DEVICE,
|
| 109 |
+
},
|
| 110 |
+
|
| 111 |
+
# ── Inference (used during val BERTScore and generate()) ──────────
|
| 112 |
+
'inference': {
|
| 113 |
+
'num_steps': INFERENCE_STEPS,
|
| 114 |
+
'temperature': 0.7, # slightly lower = more confident output
|
| 115 |
+
'top_k': 40,
|
| 116 |
+
'repetition_penalty': 1.2,
|
| 117 |
+
'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
|
| 118 |
+
},
|
| 119 |
+
}
|
config_T64.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ablation config: T=64 diffusion steps
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _get_env_int(name, default):
|
| 7 |
+
value = os.environ.get(name)
|
| 8 |
+
return int(value) if value is not None else default
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _get_env_str(name, default):
|
| 12 |
+
return os.environ.get(name, default)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
|
| 16 |
+
MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
|
| 17 |
+
NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
|
| 18 |
+
DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
|
| 19 |
+
INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
|
| 20 |
+
TRAIN_DEVICE = _get_env_str(
|
| 21 |
+
"TRAIN_DEVICE",
|
| 22 |
+
"mps" if torch.backends.mps.is_available() else "cpu",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
CONFIG = {
|
| 26 |
+
"model_type": MODEL,
|
| 27 |
+
|
| 28 |
+
"data": {
|
| 29 |
+
"include_negative_examples": NEGATIVES,
|
| 30 |
+
"dataset_size": 60000,
|
| 31 |
+
},
|
| 32 |
+
|
| 33 |
+
# "model": {
|
| 34 |
+
# "vocab_size": 16000,
|
| 35 |
+
# "max_seq_len": 80,
|
| 36 |
+
# "diffusion_steps": 10,
|
| 37 |
+
# "d_model": 384,
|
| 38 |
+
# "n_layers": 6,
|
| 39 |
+
# "n_heads": 6,
|
| 40 |
+
# "d_ff": 1536,
|
| 41 |
+
# "dropout": 0.15
|
| 42 |
+
# },
|
| 43 |
+
#
|
| 44 |
+
# "diffusion": {
|
| 45 |
+
# "mask_token_id": 0
|
| 46 |
+
# },
|
| 47 |
+
#
|
| 48 |
+
# "training": {
|
| 49 |
+
# "batch_size": 32,
|
| 50 |
+
# "epochs": 10,
|
| 51 |
+
# "lr": 2e-4,
|
| 52 |
+
# "label_smoothing": 0.05,
|
| 53 |
+
# "precision": "float32",
|
| 54 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 55 |
+
# "early_stopping_patience": 3
|
| 56 |
+
# }
|
| 57 |
+
# "model": {
|
| 58 |
+
# "vocab_size": 16000,
|
| 59 |
+
# "max_seq_len": 96, # Optimized for GRETIL slokas
|
| 60 |
+
# "diffusion_steps": 16, # Use 16 steps (better than 8)
|
| 61 |
+
# "d_model": 512, # Wider model learns faster
|
| 62 |
+
# "n_layers": 8,
|
| 63 |
+
# "n_heads": 8,
|
| 64 |
+
# "d_ff": 2048,
|
| 65 |
+
# "dropout": 0.1
|
| 66 |
+
# },
|
| 67 |
+
#
|
| 68 |
+
# "diffusion": {
|
| 69 |
+
# "mask_token_id": 0
|
| 70 |
+
# },
|
| 71 |
+
#
|
| 72 |
+
# "training": {
|
| 73 |
+
# "batch_size": 32,
|
| 74 |
+
# "epochs": 20, # 20 is enough with these tweaks
|
| 75 |
+
# "lr": 4e-4, # Higher LR + Warmup for speed
|
| 76 |
+
# "label_smoothing": 0.15, # Increased for 16k vocab stability
|
| 77 |
+
# "precision": "float32",
|
| 78 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 79 |
+
# "early_stopping_patience": 5
|
| 80 |
+
# }
|
| 81 |
+
'diffusion': {
|
| 82 |
+
'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
|
| 83 |
+
},
|
| 84 |
+
|
| 85 |
+
# ── Model architecture ────────────────────────────────────────────
|
| 86 |
+
'model': {
|
| 87 |
+
# 'vocab_size': 16000,
|
| 88 |
+
'src_vocab_size': 16000, # Roman/IAST BPE vocab
|
| 89 |
+
'tgt_vocab_size': 16000, # Devanagari BPE vocab
|
| 90 |
+
'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
|
| 91 |
+
'n_heads': 8, # 384 / 6 = 64 head_dim
|
| 92 |
+
'd_ff': 4096, #2048, #1536, # 4 × d_model
|
| 93 |
+
'n_layers': 8,#4,
|
| 94 |
+
'dropout': 0.2,
|
| 95 |
+
'max_seq_len': 80,
|
| 96 |
+
'diffusion_steps': DIFFUSION_STEPS,
|
| 97 |
+
},
|
| 98 |
+
|
| 99 |
+
# ── Training ──────────────────────────────────────────────────────
|
| 100 |
+
'training': {
|
| 101 |
+
'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
|
| 102 |
+
'batch_size': 32,
|
| 103 |
+
'accum_steps': 2, # effective batch = 64
|
| 104 |
+
'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
|
| 105 |
+
'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
|
| 106 |
+
'patience': 4, # early stop after 4 non-improving epochs
|
| 107 |
+
'l1_lambda': 1e-7, # very light L1
|
| 108 |
+
'device': TRAIN_DEVICE,
|
| 109 |
+
},
|
| 110 |
+
|
| 111 |
+
# ── Inference (used during val BERTScore and generate()) ──────────
|
| 112 |
+
'inference': {
|
| 113 |
+
'num_steps': INFERENCE_STEPS,
|
| 114 |
+
'temperature': 0.7, # slightly lower = more confident output
|
| 115 |
+
'top_k': 40,
|
| 116 |
+
'repetition_penalty': 1.2,
|
| 117 |
+
'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
|
| 118 |
+
},
|
| 119 |
+
}
|
config_T8.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Ablation config: T=8 diffusion steps
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _get_env_int(name, default):
|
| 7 |
+
value = os.environ.get(name)
|
| 8 |
+
return int(value) if value is not None else default
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _get_env_str(name, default):
|
| 12 |
+
return os.environ.get(name, default)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# 🎛️ BASH-CONTROLLED SWITCHES (Defaults if run manually)
|
| 16 |
+
MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
|
| 17 |
+
NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
|
| 18 |
+
DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
|
| 19 |
+
INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
|
| 20 |
+
TRAIN_DEVICE = _get_env_str(
|
| 21 |
+
"TRAIN_DEVICE",
|
| 22 |
+
"mps" if torch.backends.mps.is_available() else "cpu",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
CONFIG = {
|
| 26 |
+
"model_type": MODEL,
|
| 27 |
+
|
| 28 |
+
"data": {
|
| 29 |
+
"include_negative_examples": NEGATIVES,
|
| 30 |
+
"dataset_size": 60000,
|
| 31 |
+
},
|
| 32 |
+
|
| 33 |
+
# "model": {
|
| 34 |
+
# "vocab_size": 16000,
|
| 35 |
+
# "max_seq_len": 80,
|
| 36 |
+
# "diffusion_steps": 10,
|
| 37 |
+
# "d_model": 384,
|
| 38 |
+
# "n_layers": 6,
|
| 39 |
+
# "n_heads": 6,
|
| 40 |
+
# "d_ff": 1536,
|
| 41 |
+
# "dropout": 0.15
|
| 42 |
+
# },
|
| 43 |
+
#
|
| 44 |
+
# "diffusion": {
|
| 45 |
+
# "mask_token_id": 0
|
| 46 |
+
# },
|
| 47 |
+
#
|
| 48 |
+
# "training": {
|
| 49 |
+
# "batch_size": 32,
|
| 50 |
+
# "epochs": 10,
|
| 51 |
+
# "lr": 2e-4,
|
| 52 |
+
# "label_smoothing": 0.05,
|
| 53 |
+
# "precision": "float32",
|
| 54 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 55 |
+
# "early_stopping_patience": 3
|
| 56 |
+
# }
|
| 57 |
+
# "model": {
|
| 58 |
+
# "vocab_size": 16000,
|
| 59 |
+
# "max_seq_len": 96, # Optimized for GRETIL slokas
|
| 60 |
+
# "diffusion_steps": 16, # Use 16 steps (better than 8)
|
| 61 |
+
# "d_model": 512, # Wider model learns faster
|
| 62 |
+
# "n_layers": 8,
|
| 63 |
+
# "n_heads": 8,
|
| 64 |
+
# "d_ff": 2048,
|
| 65 |
+
# "dropout": 0.1
|
| 66 |
+
# },
|
| 67 |
+
#
|
| 68 |
+
# "diffusion": {
|
| 69 |
+
# "mask_token_id": 0
|
| 70 |
+
# },
|
| 71 |
+
#
|
| 72 |
+
# "training": {
|
| 73 |
+
# "batch_size": 32,
|
| 74 |
+
# "epochs": 20, # 20 is enough with these tweaks
|
| 75 |
+
# "lr": 4e-4, # Higher LR + Warmup for speed
|
| 76 |
+
# "label_smoothing": 0.15, # Increased for 16k vocab stability
|
| 77 |
+
# "precision": "float32",
|
| 78 |
+
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
|
| 79 |
+
# "early_stopping_patience": 5
|
| 80 |
+
# }
|
| 81 |
+
'diffusion': {
|
| 82 |
+
'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
|
| 83 |
+
},
|
| 84 |
+
|
| 85 |
+
# ── Model architecture ────────────────────────────────────────────
|
| 86 |
+
'model': {
|
| 87 |
+
# 'vocab_size': 16000,
|
| 88 |
+
'src_vocab_size': 16000, # Roman/IAST BPE vocab
|
| 89 |
+
'tgt_vocab_size': 16000, # Devanagari BPE vocab
|
| 90 |
+
'd_model': 1024,#512, # was 384 — kept same, shared embeds save params
|
| 91 |
+
'n_heads': 8, # 384 / 6 = 64 head_dim
|
| 92 |
+
'd_ff': 4096, #2048, #1536, # 4 × d_model
|
| 93 |
+
'n_layers': 8,#4,
|
| 94 |
+
'dropout': 0.2,
|
| 95 |
+
'max_seq_len': 80,
|
| 96 |
+
'diffusion_steps': DIFFUSION_STEPS,
|
| 97 |
+
},
|
| 98 |
+
|
| 99 |
+
# ── Training ──────────────────────────────────────────────────────
|
| 100 |
+
'training': {
|
| 101 |
+
'epochs': 20, # Target: 0.71→0.83-0.85 in 5 epochs
|
| 102 |
+
'batch_size': 32,
|
| 103 |
+
'accum_steps': 2, # effective batch = 64
|
| 104 |
+
'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
|
| 105 |
+
'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
|
| 106 |
+
'patience': 4, # early stop after 4 non-improving epochs
|
| 107 |
+
'l1_lambda': 1e-7, # very light L1
|
| 108 |
+
'device': TRAIN_DEVICE,
|
| 109 |
+
},
|
| 110 |
+
|
| 111 |
+
# ── Inference (used during val BERTScore and generate()) ──────────
|
| 112 |
+
'inference': {
|
| 113 |
+
'num_steps': INFERENCE_STEPS,
|
| 114 |
+
'temperature': 0.7, # slightly lower = more confident output
|
| 115 |
+
'top_k': 40,
|
| 116 |
+
'repetition_penalty': 1.2,
|
| 117 |
+
'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
|
| 118 |
+
},
|
| 119 |
+
}
|
d3pm_model_cross_attention.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
d3pm_model_cross_attention.py — Cross-Script + Generation-Fixed
|
| 3 |
+
=================================================================
|
| 4 |
+
INPUT : quote_text tokens (Roman script, src_vocab_size)
|
| 5 |
+
OUTPUT : quote_devanagari tokens (Devanagari script, tgt_vocab_size)
|
| 6 |
+
|
| 7 |
+
src_embed uses src_vocab_size (Roman BPE)
|
| 8 |
+
tgt_embed uses tgt_vocab_size (Devanagari BPE)
|
| 9 |
+
head outputs tgt_vocab_size (predict Devanagari tokens)
|
| 10 |
+
Weight tying: head <-> tgt_embed only (NOT src_embed)
|
| 11 |
+
|
| 12 |
+
Generation bugs fixed:
|
| 13 |
+
BUG 1 - tgt_pad_mask suppressed during inference
|
| 14 |
+
BUG 2 - q_sample skipped at t=0
|
| 15 |
+
BUG 3 - time embedding before hint_gate
|
| 16 |
+
BUG 4 - diversity penalty uses global mean not var
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
from diffusion.scheduler import OptimizedCosineScheduler
|
| 24 |
+
from diffusion.forward_process import AbsorbingForwardProcess
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SinusoidalPositionalEncoding(nn.Module):
|
| 28 |
+
def __init__(self, d_model, max_len=5000):
|
| 29 |
+
super().__init__()
|
| 30 |
+
pe = torch.zeros(max_len, d_model)
|
| 31 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 32 |
+
div_term = torch.exp(
|
| 33 |
+
torch.arange(0, d_model, 2).float() *
|
| 34 |
+
(-torch.log(torch.tensor(10000.0)) / d_model)
|
| 35 |
+
)
|
| 36 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 37 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 38 |
+
self.register_buffer("pe", pe.unsqueeze(0))
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return x + self.pe[:, :x.size(1), :]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SanskritEmbeddings(nn.Module):
|
| 45 |
+
def __init__(self, vocab_size, d_model, max_seq_len):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.token_emb = nn.Embedding(vocab_size, d_model)
|
| 48 |
+
self.pos_enc = SinusoidalPositionalEncoding(d_model, max_seq_len)
|
| 49 |
+
self.token_embedding = self.token_emb
|
| 50 |
+
def forward(self, tokens):
|
| 51 |
+
return self.pos_enc(self.token_emb(tokens))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MultiHeadAttention(nn.Module):
|
| 55 |
+
def __init__(self, d_model, n_heads, dropout=0.1):
|
| 56 |
+
super().__init__()
|
| 57 |
+
assert d_model % n_heads == 0
|
| 58 |
+
self.d_model = d_model
|
| 59 |
+
self.n_heads = n_heads
|
| 60 |
+
self.head_dim = d_model // n_heads
|
| 61 |
+
self.q_proj = nn.Linear(d_model, d_model)
|
| 62 |
+
self.k_proj = nn.Linear(d_model, d_model)
|
| 63 |
+
self.v_proj = nn.Linear(d_model, d_model)
|
| 64 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 65 |
+
self.dropout = nn.Dropout(dropout)
|
| 66 |
+
|
| 67 |
+
def forward(self, q, k, v, mask=None):
|
| 68 |
+
B, Lq, _ = q.size()
|
| 69 |
+
Lk = k.size(1)
|
| 70 |
+
Q = self.q_proj(q).view(B, Lq, self.n_heads, self.head_dim).transpose(1, 2)
|
| 71 |
+
K = self.k_proj(k).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
|
| 72 |
+
V = self.v_proj(v).view(B, Lk, self.n_heads, self.head_dim).transpose(1, 2)
|
| 73 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| 74 |
+
if mask is not None:
|
| 75 |
+
scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
|
| 76 |
+
attn = self.dropout(torch.softmax(scores, dim=-1))
|
| 77 |
+
out = torch.matmul(attn, V).transpose(1, 2).contiguous().view(B, Lq, self.d_model)
|
| 78 |
+
return self.out_proj(out)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class EncoderBlock(nn.Module):
|
| 82 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.mha = MultiHeadAttention(d_model, n_heads, dropout)
|
| 85 |
+
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
|
| 86 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout))
|
| 87 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 88 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 89 |
+
def forward(self, x, pad_mask=None):
|
| 90 |
+
x = self.norm1(x + self.mha(x, x, x, mask=pad_mask))
|
| 91 |
+
return self.norm2(x + self.ff(x))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class DecoderBlock(nn.Module):
|
| 95 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 98 |
+
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 99 |
+
self.ff = nn.Sequential(nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
|
| 100 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout))
|
| 101 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 102 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 103 |
+
self.norm3 = nn.LayerNorm(d_model)
|
| 104 |
+
def forward(self, x, memory, tgt_pad_mask=None, src_pad_mask=None):
|
| 105 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
|
| 106 |
+
x = self.norm2(x + self.cross_attn(x, memory, memory, mask=src_pad_mask))
|
| 107 |
+
return self.norm3(x + self.ff(x))
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class D3PMCrossAttention(nn.Module):
|
| 111 |
+
def __init__(self, cfg):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.cfg = cfg
|
| 114 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 115 |
+
d = cfg['model']['d_model']
|
| 116 |
+
nhead = cfg['model']['n_heads']
|
| 117 |
+
d_ff = cfg['model']['d_ff']
|
| 118 |
+
drop = cfg['model']['dropout']
|
| 119 |
+
seqlen = cfg['model']['max_seq_len']
|
| 120 |
+
nlayer = cfg['model']['n_layers']
|
| 121 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 122 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 123 |
+
|
| 124 |
+
# Separate embeddings: Roman src, Devanagari tgt
|
| 125 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
|
| 126 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
|
| 127 |
+
|
| 128 |
+
self.scheduler = OptimizedCosineScheduler(cfg)
|
| 129 |
+
self.forward_process = AbsorbingForwardProcess(self.scheduler)
|
| 130 |
+
|
| 131 |
+
self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 132 |
+
self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 133 |
+
|
| 134 |
+
self.time_mlp = nn.Sequential(nn.Linear(1, d//4), nn.SiLU(), nn.Linear(d//4, d))
|
| 135 |
+
self.hint_gate = nn.Sequential(nn.Linear(d, d), nn.Sigmoid())
|
| 136 |
+
|
| 137 |
+
# Output head: predict Devanagari tokens, tied to tgt_embed
|
| 138 |
+
self.head = nn.Linear(d, tgt_vocab, bias=False)
|
| 139 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 140 |
+
|
| 141 |
+
def forward(self, src, tgt, t, x0_hint=None, inference_mode=False):
|
| 142 |
+
PAD = 1
|
| 143 |
+
src_pad_mask = (src == PAD)
|
| 144 |
+
# BUG 1 FIX: no tgt mask during inference
|
| 145 |
+
tgt_pad_mask = None if inference_mode else (tgt == PAD)
|
| 146 |
+
|
| 147 |
+
# Encode Roman source
|
| 148 |
+
memory = self.src_embed(src)
|
| 149 |
+
for block in self.encoder_blocks:
|
| 150 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 151 |
+
|
| 152 |
+
# BUG 2 FIX: skip q_sample at final step t=0
|
| 153 |
+
if inference_mode and (t == 0).all():
|
| 154 |
+
x_t_ids = tgt
|
| 155 |
+
else:
|
| 156 |
+
_, x_t_ids = self.forward_process.q_sample(tgt, t)
|
| 157 |
+
|
| 158 |
+
x = self.tgt_embed(x_t_ids)
|
| 159 |
+
|
| 160 |
+
# BUG 3 FIX: time embedding BEFORE hint gate
|
| 161 |
+
t_norm = t.float() / self.scheduler.num_timesteps
|
| 162 |
+
t_emb = self.time_mlp(t_norm.unsqueeze(-1))
|
| 163 |
+
x = x + t_emb.unsqueeze(1)
|
| 164 |
+
|
| 165 |
+
if x0_hint is not None:
|
| 166 |
+
hint_emb = self.tgt_embed(x0_hint)
|
| 167 |
+
gate = self.hint_gate(x) # time-aware gate
|
| 168 |
+
x = x + gate * hint_emb
|
| 169 |
+
|
| 170 |
+
for block in self.decoder_blocks:
|
| 171 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
|
| 172 |
+
|
| 173 |
+
return self.head(x), None
|
| 174 |
+
|
| 175 |
+
@torch.no_grad()
|
| 176 |
+
def generate(self, src, num_steps=None, temperature=0.8, top_k=50,
|
| 177 |
+
repetition_penalty=1.2, diversity_penalty=0.0):
|
| 178 |
+
if src.dim() == 1:
|
| 179 |
+
src = src.unsqueeze(0)
|
| 180 |
+
device = src.device
|
| 181 |
+
B, L = src.shape
|
| 182 |
+
T = self.scheduler.num_timesteps
|
| 183 |
+
steps = num_steps or T
|
| 184 |
+
step_size = max(1, T // steps)
|
| 185 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 186 |
+
if timesteps[-1] != 0:
|
| 187 |
+
timesteps.append(0)
|
| 188 |
+
|
| 189 |
+
mask_id = self.mask_token_id
|
| 190 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 191 |
+
hint = None
|
| 192 |
+
|
| 193 |
+
self.eval()
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 196 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 197 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 198 |
+
logits, _ = self.forward(src, x0_est, t, x0_hint=hint, inference_mode=True)
|
| 199 |
+
if repetition_penalty != 1.0:
|
| 200 |
+
logits = _apply_repetition_penalty(logits, x0_est, repetition_penalty)
|
| 201 |
+
if diversity_penalty > 0.0:
|
| 202 |
+
logits = _apply_diversity_penalty_fixed(logits, diversity_penalty) # BUG 4 FIX
|
| 203 |
+
logits = logits / max(temperature, 1e-5)
|
| 204 |
+
if top_k > 0:
|
| 205 |
+
logits = _top_k_filter(logits, top_k)
|
| 206 |
+
probs = F.softmax(logits, dim=-1)
|
| 207 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _batch_multinomial(probs)
|
| 208 |
+
hint = x0_est
|
| 209 |
+
return x0_est
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class BaselineCrossAttention(nn.Module):
|
| 213 |
+
def __init__(self, cfg):
|
| 214 |
+
super().__init__()
|
| 215 |
+
d = cfg['model']['d_model']; nhead = cfg['model']['n_heads']
|
| 216 |
+
d_ff = cfg['model']['d_ff']; drop = cfg['model']['dropout']
|
| 217 |
+
seqlen = cfg['model']['max_seq_len']; nlayer = cfg['model']['n_layers']
|
| 218 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 219 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 220 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d, seqlen)
|
| 221 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d, seqlen)
|
| 222 |
+
self.encoder_blocks = nn.ModuleList([EncoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 223 |
+
self.decoder_blocks = nn.ModuleList([DecoderBlock(d, nhead, d_ff, drop) for _ in range(nlayer)])
|
| 224 |
+
self.head = nn.Linear(d, tgt_vocab, bias=False)
|
| 225 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 226 |
+
|
| 227 |
+
def forward(self, src, tgt, t=None, x0_hint=None):
|
| 228 |
+
PAD = 1
|
| 229 |
+
memory = self.src_embed(src)
|
| 230 |
+
for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==PAD))
|
| 231 |
+
x = self.tgt_embed(tgt)
|
| 232 |
+
for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=(tgt==PAD), src_pad_mask=(src==PAD))
|
| 233 |
+
return (self.head(x),)
|
| 234 |
+
|
| 235 |
+
@torch.no_grad()
|
| 236 |
+
def generate(self, src, max_len=None, start_token_id=2, **kwargs):
|
| 237 |
+
if max_len is None: max_len = src.size(1)
|
| 238 |
+
B, device = src.size(0), src.device
|
| 239 |
+
memory = self.src_embed(src)
|
| 240 |
+
for b in self.encoder_blocks: memory = b(memory, pad_mask=(src==1))
|
| 241 |
+
ys = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
|
| 242 |
+
for _ in range(max_len):
|
| 243 |
+
x = self.tgt_embed(ys)
|
| 244 |
+
for b in self.decoder_blocks: x = b(x, memory, tgt_pad_mask=None, src_pad_mask=(src==1))
|
| 245 |
+
ys = torch.cat([ys, torch.argmax(self.head(x)[:,-1,:], dim=-1, keepdim=True)], dim=1)
|
| 246 |
+
return ys[:, 1:max_len+1]
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# helpers
|
| 250 |
+
def _top_k_filter(logits, k):
|
| 251 |
+
B, L, V = logits.shape
|
| 252 |
+
if k >= V: return logits
|
| 253 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 254 |
+
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
|
| 255 |
+
|
| 256 |
+
def _batch_multinomial(probs):
|
| 257 |
+
B, L, V = probs.shape
|
| 258 |
+
flat = probs.view(B*L, V) + 1e-9
|
| 259 |
+
return torch.multinomial(flat/flat.sum(-1,keepdim=True), 1).squeeze(-1).view(B, L)
|
| 260 |
+
|
| 261 |
+
def _apply_repetition_penalty(logits, prev_tokens, penalty):
|
| 262 |
+
for b in range(logits.shape[0]):
|
| 263 |
+
for tid in set(prev_tokens[b].tolist()):
|
| 264 |
+
if tid > 4: logits[b, :, tid] = logits[b, :, tid] / penalty
|
| 265 |
+
return logits
|
| 266 |
+
|
| 267 |
+
def _apply_diversity_penalty(logits, penalty): # legacy wrong version
|
| 268 |
+
return logits + penalty * logits.var(dim=-1, keepdim=True)
|
| 269 |
+
|
| 270 |
+
def _apply_diversity_penalty_fixed(logits, penalty): # correct version
|
| 271 |
+
return logits - penalty * logits.mean(dim=1, keepdim=True)
|
d3pm_model_encoder_decoder.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from diffusion.scheduler import OptimizedCosineScheduler
|
| 4 |
+
from diffusion.forward_process import AbsorbingForwardProcess
|
| 5 |
+
# Import shared classes to guarantee identical architectures
|
| 6 |
+
from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention
|
| 7 |
+
class DecoderBlock(nn.Module):
|
| 8 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 11 |
+
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) # ← restored
|
| 12 |
+
self.ff = nn.Sequential(
|
| 13 |
+
nn.Linear(d_model, d_ff),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.Dropout(dropout),
|
| 16 |
+
nn.Linear(d_ff, d_model),
|
| 17 |
+
nn.Dropout(dropout),
|
| 18 |
+
)
|
| 19 |
+
self.norm1 = nn.LayerNorm(d_model)
|
| 20 |
+
self.norm2 = nn.LayerNorm(d_model)
|
| 21 |
+
self.norm3 = nn.LayerNorm(d_model) # ← restored (for cross-attn residual)
|
| 22 |
+
|
| 23 |
+
def forward(self, x, memory, tgt_pad_mask=None):
|
| 24 |
+
# 1. Masked self-attention on target
|
| 25 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
|
| 26 |
+
# 2. Cross-attention: queries from decoder, keys/values from encoder memory
|
| 27 |
+
x = self.norm2(x + self.cross_attn(x, memory, memory))
|
| 28 |
+
# 3. Feed-forward
|
| 29 |
+
return self.norm3(x + self.ff(x))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DecoderBlockNoCrossAttn(nn.Module):
|
| 33 |
+
"""Kept for reference — NOT used by D3PMEncoderDecoder."""
|
| 34 |
+
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
|
| 37 |
+
self.ff = nn.Sequential(
|
| 38 |
+
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
|
| 39 |
+
nn.Linear(d_ff, d_model), nn.Dropout(dropout),
|
| 40 |
+
)
|
| 41 |
+
self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
|
| 42 |
+
|
| 43 |
+
def forward(self, x, tgt_pad_mask=None, causal_mask=None):
|
| 44 |
+
combined_mask = None
|
| 45 |
+
if tgt_pad_mask is not None and causal_mask is not None:
|
| 46 |
+
combined_mask = tgt_pad_mask | causal_mask
|
| 47 |
+
elif causal_mask is not None:
|
| 48 |
+
combined_mask = causal_mask
|
| 49 |
+
elif tgt_pad_mask is not None:
|
| 50 |
+
combined_mask = tgt_pad_mask
|
| 51 |
+
x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask))
|
| 52 |
+
return self.norm2(x + self.ff(x))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ============================================================
|
| 56 |
+
# 1. D3PM Encoder-Decoder Model
|
| 57 |
+
# ============================================================
|
| 58 |
+
class D3PMEncoderDecoder(nn.Module):
|
| 59 |
+
def __init__(self, cfg):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.cfg = cfg
|
| 62 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 63 |
+
|
| 64 |
+
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
|
| 65 |
+
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
|
| 66 |
+
d_model = cfg['model']['d_model']
|
| 67 |
+
n_heads = cfg['model']['n_heads']
|
| 68 |
+
d_ff = cfg['model']['d_ff']
|
| 69 |
+
dropout = cfg['model']['dropout']
|
| 70 |
+
n_layers = cfg['model']['n_layers']
|
| 71 |
+
max_len = cfg['model']['max_seq_len']
|
| 72 |
+
|
| 73 |
+
self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len)
|
| 74 |
+
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len)
|
| 75 |
+
|
| 76 |
+
self.scheduler = OptimizedCosineScheduler(cfg)
|
| 77 |
+
self.forward_process = AbsorbingForwardProcess(self.scheduler)
|
| 78 |
+
|
| 79 |
+
self.encoder_blocks = nn.ModuleList([
|
| 80 |
+
EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 81 |
+
])
|
| 82 |
+
# DecoderBlock now has cross-attention — matches saved checkpoint
|
| 83 |
+
self.decoder_blocks = nn.ModuleList([
|
| 84 |
+
DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
|
| 85 |
+
])
|
| 86 |
+
|
| 87 |
+
self.time_mlp = nn.Sequential(
|
| 88 |
+
nn.Linear(1, d_model // 4), nn.SiLU(),
|
| 89 |
+
nn.Linear(d_model // 4, d_model),
|
| 90 |
+
)
|
| 91 |
+
self.head = nn.Linear(d_model, tgt_vocab)
|
| 92 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 93 |
+
|
| 94 |
+
def forward(self, src, tgt, t, x0_hint=None):
|
| 95 |
+
src_pad_mask = (src == 1)
|
| 96 |
+
tgt_pad_mask = (tgt == 1)
|
| 97 |
+
|
| 98 |
+
# Encode source (Roman IAST)
|
| 99 |
+
memory = self.src_embed(src)
|
| 100 |
+
for block in self.encoder_blocks:
|
| 101 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 102 |
+
|
| 103 |
+
# Corrupt target with forward diffusion
|
| 104 |
+
_, x_t_ids = self.forward_process.q_sample(tgt, t)
|
| 105 |
+
|
| 106 |
+
# Optionally blend in x0_hint (self-conditioning)
|
| 107 |
+
if x0_hint is not None:
|
| 108 |
+
hint_prob = 0.5
|
| 109 |
+
blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob)
|
| 110 |
+
still_mask = (x_t_ids == self.mask_token_id)
|
| 111 |
+
x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids)
|
| 112 |
+
|
| 113 |
+
x = self.tgt_embed(x_t_ids)
|
| 114 |
+
t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1)
|
| 115 |
+
x = x + t_emb.expand(-1, tgt.shape[1], -1)
|
| 116 |
+
|
| 117 |
+
# Decode with cross-attention over encoder memory
|
| 118 |
+
for block in self.decoder_blocks:
|
| 119 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
|
| 120 |
+
|
| 121 |
+
return self.head(x), None
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def generate(
|
| 125 |
+
self,
|
| 126 |
+
src,
|
| 127 |
+
num_steps = None,
|
| 128 |
+
temperature = 0.75,
|
| 129 |
+
top_k = 50,
|
| 130 |
+
repetition_penalty = 1.15,
|
| 131 |
+
diversity_penalty = 0.0,
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Iterative D3PM reverse diffusion — same signature as
|
| 135 |
+
D3PMCrossAttention.generate() so SanskritModel.generate() works
|
| 136 |
+
identically for both model types.
|
| 137 |
+
"""
|
| 138 |
+
device = src.device
|
| 139 |
+
B, L = src.shape[0], self.cfg['model']['max_seq_len']
|
| 140 |
+
T = num_steps or self.scheduler.num_timesteps
|
| 141 |
+
mask_id = self.mask_token_id
|
| 142 |
+
pad_id = 1
|
| 143 |
+
|
| 144 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 145 |
+
|
| 146 |
+
for step in range(T - 1, -1, -1):
|
| 147 |
+
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
|
| 148 |
+
hint = x0_est.clone()
|
| 149 |
+
|
| 150 |
+
logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint)
|
| 151 |
+
|
| 152 |
+
# Repetition penalty
|
| 153 |
+
if repetition_penalty != 1.0:
|
| 154 |
+
for b in range(B):
|
| 155 |
+
for tok in set(x0_est[b].tolist()):
|
| 156 |
+
if tok > pad_id:
|
| 157 |
+
logits[b, :, tok] /= repetition_penalty
|
| 158 |
+
|
| 159 |
+
# Diversity penalty (suppress common tokens)
|
| 160 |
+
if diversity_penalty > 0.0:
|
| 161 |
+
logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True)
|
| 162 |
+
|
| 163 |
+
# Temperature + top-k sampling
|
| 164 |
+
logits = logits / max(temperature, 1e-8)
|
| 165 |
+
if top_k > 0:
|
| 166 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 167 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 168 |
+
|
| 169 |
+
probs = torch.softmax(logits, dim=-1)
|
| 170 |
+
# Only update positions that are still masked
|
| 171 |
+
still = (x0_est == mask_id)
|
| 172 |
+
sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L)
|
| 173 |
+
x0_est = torch.where(still, sample, x0_est)
|
| 174 |
+
|
| 175 |
+
return x0_est
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# ============================================================
|
| 179 |
+
# 2. Baseline Encoder-Decoder Model (unchanged)
|
| 180 |
+
# ============================================================
|
| 181 |
+
class BaselineEncoderDecoder(nn.Module):
|
| 182 |
+
def __init__(self, cfg):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.cfg = cfg
|
| 185 |
+
self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
|
| 186 |
+
cfg['model']['max_seq_len'])
|
| 187 |
+
self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
|
| 188 |
+
cfg['model']['max_seq_len'])
|
| 189 |
+
self.encoder_blocks = nn.ModuleList([
|
| 190 |
+
EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
|
| 191 |
+
cfg['model']['d_ff'], cfg['model']['dropout'])
|
| 192 |
+
for _ in range(cfg['model']['n_layers'])
|
| 193 |
+
])
|
| 194 |
+
self.decoder_blocks = nn.ModuleList([
|
| 195 |
+
DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
|
| 196 |
+
cfg['model']['d_ff'], cfg['model']['dropout'])
|
| 197 |
+
for _ in range(cfg['model']['n_layers'])
|
| 198 |
+
])
|
| 199 |
+
self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size'])
|
| 200 |
+
self.head.weight = self.tgt_embed.token_embedding.weight
|
| 201 |
+
|
| 202 |
+
def forward(self, src, tgt):
|
| 203 |
+
src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1)
|
| 204 |
+
memory = self.src_embed(src)
|
| 205 |
+
for block in self.encoder_blocks:
|
| 206 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 207 |
+
x = self.tgt_embed(tgt)
|
| 208 |
+
for block in self.decoder_blocks:
|
| 209 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
|
| 210 |
+
return self.head(x)
|
| 211 |
+
|
| 212 |
+
@torch.no_grad()
|
| 213 |
+
def generate(self, src, max_len=80, start_token_id=2):
|
| 214 |
+
batch_size, device = src.size(0), src.device
|
| 215 |
+
src_pad_mask = (src == 1)
|
| 216 |
+
memory = self.src_embed(src)
|
| 217 |
+
for block in self.encoder_blocks:
|
| 218 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 219 |
+
ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id
|
| 220 |
+
for _ in range(max_len):
|
| 221 |
+
x = self.tgt_embed(ys)
|
| 222 |
+
for block in self.decoder_blocks:
|
| 223 |
+
x = block(x, memory, tgt_pad_mask=None)
|
| 224 |
+
logits = self.head(x)
|
| 225 |
+
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
|
| 226 |
+
ys = torch.cat([ys, next_token], dim=1)
|
| 227 |
+
return ys[:, 1:]
|
dataset.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dataset.py — Cross-Script Translation Fix
|
| 3 |
+
==========================================
|
| 4 |
+
INPUT : quote_text (Roman/IAST transliteration of Sanskrit)
|
| 5 |
+
TARGET : quote_devanagari (Devanagari script)
|
| 6 |
+
|
| 7 |
+
This is the CORRECT task: the model learns to transliterate / translate
|
| 8 |
+
Roman Sanskrit → Devanagari, which is a meaningful, learnable mapping
|
| 9 |
+
(far better than devanagari→devanagari reconstruction which teaches nothing).
|
| 10 |
+
|
| 11 |
+
KEY CHANGES from original:
|
| 12 |
+
1. _input_field = 'quote_text' (was 'quote_devanagari')
|
| 13 |
+
2. _target_field = 'quote_devanagari' (unchanged)
|
| 14 |
+
3. Separate source/target tokenizers — Roman and Devanagari have
|
| 15 |
+
completely different character sets; a shared BPE vocab forces the
|
| 16 |
+
model to learn both scripts in one embedding table, which wastes
|
| 17 |
+
capacity and confuses the attention mechanism.
|
| 18 |
+
4. Negative example generation fixed — reversal now operates on
|
| 19 |
+
DEVANAGARI target only (not accidentally on Roman source).
|
| 20 |
+
5. curriculum_sort uses target length (Devanagari) for difficulty proxy.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from datasets import load_dataset
|
| 24 |
+
from torch.utils.data import Dataset
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
import random
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class OptimizedSanskritDataset(Dataset):
|
| 31 |
+
def __init__(self, split='train', tokenizer=None, max_len=80, cfg=None,
|
| 32 |
+
src_tokenizer=None, tgt_tokenizer=None):
|
| 33 |
+
"""
|
| 34 |
+
Args:
|
| 35 |
+
tokenizer : shared tokenizer (legacy — used if src/tgt not provided)
|
| 36 |
+
src_tokenizer : tokenizer for quote_text (Roman script)
|
| 37 |
+
tgt_tokenizer : tokenizer for quote_devanagari (Devanagari script)
|
| 38 |
+
If None, falls back to shared `tokenizer`.
|
| 39 |
+
"""
|
| 40 |
+
from config import CONFIG
|
| 41 |
+
self.cfg = cfg or CONFIG
|
| 42 |
+
self.max_len = max_len
|
| 43 |
+
self.pad_id = 1
|
| 44 |
+
self.mask_id = self.cfg['diffusion']['mask_token_id']
|
| 45 |
+
self.include_negatives = self.cfg['data']['include_negative_examples']
|
| 46 |
+
|
| 47 |
+
# ── Tokenizer setup ───────────────────────────────────────────
|
| 48 |
+
# Support both legacy (shared) and new (separate src/tgt) tokenizers
|
| 49 |
+
self.src_tokenizer = src_tokenizer or tokenizer
|
| 50 |
+
self.tgt_tokenizer = tgt_tokenizer or tokenizer
|
| 51 |
+
|
| 52 |
+
if self.src_tokenizer is None:
|
| 53 |
+
raise ValueError("Provide at least one tokenizer.")
|
| 54 |
+
|
| 55 |
+
print(f"📥 Loading '{split}' split …")
|
| 56 |
+
raw = load_dataset("paws/sanskrit-verses-gretil", split=split)
|
| 57 |
+
cols = raw.column_names
|
| 58 |
+
|
| 59 |
+
# ── Field selection ───────────────────────────────────────────
|
| 60 |
+
if 'quote_text' in cols and 'quote_devanagari' in cols:
|
| 61 |
+
# CORRECT setup: Roman input → Devanagari output
|
| 62 |
+
self._input_field = 'quote_text'
|
| 63 |
+
self._target_field = 'quote_devanagari'
|
| 64 |
+
print(" Format: quote_text (Roman) → quote_devanagari (Devanagari) ✓")
|
| 65 |
+
elif 'sentence1' in cols and 'sentence2' in cols:
|
| 66 |
+
# PAWS paraphrase pairs fallback
|
| 67 |
+
self._input_field = 'sentence1'
|
| 68 |
+
self._target_field = 'sentence2'
|
| 69 |
+
print(" Format: PAWS sentence pairs ✓")
|
| 70 |
+
else:
|
| 71 |
+
# Last resort: same field both sides
|
| 72 |
+
self._input_field = 'quote_devanagari'
|
| 73 |
+
self._target_field = 'quote_devanagari'
|
| 74 |
+
print(" ⚠️ Format: Devanagari→Devanagari (suboptimal — no quote_text found)")
|
| 75 |
+
|
| 76 |
+
# ── Filter empty rows ─────────────────────────────────────────
|
| 77 |
+
# Some rows have empty quote_text — skip them
|
| 78 |
+
raw = raw.filter(
|
| 79 |
+
lambda ex: (
|
| 80 |
+
bool(ex[self._input_field].strip()) and
|
| 81 |
+
bool(ex[self._target_field].strip())
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
print(f" After empty-filter: {len(raw)} samples")
|
| 85 |
+
|
| 86 |
+
self.dataset = raw
|
| 87 |
+
|
| 88 |
+
if split == 'train':
|
| 89 |
+
self.dataset = self._curriculum_sort()
|
| 90 |
+
|
| 91 |
+
print(f"✅ {len(self.dataset)} samples loaded.")
|
| 92 |
+
|
| 93 |
+
# ── Encoding ──────────────────────────────────────────────────────
|
| 94 |
+
|
| 95 |
+
def _encode_src(self, text):
|
| 96 |
+
"""Encode source (Roman) text."""
|
| 97 |
+
ids = self.src_tokenizer.encode(text)[:self.max_len]
|
| 98 |
+
t = torch.tensor(ids, dtype=torch.long)
|
| 99 |
+
t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
|
| 100 |
+
return t
|
| 101 |
+
|
| 102 |
+
def _encode_tgt(self, text):
|
| 103 |
+
"""Encode target (Devanagari) text."""
|
| 104 |
+
ids = self.tgt_tokenizer.encode(text)[:self.max_len]
|
| 105 |
+
t = torch.tensor(ids, dtype=torch.long)
|
| 106 |
+
t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
|
| 107 |
+
return t
|
| 108 |
+
|
| 109 |
+
# ── Curriculum ───��────────────────────────────────────────────────
|
| 110 |
+
|
| 111 |
+
def _curriculum_sort(self):
|
| 112 |
+
"""Short, common Devanagari targets first → long, rare targets last."""
|
| 113 |
+
scores = []
|
| 114 |
+
for s in self.dataset:
|
| 115 |
+
text = s[self._target_field]
|
| 116 |
+
length = len(text.split())
|
| 117 |
+
rarity_score = len(set(text)) / max(1, len(text))
|
| 118 |
+
scores.append(length * (1 - rarity_score))
|
| 119 |
+
order = sorted(range(len(self.dataset)), key=lambda i: scores[i])
|
| 120 |
+
return self.dataset.select(order)
|
| 121 |
+
|
| 122 |
+
# ── Item ──────────────────────────────────────────────────────────
|
| 123 |
+
|
| 124 |
+
def __len__(self):
|
| 125 |
+
return len(self.dataset)
|
| 126 |
+
|
| 127 |
+
def __getitem__(self, idx):
|
| 128 |
+
sample = self.dataset[idx]
|
| 129 |
+
|
| 130 |
+
src_text = sample[self._input_field].strip()
|
| 131 |
+
tgt_text = sample[self._target_field].strip()
|
| 132 |
+
|
| 133 |
+
input_ids = self._encode_src(src_text) # Roman encoded with src_tokenizer
|
| 134 |
+
target_ids = self._encode_tgt(tgt_text) # Devanagari encoded with tgt_tokenizer
|
| 135 |
+
|
| 136 |
+
out = {
|
| 137 |
+
'input_ids': input_ids,
|
| 138 |
+
'target_ids': target_ids,
|
| 139 |
+
'input_text': src_text,
|
| 140 |
+
'target_text': tgt_text,
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if self.include_negatives:
|
| 144 |
+
neg_ids = target_ids.clone()
|
| 145 |
+
# Reverse a random chunk of the DEVANAGARI target
|
| 146 |
+
non_pad = (neg_ids != self.pad_id).sum().item()
|
| 147 |
+
if non_pad > 4:
|
| 148 |
+
i1, i2 = sorted(random.sample(range(non_pad), 2))
|
| 149 |
+
neg_ids[i1:i2] = torch.flip(neg_ids[i1:i2], dims=[0])
|
| 150 |
+
out['negative_target_ids'] = neg_ids
|
| 151 |
+
|
| 152 |
+
return out
|
forward_process.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
forward_process.py — Verified Correct (no changes needed)
|
| 3 |
+
===========================================================
|
| 4 |
+
Absorbing (mask) diffusion. PAD never masked. At t=0 alpha=1.0 exactly
|
| 5 |
+
so x_t == x_0 (nothing masked). Works correctly with the fixed scheduler.
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class AbsorbingForwardProcess:
|
| 10 |
+
def __init__(self, scheduler, mask_id=0, pad_id=1):
|
| 11 |
+
self.scheduler = scheduler
|
| 12 |
+
self.mask_id = mask_id
|
| 13 |
+
self.pad_id = pad_id
|
| 14 |
+
|
| 15 |
+
def q_sample(self, x_0, t):
|
| 16 |
+
alpha_t = self.scheduler.get_alpha(t).to(x_0.device).view(-1, 1)
|
| 17 |
+
r = torch.rand(x_0.shape, device=x_0.device)
|
| 18 |
+
x_t = x_0.clone()
|
| 19 |
+
x_t[r > alpha_t] = self.mask_id
|
| 20 |
+
x_t[x_0 == self.pad_id] = self.pad_id # PAD stays PAD always
|
| 21 |
+
return x_0, x_t
|
inference.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py
|
| 3 |
+
============
|
| 4 |
+
Correct D3PM inference for Sanskrit paraphrase generation.
|
| 5 |
+
|
| 6 |
+
The model's forward() takes CLEAN tgt and noises it internally.
|
| 7 |
+
So inference passes x0_estimate (starting all-[MASK]) as tgt each step,
|
| 8 |
+
letting the model noise it and then predict a cleaner version.
|
| 9 |
+
|
| 10 |
+
Also includes: robust checkpoint loading (auto-detects architecture
|
| 11 |
+
from saved weights — no CONFIG mismatch crashes).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import os, sys
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from torch.utils.data import DataLoader, Subset
|
| 19 |
+
|
| 20 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 21 |
+
from config import CONFIG
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ── Checkpoint loader ─────────────────────────────────────────────────
|
| 25 |
+
|
| 26 |
+
def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
|
| 27 |
+
"""
|
| 28 |
+
Auto-detect architecture from checkpoint weight shapes,
|
| 29 |
+
then load. Never fails due to CONFIG vs checkpoint mismatch.
|
| 30 |
+
"""
|
| 31 |
+
import copy
|
| 32 |
+
from model.sanskrit_model import SanskritModel
|
| 33 |
+
|
| 34 |
+
cfg = copy.deepcopy(base_cfg)
|
| 35 |
+
state = torch.load(ckpt_path, map_location='cpu')
|
| 36 |
+
|
| 37 |
+
# d_model + vocab_size
|
| 38 |
+
ek = 'model.src_embed.token_emb.weight'
|
| 39 |
+
if ek in state:
|
| 40 |
+
vocab, d = state[ek].shape
|
| 41 |
+
cfg['model']['vocab_size'] = vocab
|
| 42 |
+
cfg['model']['d_model'] = d
|
| 43 |
+
cfg['model']['d_ff'] = d * 4
|
| 44 |
+
|
| 45 |
+
# n_layers
|
| 46 |
+
ids = {int(k.split('.')[2]) for k in state if k.startswith('model.encoder_blocks.')}
|
| 47 |
+
if ids:
|
| 48 |
+
cfg['model']['n_layers'] = max(ids) + 1
|
| 49 |
+
|
| 50 |
+
# max_seq_len
|
| 51 |
+
pk = 'model.src_embed.pos_enc.pe'
|
| 52 |
+
if pk in state:
|
| 53 |
+
cfg['model']['max_seq_len'] = state[pk].shape[1]
|
| 54 |
+
|
| 55 |
+
# n_heads
|
| 56 |
+
d = cfg['model']['d_model']
|
| 57 |
+
h = cfg['model'].get('n_heads', 6)
|
| 58 |
+
if d % h != 0:
|
| 59 |
+
h = next(x for x in [8, 6, 4, 2, 1] if d % x == 0)
|
| 60 |
+
cfg['model']['n_heads'] = h
|
| 61 |
+
|
| 62 |
+
print(f"🔍 Detected: d_model={cfg['model']['d_model']}, "
|
| 63 |
+
f"n_layers={cfg['model']['n_layers']}, "
|
| 64 |
+
f"max_seq_len={cfg['model']['max_seq_len']}, "
|
| 65 |
+
f"n_heads={cfg['model']['n_heads']}")
|
| 66 |
+
|
| 67 |
+
model = SanskritModel(cfg).to(device)
|
| 68 |
+
missing, unexpected = model.load_state_dict(
|
| 69 |
+
torch.load(ckpt_path, map_location=device), strict=False
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# hint_gate may be absent in older checkpoints — initialise safely
|
| 73 |
+
allowed = {'model.hint_gate.0.weight', 'model.hint_gate.0.bias'}
|
| 74 |
+
real_missing = [k for k in missing if k not in allowed]
|
| 75 |
+
if real_missing:
|
| 76 |
+
print(f"⚠️ Missing keys: {real_missing[:3]} …")
|
| 77 |
+
if unexpected:
|
| 78 |
+
print(f"⚠️ Unexpected keys: {unexpected[:3]} …")
|
| 79 |
+
if hasattr(model.model, 'hint_gate') and 'model.hint_gate.0.weight' in missing:
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
w = model.model.hint_gate[0].weight
|
| 82 |
+
torch.nn.init.zeros_(model.model.hint_gate[0].bias)
|
| 83 |
+
torch.nn.init.eye_(w) if w.shape[0] == w.shape[1] \
|
| 84 |
+
else torch.nn.init.xavier_uniform_(w)
|
| 85 |
+
print("ℹ️ hint_gate initialised to identity (not in checkpoint).")
|
| 86 |
+
|
| 87 |
+
print("✅ Model loaded.")
|
| 88 |
+
return model, cfg
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ── Core inference function ───────────────────────────────────────────
|
| 92 |
+
|
| 93 |
+
def run_inference(model, input_ids, cfg):
|
| 94 |
+
"""
|
| 95 |
+
Correct D3PM iterative refinement.
|
| 96 |
+
|
| 97 |
+
x0_est starts as all [MASK].
|
| 98 |
+
Each step: model(src, x0_est, t) noises x0_est internally,
|
| 99 |
+
then predicts a cleaner version. x0_est is updated each step.
|
| 100 |
+
"""
|
| 101 |
+
inf = cfg['inference']
|
| 102 |
+
device = input_ids.device
|
| 103 |
+
B, L = input_ids.shape
|
| 104 |
+
|
| 105 |
+
inner = model.model
|
| 106 |
+
T = inner.scheduler.num_timesteps
|
| 107 |
+
steps = inf['num_steps'] # must equal T (set in config)
|
| 108 |
+
step_size = max(1, T // steps)
|
| 109 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 110 |
+
if timesteps[-1] != 0:
|
| 111 |
+
timesteps.append(0)
|
| 112 |
+
|
| 113 |
+
mask_id = inner.mask_token_id
|
| 114 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 115 |
+
hint = None
|
| 116 |
+
|
| 117 |
+
model.eval()
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 120 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 121 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 122 |
+
|
| 123 |
+
logits, _ = model(input_ids, x0_est, t, x0_hint=hint)
|
| 124 |
+
|
| 125 |
+
# Penalties
|
| 126 |
+
if inf['repetition_penalty'] != 1.0:
|
| 127 |
+
from model.d3pm_model_cross_attention import _apply_repetition_penalty
|
| 128 |
+
logits = _apply_repetition_penalty(
|
| 129 |
+
logits, x0_est, inf['repetition_penalty']
|
| 130 |
+
)
|
| 131 |
+
if inf['diversity_penalty'] > 0.0:
|
| 132 |
+
from model.d3pm_model_cross_attention import _apply_diversity_penalty
|
| 133 |
+
logits = _apply_diversity_penalty(logits, inf['diversity_penalty'])
|
| 134 |
+
|
| 135 |
+
logits = logits / max(inf['temperature'], 1e-5)
|
| 136 |
+
if inf['top_k'] > 0:
|
| 137 |
+
from model.d3pm_model_cross_attention import _top_k_filter
|
| 138 |
+
logits = _top_k_filter(logits, inf['top_k'])
|
| 139 |
+
|
| 140 |
+
probs = F.softmax(logits, dim=-1)
|
| 141 |
+
|
| 142 |
+
if is_last:
|
| 143 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 144 |
+
else:
|
| 145 |
+
from model.d3pm_model_cross_attention import _batch_multinomial
|
| 146 |
+
x0_est = _batch_multinomial(probs)
|
| 147 |
+
|
| 148 |
+
hint = x0_est
|
| 149 |
+
|
| 150 |
+
return x0_est
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ── Interactive demo ──────────────────────────────────────────────────
|
| 154 |
+
|
| 155 |
+
def interactive_demo():
|
| 156 |
+
from model.tokenizer import SanskritTokenizer
|
| 157 |
+
|
| 158 |
+
cfg = CONFIG
|
| 159 |
+
device = torch.device(cfg['training']['device'])
|
| 160 |
+
|
| 161 |
+
model_name = cfg['model_type']
|
| 162 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 163 |
+
ckpt = f"results/{model_name}_neg_{has_neg}/best_model.pt"
|
| 164 |
+
|
| 165 |
+
if not os.path.exists(ckpt):
|
| 166 |
+
raise FileNotFoundError(f"No checkpoint at {ckpt} — train first.")
|
| 167 |
+
|
| 168 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 169 |
+
model.eval()
|
| 170 |
+
|
| 171 |
+
tokenizer = SanskritTokenizer(cfg['model']['vocab_size'])
|
| 172 |
+
PAD_ID = tokenizer.tokenizer.token_to_id('[PAD]') or 1
|
| 173 |
+
MASK_ID = cfg['diffusion']['mask_token_id']
|
| 174 |
+
|
| 175 |
+
print("\n" + "="*55)
|
| 176 |
+
print("Sanskrit D3PM Paraphrase — type verse, get paraphrase")
|
| 177 |
+
print("="*55 + "\n")
|
| 178 |
+
|
| 179 |
+
while True:
|
| 180 |
+
try:
|
| 181 |
+
text = input("INPUT > ").strip()
|
| 182 |
+
except (EOFError, KeyboardInterrupt):
|
| 183 |
+
break
|
| 184 |
+
if not text or text.lower() in ('quit', 'exit', 'q'):
|
| 185 |
+
break
|
| 186 |
+
|
| 187 |
+
ids = torch.tensor(
|
| 188 |
+
[tokenizer.encode(text)[:cfg['model']['max_seq_len']]],
|
| 189 |
+
dtype=torch.long, device=device
|
| 190 |
+
)
|
| 191 |
+
out = run_inference(model, ids, cfg)
|
| 192 |
+
clean = [i for i in out[0].tolist() if i not in (MASK_ID, PAD_ID)]
|
| 193 |
+
print(f"PARAPHRASE → {tokenizer.decode(clean).strip()}\n")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ── Batch evaluation ──────────────────────────────────────────────────
|
| 197 |
+
|
| 198 |
+
def batch_evaluate(sample_size=500):
|
| 199 |
+
from data.dataset import OptimizedSanskritDataset
|
| 200 |
+
from model.tokenizer import SanskritTokenizer
|
| 201 |
+
|
| 202 |
+
cfg = CONFIG
|
| 203 |
+
device = torch.device(cfg['training']['device'])
|
| 204 |
+
|
| 205 |
+
model_name = cfg['model_type']
|
| 206 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 207 |
+
exp_dir = f"results/{model_name}_neg_{has_neg}"
|
| 208 |
+
ckpt = f"{exp_dir}/best_model.pt"
|
| 209 |
+
|
| 210 |
+
if not os.path.exists(ckpt):
|
| 211 |
+
raise FileNotFoundError(f"No checkpoint at {ckpt}")
|
| 212 |
+
|
| 213 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 214 |
+
model.eval()
|
| 215 |
+
|
| 216 |
+
tokenizer = SanskritTokenizer(cfg['model']['vocab_size'])
|
| 217 |
+
PAD_ID = tokenizer.tokenizer.token_to_id('[PAD]') or 1
|
| 218 |
+
MASK_ID = cfg['diffusion']['mask_token_id']
|
| 219 |
+
|
| 220 |
+
def collate(batch):
|
| 221 |
+
return {
|
| 222 |
+
'input_ids': torch.stack([b['input_ids'].long() for b in batch]),
|
| 223 |
+
'target_text': [b['target_text'] for b in batch],
|
| 224 |
+
'input_text': [b['input_text'] for b in batch],
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
dataset = OptimizedSanskritDataset('test', tokenizer, cfg['model']['max_seq_len'], cfg)
|
| 228 |
+
indices = list(range(min(sample_size, len(dataset))))
|
| 229 |
+
loader = DataLoader(
|
| 230 |
+
Subset(dataset, indices),
|
| 231 |
+
batch_size=cfg['training']['batch_size'],
|
| 232 |
+
shuffle=False, collate_fn=collate
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
all_preds, all_refs, all_inputs = [], [], []
|
| 236 |
+
print(f"⏳ Generating {len(indices)} paraphrases …")
|
| 237 |
+
|
| 238 |
+
for batch in tqdm(loader):
|
| 239 |
+
ids = batch['input_ids'].to(device)
|
| 240 |
+
out = run_inference(model, ids, cfg)
|
| 241 |
+
for i in range(out.size(0)):
|
| 242 |
+
clean = [x for x in out[i].tolist() if x not in (MASK_ID, PAD_ID)]
|
| 243 |
+
all_preds.append(tokenizer.decode(clean).strip())
|
| 244 |
+
all_refs.append(batch['target_text'][i].strip())
|
| 245 |
+
all_inputs.append(batch['input_text'][i].strip())
|
| 246 |
+
|
| 247 |
+
# Metrics
|
| 248 |
+
bleu_score, bert_f1 = 0.0, 0.0
|
| 249 |
+
try:
|
| 250 |
+
from nltk.translate.bleu_score import corpus_bleu
|
| 251 |
+
bleu_score = corpus_bleu(
|
| 252 |
+
[[r.split()] for r in all_refs],
|
| 253 |
+
[p.split() for p in all_preds]
|
| 254 |
+
)
|
| 255 |
+
except Exception:
|
| 256 |
+
pass
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
import evaluate as hf_eval
|
| 260 |
+
res = hf_eval.load('bertscore').compute(
|
| 261 |
+
predictions=all_preds, references=all_refs, lang='hi'
|
| 262 |
+
)
|
| 263 |
+
bert_f1 = sum(res['f1']) / len(res['f1'])
|
| 264 |
+
except Exception:
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
# Save
|
| 268 |
+
out_path = f"{exp_dir}/evaluation_results.txt"
|
| 269 |
+
with open(out_path, 'w', encoding='utf-8') as f:
|
| 270 |
+
f.write(f"Model : {model_name}\n")
|
| 271 |
+
f.write(f"Negatives: {has_neg}\n")
|
| 272 |
+
f.write(f"Steps : {cfg['inference']['num_steps']}\n")
|
| 273 |
+
f.write(f"Temp : {cfg['inference']['temperature']}\n")
|
| 274 |
+
f.write(f"RepPen : {cfg['inference']['repetition_penalty']}\n")
|
| 275 |
+
f.write(f"DivPen : {cfg['inference']['diversity_penalty']}\n")
|
| 276 |
+
f.write(f"BLEU : {bleu_score:.4f}\n")
|
| 277 |
+
f.write(f"BERTScore: {bert_f1:.4f}\n\n")
|
| 278 |
+
f.write("=== SAMPLES ===\n")
|
| 279 |
+
for i in range(min(20, len(all_preds))):
|
| 280 |
+
f.write(f"IN : {all_inputs[i]}\n")
|
| 281 |
+
f.write(f"REF : {all_refs[i]}\n")
|
| 282 |
+
f.write(f"PRED: {all_preds[i]}\n")
|
| 283 |
+
f.write("-" * 60 + "\n")
|
| 284 |
+
|
| 285 |
+
print(f"\n✅ Results → {out_path}")
|
| 286 |
+
print(f"📊 BLEU: {bleu_score:.4f} | BERTScore: {bert_f1:.4f}")
|
| 287 |
+
return all_preds, all_refs
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
if __name__ == '__main__':
|
| 291 |
+
import argparse
|
| 292 |
+
p = argparse.ArgumentParser()
|
| 293 |
+
p.add_argument('--mode', choices=['demo', 'eval'], default='demo')
|
| 294 |
+
p.add_argument('--samples', type=int, default=500)
|
| 295 |
+
args = p.parse_args()
|
| 296 |
+
|
| 297 |
+
if args.mode == 'demo':
|
| 298 |
+
interactive_demo()
|
| 299 |
+
else:
|
| 300 |
+
batch_evaluate(args.samples)
|
kv_cache_benchmark.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis/kv_cache_benchmark.py
|
| 3 |
+
================================
|
| 4 |
+
Task 1: Benchmark KV cache vs standard generate().
|
| 5 |
+
|
| 6 |
+
Measures:
|
| 7 |
+
- Wall-clock time for generate() vs generate_cached()
|
| 8 |
+
- Encoder time as % of total generation time (before/after)
|
| 9 |
+
- Speedup ratio at src_len = 16, 32, 64 tokens
|
| 10 |
+
|
| 11 |
+
How it works:
|
| 12 |
+
Standard generate():
|
| 13 |
+
For each of T=128 steps:
|
| 14 |
+
src → encoder → memory → decoder → logits (encoder runs 128 times)
|
| 15 |
+
|
| 16 |
+
generate_cached():
|
| 17 |
+
src → encoder → memory (once)
|
| 18 |
+
For each of T=128 steps:
|
| 19 |
+
cached_memory → decoder → logits (encoder runs 1 time)
|
| 20 |
+
|
| 21 |
+
Expected speedup:
|
| 22 |
+
If encoder = 30% of per-step time:
|
| 23 |
+
Saved = 127/128 * 30% ≈ 29.7% of total time
|
| 24 |
+
If encoder = 50% of per-step time:
|
| 25 |
+
Saved ≈ 49.6% of total time
|
| 26 |
+
|
| 27 |
+
Usage:
|
| 28 |
+
python -m analysis.kv_cache_benchmark
|
| 29 |
+
or:
|
| 30 |
+
from analysis.kv_cache_benchmark import run_benchmark
|
| 31 |
+
results = run_benchmark(model, src_tokenizer, device)
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import time
|
| 36 |
+
import numpy as np
|
| 37 |
+
from typing import Dict, List
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _make_src(src_len: int, src_vocab: int, device: torch.device, batch_size: int = 1):
|
| 41 |
+
"""Create a random source tensor of given length."""
|
| 42 |
+
# Random real tokens (ids 5..src_vocab-1), padded to src_len
|
| 43 |
+
ids = torch.randint(5, src_vocab, (batch_size, src_len), device=device)
|
| 44 |
+
return ids
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _time_fn(fn, n_warmup: int = 2, n_runs: int = 5) -> float:
|
| 48 |
+
"""
|
| 49 |
+
Time a zero-argument callable.
|
| 50 |
+
Returns mean wall-clock seconds over n_runs after n_warmup warmup calls.
|
| 51 |
+
"""
|
| 52 |
+
# Warmup
|
| 53 |
+
for _ in range(n_warmup):
|
| 54 |
+
fn()
|
| 55 |
+
if torch.cuda.is_available():
|
| 56 |
+
torch.cuda.synchronize()
|
| 57 |
+
elif torch.backends.mps.is_available():
|
| 58 |
+
torch.mps.synchronize()
|
| 59 |
+
|
| 60 |
+
times = []
|
| 61 |
+
for _ in range(n_runs):
|
| 62 |
+
start = time.perf_counter()
|
| 63 |
+
fn()
|
| 64 |
+
if torch.cuda.is_available():
|
| 65 |
+
torch.cuda.synchronize()
|
| 66 |
+
elif torch.backends.mps.is_available():
|
| 67 |
+
torch.mps.synchronize()
|
| 68 |
+
times.append(time.perf_counter() - start)
|
| 69 |
+
|
| 70 |
+
return float(np.mean(times))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def benchmark_encoder_cost(
|
| 74 |
+
model,
|
| 75 |
+
src: torch.Tensor,
|
| 76 |
+
) -> Dict[str, float]:
|
| 77 |
+
"""
|
| 78 |
+
Measure encoder time as a fraction of one full forward pass.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
encoder_s : seconds for one encoder call
|
| 82 |
+
full_step_s : seconds for one full forward_cached decoder step
|
| 83 |
+
encoder_pct : encoder_s / (encoder_s + full_step_s) * 100
|
| 84 |
+
"""
|
| 85 |
+
inner = model.model
|
| 86 |
+
if not hasattr(inner, 'encode_source'):
|
| 87 |
+
raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
|
| 88 |
+
|
| 89 |
+
device = src.device
|
| 90 |
+
B = src.shape[0]
|
| 91 |
+
T = inner.scheduler.num_timesteps
|
| 92 |
+
tgt_len = inner.max_seq_len
|
| 93 |
+
mask_id = inner.mask_token_id
|
| 94 |
+
|
| 95 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 96 |
+
t = torch.zeros(B, dtype=torch.long, device=device)
|
| 97 |
+
|
| 98 |
+
# Time encoder alone
|
| 99 |
+
encoder_s = _time_fn(lambda: inner.encode_source(src))
|
| 100 |
+
|
| 101 |
+
# Pre-compute memory for decoder timing
|
| 102 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 103 |
+
|
| 104 |
+
# Time one decoder step (cached)
|
| 105 |
+
decoder_s = _time_fn(
|
| 106 |
+
lambda: inner.forward_cached(memory, src_pad_mask, x0_est, t,
|
| 107 |
+
inference_mode=True)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Time one full step (non-cached = encoder + decoder)
|
| 111 |
+
full_s = _time_fn(
|
| 112 |
+
lambda: inner.forward(src, x0_est, t, inference_mode=True)
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
encoder_pct = 100.0 * encoder_s / max(full_s, 1e-9)
|
| 116 |
+
|
| 117 |
+
return {
|
| 118 |
+
"encoder_s": encoder_s,
|
| 119 |
+
"decoder_s": decoder_s,
|
| 120 |
+
"full_step_s": full_s,
|
| 121 |
+
"encoder_pct": encoder_pct,
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def run_benchmark(
|
| 126 |
+
model,
|
| 127 |
+
src_tokenizer,
|
| 128 |
+
device: torch.device,
|
| 129 |
+
src_lens: List[int] = [16, 32, 64],
|
| 130 |
+
n_runs: int = 5,
|
| 131 |
+
) -> Dict:
|
| 132 |
+
"""
|
| 133 |
+
Full benchmark: compare generate() vs generate_cached() at multiple src lengths.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
model : SanskritModel (D3PMCrossAttention)
|
| 137 |
+
src_tokenizer : SanskritSourceTokenizer
|
| 138 |
+
device : torch.device
|
| 139 |
+
src_lens : list of source lengths to benchmark
|
| 140 |
+
n_runs : number of timing runs per condition
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
results dict with timing and speedup for each src_len
|
| 144 |
+
"""
|
| 145 |
+
inner = model.model
|
| 146 |
+
if not hasattr(inner, 'generate_cached'):
|
| 147 |
+
raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
|
| 148 |
+
|
| 149 |
+
src_vocab = inner.src_embed.token_emb.weight.shape[0]
|
| 150 |
+
results = {}
|
| 151 |
+
|
| 152 |
+
print("\n" + "=" * 65)
|
| 153 |
+
print(" KV CACHE BENCHMARK")
|
| 154 |
+
print("=" * 65)
|
| 155 |
+
print(f" {'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
|
| 156 |
+
f"{'speedup':>8} {'encoder%':>9}")
|
| 157 |
+
print("-" * 65)
|
| 158 |
+
|
| 159 |
+
for src_len in src_lens:
|
| 160 |
+
src = _make_src(src_len, src_vocab, device)
|
| 161 |
+
|
| 162 |
+
# Encoder cost breakdown
|
| 163 |
+
enc_cost = benchmark_encoder_cost(model, src)
|
| 164 |
+
|
| 165 |
+
# Time standard generate() — encoder runs T times
|
| 166 |
+
def run_standard():
|
| 167 |
+
return inner.generate(src, temperature=0.8, top_k=40)
|
| 168 |
+
|
| 169 |
+
# Time generate_cached() — encoder runs once
|
| 170 |
+
def run_cached():
|
| 171 |
+
return inner.generate_cached(src, temperature=0.8, top_k=40)
|
| 172 |
+
|
| 173 |
+
t_standard = _time_fn(run_standard, n_warmup=1, n_runs=n_runs)
|
| 174 |
+
t_cached = _time_fn(run_cached, n_warmup=1, n_runs=n_runs)
|
| 175 |
+
speedup = t_standard / max(t_cached, 1e-9)
|
| 176 |
+
|
| 177 |
+
results[src_len] = {
|
| 178 |
+
"standard_s": t_standard,
|
| 179 |
+
"cached_s": t_cached,
|
| 180 |
+
"speedup": speedup,
|
| 181 |
+
"encoder_pct": enc_cost["encoder_pct"],
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
print(f" {src_len:>8} {t_standard:>12.3f} {t_cached:>10.3f} "
|
| 185 |
+
f"{speedup:>7.2f}x {enc_cost['encoder_pct']:>8.1f}%")
|
| 186 |
+
|
| 187 |
+
print("=" * 65)
|
| 188 |
+
print(f"\n Encoder cost = % of one full forward pass")
|
| 189 |
+
print(f" Speedup = standard_time / cached_time")
|
| 190 |
+
print(f" Expected: speedup ≈ 1 / (1 - encoder_pct/100 * (T-1)/T)")
|
| 191 |
+
|
| 192 |
+
return results
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def print_summary(results: Dict):
|
| 196 |
+
"""Print a human-readable summary of benchmark results."""
|
| 197 |
+
print("\n SUMMARY")
|
| 198 |
+
print(" -------")
|
| 199 |
+
for src_len, r in results.items():
|
| 200 |
+
saved_pct = (1.0 - 1.0 / r["speedup"]) * 100
|
| 201 |
+
print(f" src_len={src_len}: {r['speedup']:.2f}x speedup "
|
| 202 |
+
f"({saved_pct:.1f}% time saved, "
|
| 203 |
+
f"encoder was {r['encoder_pct']:.1f}% of total)")
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
if __name__ == "__main__":
|
| 207 |
+
import sys, os
|
| 208 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 209 |
+
from config import CONFIG
|
| 210 |
+
from inference import load_model
|
| 211 |
+
from models.tokenizer import SanskritSourceTokenizer
|
| 212 |
+
|
| 213 |
+
cfg = CONFIG
|
| 214 |
+
device = torch.device(cfg['training']['device'])
|
| 215 |
+
|
| 216 |
+
model_name = cfg['model_type']
|
| 217 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 218 |
+
ckpt = f"results7/{model_name}_neg_{has_neg}/best_model.pt"
|
| 219 |
+
|
| 220 |
+
if not os.path.exists(ckpt):
|
| 221 |
+
print(f"No checkpoint at {ckpt}. Train first.")
|
| 222 |
+
sys.exit(1)
|
| 223 |
+
|
| 224 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 225 |
+
model.eval()
|
| 226 |
+
|
| 227 |
+
src_tokenizer = SanskritSourceTokenizer(
|
| 228 |
+
vocab_size = cfg['model'].get('src_vocab_size', 500),
|
| 229 |
+
max_len = cfg['model']['max_seq_len'],
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
results = run_benchmark(model, src_tokenizer, device)
|
| 233 |
+
print_summary(results)
|
quality_classifier.py
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis/quality_classifier.py
|
| 3 |
+
================================
|
| 4 |
+
Task 5: Classifier-Free Guidance for Paraphrase Quality Control
|
| 5 |
+
|
| 6 |
+
Two steps — only Step 2 requires training a SMALL model (not the main D3PM):
|
| 7 |
+
|
| 8 |
+
STEP 1 — Collect training data (no training):
|
| 9 |
+
Run existing model on val set, record (hidden_state, CER) pairs.
|
| 10 |
+
Hidden states come from model.model._last_hidden after forward_cached().
|
| 11 |
+
CER score = quality label (lower CER = higher quality).
|
| 12 |
+
|
| 13 |
+
STEP 2 — Train quality classifier:
|
| 14 |
+
Small 2-layer MLP: d_model → 64 → 1
|
| 15 |
+
Input: pooled decoder hidden state [B, d_model]
|
| 16 |
+
Output: predicted quality score in [0, 1] (1 = high quality)
|
| 17 |
+
Loss: MSE against normalized CER labels
|
| 18 |
+
Training time: ~5-10 minutes on CPU for 10k examples
|
| 19 |
+
|
| 20 |
+
STEP 3 — Guided inference (no retraining):
|
| 21 |
+
At each diffusion step, use classifier gradient to shift logits:
|
| 22 |
+
guided_logits = logits + λ * ∂(quality_score)/∂(logits)
|
| 23 |
+
Higher λ → model biased toward high-quality outputs
|
| 24 |
+
λ=0 → standard generation (no guidance)
|
| 25 |
+
|
| 26 |
+
Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
import numpy as np
|
| 33 |
+
import os
|
| 34 |
+
import json
|
| 35 |
+
from typing import List, Dict, Optional, Tuple
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ── Quality classifier architecture ──────────────────────────────────
|
| 39 |
+
|
| 40 |
+
class QualityClassifier(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
Lightweight MLP that predicts transliteration quality from decoder
|
| 43 |
+
hidden states.
|
| 44 |
+
|
| 45 |
+
Architecture:
|
| 46 |
+
d_model → 128 → 64 → 1 → Sigmoid
|
| 47 |
+
|
| 48 |
+
Input: mean-pooled decoder hidden state [B, d_model]
|
| 49 |
+
Output: quality score [B, 1] ∈ [0, 1] (1 = high quality)
|
| 50 |
+
|
| 51 |
+
~10k parameters. Trains in minutes on CPU.
|
| 52 |
+
"""
|
| 53 |
+
def __init__(self, d_model: int):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.net = nn.Sequential(
|
| 56 |
+
nn.Linear(d_model, 128),
|
| 57 |
+
nn.ReLU(),
|
| 58 |
+
nn.Dropout(0.1),
|
| 59 |
+
nn.Linear(128, 64),
|
| 60 |
+
nn.ReLU(),
|
| 61 |
+
nn.Linear(64, 1),
|
| 62 |
+
nn.Sigmoid(),
|
| 63 |
+
)
|
| 64 |
+
self.d_model = d_model
|
| 65 |
+
|
| 66 |
+
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled)
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
score : [B, 1] quality score in [0, 1]
|
| 73 |
+
"""
|
| 74 |
+
if hidden.dim() == 3:
|
| 75 |
+
# Pool over sequence length
|
| 76 |
+
hidden = hidden.mean(dim=1) # [B, d_model]
|
| 77 |
+
return self.net(hidden) # [B, 1]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ── Training data collection ──────────────────────────────────────────
|
| 81 |
+
|
| 82 |
+
@torch.no_grad()
|
| 83 |
+
def collect_quality_data(
|
| 84 |
+
model,
|
| 85 |
+
src_list: List[torch.Tensor],
|
| 86 |
+
ref_list: List[str],
|
| 87 |
+
tgt_tokenizer,
|
| 88 |
+
t_capture: int = 0,
|
| 89 |
+
temperature: float = 0.8,
|
| 90 |
+
top_k: int = 40,
|
| 91 |
+
max_samples: int = 5000,
|
| 92 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 93 |
+
"""
|
| 94 |
+
Collect (hidden_state, quality_score) pairs for classifier training.
|
| 95 |
+
|
| 96 |
+
For each sample:
|
| 97 |
+
1. Run generate_cached() on src
|
| 98 |
+
2. Capture decoder hidden state at t=t_capture
|
| 99 |
+
3. Compute CER between output and reference
|
| 100 |
+
4. Quality = 1 - CER (normalize to [0,1])
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
model : SanskritModel
|
| 104 |
+
src_list : list of [1, src_len] tensors
|
| 105 |
+
ref_list : list of reference Devanagari strings
|
| 106 |
+
tgt_tokenizer : SanskritTargetTokenizer
|
| 107 |
+
t_capture : which step to capture hidden states (0 = final)
|
| 108 |
+
max_samples : cap number of training examples
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
hidden_matrix : np.ndarray [N, d_model]
|
| 112 |
+
quality_scores: np.ndarray [N] values in [0, 1]
|
| 113 |
+
"""
|
| 114 |
+
inner = model.model
|
| 115 |
+
T = inner.scheduler.num_timesteps
|
| 116 |
+
device = next(inner.parameters()).device
|
| 117 |
+
|
| 118 |
+
hidden_list = []
|
| 119 |
+
quality_list = []
|
| 120 |
+
n = min(len(src_list), max_samples)
|
| 121 |
+
|
| 122 |
+
def cer(pred, ref):
|
| 123 |
+
if not ref:
|
| 124 |
+
return 1.0
|
| 125 |
+
def ed(s1, s2):
|
| 126 |
+
m, n = len(s1), len(s2)
|
| 127 |
+
dp = list(range(n + 1))
|
| 128 |
+
for i in range(1, m + 1):
|
| 129 |
+
prev, dp[0] = dp[0], i
|
| 130 |
+
for j in range(1, n + 1):
|
| 131 |
+
temp = dp[j]
|
| 132 |
+
dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 133 |
+
prev = temp
|
| 134 |
+
return dp[n]
|
| 135 |
+
return ed(pred, ref) / max(len(ref), 1)
|
| 136 |
+
|
| 137 |
+
print(f"Collecting quality data from {n} examples...")
|
| 138 |
+
for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
|
| 139 |
+
if i % 200 == 0:
|
| 140 |
+
print(f" {i}/{n}")
|
| 141 |
+
|
| 142 |
+
if src.dim() == 1:
|
| 143 |
+
src = src.unsqueeze(0)
|
| 144 |
+
src = src.to(device)
|
| 145 |
+
|
| 146 |
+
B = src.shape[0]
|
| 147 |
+
tgt_len = inner.max_seq_len
|
| 148 |
+
mask_id = inner.mask_token_id
|
| 149 |
+
|
| 150 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 151 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 152 |
+
hint = None
|
| 153 |
+
h_cap = None
|
| 154 |
+
|
| 155 |
+
for t_val in range(T - 1, -1, -1):
|
| 156 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 157 |
+
is_last = (t_val == 0)
|
| 158 |
+
|
| 159 |
+
logits, _ = inner.forward_cached(
|
| 160 |
+
memory, src_pad_mask, x0_est, t,
|
| 161 |
+
x0_hint=hint, inference_mode=True,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if t_val == t_capture and hasattr(inner, '_last_hidden'):
|
| 165 |
+
h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() # [d_model]
|
| 166 |
+
|
| 167 |
+
logits = logits / max(temperature, 1e-8)
|
| 168 |
+
if top_k > 0:
|
| 169 |
+
V = logits.shape[-1]
|
| 170 |
+
if top_k < V:
|
| 171 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 172 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 173 |
+
|
| 174 |
+
probs = F.softmax(logits, dim=-1)
|
| 175 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 176 |
+
hint = x0_est
|
| 177 |
+
|
| 178 |
+
if h_cap is None:
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
ids = [x for x in x0_est[0].tolist() if x > 4]
|
| 182 |
+
pred = tgt_tokenizer.decode(ids).strip()
|
| 183 |
+
q = max(0.0, 1.0 - cer(pred, ref)) # quality = 1 - CER
|
| 184 |
+
|
| 185 |
+
hidden_list.append(h_cap.numpy())
|
| 186 |
+
quality_list.append(q)
|
| 187 |
+
|
| 188 |
+
print(f"Collected {len(hidden_list)} quality examples.")
|
| 189 |
+
print(f"Quality stats: mean={np.mean(quality_list):.3f} "
|
| 190 |
+
f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}")
|
| 191 |
+
|
| 192 |
+
return np.stack(hidden_list), np.array(quality_list, dtype=np.float32)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _sample(probs):
|
| 196 |
+
B, L, V = probs.shape
|
| 197 |
+
flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 198 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 199 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ── Training ──────────────────────────────────────────────────────────
|
| 203 |
+
|
| 204 |
+
def train_quality_classifier(
|
| 205 |
+
hidden_matrix: np.ndarray,
|
| 206 |
+
quality_scores: np.ndarray,
|
| 207 |
+
d_model: int,
|
| 208 |
+
epochs: int = 30,
|
| 209 |
+
batch_size: int = 64,
|
| 210 |
+
lr: float = 1e-3,
|
| 211 |
+
val_frac: float = 0.1,
|
| 212 |
+
save_path: Optional[str] = None,
|
| 213 |
+
) -> QualityClassifier:
|
| 214 |
+
"""
|
| 215 |
+
Train QualityClassifier on collected (hidden, quality) pairs.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
hidden_matrix : [N, d_model] from collect_quality_data()
|
| 219 |
+
quality_scores : [N] quality labels in [0, 1]
|
| 220 |
+
d_model : hidden dimension
|
| 221 |
+
epochs : training epochs
|
| 222 |
+
save_path : if given, save trained classifier weights here
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
trained QualityClassifier
|
| 226 |
+
"""
|
| 227 |
+
device = torch.device("cpu") # classifier is tiny, CPU is fine
|
| 228 |
+
|
| 229 |
+
X = torch.tensor(hidden_matrix, dtype=torch.float32)
|
| 230 |
+
y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1)
|
| 231 |
+
|
| 232 |
+
N = len(X)
|
| 233 |
+
n_val = max(1, int(N * val_frac))
|
| 234 |
+
idx = torch.randperm(N)
|
| 235 |
+
val_idx = idx[:n_val]
|
| 236 |
+
train_idx = idx[n_val:]
|
| 237 |
+
|
| 238 |
+
X_train, y_train = X[train_idx], y[train_idx]
|
| 239 |
+
X_val, y_val = X[val_idx], y[val_idx]
|
| 240 |
+
|
| 241 |
+
clf = QualityClassifier(d_model).to(device)
|
| 242 |
+
optimizer = torch.optim.Adam(clf.parameters(), lr=lr)
|
| 243 |
+
|
| 244 |
+
print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
|
| 245 |
+
print(f"Train: {len(X_train)} Val: {len(X_val)}")
|
| 246 |
+
|
| 247 |
+
best_val_loss = float('inf')
|
| 248 |
+
best_state = None
|
| 249 |
+
|
| 250 |
+
for epoch in range(epochs):
|
| 251 |
+
clf.train()
|
| 252 |
+
perm = torch.randperm(len(X_train))
|
| 253 |
+
train_loss = 0.0
|
| 254 |
+
n_batches = 0
|
| 255 |
+
|
| 256 |
+
for start in range(0, len(X_train), batch_size):
|
| 257 |
+
batch_idx = perm[start:start + batch_size]
|
| 258 |
+
xb, yb = X_train[batch_idx], y_train[batch_idx]
|
| 259 |
+
pred = clf(xb)
|
| 260 |
+
loss = F.mse_loss(pred, yb)
|
| 261 |
+
optimizer.zero_grad()
|
| 262 |
+
loss.backward()
|
| 263 |
+
optimizer.step()
|
| 264 |
+
train_loss += loss.item()
|
| 265 |
+
n_batches += 1
|
| 266 |
+
|
| 267 |
+
clf.eval()
|
| 268 |
+
with torch.no_grad():
|
| 269 |
+
val_pred = clf(X_val)
|
| 270 |
+
val_loss = F.mse_loss(val_pred, y_val).item()
|
| 271 |
+
|
| 272 |
+
if epoch % 5 == 0 or epoch == epochs - 1:
|
| 273 |
+
print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}")
|
| 274 |
+
|
| 275 |
+
if val_loss < best_val_loss:
|
| 276 |
+
best_val_loss = val_loss
|
| 277 |
+
best_state = {k: v.clone() for k, v in clf.state_dict().items()}
|
| 278 |
+
|
| 279 |
+
if best_state:
|
| 280 |
+
clf.load_state_dict(best_state)
|
| 281 |
+
print(f" Best val loss: {best_val_loss:.4f}")
|
| 282 |
+
|
| 283 |
+
if save_path:
|
| 284 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 285 |
+
torch.save(clf.state_dict(), save_path)
|
| 286 |
+
print(f" Classifier saved: {save_path}")
|
| 287 |
+
|
| 288 |
+
return clf
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# ── Guided inference ──────────────────────────────────────────────────
|
| 292 |
+
|
| 293 |
+
def generate_guided(
|
| 294 |
+
model,
|
| 295 |
+
src: torch.Tensor,
|
| 296 |
+
classifier: QualityClassifier,
|
| 297 |
+
guidance_scale: float = 1.0,
|
| 298 |
+
temperature: float = 0.8,
|
| 299 |
+
top_k: int = 40,
|
| 300 |
+
) -> torch.Tensor:
|
| 301 |
+
"""
|
| 302 |
+
Classifier-guided generation.
|
| 303 |
+
|
| 304 |
+
At each diffusion step:
|
| 305 |
+
1. Run forward_cached() → logits, hidden states
|
| 306 |
+
2. Compute classifier gradient: ∂(quality_score) / ∂(hidden)
|
| 307 |
+
3. Project gradient back to logit space (approximate)
|
| 308 |
+
4. guided_logits = logits + λ * gradient_signal
|
| 309 |
+
5. Sample from guided_logits
|
| 310 |
+
|
| 311 |
+
guidance_scale λ:
|
| 312 |
+
0.0 → no guidance (standard generation)
|
| 313 |
+
0.5 → weak guidance
|
| 314 |
+
1.0 → moderate guidance (recommended starting point)
|
| 315 |
+
2.0 → strong guidance (may reduce diversity)
|
| 316 |
+
3.0 → very strong (may collapse to repetitive output)
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
model : SanskritModel (frozen)
|
| 320 |
+
src : [1, src_len] IAST token ids
|
| 321 |
+
classifier : trained QualityClassifier
|
| 322 |
+
guidance_scale : λ — guidance strength
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
x0_est : [1, tgt_len] generated token ids
|
| 326 |
+
"""
|
| 327 |
+
inner = model.model
|
| 328 |
+
T = inner.scheduler.num_timesteps
|
| 329 |
+
device = next(inner.parameters()).device
|
| 330 |
+
clf_device = next(classifier.parameters()).device
|
| 331 |
+
|
| 332 |
+
if src.dim() == 1:
|
| 333 |
+
src = src.unsqueeze(0)
|
| 334 |
+
src = src.to(device)
|
| 335 |
+
|
| 336 |
+
B = src.shape[0]
|
| 337 |
+
tgt_len = inner.max_seq_len
|
| 338 |
+
mask_id = inner.mask_token_id
|
| 339 |
+
|
| 340 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 341 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 342 |
+
hint = None
|
| 343 |
+
|
| 344 |
+
inner.eval()
|
| 345 |
+
classifier.eval()
|
| 346 |
+
|
| 347 |
+
for t_val in range(T - 1, -1, -1):
|
| 348 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 349 |
+
is_last = (t_val == 0)
|
| 350 |
+
|
| 351 |
+
if guidance_scale > 0.0:
|
| 352 |
+
# Need gradients for classifier guidance
|
| 353 |
+
with torch.enable_grad():
|
| 354 |
+
# Run forward_cached and get hidden states
|
| 355 |
+
PAD = 1
|
| 356 |
+
if t_val > 0:
|
| 357 |
+
_, x_t_ids = inner.forward_process.q_sample(x0_est, t)
|
| 358 |
+
else:
|
| 359 |
+
x_t_ids = x0_est
|
| 360 |
+
|
| 361 |
+
x = inner.tgt_embed(x_t_ids)
|
| 362 |
+
t_norm = t.float() / T
|
| 363 |
+
t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
|
| 364 |
+
x = x + t_emb.unsqueeze(1)
|
| 365 |
+
|
| 366 |
+
if hint is not None:
|
| 367 |
+
hint_emb = inner.tgt_embed(hint)
|
| 368 |
+
gate = inner.hint_gate(x)
|
| 369 |
+
x = x + gate * hint_emb
|
| 370 |
+
|
| 371 |
+
for block in inner.decoder_blocks:
|
| 372 |
+
x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
|
| 373 |
+
|
| 374 |
+
# hidden: [B, tgt_len, d_model] — detach from graph for clf
|
| 375 |
+
hidden = x.detach().requires_grad_(True).to(clf_device)
|
| 376 |
+
|
| 377 |
+
# Classifier quality score
|
| 378 |
+
quality = classifier(hidden) # [B, 1]
|
| 379 |
+
quality.sum().backward()
|
| 380 |
+
|
| 381 |
+
# Gradient of quality w.r.t. hidden: [B, tgt_len, d_model]
|
| 382 |
+
grad = hidden.grad.to(device) # [B, tgt_len, d_model]
|
| 383 |
+
|
| 384 |
+
# Project gradient to logit space via output head weight
|
| 385 |
+
# logit_grad ≈ grad @ head.weight [B, tgt_len, tgt_vocab]
|
| 386 |
+
logit_grad = grad @ inner.head.weight.T
|
| 387 |
+
|
| 388 |
+
# Compute standard logits (no gradient needed)
|
| 389 |
+
with torch.no_grad():
|
| 390 |
+
logits = inner.head(x)
|
| 391 |
+
|
| 392 |
+
# Apply guidance
|
| 393 |
+
logits = logits + guidance_scale * logit_grad
|
| 394 |
+
|
| 395 |
+
else:
|
| 396 |
+
with torch.no_grad():
|
| 397 |
+
logits, _ = inner.forward_cached(
|
| 398 |
+
memory, src_pad_mask, x0_est, t,
|
| 399 |
+
x0_hint=hint, inference_mode=True,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
with torch.no_grad():
|
| 403 |
+
logits = logits / max(temperature, 1e-8)
|
| 404 |
+
if top_k > 0:
|
| 405 |
+
V = logits.shape[-1]
|
| 406 |
+
if top_k < V:
|
| 407 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 408 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 409 |
+
|
| 410 |
+
probs = F.softmax(logits, dim=-1)
|
| 411 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs)
|
| 412 |
+
hint = x0_est
|
| 413 |
+
|
| 414 |
+
return x0_est
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _sample_no_grad(probs):
|
| 418 |
+
B, L, V = probs.shape
|
| 419 |
+
flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 420 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 421 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# ── Guidance scale sweep ──────────────────────────────────────────────
|
| 425 |
+
|
| 426 |
+
def sweep_guidance_scales(
|
| 427 |
+
model,
|
| 428 |
+
classifier: QualityClassifier,
|
| 429 |
+
src_list: List[torch.Tensor],
|
| 430 |
+
ref_list: List[str],
|
| 431 |
+
tgt_tokenizer,
|
| 432 |
+
scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
| 433 |
+
n_samples: int = 50,
|
| 434 |
+
device: torch.device = None,
|
| 435 |
+
output_dir: str = "analysis/outputs",
|
| 436 |
+
) -> Dict:
|
| 437 |
+
"""
|
| 438 |
+
Evaluate CER at each guidance scale.
|
| 439 |
+
Produces quality-diversity tradeoff plot.
|
| 440 |
+
"""
|
| 441 |
+
def cer(pred, ref):
|
| 442 |
+
if not ref:
|
| 443 |
+
return 1.0
|
| 444 |
+
def ed(s1, s2):
|
| 445 |
+
m, n = len(s1), len(s2)
|
| 446 |
+
dp = list(range(n + 1))
|
| 447 |
+
for i in range(1, m + 1):
|
| 448 |
+
prev, dp[0] = dp[0], i
|
| 449 |
+
for j in range(1, n + 1):
|
| 450 |
+
temp = dp[j]
|
| 451 |
+
dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 452 |
+
prev = temp
|
| 453 |
+
return dp[n]
|
| 454 |
+
return ed(pred, ref) / max(len(ref), 1)
|
| 455 |
+
|
| 456 |
+
device = device or next(model.parameters()).device
|
| 457 |
+
results = {}
|
| 458 |
+
n = min(n_samples, len(src_list))
|
| 459 |
+
|
| 460 |
+
print("\nGuidance scale sweep...")
|
| 461 |
+
for scale in scales:
|
| 462 |
+
cer_list = []
|
| 463 |
+
output_set = []
|
| 464 |
+
for src, ref in zip(src_list[:n], ref_list[:n]):
|
| 465 |
+
if src.dim() == 1:
|
| 466 |
+
src = src.unsqueeze(0)
|
| 467 |
+
out = generate_guided(model, src.to(device), classifier,
|
| 468 |
+
guidance_scale=scale)
|
| 469 |
+
ids = [x for x in out[0].tolist() if x > 4]
|
| 470 |
+
pred = tgt_tokenizer.decode(ids).strip()
|
| 471 |
+
cer_list.append(cer(pred, ref))
|
| 472 |
+
output_set.append(pred)
|
| 473 |
+
|
| 474 |
+
mean_cer = float(np.mean(cer_list))
|
| 475 |
+
|
| 476 |
+
# Self-diversity: unique outputs / total (proxy for diversity)
|
| 477 |
+
unique_frac = len(set(output_set)) / max(len(output_set), 1)
|
| 478 |
+
|
| 479 |
+
results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac}
|
| 480 |
+
print(f" λ={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}")
|
| 481 |
+
|
| 482 |
+
# Plot
|
| 483 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 484 |
+
try:
|
| 485 |
+
import matplotlib.pyplot as plt
|
| 486 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
| 487 |
+
|
| 488 |
+
sc_list = sorted(results.keys())
|
| 489 |
+
cers = [results[s]["mean_cer"] for s in sc_list]
|
| 490 |
+
diversities = [results[s]["diversity"] for s in sc_list]
|
| 491 |
+
|
| 492 |
+
ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7)
|
| 493 |
+
ax1.set_xlabel("Guidance scale λ", fontsize=10)
|
| 494 |
+
ax1.set_ylabel("CER (↓ better)", fontsize=10)
|
| 495 |
+
ax1.set_title("Quality vs guidance scale", fontsize=10)
|
| 496 |
+
|
| 497 |
+
ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7)
|
| 498 |
+
ax2.set_xlabel("Guidance scale λ", fontsize=10)
|
| 499 |
+
ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10)
|
| 500 |
+
ax2.set_title("Diversity vs guidance scale", fontsize=10)
|
| 501 |
+
|
| 502 |
+
plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11)
|
| 503 |
+
plt.tight_layout()
|
| 504 |
+
path = os.path.join(output_dir, "guidance_scale_sweep.png")
|
| 505 |
+
plt.savefig(path, dpi=150, bbox_inches='tight')
|
| 506 |
+
plt.close()
|
| 507 |
+
print(f" Saved: {path}")
|
| 508 |
+
except ImportError:
|
| 509 |
+
pass
|
| 510 |
+
|
| 511 |
+
with open(os.path.join(output_dir, "guidance_results.json"), "w") as f:
|
| 512 |
+
json.dump({str(k): v for k, v in results.items()}, f, indent=2)
|
| 513 |
+
|
| 514 |
+
return results
|
reverse_process.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
reverse_process.py — Fixed
|
| 3 |
+
===========================
|
| 4 |
+
Two bugs fixed from the original:
|
| 5 |
+
|
| 6 |
+
BUG 1 (critical): generate_beam() passed x_t (noisy) as `tgt` to model.
|
| 7 |
+
The model does q_sample(tgt, t) internally — so x_t got double-noised.
|
| 8 |
+
Fix: pass x0_estimate (current clean guess) as tgt. Model noises it correctly.
|
| 9 |
+
|
| 10 |
+
BUG 2: apply_diversity_penalty used logits.var(dim=-1) — this adds the
|
| 11 |
+
variance of each position's own distribution back to itself, which is
|
| 12 |
+
mathematically meaningless and just injects noise.
|
| 13 |
+
Fix: penalize tokens that are uniformly high-probability across ALL positions
|
| 14 |
+
(global common tokens). This genuinely promotes diversity.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ReverseDiffusion:
|
| 22 |
+
def __init__(self, scheduler):
|
| 23 |
+
self.scheduler = scheduler
|
| 24 |
+
|
| 25 |
+
def p_sample_step(
|
| 26 |
+
self,
|
| 27 |
+
model,
|
| 28 |
+
x_t,
|
| 29 |
+
t,
|
| 30 |
+
condition,
|
| 31 |
+
beam_width=3,
|
| 32 |
+
temperature=1.0,
|
| 33 |
+
repetition_penalty=1.2,
|
| 34 |
+
diversity_penalty=0.3
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Single reverse step with temperature + penalties.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
|
| 42 |
+
# ---- Shape safety ----
|
| 43 |
+
if x_t.dim() == 1:
|
| 44 |
+
x_t = x_t.unsqueeze(0)
|
| 45 |
+
|
| 46 |
+
if condition.dim() == 1:
|
| 47 |
+
condition = condition.unsqueeze(0)
|
| 48 |
+
|
| 49 |
+
if t.dim() == 0:
|
| 50 |
+
t = t.unsqueeze(0)
|
| 51 |
+
|
| 52 |
+
if t.shape[0] != x_t.shape[0]:
|
| 53 |
+
t = t.expand(x_t.shape[0])
|
| 54 |
+
|
| 55 |
+
# ---- Model forward ----
|
| 56 |
+
logits, _ = model(condition, x_t, t)
|
| 57 |
+
|
| 58 |
+
# ---- Temperature scaling ----
|
| 59 |
+
logits = logits / temperature
|
| 60 |
+
|
| 61 |
+
# ---- Repetition penalty (FIXED VERSION) ----
|
| 62 |
+
if repetition_penalty != 1.0:
|
| 63 |
+
logits = apply_repetition_penalty(
|
| 64 |
+
logits, x_t, repetition_penalty
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# ---- Diversity penalty ----
|
| 68 |
+
if diversity_penalty > 0:
|
| 69 |
+
logits = apply_diversity_penalty(
|
| 70 |
+
logits, diversity_penalty
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
probs = F.softmax(logits, dim=-1)
|
| 74 |
+
|
| 75 |
+
B, L, V = probs.shape
|
| 76 |
+
|
| 77 |
+
# ---- Top-k beam expansion ----
|
| 78 |
+
topk_probs, topk_ids = torch.topk(
|
| 79 |
+
probs, beam_width, dim=-1
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
candidates = []
|
| 83 |
+
|
| 84 |
+
for k in range(beam_width):
|
| 85 |
+
next_tokens = topk_ids[:, :, k]
|
| 86 |
+
score = torch.log(
|
| 87 |
+
topk_probs[:, :, k] + 1e-9
|
| 88 |
+
).sum()
|
| 89 |
+
candidates.append((next_tokens, score))
|
| 90 |
+
|
| 91 |
+
return candidates
|
| 92 |
+
|
| 93 |
+
def generate_beam(
|
| 94 |
+
self,
|
| 95 |
+
model,
|
| 96 |
+
condition,
|
| 97 |
+
beam_width=3,
|
| 98 |
+
num_steps=None,
|
| 99 |
+
temperature=1.0,
|
| 100 |
+
repetition_penalty=1.2,
|
| 101 |
+
diversity_penalty=0.3
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
Beam-search reverse diffusion with temperature.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
if num_steps is None:
|
| 108 |
+
num_steps = self.scheduler.num_timesteps
|
| 109 |
+
|
| 110 |
+
device = condition.device
|
| 111 |
+
|
| 112 |
+
if condition.dim() == 1:
|
| 113 |
+
condition = condition.unsqueeze(0)
|
| 114 |
+
|
| 115 |
+
B, L = condition.shape
|
| 116 |
+
|
| 117 |
+
# 🔥 Better initialization: start from MASK
|
| 118 |
+
x_init = torch.full(
|
| 119 |
+
(B, L),
|
| 120 |
+
fill_value=model.mask_token_id,
|
| 121 |
+
dtype=torch.long,
|
| 122 |
+
device=device
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
beams = [(x_init, 0.0)]
|
| 126 |
+
|
| 127 |
+
for step in reversed(range(num_steps)):
|
| 128 |
+
|
| 129 |
+
new_beams = []
|
| 130 |
+
|
| 131 |
+
for x_t, score in beams:
|
| 132 |
+
|
| 133 |
+
t_tensor = torch.full(
|
| 134 |
+
(B,),
|
| 135 |
+
step,
|
| 136 |
+
dtype=torch.long,
|
| 137 |
+
device=device
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
candidates = self.p_sample_step(
|
| 141 |
+
model,
|
| 142 |
+
x_t,
|
| 143 |
+
t_tensor,
|
| 144 |
+
condition,
|
| 145 |
+
beam_width,
|
| 146 |
+
temperature,
|
| 147 |
+
repetition_penalty,
|
| 148 |
+
diversity_penalty
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
for tokens, new_score in candidates:
|
| 152 |
+
new_beams.append(
|
| 153 |
+
(tokens, score + new_score)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# ---- Keep top beams ----
|
| 157 |
+
new_beams = sorted(
|
| 158 |
+
new_beams,
|
| 159 |
+
key=lambda x: x[1],
|
| 160 |
+
reverse=True
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
beams = new_beams[:beam_width]
|
| 164 |
+
|
| 165 |
+
best_tokens, best_score = beams[0]
|
| 166 |
+
return best_tokens
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def generate(
|
| 171 |
+
self,
|
| 172 |
+
model,
|
| 173 |
+
condition,
|
| 174 |
+
num_steps=None,
|
| 175 |
+
temperature=0.8,
|
| 176 |
+
top_k=50,
|
| 177 |
+
repetition_penalty=1.2,
|
| 178 |
+
diversity_penalty=0.0,
|
| 179 |
+
):
|
| 180 |
+
"""
|
| 181 |
+
Correct D3PM iterative refinement.
|
| 182 |
+
|
| 183 |
+
x0_est starts as all [MASK].
|
| 184 |
+
Each step: forward(src=condition, tgt=x0_est, t)
|
| 185 |
+
→ model applies q_sample(x0_est, t) internally
|
| 186 |
+
→ predicts cleaner x0
|
| 187 |
+
→ x0_est updated
|
| 188 |
+
|
| 189 |
+
diversity_penalty: reduces probability of tokens that are
|
| 190 |
+
globally dominant across all sequence positions (not logits.var()).
|
| 191 |
+
"""
|
| 192 |
+
if num_steps is None:
|
| 193 |
+
num_steps = self.scheduler.num_timesteps
|
| 194 |
+
|
| 195 |
+
device = condition.device
|
| 196 |
+
if condition.dim() == 1:
|
| 197 |
+
condition = condition.unsqueeze(0)
|
| 198 |
+
B, L = condition.shape
|
| 199 |
+
|
| 200 |
+
T = self.scheduler.num_timesteps
|
| 201 |
+
step_size = max(1, T // num_steps)
|
| 202 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 203 |
+
if timesteps[-1] != 0:
|
| 204 |
+
timesteps.append(0)
|
| 205 |
+
|
| 206 |
+
mask_id = model.mask_token_id
|
| 207 |
+
# Start: know nothing → all MASK is our initial clean estimate
|
| 208 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 209 |
+
hint = None
|
| 210 |
+
|
| 211 |
+
model.eval()
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 214 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 215 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 216 |
+
|
| 217 |
+
# KEY: pass x0_est as tgt — model noises it internally
|
| 218 |
+
import inspect
|
| 219 |
+
sig = inspect.signature(model.forward).parameters
|
| 220 |
+
if 'x0_hint' in sig:
|
| 221 |
+
outputs = model(condition, x0_est, t, x0_hint=hint)
|
| 222 |
+
else:
|
| 223 |
+
outputs = model(condition, x0_est, t)
|
| 224 |
+
|
| 225 |
+
logits = outputs[0] if isinstance(outputs, tuple) else outputs
|
| 226 |
+
|
| 227 |
+
# Repetition penalty: down-weight tokens already in sequence
|
| 228 |
+
if repetition_penalty != 1.0:
|
| 229 |
+
logits = apply_repetition_penalty(logits, x0_est, repetition_penalty)
|
| 230 |
+
|
| 231 |
+
# Diversity penalty: reduce globally dominant tokens
|
| 232 |
+
if diversity_penalty > 0.0:
|
| 233 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 234 |
+
|
| 235 |
+
# Temperature + top-k
|
| 236 |
+
logits = logits / max(temperature, 1e-5)
|
| 237 |
+
if top_k > 0:
|
| 238 |
+
logits = top_k_filter(logits, top_k)
|
| 239 |
+
|
| 240 |
+
probs = F.softmax(logits, dim=-1)
|
| 241 |
+
|
| 242 |
+
if is_last:
|
| 243 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 244 |
+
else:
|
| 245 |
+
x0_est = batch_multinomial(probs)
|
| 246 |
+
|
| 247 |
+
hint = x0_est
|
| 248 |
+
|
| 249 |
+
return x0_est
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ── Penalty functions ─────────────────────────────────────────────────
|
| 253 |
+
|
| 254 |
+
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
|
| 255 |
+
"""
|
| 256 |
+
Down-weight tokens that already appear in the current sequence.
|
| 257 |
+
Prevents मनो मनो मनो repetition loops.
|
| 258 |
+
penalty=1.0 → no effect
|
| 259 |
+
penalty=1.2 → mild suppression of repeated tokens
|
| 260 |
+
penalty=2.0 → strong suppression
|
| 261 |
+
"""
|
| 262 |
+
B, L, V = logits.shape
|
| 263 |
+
for b in range(B):
|
| 264 |
+
for token_id in set(prev_tokens[b].tolist()):
|
| 265 |
+
if token_id > 4: # don't penalize special tokens
|
| 266 |
+
logits[b, :, token_id] = logits[b, :, token_id] / penalty
|
| 267 |
+
return logits
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def apply_diversity_penalty(logits, penalty=0.5):
|
| 271 |
+
"""
|
| 272 |
+
Correct diversity penalty: penalize tokens that are globally dominant
|
| 273 |
+
across ALL sequence positions. This forces the model to use less
|
| 274 |
+
common tokens, increasing output diversity.
|
| 275 |
+
|
| 276 |
+
Method: compute mean probability across positions, subtract penalty
|
| 277 |
+
times that mean. Tokens uniformly high everywhere get suppressed.
|
| 278 |
+
|
| 279 |
+
penalty=0.0 → no diversity enforcement
|
| 280 |
+
penalty=0.5 → moderate diversity
|
| 281 |
+
penalty=1.0 → strong diversity (may hurt coherence)
|
| 282 |
+
"""
|
| 283 |
+
# Mean logit across all positions: [B, V]
|
| 284 |
+
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
|
| 285 |
+
# Subtract scaled global mean — suppresses globally common tokens
|
| 286 |
+
return logits - penalty * global_mean
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def top_k_filter(logits, k):
|
| 290 |
+
B, L, V = logits.shape
|
| 291 |
+
if k >= V:
|
| 292 |
+
return logits
|
| 293 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 294 |
+
threshold = topk_vals[..., -1].unsqueeze(-1)
|
| 295 |
+
return logits.masked_fill(logits < threshold, float('-inf'))
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def batch_multinomial(probs):
|
| 299 |
+
B, L, V = probs.shape
|
| 300 |
+
flat = probs.view(B * L, V) + 1e-9
|
| 301 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 302 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
reverse_process1.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ReverseDiffusion:
|
| 6 |
+
"""
|
| 7 |
+
Stable reverse diffusion with:
|
| 8 |
+
- Beam search
|
| 9 |
+
- Self conditioning
|
| 10 |
+
- Temperature sampling
|
| 11 |
+
- Repetition penalty
|
| 12 |
+
- Diversity penalty
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, scheduler):
|
| 16 |
+
|
| 17 |
+
self.scheduler = scheduler
|
| 18 |
+
|
| 19 |
+
self.temperature = 0.75
|
| 20 |
+
self.repetition_penalty = 1.15
|
| 21 |
+
self.diversity_penalty = 0.0
|
| 22 |
+
self.length_penalty = 1.0
|
| 23 |
+
|
| 24 |
+
# ------------------------------------------------
|
| 25 |
+
# penalties
|
| 26 |
+
# ------------------------------------------------
|
| 27 |
+
|
| 28 |
+
def apply_repetition_penalty(self, logits, tokens):
|
| 29 |
+
|
| 30 |
+
B, L, V = logits.shape
|
| 31 |
+
|
| 32 |
+
for b in range(B):
|
| 33 |
+
|
| 34 |
+
used = set(tokens[b].tolist())
|
| 35 |
+
|
| 36 |
+
for token_id in used:
|
| 37 |
+
logits[b, :, token_id] /= self.repetition_penalty
|
| 38 |
+
|
| 39 |
+
return logits
|
| 40 |
+
|
| 41 |
+
def apply_diversity_penalty(self, logits):
|
| 42 |
+
|
| 43 |
+
if self.diversity_penalty == 0:
|
| 44 |
+
return logits
|
| 45 |
+
|
| 46 |
+
logits_var = logits.var(dim=-1, keepdim=True)
|
| 47 |
+
return logits + self.diversity_penalty * logits_var
|
| 48 |
+
|
| 49 |
+
# ------------------------------------------------
|
| 50 |
+
# single reverse step
|
| 51 |
+
# ------------------------------------------------
|
| 52 |
+
|
| 53 |
+
def p_sample_step(self, model, x_t, t, condition, self_cond=None, beam_width=3):
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
|
| 57 |
+
logits, hidden = model(condition, x_t, t, self_cond)
|
| 58 |
+
|
| 59 |
+
logits = logits / self.temperature
|
| 60 |
+
|
| 61 |
+
logits = self.apply_repetition_penalty(logits, x_t)
|
| 62 |
+
logits = self.apply_diversity_penalty(logits)
|
| 63 |
+
|
| 64 |
+
probs = F.softmax(logits, dim=-1)
|
| 65 |
+
|
| 66 |
+
B, L, V = probs.shape
|
| 67 |
+
|
| 68 |
+
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
|
| 69 |
+
|
| 70 |
+
candidates = []
|
| 71 |
+
|
| 72 |
+
for k in range(beam_width):
|
| 73 |
+
|
| 74 |
+
tokens = topk_ids[:, :, k]
|
| 75 |
+
|
| 76 |
+
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
|
| 77 |
+
|
| 78 |
+
candidates.append((tokens, score))
|
| 79 |
+
|
| 80 |
+
return candidates
|
| 81 |
+
|
| 82 |
+
# ------------------------------------------------
|
| 83 |
+
# beam reverse diffusion
|
| 84 |
+
# ------------------------------------------------
|
| 85 |
+
|
| 86 |
+
def generate_beam(self, model, condition, beam_width=3, num_steps=None):
|
| 87 |
+
|
| 88 |
+
if num_steps is None:
|
| 89 |
+
num_steps = self.scheduler.num_timesteps
|
| 90 |
+
|
| 91 |
+
device = condition.device
|
| 92 |
+
|
| 93 |
+
if condition.dim() == 1:
|
| 94 |
+
condition = condition.unsqueeze(0)
|
| 95 |
+
|
| 96 |
+
B, L = condition.shape
|
| 97 |
+
|
| 98 |
+
# ------------------------------------------------
|
| 99 |
+
# BETTER LATENT INITIALIZATION
|
| 100 |
+
# ------------------------------------------------
|
| 101 |
+
|
| 102 |
+
x_init = condition.clone()
|
| 103 |
+
|
| 104 |
+
mask = torch.rand_like(x_init.float()) < 0.5
|
| 105 |
+
x_init[mask] = model.mask_token_id
|
| 106 |
+
|
| 107 |
+
beams = [(x_init, 0.0)]
|
| 108 |
+
|
| 109 |
+
self_cond = None
|
| 110 |
+
|
| 111 |
+
for step in reversed(range(num_steps)):
|
| 112 |
+
|
| 113 |
+
new_beams = []
|
| 114 |
+
|
| 115 |
+
for x_t, score in beams:
|
| 116 |
+
|
| 117 |
+
t_tensor = torch.full(
|
| 118 |
+
(B,),
|
| 119 |
+
step,
|
| 120 |
+
dtype=torch.long,
|
| 121 |
+
device=device
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
candidates = self.p_sample_step(
|
| 125 |
+
model,
|
| 126 |
+
x_t,
|
| 127 |
+
t_tensor,
|
| 128 |
+
condition,
|
| 129 |
+
self_cond,
|
| 130 |
+
beam_width
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
for tokens, new_score in candidates:
|
| 134 |
+
|
| 135 |
+
length_norm = tokens.shape[1] ** self.length_penalty
|
| 136 |
+
|
| 137 |
+
final_score = (score + new_score) / length_norm
|
| 138 |
+
|
| 139 |
+
new_beams.append((tokens, final_score))
|
| 140 |
+
|
| 141 |
+
new_beams = sorted(
|
| 142 |
+
new_beams,
|
| 143 |
+
key=lambda x: x[1],
|
| 144 |
+
reverse=True
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
beams = new_beams[:beam_width]
|
| 148 |
+
|
| 149 |
+
# self conditioning
|
| 150 |
+
self_cond = beams[0][0]
|
| 151 |
+
|
| 152 |
+
best_tokens, best_score = beams[0]
|
| 153 |
+
|
| 154 |
+
return best_tokens
|
reverse_process2.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
reverse_process.py — Final Correct Version
|
| 3 |
+
=============================================
|
| 4 |
+
|
| 5 |
+
KEY PRINCIPLE: generate() must be byte-for-byte identical to run_inference()
|
| 6 |
+
in inference.py, which is what produced BERTScore 0.75 at validation.
|
| 7 |
+
|
| 8 |
+
CRITICAL BUG IN PREVIOUS VERSION:
|
| 9 |
+
We passed inference_mode=True to the model, but the model was NEVER
|
| 10 |
+
called with inference_mode=True during training or validation.
|
| 11 |
+
run_inference() (the validated path) does:
|
| 12 |
+
model(input_ids, x0_est, t, x0_hint=hint)
|
| 13 |
+
→ inference_mode defaults to False.
|
| 14 |
+
|
| 15 |
+
With inference_mode=True the model does two things differently:
|
| 16 |
+
1. tgt_pad_mask = None (training used tgt_pad_mask = tgt==PAD)
|
| 17 |
+
2. Skips q_sample at t=0 (training always called q_sample)
|
| 18 |
+
The model was never trained to handle these conditions → garbage output.
|
| 19 |
+
|
| 20 |
+
Fix: do NOT pass inference_mode. Let it default to False, exactly
|
| 21 |
+
as run_inference() did.
|
| 22 |
+
|
| 23 |
+
BUGS FIXED (vs original reverse_process.py)
|
| 24 |
+
--------------------------------------------
|
| 25 |
+
BUG 1 generate_beam() used for D3PM → all-Ṛ repetition.
|
| 26 |
+
Use generate() (iterative refinement) from app1.py instead.
|
| 27 |
+
BUG 2 apply_diversity_penalty used logits.var() → noise injection.
|
| 28 |
+
Fixed to logits - penalty * logits.mean(dim=1) — global suppression.
|
| 29 |
+
BUG 3 x0_hint (self-conditioning) never passed to model.
|
| 30 |
+
Fixed: generate() passes x0_hint=hint every step.
|
| 31 |
+
BUG 4 params not forwarded from generate_beam() to p_sample_step().
|
| 32 |
+
Fixed in generate_beam() (kept for reference, not for production use).
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ReverseDiffusion:
|
| 40 |
+
|
| 41 |
+
def __init__(self, scheduler):
|
| 42 |
+
self.scheduler = scheduler
|
| 43 |
+
|
| 44 |
+
# Attribute-style defaults for backward compat with any code
|
| 45 |
+
# that sets reverse_diffusion.temperature = 0.9 etc.
|
| 46 |
+
# generate() prefers explicit kwargs and falls back to these.
|
| 47 |
+
self.temperature = 0.75
|
| 48 |
+
self.repetition_penalty = 1.15
|
| 49 |
+
self.diversity_penalty = 0.0
|
| 50 |
+
self.top_k = 50
|
| 51 |
+
|
| 52 |
+
# ------------------------------------------------------------------ #
|
| 53 |
+
# generate — CORRECT D3PM iterative refinement #
|
| 54 |
+
# Exact equivalent of run_inference() in inference.py #
|
| 55 |
+
# ------------------------------------------------------------------ #
|
| 56 |
+
def generate(
|
| 57 |
+
self,
|
| 58 |
+
model,
|
| 59 |
+
condition,
|
| 60 |
+
num_steps = None,
|
| 61 |
+
temperature = None,
|
| 62 |
+
top_k = None,
|
| 63 |
+
repetition_penalty = None,
|
| 64 |
+
diversity_penalty = None,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
D3PM iterative refinement — identical to run_inference() in inference.py,
|
| 68 |
+
which is the validated path (BERTScore 0.75).
|
| 69 |
+
|
| 70 |
+
Algorithm:
|
| 71 |
+
x0_est = all [MASK]
|
| 72 |
+
for t = T-1 down to 0:
|
| 73 |
+
logits = model(src, x0_est, t, x0_hint=hint)
|
| 74 |
+
↑ inference_mode NOT passed (defaults to False)
|
| 75 |
+
↑ this exactly matches training/validation
|
| 76 |
+
apply penalties, temperature, top_k
|
| 77 |
+
if t > 0: x0_est = multinomial(softmax(logits)) ← stochastic
|
| 78 |
+
if t = 0: x0_est = argmax(softmax(logits)) ← deterministic
|
| 79 |
+
hint = x0_est
|
| 80 |
+
"""
|
| 81 |
+
# Resolve: explicit kwarg > object attribute
|
| 82 |
+
temperature = temperature if temperature is not None else self.temperature
|
| 83 |
+
top_k = top_k if top_k is not None else self.top_k
|
| 84 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
|
| 85 |
+
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
|
| 86 |
+
|
| 87 |
+
if num_steps is None:
|
| 88 |
+
num_steps = self.scheduler.num_timesteps
|
| 89 |
+
|
| 90 |
+
device = condition.device
|
| 91 |
+
if condition.dim() == 1:
|
| 92 |
+
condition = condition.unsqueeze(0)
|
| 93 |
+
B, L = condition.shape
|
| 94 |
+
|
| 95 |
+
T = self.scheduler.num_timesteps
|
| 96 |
+
step_size = max(1, T // num_steps)
|
| 97 |
+
timesteps = list(range(T - 1, -1, -step_size))
|
| 98 |
+
if timesteps[-1] != 0:
|
| 99 |
+
timesteps.append(0)
|
| 100 |
+
|
| 101 |
+
mask_id = model.mask_token_id
|
| 102 |
+
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
|
| 103 |
+
hint = None
|
| 104 |
+
|
| 105 |
+
model.eval()
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
for step_idx, t_val in enumerate(timesteps):
|
| 108 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 109 |
+
is_last = (step_idx == len(timesteps) - 1)
|
| 110 |
+
|
| 111 |
+
# ── CRITICAL: do NOT pass inference_mode ──────────────────
|
| 112 |
+
# inference_mode defaults to False inside SanskritModel /
|
| 113 |
+
# D3PMCrossAttention. This matches run_inference() exactly.
|
| 114 |
+
# Passing inference_mode=True changes tgt_pad_mask and
|
| 115 |
+
# q_sample behaviour — the model was never trained for that.
|
| 116 |
+
logits, _ = model(condition, x0_est, t, x0_hint=hint)
|
| 117 |
+
|
| 118 |
+
# Repetition penalty
|
| 119 |
+
if repetition_penalty != 1.0:
|
| 120 |
+
logits = apply_repetition_penalty(
|
| 121 |
+
logits, x0_est, repetition_penalty
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Diversity penalty (correct: global mean suppression)
|
| 125 |
+
if diversity_penalty > 0.0:
|
| 126 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 127 |
+
|
| 128 |
+
logits = logits / max(temperature, 1e-5)
|
| 129 |
+
|
| 130 |
+
if top_k > 0:
|
| 131 |
+
logits = top_k_filter(logits, top_k)
|
| 132 |
+
|
| 133 |
+
probs = F.softmax(logits, dim=-1)
|
| 134 |
+
|
| 135 |
+
# Stochastic at every step except the last (argmax at t=0)
|
| 136 |
+
if is_last:
|
| 137 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 138 |
+
else:
|
| 139 |
+
x0_est = batch_multinomial(probs)
|
| 140 |
+
|
| 141 |
+
hint = x0_est
|
| 142 |
+
|
| 143 |
+
return x0_est # (B, L)
|
| 144 |
+
|
| 145 |
+
# ------------------------------------------------------------------ #
|
| 146 |
+
# p_sample_step — used by generate_beam (not for production) #
|
| 147 |
+
# ------------------------------------------------------------------ #
|
| 148 |
+
def p_sample_step(
|
| 149 |
+
self,
|
| 150 |
+
model,
|
| 151 |
+
x_t,
|
| 152 |
+
t,
|
| 153 |
+
condition,
|
| 154 |
+
beam_width = 3,
|
| 155 |
+
temperature = 1.0,
|
| 156 |
+
repetition_penalty = 1.2,
|
| 157 |
+
diversity_penalty = 0.3,
|
| 158 |
+
x0_hint = None,
|
| 159 |
+
):
|
| 160 |
+
with torch.no_grad():
|
| 161 |
+
if x_t.dim() == 1: x_t = x_t.unsqueeze(0)
|
| 162 |
+
if condition.dim() == 1: condition = condition.unsqueeze(0)
|
| 163 |
+
if t.dim() == 0: t = t.unsqueeze(0)
|
| 164 |
+
if t.shape[0] != x_t.shape[0]:
|
| 165 |
+
t = t.expand(x_t.shape[0])
|
| 166 |
+
|
| 167 |
+
# No inference_mode — matches training convention
|
| 168 |
+
logits, _ = model(condition, x_t, t, x0_hint=x0_hint)
|
| 169 |
+
|
| 170 |
+
logits = logits / max(temperature, 1e-5)
|
| 171 |
+
|
| 172 |
+
if repetition_penalty != 1.0:
|
| 173 |
+
logits = apply_repetition_penalty(logits, x_t, repetition_penalty)
|
| 174 |
+
if diversity_penalty > 0.0:
|
| 175 |
+
logits = apply_diversity_penalty(logits, diversity_penalty)
|
| 176 |
+
|
| 177 |
+
probs = F.softmax(logits, dim=-1)
|
| 178 |
+
B, L, V = probs.shape
|
| 179 |
+
|
| 180 |
+
topk_probs, topk_ids = torch.topk(probs, beam_width, dim=-1)
|
| 181 |
+
candidates = []
|
| 182 |
+
for k in range(beam_width):
|
| 183 |
+
next_tokens = topk_ids[:, :, k]
|
| 184 |
+
score = torch.log(topk_probs[:, :, k] + 1e-9).sum()
|
| 185 |
+
candidates.append((next_tokens, score))
|
| 186 |
+
return candidates
|
| 187 |
+
|
| 188 |
+
# ------------------------------------------------------------------ #
|
| 189 |
+
# generate_beam — kept for reference; NOT the correct D3PM method #
|
| 190 |
+
# ------------------------------------------------------------------ #
|
| 191 |
+
def generate_beam(
|
| 192 |
+
self,
|
| 193 |
+
model,
|
| 194 |
+
condition,
|
| 195 |
+
beam_width = 3,
|
| 196 |
+
num_steps = None,
|
| 197 |
+
temperature = None,
|
| 198 |
+
repetition_penalty = None,
|
| 199 |
+
diversity_penalty = None,
|
| 200 |
+
):
|
| 201 |
+
"""
|
| 202 |
+
WARNING: do NOT call this from app1.py for D3PM generation.
|
| 203 |
+
generate_beam() forces every position to the same top-k token
|
| 204 |
+
→ all-Ṛ / all-rud repetition. Use generate() instead.
|
| 205 |
+
Kept only for experimental reference.
|
| 206 |
+
"""
|
| 207 |
+
temperature = temperature if temperature is not None else self.temperature
|
| 208 |
+
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.repetition_penalty
|
| 209 |
+
diversity_penalty = diversity_penalty if diversity_penalty is not None else self.diversity_penalty
|
| 210 |
+
if num_steps is None:
|
| 211 |
+
num_steps = self.scheduler.num_timesteps
|
| 212 |
+
|
| 213 |
+
device = condition.device
|
| 214 |
+
if condition.dim() == 1: condition = condition.unsqueeze(0)
|
| 215 |
+
B, L = condition.shape
|
| 216 |
+
|
| 217 |
+
x_init = torch.full((B, L), fill_value=model.mask_token_id,
|
| 218 |
+
dtype=torch.long, device=device)
|
| 219 |
+
beams = [(x_init, 0.0)]
|
| 220 |
+
best_hint = None
|
| 221 |
+
|
| 222 |
+
for step in reversed(range(num_steps)):
|
| 223 |
+
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
|
| 224 |
+
new_beams = []
|
| 225 |
+
for x_t, score in beams:
|
| 226 |
+
candidates = self.p_sample_step(
|
| 227 |
+
model, x_t, t_tensor, condition,
|
| 228 |
+
beam_width = beam_width,
|
| 229 |
+
temperature = temperature,
|
| 230 |
+
repetition_penalty = repetition_penalty,
|
| 231 |
+
diversity_penalty = diversity_penalty,
|
| 232 |
+
x0_hint = best_hint,
|
| 233 |
+
)
|
| 234 |
+
for tokens, new_score in candidates:
|
| 235 |
+
new_beams.append((tokens, score + new_score.item()))
|
| 236 |
+
|
| 237 |
+
new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
|
| 238 |
+
beams = new_beams[:beam_width]
|
| 239 |
+
best_hint = beams[0][0]
|
| 240 |
+
|
| 241 |
+
return beams[0][0] # (B, L)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ── Penalty helpers ────────────────────────────────────────────────────────
|
| 245 |
+
|
| 246 |
+
def apply_repetition_penalty(logits, prev_tokens, penalty=1.2):
|
| 247 |
+
"""Down-weight tokens already present in the sequence."""
|
| 248 |
+
for b in range(logits.shape[0]):
|
| 249 |
+
for token_id in set(prev_tokens[b].tolist()):
|
| 250 |
+
if token_id > 4:
|
| 251 |
+
logits[b, :, token_id] = logits[b, :, token_id] / penalty
|
| 252 |
+
return logits
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def apply_diversity_penalty(logits, penalty=0.3):
|
| 256 |
+
"""
|
| 257 |
+
Correct diversity penalty: suppress globally dominant tokens.
|
| 258 |
+
logits -= penalty * mean(logits, dim=1) [sequence dimension]
|
| 259 |
+
"""
|
| 260 |
+
global_mean = logits.mean(dim=1, keepdim=True) # [B, 1, V]
|
| 261 |
+
return logits - penalty * global_mean
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def top_k_filter(logits, k):
|
| 265 |
+
B, L, V = logits.shape
|
| 266 |
+
if k >= V: return logits
|
| 267 |
+
topk_vals, _ = torch.topk(logits, k, dim=-1)
|
| 268 |
+
return logits.masked_fill(logits < topk_vals[..., -1].unsqueeze(-1), float('-inf'))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def batch_multinomial(probs):
|
| 272 |
+
B, L, V = probs.shape
|
| 273 |
+
flat = probs.view(B * L, V) + 1e-9
|
| 274 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 275 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
run_analysis.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis/run_analysis.py
|
| 3 |
+
=========================
|
| 4 |
+
Entry point for all 5 tasks.
|
| 5 |
+
|
| 6 |
+
Tasks:
|
| 7 |
+
Task 1 — KV Cache benchmark (no retraining)
|
| 8 |
+
Task 2 — Attention viz + drift (no retraining)
|
| 9 |
+
Task 3 — Concept vectors + PCA steer (no retraining)
|
| 10 |
+
Task 4 — Step ablation (REQUIRES retraining for each T)
|
| 11 |
+
Task 5 — Classifier-free guidance (trains small 10k-param classifier)
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python analysis/run_analysis.py --task 1
|
| 15 |
+
python analysis/run_analysis.py --task 2 --input "dharmo rakṣati rakṣitaḥ"
|
| 16 |
+
python analysis/run_analysis.py --task 3
|
| 17 |
+
python analysis/run_analysis.py --task 4 --phase generate_configs
|
| 18 |
+
python analysis/run_analysis.py --task 4 --phase analyze
|
| 19 |
+
python analysis/run_analysis.py --task 5
|
| 20 |
+
python analysis/run_analysis.py --task all --input "satyameva jayate"
|
| 21 |
+
|
| 22 |
+
Output files: analysis/outputs/
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
import os, sys, argparse, json
|
| 27 |
+
import numpy as np
|
| 28 |
+
|
| 29 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 30 |
+
from config import CONFIG
|
| 31 |
+
from inference import load_model
|
| 32 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 33 |
+
|
| 34 |
+
OUTPUT_DIR = "analysis/outputs"
|
| 35 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ── Shared loader ─────────────────────────────────────────────────────
|
| 39 |
+
|
| 40 |
+
def load_everything(cfg, device):
|
| 41 |
+
model_name = cfg['model_type']
|
| 42 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 43 |
+
ckpt = f"results7/{model_name}_neg_{has_neg}/best_model.pt"
|
| 44 |
+
if not os.path.exists(ckpt):
|
| 45 |
+
raise FileNotFoundError(f"No checkpoint at {ckpt}. Train first.")
|
| 46 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 47 |
+
model.eval()
|
| 48 |
+
src_tok = SanskritSourceTokenizer(
|
| 49 |
+
vocab_size=cfg['model'].get('src_vocab_size', 500),
|
| 50 |
+
max_len=cfg['model']['max_seq_len'])
|
| 51 |
+
tgt_tok = SanskritTargetTokenizer(
|
| 52 |
+
vocab_size=cfg['model'].get('tgt_vocab_size', 500),
|
| 53 |
+
max_len=cfg['model']['max_seq_len'])
|
| 54 |
+
return model, src_tok, tgt_tok, cfg
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_val_data(cfg, src_tok, tgt_tok, n=500):
|
| 58 |
+
"""Load validation set as (src_tensors, ref_strings, input_strings)."""
|
| 59 |
+
from Data.data import OptimizedSanskritDataset
|
| 60 |
+
from torch.utils.data import Subset
|
| 61 |
+
from sklearn.model_selection import train_test_split
|
| 62 |
+
|
| 63 |
+
dataset = OptimizedSanskritDataset(
|
| 64 |
+
'train', max_len=cfg['model']['max_seq_len'],
|
| 65 |
+
cfg=cfg, src_tokenizer=src_tok, tgt_tokenizer=tgt_tok)
|
| 66 |
+
total = min(cfg['data']['dataset_size'], len(dataset))
|
| 67 |
+
_, val_idx = train_test_split(list(range(total)), train_size=0.8, random_state=42)
|
| 68 |
+
val_idx = val_idx[:n]
|
| 69 |
+
|
| 70 |
+
src_list, ref_list, inp_list = [], [], []
|
| 71 |
+
for i in val_idx:
|
| 72 |
+
item = dataset[i]
|
| 73 |
+
src_list.append(item['input_ids'].unsqueeze(0))
|
| 74 |
+
ref_list.append(item['target_text'])
|
| 75 |
+
inp_list.append(item['input_text'])
|
| 76 |
+
return src_list, ref_list, inp_list
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ── Task 1 ────────────────────────────────────────────────────────────
|
| 80 |
+
|
| 81 |
+
def run_task1(model, src_tok, device):
|
| 82 |
+
print("\n" + "="*65)
|
| 83 |
+
print(" TASK 1 — KV Cache Benchmark")
|
| 84 |
+
print("="*65)
|
| 85 |
+
if not hasattr(model.model, 'generate_cached'):
|
| 86 |
+
print(" SKIP: not D3PMCrossAttention.")
|
| 87 |
+
return
|
| 88 |
+
from analysis.kv_cache_benchmark import run_benchmark, print_summary
|
| 89 |
+
results = run_benchmark(model, src_tok, device, src_lens=[16, 32, 64])
|
| 90 |
+
print_summary(results)
|
| 91 |
+
path = os.path.join(OUTPUT_DIR, "task1_kv_cache.txt")
|
| 92 |
+
with open(path, "w") as f:
|
| 93 |
+
f.write("TASK 1 — KV CACHE BENCHMARK\n" + "="*40 + "\n\n")
|
| 94 |
+
f.write(f"{'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
|
| 95 |
+
f"{'speedup':>8} {'encoder%':>9}\n")
|
| 96 |
+
for src_len, r in results.items():
|
| 97 |
+
f.write(f"{src_len:>8} {r['standard_s']:>12.3f} {r['cached_s']:>10.3f} "
|
| 98 |
+
f"{r['speedup']:>7.2f}x {r['encoder_pct']:>8.1f}%\n")
|
| 99 |
+
print(f" Saved: {path}")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ── Task 2 ────────────────────────────────────────────────────────────
|
| 103 |
+
|
| 104 |
+
def run_task2(model, src_tok, tgt_tok, device, input_text):
|
| 105 |
+
print("\n" + "="*65)
|
| 106 |
+
print(" TASK 2 — Attention Visualization + Semantic Drift")
|
| 107 |
+
print("="*65)
|
| 108 |
+
print(f" Input: {input_text}")
|
| 109 |
+
if not hasattr(model.model, 'encode_source'):
|
| 110 |
+
print(" SKIP: not D3PMCrossAttention.")
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
src_ids = src_tok.encode(input_text)
|
| 114 |
+
src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)
|
| 115 |
+
src_chars = list(input_text.strip())
|
| 116 |
+
|
| 117 |
+
from analysis.attention_viz import (AttentionCapture, plot_attn_heatmap,
|
| 118 |
+
plot_attn_evolution, plot_all_layers)
|
| 119 |
+
from analysis.semantic_drift import (capture_intermediate_outputs,
|
| 120 |
+
compute_drift, compute_token_stability,
|
| 121 |
+
plot_drift_curve)
|
| 122 |
+
|
| 123 |
+
# Attention capture
|
| 124 |
+
print(" Capturing attention weights...")
|
| 125 |
+
capturer = AttentionCapture(model)
|
| 126 |
+
step_weights = capturer.capture(src_tensor, capture_every=10)
|
| 127 |
+
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
out_ids = model.generate_cached(src_tensor)
|
| 130 |
+
tgt_ids = [x for x in out_ids[0].tolist() if x > 4]
|
| 131 |
+
tgt_text = tgt_tok.decode(tgt_ids).strip()
|
| 132 |
+
tgt_chars = list(tgt_text)
|
| 133 |
+
print(f" Output: {tgt_text}")
|
| 134 |
+
|
| 135 |
+
first_t = max(step_weights.keys())
|
| 136 |
+
plot_attn_heatmap(step_weights, t_val=first_t, layer=0,
|
| 137 |
+
src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
|
| 138 |
+
save_path=os.path.join(OUTPUT_DIR, f"task2_attn_t{first_t}.png"),
|
| 139 |
+
title=f"Attention t={first_t} (noisy) Layer 0")
|
| 140 |
+
plot_attn_heatmap(step_weights, t_val=0, layer=0,
|
| 141 |
+
src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
|
| 142 |
+
save_path=os.path.join(OUTPUT_DIR, "task2_attn_t0.png"),
|
| 143 |
+
title="Attention t=0 (final) Layer 0")
|
| 144 |
+
plot_all_layers(step_weights, t_val=0,
|
| 145 |
+
src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
|
| 146 |
+
save_path=os.path.join(OUTPUT_DIR, "task2_all_layers_t0.png"))
|
| 147 |
+
if len(src_chars) > 0 and len(tgt_chars) > 0:
|
| 148 |
+
plot_attn_evolution(step_weights, src_token_idx=0, tgt_token_idx=0,
|
| 149 |
+
layer=0, src_token_str=src_chars[0], tgt_token_str=tgt_chars[0],
|
| 150 |
+
save_path=os.path.join(OUTPUT_DIR, "task2_attn_evolution.png"))
|
| 151 |
+
|
| 152 |
+
# Semantic drift
|
| 153 |
+
print(" Computing semantic drift...")
|
| 154 |
+
step_outputs, final_out = capture_intermediate_outputs(
|
| 155 |
+
model, src_tensor, tgt_tok, capture_every=5)
|
| 156 |
+
drift = compute_drift(step_outputs, final_out)
|
| 157 |
+
stab = compute_token_stability(step_outputs, final_out, tgt_tok)
|
| 158 |
+
plot_drift_curve(drift, src_text=input_text,
|
| 159 |
+
save_path=os.path.join(OUTPUT_DIR, "task2_semantic_drift.png"))
|
| 160 |
+
|
| 161 |
+
print(f" Lock-in timestep: t={drift['lock_in_t']}")
|
| 162 |
+
print(f" Mean position lock-in: t={stab['mean_lock_t']:.1f} ± {stab['std_lock_t']:.1f}")
|
| 163 |
+
|
| 164 |
+
report = os.path.join(OUTPUT_DIR, "task2_report.txt")
|
| 165 |
+
with open(report, "w", encoding="utf-8") as f:
|
| 166 |
+
f.write("TASK 2 — ATTENTION + DRIFT REPORT\n" + "="*50 + "\n\n")
|
| 167 |
+
f.write(f"Input : {input_text}\nOutput : {final_out}\n\n")
|
| 168 |
+
f.write(f"Lock-in t : {drift['lock_in_t']}\n")
|
| 169 |
+
f.write(f"Mean pos lock-in : {stab['mean_lock_t']:.1f} ± {stab['std_lock_t']:.1f}\n\n")
|
| 170 |
+
f.write("Step → Output → CER-to-final\n" + "-"*60 + "\n")
|
| 171 |
+
for tv, cer in zip(drift["t_vals"], drift["cer_to_final"]):
|
| 172 |
+
f.write(f" t={tv:4d} | {step_outputs.get(tv,'')[:40]:40s} | {cer:.4f}\n")
|
| 173 |
+
print(f" Report: {report}")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ── Task 3 ────────────────────────────────────────────────────────────
|
| 177 |
+
|
| 178 |
+
def run_task3(model, src_tok, tgt_tok, device, src_list, ref_list):
|
| 179 |
+
print("\n" + "="*65)
|
| 180 |
+
print(" TASK 3 — Concept Vectors + PCA Steering")
|
| 181 |
+
print("="*65)
|
| 182 |
+
if not hasattr(model.model, 'encode_source'):
|
| 183 |
+
print(" SKIP: not D3PMCrossAttention.")
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
from analysis.concept_vectors import (collect_hidden_states, fit_pca,
|
| 187 |
+
find_diversity_direction, generate_diversity_spectrum, plot_pca_space)
|
| 188 |
+
|
| 189 |
+
# Collect hidden states from val set
|
| 190 |
+
n = min(500, len(src_list))
|
| 191 |
+
print(f" Collecting hidden states from {n} examples...")
|
| 192 |
+
hidden, _ = collect_hidden_states(
|
| 193 |
+
model, src_list[:n], t_capture=0, max_samples=n)
|
| 194 |
+
|
| 195 |
+
# Compute output lengths for diversity direction
|
| 196 |
+
lengths = []
|
| 197 |
+
for src in src_list[:n]:
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
out = model.generate_cached(src.to(device))
|
| 200 |
+
ids = [x for x in out[0].tolist() if x > 4]
|
| 201 |
+
lengths.append(len(tgt_tok.decode(ids)))
|
| 202 |
+
|
| 203 |
+
# Fit PCA + find diversity direction
|
| 204 |
+
pca = fit_pca(hidden, n_components=min(50, n-1))
|
| 205 |
+
direction, best_pc, corr = find_diversity_direction(hidden, lengths, pca)
|
| 206 |
+
|
| 207 |
+
# Plot concept space
|
| 208 |
+
plot_pca_space(hidden, lengths, pca, best_pc,
|
| 209 |
+
save_path=os.path.join(OUTPUT_DIR, "task3_concept_space.png"))
|
| 210 |
+
|
| 211 |
+
# Generate diversity spectrum for first example
|
| 212 |
+
print("\n Diversity spectrum for first example:")
|
| 213 |
+
src0 = src_list[0]
|
| 214 |
+
inp0 = src_tok.decode([x for x in src0[0].tolist() if x > 4])
|
| 215 |
+
print(f" Input: {inp0}")
|
| 216 |
+
spectrum = generate_diversity_spectrum(
|
| 217 |
+
model, src0.to(device), direction, tgt_tok,
|
| 218 |
+
alphas=[-2.0, -1.0, 0.0, 1.0, 2.0])
|
| 219 |
+
|
| 220 |
+
# Save diversity direction + results
|
| 221 |
+
np.save(os.path.join(OUTPUT_DIR, "task3_diversity_direction.npy"), direction)
|
| 222 |
+
|
| 223 |
+
report = os.path.join(OUTPUT_DIR, "task3_report.txt")
|
| 224 |
+
with open(report, "w", encoding="utf-8") as f:
|
| 225 |
+
f.write("TASK 3 — CONCEPT VECTORS + PCA STEERING\n" + "="*50 + "\n\n")
|
| 226 |
+
f.write(f"PCA: {pca.n_components_} components, "
|
| 227 |
+
f"{pca.explained_variance_ratio_.sum()*100:.1f}% variance\n")
|
| 228 |
+
f.write(f"Diversity PC: {best_pc} (|r|={corr:.3f} with output length)\n\n")
|
| 229 |
+
f.write("Diversity spectrum:\n")
|
| 230 |
+
for alpha, text in sorted(spectrum.items()):
|
| 231 |
+
f.write(f" alpha={alpha:+.1f} → {text}\n")
|
| 232 |
+
print(f" Report: {report}")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ── Task 4 ────────────────────────────────────────────────────────────
|
| 236 |
+
|
| 237 |
+
def run_task4(phase, model, src_tok, tgt_tok, device, cfg,
|
| 238 |
+
src_list, ref_list):
|
| 239 |
+
print("\n" + "="*65)
|
| 240 |
+
print(f" TASK 4 — Step Ablation (phase={phase})")
|
| 241 |
+
print("="*65)
|
| 242 |
+
|
| 243 |
+
from analysis.step_ablation import (generate_ablation_configs,
|
| 244 |
+
run_ablation_analysis, plot_ablation_3d, run_adversarial_test)
|
| 245 |
+
|
| 246 |
+
if phase == "generate_configs":
|
| 247 |
+
print(" Generating ablation configs...")
|
| 248 |
+
generate_ablation_configs(output_dir="ablation_configs")
|
| 249 |
+
print("\n NEXT STEPS:")
|
| 250 |
+
print(" 1. bash ablation_configs/train_all.sh")
|
| 251 |
+
print(" 2. python analysis/run_analysis.py --task 4 --phase analyze")
|
| 252 |
+
|
| 253 |
+
elif phase == "analyze":
|
| 254 |
+
# Check which models exist
|
| 255 |
+
existing = [T for T in [4, 8, 16, 32, 64]
|
| 256 |
+
if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
|
| 257 |
+
if not existing:
|
| 258 |
+
print(" No ablation models found at ablation_results/T*/best_model.pt")
|
| 259 |
+
print(" Run: python analysis/run_analysis.py --task 4 --phase generate_configs")
|
| 260 |
+
print(" Then: bash ablation_configs/train_all.sh")
|
| 261 |
+
return
|
| 262 |
+
|
| 263 |
+
print(f" Found models for T={existing}")
|
| 264 |
+
results = run_ablation_analysis(
|
| 265 |
+
ablation_dir="ablation_results", base_cfg=cfg,
|
| 266 |
+
src_list=src_list[:200], ref_list=ref_list[:200],
|
| 267 |
+
tgt_tokenizer=tgt_tok, device=device,
|
| 268 |
+
output_dir=OUTPUT_DIR)
|
| 269 |
+
plot_ablation_3d(results,
|
| 270 |
+
save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png"))
|
| 271 |
+
|
| 272 |
+
# Adversarial robustness always runs on existing model (no retraining)
|
| 273 |
+
print("\n Running adversarial robustness test...")
|
| 274 |
+
inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4])
|
| 275 |
+
for s in src_list[:50]]
|
| 276 |
+
run_adversarial_test(
|
| 277 |
+
model, src_tok, tgt_tok,
|
| 278 |
+
test_inputs=inp_texts, test_refs=ref_list[:50],
|
| 279 |
+
device=device, output_dir=OUTPUT_DIR)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
# ── Task 5 ────────────────────────────────────────────────────────────
|
| 283 |
+
|
| 284 |
+
def run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list):
|
| 285 |
+
print("\n" + "="*65)
|
| 286 |
+
print(" TASK 5 — Classifier-Free Guidance")
|
| 287 |
+
print("="*65)
|
| 288 |
+
if not hasattr(model.model, 'encode_source'):
|
| 289 |
+
print(" SKIP: not D3PMCrossAttention.")
|
| 290 |
+
return
|
| 291 |
+
|
| 292 |
+
from analysis.quality_classifier import (
|
| 293 |
+
QualityClassifier, collect_quality_data,
|
| 294 |
+
train_quality_classifier, sweep_guidance_scales)
|
| 295 |
+
|
| 296 |
+
clf_path = os.path.join(OUTPUT_DIR, "task5_quality_classifier.pt")
|
| 297 |
+
d_model = cfg['model']['d_model']
|
| 298 |
+
|
| 299 |
+
# Step 1: collect or load training data
|
| 300 |
+
data_path = os.path.join(OUTPUT_DIR, "task5_quality_data.npz")
|
| 301 |
+
if os.path.exists(data_path):
|
| 302 |
+
print(" Loading cached quality data...")
|
| 303 |
+
data = np.load(data_path)
|
| 304 |
+
hidden = data["hidden"]
|
| 305 |
+
quality = data["quality"]
|
| 306 |
+
else:
|
| 307 |
+
print(" Collecting quality data (this takes a few minutes)...")
|
| 308 |
+
n = min(2000, len(src_list))
|
| 309 |
+
hidden, quality = collect_quality_data(
|
| 310 |
+
model, src_list[:n], ref_list[:n], tgt_tok,
|
| 311 |
+
t_capture=0, max_samples=n)
|
| 312 |
+
np.savez(data_path, hidden=hidden, quality=quality)
|
| 313 |
+
print(f" Saved quality data: {data_path}")
|
| 314 |
+
|
| 315 |
+
# Step 2: train or load classifier
|
| 316 |
+
if os.path.exists(clf_path):
|
| 317 |
+
print(f" Loading cached classifier: {clf_path}")
|
| 318 |
+
clf = QualityClassifier(d_model)
|
| 319 |
+
clf.load_state_dict(torch.load(clf_path, map_location='cpu'))
|
| 320 |
+
clf.eval()
|
| 321 |
+
else:
|
| 322 |
+
print(" Training quality classifier...")
|
| 323 |
+
clf = train_quality_classifier(
|
| 324 |
+
hidden, quality, d_model=d_model,
|
| 325 |
+
epochs=30, batch_size=64, lr=1e-3,
|
| 326 |
+
save_path=clf_path)
|
| 327 |
+
clf.eval()
|
| 328 |
+
|
| 329 |
+
# Step 3: guidance scale sweep
|
| 330 |
+
print("\n Guidance scale sweep (λ ∈ {0.0, 0.5, 1.0, 1.5, 2.0, 3.0})...")
|
| 331 |
+
n_sweep = min(50, len(src_list))
|
| 332 |
+
results = sweep_guidance_scales(
|
| 333 |
+
model, clf, src_list[:n_sweep], ref_list[:n_sweep],
|
| 334 |
+
tgt_tok, scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
| 335 |
+
n_samples=n_sweep, device=device, output_dir=OUTPUT_DIR)
|
| 336 |
+
|
| 337 |
+
# Find optimal scale
|
| 338 |
+
best_scale = min(results, key=lambda s: results[s]["mean_cer"])
|
| 339 |
+
print(f"\n Optimal guidance scale: λ={best_scale:.1f} "
|
| 340 |
+
f"CER={results[best_scale]['mean_cer']:.4f}")
|
| 341 |
+
|
| 342 |
+
report = os.path.join(OUTPUT_DIR, "task5_report.txt")
|
| 343 |
+
with open(report, "w") as f:
|
| 344 |
+
f.write("TASK 5 — CLASSIFIER-FREE GUIDANCE\n" + "="*50 + "\n\n")
|
| 345 |
+
f.write(f"Classifier params: {sum(p.numel() for p in clf.parameters())}\n")
|
| 346 |
+
f.write(f"Training samples : {len(hidden)}\n\n")
|
| 347 |
+
f.write("Guidance scale sweep:\n")
|
| 348 |
+
f.write(f" {'λ':>6} {'CER':>8} {'diversity':>10}\n")
|
| 349 |
+
f.write(" " + "-"*28 + "\n")
|
| 350 |
+
for s in sorted(results.keys()):
|
| 351 |
+
r = results[s]
|
| 352 |
+
marker = " ← optimal" if s == best_scale else ""
|
| 353 |
+
f.write(f" {s:>6.1f} {r['mean_cer']:>8.4f} {r['diversity']:>10.3f}{marker}\n")
|
| 354 |
+
print(f" Report: {report}")
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# ── Main ──────────────────────────────────────────────────────────────
|
| 358 |
+
|
| 359 |
+
def main():
|
| 360 |
+
parser = argparse.ArgumentParser()
|
| 361 |
+
parser.add_argument("--task",
|
| 362 |
+
choices=["1","2","3","4","5","all"], default="all")
|
| 363 |
+
parser.add_argument("--input",
|
| 364 |
+
default="dharmo rakṣati rakṣitaḥ",
|
| 365 |
+
help="IAST input text for Task 2")
|
| 366 |
+
parser.add_argument("--phase",
|
| 367 |
+
choices=["generate_configs", "analyze"], default="analyze",
|
| 368 |
+
help="Task 4 phase: generate_configs (before training) or analyze (after)")
|
| 369 |
+
args = parser.parse_args()
|
| 370 |
+
|
| 371 |
+
cfg = CONFIG
|
| 372 |
+
device = torch.device(cfg['training']['device'])
|
| 373 |
+
|
| 374 |
+
print("Loading model and tokenizers...")
|
| 375 |
+
model, src_tok, tgt_tok, cfg = load_everything(cfg, device)
|
| 376 |
+
|
| 377 |
+
# Load val data for tasks that need it (Tasks 3, 4, 5)
|
| 378 |
+
needs_data = args.task in ("3", "4", "5", "all")
|
| 379 |
+
if needs_data:
|
| 380 |
+
print("Loading validation data...")
|
| 381 |
+
src_list, ref_list, inp_list = load_val_data(cfg, src_tok, tgt_tok, n=500)
|
| 382 |
+
else:
|
| 383 |
+
src_list, ref_list, inp_list = [], [], []
|
| 384 |
+
|
| 385 |
+
tasks = (["1","2","3","4","5"] if args.task == "all"
|
| 386 |
+
else [args.task])
|
| 387 |
+
|
| 388 |
+
for task in tasks:
|
| 389 |
+
if task == "1":
|
| 390 |
+
run_task1(model, src_tok, device)
|
| 391 |
+
elif task == "2":
|
| 392 |
+
run_task2(model, src_tok, tgt_tok, device, args.input)
|
| 393 |
+
elif task == "3":
|
| 394 |
+
run_task3(model, src_tok, tgt_tok, device, src_list, ref_list)
|
| 395 |
+
elif task == "4":
|
| 396 |
+
run_task4(args.phase, model, src_tok, tgt_tok, device, cfg,
|
| 397 |
+
src_list, ref_list)
|
| 398 |
+
elif task == "5":
|
| 399 |
+
run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list)
|
| 400 |
+
|
| 401 |
+
print(f"\n{'='*65}")
|
| 402 |
+
print(f" All outputs saved to: {OUTPUT_DIR}/")
|
| 403 |
+
print("="*65)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
if __name__ == "__main__":
|
| 407 |
+
main()
|
sanskrit_model.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
sanskrit_model.py — Fixed
|
| 3 |
+
===========================
|
| 4 |
+
Added inference_mode parameter to forward() so reverse_process.py can
|
| 5 |
+
pass inference_mode=True without a TypeError.
|
| 6 |
+
|
| 7 |
+
The wrapper introspects each inner model's signature and only passes
|
| 8 |
+
kwargs that model actually accepts — safe across all four architectures.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import inspect
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SanskritModel(nn.Module):
|
| 17 |
+
def __init__(self, cfg):
|
| 18 |
+
super().__init__()
|
| 19 |
+
model_type = cfg['model_type']
|
| 20 |
+
|
| 21 |
+
if model_type == 'd3pm_cross_attention':
|
| 22 |
+
from model.d3pm_model_cross_attention import D3PMCrossAttention
|
| 23 |
+
self.model = D3PMCrossAttention(cfg)
|
| 24 |
+
|
| 25 |
+
elif model_type == 'd3pm_encoder_decoder':
|
| 26 |
+
from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder
|
| 27 |
+
self.model = D3PMEncoderDecoder(cfg)
|
| 28 |
+
|
| 29 |
+
elif model_type == 'baseline_cross_attention':
|
| 30 |
+
from model.d3pm_model_cross_attention import BaselineCrossAttention
|
| 31 |
+
self.model = BaselineCrossAttention(cfg)
|
| 32 |
+
|
| 33 |
+
elif model_type == 'baseline_encoder_decoder':
|
| 34 |
+
from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder
|
| 35 |
+
self.model = BaselineEncoderDecoder(cfg)
|
| 36 |
+
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unknown model_type: {model_type}")
|
| 39 |
+
|
| 40 |
+
def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False):
|
| 41 |
+
"""
|
| 42 |
+
Forward pass. Introspects the inner model's signature so only
|
| 43 |
+
supported kwargs are passed — works with all four architectures.
|
| 44 |
+
"""
|
| 45 |
+
sig = inspect.signature(self.model.forward).parameters
|
| 46 |
+
kwargs = {}
|
| 47 |
+
if 'x0_hint' in sig:
|
| 48 |
+
kwargs['x0_hint'] = x0_hint
|
| 49 |
+
if 'inference_mode' in sig:
|
| 50 |
+
kwargs['inference_mode'] = inference_mode
|
| 51 |
+
|
| 52 |
+
if 't' in sig:
|
| 53 |
+
return self.model(input_ids, target_ids, t, **kwargs)
|
| 54 |
+
else:
|
| 55 |
+
return self.model(input_ids, target_ids, **kwargs)
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def generate(self, src, **kwargs):
|
| 59 |
+
sig = inspect.signature(self.model.generate).parameters
|
| 60 |
+
filtered = {k: v for k, v in kwargs.items() if k in sig}
|
| 61 |
+
return self.model.generate(src, **filtered)
|
scheduler.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
scheduler.py — Fixed & Upgraded
|
| 3 |
+
==================================
|
| 4 |
+
Changes:
|
| 5 |
+
1. T=64 (was 16). More timesteps = richer denoising curriculum per epoch.
|
| 6 |
+
2. alpha at t=0 is EXACTLY 1.0 — fixes Bug 2 (final-step re-noise).
|
| 7 |
+
3. sample_timestep samples [0, T-1] including t=0, so model trains on
|
| 8 |
+
fully-clean inputs (learns the identity at t=0 explicitly).
|
| 9 |
+
"""
|
| 10 |
+
import torch, math
|
| 11 |
+
|
| 12 |
+
class OptimizedCosineScheduler:
|
| 13 |
+
def __init__(self, cfg, device=None):
|
| 14 |
+
self.num_timesteps = cfg['model']['diffusion_steps'] # 64
|
| 15 |
+
self.mask_token_id = cfg['diffusion']['mask_token_id']
|
| 16 |
+
self.device = device or torch.device('cpu')
|
| 17 |
+
self.alphas_cumprod = self._build_schedule().to(self.device)
|
| 18 |
+
|
| 19 |
+
def _build_schedule(self):
|
| 20 |
+
T = self.num_timesteps
|
| 21 |
+
t = torch.arange(T + 1, dtype=torch.float32)
|
| 22 |
+
f_t = torch.cos((t / T + 0.008) / 1.008 * math.pi / 2) ** 2
|
| 23 |
+
alphas_bar = f_t / f_t[0]
|
| 24 |
+
alphas_bar = alphas_bar[1:] # shape [T]
|
| 25 |
+
alphas_bar[0] = 1.0 # FIX: exact 1.0 at t=0
|
| 26 |
+
alphas_bar[-1] = alphas_bar[-1].clamp(max=0.001)
|
| 27 |
+
return alphas_bar
|
| 28 |
+
|
| 29 |
+
def sample_timestep(self, batch_size):
|
| 30 |
+
"""Uniform [0, T-1] — includes t=0 so model sees clean inputs."""
|
| 31 |
+
return torch.randint(0, self.num_timesteps, (batch_size,))
|
| 32 |
+
|
| 33 |
+
def get_alpha(self, t):
|
| 34 |
+
return self.alphas_cumprod[t.to(self.alphas_cumprod.device).long()]
|
semantic_drift.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis/semantic_drift.py
|
| 3 |
+
===========================
|
| 4 |
+
Task 2: Semantic drift metric — how much does the intermediate generation
|
| 5 |
+
diverge from the final output as we walk through diffusion steps T → 0?
|
| 6 |
+
|
| 7 |
+
Metric: CER between x0_estimate at each step vs the final x0 at t=0.
|
| 8 |
+
|
| 9 |
+
A well-trained model should show:
|
| 10 |
+
- High drift at t=T-1 (near-random initial estimate)
|
| 11 |
+
- Rapid decrease in drift around t=T//2 (model finds the right structure)
|
| 12 |
+
- Near-zero drift at t=10 (output is stable, only fine corrections remain)
|
| 13 |
+
|
| 14 |
+
If drift stays high until t=5 then suddenly collapses → model is doing all
|
| 15 |
+
its work in the last few steps → consider reducing T.
|
| 16 |
+
|
| 17 |
+
Also measures:
|
| 18 |
+
- Token stability: fraction of positions that don't change between steps
|
| 19 |
+
- Lock-in time: first step where each position "commits" to its final token
|
| 20 |
+
|
| 21 |
+
No retraining required. Uses generate_cached() with intermediate snapshots.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
import numpy as np
|
| 27 |
+
from typing import List, Dict, Optional, Tuple
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def compute_cer_between(pred: str, ref: str) -> float:
|
| 31 |
+
"""CER between two strings."""
|
| 32 |
+
if not ref:
|
| 33 |
+
return 1.0 if pred else 0.0
|
| 34 |
+
|
| 35 |
+
def edit_distance(s1, s2):
|
| 36 |
+
m, n = len(s1), len(s2)
|
| 37 |
+
dp = list(range(n + 1))
|
| 38 |
+
for i in range(1, m + 1):
|
| 39 |
+
prev, dp[0] = dp[0], i
|
| 40 |
+
for j in range(1, n + 1):
|
| 41 |
+
temp = dp[j]
|
| 42 |
+
dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 43 |
+
prev = temp
|
| 44 |
+
return dp[n]
|
| 45 |
+
|
| 46 |
+
return edit_distance(pred, ref) / len(ref)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@torch.no_grad()
|
| 50 |
+
def capture_intermediate_outputs(
|
| 51 |
+
model,
|
| 52 |
+
src: torch.Tensor,
|
| 53 |
+
tgt_tokenizer,
|
| 54 |
+
capture_every: int = 5,
|
| 55 |
+
temperature: float = 0.8,
|
| 56 |
+
top_k: int = 40,
|
| 57 |
+
) -> Tuple[Dict[int, str], str]:
|
| 58 |
+
"""
|
| 59 |
+
Run generation while recording the decoded x0_estimate at every
|
| 60 |
+
`capture_every` diffusion steps.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
model : SanskritModel (D3PMCrossAttention)
|
| 64 |
+
src : [1, src_len] IAST token ids (single sample)
|
| 65 |
+
tgt_tokenizer : SanskritTargetTokenizer for decoding intermediate outputs
|
| 66 |
+
capture_every : record every N steps
|
| 67 |
+
temperature : sampling temperature
|
| 68 |
+
top_k : top-k filter
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
step_outputs : dict mapping t_val → decoded Devanagari string at that step
|
| 72 |
+
final_output : decoded string at t=0 (final result)
|
| 73 |
+
"""
|
| 74 |
+
if src.dim() == 1:
|
| 75 |
+
src = src.unsqueeze(0)
|
| 76 |
+
|
| 77 |
+
inner = model.model
|
| 78 |
+
T = inner.scheduler.num_timesteps
|
| 79 |
+
device = src.device
|
| 80 |
+
|
| 81 |
+
# Encode source once (KV cache)
|
| 82 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 83 |
+
|
| 84 |
+
B = src.shape[0]
|
| 85 |
+
tgt_len = inner.max_seq_len
|
| 86 |
+
mask_id = inner.mask_token_id
|
| 87 |
+
|
| 88 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 89 |
+
hint = None
|
| 90 |
+
|
| 91 |
+
step_outputs: Dict[int, str] = {}
|
| 92 |
+
inner.eval()
|
| 93 |
+
|
| 94 |
+
for t_val in range(T - 1, -1, -1):
|
| 95 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 96 |
+
is_last = (t_val == 0)
|
| 97 |
+
|
| 98 |
+
logits, _ = inner.forward_cached(
|
| 99 |
+
memory, src_pad_mask, x0_est, t,
|
| 100 |
+
x0_hint=hint, inference_mode=True,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
logits = logits / max(temperature, 1e-8)
|
| 104 |
+
if top_k > 0:
|
| 105 |
+
V = logits.shape[-1]
|
| 106 |
+
if top_k < V:
|
| 107 |
+
topk_vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 108 |
+
threshold = topk_vals[..., -1].unsqueeze(-1)
|
| 109 |
+
logits = logits.masked_fill(logits < threshold, float('-inf'))
|
| 110 |
+
|
| 111 |
+
probs = F.softmax(logits, dim=-1)
|
| 112 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 113 |
+
hint = x0_est
|
| 114 |
+
|
| 115 |
+
# Capture at this step
|
| 116 |
+
if (T - 1 - t_val) % capture_every == 0 or is_last:
|
| 117 |
+
ids = [x for x in x0_est[0].tolist() if x > 4]
|
| 118 |
+
text = tgt_tokenizer.decode(ids).strip()
|
| 119 |
+
step_outputs[t_val] = text
|
| 120 |
+
|
| 121 |
+
final_output = step_outputs.get(0, "")
|
| 122 |
+
return step_outputs, final_output
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _sample(probs):
|
| 126 |
+
B, L, V = probs.shape
|
| 127 |
+
flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 128 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 129 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def compute_drift(
|
| 133 |
+
step_outputs: Dict[int, str],
|
| 134 |
+
final_output: str,
|
| 135 |
+
) -> Dict[str, object]:
|
| 136 |
+
"""
|
| 137 |
+
Compute drift metrics comparing each intermediate output to the final.
|
| 138 |
+
|
| 139 |
+
Returns dict with:
|
| 140 |
+
t_vals : list of captured timesteps (T-1 → 0)
|
| 141 |
+
cer_to_final: CER between each step's output and the final output
|
| 142 |
+
0.0 = identical to final, 1.0 = completely different
|
| 143 |
+
lock_in_t : first t_val where CER drops and stays below 0.1
|
| 144 |
+
(step at which output "commits" to final form)
|
| 145 |
+
"""
|
| 146 |
+
t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
|
| 147 |
+
cer_to_final = []
|
| 148 |
+
|
| 149 |
+
for t_val in t_vals:
|
| 150 |
+
cer = compute_cer_between(step_outputs[t_val], final_output)
|
| 151 |
+
cer_to_final.append(cer)
|
| 152 |
+
|
| 153 |
+
# Find lock-in: first step where CER stays below threshold for rest of run
|
| 154 |
+
threshold = 0.1
|
| 155 |
+
lock_in_t = 0 # default: never locked in early
|
| 156 |
+
for i, (t_val, cer) in enumerate(zip(t_vals, cer_to_final)):
|
| 157 |
+
if all(c <= threshold for c in cer_to_final[i:]):
|
| 158 |
+
lock_in_t = t_val
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
return {
|
| 162 |
+
"t_vals": t_vals,
|
| 163 |
+
"cer_to_final": cer_to_final,
|
| 164 |
+
"lock_in_t": lock_in_t,
|
| 165 |
+
"final_output": final_output,
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def compute_token_stability(
|
| 170 |
+
step_outputs: Dict[int, str],
|
| 171 |
+
final_output: str,
|
| 172 |
+
tgt_tokenizer,
|
| 173 |
+
) -> Dict[str, object]:
|
| 174 |
+
"""
|
| 175 |
+
Token-level stability: for each position, at which diffusion step
|
| 176 |
+
does it first match its final token and stay matched?
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
position_lock_times: list of t_val at which each position locks in
|
| 180 |
+
mean_lock_t : average lock-in timestep across positions
|
| 181 |
+
"""
|
| 182 |
+
T = max(step_outputs.keys())
|
| 183 |
+
t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
|
| 184 |
+
|
| 185 |
+
# Encode all intermediate outputs and the final
|
| 186 |
+
def encode(text):
|
| 187 |
+
return tgt_tokenizer.encode(text)
|
| 188 |
+
|
| 189 |
+
final_ids = encode(final_output)
|
| 190 |
+
L = len(final_ids)
|
| 191 |
+
|
| 192 |
+
# Build matrix: [n_steps, L]
|
| 193 |
+
step_ids = []
|
| 194 |
+
for t_val in t_vals:
|
| 195 |
+
step_ids.append(encode(step_outputs.get(t_val, "")))
|
| 196 |
+
|
| 197 |
+
# Pad all to same length
|
| 198 |
+
max_len = max(len(s) for s in step_ids)
|
| 199 |
+
step_ids = [s + [1] * (max_len - len(s)) for s in step_ids] # 1=PAD
|
| 200 |
+
final_ids_padded = final_ids + [1] * (max_len - len(final_ids))
|
| 201 |
+
|
| 202 |
+
step_arr = np.array(step_ids) # [n_steps, L]
|
| 203 |
+
final_arr = np.array(final_ids_padded) # [L]
|
| 204 |
+
|
| 205 |
+
# For each position: find first step index where it matches final
|
| 206 |
+
# and stays matched for all subsequent steps
|
| 207 |
+
position_lock_steps = []
|
| 208 |
+
for pos in range(min(L, max_len)):
|
| 209 |
+
col = step_arr[:, pos] # [n_steps]
|
| 210 |
+
fin = final_arr[pos]
|
| 211 |
+
locked_at = len(t_vals) - 1 # default: never locks early
|
| 212 |
+
for i in range(len(t_vals)):
|
| 213 |
+
if all(col[i:] == fin):
|
| 214 |
+
locked_at = i
|
| 215 |
+
break
|
| 216 |
+
position_lock_steps.append(t_vals[locked_at] if locked_at < len(t_vals) else 0)
|
| 217 |
+
|
| 218 |
+
return {
|
| 219 |
+
"position_lock_times": position_lock_steps,
|
| 220 |
+
"mean_lock_t": float(np.mean(position_lock_steps)),
|
| 221 |
+
"std_lock_t": float(np.std(position_lock_steps)),
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def plot_drift_curve(
|
| 226 |
+
drift_result: Dict,
|
| 227 |
+
src_text: str = "",
|
| 228 |
+
save_path: Optional[str] = None,
|
| 229 |
+
):
|
| 230 |
+
"""
|
| 231 |
+
Plot CER-to-final vs diffusion step.
|
| 232 |
+
Shows where the model "commits" to the final output.
|
| 233 |
+
"""
|
| 234 |
+
try:
|
| 235 |
+
import matplotlib.pyplot as plt
|
| 236 |
+
except ImportError:
|
| 237 |
+
print("pip install matplotlib.")
|
| 238 |
+
return
|
| 239 |
+
|
| 240 |
+
t_vals = drift_result["t_vals"]
|
| 241 |
+
cers = drift_result["cer_to_final"]
|
| 242 |
+
lock_t = drift_result["lock_in_t"]
|
| 243 |
+
|
| 244 |
+
fig, ax = plt.subplots(figsize=(12, 4))
|
| 245 |
+
ax.plot(range(len(t_vals)), cers, linewidth=1.8, color='coral', label='CER to final')
|
| 246 |
+
ax.fill_between(range(len(t_vals)), cers, alpha=0.15, color='coral')
|
| 247 |
+
|
| 248 |
+
# Mark lock-in point
|
| 249 |
+
if lock_t in t_vals:
|
| 250 |
+
lock_idx = t_vals.index(lock_t)
|
| 251 |
+
ax.axvline(lock_idx, color='steelblue', linestyle='--', linewidth=1.2,
|
| 252 |
+
label=f"Lock-in at t={lock_t}")
|
| 253 |
+
|
| 254 |
+
ax.axhline(0.1, color='gray', linestyle=':', linewidth=1, alpha=0.7)
|
| 255 |
+
|
| 256 |
+
n = len(t_vals)
|
| 257 |
+
tick_positions = list(range(0, n, max(1, n // 10)))
|
| 258 |
+
ax.set_xticks(tick_positions)
|
| 259 |
+
ax.set_xticklabels([str(t_vals[i]) for i in tick_positions], fontsize=8)
|
| 260 |
+
ax.set_xlabel("Diffusion step t (T-1 → 0)", fontsize=11)
|
| 261 |
+
ax.set_ylabel("CER vs final output", fontsize=11)
|
| 262 |
+
ax.set_ylim(0, 1.05)
|
| 263 |
+
ax.set_xlim(0, n - 1)
|
| 264 |
+
ax.legend(fontsize=10)
|
| 265 |
+
|
| 266 |
+
title = f"Semantic drift"
|
| 267 |
+
if src_text:
|
| 268 |
+
title += f" | src: {src_text[:50]}"
|
| 269 |
+
ax.set_title(title, fontsize=11)
|
| 270 |
+
plt.tight_layout()
|
| 271 |
+
|
| 272 |
+
if save_path:
|
| 273 |
+
import os
|
| 274 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 275 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 276 |
+
print(f"Saved: {save_path}")
|
| 277 |
+
else:
|
| 278 |
+
plt.show()
|
| 279 |
+
plt.close()
|
step_ablation.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis/step_ablation.py
|
| 3 |
+
==========================
|
| 4 |
+
Task 4: Semantic Robustness — Ablation of Diffusion Steps vs Meaning Preservation
|
| 5 |
+
|
| 6 |
+
Two-phase workflow (retraining IS required for different T values):
|
| 7 |
+
|
| 8 |
+
PHASE 1 — Generate configs + train (run once per T value):
|
| 9 |
+
python analysis/step_ablation.py --phase generate_configs
|
| 10 |
+
# Creates configs: ablation_configs/T4.py, T8.py, T16.py, T32.py, T64.py
|
| 11 |
+
# Then train each: MODEL_TYPE=d3pm_cross_attention python train.py (for each config)
|
| 12 |
+
|
| 13 |
+
PHASE 2 — Analyze trained models (no retraining needed):
|
| 14 |
+
python analysis/step_ablation.py --phase analyze
|
| 15 |
+
# Loads each trained model, generates 200 paraphrases, computes CER
|
| 16 |
+
# Produces 3D plot: X=steps, Y=generation_speed, Z=CER
|
| 17 |
+
|
| 18 |
+
Why retraining is needed:
|
| 19 |
+
A model trained with T=128 learns to denoise from x_t~Uniform[0,128].
|
| 20 |
+
Running it with T=4 means the model only sees t∈{0,1,2,3} — which it
|
| 21 |
+
was never trained on at those scales. Outputs are meaningless.
|
| 22 |
+
You must train a separate model for each T value.
|
| 23 |
+
|
| 24 |
+
Also implements adversarial robustness test (no retraining):
|
| 25 |
+
Takes your existing T=128 model and tests whether corrupted IAST
|
| 26 |
+
inputs (typos, character swaps) cause proportional output degradation.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
import numpy as np
|
| 32 |
+
import os
|
| 33 |
+
import sys
|
| 34 |
+
import time
|
| 35 |
+
import json
|
| 36 |
+
import copy
|
| 37 |
+
from typing import List, Dict, Optional
|
| 38 |
+
|
| 39 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ── Phase 1: Config generation ────────────────────────────────────────
|
| 43 |
+
|
| 44 |
+
T_VALUES = [4, 8, 16, 32, 64]
|
| 45 |
+
|
| 46 |
+
def generate_ablation_configs(base_config_path: str = "config.py",
|
| 47 |
+
output_dir: str = "ablation_configs"):
|
| 48 |
+
"""
|
| 49 |
+
Generate one config file per T value.
|
| 50 |
+
Each config is a copy of the base config with diffusion_steps changed.
|
| 51 |
+
|
| 52 |
+
After running this, train each model:
|
| 53 |
+
for T in 4 8 16 32 64; do
|
| 54 |
+
cp ablation_configs/config_T${T}.py config.py
|
| 55 |
+
python train.py
|
| 56 |
+
mv results7/d3pm_cross_attention_neg_False \
|
| 57 |
+
ablation_results/T${T}
|
| 58 |
+
done
|
| 59 |
+
"""
|
| 60 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# Read base config
|
| 63 |
+
with open(base_config_path, "r") as f:
|
| 64 |
+
base_src = f.read()
|
| 65 |
+
|
| 66 |
+
for T in T_VALUES:
|
| 67 |
+
# Replace diffusion_steps and num_steps
|
| 68 |
+
cfg_src = base_src
|
| 69 |
+
cfg_src = cfg_src.replace(
|
| 70 |
+
'"diffusion_steps": 128',
|
| 71 |
+
f'"diffusion_steps": {T}'
|
| 72 |
+
)
|
| 73 |
+
cfg_src = cfg_src.replace(
|
| 74 |
+
"'diffusion_steps': 128",
|
| 75 |
+
f"'diffusion_steps': {T}"
|
| 76 |
+
)
|
| 77 |
+
cfg_src = cfg_src.replace(
|
| 78 |
+
'"num_steps": 128',
|
| 79 |
+
f'"num_steps": {T}'
|
| 80 |
+
)
|
| 81 |
+
cfg_src = cfg_src.replace(
|
| 82 |
+
"'num_steps': 128",
|
| 83 |
+
f"'num_steps': {T}"
|
| 84 |
+
)
|
| 85 |
+
out_path = os.path.join(output_dir, f"config_T{T}.py")
|
| 86 |
+
with open(out_path, "w") as f:
|
| 87 |
+
f.write(f"# Ablation config: T={T} diffusion steps\n")
|
| 88 |
+
f.write(cfg_src)
|
| 89 |
+
print(f" Wrote: {out_path}")
|
| 90 |
+
|
| 91 |
+
# Write a shell script to train all
|
| 92 |
+
shell_script = os.path.join(output_dir, "train_all.sh")
|
| 93 |
+
with open(shell_script, "w") as f:
|
| 94 |
+
f.write("#!/bin/bash\n")
|
| 95 |
+
f.write("# Run this script to train all ablation models\n\n")
|
| 96 |
+
for T in T_VALUES:
|
| 97 |
+
f.write(f"echo '=== Training T={T} ==='\n")
|
| 98 |
+
f.write(f"cp {output_dir}/config_T{T}.py config.py\n")
|
| 99 |
+
f.write(f"python train.py\n")
|
| 100 |
+
f.write(f"mkdir -p ablation_results/T{T}\n")
|
| 101 |
+
f.write(f"cp -r results7/d3pm_cross_attention_neg_False/best_model.pt "
|
| 102 |
+
f"ablation_results/T{T}/best_model.pt\n")
|
| 103 |
+
f.write(f"cp -r results7/d3pm_cross_attention_neg_False/train.log "
|
| 104 |
+
f"ablation_results/T{T}/train.log\n\n")
|
| 105 |
+
os.chmod(shell_script, 0o755)
|
| 106 |
+
print(f"\nTraining script: {shell_script}")
|
| 107 |
+
print(f"Run: bash {shell_script}")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ── Phase 2: Analysis (after models are trained) ──────────────────────
|
| 111 |
+
|
| 112 |
+
def compute_cer(pred: str, ref: str) -> float:
|
| 113 |
+
if not ref:
|
| 114 |
+
return 1.0
|
| 115 |
+
|
| 116 |
+
def edit_distance(s1, s2):
|
| 117 |
+
m, n = len(s1), len(s2)
|
| 118 |
+
dp = list(range(n + 1))
|
| 119 |
+
for i in range(1, m + 1):
|
| 120 |
+
prev, dp[0] = dp[0], i
|
| 121 |
+
for j in range(1, n + 1):
|
| 122 |
+
temp = dp[j]
|
| 123 |
+
dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 124 |
+
prev = temp
|
| 125 |
+
return dp[n]
|
| 126 |
+
|
| 127 |
+
return edit_distance(pred, ref) / max(len(ref), 1)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def evaluate_model(
|
| 131 |
+
model,
|
| 132 |
+
src_list: List[torch.Tensor],
|
| 133 |
+
ref_list: List[str],
|
| 134 |
+
tgt_tokenizer,
|
| 135 |
+
n_samples: int = 200,
|
| 136 |
+
temperature: float = 0.8,
|
| 137 |
+
top_k: int = 40,
|
| 138 |
+
) -> Dict:
|
| 139 |
+
"""
|
| 140 |
+
Generate n_samples outputs and compute CER + generation speed.
|
| 141 |
+
|
| 142 |
+
Returns dict with:
|
| 143 |
+
mean_cer : average CER over samples
|
| 144 |
+
generation_s : total wall-clock seconds for all generations
|
| 145 |
+
speed_per_sample: seconds per sample
|
| 146 |
+
cer_list : per-sample CER values
|
| 147 |
+
"""
|
| 148 |
+
device = next(model.parameters()).device
|
| 149 |
+
n = min(n_samples, len(src_list))
|
| 150 |
+
cer_list = []
|
| 151 |
+
|
| 152 |
+
start = time.perf_counter()
|
| 153 |
+
for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
|
| 154 |
+
if src.dim() == 1:
|
| 155 |
+
src = src.unsqueeze(0)
|
| 156 |
+
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
if hasattr(model.model, 'generate_cached'):
|
| 159 |
+
out = model.model.generate_cached(
|
| 160 |
+
src.to(device), temperature=temperature, top_k=top_k
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
out = model.generate(
|
| 164 |
+
src.to(device), temperature=temperature, top_k=top_k
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
ids = [x for x in out[0].tolist() if x > 4]
|
| 168 |
+
pred = tgt_tokenizer.decode(ids).strip()
|
| 169 |
+
cer = compute_cer(pred, ref)
|
| 170 |
+
cer_list.append(cer)
|
| 171 |
+
|
| 172 |
+
elapsed = time.perf_counter() - start
|
| 173 |
+
|
| 174 |
+
return {
|
| 175 |
+
"mean_cer": float(np.mean(cer_list)),
|
| 176 |
+
"std_cer": float(np.std(cer_list)),
|
| 177 |
+
"generation_s": elapsed,
|
| 178 |
+
"speed_per_sample": elapsed / max(n, 1),
|
| 179 |
+
"cer_list": cer_list,
|
| 180 |
+
"n_samples": n,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def run_ablation_analysis(
|
| 185 |
+
ablation_dir: str = "ablation_results",
|
| 186 |
+
base_cfg: dict = None,
|
| 187 |
+
src_list: List[torch.Tensor] = None,
|
| 188 |
+
ref_list: List[str] = None,
|
| 189 |
+
tgt_tokenizer = None,
|
| 190 |
+
device: torch.device = None,
|
| 191 |
+
output_dir: str = "analysis/outputs",
|
| 192 |
+
) -> Dict:
|
| 193 |
+
"""
|
| 194 |
+
Load each trained model and evaluate.
|
| 195 |
+
Produces results dict and 3D plot.
|
| 196 |
+
|
| 197 |
+
Expects ablation_results/T{N}/best_model.pt for each T in T_VALUES.
|
| 198 |
+
"""
|
| 199 |
+
from inference import load_model
|
| 200 |
+
|
| 201 |
+
results = {}
|
| 202 |
+
for T in T_VALUES:
|
| 203 |
+
ckpt = os.path.join(ablation_dir, f"T{T}", "best_model.pt")
|
| 204 |
+
if not os.path.exists(ckpt):
|
| 205 |
+
print(f" SKIP T={T}: no checkpoint at {ckpt}")
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
print(f"\nEvaluating T={T}...")
|
| 209 |
+
cfg_T = copy.deepcopy(base_cfg)
|
| 210 |
+
cfg_T['model']['diffusion_steps'] = T
|
| 211 |
+
cfg_T['inference']['num_steps'] = T
|
| 212 |
+
|
| 213 |
+
model, cfg_T = load_model(ckpt, cfg_T, device)
|
| 214 |
+
model.eval()
|
| 215 |
+
|
| 216 |
+
metrics = evaluate_model(
|
| 217 |
+
model, src_list, ref_list, tgt_tokenizer, n_samples=200
|
| 218 |
+
)
|
| 219 |
+
results[T] = metrics
|
| 220 |
+
print(f" T={T} CER={metrics['mean_cer']:.4f} "
|
| 221 |
+
f"speed={metrics['speed_per_sample']:.3f}s/sample")
|
| 222 |
+
|
| 223 |
+
del model
|
| 224 |
+
|
| 225 |
+
# Save results
|
| 226 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 227 |
+
results_path = os.path.join(output_dir, "ablation_results.json")
|
| 228 |
+
with open(results_path, "w") as f:
|
| 229 |
+
json.dump({str(k): {kk: vv for kk, vv in v.items() if kk != 'cer_list'}
|
| 230 |
+
for k, v in results.items()}, f, indent=2)
|
| 231 |
+
print(f"\nResults saved: {results_path}")
|
| 232 |
+
|
| 233 |
+
return results
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def plot_ablation_3d(
|
| 237 |
+
results: Dict,
|
| 238 |
+
save_path: Optional[str] = None,
|
| 239 |
+
):
|
| 240 |
+
"""
|
| 241 |
+
3D plot: X=diffusion_steps, Y=generation_speed(s/sample), Z=CER.
|
| 242 |
+
Also produces a 2D summary plot.
|
| 243 |
+
"""
|
| 244 |
+
try:
|
| 245 |
+
import matplotlib.pyplot as plt
|
| 246 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 247 |
+
except ImportError:
|
| 248 |
+
print("pip install matplotlib.")
|
| 249 |
+
return
|
| 250 |
+
|
| 251 |
+
T_list = sorted(results.keys())
|
| 252 |
+
cers = [results[T]["mean_cer"] for T in T_list]
|
| 253 |
+
speeds = [results[T]["speed_per_sample"] for T in T_list]
|
| 254 |
+
|
| 255 |
+
# ── 3D plot ───────────────────────────────────────────────────────
|
| 256 |
+
fig = plt.figure(figsize=(14, 5))
|
| 257 |
+
|
| 258 |
+
ax3d = fig.add_subplot(121, projection='3d')
|
| 259 |
+
ax3d.scatter(T_list, speeds, cers, c=cers, cmap='RdYlGn_r', s=80)
|
| 260 |
+
for T, s, c in zip(T_list, speeds, cers):
|
| 261 |
+
ax3d.text(T, s, c, f"T={T}", fontsize=8)
|
| 262 |
+
ax3d.set_xlabel("Diffusion steps T", fontsize=9)
|
| 263 |
+
ax3d.set_ylabel("Speed (s/sample)", fontsize=9)
|
| 264 |
+
ax3d.set_zlabel("CER (↓ better)", fontsize=9)
|
| 265 |
+
ax3d.set_title("T vs speed vs CER", fontsize=10)
|
| 266 |
+
|
| 267 |
+
# ── 2D CER vs T (find the knee) ──────────────────────────────────
|
| 268 |
+
ax2d = fig.add_subplot(122)
|
| 269 |
+
ax2d.plot(T_list, cers, 'o-', linewidth=1.8, color='coral', markersize=7)
|
| 270 |
+
for T, c in zip(T_list, cers):
|
| 271 |
+
ax2d.annotate(f"{c:.3f}", (T, c), textcoords="offset points",
|
| 272 |
+
xytext=(0, 8), fontsize=8, ha='center')
|
| 273 |
+
|
| 274 |
+
# Find knee: largest CER drop per unit T (elbow method)
|
| 275 |
+
if len(T_list) >= 3:
|
| 276 |
+
drops = [cers[i] - cers[i+1] for i in range(len(cers)-1)]
|
| 277 |
+
knee_i = int(np.argmax(drops))
|
| 278 |
+
knee_T = T_list[knee_i + 1]
|
| 279 |
+
ax2d.axvline(knee_T, color='steelblue', linestyle='--', linewidth=1.2,
|
| 280 |
+
label=f"Knee at T={knee_T}")
|
| 281 |
+
ax2d.legend(fontsize=9)
|
| 282 |
+
|
| 283 |
+
ax2d.set_xlabel("Diffusion steps T", fontsize=10)
|
| 284 |
+
ax2d.set_ylabel("CER (lower = better)", fontsize=10)
|
| 285 |
+
ax2d.set_title("CER vs diffusion steps", fontsize=10)
|
| 286 |
+
ax2d.set_ylim(0, max(cers) * 1.1)
|
| 287 |
+
|
| 288 |
+
plt.tight_layout()
|
| 289 |
+
if save_path:
|
| 290 |
+
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 291 |
+
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 292 |
+
print(f"Saved: {save_path}")
|
| 293 |
+
else:
|
| 294 |
+
plt.show()
|
| 295 |
+
plt.close()
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
# ── Adversarial robustness test (no retraining needed) ───────────────
|
| 299 |
+
|
| 300 |
+
def corrupt_iast(text: str, corruption_rate: float = 0.05) -> str:
|
| 301 |
+
"""
|
| 302 |
+
Introduce random corruption into IAST text:
|
| 303 |
+
- Character swap (adjacent chars swapped)
|
| 304 |
+
- Character deletion
|
| 305 |
+
- Random character insertion
|
| 306 |
+
|
| 307 |
+
Models rate as 5% to 20% corruption to test robustness.
|
| 308 |
+
"""
|
| 309 |
+
import random
|
| 310 |
+
chars = list(text)
|
| 311 |
+
n_corrupt = max(1, int(len(chars) * corruption_rate))
|
| 312 |
+
|
| 313 |
+
for _ in range(n_corrupt):
|
| 314 |
+
op = random.choice(['swap', 'delete', 'insert'])
|
| 315 |
+
pos = random.randint(0, len(chars) - 1)
|
| 316 |
+
|
| 317 |
+
if op == 'swap' and pos < len(chars) - 1:
|
| 318 |
+
chars[pos], chars[pos+1] = chars[pos+1], chars[pos]
|
| 319 |
+
elif op == 'delete' and len(chars) > 1:
|
| 320 |
+
chars.pop(pos)
|
| 321 |
+
elif op == 'insert':
|
| 322 |
+
chars.insert(pos, random.choice('abcdeimnostu'))
|
| 323 |
+
|
| 324 |
+
return "".join(chars)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
@torch.no_grad()
|
| 328 |
+
def run_adversarial_test(
|
| 329 |
+
model,
|
| 330 |
+
src_tokenizer,
|
| 331 |
+
tgt_tokenizer,
|
| 332 |
+
test_inputs: List[str],
|
| 333 |
+
test_refs: List[str],
|
| 334 |
+
corruption_rates: List[float] = [0.0, 0.05, 0.10, 0.15, 0.20],
|
| 335 |
+
device: torch.device = None,
|
| 336 |
+
output_dir: str = "analysis/outputs",
|
| 337 |
+
) -> Dict:
|
| 338 |
+
"""
|
| 339 |
+
Test if CER degrades proportionally with IAST corruption.
|
| 340 |
+
Uses existing trained model — no retraining.
|
| 341 |
+
"""
|
| 342 |
+
device = device or next(model.parameters()).device
|
| 343 |
+
results = {}
|
| 344 |
+
|
| 345 |
+
print("\nAdversarial robustness test...")
|
| 346 |
+
for rate in corruption_rates:
|
| 347 |
+
cer_list = []
|
| 348 |
+
for text, ref in zip(test_inputs, test_refs):
|
| 349 |
+
corrupted = corrupt_iast(text, rate)
|
| 350 |
+
ids = src_tokenizer.encode(corrupted)
|
| 351 |
+
src = torch.tensor([ids], dtype=torch.long, device=device)
|
| 352 |
+
|
| 353 |
+
if hasattr(model.model, 'generate_cached'):
|
| 354 |
+
out = model.model.generate_cached(src)
|
| 355 |
+
else:
|
| 356 |
+
out = model.generate(src)
|
| 357 |
+
|
| 358 |
+
pred_ids = [x for x in out[0].tolist() if x > 4]
|
| 359 |
+
pred = tgt_tokenizer.decode(pred_ids).strip()
|
| 360 |
+
cer_list.append(compute_cer(pred, ref))
|
| 361 |
+
|
| 362 |
+
mean_cer = float(np.mean(cer_list))
|
| 363 |
+
results[rate] = mean_cer
|
| 364 |
+
print(f" corruption={rate*100:.0f}% → CER={mean_cer:.4f}")
|
| 365 |
+
|
| 366 |
+
# Save + plot
|
| 367 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 368 |
+
try:
|
| 369 |
+
import matplotlib.pyplot as plt
|
| 370 |
+
fig, ax = plt.subplots(figsize=(8, 4))
|
| 371 |
+
rates = [r * 100 for r in corruption_rates]
|
| 372 |
+
cers = [results[r] for r in corruption_rates]
|
| 373 |
+
ax.plot(rates, cers, 'o-', linewidth=1.8, color='steelblue', markersize=7)
|
| 374 |
+
ax.set_xlabel("IAST corruption rate (%)", fontsize=11)
|
| 375 |
+
ax.set_ylabel("CER", fontsize=11)
|
| 376 |
+
ax.set_title("Model robustness to IAST input corruption", fontsize=11)
|
| 377 |
+
ax.set_ylim(0, max(cers) * 1.2)
|
| 378 |
+
plt.tight_layout()
|
| 379 |
+
plt.savefig(os.path.join(output_dir, "adversarial_robustness.png"),
|
| 380 |
+
dpi=150, bbox_inches='tight')
|
| 381 |
+
plt.close()
|
| 382 |
+
print(f" Saved: {output_dir}/adversarial_robustness.png")
|
| 383 |
+
except ImportError:
|
| 384 |
+
pass
|
| 385 |
+
|
| 386 |
+
with open(os.path.join(output_dir, "adversarial_results.json"), "w") as f:
|
| 387 |
+
json.dump({str(k): v for k, v in results.items()}, f, indent=2)
|
| 388 |
+
|
| 389 |
+
return results
|
tokenizer.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tokenizer.py — Dual Tokenizer Fix
|
| 3 |
+
====================================
|
| 4 |
+
Two separate BPE tokenizers:
|
| 5 |
+
|
| 6 |
+
SanskritSourceTokenizer — trained on quote_text (Roman/IAST script)
|
| 7 |
+
SanskritTargetTokenizer — trained on quote_devanagari (Devanagari script)
|
| 8 |
+
|
| 9 |
+
WHY SEPARATE?
|
| 10 |
+
Roman Sanskrit and Devanagari are fundamentally different character sets.
|
| 11 |
+
Roman uses a-z + diacritics (~60 unique chars), Devanagari uses ā-ह + matras
|
| 12 |
+
(~100+ unique chars). A shared BPE tokenizer wastes half its vocab on
|
| 13 |
+
character combos that never cross scripts, and forces the embedding table
|
| 14 |
+
to encode both scripts in one space — confusing the model's cross-attention.
|
| 15 |
+
|
| 16 |
+
With separate tokenizers:
|
| 17 |
+
- src vocab captures Roman subwords cleanly (ā, ś, ṭ, ṃ etc.)
|
| 18 |
+
- tgt vocab captures Devanagari akshara clusters cleanly (क्ष, त्र, etc.)
|
| 19 |
+
- The model learns a true cross-script mapping in its cross-attention
|
| 20 |
+
|
| 21 |
+
SPECIAL TOKENS (same IDs in both):
|
| 22 |
+
[MASK] = 0 ← required by absorbing diffusion
|
| 23 |
+
[PAD] = 1
|
| 24 |
+
[UNK] = 2
|
| 25 |
+
[CLS] = 3
|
| 26 |
+
[SEP] = 4
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from tokenizers import Tokenizer
|
| 30 |
+
from tokenizers.models import BPE
|
| 31 |
+
from tokenizers.trainers import BpeTrainer
|
| 32 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 33 |
+
from datasets import load_dataset
|
| 34 |
+
from pathlib import Path
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
SPECIAL_TOKENS = ["[MASK]", "[PAD]", "[UNK]", "[CLS]", "[SEP]"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _build_bpe(texts, vocab_size):
|
| 41 |
+
"""Build a BPE tokenizer from an iterator of strings."""
|
| 42 |
+
tok = Tokenizer(BPE(unk_token="[UNK]"))
|
| 43 |
+
tok.pre_tokenizer = Whitespace()
|
| 44 |
+
trainer = BpeTrainer(
|
| 45 |
+
vocab_size=vocab_size,
|
| 46 |
+
special_tokens=SPECIAL_TOKENS, # [MASK] MUST be first → id=0
|
| 47 |
+
min_frequency=2,
|
| 48 |
+
)
|
| 49 |
+
tok.train_from_iterator(texts, trainer)
|
| 50 |
+
return tok
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _validate(tok, name):
|
| 54 |
+
mask_id = tok.token_to_id("[MASK]")
|
| 55 |
+
pad_id = tok.token_to_id("[PAD]")
|
| 56 |
+
assert mask_id == 0, f"{name}: [MASK] must be id=0, got {mask_id}"
|
| 57 |
+
assert pad_id == 1, f"{name}: [PAD] must be id=1, got {pad_id}"
|
| 58 |
+
print(f"✅ {name}: [MASK]=0, [PAD]=1 confirmed. Vocab size={tok.get_vocab_size()}")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ── Source tokenizer (Roman/IAST Sanskrit) ────────────────────────────
|
| 62 |
+
|
| 63 |
+
class SanskritSourceTokenizer:
|
| 64 |
+
"""
|
| 65 |
+
Tokenizer for quote_text — Roman transliteration of Sanskrit.
|
| 66 |
+
Examples: "dharmo rakṣati rakṣitaḥ", "yatra nāryastu pūjyante"
|
| 67 |
+
"""
|
| 68 |
+
MODEL_PATH = "sanskrit_src_tokenizer.json"
|
| 69 |
+
|
| 70 |
+
def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
|
| 71 |
+
self.vocab_size = vocab_size
|
| 72 |
+
self.max_len = max_len
|
| 73 |
+
self.mask_token_id = 0
|
| 74 |
+
|
| 75 |
+
if Path(self.MODEL_PATH).exists():
|
| 76 |
+
print(f"📖 Loading source tokenizer from {self.MODEL_PATH} …")
|
| 77 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 78 |
+
else:
|
| 79 |
+
print("🎓 Training source tokenizer on quote_text …")
|
| 80 |
+
self._train(vocab_size, n_train_samples)
|
| 81 |
+
|
| 82 |
+
_validate(self.tokenizer, "SrcTokenizer")
|
| 83 |
+
|
| 84 |
+
def _train(self, vocab_size, n_samples):
|
| 85 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 86 |
+
n = min(n_samples, len(dataset))
|
| 87 |
+
texts = [s["quote_text"] for s in dataset.select(range(n))
|
| 88 |
+
if s["quote_text"].strip()]
|
| 89 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 90 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 91 |
+
print(f"✅ Source tokenizer trained on {len(texts)} Roman texts.")
|
| 92 |
+
|
| 93 |
+
def encode(self, text):
|
| 94 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 95 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 96 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 97 |
+
return ids[:self.max_len]
|
| 98 |
+
|
| 99 |
+
def decode(self, ids):
|
| 100 |
+
clean = [i for i in ids if i > 4] # skip special tokens
|
| 101 |
+
return self.tokenizer.decode(clean)
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
return self.vocab_size
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ── Target tokenizer (Devanagari Sanskrit) ───────────────────────────
|
| 108 |
+
|
| 109 |
+
class SanskritTargetTokenizer:
|
| 110 |
+
"""
|
| 111 |
+
Tokenizer for quote_devanagari — Devanagari script.
|
| 112 |
+
Examples: "धर्मो रक्षति रक्षितः", "यत्र नार्यस्तु पूज्यन्ते"
|
| 113 |
+
"""
|
| 114 |
+
MODEL_PATH = "sanskrit_tgt_tokenizer.json"
|
| 115 |
+
|
| 116 |
+
def __init__(self, vocab_size=8000, max_len=80, n_train_samples=50000):
|
| 117 |
+
self.vocab_size = vocab_size
|
| 118 |
+
self.max_len = max_len
|
| 119 |
+
self.mask_token_id = 0
|
| 120 |
+
|
| 121 |
+
if Path(self.MODEL_PATH).exists():
|
| 122 |
+
print(f"📖 Loading target tokenizer from {self.MODEL_PATH} …")
|
| 123 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 124 |
+
else:
|
| 125 |
+
print("🎓 Training target tokenizer on quote_devanagari …")
|
| 126 |
+
self._train(vocab_size, n_train_samples)
|
| 127 |
+
|
| 128 |
+
_validate(self.tokenizer, "TgtTokenizer")
|
| 129 |
+
|
| 130 |
+
def _train(self, vocab_size, n_samples):
|
| 131 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 132 |
+
n = min(n_samples, len(dataset))
|
| 133 |
+
texts = [s["quote_devanagari"] for s in dataset.select(range(n))
|
| 134 |
+
if s["quote_devanagari"].strip()]
|
| 135 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 136 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 137 |
+
print(f"✅ Target tokenizer trained on {len(texts)} Devanagari texts.")
|
| 138 |
+
|
| 139 |
+
def encode(self, text):
|
| 140 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 141 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 142 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 143 |
+
return ids[:self.max_len]
|
| 144 |
+
|
| 145 |
+
def decode(self, ids):
|
| 146 |
+
clean = [i for i in ids if i > 4]
|
| 147 |
+
return self.tokenizer.decode(clean)
|
| 148 |
+
|
| 149 |
+
# Methods required by BERTScore
|
| 150 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 151 |
+
return list(token_ids)
|
| 152 |
+
|
| 153 |
+
def get_vocab(self):
|
| 154 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 155 |
+
|
| 156 |
+
def convert_ids_to_tokens(self, ids):
|
| 157 |
+
return [str(i) for i in ids]
|
| 158 |
+
|
| 159 |
+
def __len__(self):
|
| 160 |
+
return self.vocab_size
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ── Legacy shared tokenizer (kept for backward compat) ───────────────
|
| 164 |
+
|
| 165 |
+
class SanskritTokenizer:
|
| 166 |
+
"""
|
| 167 |
+
LEGACY: single shared tokenizer trained on BOTH scripts.
|
| 168 |
+
Still works but suboptimal — use SanskritSourceTokenizer +
|
| 169 |
+
SanskritTargetTokenizer for the quote_text → quote_devanagari task.
|
| 170 |
+
"""
|
| 171 |
+
MODEL_PATH = "sanskrit_tokenizer_m4pro.json"
|
| 172 |
+
|
| 173 |
+
def __init__(self, vocab_size=16000, max_len=80):
|
| 174 |
+
self.vocab_size = vocab_size
|
| 175 |
+
self.max_len = max_len
|
| 176 |
+
self.mask_token_id = 0
|
| 177 |
+
|
| 178 |
+
if Path(self.MODEL_PATH).exists():
|
| 179 |
+
print("📖 Loading shared tokenizer …")
|
| 180 |
+
self.tokenizer = Tokenizer.from_file(self.MODEL_PATH)
|
| 181 |
+
else:
|
| 182 |
+
print("🎓 Training shared tokenizer on both scripts …")
|
| 183 |
+
self._train(vocab_size)
|
| 184 |
+
|
| 185 |
+
_validate(self.tokenizer, "SharedTokenizer")
|
| 186 |
+
|
| 187 |
+
def _train(self, vocab_size):
|
| 188 |
+
dataset = load_dataset("paws/sanskrit-verses-gretil", split="train")
|
| 189 |
+
n = min(50000, len(dataset))
|
| 190 |
+
texts = []
|
| 191 |
+
for s in dataset.select(range(n)):
|
| 192 |
+
if s["quote_text"].strip():
|
| 193 |
+
texts.append(s["quote_text"])
|
| 194 |
+
if s["quote_devanagari"].strip():
|
| 195 |
+
texts.append(s["quote_devanagari"])
|
| 196 |
+
self.tokenizer = _build_bpe(texts, vocab_size)
|
| 197 |
+
self.tokenizer.save(self.MODEL_PATH)
|
| 198 |
+
print(f"✅ Shared tokenizer trained ({len(texts)} texts).")
|
| 199 |
+
|
| 200 |
+
def encode(self, text):
|
| 201 |
+
ids = self.tokenizer.encode(text).ids[:self.max_len]
|
| 202 |
+
pad = self.tokenizer.token_to_id("[PAD]")
|
| 203 |
+
ids += [pad] * max(0, self.max_len - len(ids))
|
| 204 |
+
return ids[:self.max_len]
|
| 205 |
+
|
| 206 |
+
def decode(self, ids):
|
| 207 |
+
if ids and isinstance(ids[0], list):
|
| 208 |
+
raise TypeError("decode() got 2D list — pass a 1D list.")
|
| 209 |
+
clean = [i for i in ids if i > 4]
|
| 210 |
+
return self.tokenizer.decode(clean)
|
| 211 |
+
|
| 212 |
+
def build_inputs_with_special_tokens(self, token_ids):
|
| 213 |
+
return list(token_ids)
|
| 214 |
+
|
| 215 |
+
def get_vocab(self):
|
| 216 |
+
return {str(i): i for i in range(self.vocab_size)}
|
| 217 |
+
|
| 218 |
+
def convert_ids_to_tokens(self, ids):
|
| 219 |
+
return [str(i) for i in ids]
|
| 220 |
+
|
| 221 |
+
def __len__(self):
|
| 222 |
+
return self.vocab_size
|
train_all.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
# Run this script to train all ablation models
|
| 4 |
+
|
| 5 |
+
MODEL_TYPE=${MODEL_TYPE:-d3pm_cross_attention}
|
| 6 |
+
INCLUDE_NEG=${INCLUDE_NEG:-False}
|
| 7 |
+
TRAIN_DEVICE=${TRAIN_DEVICE:-mps}
|
| 8 |
+
|
| 9 |
+
echo '=== Training T=4 ==='
|
| 10 |
+
mkdir -p ablation_results/T4
|
| 11 |
+
MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=4 INFERENCE_NUM_STEPS=4 TRAIN_OUTPUT_DIR="ablation_results/T4" python train.py
|
| 12 |
+
|
| 13 |
+
echo '=== Training T=8 ==='
|
| 14 |
+
mkdir -p ablation_results/T8
|
| 15 |
+
MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=8 INFERENCE_NUM_STEPS=8 TRAIN_OUTPUT_DIR="ablation_results/T8" python train.py
|
| 16 |
+
|
| 17 |
+
echo '=== Training T=16 ==='
|
| 18 |
+
mkdir -p ablation_results/T16
|
| 19 |
+
MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=16 INFERENCE_NUM_STEPS=16 TRAIN_OUTPUT_DIR="ablation_results/T16" python train.py
|
| 20 |
+
|
| 21 |
+
echo '=== Training T=32 ==='
|
| 22 |
+
mkdir -p ablation_results/T32
|
| 23 |
+
MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=32 INFERENCE_NUM_STEPS=32 TRAIN_OUTPUT_DIR="ablation_results/T32" python train.py
|
| 24 |
+
|
| 25 |
+
echo '=== Training T=64 ==='
|
| 26 |
+
mkdir -p ablation_results/T64
|
| 27 |
+
MODEL_TYPE="$MODEL_TYPE" INCLUDE_NEG="$INCLUDE_NEG" TRAIN_DEVICE="$TRAIN_DEVICE" DIFFUSION_STEPS=64 INFERENCE_NUM_STEPS=64 TRAIN_OUTPUT_DIR="ablation_results/T64" python train.py
|
| 28 |
+
|