import gradio as gr import torch from src.utils import load_experiment from demo.interpretability_demo import generate_rollout_for_demo from data.transforms import build_coco_transform # Load Model + Preprocessing CHECKPOINT = "experiment_21_full_vlm_ceiling/20251127_231732" device = "cuda" if torch.cuda.is_available() else "cpu" model, tokenizer, meta, config = load_experiment(CHECKPOINT, device=device) image_size = config["model"]["image_size"] preprocess = build_coco_transform(image_size) def step_prev(step): return max(step - 1, 0) def step_next(step, max_step): return min(step + 1, max_step) # Backend Logic def run_full_rollout(img, max_tokens, alpha): data = generate_rollout_for_demo( model, tokenizer, img, preprocess, device=device, max_new_tokens=max_tokens, alpha=alpha ) caption = data["caption"] avg_rollout = data["avg"]["frames"] heads_rollout = data["heads"]["frames"] labels = data["avg"]["labels"] if len(avg_rollout) == 0: return caption, None, None, None, 0 max_step = len(avg_rollout) - 1 return caption, avg_rollout[0], avg_rollout, heads_rollout, labels, max_step def update_display(step, mode, avg_rollout, heads_rollout, labels): if avg_rollout is None: return gr.update(visible=True, value=None), "", gr.update(visible=False) step = max(0, min(step, len(avg_rollout) - 1)) label = labels[step] if mode == "Averaged": return ( gr.update(visible=True, value=avg_rollout[step]), # show averaged label, gr.update(visible=False) # hide gallery ) # All Heads mode frames = heads_rollout[step] # list of PIL images return ( gr.update(visible=False), # hide averaged label, gr.update(visible=True, value=frames) # show gallery ) # Gradio UI with gr.Blocks( title="Team Coco — Image Captioning + Cross-Attention Viz", css=""" .token-box textarea { font-size: 22px !important; line-height: 1.5 !important; height: 70px !important; width: 200px !important; } """ ) as demo: gr.Markdown("## Image Captioning + Cross-Attention Visualization") with gr.Row(): input_img = gr.Image(type="pil", label="Upload Image",value="demo/36384.jpg", scale=0, image_mode="RGB", interactive=True) with gr.Column(): max_tokens = gr.Slider(1, 64, value=32, step=1, label="Max Tokens") alpha = gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Overlay Transparency") run_btn = gr.Button("Generate Caption + Heatmaps", variant="primary") caption_out = gr.Textbox(label="Generated Caption") mode = gr.Radio( choices=["Averaged", "All Heads"], value="Averaged", label="Attention Heads" ) step_slider = gr.Slider( minimum=0, maximum=0, value=0, step=1, label="Token", visible=False, interactive=True ) with gr.Row(): prev_btn = gr.Button("◀ Prev") next_btn = gr.Button("Next ▶") gr.Markdown("## Cross-Attention Heatmap") with gr.Row(): attention_img = gr.Image( label="Averaged Attention Overlay", visible=True, container=False, scale=1 ) attention_label = gr.Textbox( label="Token", interactive=False, elem_classes=["token-box"], scale=1 ) head_gallery = gr.Gallery( label="All Heads", visible=False, columns=6, height="auto" ) avg_state = gr.State() heads_state = gr.State() labels_state = gr.State() max_step_state = gr.State() # Run Rollout run_btn.click( fn=run_full_rollout, inputs=[input_img, max_tokens, alpha], outputs=[ caption_out, attention_img, avg_state, heads_state, labels_state, max_step_state ] ).then( lambda ms: gr.update(visible=True, maximum=ms, value=0), inputs=max_step_state, outputs=step_slider ) # Updates on Step Change step_slider.change( fn=update_display, inputs=[step_slider, mode, avg_state, heads_state, labels_state], outputs=[attention_img, attention_label, head_gallery] ) prev_btn.click(step_prev, inputs=step_slider, outputs=step_slider) next_btn.click(step_next, inputs=[step_slider, max_step_state], outputs=step_slider) # Updates on Mode Change mode.change( fn=update_display, inputs=[step_slider, mode, avg_state, heads_state, labels_state], outputs=[attention_img, attention_label, head_gallery] ) demo.load( fn=run_full_rollout, inputs=[input_img, max_tokens, alpha], outputs=[ caption_out, attention_img, avg_state, heads_state, labels_state, max_step_state ] ).then( lambda ms: gr.update(visible=True, maximum=ms, value=0), inputs=max_step_state, outputs=step_slider ) demo.launch(share=True)