devflow / inference.py
bhsinghgrid's picture
Upload 27 files
f8437ec verified
"""
inference.py
============
Correct D3PM inference for Sanskrit paraphrase generation.
The model's forward() takes CLEAN tgt and noises it internally.
So inference passes x0_estimate (starting all-[MASK]) as tgt each step,
letting the model noise it and then predict a cleaner version.
Also includes: robust checkpoint loading (auto-detects architecture
from saved weights β€” no CONFIG mismatch crashes).
"""
import torch
import torch.nn.functional as F
import os, sys
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from config import CONFIG
# ── Checkpoint loader ─────────────────────────────────────────────────
def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
"""
Auto-detect architecture from checkpoint weight shapes,
then load. Never fails due to CONFIG vs checkpoint mismatch.
"""
import copy
from model.sanskrit_model import SanskritModel
cfg = copy.deepcopy(base_cfg)
state = torch.load(ckpt_path, map_location='cpu')
# d_model + vocab_size
ek = 'model.src_embed.token_emb.weight'
if ek in state:
vocab, d = state[ek].shape
cfg['model']['vocab_size'] = vocab
cfg['model']['d_model'] = d
cfg['model']['d_ff'] = d * 4
# n_layers
ids = {int(k.split('.')[2]) for k in state if k.startswith('model.encoder_blocks.')}
if ids:
cfg['model']['n_layers'] = max(ids) + 1
# max_seq_len
pk = 'model.src_embed.pos_enc.pe'
if pk in state:
cfg['model']['max_seq_len'] = state[pk].shape[1]
# n_heads
d = cfg['model']['d_model']
h = cfg['model'].get('n_heads', 6)
if d % h != 0:
h = next(x for x in [8, 6, 4, 2, 1] if d % x == 0)
cfg['model']['n_heads'] = h
print(f"πŸ” Detected: d_model={cfg['model']['d_model']}, "
f"n_layers={cfg['model']['n_layers']}, "
f"max_seq_len={cfg['model']['max_seq_len']}, "
f"n_heads={cfg['model']['n_heads']}")
model = SanskritModel(cfg).to(device)
missing, unexpected = model.load_state_dict(
torch.load(ckpt_path, map_location=device), strict=False
)
# hint_gate may be absent in older checkpoints β€” initialise safely
allowed = {'model.hint_gate.0.weight', 'model.hint_gate.0.bias'}
real_missing = [k for k in missing if k not in allowed]
if real_missing:
print(f"⚠️ Missing keys: {real_missing[:3]} …")
if unexpected:
print(f"⚠️ Unexpected keys: {unexpected[:3]} …")
if hasattr(model.model, 'hint_gate') and 'model.hint_gate.0.weight' in missing:
with torch.no_grad():
w = model.model.hint_gate[0].weight
torch.nn.init.zeros_(model.model.hint_gate[0].bias)
torch.nn.init.eye_(w) if w.shape[0] == w.shape[1] \
else torch.nn.init.xavier_uniform_(w)
print("ℹ️ hint_gate initialised to identity (not in checkpoint).")
print("βœ… Model loaded.")
return model, cfg
# ── Core inference function ───────────────────────────────────────────
def run_inference(model, input_ids, cfg):
"""
Correct D3PM iterative refinement.
x0_est starts as all [MASK].
Each step: model(src, x0_est, t) noises x0_est internally,
then predicts a cleaner version. x0_est is updated each step.
"""
inf = cfg['inference']
device = input_ids.device
B, L = input_ids.shape
inner = model.model
T = inner.scheduler.num_timesteps
steps = inf['num_steps'] # must equal T (set in config)
step_size = max(1, T // steps)
timesteps = list(range(T - 1, -1, -step_size))
if timesteps[-1] != 0:
timesteps.append(0)
mask_id = inner.mask_token_id
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
hint = None
model.eval()
with torch.no_grad():
for step_idx, t_val in enumerate(timesteps):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = (step_idx == len(timesteps) - 1)
logits, _ = model(input_ids, x0_est, t, x0_hint=hint)
# Penalties
if inf['repetition_penalty'] != 1.0:
from model.d3pm_model_cross_attention import _apply_repetition_penalty
logits = _apply_repetition_penalty(
logits, x0_est, inf['repetition_penalty']
)
if inf['diversity_penalty'] > 0.0:
from model.d3pm_model_cross_attention import _apply_diversity_penalty
logits = _apply_diversity_penalty(logits, inf['diversity_penalty'])
logits = logits / max(inf['temperature'], 1e-5)
if inf['top_k'] > 0:
from model.d3pm_model_cross_attention import _top_k_filter
logits = _top_k_filter(logits, inf['top_k'])
probs = F.softmax(logits, dim=-1)
if is_last:
x0_est = torch.argmax(probs, dim=-1)
else:
from model.d3pm_model_cross_attention import _batch_multinomial
x0_est = _batch_multinomial(probs)
hint = x0_est
return x0_est
# ── Interactive demo ──────────────────────────────────────────────────
def interactive_demo():
from model.tokenizer import SanskritTokenizer
cfg = CONFIG
device = torch.device(cfg['training']['device'])
model_name = cfg['model_type']
has_neg = cfg['data']['include_negative_examples']
ckpt = f"results/{model_name}_neg_{has_neg}/best_model.pt"
if not os.path.exists(ckpt):
raise FileNotFoundError(f"No checkpoint at {ckpt} β€” train first.")
model, cfg = load_model(ckpt, cfg, device)
model.eval()
tokenizer = SanskritTokenizer(cfg['model']['vocab_size'])
PAD_ID = tokenizer.tokenizer.token_to_id('[PAD]') or 1
MASK_ID = cfg['diffusion']['mask_token_id']
print("\n" + "="*55)
print("Sanskrit D3PM Paraphrase β€” type verse, get paraphrase")
print("="*55 + "\n")
while True:
try:
text = input("INPUT > ").strip()
except (EOFError, KeyboardInterrupt):
break
if not text or text.lower() in ('quit', 'exit', 'q'):
break
ids = torch.tensor(
[tokenizer.encode(text)[:cfg['model']['max_seq_len']]],
dtype=torch.long, device=device
)
out = run_inference(model, ids, cfg)
clean = [i for i in out[0].tolist() if i not in (MASK_ID, PAD_ID)]
print(f"PARAPHRASE β†’ {tokenizer.decode(clean).strip()}\n")
# ── Batch evaluation ──────────────────────────────────────────────────
def batch_evaluate(sample_size=500):
from data.dataset import OptimizedSanskritDataset
from model.tokenizer import SanskritTokenizer
cfg = CONFIG
device = torch.device(cfg['training']['device'])
model_name = cfg['model_type']
has_neg = cfg['data']['include_negative_examples']
exp_dir = f"results/{model_name}_neg_{has_neg}"
ckpt = f"{exp_dir}/best_model.pt"
if not os.path.exists(ckpt):
raise FileNotFoundError(f"No checkpoint at {ckpt}")
model, cfg = load_model(ckpt, cfg, device)
model.eval()
tokenizer = SanskritTokenizer(cfg['model']['vocab_size'])
PAD_ID = tokenizer.tokenizer.token_to_id('[PAD]') or 1
MASK_ID = cfg['diffusion']['mask_token_id']
def collate(batch):
return {
'input_ids': torch.stack([b['input_ids'].long() for b in batch]),
'target_text': [b['target_text'] for b in batch],
'input_text': [b['input_text'] for b in batch],
}
dataset = OptimizedSanskritDataset('test', tokenizer, cfg['model']['max_seq_len'], cfg)
indices = list(range(min(sample_size, len(dataset))))
loader = DataLoader(
Subset(dataset, indices),
batch_size=cfg['training']['batch_size'],
shuffle=False, collate_fn=collate
)
all_preds, all_refs, all_inputs = [], [], []
print(f"⏳ Generating {len(indices)} paraphrases …")
for batch in tqdm(loader):
ids = batch['input_ids'].to(device)
out = run_inference(model, ids, cfg)
for i in range(out.size(0)):
clean = [x for x in out[i].tolist() if x not in (MASK_ID, PAD_ID)]
all_preds.append(tokenizer.decode(clean).strip())
all_refs.append(batch['target_text'][i].strip())
all_inputs.append(batch['input_text'][i].strip())
# Metrics
bleu_score, bert_f1 = 0.0, 0.0
try:
from nltk.translate.bleu_score import corpus_bleu
bleu_score = corpus_bleu(
[[r.split()] for r in all_refs],
[p.split() for p in all_preds]
)
except Exception:
pass
try:
import evaluate as hf_eval
res = hf_eval.load('bertscore').compute(
predictions=all_preds, references=all_refs, lang='hi'
)
bert_f1 = sum(res['f1']) / len(res['f1'])
except Exception:
pass
# Save
out_path = f"{exp_dir}/evaluation_results.txt"
with open(out_path, 'w', encoding='utf-8') as f:
f.write(f"Model : {model_name}\n")
f.write(f"Negatives: {has_neg}\n")
f.write(f"Steps : {cfg['inference']['num_steps']}\n")
f.write(f"Temp : {cfg['inference']['temperature']}\n")
f.write(f"RepPen : {cfg['inference']['repetition_penalty']}\n")
f.write(f"DivPen : {cfg['inference']['diversity_penalty']}\n")
f.write(f"BLEU : {bleu_score:.4f}\n")
f.write(f"BERTScore: {bert_f1:.4f}\n\n")
f.write("=== SAMPLES ===\n")
for i in range(min(20, len(all_preds))):
f.write(f"IN : {all_inputs[i]}\n")
f.write(f"REF : {all_refs[i]}\n")
f.write(f"PRED: {all_preds[i]}\n")
f.write("-" * 60 + "\n")
print(f"\nβœ… Results β†’ {out_path}")
print(f"πŸ“Š BLEU: {bleu_score:.4f} | BERTScore: {bert_f1:.4f}")
return all_preds, all_refs
if __name__ == '__main__':
import argparse
p = argparse.ArgumentParser()
p.add_argument('--mode', choices=['demo', 'eval'], default='demo')
p.add_argument('--samples', type=int, default=500)
args = p.parse_args()
if args.mode == 'demo':
interactive_demo()
else:
batch_evaluate(args.samples)