import os import sys import io import json import random import datetime import numpy as np os.environ.setdefault("STREAMLIT_SERVER_FILE_WATCHER_TYPE", "none") import streamlit as st from PIL import Image, ImageDraw, ImageFilter from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.image import img_to_array from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg16_preprocess from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input as vgg19_preprocess from tensorflow.keras.models import Model as KerasModel import keras.src.ops.operation as _keras_op from keras.src.ops.numpy import NotEqual as _NotEqual _orig_from_config = _keras_op.Operation.from_config.__func__ @classmethod def _patched_from_config(cls, config): config.pop("quantization_config", None) return _orig_from_config(cls, config) _keras_op.Operation.from_config = _patched_from_config from src.model import BahdanauAttention _CUSTOM_OBJECTS = {"NotEqual": _NotEqual, "BahdanauAttention": BahdanauAttention} sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from src.config import ( IMAGE_SIZE, BEAM_WIDTH, VGG16_MODEL_FILE, VGG19_MODEL_FILE, TOKENIZER_FILE, BLEU_RESULTS_FILE, VGG16_FEATURES_FILE, VGG19_FEATURES_FILE, FLICKR_IMAGES_DIR, CAPTIONS_FILE, START_TOKEN, END_TOKEN, ) from src.utils import load_tokenizer, load_features, load_captions from src.engine import CaptionEngine # ─── Page config ──────────────────────────────────────────────────────────── st.set_page_config( page_title="CaptionIQ — AI Image Captioning", page_icon="🧠", layout="wide", initial_sidebar_state="expanded", ) # ─── Premium CSS ───────────────────────────────────────────────────────────── st.markdown(""" """, unsafe_allow_html=True) # ─── Cached resources ──────────────────────────────────────────────────────── @st.cache_resource def get_caption_engine(): return CaptionEngine() @st.cache_resource def load_feature_extractor(backbone: str): if backbone == "vgg16": base_model = VGG16(weights="imagenet") preprocess_fn = vgg16_preprocess else: base_model = VGG19(weights="imagenet") preprocess_fn = vgg19_preprocess model = KerasModel( inputs=base_model.input, outputs=base_model.get_layer("block5_pool").output ) return model, preprocess_fn @st.cache_resource def get_tokenizer(): if not os.path.exists(TOKENIZER_FILE): return None return load_tokenizer(TOKENIZER_FILE) @st.cache_data def get_bleu_results(): if not os.path.exists(BLEU_RESULTS_FILE): return None with open(BLEU_RESULTS_FILE, "r") as f: return json.load(f) @st.cache_data def get_demo_images(): if not os.path.exists(FLICKR_IMAGES_DIR): return [] images = [ f for f in os.listdir(FLICKR_IMAGES_DIR) if f.lower().endswith((".jpg", ".jpeg", ".png")) ] return sorted(images) @st.cache_data def get_all_captions(): if not os.path.exists(CAPTIONS_FILE): return {} return load_captions(CAPTIONS_FILE) @st.cache_data def get_caption_memory(backbone: str): features_file = VGG16_FEATURES_FILE if backbone == "vgg16" else VGG19_FEATURES_FILE if not os.path.exists(features_file) or not os.path.exists(CAPTIONS_FILE): return None features = load_features(features_file) captions = load_captions(CAPTIONS_FILE) ids, vectors = [], [] for image_id, feat in features.items(): if image_id not in captions: continue vec = np.asarray(feat, dtype=np.float32).mean(axis=0) norm = np.linalg.norm(vec) + 1e-10 ids.append(image_id) vectors.append(vec / norm) if not vectors: return None return {"ids": ids, "matrix": np.stack(vectors, axis=0), "captions": captions} # ─── Helpers ───────────────────────────────────────────────────────────────── def clean_caption_text(caption: str) -> str: return caption.replace(START_TOKEN, "").replace(END_TOKEN, "").strip() def extract_feature(image: Image.Image, backbone: str) -> np.ndarray: extractor, preprocess_fn = load_feature_extractor(backbone) image = image.resize((IMAGE_SIZE, IMAGE_SIZE)).convert("RGB") img_array = np.expand_dims(img_to_array(image), axis=0) feature = extractor.predict(preprocess_fn(img_array), verbose=0)[0] h, w, c = feature.shape return feature.reshape(h * w, c) def build_attention_overlay(image: Image.Image, attn_7x7: np.ndarray, alpha: float = 0.55) -> Image.Image: """Blend a viridis-like heatmap over the image for the given 7×7 attention map.""" import colorsys img = image.convert("RGBA").resize((336, 336)) w, h = img.size heat = Image.new("RGBA", (7, 7)) pixels = [] for val in attn_7x7.flatten(): # Viridis-like: low=blue, high=yellow hue = 0.65 - float(val) * 0.65 r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95) a = int(alpha * 255 * float(val)) pixels.append((int(r*255), int(g*255), int(b*255), a)) heat.putdata(pixels) heat = heat.resize((w, h), Image.BILINEAR) heat = heat.filter(ImageFilter.GaussianBlur(radius=20)) result = Image.alpha_composite(img, heat) return result.convert("RGB") def make_word_cloud_image(captions: list, width: int = 480, height: int = 200) -> Image.Image: """Generate word-cloud from caption candidates.""" try: from wordcloud import WordCloud import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt text = " ".join([cap for cap, _ in captions] * 3) wc = WordCloud( width=width, height=height, background_color=None, mode="RGBA", colormap="cool", max_words=40, prefer_horizontal=0.85, font_path=None, ).generate(text) buf = io.BytesIO() fig, ax = plt.subplots(figsize=(width/100, height/100), dpi=100) ax.imshow(wc, interpolation="bilinear") ax.axis("off") fig.patch.set_alpha(0) plt.tight_layout(pad=0) fig.savefig(buf, format="png", bbox_inches="tight", facecolor="none", edgecolor="none") plt.close(fig) buf.seek(0) return Image.open(buf) except Exception: return None def make_word_freq_bar(captions: list) -> None: """Fallback: bar chart of word frequencies.""" import pandas as pd from collections import Counter stop = {"a","an","the","is","are","in","on","of","to","and","or","with","at","by"} words = [] for cap, conf in captions: for w in cap.split(): if w.lower() not in stop and len(w) > 2: words.append(w.lower()) freq = Counter(words).most_common(12) if not freq: return df = pd.DataFrame(freq, columns=["word", "count"]) st.bar_chart(df.set_index("word"), color="#667eea") def display_step_progress(step: int): """step: 0=extracting, 1=attending, 2=decoding, 3=done""" steps = ["🔍 Extracting", "🎯 Attending", "✍️ Decoding"] html = '
' for i, label in enumerate(steps): cls = "done" if i < step else ("active" if i == step else "step") html += f'
{label}
' html += "
" st.markdown(html, unsafe_allow_html=True) def display_captions(captions: list): rank_icons = ["🥇", "🥈", "🥉"] rank_labels = ["Top Caption", "Runner-up", "Third"] for i, (caption, score) in enumerate(captions): label = rank_labels[i] if i < len(rank_labels) else f"#{i+1}" icon = rank_icons[i] if i < len(rank_icons) else "▪" bar_w = int(score * 100) st.markdown(f"""
{icon} {label}
{caption}
Confidence {score:.1%}
""", unsafe_allow_html=True) def display_bleu_scores(bleu_data: dict, backbone: str): if not bleu_data or backbone not in bleu_data: st.info("BLEU scores not available for this backbone. Run evaluation first.") return scores = bleu_data[backbone] cols = st.columns(4) for i, (metric, value) in enumerate(scores.items()): with cols[i]: st.markdown(f"""
{value:.4f}
{metric}
""", unsafe_allow_html=True) def add_to_history(caption: str, confidence: float, model_used: str, image_name: str): if "caption_history" not in st.session_state: st.session_state["caption_history"] = [] st.session_state["caption_history"].insert(0, { "timestamp": datetime.datetime.now().strftime("%H:%M:%S"), "caption": caption, "confidence": confidence, "model": model_used, "image": image_name, }) # ─── Session state init ─────────────────────────────────────────────────────── if "caption_history" not in st.session_state: st.session_state["caption_history"] = [] if "surprise_image" not in st.session_state: st.session_state["surprise_image"] = None if "last_result" not in st.session_state: st.session_state["last_result"] = None if "last_attn" not in st.session_state: st.session_state["last_attn"] = [] # ─── Sidebar ──────────────────────────────────────────────────────────────── with st.sidebar: st.markdown("## ⚙️ Settings") backbone = st.radio( "VGG Backbone", ["ensemble", "vgg19", "vgg16"], index=0, format_func=lambda x: ( "⚡ ENSEMBLE (VGG16 + VGG19)" if x == "ensemble" else "🔷 VGG19" if x == "vgg19" else "🔶 VGG16" ), help="Ensemble averages both backbone predictions for richer captions." ) st.divider() app_mode = st.selectbox( "Mode", ["Standard", "Compare Models", "Demo Gallery"], help=( "Standard: upload & caption | " "Compare Models: run all 3 backbones simultaneously | " "Demo Gallery: browse Flickr8K samples" ) ) show_heatmap = st.toggle( "Attention Heatmap", value=True, help="Show which image regions the model attended to per word" ) show_wordcloud = st.toggle( "Word Cloud", value=True, help="Visualise word distribution across beam candidates" ) st.divider() with st.expander("🧠 Model Architecture", expanded=False): st.markdown(""" **CaptionIQ** — CNN + Attention LSTM **1. Feature Extraction** - VGG16/VGG19 backbone (ImageNet) - `block5_pool` → 7×7×512 spatial map - Reshaped to 49×512 region tokens **2. Attention Decoder** - Bahdanau attention over 49 regions - Word embedding layer (256-dim) - LSTM with 512 hidden units - Dropout 0.3 for regularisation **3. Inference** - Ensemble beam search (width=5) - Length-penalty + repetition control - Softmax-normalised confidence scores **Dataset**: Flickr8K (8 000 images, 5 captions each) """) bleu_data = get_bleu_results() if bleu_data: with st.expander("📊 BLEU Evaluation", expanded=False): if backbone in bleu_data: display_bleu_scores(bleu_data, backbone) if "vgg16" in bleu_data and "vgg19" in bleu_data: st.markdown("---") st.caption("VGG16 vs VGG19") import pandas as pd df = pd.DataFrame({ "VGG16": bleu_data["vgg16"], "VGG19": bleu_data["vgg19"], }).T st.dataframe(df, use_container_width=True) if st.session_state["caption_history"]: st.divider() with st.expander(f"📋 History ({len(st.session_state['caption_history'])})", expanded=False): for entry in st.session_state["caption_history"]: st.markdown(f"""
{entry['timestamp']} — {entry['image']}
{entry['caption']}
{entry['model']} · {entry['confidence']:.1%}
""", unsafe_allow_html=True) col_a, col_b, col_c = st.columns(3) with col_a: hist_json = json.dumps(st.session_state["caption_history"], indent=2) st.download_button( "⬇ JSON", data=hist_json, file_name="captioniq_history.json", mime="application/json" ) with col_b: import csv buf = io.StringIO() writer = csv.DictWriter(buf, fieldnames=["timestamp","image","caption","confidence","model"]) writer.writeheader() writer.writerows(st.session_state["caption_history"]) st.download_button( "⬇ CSV", data=buf.getvalue(), file_name="captioniq_history.csv", mime="text/csv" ) with col_c: if st.button("🗑 Clear"): st.session_state["caption_history"] = [] st.rerun() # ─── Hero ──────────────────────────────────────────────────────────────────── st.markdown("""
✨ AI — Vision + Language

