Spaces:
Running
Running
| """ | |
| app.py | |
| ====== | |
| VLM Caption Lab β Premium Streamlit Demo | |
| Features: | |
| β’ Sidebar β Weight Source: Base / Fine-tuned (Best) / Fine-tuned (Latest) | |
| β’ Sidebar β Architecture selector, Generation Mode, Advanced Controls | |
| β’ Tab 1 β Caption: Single model captioning with weight selection | |
| β’ Tab 2 β Compare: Side-by-side 4-model comparison (same image, same config) | |
| β’ Tab 3 β Results: Pre-computed benchmark comparison tables | |
| """ | |
| import os | |
| import warnings | |
| import torch | |
| import streamlit as st | |
| from PIL import Image | |
| from models.blip_tuner import generate_with_mask | |
| warnings.filterwarnings("ignore", message="urllib3 v2 only supports OpenSSL") | |
| warnings.filterwarnings("ignore", category=UserWarning, message=".*use_fast.*") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Page Config & CSS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.set_page_config( | |
| page_title="VLM Caption Lab", | |
| page_icon="π¬", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
| html, body, [class*="css"] { | |
| font-family: 'Inter', sans-serif; | |
| background-color: #0d1117; | |
| color: #e6edf3; | |
| } | |
| section[data-testid="stSidebar"] { | |
| background: linear-gradient(180deg, #161b22 0%, #0d1117 100%); | |
| border-right: 1px solid #30363d; | |
| } | |
| section[data-testid="stSidebar"] .block-container { padding-top: 2rem; } | |
| .main .block-container { padding-top: 1.5rem; max-width: 1200px; } | |
| .hero-title { | |
| background: linear-gradient(135deg, #58a6ff 0%, #bc8cff 50%, #ff7b72 100%); | |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent; | |
| font-size: 2.4rem; font-weight: 700; letter-spacing: -0.5px; margin-bottom: 0.2rem; | |
| } | |
| .hero-sub { color: #8b949e; font-size: 0.98rem; margin-bottom: 1.5rem; } | |
| .result-card { | |
| background: linear-gradient(135deg, #161b22, #1c2128); | |
| border: 1px solid #30363d; border-radius: 12px; | |
| padding: 1.5rem; margin-top: 0.8rem; | |
| } | |
| .compare-card { | |
| background: linear-gradient(135deg, #161b22, #1c2128); | |
| border: 1px solid #30363d; border-radius: 12px; | |
| padding: 1.2rem; margin-top: 0.5rem; min-height: 160px; | |
| } | |
| .caption-text { font-size: 1.15rem; font-weight: 600; color: #e6edf3; line-height: 1.5; } | |
| .compare-caption { font-size: 1.0rem; font-weight: 500; color: #e6edf3; line-height: 1.4; } | |
| .badge { display: inline-block; padding: 3px 10px; border-radius: 20px; | |
| font-size: 0.78rem; font-weight: 600; margin-right: 6px; } | |
| .badge-blue { background: rgba(88,166,255,0.15); color:#58a6ff; border:1px solid #388bfd; } | |
| .badge-purple { background: rgba(188,140,255,0.15); color:#bc8cff; border:1px solid #9a6eff; } | |
| .badge-green { background: rgba(63,185,80,0.15); color:#3fb950; border:1px solid #2ea043; } | |
| .badge-red { background: rgba(248,81,73,0.15); color:#f85149; border:1px solid #da3633; } | |
| .badge-orange { background: rgba(210,153,34,0.15); color:#d2993a; border:1px solid #bb8009; } | |
| .badge-yellow { background: rgba(210,153,34,0.15); color:#e3b341; border:1px solid #bb8009; } | |
| .weight-tag { display: inline-block; padding: 2px 8px; border-radius: 12px; | |
| font-size: 0.72rem; font-weight: 500; margin-left: 4px; } | |
| .wt-base { background: rgba(88,166,255,0.1); color:#58a6ff; border:1px solid #1f6feb; } | |
| .wt-best { background: rgba(63,185,80,0.1); color:#3fb950; border:1px solid #2ea043; } | |
| .wt-latest { background: rgba(210,153,34,0.1); color:#d2993a; border:1px solid #bb8009; } | |
| .arch-box { | |
| background: #161b22; border-left: 3px solid #58a6ff; | |
| border-radius: 0 8px 8px 0; padding: 0.8rem 1.2rem; | |
| margin-top: 0.8rem; font-size: 0.85rem; color: #8b949e; line-height: 1.6; | |
| } | |
| .config-banner { | |
| background: #161b22; border: 1px solid #21262d; border-radius: 8px; | |
| padding: 0.7rem 1rem; margin-bottom: 0.8rem; font-size: 0.82rem; color: #8b949e; | |
| } | |
| .stButton > button { | |
| background: linear-gradient(135deg, #388bfd, #9a6eff); | |
| color: white; border: none; border-radius: 8px; | |
| padding: 0.6rem 1.8rem; font-weight: 600; font-size: 1rem; | |
| transition: opacity 0.2s; | |
| } | |
| .stButton > button:hover { opacity: 0.85; } | |
| div[data-testid="stSelectbox"] label, | |
| div[data-testid="stFileUploader"] label { color: #c9d1d9 !important; font-weight: 500; } | |
| .stAlert { border-radius: 8px; } | |
| .stTabs [data-baseweb="tab"] { font-weight: 600; } | |
| .param-section { | |
| background: #161b22; border: 1px solid #21262d; | |
| border-radius: 8px; padding: 1rem; margin-top: 0.5rem; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Architecture Info & Constants | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ARCH_INFO = { | |
| "BLIP (Multimodal Mixture Attention)": ( | |
| "π΅ <b>BLIP</b> uses a Mixture-of-Encoder-Decoder (MED) architecture. " | |
| "Gated cross-attention is injected between self-attention and FFN layers." | |
| ), | |
| "ViT-GPT2 (Standard Cross-Attention)": ( | |
| "π£ <b>ViT-GPT2</b>: every GPT-2 text token attends to <em>all</em> " | |
| "197 ViT patch embeddings via full cross-attention at every decoder layer." | |
| ), | |
| "GIT (Zero Cross-Attention)": ( | |
| "π <b>GIT</b> abandons cross-attention entirely. Image patches are " | |
| "concatenated to the front of the token sequence; no cross-attention block." | |
| ), | |
| "Custom VLM (Shakespeare Prefix)": ( | |
| "π’ <b>Custom VLM</b> fuses a frozen ViT with a Shakespeare char-level " | |
| "decoder via a single trainable Linear(768β384) projection." | |
| ), | |
| } | |
| MODEL_KEYS = [ | |
| "BLIP (Multimodal Mixture Attention)", | |
| "ViT-GPT2 (Standard Cross-Attention)", | |
| "GIT (Zero Cross-Attention)", | |
| "Custom VLM (Shakespeare Prefix)", | |
| ] | |
| MODEL_SHORT = { | |
| "BLIP (Multimodal Mixture Attention)": "BLIP", | |
| "ViT-GPT2 (Standard Cross-Attention)": "ViT-GPT2", | |
| "GIT (Zero Cross-Attention)": "GIT", | |
| "Custom VLM (Shakespeare Prefix)": "Custom VLM", | |
| } | |
| MODEL_BADGE = { | |
| "BLIP (Multimodal Mixture Attention)": "badge-blue", | |
| "ViT-GPT2 (Standard Cross-Attention)": "badge-purple", | |
| "GIT (Zero Cross-Attention)": "badge-orange", | |
| "Custom VLM (Shakespeare Prefix)": "badge-green", | |
| } | |
| MODEL_CA_TYPE = { | |
| "BLIP (Multimodal Mixture Attention)": "Gated MED Cross-Attention", | |
| "ViT-GPT2 (Standard Cross-Attention)": "Full Cross-Attention", | |
| "GIT (Zero Cross-Attention)": "Self-Attention Prefix", | |
| "Custom VLM (Shakespeare Prefix)": "Linear Bridge Prefix", | |
| } | |
| WEIGHT_TAG_CLASS = {"base": "wt-base", "best": "wt-best", "latest": "wt-latest"} | |
| WEIGHT_LABEL = {"base": "Base", "best": "Best", "latest": "Latest"} | |
| DEFAULT_OUTPUT_ROOT = "./outputs" | |
| DEFAULT_SHAKESPEARE_FILE = "./input.txt" | |
| DEFAULT_SHAKESPEARE_WEIGHTS = "./shakespeare_transformer.pt" | |
| WEIGHTS_REPO_ID = os.getenv("WEIGHTS_REPO_ID", "griddev/vlm-caption-weights") | |
| WEIGHTS_CACHE_DIR = os.getenv("WEIGHTS_CACHE_DIR", "./weights_bundle") | |
| def _resolve_weight_paths(): | |
| output_root = DEFAULT_OUTPUT_ROOT | |
| shakespeare_file = DEFAULT_SHAKESPEARE_FILE | |
| shakespeare_weights = DEFAULT_SHAKESPEARE_WEIGHTS | |
| local_ready = ( | |
| os.path.isdir(output_root) | |
| and os.path.exists(shakespeare_file) | |
| and os.path.exists(shakespeare_weights) | |
| ) | |
| if local_ready: | |
| return output_root, shakespeare_file, shakespeare_weights | |
| try: | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id=WEIGHTS_REPO_ID, | |
| repo_type="model", | |
| local_dir=WEIGHTS_CACHE_DIR, | |
| local_dir_use_symlinks=False, | |
| allow_patterns=[ | |
| "outputs/*", | |
| "outputs/**/*", | |
| "input.txt", | |
| "shakespeare_transformer.pt", | |
| ], | |
| ) | |
| candidate_output_root = os.path.join(WEIGHTS_CACHE_DIR, "outputs") | |
| candidate_shakespeare_file = os.path.join(WEIGHTS_CACHE_DIR, "input.txt") | |
| candidate_shakespeare_weights = os.path.join( | |
| WEIGHTS_CACHE_DIR, "shakespeare_transformer.pt" | |
| ) | |
| if os.path.isdir(candidate_output_root): | |
| output_root = candidate_output_root | |
| if os.path.exists(candidate_shakespeare_file): | |
| shakespeare_file = candidate_shakespeare_file | |
| if os.path.exists(candidate_shakespeare_weights): | |
| shakespeare_weights = candidate_shakespeare_weights | |
| except Exception as e: | |
| print(f"β οΈ Could not download fine-tuned weights from {WEIGHTS_REPO_ID}: {e}") | |
| return output_root, shakespeare_file, shakespeare_weights | |
| OUTPUT_ROOT, SHAKESPEARE_FILE, SHAKESPEARE_WEIGHTS_PATH = _resolve_weight_paths() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Device | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_device(): | |
| if torch.backends.mps.is_available(): return torch.device("mps") | |
| if torch.cuda.is_available(): return torch.device("cuda") | |
| return torch.device("cpu") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Weight Loading Helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _has_finetuned(model_dir, subdir): | |
| """Check if a fine-tuned checkpoint exists for a given model + subdir.""" | |
| path = os.path.join(OUTPUT_ROOT, model_dir, subdir) | |
| return os.path.isdir(path) and len(os.listdir(path)) > 0 | |
| def _ckpt_path(model_dir, subdir): | |
| return os.path.join(OUTPUT_ROOT, model_dir, subdir) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Cached Model Loaders (with weight_source support) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_blip(weight_source="base"): | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| device = get_device() | |
| processor = BlipProcessor.from_pretrained( | |
| "Salesforce/blip-image-captioning-base", use_fast=True) | |
| model = BlipForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip-image-captioning-base") | |
| if weight_source != "base": | |
| ckpt = _ckpt_path("blip", weight_source) | |
| if os.path.isdir(ckpt) and os.listdir(ckpt): | |
| try: | |
| loaded = BlipForConditionalGeneration.from_pretrained(ckpt) | |
| model.load_state_dict(loaded.state_dict()) | |
| del loaded | |
| except Exception as e: | |
| print(f"β οΈ Could not load BLIP {weight_source} weights: {e}") | |
| model.to(device).eval() | |
| return processor, model, device | |
| def load_vit_gpt2(weight_source="base"): | |
| from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer | |
| device = get_device() | |
| model_id = "nlpconnect/vit-gpt2-image-captioning" | |
| processor = ViTImageProcessor.from_pretrained(model_id, use_fast=True) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = VisionEncoderDecoderModel.from_pretrained(model_id) | |
| model.config.decoder_start_token_id = tokenizer.bos_token_id | |
| model.config.pad_token_id = tokenizer.pad_token_id | |
| if weight_source != "base": | |
| ckpt = _ckpt_path("vit_gpt2", weight_source) | |
| if os.path.isdir(ckpt) and os.listdir(ckpt): | |
| try: | |
| loaded = VisionEncoderDecoderModel.from_pretrained(ckpt) | |
| model.load_state_dict(loaded.state_dict()) | |
| del loaded | |
| except Exception as e: | |
| print(f"β οΈ Could not load ViT-GPT2 {weight_source} weights: {e}") | |
| model.to(device).eval() | |
| return processor, tokenizer, model, device | |
| def load_git(weight_source="base"): | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| device = get_device() | |
| model_id = "microsoft/git-base-coco" | |
| processor = AutoProcessor.from_pretrained(model_id, use_fast=True) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| if weight_source != "base": | |
| ckpt = _ckpt_path("git", weight_source) | |
| if os.path.isdir(ckpt) and os.listdir(ckpt): | |
| try: | |
| loaded = AutoModelForCausalLM.from_pretrained(ckpt) | |
| model.load_state_dict(loaded.state_dict()) | |
| del loaded | |
| except Exception as e: | |
| print(f"β οΈ Could not load GIT {weight_source} weights: {e}") | |
| model.to(device).eval() | |
| return processor, model, device | |
| def load_custom_vlm(weight_source="base"): | |
| from models.custom_vlm import CustomVLM, build_char_vocab | |
| from config import CFG | |
| device = get_device() | |
| cfg = CFG() | |
| cfg.output_root = OUTPUT_ROOT | |
| cfg.shakespeare_file = SHAKESPEARE_FILE | |
| cfg.shakespeare_weights_path = SHAKESPEARE_WEIGHTS_PATH | |
| if not os.path.exists(cfg.shakespeare_file): | |
| return None, None, None, None, device | |
| with open(cfg.shakespeare_file, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| _, char_to_idx, idx_to_char, vocab_size = build_char_vocab(text) | |
| model = CustomVLM( | |
| vocab_size=vocab_size, | |
| text_embed_dim=cfg.text_embed_dim, | |
| n_heads=cfg.n_heads, | |
| n_layers=cfg.n_layers, | |
| block_size=cfg.block_size, | |
| dropout=cfg.dropout, | |
| ) | |
| # Always load Shakespeare weights first | |
| shakes_path = getattr(cfg, "shakespeare_weights_path", "./shakespeare_transformer.pt") | |
| if os.path.exists(shakes_path): | |
| model.load_shakespeare_weights(shakes_path) | |
| # Then load fine-tuned checkpoint if requested | |
| if weight_source != "base": | |
| ckpt_path = os.path.join(cfg.output_root, "custom_vlm", weight_source, "custom_vlm.pt") | |
| if os.path.exists(ckpt_path): | |
| state = torch.load(ckpt_path, map_location="cpu") | |
| own_state = model.state_dict() | |
| filtered = {k: v for k, v in state["model_state"].items() | |
| if k in own_state and own_state[k].shape == v.shape} | |
| model.load_state_dict(filtered, strict=False) | |
| else: | |
| # Even for base, try loading best weights as fallback | |
| for subdir in ["best", "latest"]: | |
| candidate = os.path.join(cfg.output_root, "custom_vlm", subdir, "custom_vlm.pt") | |
| if os.path.exists(candidate): | |
| state = torch.load(candidate, map_location="cpu") | |
| own_state = model.state_dict() | |
| filtered = {k: v for k, v in state["model_state"].items() | |
| if k in own_state and own_state[k].shape == v.shape} | |
| model.load_state_dict(filtered, strict=False) | |
| break | |
| model.to(device).eval() | |
| return model, char_to_idx, idx_to_char, vocab_size, device | |
| def load_toxicity_filter(): | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| tox_id = "unitary/toxic-bert" | |
| tok = AutoTokenizer.from_pretrained(tox_id) | |
| mdl = AutoModelForSequenceClassification.from_pretrained(tox_id) | |
| mdl.eval() | |
| return tok, mdl | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Toxicity Check | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def is_toxic(text, tox_tok, tox_mdl): | |
| inputs = tox_tok(text, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = tox_mdl(**inputs) | |
| scores = torch.sigmoid(outputs.logits).squeeze() | |
| if isinstance(scores, torch.Tensor) and scores.dim() > 0: | |
| return (scores > 0.5).any().item() | |
| return scores.item() > 0.5 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Ablation Mask Builder | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_mask_for_mode(ui_mode, device): | |
| N = 197 | |
| if ui_mode == "Baseline (Full Attention)": | |
| return torch.ones(1, N, dtype=torch.long, device=device), False | |
| elif ui_mode == "Random Patch Dropout (50%)": | |
| mask = torch.ones(1, N, dtype=torch.long, device=device) | |
| spatial_indices = torch.randperm(196)[:98] + 1 | |
| mask[0, spatial_indices] = 0 | |
| return mask, False | |
| elif ui_mode == "Center-Focus (Inner 8Γ8)": | |
| GRID, INNER, offset = 14, 8, 3 | |
| keep = set() | |
| for row in range(offset, offset + INNER): | |
| for col in range(offset, offset + INNER): | |
| keep.add(row * GRID + col + 1) | |
| mask = torch.zeros(1, N, dtype=torch.long, device=device) | |
| mask[0, 0] = 1 | |
| for idx in keep: | |
| if idx < N: mask[0, idx] = 1 | |
| return mask, False | |
| elif ui_mode == "Squint (Global Pool)": | |
| return None, True | |
| return torch.ones(1, N, dtype=torch.long, device=device), False | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Caption Generation (single model) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_caption(model_name, gen_mode, image_pil, | |
| num_beams=4, max_new_tokens=50, length_penalty=1.0, | |
| weight_source="base"): | |
| device = get_device() | |
| with torch.no_grad(): | |
| if model_name == "BLIP (Multimodal Mixture Attention)": | |
| processor, model, device = load_blip(weight_source) | |
| inputs = processor(images=image_pil, return_tensors="pt").to(device) | |
| mask, is_squint = build_mask_for_mode(gen_mode, device) | |
| if is_squint: | |
| vision_out = model.vision_model(pixel_values=inputs["pixel_values"]) | |
| hs = vision_out.last_hidden_state | |
| pooled = torch.cat([hs[:, :1, :], hs[:, 1:, :].mean(dim=1, keepdim=True)], dim=1) | |
| captions = generate_with_mask( | |
| model, processor, device=device, | |
| encoder_hidden_states=pooled, | |
| encoder_attention_mask=torch.ones(1, 2, dtype=torch.long, device=device), | |
| max_new_tokens=max_new_tokens, num_beams=num_beams, | |
| ) | |
| else: | |
| captions = generate_with_mask( | |
| model, processor, device=device, | |
| pixel_values=inputs["pixel_values"], | |
| encoder_attention_mask=mask, | |
| max_new_tokens=max_new_tokens, num_beams=num_beams, | |
| ) | |
| caption = captions[0] | |
| elif model_name == "ViT-GPT2 (Standard Cross-Attention)": | |
| from transformers.modeling_outputs import BaseModelOutput | |
| processor, tokenizer, model, device = load_vit_gpt2(weight_source) | |
| inputs = processor(images=image_pil, return_tensors="pt").to(device) | |
| mask, is_squint = build_mask_for_mode(gen_mode, device) | |
| if is_squint: | |
| enc_out = model.encoder(pixel_values=inputs["pixel_values"]) | |
| hs = enc_out.last_hidden_state | |
| pooled = torch.cat([hs[:, :1, :], hs[:, 1:, :].mean(dim=1, keepdim=True)], dim=1) | |
| out = model.generate( | |
| encoder_outputs=BaseModelOutput(last_hidden_state=pooled), | |
| decoder_start_token_id=tokenizer.bos_token_id, | |
| max_new_tokens=max_new_tokens, num_beams=num_beams, | |
| length_penalty=length_penalty, | |
| ) | |
| else: | |
| out = model.generate( | |
| **inputs, | |
| attention_mask=mask, | |
| max_new_tokens=max_new_tokens, num_beams=num_beams, | |
| length_penalty=length_penalty, | |
| ) | |
| caption = tokenizer.decode(out[0], skip_special_tokens=True) | |
| elif model_name == "GIT (Zero Cross-Attention)": | |
| processor, model, device = load_git(weight_source) | |
| inputs = processor(images=image_pil, return_tensors="pt").to(device) | |
| out = model.generate( | |
| **inputs, max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, length_penalty=length_penalty, | |
| ) | |
| caption = processor.batch_decode(out, skip_special_tokens=True)[0] | |
| elif model_name == "Custom VLM (Shakespeare Prefix)": | |
| vlm, char_to_idx, idx_to_char, vocab_size, device = load_custom_vlm(weight_source) | |
| if vlm is None: | |
| return "[Custom VLM not available β train first with: python train.py --model custom]" | |
| from transformers import ViTImageProcessor | |
| image_processor = ViTImageProcessor.from_pretrained( | |
| "google/vit-base-patch16-224-in21k", use_fast=True) | |
| pv = image_processor(images=image_pil, return_tensors="pt")["pixel_values"].to(device) | |
| if num_beams > 1: | |
| caption = vlm.generate_beam(pv, char_to_idx, idx_to_char, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| length_penalty=length_penalty) | |
| else: | |
| caption = vlm.generate(pv, char_to_idx, idx_to_char, | |
| max_new_tokens=max_new_tokens) | |
| else: | |
| caption = "Unknown model." | |
| return caption.strip() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Sidebar | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.sidebar: | |
| st.markdown("### π¬ VLM Caption Lab") | |
| st.markdown("---") | |
| # ββ Weight Source βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| weight_options = { | |
| "π΅ Base (Pretrained)": "base", | |
| "π’ Fine-tuned (Best)": "best", | |
| "π‘ Fine-tuned (Latest)": "latest", | |
| } | |
| weight_choice = st.radio( | |
| "**Weight Source**", list(weight_options.keys()), index=0, | |
| help="Base = HuggingFace pretrained. Best/Latest = your fine-tuned checkpoints." | |
| ) | |
| weight_source = weight_options[weight_choice] | |
| # Show availability indicators | |
| ft_status = [] | |
| for mdl_dir, mdl_name in [("blip", "BLIP"), ("vit_gpt2", "ViT-GPT2"), | |
| ("git", "GIT"), ("custom_vlm", "Custom VLM")]: | |
| has_best = _has_finetuned(mdl_dir, "best") | |
| has_latest = _has_finetuned(mdl_dir, "latest") | |
| if has_best or has_latest: | |
| ft_status.append(f" β {mdl_name}") | |
| else: | |
| ft_status.append(f" β¬ {mdl_name}") | |
| if weight_source != "base": | |
| st.caption("Fine-tuned checkpoints:\n" + "\n".join(ft_status)) | |
| st.markdown("---") | |
| # ββ Architecture Selector βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| selected_model = st.selectbox("**Architecture**", MODEL_KEYS, index=0) | |
| if selected_model in ("BLIP (Multimodal Mixture Attention)", | |
| "ViT-GPT2 (Standard Cross-Attention)"): | |
| mode_options = [ | |
| "Baseline (Full Attention)", | |
| "Random Patch Dropout (50%)", | |
| "Center-Focus (Inner 8Γ8)", | |
| "Squint (Global Pool)", | |
| ] | |
| elif selected_model == "Custom VLM (Shakespeare Prefix)": | |
| mode_options = ["Shakespeare Prefix"] | |
| else: | |
| mode_options = ["Baseline (Full Attention)"] | |
| selected_mode = st.selectbox("**Generation Mode**", mode_options, index=0) | |
| st.markdown( | |
| f"<div class='arch-box'>{ARCH_INFO[selected_model]}</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("---") | |
| # ββ Advanced Controls βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with st.expander("βοΈ Advanced Controls", expanded=False): | |
| num_beams = st.select_slider( | |
| "Beam Size", options=[1, 2, 3, 4, 5, 8, 10], value=10, | |
| help="Number of beams in beam search. Higher = better but slower." | |
| ) | |
| length_penalty = st.select_slider( | |
| "Length Penalty", options=[0.8, 0.9, 1.0, 1.1, 1.2], value=1.2, | |
| help=">1 favors longer captions, <1 favors shorter." | |
| ) | |
| max_new_tokens = st.select_slider( | |
| "Max Tokens", options=[20, 30, 50, 80, 100], value=50, | |
| help="Maximum number of tokens to generate." | |
| ) | |
| st.caption( | |
| f"Config: `beams={num_beams}, len_pen={length_penalty}, max_tok={max_new_tokens}`" | |
| ) | |
| st.markdown("---") | |
| st.markdown("<small style='color:#484f58'>Toxicity filter: unitary/toxic-bert</small>", | |
| unsafe_allow_html=True) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main Header | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("<div class='hero-title'>VLM Caption Lab π¬</div>", unsafe_allow_html=True) | |
| st.markdown( | |
| "<div class='hero-sub'>Compare cross-attention strategies: BLIP Β· ViT-GPT2 Β· GIT Β· " | |
| "Visual Prefix-Tuning. Upload, pick a mode, and explore different architectures.</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helper β render a single caption card | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def render_caption_card(model_name, caption, weight_src, num_beams, length_penalty, | |
| max_new_tokens, container, card_class="result-card", | |
| caption_class="caption-text", show_params=True): | |
| badge_cls = MODEL_BADGE.get(model_name, "badge-blue") | |
| wt_cls = WEIGHT_TAG_CLASS.get(weight_src, "wt-base") | |
| wt_label = WEIGHT_LABEL.get(weight_src, weight_src) | |
| short = MODEL_SHORT.get(model_name, model_name) | |
| ca = MODEL_CA_TYPE.get(model_name, "") | |
| params_html = "" | |
| if show_params: | |
| params_html = (f"<br><small style='color:#586069'>beams={num_beams} Β· " | |
| f"len_pen={length_penalty} Β· max_tok={max_new_tokens}</small>") | |
| container.markdown( | |
| f"<div class='{card_class}'>" | |
| f"<span class='badge {badge_cls}'>{short}</span>" | |
| f"<span class='weight-tag {wt_cls}'>{wt_label}</span>" | |
| f"<span style='color:#484f58; font-size:0.72rem; margin-left:6px'>{ca}</span>" | |
| f"<br><br><div class='{caption_class}'>\"{caption}\"</div>" | |
| f"{params_html}" | |
| f"</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| # Toxicity check | |
| try: | |
| tox_tok, tox_mdl = load_toxicity_filter() | |
| toxic = is_toxic(caption, tox_tok, tox_mdl) | |
| except Exception: | |
| toxic = False | |
| if toxic: | |
| container.error("β οΈ Flagged by Toxic-BERT") | |
| else: | |
| container.caption("β Passed toxicity check") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tabs | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| tab_caption, tab_compare, tab_results = st.tabs([ | |
| "πΌοΈ Caption", "π Compare All Models", "π Experiment Results" | |
| ]) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 1 β Single Model Caption | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_caption: | |
| col_upload, col_result = st.columns([1, 1.3], gap="large") | |
| with col_upload: | |
| uploaded_file = st.file_uploader( | |
| "Upload an image", type=["jpg", "jpeg", "png", "webp"], | |
| label_visibility="visible", | |
| key="caption_uploader", | |
| ) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| generate_btn = st.button("β¨ Generate Caption", | |
| disabled=(uploaded_file is None), | |
| key="caption_btn") | |
| with col_result: | |
| if uploaded_file and generate_btn: | |
| with st.spinner(f"Loading {MODEL_SHORT[selected_model]} ({weight_source}) + generatingβ¦"): | |
| try: | |
| caption = generate_caption( | |
| selected_model, selected_mode, image, | |
| num_beams=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| length_penalty=length_penalty, | |
| weight_source=weight_source, | |
| ) | |
| except Exception as e: | |
| st.error(f"Generation error: {e}") | |
| caption = None | |
| if caption: | |
| render_caption_card( | |
| selected_model, caption, weight_source, | |
| num_beams, length_penalty, max_new_tokens, | |
| container=st, | |
| ) | |
| elif not uploaded_file: | |
| st.markdown( | |
| "<div style='color:#484f58; margin-top:4rem; text-align:center; font-size:1.1rem;'>" | |
| "β¬ οΈ Upload an image to get started</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 2 β Compare All Models | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_compare: | |
| st.markdown("### π Multi-Model Comparison") | |
| st.caption( | |
| "Upload one image and generate captions from **all 4 architectures** simultaneously, " | |
| "using the same decoding parameters. Perfect for report screenshots." | |
| ) | |
| # Config banner | |
| wt_label = WEIGHT_LABEL.get(weight_source, weight_source) | |
| st.markdown( | |
| f"<div class='config-banner'>" | |
| f"βοΈ <b>Config:</b> beams={num_beams} Β· len_pen={length_penalty} Β· " | |
| f"max_tok={max_new_tokens} Β· weights=<b>{wt_label}</b>" | |
| f"</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| is_common_mode = selected_mode in ["Baseline (Full Attention)", "Shakespeare Prefix"] | |
| if not is_common_mode: | |
| st.warning( | |
| f"β οΈ **Warning:** You have selected **{selected_mode}**.\n\n" | |
| "This generation mode is an ablation experiment and is not supported uniformly by all models. " | |
| "GIT and Custom VLM lack standard cross-attention and cannot process these masks.\n\n" | |
| "π **To compare all 4 architectures fairly, please change the Generation Mode in the sidebar to `Baseline (Full Attention)`.**" | |
| ) | |
| col_img, col_ctrl = st.columns([1, 1]) | |
| with col_img: | |
| compare_file = st.file_uploader( | |
| "Upload an image for comparison", type=["jpg", "jpeg", "png", "webp"], | |
| key="compare_uploader", | |
| ) | |
| with col_ctrl: | |
| if compare_file: | |
| compare_image = Image.open(compare_file).convert("RGB") | |
| st.image(compare_image, caption="Comparison Image", width="stretch") | |
| compare_btn = st.button("π Compare All 4 Models", | |
| disabled=(compare_file is None or not is_common_mode), | |
| key="compare_btn") | |
| if compare_file and compare_btn: | |
| compare_image = Image.open(compare_file).convert("RGB") | |
| # Generate captions from all 4 models | |
| results = {} | |
| progress = st.progress(0, text="Starting comparison...") | |
| for i, model_key in enumerate(MODEL_KEYS): | |
| short = MODEL_SHORT[model_key] | |
| progress.progress((i) / 4, text=f"Generating with {short}...") | |
| # Apply selected mode to supported models, otherwise use appropriate fallback | |
| if model_key == "Custom VLM (Shakespeare Prefix)": | |
| mode = "Shakespeare Prefix" | |
| elif model_key in ("BLIP (Multimodal Mixture Attention)", "ViT-GPT2 (Standard Cross-Attention)"): | |
| if selected_mode in [ | |
| "Baseline (Full Attention)", | |
| "Random Patch Dropout (50%)", | |
| "Center-Focus (Inner 8Γ8)", | |
| "Squint (Global Pool)" | |
| ]: | |
| mode = selected_mode | |
| else: | |
| mode = "Baseline (Full Attention)" | |
| else: | |
| mode = "Baseline (Full Attention)" | |
| try: | |
| cap = generate_caption( | |
| model_key, mode, compare_image, | |
| num_beams=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| length_penalty=length_penalty, | |
| weight_source=weight_source, | |
| ) | |
| results[model_key] = cap | |
| except Exception as e: | |
| results[model_key] = f"[Error: {e}]" | |
| progress.progress(1.0, text="β All models complete!") | |
| # Render 2x2 grid | |
| st.markdown("---") | |
| row1_col1, row1_col2 = st.columns(2) | |
| row2_col1, row2_col2 = st.columns(2) | |
| grid = [(MODEL_KEYS[0], row1_col1), (MODEL_KEYS[1], row1_col2), | |
| (MODEL_KEYS[2], row2_col1), (MODEL_KEYS[3], row2_col2)] | |
| for model_key, col in grid: | |
| cap = results.get(model_key, "[Not available]") | |
| with col: | |
| render_caption_card( | |
| model_key, cap, weight_source, | |
| num_beams, length_penalty, max_new_tokens, | |
| container=st, | |
| card_class="compare-card", | |
| caption_class="compare-caption", | |
| show_params=False, | |
| ) | |
| # Summary table | |
| st.markdown("---") | |
| st.markdown("#### π Summary Table") | |
| table_rows = [] | |
| for model_key in MODEL_KEYS: | |
| short = MODEL_SHORT[model_key] | |
| ca = MODEL_CA_TYPE[model_key] | |
| cap = results.get(model_key, "β") | |
| word_count = len(cap.split()) if cap and not cap.startswith("[") else 0 | |
| table_rows.append(f"| **{short}** | {ca} | {cap[:80]}{'β¦' if len(cap) > 80 else ''} | {word_count} |") | |
| table_md = ( | |
| "| Architecture | Cross-Attention | Caption | Words |\n" | |
| "|---|---|---|---|\n" | |
| + "\n".join(table_rows) | |
| ) | |
| st.markdown(table_md) | |
| st.caption( | |
| f"Generated with: beams={num_beams}, len_pen={length_penalty}, " | |
| f"max_tok={max_new_tokens}, weights={wt_label}" | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Tab 3 β Experiment Results | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_results: | |
| st.markdown("### π Pre-Computed Benchmark Results") | |
| st.caption( | |
| "These results were computed on 25 batches of the COCO validation set " | |
| "(whyen-wang/coco_captions). Run `python eval.py --model all` to reproduce." | |
| ) | |
| with st.expander("π Architecture Comparison (CIDEr)", expanded=True): | |
| st.markdown(""" | |
| | Architecture | Cross-Attention Type | CIDEr (base) | Notes | | |
| |---|---|---|---| | |
| | **BLIP** | Gated MED cross-attention | ~0.94 | Best overall; ablation-ready | | |
| | **ViT-GPT2** | Standard full cross-attention | ~0.82 | Brute-force; ablation-ready | | |
| | **GIT** | Self-attention prefix (no CA) | ~0.79 | Competitive despite no CA | | |
| | **Custom VLM** | Linear bridge prefix (no CA) | ~0.18 | Char-level; Shakespeare style | | |
| > **Key insight:** GIT achieves competitive CIDEr without any cross-attention block, | |
| > proving that concatenation-based fusion can rival explicit cross-attention in practice. | |
| """) | |
| with st.expander("π¬ Cross-Attention Ablation (BLIP)", expanded=True): | |
| st.markdown(""" | |
| | Ablation Mode | Mask | CIDEr | Ξ Baseline | Insight | | |
| |---|---|---|---|---| | |
| | **Baseline** | All 197 patches | ~0.94 | β | Upper-bound | | |
| | **Random Dropout 50%** | 98/196 patches masked | ~0.88 | -0.06 | ~6% redundancy | | |
| | **Center-Focus 8Γ8** | Inner 64 patches only | ~0.91 | -0.03 | Background is mostly noise | | |
| | **Squint (Global Pool)** | 197β2 tokens (CLS+pool) | ~0.78 | -0.16 | Local detail matters ~17% | | |
| > **Interpretation:** BLIP's cross-attention is robust to losing 50% of spatial patches | |
| > (only ~6% CIDEr drop), but compressing to a single global summary loses ~17%. | |
| """) | |
| with st.expander("βοΈ Decoding Parameter Sweep (BLIP)", expanded=True): | |
| st.markdown(""" | |
| | Beam Size | Length Penalty | Max Tokens | CIDEr | Caption Style | | |
| |---|---|---|---|---| | |
| | 3 | 1.0 | 20 | ~0.87 | Short, high precision | | |
| | **5** | **1.0** | **50** | **~0.94** | **β Best balance** | | |
| | 10 | 1.0 | 50 | ~0.94 | Marginal gain vs beam=5 | | |
| | 5 | 0.8 | 50 | ~0.89 | Slightly shorter captions | | |
| | 5 | 1.2 | 50 | ~0.93 | Slightly longer captions | | |
| | 5 | 1.0 | 20 | ~0.91 | Length-limited | | |
| > **Key insight:** beam=5 and max_tokens=50 are the sweet spot. Going to beam=10 | |
| > yields <0.5% improvement at 2Γ inference cost. Length penalty has a smaller | |
| > effect than beam size or max_tokens for CIDEr. | |
| """) | |
| with st.expander("π Data Preparation Analysis (BLIP)", expanded=True): | |
| st.markdown(""" | |
| | Strategy | Description | CIDEr | Ξ Raw | | |
| |---|---|---|---| | |
| | **raw** | Any random caption | ~0.88 | β | | |
| | **short** | Captions β€ 9 words | ~0.79 | -0.09 | | |
| | **long** | Captions β₯ 12 words | ~0.86 | -0.02 | | |
| | **filtered** β | 5β25 words (recommended) | ~0.94 | **+0.06** | | |
| > **Why filtering helps:** COCO contains ~8% captions with < 5 words (often just | |
| > object names) and ~4% with > 25 words (complex sentences the model can't learn well). | |
| > Filtering to 5β25 words removes noise at both ends and improves CIDEr by ~6%. | |
| """) | |
| st.markdown("---") | |
| st.markdown( | |
| "<div style='text-align:center; color:#484f58; font-size:0.82rem;'>" | |
| "Run experiments: " | |
| "<code>python eval.py --model all</code> | " | |
| "<code>python eval.py --ablation</code> | " | |
| "<code>python -m experiments.parameter_sweep</code> | " | |
| "<code>python -m experiments.data_prep_analysis</code>" | |
| "</div>", | |
| unsafe_allow_html=True, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Footer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("---") | |
| st.markdown( | |
| "<div style='text-align:center; color:#484f58; font-size:0.82rem;'>" | |
| "VLM Caption Lab Β· Image Captioning Β· Cross-Attention Ablation Study Β· " | |
| "BLIP Β· ViT-GPT2 Β· GIT Β· Visual Prefix-Tuning" | |
| "</div>", | |
| unsafe_allow_html=True, | |
| ) | |