Spaces:
Running
Running
| """ | |
| eval.py | |
| ======= | |
| Unified Evaluator β CIDEr across all four VLM architectures. | |
| This module: | |
| 1. Evaluates each model's baseline CIDEr on the COCO validation set | |
| 2. Delegates ablation studies to experiments/ablation_study.py | |
| 3. Provides a unified cross-model comparison table | |
| Weight Selection (--weights flag): | |
| base β Use pretrained HuggingFace weights (no fine-tuning) | |
| finetuned β Load from outputs/{model}/latest/ | |
| best β Load from outputs/{model}/best/ | |
| Usage: | |
| python eval.py # BLIP base weights | |
| python eval.py --model blip --weights best # BLIP best fine-tuned | |
| python eval.py --model all # All 4 models | |
| python eval.py --model all --weights best # All 4 models, best weights | |
| python eval.py --ablation # BLIP 4-mode ablation | |
| python eval.py --sweep # Decoding parameter sweep | |
| """ | |
| import os | |
| import argparse | |
| import torch | |
| from typing import Optional | |
| from tqdm.auto import tqdm | |
| from pycocoevalcap.cider.cider import Cider | |
| from config import CFG | |
| from data_prep import get_dataloaders, get_dataloaders_for_model | |
| from models.blip_tuner import get_blip_model, load_ckpt, generate_with_mask | |
| from experiments.ablation_study import run_ablation_study | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Device Helper | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_device(): | |
| if torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| elif torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| return torch.device("cpu") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Weight Loading Helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_weights_dir(cfg, model_name: str, weights: str) -> Optional[str]: | |
| """ | |
| Return the checkpoint directory for the given model and weight selection. | |
| Args: | |
| cfg : CFG instance | |
| model_name : 'blip', 'vit_gpt2', 'git', 'custom' | |
| weights : 'base', 'finetuned', 'best' | |
| Returns: | |
| Absolute path to checkpoint dir, or None for base weights. | |
| """ | |
| if weights == "base": | |
| return None | |
| subdir = "latest" if weights == "finetuned" else "best" | |
| path = os.path.join(cfg.output_root, model_name, subdir) | |
| if os.path.isdir(path) and os.listdir(path): | |
| return path | |
| print(f"β οΈ No {subdir} checkpoint found at {path}. Falling back to base weights.") | |
| return None | |
| def print_weights_banner(model_name: str, weights: str, ckpt_dir: Optional[str]): | |
| """Print a clear banner showing which weights are being used.""" | |
| print("=" * 60) | |
| print(f" Model: {model_name}") | |
| if ckpt_dir: | |
| print(f" Weights: {weights} β {ckpt_dir}") | |
| else: | |
| print(f" Weights: base (pretrained, no fine-tuning)") | |
| print("=" * 60) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # BLIP Baseline CIDEr Evaluation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate_blip(model, processor, dataloader, device, | |
| num_beams=4, max_new_tokens=32, length_penalty=1.0, | |
| eval_batches=25): | |
| """Evaluate BLIP CIDEr score (full attention β no ablation masking).""" | |
| model.eval() | |
| gts, res = {}, {} | |
| with torch.no_grad(): | |
| for i, batch in enumerate(tqdm(dataloader, desc="Eval [BLIP]")): | |
| if i >= eval_batches: | |
| break | |
| pixel_values = batch["pixel_values"].to(device) | |
| B = pixel_values.shape[0] | |
| mask = torch.ones(B, 197, dtype=torch.long, device=device) | |
| decoded = generate_with_mask( | |
| model, processor, device=device, | |
| pixel_values=pixel_values, | |
| encoder_attention_mask=mask, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| ) | |
| preds = decoded # generate_with_mask already returns decoded strings | |
| labels = batch["labels"].clone() | |
| gt_texts = processor.batch_decode(labels, skip_special_tokens=True) | |
| for j, (p, g) in enumerate(zip(preds, gt_texts)): | |
| k = str(i * len(preds) + j) | |
| res[k] = [p] | |
| gts[k] = [g] | |
| if not gts: | |
| return 0.0 | |
| scorer = Cider() | |
| score, _ = scorer.compute_score(gts, res) | |
| print(f" β CIDEr [BLIP]: {score:.4f}") | |
| return score | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ViT-GPT2 CIDEr Evaluation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate_vit_gpt2(model, tokenizer, dataloader, device, | |
| num_beams=4, max_new_tokens=32, length_penalty=1.0, | |
| eval_batches=25): | |
| """Evaluate ViT-GPT2 CIDEr score.""" | |
| model.eval() | |
| gts, res = {}, {} | |
| with torch.no_grad(): | |
| for i, batch in enumerate(tqdm(dataloader, desc="Eval [ViT-GPT2]")): | |
| if i >= eval_batches: | |
| break | |
| pixel_values = batch["pixel_values"].to(device) | |
| out = model.generate( | |
| pixel_values=pixel_values, | |
| num_beams=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| length_penalty=length_penalty, | |
| ) | |
| preds = [tokenizer.decode(ids, skip_special_tokens=True) for ids in out] | |
| labels = batch["labels"].clone() | |
| labels[labels == -100] = tokenizer.pad_token_id | |
| gt_texts = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
| for j, (p, g) in enumerate(zip(preds, gt_texts)): | |
| k = str(i * len(preds) + j) | |
| res[k] = [p] | |
| gts[k] = [g] | |
| if not gts: | |
| return 0.0 | |
| scorer = Cider() | |
| score, _ = scorer.compute_score(gts, res) | |
| print(f" β CIDEr [ViT-GPT2]: {score:.4f}") | |
| return score | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GIT CIDEr Evaluation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate_git(model, processor, dataloader, device, | |
| num_beams=4, max_new_tokens=32, length_penalty=1.0, | |
| eval_batches=25): | |
| """Evaluate GIT CIDEr score.""" | |
| model.eval() | |
| gts, res = {}, {} | |
| with torch.no_grad(): | |
| for i, batch in enumerate(tqdm(dataloader, desc="Eval [GIT]")): | |
| if i >= eval_batches: | |
| break | |
| inputs = {k: v.to(device) for k, v in batch.items() | |
| if k in ("pixel_values", "input_ids", "attention_mask")} | |
| out = model.generate( | |
| **inputs, | |
| num_beams=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| length_penalty=length_penalty, | |
| ) | |
| preds = processor.batch_decode(out, skip_special_tokens=True) | |
| labels = batch["labels"].clone() | |
| labels[labels == -100] = processor.tokenizer.pad_token_id | |
| gt_texts = processor.batch_decode(labels, skip_special_tokens=True) | |
| for j, (p, g) in enumerate(zip(preds, gt_texts)): | |
| k = str(i * len(preds) + j) | |
| res[k] = [p] | |
| gts[k] = [g] | |
| if not gts: | |
| return 0.0 | |
| scorer = Cider() | |
| score, _ = scorer.compute_score(gts, res) | |
| print(f" β CIDEr [GIT]: {score:.4f}") | |
| return score | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Custom VLM CIDEr Evaluation | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate_custom_vlm_cider(model, val_loader, device, | |
| char_to_idx, idx_to_char, | |
| max_new_tokens=80, num_beams=1, | |
| length_penalty=1.0, | |
| eval_batches=20): | |
| """Evaluate CIDEr score for the CustomVLM using autoregressive generation.""" | |
| model.eval() | |
| gts, res = {}, {} | |
| print("\nEvaluating Custom VLM (Visual Prefix-Tuning)...") | |
| with torch.no_grad(): | |
| for i, batch in enumerate(tqdm(val_loader, desc="Eval [CustomVLM]")): | |
| if i >= eval_batches: | |
| break | |
| pixel_values = batch["pixel_values"].to(device) | |
| B = pixel_values.shape[0] | |
| for b in range(B): | |
| pv_single = pixel_values[b:b+1] | |
| if num_beams > 1: | |
| pred = model.generate_beam( | |
| pv_single, char_to_idx, idx_to_char, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| length_penalty=length_penalty, | |
| ) | |
| else: | |
| pred = model.generate( | |
| pv_single, char_to_idx, idx_to_char, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| tgt_ids = batch["text_targets"][b].tolist() | |
| gt_text = "".join(idx_to_char.get(idx, "") for idx in tgt_ids if idx != 0) | |
| idx_key = str(i * B + b) | |
| res[idx_key] = [pred.strip()] | |
| gts[idx_key] = [gt_text.strip()] | |
| if not gts: | |
| return 0.0 | |
| scorer = Cider() | |
| score, _ = scorer.compute_score(gts, res) | |
| print(f" β CIDEr [CustomVLM]: {score:.4f}") | |
| return score | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Custom VLM Loader (with weight selection) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_custom_vlm_for_eval(cfg, device, weights="base"): | |
| """ | |
| Load CustomVLM with the specified weight selection. | |
| Args: | |
| weights: 'base' (Shakespeare only), 'finetuned' (latest ckpt), 'best' (best ckpt) | |
| """ | |
| from models.custom_vlm import CustomVLM, build_char_vocab | |
| from data_prep import get_custom_vlm_dataloader | |
| with open(cfg.shakespeare_file, "r") as f: | |
| text = f.read() | |
| _, c2i, i2c, vs = build_char_vocab(text) | |
| model = CustomVLM( | |
| vocab_size=vs, | |
| text_embed_dim=cfg.text_embed_dim, | |
| n_heads=cfg.n_heads, | |
| n_layers=cfg.n_layers, | |
| block_size=cfg.block_size, | |
| dropout=cfg.dropout, | |
| ) | |
| # Always load Shakespeare weights first | |
| if os.path.exists(cfg.shakespeare_weights_path): | |
| model.load_shakespeare_weights(cfg.shakespeare_weights_path) | |
| # Then optionally load fine-tuned weights on top | |
| ckpt_dir = get_weights_dir(cfg, "custom_vlm", weights) | |
| if ckpt_dir: | |
| ckpt_path = os.path.join(ckpt_dir, "custom_vlm.pt") | |
| if os.path.exists(ckpt_path): | |
| state = torch.load(ckpt_path, map_location="cpu") | |
| # Filter shape mismatches gracefully | |
| own_state = model.state_dict() | |
| filtered = {k: v for k, v in state["model_state"].items() | |
| if k in own_state and own_state[k].shape == v.shape} | |
| model.load_state_dict(filtered, strict=False) | |
| print(f" β Loaded fine-tuned weights from {ckpt_path}") | |
| print_weights_banner("Custom VLM", weights, ckpt_dir) | |
| model.to(device).eval() | |
| _, val_loader = get_custom_vlm_dataloader(cfg, c2i) | |
| return model, c2i, i2c, val_loader | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # All-Model Comparison Table | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate_all_models(cfg, device, weights="base", | |
| num_beams=4, max_new_tokens=32, | |
| length_penalty=1.0, eval_batches=25): | |
| """Run CIDEr evaluation for all four models and print a comparison table.""" | |
| results = {} | |
| # ββ BLIP ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n[1/4] Evaluating BLIP...") | |
| blip_cfg = CFG.load_for_model("blip") | |
| model_b, proc_b = get_blip_model(blip_cfg, device) | |
| ckpt = get_weights_dir(blip_cfg, "blip", weights) | |
| if ckpt: | |
| load_ckpt(model_b, None, None, ckpt) | |
| print_weights_banner("BLIP", weights, ckpt) | |
| _, val_b = get_dataloaders(blip_cfg, proc_b) | |
| results["BLIP"] = evaluate_blip( | |
| model_b, proc_b, val_b, device, | |
| num_beams=num_beams, max_new_tokens=max_new_tokens, | |
| length_penalty=length_penalty, eval_batches=eval_batches, | |
| ) | |
| del model_b, proc_b | |
| # ββ ViT-GPT2 ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n[2/4] Evaluating ViT-GPT2...") | |
| from models.vit_gpt2_tuner import get_vit_gpt2_model | |
| vg2_cfg = CFG.load_for_model("vit_gpt2") | |
| model_v, proc_v, tok_v = get_vit_gpt2_model(vg2_cfg, device) | |
| ckpt = get_weights_dir(vg2_cfg, "vit_gpt2", weights) | |
| if ckpt: | |
| from transformers import VisionEncoderDecoderModel | |
| finetuned = VisionEncoderDecoderModel.from_pretrained(ckpt) | |
| model_v.load_state_dict(finetuned.state_dict()) | |
| model_v.to(device) | |
| print_weights_banner("ViT-GPT2", weights, ckpt) | |
| _, val_v = get_dataloaders_for_model(vg2_cfg, "vit_gpt2", proc_v, tok_v) | |
| results["ViT-GPT2"] = evaluate_vit_gpt2( | |
| model_v, tok_v, val_v, device, | |
| num_beams=num_beams, max_new_tokens=max_new_tokens, | |
| length_penalty=length_penalty, eval_batches=eval_batches, | |
| ) | |
| del model_v, proc_v, tok_v | |
| # ββ GIT βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n[3/4] Evaluating GIT...") | |
| from models.git_tuner import get_git_model | |
| git_cfg = CFG.load_for_model("git") | |
| model_g, proc_g = get_git_model(git_cfg, device) | |
| ckpt = get_weights_dir(git_cfg, "git", weights) | |
| if ckpt: | |
| from transformers import AutoModelForCausalLM | |
| finetuned = AutoModelForCausalLM.from_pretrained(ckpt) | |
| model_g.load_state_dict(finetuned.state_dict()) | |
| model_g.to(device) | |
| print_weights_banner("GIT", weights, ckpt) | |
| _, val_g = get_dataloaders_for_model(git_cfg, "git", proc_g) | |
| results["GIT"] = evaluate_git( | |
| model_g, proc_g, val_g, device, | |
| num_beams=num_beams, max_new_tokens=max_new_tokens, | |
| length_penalty=length_penalty, eval_batches=eval_batches, | |
| ) | |
| del model_g, proc_g | |
| # ββ Custom VLM ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n[4/4] Evaluating Custom VLM...") | |
| vlm_cfg = CFG.load_for_model("custom") | |
| model_c, c2i, i2c, val_c = load_custom_vlm_for_eval(vlm_cfg, device, weights) | |
| results["CustomVLM"] = evaluate_custom_vlm_cider( | |
| model_c, val_c, device, c2i, i2c, | |
| max_new_tokens=80, eval_batches=15, | |
| ) | |
| del model_c | |
| # ββ Summary Table ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print("\n") | |
| print("=" * 65) | |
| print(f" All-Model CIDEr Comparison | Weights: {weights}") | |
| print(f" Beams={num_beams} MaxTok={max_new_tokens} LenPen={length_penalty}") | |
| print("=" * 65) | |
| print(f" {'Architecture':<22} {'CIDEr':>8} {'CA Type'}") | |
| print(" " + "-" * 61) | |
| ca_types = { | |
| "BLIP": "Gated MED cross-attention", | |
| "ViT-GPT2": "Standard full cross-attention", | |
| "GIT": "Self-attention prefix (no CA)", | |
| "CustomVLM": "Linear bridge prefix (no CA)", | |
| } | |
| for name, score in sorted(results.items(), key=lambda x: -x[1]): | |
| print(f" {name:<22} {score:>8.4f} {ca_types.get(name, '')}") | |
| print("=" * 65) | |
| return results | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Unified VLM Evaluator") | |
| parser.add_argument( | |
| "--model", type=str, default="blip", | |
| choices=["blip", "vit_gpt2", "git", "custom", "all"], | |
| help="Which model(s) to evaluate", | |
| ) | |
| parser.add_argument( | |
| "--weights", type=str, default="base", | |
| choices=["base", "finetuned", "best"], | |
| help="Which weights to use: base (pretrained), finetuned (latest/), best (best/)", | |
| ) | |
| parser.add_argument("--ablation", action="store_true", | |
| help="Run BLIP 4-mode cross-attention ablation study") | |
| parser.add_argument("--sweep", action="store_true", | |
| help="Run decoding parameter sweep") | |
| parser.add_argument("--num_beams", type=int, default=10) | |
| parser.add_argument("--max_new_tokens", type=int, default=50) | |
| parser.add_argument("--length_penalty", type=float, default=1.2) | |
| parser.add_argument("--eval_batches", type=int, default=25) | |
| args = parser.parse_args() | |
| device = get_device() | |
| print(f"β Device: {device}") | |
| if args.model == "all": | |
| cfg = CFG.load_for_model("blip") | |
| evaluate_all_models( | |
| cfg, device, | |
| weights=args.weights, | |
| num_beams=args.num_beams, | |
| max_new_tokens=args.max_new_tokens, | |
| length_penalty=args.length_penalty, | |
| eval_batches=args.eval_batches, | |
| ) | |
| return | |
| cfg = CFG.load_for_model(args.model) | |
| if args.model == "blip" or args.ablation: | |
| model, processor = get_blip_model(cfg, device) | |
| ckpt_dir = get_weights_dir(cfg, "blip", args.weights) | |
| if ckpt_dir: | |
| load_ckpt(model, None, None, ckpt_dir) | |
| print_weights_banner("BLIP", args.weights, ckpt_dir) | |
| _, val_loader = get_dataloaders(cfg, processor) | |
| if args.ablation: | |
| run_ablation_study( | |
| model, processor, val_loader, device, cfg, | |
| num_beams=args.num_beams, | |
| max_new_tokens=args.max_new_tokens, | |
| length_penalty=args.length_penalty, | |
| eval_batches=args.eval_batches, | |
| ) | |
| elif args.sweep: | |
| from experiments.parameter_sweep import run_parameter_sweep | |
| run_parameter_sweep( | |
| "blip", | |
| {"model": model, "processor": processor}, | |
| val_loader, device, | |
| eval_batches=args.eval_batches, | |
| ) | |
| else: | |
| evaluate_blip( | |
| model, processor, val_loader, device, | |
| num_beams=args.num_beams, | |
| max_new_tokens=args.max_new_tokens, | |
| length_penalty=args.length_penalty, | |
| eval_batches=args.eval_batches, | |
| ) | |
| elif args.model == "vit_gpt2": | |
| from models.vit_gpt2_tuner import get_vit_gpt2_model | |
| model, processor, tokenizer = get_vit_gpt2_model(cfg, device) | |
| ckpt_dir = get_weights_dir(cfg, "vit_gpt2", args.weights) | |
| if ckpt_dir: | |
| from transformers import VisionEncoderDecoderModel | |
| finetuned = VisionEncoderDecoderModel.from_pretrained(ckpt_dir) | |
| model.load_state_dict(finetuned.state_dict()) | |
| model.to(device) | |
| print_weights_banner("ViT-GPT2", args.weights, ckpt_dir) | |
| _, val_loader = get_dataloaders_for_model(cfg, "vit_gpt2", processor, tokenizer) | |
| evaluate_vit_gpt2( | |
| model, tokenizer, val_loader, device, | |
| num_beams=args.num_beams, | |
| max_new_tokens=args.max_new_tokens, | |
| length_penalty=args.length_penalty, | |
| eval_batches=args.eval_batches, | |
| ) | |
| elif args.model == "git": | |
| from models.git_tuner import get_git_model | |
| model, processor = get_git_model(cfg, device) | |
| ckpt_dir = get_weights_dir(cfg, "git", args.weights) | |
| if ckpt_dir: | |
| from transformers import AutoModelForCausalLM | |
| finetuned = AutoModelForCausalLM.from_pretrained(ckpt_dir) | |
| model.load_state_dict(finetuned.state_dict()) | |
| model.to(device) | |
| print_weights_banner("GIT", args.weights, ckpt_dir) | |
| _, val_loader = get_dataloaders_for_model(cfg, "git", processor) | |
| evaluate_git( | |
| model, processor, val_loader, device, | |
| num_beams=args.num_beams, | |
| max_new_tokens=args.max_new_tokens, | |
| length_penalty=args.length_penalty, | |
| eval_batches=args.eval_batches, | |
| ) | |
| elif args.model == "custom": | |
| vlm_cfg = CFG.load_for_model("custom") | |
| model, c2i, i2c, val_loader = load_custom_vlm_for_eval( | |
| vlm_cfg, device, args.weights) | |
| evaluate_custom_vlm_cider( | |
| model, val_loader, device, c2i, i2c, | |
| max_new_tokens=80, | |
| num_beams=args.num_beams, | |
| length_penalty=args.length_penalty, | |
| eval_batches=args.eval_batches, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |