import gradio as gr import torch from src.utils import load_experiment from interpretability_demo import generate_rollout_for_demo from data.transforms import build_coco_transform from huggingface_hub import snapshot_download import os def load_model(): repo_id = "evanec/coco-demo" local_dir = snapshot_download(repo_id) base_path = os.path.join( local_dir, "experiment_54_clipencoder_last4_t5base_full", "20251204_140137" ) return base_path # Load Model + Preprocessing CHECKPOINT = load_model() device = "cuda" if torch.cuda.is_available() else "cpu" model, tokenizer, meta, config = load_experiment(CHECKPOINT, device=device) image_size = config["model"]["image_size"] preprocess = build_coco_transform(image_size) def step_prev(step): return max(step - 1, 0) def step_next(step, max_step): return min(step + 1, max_step) # Backend Logic def run_full_rollout(img, max_tokens, alpha): data = generate_rollout_for_demo( model, tokenizer, img, preprocess, device=device, max_new_tokens=max_tokens, alpha=alpha ) caption = data["caption"] avg_rollout = data["avg"]["frames"] heads_rollout = data["heads"]["frames"] labels = data["avg"]["labels"] if len(avg_rollout) == 0: return caption, None, None, None, 0 max_step = len(avg_rollout) - 1 return caption, avg_rollout[0], avg_rollout, heads_rollout, labels, max_step def update_display(step, mode, avg_rollout, heads_rollout, labels): if avg_rollout is None: return gr.update(visible=True, value=None), "", gr.update(visible=False) step = max(0, min(step, len(avg_rollout) - 1)) label = labels[step] if mode == "Averaged": return ( gr.update(visible=True, value=avg_rollout[step]), # show averaged label, gr.update(visible=False) # hide gallery ) # All Heads mode frames = heads_rollout[step] # list of PIL images return ( gr.update(visible=False), # hide averaged label, gr.update(visible=True, value=frames) # show gallery ) # Gradio UI with gr.Blocks( title="Team Coco — Image Captioning + Cross-Attention Viz", css=""" .token-box textarea { font-size: 22px !important; line-height: 1.5 !important; height: 70px !important; width: 200px !important; } """ ) as demo: gr.Markdown("## CS7643 Coco Project: Image Captioning + Cross-Attention Visualization") gr.Markdown(""" ### About This Demo This Space showcases a multimodal captioning model built from: - **CLIP ViT-B/16 encoder (last 4 layers fine-tuned)** - **T5-Base decoder (fully fine-tuned)** This is a mid-tier model from within our experiments, but due to compute limitations on the demo server, this configuration provides a good balance between performance, speed, and interpretability, while capturing the core behavior of the full model. ### What You Can Do Here **1. Generate a caption** Upload any image and the model will produce a text caption autoregressively. **2. Visualize cross-attention heatmaps** For each generated token, you can explore attention patterns between the image and text: - **Averaged Mode:** Shows attention averaged across all decoder heads. - **All Heads Mode:** Displays a grid of heatmaps, one per attention head. **3. Inspect Token-by-Token Reasoning** Use the slider to move through the caption token sequence and observe how attention shifts as the model generates words. **4. Adjust Generation & Visualization Controls** - **Max Tokens:** Limits caption length. - **Alpha:** Controls heatmap overlay transparency. This demo is meant to provide insight into how the model sees and why it generates specific words, making the captioning process far more interpretable than a standard black-box model. """) with gr.Row(): input_img = gr.Image(type="pil", label="Upload Image",value="36384.jpg", scale=0, image_mode="RGB", interactive=True) with gr.Column(): max_tokens = gr.Slider(1, 64, value=32, step=1, label="Max Tokens") alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay Transparency") run_btn = gr.Button("Generate Caption + Heatmaps", variant="primary") caption_out = gr.Textbox(label="Generated Caption") mode = gr.Radio( choices=["Averaged", "All Heads"], value="Averaged", label="Attention Heads" ) step_slider = gr.Slider( minimum=0, maximum=0, value=0, step=1, label="Token", visible=False, interactive=True ) with gr.Row(): prev_btn = gr.Button("◀ Prev") next_btn = gr.Button("Next ▶") gr.Markdown("## Cross-Attention Heatmap") with gr.Row(): attention_img = gr.Image( label="Averaged Attention Overlay", visible=True, container=False, scale=1 ) attention_label = gr.Textbox( label="Token", interactive=False, elem_classes=["token-box"], scale=1 ) head_gallery = gr.Gallery( label="All Heads", visible=False, columns=6, height="auto" ) avg_state = gr.State() heads_state = gr.State() labels_state = gr.State() max_step_state = gr.State() # Run Rollout run_btn.click( fn=run_full_rollout, inputs=[input_img, max_tokens, alpha], outputs=[ caption_out, attention_img, avg_state, heads_state, labels_state, max_step_state ] ).then( lambda ms: gr.update(visible=True, maximum=ms, value=0), inputs=max_step_state, outputs=step_slider ) # Updates on Step Change step_slider.change( fn=update_display, inputs=[step_slider, mode, avg_state, heads_state, labels_state], outputs=[attention_img, attention_label, head_gallery] ) # Previous button prev_btn.click( step_prev, inputs=step_slider, outputs=step_slider ).then( update_display, inputs=[step_slider, mode, avg_state, heads_state, labels_state], outputs=[attention_img, attention_label, head_gallery] ) # Next button next_btn.click( step_next, inputs=[step_slider, max_step_state], outputs=step_slider ).then( update_display, inputs=[step_slider, mode, avg_state, heads_state, labels_state], outputs=[attention_img, attention_label, head_gallery] ) # Updates on Mode Change mode.change( fn=update_display, inputs=[step_slider, mode, avg_state, heads_state, labels_state], outputs=[attention_img, attention_label, head_gallery] ) demo.load( fn=run_full_rollout, inputs=[input_img, max_tokens, alpha], outputs=[ caption_out, attention_img, avg_state, heads_state, labels_state, max_step_state ] ).then( lambda ms: gr.update(visible=True, maximum=ms, value=0), inputs=max_step_state, outputs=step_slider ) demo.launch()