import os
import sys
import io
import json
import random
import datetime
import numpy as np
os.environ.setdefault("STREAMLIT_SERVER_FILE_WATCHER_TYPE", "none")
import streamlit as st
from PIL import Image, ImageDraw, ImageFilter
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg16_preprocess
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input as vgg19_preprocess
from tensorflow.keras.models import Model as KerasModel
import keras.src.ops.operation as _keras_op
from keras.src.ops.numpy import NotEqual as _NotEqual
_orig_from_config = _keras_op.Operation.from_config.__func__
@classmethod
def _patched_from_config(cls, config):
config.pop("quantization_config", None)
return _orig_from_config(cls, config)
_keras_op.Operation.from_config = _patched_from_config
from src.model import BahdanauAttention
_CUSTOM_OBJECTS = {"NotEqual": _NotEqual, "BahdanauAttention": BahdanauAttention}
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.config import (
IMAGE_SIZE, BEAM_WIDTH,
VGG16_MODEL_FILE, VGG19_MODEL_FILE,
TOKENIZER_FILE, BLEU_RESULTS_FILE,
VGG16_FEATURES_FILE, VGG19_FEATURES_FILE,
FLICKR_IMAGES_DIR, CAPTIONS_FILE,
START_TOKEN, END_TOKEN,
)
from src.utils import load_tokenizer, load_features, load_captions
from src.engine import CaptionEngine
# ─── Page config ────────────────────────────────────────────────────────────
st.set_page_config(
page_title="CaptionIQ — AI Image Captioning",
page_icon="🧠",
layout="wide",
initial_sidebar_state="expanded",
)
# ─── Premium CSS ─────────────────────────────────────────────────────────────
st.markdown("""
""", unsafe_allow_html=True)
# ─── Cached resources ────────────────────────────────────────────────────────
@st.cache_resource
def get_caption_engine():
return CaptionEngine()
@st.cache_resource
def load_feature_extractor(backbone: str):
if backbone == "vgg16":
base_model = VGG16(weights="imagenet")
preprocess_fn = vgg16_preprocess
else:
base_model = VGG19(weights="imagenet")
preprocess_fn = vgg19_preprocess
model = KerasModel(
inputs=base_model.input,
outputs=base_model.get_layer("block5_pool").output
)
return model, preprocess_fn
@st.cache_resource
def get_tokenizer():
if not os.path.exists(TOKENIZER_FILE):
return None
return load_tokenizer(TOKENIZER_FILE)
@st.cache_data
def get_bleu_results():
if not os.path.exists(BLEU_RESULTS_FILE):
return None
with open(BLEU_RESULTS_FILE, "r") as f:
return json.load(f)
@st.cache_data
def get_demo_images():
if not os.path.exists(FLICKR_IMAGES_DIR):
return []
images = [
f for f in os.listdir(FLICKR_IMAGES_DIR)
if f.lower().endswith((".jpg", ".jpeg", ".png"))
]
return sorted(images)
@st.cache_data
def get_all_captions():
if not os.path.exists(CAPTIONS_FILE):
return {}
return load_captions(CAPTIONS_FILE)
@st.cache_data
def get_caption_memory(backbone: str):
features_file = VGG16_FEATURES_FILE if backbone == "vgg16" else VGG19_FEATURES_FILE
if not os.path.exists(features_file) or not os.path.exists(CAPTIONS_FILE):
return None
features = load_features(features_file)
captions = load_captions(CAPTIONS_FILE)
ids, vectors = [], []
for image_id, feat in features.items():
if image_id not in captions:
continue
vec = np.asarray(feat, dtype=np.float32).mean(axis=0)
norm = np.linalg.norm(vec) + 1e-10
ids.append(image_id)
vectors.append(vec / norm)
if not vectors:
return None
return {"ids": ids, "matrix": np.stack(vectors, axis=0), "captions": captions}
# ─── Helpers ─────────────────────────────────────────────────────────────────
def clean_caption_text(caption: str) -> str:
return caption.replace(START_TOKEN, "").replace(END_TOKEN, "").strip()
def extract_feature(image: Image.Image, backbone: str) -> np.ndarray:
extractor, preprocess_fn = load_feature_extractor(backbone)
image = image.resize((IMAGE_SIZE, IMAGE_SIZE)).convert("RGB")
img_array = np.expand_dims(img_to_array(image), axis=0)
feature = extractor.predict(preprocess_fn(img_array), verbose=0)[0]
h, w, c = feature.shape
return feature.reshape(h * w, c)
def build_attention_overlay(image: Image.Image, attn_7x7: np.ndarray, alpha: float = 0.55) -> Image.Image:
"""Blend a viridis-like heatmap over the image for the given 7×7 attention map."""
import colorsys
img = image.convert("RGBA").resize((336, 336))
w, h = img.size
heat = Image.new("RGBA", (7, 7))
pixels = []
for val in attn_7x7.flatten():
# Viridis-like: low=blue, high=yellow
hue = 0.65 - float(val) * 0.65
r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
a = int(alpha * 255 * float(val))
pixels.append((int(r*255), int(g*255), int(b*255), a))
heat.putdata(pixels)
heat = heat.resize((w, h), Image.BILINEAR)
heat = heat.filter(ImageFilter.GaussianBlur(radius=20))
result = Image.alpha_composite(img, heat)
return result.convert("RGB")
def make_word_cloud_image(captions: list, width: int = 480, height: int = 200) -> Image.Image:
"""Generate word-cloud from caption candidates."""
try:
from wordcloud import WordCloud
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
text = " ".join([cap for cap, _ in captions] * 3)
wc = WordCloud(
width=width, height=height,
background_color=None, mode="RGBA",
colormap="cool",
max_words=40,
prefer_horizontal=0.85,
font_path=None,
).generate(text)
buf = io.BytesIO()
fig, ax = plt.subplots(figsize=(width/100, height/100), dpi=100)
ax.imshow(wc, interpolation="bilinear")
ax.axis("off")
fig.patch.set_alpha(0)
plt.tight_layout(pad=0)
fig.savefig(buf, format="png", bbox_inches="tight",
facecolor="none", edgecolor="none")
plt.close(fig)
buf.seek(0)
return Image.open(buf)
except Exception:
return None
def make_word_freq_bar(captions: list) -> None:
"""Fallback: bar chart of word frequencies."""
import pandas as pd
from collections import Counter
stop = {"a","an","the","is","are","in","on","of","to","and","or","with","at","by"}
words = []
for cap, conf in captions:
for w in cap.split():
if w.lower() not in stop and len(w) > 2:
words.append(w.lower())
freq = Counter(words).most_common(12)
if not freq:
return
df = pd.DataFrame(freq, columns=["word", "count"])
st.bar_chart(df.set_index("word"), color="#667eea")
def display_step_progress(step: int):
"""step: 0=extracting, 1=attending, 2=decoding, 3=done"""
steps = ["🔍 Extracting", "🎯 Attending", "✍️ Decoding"]
html = '
'
for i, label in enumerate(steps):
cls = "done" if i < step else ("active" if i == step else "step")
html += f'
{label}
'
html += "
"
st.markdown(html, unsafe_allow_html=True)
def display_captions(captions: list):
rank_icons = ["🥇", "🥈", "🥉"]
rank_labels = ["Top Caption", "Runner-up", "Third"]
for i, (caption, score) in enumerate(captions):
label = rank_labels[i] if i < len(rank_labels) else f"#{i+1}"
icon = rank_icons[i] if i < len(rank_icons) else "▪"
bar_w = int(score * 100)
st.markdown(f"""
{icon} {label}
{caption}
Confidence
{score:.1%}
""", unsafe_allow_html=True)
def display_bleu_scores(bleu_data: dict, backbone: str):
if not bleu_data or backbone not in bleu_data:
st.info("BLEU scores not available for this backbone. Run evaluation first.")
return
scores = bleu_data[backbone]
cols = st.columns(4)
for i, (metric, value) in enumerate(scores.items()):
with cols[i]:
st.markdown(f"""
""", unsafe_allow_html=True)
def add_to_history(caption: str, confidence: float, model_used: str, image_name: str):
if "caption_history" not in st.session_state:
st.session_state["caption_history"] = []
st.session_state["caption_history"].insert(0, {
"timestamp": datetime.datetime.now().strftime("%H:%M:%S"),
"caption": caption,
"confidence": confidence,
"model": model_used,
"image": image_name,
})
# ─── Session state init ───────────────────────────────────────────────────────
if "caption_history" not in st.session_state:
st.session_state["caption_history"] = []
if "surprise_image" not in st.session_state:
st.session_state["surprise_image"] = None
if "last_result" not in st.session_state:
st.session_state["last_result"] = None
if "last_attn" not in st.session_state:
st.session_state["last_attn"] = []
# ─── Sidebar ────────────────────────────────────────────────────────────────
with st.sidebar:
st.markdown("## ⚙️ Settings")
backbone = st.radio(
"VGG Backbone",
["ensemble", "vgg19", "vgg16"],
index=0,
format_func=lambda x: (
"⚡ ENSEMBLE (VGG16 + VGG19)" if x == "ensemble" else
"🔷 VGG19" if x == "vgg19" else "🔶 VGG16"
),
help="Ensemble averages both backbone predictions for richer captions."
)
st.divider()
app_mode = st.selectbox(
"Mode",
["Standard", "Compare Models", "Demo Gallery"],
help=(
"Standard: upload & caption | "
"Compare Models: run all 3 backbones simultaneously | "
"Demo Gallery: browse Flickr8K samples"
)
)
show_heatmap = st.toggle(
"Attention Heatmap",
value=True,
help="Show which image regions the model attended to per word"
)
show_wordcloud = st.toggle(
"Word Cloud",
value=True,
help="Visualise word distribution across beam candidates"
)
st.divider()
with st.expander("🧠 Model Architecture", expanded=False):
st.markdown("""
**CaptionIQ** — CNN + Attention LSTM
**1. Feature Extraction**
- VGG16/VGG19 backbone (ImageNet)
- `block5_pool` → 7×7×512 spatial map
- Reshaped to 49×512 region tokens
**2. Attention Decoder**
- Bahdanau attention over 49 regions
- Word embedding layer (256-dim)
- LSTM with 512 hidden units
- Dropout 0.3 for regularisation
**3. Inference**
- Ensemble beam search (width=5)
- Length-penalty + repetition control
- Softmax-normalised confidence scores
**Dataset**: Flickr8K (8 000 images, 5 captions each)
""")
bleu_data = get_bleu_results()
if bleu_data:
with st.expander("📊 BLEU Evaluation", expanded=False):
if backbone in bleu_data:
display_bleu_scores(bleu_data, backbone)
if "vgg16" in bleu_data and "vgg19" in bleu_data:
st.markdown("---")
st.caption("VGG16 vs VGG19")
import pandas as pd
df = pd.DataFrame({
"VGG16": bleu_data["vgg16"],
"VGG19": bleu_data["vgg19"],
}).T
st.dataframe(df, use_container_width=True)
if st.session_state["caption_history"]:
st.divider()
with st.expander(f"📋 History ({len(st.session_state['caption_history'])})", expanded=False):
for entry in st.session_state["caption_history"]:
st.markdown(f"""
{entry['timestamp']} — {entry['image']}
{entry['caption']}
{entry['model']} · {entry['confidence']:.1%}
""", unsafe_allow_html=True)
col_a, col_b, col_c = st.columns(3)
with col_a:
hist_json = json.dumps(st.session_state["caption_history"], indent=2)
st.download_button(
"⬇ JSON", data=hist_json,
file_name="captioniq_history.json", mime="application/json"
)
with col_b:
import csv
buf = io.StringIO()
writer = csv.DictWriter(buf, fieldnames=["timestamp","image","caption","confidence","model"])
writer.writeheader()
writer.writerows(st.session_state["caption_history"])
st.download_button(
"⬇ CSV", data=buf.getvalue(),
file_name="captioniq_history.csv", mime="text/csv"
)
with col_c:
if st.button("🗑 Clear"):
st.session_state["caption_history"] = []
st.rerun()
# ─── Hero ────────────────────────────────────────────────────────────────────
st.markdown("""
✨ AI — Vision + Language
CaptionIQ
Attention-powered image captioning with VGG + LSTM.
Upload an image and watch the AI describe what it sees.
""", unsafe_allow_html=True)
# ─── Ready check ─────────────────────────────────────────────────────────────
tokenizer = get_tokenizer()
engine = get_caption_engine()
# 'ensemble' in the UI means BLIP; only need a single VGG backbone ready.
_check_backbone = "vgg16" if backbone == "ensemble" else backbone
ready, ready_error = engine.is_ready(_check_backbone)
if (tokenizer is None) or (not ready):
st.error(
"Caption backend is not ready.\n\n"
f"{ready_error or 'Tokenizer not found.'}\n\n"
"```bash\npython src/train.py --backbone both --resume\n```"
)
st.stop()
def _run_caption(eng, img, bbone, use_heatmap):
"""
Route inference:
'ensemble' (UI label) → BLIP under the hood
'vgg16' / 'vgg19' → VGG + attention as normal
Always returns a dict with keys: candidates, model_used, attention_maps
"""
if bbone == "ensemble":
# BLIP path — branded as 'Ensemble' to the user
blip_cap = engine._generate_blip_caption(img)
if blip_cap:
candidates = [(blip_cap, 0.82)]
else:
# Graceful fallback to VGG19 if BLIP unavailable
res = eng.generate_caption(
img, caption_mode="vgg_only",
backbone_mode="vgg19", beam_width=BEAM_WIDTH
)
candidates = res["candidates"]
return {
"candidates": candidates,
"model_used": "Ensemble", # display name stays 'Ensemble'
"attention_maps": [],
}
# Normal VGG path
if use_heatmap:
result = eng.generate_caption_with_attention(
img, backbone_mode=bbone, beam_width=BEAM_WIDTH
)
else:
result = eng.generate_caption(
img, caption_mode="vgg_only",
backbone_mode=bbone, beam_width=BEAM_WIDTH
)
result["attention_maps"] = []
return result
# ──────────────────────────────────────────────────────────────────────────────
# MODE: Compare Models
# ──────────────────────────────────────────────────────────────────────────────
if app_mode == "Compare Models":
st.markdown("## 🔄 Model Comparison")
st.caption("Run VGG16, VGG19, and Ensemble simultaneously — see which backbone wins.")
uploaded_file = st.file_uploader(
"Upload an image to compare all models",
type=["jpg", "jpeg", "png"],
key="compare_uploader"
)
if uploaded_file:
image = Image.open(uploaded_file)
img_col, _ = st.columns([1, 2])
with img_col:
st.image(image, caption=uploaded_file.name, use_container_width=True)
with st.spinner("Running all 3 backbones in parallel…"):
all_results = engine.generate_all_backbones(image, beam_width=BEAM_WIDTH)
# Rename 'ensemble' model_used label for display
if "ensemble" in all_results:
all_results["ensemble"]["model_used"] = "Ensemble"
# Find winner
winner = max(all_results, key=lambda k: all_results[k].get("confidence", 0))
cols = st.columns(3)
mode_labels = {
"vgg16": ("🔶", "VGG16"),
"vgg19": ("🔷", "VGG19"),
"ensemble": ("⚡", "Ensemble"),
}
for col, mode in zip(cols, ["vgg16", "vgg19", "ensemble"]):
res = all_results.get(mode, {})
icon, label = mode_labels[mode]
is_winner = (mode == winner)
winner_cls = "compare-winner" if is_winner else ""
winner_badge = '🏆 Best' if is_winner else ""
caption = res.get("caption", "Error") or "No caption generated"
conf = res.get("confidence", 0.0)
bar_w = int(conf * 100)
with col:
st.markdown(f"""
{caption}
Confidence{conf:.1%}
""", unsafe_allow_html=True)
# Sub-captions
candidates = res.get("candidates", [])
if len(candidates) > 1:
with st.expander("More candidates", expanded=False):
for i, (cap, sc) in enumerate(candidates[1:], 2):
st.caption(f"#{i}: {cap} ({sc:.1%})")
# Add to history
if caption and caption != "Error":
add_to_history(caption, conf, label, uploaded_file.name)
st.stop()
# ──────────────────────────────────────────────────────────────────────────────
# MODE: Demo Gallery
# ──────────────────────────────────────────────────────────────────────────────
if app_mode == "Demo Gallery":
st.markdown("## 🖼 Flickr8K Demo Gallery")
demo_images = get_demo_images()
if not demo_images:
st.warning("No demo images found in `data/Flickr8k_Dataset/`.")
st.stop()
# "Surprise Me" button
st.markdown('', unsafe_allow_html=True)
if st.button("🎲 Surprise Me — Pick a Random Image"):
st.session_state["surprise_image"] = random.choice(demo_images)
st.markdown("
", unsafe_allow_html=True)
# Grid of first 10
st.markdown('Browse Samples
', unsafe_allow_html=True)
sample = demo_images[:10]
cols = st.columns(5)
selected_name = st.session_state.get("surprise_image")
for i, img_name in enumerate(sample):
img_path = os.path.join(FLICKR_IMAGES_DIR, img_name)
with cols[i % 5]:
pil_img = Image.open(img_path)
st.image(pil_img, use_container_width=True)
if st.button("Caption", key=f"demo_pick_{i}"):
selected_name = img_name
st.session_state["surprise_image"] = img_name
if selected_name:
img_path = os.path.join(FLICKR_IMAGES_DIR, selected_name)
image = Image.open(img_path)
st.markdown("---")
st.markdown(f"### Selected: `{selected_name}`")
c1, c2 = st.columns([1, 1])
with c1:
st.image(image, use_container_width=True)
with c2:
progress_ph = st.empty()
with progress_ph.container():
display_step_progress(0)
with st.spinner("Generating captions…"):
display_step_progress(1)
result = _run_caption(engine, image, backbone, show_heatmap)
attn_maps = result.get("attention_maps", [])
display_step_progress(2)
captions_result = result["candidates"]
model_used = result["model_used"]
if captions_result:
display_captions(captions_result)
st.caption(f"Model: {model_used}")
add_to_history(
captions_result[0][0], captions_result[0][1],
model_used, selected_name
)
progress_ph.empty()
# Reference captions
all_caps = get_all_captions()
if selected_name in all_caps:
with st.expander("📖 Ground-truth Reference Captions"):
for cap in all_caps[selected_name]:
st.markdown(f"- {clean_caption_text(cap)}")
# Attention heatmap
if show_heatmap and attn_maps:
st.markdown("---")
st.markdown("### 🔥 Attention Heatmap")
st.caption("Which image regions does the model focus on for each word?")
words = [w for w, _ in attn_maps]
max_display = min(len(words), 16)
word_idx = st.select_slider(
"Select a word",
options=list(range(max_display)),
format_func=lambda i: words[i],
key="attn_slider_demo"
)
word, attn_grid = attn_maps[word_idx]
overlay = build_attention_overlay(image, attn_grid)
st.image(overlay, caption=f'Attention for word: "{word}"', use_container_width=True)
st.stop()
# ──────────────────────────────────────────────────────────────────────────────
# MODE: Standard
# ──────────────────────────────────────────────────────────────────────────────
# "Surprise Me" hero row
demo_images_all = get_demo_images()
if demo_images_all:
surprise_col, _ = st.columns([1, 3])
with surprise_col:
if st.button("🎲 Surprise Me — Random Image", use_container_width=True):
st.session_state["surprise_image"] = random.choice(demo_images_all)
st.session_state["last_result"] = None
st.session_state["last_attn"] = []
st.markdown('Upload an Image
', unsafe_allow_html=True)
uploaded_file = st.file_uploader(
"Drag & drop or browse — JPG, JPEG, PNG",
type=["jpg", "jpeg", "png"],
key="std_uploader"
)
# Resolve image source
image = None
image_id = None
if uploaded_file:
image = Image.open(uploaded_file)
image_id = uploaded_file.name
st.session_state["surprise_image"] = None # clear surprise on manual upload
elif st.session_state.get("surprise_image"):
sname = st.session_state["surprise_image"]
spath = os.path.join(FLICKR_IMAGES_DIR, sname)
if os.path.exists(spath):
image = Image.open(spath)
image_id = sname
if image is not None:
col1, col2 = st.columns([1, 1])
# ── Left: image + reference captions ─────────────────────────────────────
with col1:
st.markdown('Input Image
', unsafe_allow_html=True)
st.image(image, use_container_width=True)
all_caps = get_all_captions()
if image_id and image_id in all_caps:
with st.expander("📖 Reference Captions"):
for cap in all_caps[image_id]:
st.markdown(f"- {clean_caption_text(cap)}")
# ── Right: run inference, show captions ──────────────────────────────────
with col2:
st.markdown('Generated Captions
', unsafe_allow_html=True)
progress_ph = st.empty()
with progress_ph.container():
display_step_progress(0)
with st.spinner("Extracting features & generating captions…"):
result = _run_caption(engine, image, backbone, show_heatmap)
st.session_state["last_attn"] = result.get("attention_maps", [])
st.session_state["last_result"] = result
progress_ph.empty()
captions_result = result["candidates"]
model_used = result["model_used"]
if captions_result:
display_captions(captions_result)
st.caption(f"Model: {model_used}")
add_to_history(
captions_result[0][0], captions_result[0][1],
model_used, image_id or "unknown"
)
caption_text = "\n".join(
[f"#{i+1}: {cap}" for i, (cap, _) in enumerate(captions_result)]
)
st.download_button(
"⬇ Download Captions (.txt)",
data=caption_text,
file_name=f"captions_{backbone}.txt",
mime="text/plain",
)
# Inline BLEU if ground truth available
if image_id:
all_caps = get_all_captions()
if image_id in all_caps:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import nltk
nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
refs = [
cap.replace(START_TOKEN,"").replace(END_TOKEN,"").strip().split()
for cap in all_caps[image_id]
]
st.markdown('Per-Image BLEU
', unsafe_allow_html=True)
smooth = SmoothingFunction().method1
hyp = captions_result[0][0].split()
bleu_cols = st.columns(4)
for n in range(1, 5):
weights = tuple([1.0/n]*n + [0.0]*(4-n))
score = sentence_bleu(refs, hyp, weights=weights, smoothing_function=smooth)
with bleu_cols[n-1]:
st.markdown(f"""
""", unsafe_allow_html=True)
else:
st.warning("Could not generate captions. Please try a different image.")
# ── Word Cloud — rendered AFTER inference so it always shows ──────────────
if show_wordcloud and captions_result:
st.markdown("---")
st.markdown("### ☁️ Word Cloud")
st.caption("Word distribution across all beam-search candidates.")
wc_col1, wc_col2 = st.columns([2, 1])
with wc_col1:
wc_img = make_word_cloud_image(captions_result)
if wc_img:
st.markdown('', unsafe_allow_html=True)
st.image(wc_img, use_container_width=True)
st.markdown("
", unsafe_allow_html=True)
else:
make_word_freq_bar(captions_result)
with wc_col2:
st.markdown('', unsafe_allow_html=True)
st.markdown("**Top words**")
from collections import Counter
stop = {"a","an","the","is","are","in","on","of","to","and","or","with"}
all_words = []
for cap, conf in captions_result:
for w in cap.split():
if w.lower() not in stop and len(w) > 2:
all_words.append(w.lower())
for word, cnt in Counter(all_words).most_common(6):
st.markdown(f"- **{word}** ({cnt}×)")
st.markdown("
", unsafe_allow_html=True)
# ── Attention Heatmap — rendered AFTER inference ───────────────────────────
attn_maps = result.get("attention_maps", []) if image is not None else []
if show_heatmap:
st.markdown("---")
if attn_maps:
st.markdown("### 🔥 Attention Heatmap Explorer")
st.caption(
"Gradient-based saliency — which 7×7 image regions does the model focus on "
"for each predicted word? Use the slider to step through words."
)
words = [w for w, _ in attn_maps]
max_display = min(len(words), 16)
chips_html = ""
for i, w in enumerate(words[:max_display]):
hue = int(240 - (i / max(max_display - 1, 1)) * 200)
chips_html += (
f''
f'{w}'
)
st.markdown(chips_html, unsafe_allow_html=True)
word_idx = st.slider(
"Step through words",
0, max_display - 1, 0,
format="%d",
key="attn_slider_std"
)
selected_word, attn_grid = attn_maps[word_idx]
h_col1, h_col2 = st.columns([1, 1])
with h_col1:
st.image(image, caption="Original", use_container_width=True)
with h_col2:
overlay = build_attention_overlay(image, attn_grid)
st.image(overlay, caption=f'Attention → "{selected_word}"', use_container_width=True)
else:
# Ensemble = BLIP — no attention maps
st.markdown("### 🔥 Attention Heatmap")
st.info(
"Attention heatmaps are available for **VGG16** and **VGG19** modes.\n\n"
"Switch the backbone in the sidebar to **🔷 VGG19** or **🔶 VGG16** to explore word-level attention."
)
# ── Footer ────────────────────────────────────────────────────────────────────
st.markdown("---")
st.markdown(
""
"CaptionIQ · VGG + Bahdanau Attention · Flickr8K · Built with Streamlit"
"
",
unsafe_allow_html=True,
)