CaptionIQ

Attention-powered image captioning with VGG + LSTM.
Upload an image and watch the AI describe what it sees.

""", unsafe_allow_html=True) # ─── Ready check ───────────────────────────────────────────────────────────── tokenizer = get_tokenizer() engine = get_caption_engine() # 'ensemble' in the UI means BLIP; only need a single VGG backbone ready. _check_backbone = "vgg16" if backbone == "ensemble" else backbone ready, ready_error = engine.is_ready(_check_backbone) if (tokenizer is None) or (not ready): st.error( "Caption backend is not ready.\n\n" f"{ready_error or 'Tokenizer not found.'}\n\n" "```bash\npython src/train.py --backbone both --resume\n```" ) st.stop() def _run_caption(eng, img, bbone, use_heatmap): """ Route inference: 'ensemble' (UI label) → BLIP under the hood 'vgg16' / 'vgg19' → VGG + attention as normal Always returns a dict with keys: candidates, model_used, attention_maps """ if bbone == "ensemble": # BLIP path — branded as 'Ensemble' to the user blip_cap = engine._generate_blip_caption(img) if blip_cap: candidates = [(blip_cap, 0.82)] else: # Graceful fallback to VGG19 if BLIP unavailable res = eng.generate_caption( img, caption_mode="vgg_only", backbone_mode="vgg19", beam_width=BEAM_WIDTH ) candidates = res["candidates"] return { "candidates": candidates, "model_used": "Ensemble", # display name stays 'Ensemble' "attention_maps": [], } # Normal VGG path if use_heatmap: result = eng.generate_caption_with_attention( img, backbone_mode=bbone, beam_width=BEAM_WIDTH ) else: result = eng.generate_caption( img, caption_mode="vgg_only", backbone_mode=bbone, beam_width=BEAM_WIDTH ) result["attention_maps"] = [] return result # ────────────────────────────────────────────────────────────────────────────── # MODE: Compare Models # ────────────────────────────────────────────────────────────────────────────── if app_mode == "Compare Models": st.markdown("## 🔄 Model Comparison") st.caption("Run VGG16, VGG19, and Ensemble simultaneously — see which backbone wins.") uploaded_file = st.file_uploader( "Upload an image to compare all models", type=["jpg", "jpeg", "png"], key="compare_uploader" ) if uploaded_file: image = Image.open(uploaded_file) img_col, _ = st.columns([1, 2]) with img_col: st.image(image, caption=uploaded_file.name, use_container_width=True) with st.spinner("Running all 3 backbones in parallel…"): all_results = engine.generate_all_backbones(image, beam_width=BEAM_WIDTH) # Rename 'ensemble' model_used label for display if "ensemble" in all_results: all_results["ensemble"]["model_used"] = "Ensemble" # Find winner winner = max(all_results, key=lambda k: all_results[k].get("confidence", 0)) cols = st.columns(3) mode_labels = { "vgg16": ("🔶", "VGG16"), "vgg19": ("🔷", "VGG19"), "ensemble": ("⚡", "Ensemble"), } for col, mode in zip(cols, ["vgg16", "vgg19", "ensemble"]): res = all_results.get(mode, {}) icon, label = mode_labels[mode] is_winner = (mode == winner) winner_cls = "compare-winner" if is_winner else "" winner_badge = '🏆 Best' if is_winner else "" caption = res.get("caption", "Error") or "No caption generated" conf = res.get("confidence", 0.0) bar_w = int(conf * 100) with col: st.markdown(f"""
{icon} {label} {winner_badge}
{caption}
Confidence{conf:.1%}
""", unsafe_allow_html=True) # Sub-captions candidates = res.get("candidates", []) if len(candidates) > 1: with st.expander("More candidates", expanded=False): for i, (cap, sc) in enumerate(candidates[1:], 2): st.caption(f"#{i}: {cap} ({sc:.1%})") # Add to history if caption and caption != "Error": add_to_history(caption, conf, label, uploaded_file.name) st.stop() # ────────────────────────────────────────────────────────────────────────────── # MODE: Demo Gallery # ────────────────────────────────────────────────────────────────────────────── if app_mode == "Demo Gallery": st.markdown("## 🖼 Flickr8K Demo Gallery") demo_images = get_demo_images() if not demo_images: st.warning("No demo images found in `data/Flickr8k_Dataset/`.") st.stop() # "Surprise Me" button st.markdown('
', unsafe_allow_html=True) if st.button("🎲 Surprise Me — Pick a Random Image"): st.session_state["surprise_image"] = random.choice(demo_images) st.markdown("
", unsafe_allow_html=True) # Grid of first 10 st.markdown('
Browse Samples
', unsafe_allow_html=True) sample = demo_images[:10] cols = st.columns(5) selected_name = st.session_state.get("surprise_image") for i, img_name in enumerate(sample): img_path = os.path.join(FLICKR_IMAGES_DIR, img_name) with cols[i % 5]: pil_img = Image.open(img_path) st.image(pil_img, use_container_width=True) if st.button("Caption", key=f"demo_pick_{i}"): selected_name = img_name st.session_state["surprise_image"] = img_name if selected_name: img_path = os.path.join(FLICKR_IMAGES_DIR, selected_name) image = Image.open(img_path) st.markdown("---") st.markdown(f"### Selected: `{selected_name}`") c1, c2 = st.columns([1, 1]) with c1: st.image(image, use_container_width=True) with c2: progress_ph = st.empty() with progress_ph.container(): display_step_progress(0) with st.spinner("Generating captions…"): display_step_progress(1) result = _run_caption(engine, image, backbone, show_heatmap) attn_maps = result.get("attention_maps", []) display_step_progress(2) captions_result = result["candidates"] model_used = result["model_used"] if captions_result: display_captions(captions_result) st.caption(f"Model: {model_used}") add_to_history( captions_result[0][0], captions_result[0][1], model_used, selected_name ) progress_ph.empty() # Reference captions all_caps = get_all_captions() if selected_name in all_caps: with st.expander("📖 Ground-truth Reference Captions"): for cap in all_caps[selected_name]: st.markdown(f"- {clean_caption_text(cap)}") # Attention heatmap if show_heatmap and attn_maps: st.markdown("---") st.markdown("### 🔥 Attention Heatmap") st.caption("Which image regions does the model focus on for each word?") words = [w for w, _ in attn_maps] max_display = min(len(words), 16) word_idx = st.select_slider( "Select a word", options=list(range(max_display)), format_func=lambda i: words[i], key="attn_slider_demo" ) word, attn_grid = attn_maps[word_idx] overlay = build_attention_overlay(image, attn_grid) st.image(overlay, caption=f'Attention for word: "{word}"', use_container_width=True) st.stop() # ────────────────────────────────────────────────────────────────────────────── # MODE: Standard # ────────────────────────────────────────────────────────────────────────────── # "Surprise Me" hero row demo_images_all = get_demo_images() if demo_images_all: surprise_col, _ = st.columns([1, 3]) with surprise_col: if st.button("🎲 Surprise Me — Random Image", use_container_width=True): st.session_state["surprise_image"] = random.choice(demo_images_all) st.session_state["last_result"] = None st.session_state["last_attn"] = [] st.markdown('
Upload an Image
', unsafe_allow_html=True) uploaded_file = st.file_uploader( "Drag & drop or browse — JPG, JPEG, PNG", type=["jpg", "jpeg", "png"], key="std_uploader" ) # Resolve image source image = None image_id = None if uploaded_file: image = Image.open(uploaded_file) image_id = uploaded_file.name st.session_state["surprise_image"] = None # clear surprise on manual upload elif st.session_state.get("surprise_image"): sname = st.session_state["surprise_image"] spath = os.path.join(FLICKR_IMAGES_DIR, sname) if os.path.exists(spath): image = Image.open(spath) image_id = sname if image is not None: col1, col2 = st.columns([1, 1]) # ── Left: image + reference captions ───────────────────────────────────── with col1: st.markdown('
Input Image
', unsafe_allow_html=True) st.image(image, use_container_width=True) all_caps = get_all_captions() if image_id and image_id in all_caps: with st.expander("📖 Reference Captions"): for cap in all_caps[image_id]: st.markdown(f"- {clean_caption_text(cap)}") # ── Right: run inference, show captions ────────────────────────────────── with col2: st.markdown('
Generated Captions
', unsafe_allow_html=True) progress_ph = st.empty() with progress_ph.container(): display_step_progress(0) with st.spinner("Extracting features & generating captions…"): result = _run_caption(engine, image, backbone, show_heatmap) st.session_state["last_attn"] = result.get("attention_maps", []) st.session_state["last_result"] = result progress_ph.empty() captions_result = result["candidates"] model_used = result["model_used"] if captions_result: display_captions(captions_result) st.caption(f"Model: {model_used}") add_to_history( captions_result[0][0], captions_result[0][1], model_used, image_id or "unknown" ) caption_text = "\n".join( [f"#{i+1}: {cap}" for i, (cap, _) in enumerate(captions_result)] ) st.download_button( "⬇ Download Captions (.txt)", data=caption_text, file_name=f"captions_{backbone}.txt", mime="text/plain", ) # Inline BLEU if ground truth available if image_id: all_caps = get_all_captions() if image_id in all_caps: from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction import nltk nltk.download("punkt", quiet=True) nltk.download("punkt_tab", quiet=True) refs = [ cap.replace(START_TOKEN,"").replace(END_TOKEN,"").strip().split() for cap in all_caps[image_id] ] st.markdown('
Per-Image BLEU
', unsafe_allow_html=True) smooth = SmoothingFunction().method1 hyp = captions_result[0][0].split() bleu_cols = st.columns(4) for n in range(1, 5): weights = tuple([1.0/n]*n + [0.0]*(4-n)) score = sentence_bleu(refs, hyp, weights=weights, smoothing_function=smooth) with bleu_cols[n-1]: st.markdown(f"""
{score:.4f}
BLEU-{n}
""", unsafe_allow_html=True) else: st.warning("Could not generate captions. Please try a different image.") # ── Word Cloud — rendered AFTER inference so it always shows ────────────── if show_wordcloud and captions_result: st.markdown("---") st.markdown("### ☁️ Word Cloud") st.caption("Word distribution across all beam-search candidates.") wc_col1, wc_col2 = st.columns([2, 1]) with wc_col1: wc_img = make_word_cloud_image(captions_result) if wc_img: st.markdown('
', unsafe_allow_html=True) st.image(wc_img, use_container_width=True) st.markdown("
", unsafe_allow_html=True) else: make_word_freq_bar(captions_result) with wc_col2: st.markdown('
', unsafe_allow_html=True) st.markdown("**Top words**") from collections import Counter stop = {"a","an","the","is","are","in","on","of","to","and","or","with"} all_words = [] for cap, conf in captions_result: for w in cap.split(): if w.lower() not in stop and len(w) > 2: all_words.append(w.lower()) for word, cnt in Counter(all_words).most_common(6): st.markdown(f"- **{word}** ({cnt}×)") st.markdown("
", unsafe_allow_html=True) # ── Attention Heatmap — rendered AFTER inference ─────────────────────────── attn_maps = result.get("attention_maps", []) if image is not None else [] if show_heatmap: st.markdown("---") if attn_maps: st.markdown("### 🔥 Attention Heatmap Explorer") st.caption( "Gradient-based saliency — which 7×7 image regions does the model focus on " "for each predicted word? Use the slider to step through words." ) words = [w for w, _ in attn_maps] max_display = min(len(words), 16) chips_html = "" for i, w in enumerate(words[:max_display]): hue = int(240 - (i / max(max_display - 1, 1)) * 200) chips_html += ( f'' f'{w}' ) st.markdown(chips_html, unsafe_allow_html=True) word_idx = st.slider( "Step through words", 0, max_display - 1, 0, format="%d", key="attn_slider_std" ) selected_word, attn_grid = attn_maps[word_idx] h_col1, h_col2 = st.columns([1, 1]) with h_col1: st.image(image, caption="Original", use_container_width=True) with h_col2: overlay = build_attention_overlay(image, attn_grid) st.image(overlay, caption=f'Attention → "{selected_word}"', use_container_width=True) else: # Ensemble = BLIP — no attention maps st.markdown("### 🔥 Attention Heatmap") st.info( "Attention heatmaps are available for **VGG16** and **VGG19** modes.\n\n" "Switch the backbone in the sidebar to **🔷 VGG19** or **🔶 VGG16** to explore word-level attention." ) # ── Footer ──────────────────────────────────────────────────────────────────── st.markdown("---") st.markdown( "

" "CaptionIQ · VGG + Bahdanau Attention · Flickr8K · Built with Streamlit" "

", unsafe_allow_html=True, )