""" app.py ====== VLM Caption Lab — Premium Streamlit Demo Features: • Sidebar — Weight Source: Base / Fine-tuned (Best) / Fine-tuned (Latest) • Sidebar — Architecture selector, Generation Mode, Advanced Controls • Tab 1 — Caption: Single model captioning with weight selection • Tab 2 — Compare: Side-by-side 4-model comparison (same image, same config) • Tab 3 — Results: Pre-computed benchmark comparison tables """ import os import warnings import torch import streamlit as st from PIL import Image from models.blip_tuner import generate_with_mask warnings.filterwarnings("ignore", message="urllib3 v2 only supports OpenSSL") warnings.filterwarnings("ignore", category=UserWarning, message=".*use_fast.*") # ───────────────────────────────────────────────────────────────────────────── # Page Config & CSS # ───────────────────────────────────────────────────────────────────────────── st.set_page_config( page_title="VLM Caption Lab", page_icon="🔬", layout="wide", initial_sidebar_state="expanded", ) st.markdown(""" """, unsafe_allow_html=True) # ───────────────────────────────────────────────────────────────────────────── # Architecture Info & Constants # ───────────────────────────────────────────────────────────────────────────── ARCH_INFO = { "BLIP (Multimodal Mixture Attention)": ( "🔵 BLIP uses a Mixture-of-Encoder-Decoder (MED) architecture. " "Gated cross-attention is injected between self-attention and FFN layers." ), "ViT-GPT2 (Standard Cross-Attention)": ( "🟣 ViT-GPT2: every GPT-2 text token attends to all " "197 ViT patch embeddings via full cross-attention at every decoder layer." ), "GIT (Zero Cross-Attention)": ( "🟠 GIT abandons cross-attention entirely. Image patches are " "concatenated to the front of the token sequence; no cross-attention block." ), "Custom VLM (Shakespeare Prefix)": ( "🟢 Custom VLM fuses a frozen ViT with a Shakespeare char-level " "decoder via a single trainable Linear(768→384) projection." ), } MODEL_KEYS = [ "BLIP (Multimodal Mixture Attention)", "ViT-GPT2 (Standard Cross-Attention)", "GIT (Zero Cross-Attention)", "Custom VLM (Shakespeare Prefix)", ] MODEL_SHORT = { "BLIP (Multimodal Mixture Attention)": "BLIP", "ViT-GPT2 (Standard Cross-Attention)": "ViT-GPT2", "GIT (Zero Cross-Attention)": "GIT", "Custom VLM (Shakespeare Prefix)": "Custom VLM", } MODEL_BADGE = { "BLIP (Multimodal Mixture Attention)": "badge-blue", "ViT-GPT2 (Standard Cross-Attention)": "badge-purple", "GIT (Zero Cross-Attention)": "badge-orange", "Custom VLM (Shakespeare Prefix)": "badge-green", } MODEL_CA_TYPE = { "BLIP (Multimodal Mixture Attention)": "Gated MED Cross-Attention", "ViT-GPT2 (Standard Cross-Attention)": "Full Cross-Attention", "GIT (Zero Cross-Attention)": "Self-Attention Prefix", "Custom VLM (Shakespeare Prefix)": "Linear Bridge Prefix", } WEIGHT_TAG_CLASS = {"base": "wt-base", "best": "wt-best", "latest": "wt-latest"} WEIGHT_LABEL = {"base": "Base", "best": "Best", "latest": "Latest"} DEFAULT_OUTPUT_ROOT = "./outputs" DEFAULT_SHAKESPEARE_FILE = "./input.txt" DEFAULT_SHAKESPEARE_WEIGHTS = "./shakespeare_transformer.pt" WEIGHTS_REPO_ID = os.getenv("WEIGHTS_REPO_ID", "griddev/vlm-caption-weights") WEIGHTS_CACHE_DIR = os.getenv("WEIGHTS_CACHE_DIR", "./weights_bundle") def _resolve_weight_paths(): output_root = DEFAULT_OUTPUT_ROOT shakespeare_file = DEFAULT_SHAKESPEARE_FILE shakespeare_weights = DEFAULT_SHAKESPEARE_WEIGHTS local_ready = ( os.path.isdir(output_root) and os.path.exists(shakespeare_file) and os.path.exists(shakespeare_weights) ) if local_ready: return output_root, shakespeare_file, shakespeare_weights try: from huggingface_hub import snapshot_download snapshot_download( repo_id=WEIGHTS_REPO_ID, repo_type="model", local_dir=WEIGHTS_CACHE_DIR, local_dir_use_symlinks=False, allow_patterns=[ "outputs/*", "outputs/**/*", "input.txt", "shakespeare_transformer.pt", ], ) candidate_output_root = os.path.join(WEIGHTS_CACHE_DIR, "outputs") candidate_shakespeare_file = os.path.join(WEIGHTS_CACHE_DIR, "input.txt") candidate_shakespeare_weights = os.path.join( WEIGHTS_CACHE_DIR, "shakespeare_transformer.pt" ) if os.path.isdir(candidate_output_root): output_root = candidate_output_root if os.path.exists(candidate_shakespeare_file): shakespeare_file = candidate_shakespeare_file if os.path.exists(candidate_shakespeare_weights): shakespeare_weights = candidate_shakespeare_weights except Exception as e: print(f"⚠️ Could not download fine-tuned weights from {WEIGHTS_REPO_ID}: {e}") return output_root, shakespeare_file, shakespeare_weights OUTPUT_ROOT, SHAKESPEARE_FILE, SHAKESPEARE_WEIGHTS_PATH = _resolve_weight_paths() # ───────────────────────────────────────────────────────────────────────────── # Device # ───────────────────────────────────────────────────────────────────────────── def get_device(): if torch.backends.mps.is_available(): return torch.device("mps") if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") # ───────────────────────────────────────────────────────────────────────────── # Weight Loading Helpers # ───────────────────────────────────────────────────────────────────────────── def _has_finetuned(model_dir, subdir): """Check if a fine-tuned checkpoint exists for a given model + subdir.""" path = os.path.join(OUTPUT_ROOT, model_dir, subdir) return os.path.isdir(path) and len(os.listdir(path)) > 0 def _ckpt_path(model_dir, subdir): return os.path.join(OUTPUT_ROOT, model_dir, subdir) # ───────────────────────────────────────────────────────────────────────────── # Cached Model Loaders (with weight_source support) # ───────────────────────────────────────────────────────────────────────────── @st.cache_resource(show_spinner=False) def load_blip(weight_source="base"): from transformers import BlipProcessor, BlipForConditionalGeneration device = get_device() processor = BlipProcessor.from_pretrained( "Salesforce/blip-image-captioning-base", use_fast=True) model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base") if weight_source != "base": ckpt = _ckpt_path("blip", weight_source) if os.path.isdir(ckpt) and os.listdir(ckpt): try: loaded = BlipForConditionalGeneration.from_pretrained(ckpt) model.load_state_dict(loaded.state_dict()) del loaded except Exception as e: print(f"⚠️ Could not load BLIP {weight_source} weights: {e}") model.to(device).eval() return processor, model, device @st.cache_resource(show_spinner=False) def load_vit_gpt2(weight_source="base"): from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer device = get_device() model_id = "nlpconnect/vit-gpt2-image-captioning" processor = ViTImageProcessor.from_pretrained(model_id, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token model = VisionEncoderDecoderModel.from_pretrained(model_id) model.config.decoder_start_token_id = tokenizer.bos_token_id model.config.pad_token_id = tokenizer.pad_token_id if weight_source != "base": ckpt = _ckpt_path("vit_gpt2", weight_source) if os.path.isdir(ckpt) and os.listdir(ckpt): try: loaded = VisionEncoderDecoderModel.from_pretrained(ckpt) model.load_state_dict(loaded.state_dict()) del loaded except Exception as e: print(f"⚠️ Could not load ViT-GPT2 {weight_source} weights: {e}") model.to(device).eval() return processor, tokenizer, model, device @st.cache_resource(show_spinner=False) def load_git(weight_source="base"): from transformers import AutoProcessor, AutoModelForCausalLM device = get_device() model_id = "microsoft/git-base-coco" processor = AutoProcessor.from_pretrained(model_id, use_fast=True) model = AutoModelForCausalLM.from_pretrained(model_id) if weight_source != "base": ckpt = _ckpt_path("git", weight_source) if os.path.isdir(ckpt) and os.listdir(ckpt): try: loaded = AutoModelForCausalLM.from_pretrained(ckpt) model.load_state_dict(loaded.state_dict()) del loaded except Exception as e: print(f"⚠️ Could not load GIT {weight_source} weights: {e}") model.to(device).eval() return processor, model, device @st.cache_resource(show_spinner=False) def load_custom_vlm(weight_source="base"): from models.custom_vlm import CustomVLM, build_char_vocab from config import CFG device = get_device() cfg = CFG() cfg.output_root = OUTPUT_ROOT cfg.shakespeare_file = SHAKESPEARE_FILE cfg.shakespeare_weights_path = SHAKESPEARE_WEIGHTS_PATH if not os.path.exists(cfg.shakespeare_file): return None, None, None, None, device with open(cfg.shakespeare_file, "r", encoding="utf-8") as f: text = f.read() _, char_to_idx, idx_to_char, vocab_size = build_char_vocab(text) model = CustomVLM( vocab_size=vocab_size, text_embed_dim=cfg.text_embed_dim, n_heads=cfg.n_heads, n_layers=cfg.n_layers, block_size=cfg.block_size, dropout=cfg.dropout, ) # Always load Shakespeare weights first shakes_path = getattr(cfg, "shakespeare_weights_path", "./shakespeare_transformer.pt") if os.path.exists(shakes_path): model.load_shakespeare_weights(shakes_path) # Then load fine-tuned checkpoint if requested if weight_source != "base": ckpt_path = os.path.join(cfg.output_root, "custom_vlm", weight_source, "custom_vlm.pt") if os.path.exists(ckpt_path): state = torch.load(ckpt_path, map_location="cpu") own_state = model.state_dict() filtered = {k: v for k, v in state["model_state"].items() if k in own_state and own_state[k].shape == v.shape} model.load_state_dict(filtered, strict=False) else: # Even for base, try loading best weights as fallback for subdir in ["best", "latest"]: candidate = os.path.join(cfg.output_root, "custom_vlm", subdir, "custom_vlm.pt") if os.path.exists(candidate): state = torch.load(candidate, map_location="cpu") own_state = model.state_dict() filtered = {k: v for k, v in state["model_state"].items() if k in own_state and own_state[k].shape == v.shape} model.load_state_dict(filtered, strict=False) break model.to(device).eval() return model, char_to_idx, idx_to_char, vocab_size, device @st.cache_resource(show_spinner=False) def load_toxicity_filter(): from transformers import AutoModelForSequenceClassification, AutoTokenizer tox_id = "unitary/toxic-bert" tok = AutoTokenizer.from_pretrained(tox_id) mdl = AutoModelForSequenceClassification.from_pretrained(tox_id) mdl.eval() return tok, mdl # ───────────────────────────────────────────────────────────────────────────── # Toxicity Check # ───────────────────────────────────────────────────────────────────────────── def is_toxic(text, tox_tok, tox_mdl): inputs = tox_tok(text, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = tox_mdl(**inputs) scores = torch.sigmoid(outputs.logits).squeeze() if isinstance(scores, torch.Tensor) and scores.dim() > 0: return (scores > 0.5).any().item() return scores.item() > 0.5 # ───────────────────────────────────────────────────────────────────────────── # Ablation Mask Builder # ───────────────────────────────────────────────────────────────────────────── def build_mask_for_mode(ui_mode, device): N = 197 if ui_mode == "Baseline (Full Attention)": return torch.ones(1, N, dtype=torch.long, device=device), False elif ui_mode == "Random Patch Dropout (50%)": mask = torch.ones(1, N, dtype=torch.long, device=device) spatial_indices = torch.randperm(196)[:98] + 1 mask[0, spatial_indices] = 0 return mask, False elif ui_mode == "Center-Focus (Inner 8×8)": GRID, INNER, offset = 14, 8, 3 keep = set() for row in range(offset, offset + INNER): for col in range(offset, offset + INNER): keep.add(row * GRID + col + 1) mask = torch.zeros(1, N, dtype=torch.long, device=device) mask[0, 0] = 1 for idx in keep: if idx < N: mask[0, idx] = 1 return mask, False elif ui_mode == "Squint (Global Pool)": return None, True return torch.ones(1, N, dtype=torch.long, device=device), False # ───────────────────────────────────────────────────────────────────────────── # Caption Generation (single model) # ───────────────────────────────────────────────────────────────────────────── def generate_caption(model_name, gen_mode, image_pil, num_beams=4, max_new_tokens=50, length_penalty=1.0, weight_source="base"): device = get_device() with torch.no_grad(): if model_name == "BLIP (Multimodal Mixture Attention)": processor, model, device = load_blip(weight_source) inputs = processor(images=image_pil, return_tensors="pt").to(device) mask, is_squint = build_mask_for_mode(gen_mode, device) if is_squint: vision_out = model.vision_model(pixel_values=inputs["pixel_values"]) hs = vision_out.last_hidden_state pooled = torch.cat([hs[:, :1, :], hs[:, 1:, :].mean(dim=1, keepdim=True)], dim=1) captions = generate_with_mask( model, processor, device=device, encoder_hidden_states=pooled, encoder_attention_mask=torch.ones(1, 2, dtype=torch.long, device=device), max_new_tokens=max_new_tokens, num_beams=num_beams, ) else: captions = generate_with_mask( model, processor, device=device, pixel_values=inputs["pixel_values"], encoder_attention_mask=mask, max_new_tokens=max_new_tokens, num_beams=num_beams, ) caption = captions[0] elif model_name == "ViT-GPT2 (Standard Cross-Attention)": from transformers.modeling_outputs import BaseModelOutput processor, tokenizer, model, device = load_vit_gpt2(weight_source) inputs = processor(images=image_pil, return_tensors="pt").to(device) mask, is_squint = build_mask_for_mode(gen_mode, device) if is_squint: enc_out = model.encoder(pixel_values=inputs["pixel_values"]) hs = enc_out.last_hidden_state pooled = torch.cat([hs[:, :1, :], hs[:, 1:, :].mean(dim=1, keepdim=True)], dim=1) out = model.generate( encoder_outputs=BaseModelOutput(last_hidden_state=pooled), decoder_start_token_id=tokenizer.bos_token_id, max_new_tokens=max_new_tokens, num_beams=num_beams, length_penalty=length_penalty, ) else: out = model.generate( **inputs, attention_mask=mask, max_new_tokens=max_new_tokens, num_beams=num_beams, length_penalty=length_penalty, ) caption = tokenizer.decode(out[0], skip_special_tokens=True) elif model_name == "GIT (Zero Cross-Attention)": processor, model, device = load_git(weight_source) inputs = processor(images=image_pil, return_tensors="pt").to(device) out = model.generate( **inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, length_penalty=length_penalty, ) caption = processor.batch_decode(out, skip_special_tokens=True)[0] elif model_name == "Custom VLM (Shakespeare Prefix)": vlm, char_to_idx, idx_to_char, vocab_size, device = load_custom_vlm(weight_source) if vlm is None: return "[Custom VLM not available — train first with: python train.py --model custom]" from transformers import ViTImageProcessor image_processor = ViTImageProcessor.from_pretrained( "google/vit-base-patch16-224-in21k", use_fast=True) pv = image_processor(images=image_pil, return_tensors="pt")["pixel_values"].to(device) if num_beams > 1: caption = vlm.generate_beam(pv, char_to_idx, idx_to_char, max_new_tokens=max_new_tokens, num_beams=num_beams, length_penalty=length_penalty) else: caption = vlm.generate(pv, char_to_idx, idx_to_char, max_new_tokens=max_new_tokens) else: caption = "Unknown model." return caption.strip() # ───────────────────────────────────────────────────────────────────────────── # Sidebar # ───────────────────────────────────────────────────────────────────────────── with st.sidebar: st.markdown("### 🔬 VLM Caption Lab") st.markdown("---") # ── Weight Source ───────────────────────────────────────────────────────── weight_options = { "🔵 Base (Pretrained)": "base", "🟢 Fine-tuned (Best)": "best", "🟡 Fine-tuned (Latest)": "latest", } weight_choice = st.radio( "**Weight Source**", list(weight_options.keys()), index=0, help="Base = HuggingFace pretrained. Best/Latest = your fine-tuned checkpoints." ) weight_source = weight_options[weight_choice] # Show availability indicators ft_status = [] for mdl_dir, mdl_name in [("blip", "BLIP"), ("vit_gpt2", "ViT-GPT2"), ("git", "GIT"), ("custom_vlm", "Custom VLM")]: has_best = _has_finetuned(mdl_dir, "best") has_latest = _has_finetuned(mdl_dir, "latest") if has_best or has_latest: ft_status.append(f" ✅ {mdl_name}") else: ft_status.append(f" ⬜ {mdl_name}") if weight_source != "base": st.caption("Fine-tuned checkpoints:\n" + "\n".join(ft_status)) st.markdown("---") # ── Architecture Selector ───────────────────────────────────────────────── selected_model = st.selectbox("**Architecture**", MODEL_KEYS, index=0) if selected_model in ("BLIP (Multimodal Mixture Attention)", "ViT-GPT2 (Standard Cross-Attention)"): mode_options = [ "Baseline (Full Attention)", "Random Patch Dropout (50%)", "Center-Focus (Inner 8×8)", "Squint (Global Pool)", ] elif selected_model == "Custom VLM (Shakespeare Prefix)": mode_options = ["Shakespeare Prefix"] else: mode_options = ["Baseline (Full Attention)"] selected_mode = st.selectbox("**Generation Mode**", mode_options, index=0) st.markdown( f"
python eval.py --model all | "
"python eval.py --ablation | "
"python -m experiments.parameter_sweep | "
"python -m experiments.data_prep_analysis"
"