import os
import streamlit as st
import torch
import matplotlib.pyplot as plt
import numpy as np
import time
import pandas as pd
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
VisionEncoderDecoderModel,
ViTImageProcessor,
AutoTokenizer,
GitProcessor,
GitForCausalLM
)
from PIL import Image
def _get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
device = _get_device()
_TORCH_DTYPE = torch.float16 if device.type in {"cuda", "mps"} else torch.float32
def _resolve_source(local_dir: str, hub_id: str) -> str:
"""
Prefer a local directory if it exists; otherwise use a Hugging Face Hub repo id.
"""
if local_dir and os.path.isdir(local_dir):
return local_dir
return hub_id
# ================================
# EXPERIMENT GRAPH FUNCTIONS
# ================================
def plot_beam_experiment():
beam_sizes = [1,3,5,10]
blip_scores = [0.52,0.59,0.61,0.60]
vit_scores = [0.50,0.56,0.60,0.58]
git_scores = [0.12,0.16,0.17,0.16]
fig, ax = plt.subplots(figsize=(10,6))
ax.plot(beam_sizes, blip_scores, marker='o', linewidth=3, label="BLIP")
ax.plot(beam_sizes, vit_scores, marker='o', linewidth=3, label="ViT-GPT2")
ax.plot(beam_sizes, git_scores, marker='o', linewidth=3, label="GIT")
ax.set_xlabel("Beam Size")
ax.set_ylabel("CIDEr Score")
ax.set_title("Beam Size vs Caption Quality")
ax.legend()
ax.grid(True)
return fig
def plot_caption_length():
labels = ["Short","Medium","Long"]
blip = [0.71,0.60,0.48]
vit = [0.65,0.59,0.42]
git = [0.30,0.18,0.11]
x = np.arange(len(labels))
width = 0.25
fig, ax = plt.subplots(figsize=(10,6))
ax.bar(x - width, blip, width, label="BLIP")
ax.bar(x, vit, width, label="ViT-GPT2")
ax.bar(x + width, git, width, label="GIT")
ax.set_xlabel("Caption Length Category")
ax.set_ylabel("CIDEr Score")
ax.set_title("Model Performance vs Caption Length")
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
return fig
# ================================
# UI STYLE
# ================================
st.markdown("""
""", unsafe_allow_html=True)
# ================================
# LOAD MODELS
# ================================
@st.cache_resource
def load_blip():
source = _resolve_source(
os.getenv("BLIP_LOCAL_DIR", "saved_model_phase2"),
os.getenv("BLIP_MODEL_ID", "pchandragrid/blip-caption-model"),
)
model = BlipForConditionalGeneration.from_pretrained(
source,
torch_dtype=_TORCH_DTYPE,
low_cpu_mem_usage=True,
)
processor = BlipProcessor.from_pretrained(source)
model.to(device)
model.eval()
return model, processor
@st.cache_resource
def load_vit_gpt2():
source = _resolve_source(
os.getenv("VITGPT2_LOCAL_DIR", "saved_vit_gpt2"),
os.getenv("VITGPT2_MODEL_ID", "pchandragrid/vit-gpt2-caption-model"),
)
model = VisionEncoderDecoderModel.from_pretrained(
source,
torch_dtype=_TORCH_DTYPE,
low_cpu_mem_usage=True,
)
processor = ViTImageProcessor.from_pretrained(source)
tokenizer = AutoTokenizer.from_pretrained(source)
model.to(device)
model.eval()
return model, processor, tokenizer
@st.cache_resource
def load_git():
source = _resolve_source(
os.getenv("GIT_LOCAL_DIR", "saved_git_model"),
os.getenv("GIT_MODEL_ID", "pchandragrid/git-caption-model"),
)
processor = GitProcessor.from_pretrained(source)
model = GitForCausalLM.from_pretrained(
source,
torch_dtype=_TORCH_DTYPE,
low_cpu_mem_usage=True,
)
model.to(device)
model.eval()
return model, processor
# ================================
# HEADER
# ================================
st.markdown('
🖼️ Image Captioning
', unsafe_allow_html=True)
st.markdown(
'Compare BLIP vs ViT-GPT2 vs GIT on the same image
',
unsafe_allow_html=True
)
# ================================
# SIDEBAR
# ================================
st.sidebar.header("⚙️ Generation Settings")
st.sidebar.subheader("Models to run")
run_blip = st.sidebar.checkbox("BLIP", value=True)
run_vit = st.sidebar.checkbox("ViT-GPT2", value=False)
run_git = st.sidebar.checkbox("GIT", value=False)
num_beams = st.sidebar.slider("Beam Size",1,10,5)
max_length = st.sidebar.slider("Max Length",10,50,20)
length_penalty = st.sidebar.slider("Length Penalty",0.5,2.0,1.0,step=0.1)
uploaded_file = st.file_uploader("Upload Image", type=["jpg","png","jpeg"])
# ================================
# IMAGE DISPLAY
# ================================
if uploaded_file:
image = Image.open(uploaded_file).convert("RGB")
st.markdown(
"""
Uploaded Image
""",
unsafe_allow_html=True
)
st.image(image, use_container_width=True)
if st.button("Generate Captions"):
with st.spinner("Running models..."):
if not any([run_blip, run_vit, run_git]):
st.warning("Select at least one model in the sidebar.")
st.stop()
results = []
blip_inputs = None
if run_blip:
blip_model, blip_processor = load_blip()
start = time.time()
blip_inputs = blip_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
blip_ids = blip_model.generate(
**blip_inputs,
num_beams=num_beams,
max_length=max_length,
length_penalty=length_penalty,
)
blip_caption = blip_processor.decode(blip_ids[0], skip_special_tokens=True)
results.append(("BLIP", blip_caption, time.time() - start))
if run_vit:
vit_model, vit_processor, vit_tokenizer = load_vit_gpt2()
start = time.time()
pixel_values = vit_processor(images=image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
vit_ids = vit_model.generate(
pixel_values=pixel_values,
num_beams=num_beams,
max_length=max_length,
)
vit_caption = vit_tokenizer.decode(vit_ids[0], skip_special_tokens=True)
results.append(("ViT-GPT2", vit_caption, time.time() - start))
if run_git:
git_model, git_processor = load_git()
start = time.time()
git_inputs = git_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
git_ids = git_model.generate(
**git_inputs,
num_beams=num_beams,
max_length=max_length,
)
git_caption = git_processor.batch_decode(git_ids, skip_special_tokens=True)[0]
results.append(("GIT", git_caption, time.time() - start))
st.divider()
st.subheader("Model Comparison")
st.markdown("""
Each model generates a caption describing the uploaded image.
This comparison highlights differences in:
• caption quality
• inference speed
• architectural design
""")
cols = st.columns(len(results))
for col, (name, caption, seconds) in zip(cols, results):
with col:
st.markdown(f'{name}
', unsafe_allow_html=True)
st.markdown(f'{caption}
', unsafe_allow_html=True)
st.caption(f"Inference: {seconds:.2f}s")
st.divider()
# ================================
# ATTENTION HEATMAP
# ================================
if run_blip and blip_inputs is not None:
blip_model, _ = load_blip()
with torch.no_grad():
vision_outputs = blip_model.vision_model(
blip_inputs["pixel_values"],
output_attentions=True,
return_dict=True,
)
attentions = vision_outputs.attentions[-1]
attn = attentions[0].mean(0)
cls_attn = attn[0, 1:]
attn_map = cls_attn.cpu().numpy()
attn_map = attn_map / attn_map.max()
size = int(np.sqrt(len(attn_map)))
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(attn_map.reshape(size, size), cmap="viridis")
ax.set_title("BLIP Vision Attention")
ax.axis("off")
st.pyplot(fig, use_container_width=True)
st.markdown("""
### 🔍 Attention Visualization
The attention heatmap highlights **which regions of the image the model focused on while generating the caption**.
Brighter regions indicate higher importance for the caption generation process.
""")
# ================================
# ARCHITECTURE COMPARISON TABLE
# ================================
st.divider()
tab1, tab2 = st.tabs(["📊 Model Architecture Comparison", "📊 Experiment Analysis"])
with tab1:
st.header("Model Architecture Comparison")
data = {
"Model":["BLIP","ViT-GPT2","GIT"],
"Architecture":[
"Vision Transformer + Text Decoder",
"ViT Encoder + GPT2 Decoder",
"Unified Transformer"
],
"Parameters":["~224M","~210M","~150M"],
"Training Time":["~1h 34m / epoch","~1h 20m / epoch","~11 min / epoch"],
"CIDEr Score":["0.61","0.60","0.17"]
}
df = pd.DataFrame(data)
st.table(df)
with tab2:
st.header("Experiment Analysis")
st.subheader("Beam Size vs Caption Quality")
fig1 = plot_beam_experiment()
st.pyplot(fig1, use_container_width=True)
st.markdown("""
Beam search controls how many candidate captions are explored during generation.
Increasing beam size improves caption quality initially but eventually leads to diminishing returns.
""")
st.divider()
st.subheader("Caption Length vs Model Performance")
fig2 = plot_caption_length()
st.pyplot(fig2, use_container_width=True)
st.markdown("""
Caption length impacts performance because longer captions require more detailed reasoning about the scene.
Models generally perform better on shorter captions.
""")