caption-gen / app.py
Sher1988's picture
Update app.py
49efe0e verified
import torch
import pandas as pd
import streamlit as st
from PIL import Image
from encoder import EncoderCNN
from decoder import DecoderRNN
from utils.vocab import Vocabulary
#from torchvision import transforms as T
from utils.helpers import VOCAB_PATH, CAPTIONS_PATH, IMAGE_DIR
from utils.transforms import transforms
from inference import sample_with_temp, sample
import sacrebleu
import os
from huggingface_hub import hf_hub_download
@st.cache_resource
def load_models():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load captions and vocab
captions = pd.read_csv(CAPTIONS_PATH)
vocab = Vocabulary(load_path=VOCAB_PATH)
# Initialize Models
encoder = EncoderCNN(256).to(device)
decoder = DecoderRNN(len(vocab), 256, 512).to(device)
#
repo_id = "Sher1988/image-classifier-weights"
encoder_path = hf_hub_download(repo_id=repo_id, filename="encoder.pth")
decoder_path = hf_hub_download(repo_id=repo_id, filename="decoder.pth")
# Load Weights
encoder.load_state_dict(torch.load(encoder_path, map_location=device))
decoder.load_state_dict(torch.load(decoder_path, map_location=device))
encoder.eval()
decoder.eval()
return encoder, decoder, vocab, device, captions
# --- Sidebar Configuration ---
st.sidebar.header("Select an Example Image")
if os.path.exists(IMAGE_DIR):
available_images = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
selected_img_name = st.sidebar.selectbox("Choose from Flickr8k:", ["None"] + available_images)
# Add the preview thumbnail here
if selected_img_name != "None":
img_path = os.path.join(IMAGE_DIR, selected_img_name)
st.sidebar.image(Image.open(img_path), caption="Sidebar Selection Preview", use_container_width=True)
else:
st.sidebar.warning("Image directory not found. Please check IMAGE_DIR path.")
selected_img_name = "None"
# --- Main App Logic ---
encoder, decoder, vocab, device, captions = load_models()
act_caps = []
caption = ''
st.title("📸 AI Image Captioner")
temp = st.slider("Sampling Temperature", min_value=0.0, max_value=0.8, value=0.1, step=0.1)
st.info("Higher temperature = more creative/random. Lower temperature = more predictable.")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
# Determine which image to process
img = None
img_name = None
if uploaded_file is not None:
img = Image.open(uploaded_file).convert('RGB')
img_name = uploaded_file.name
elif selected_img_name != "None":
img_path = os.path.join(IMAGE_DIR, selected_img_name)
img = Image.open(img_path).convert('RGB')
img_name = selected_img_name
# If we have an image (from either source), run the model
if img is not None:
st.image(img, caption=f'Selected: {img_name}', width=300)
# Process
# Assuming transforms is defined or returned from load_models
img_tensor = transforms(img).unsqueeze(0).to(device)
# Get ground truth captions for the selected image name
act_caps = captions[captions['image'] == img_name]['caption'].tolist()
if act_caps:
st.subheader("Actual Captions:")
st.success(" \n".join(act_caps))
else:
st.info("No ground truth captions found for this image in the CSV.")
with torch.no_grad():
encoder_out = encoder(img_tensor)
# Pass the 'temp' variable from the slider here
caption = sample_with_temp(encoder_out, decoder, vocab, temp=temp)
st.subheader("Generated Caption:")
st.success(caption)
if act_caps:
# sacrebleu expects a list of strings for hypothesis
# and a list of strings for references
refs = [act_caps]
sys = [caption]
bleu = sacrebleu.corpus_bleu(sys, refs)
st.subheader("Evaluation Metrics:")
st.metric(label="SacreBLEU Score", value=f"{bleu.score:.2f}")
st.progress(min(bleu.score / 50, 1.0))
# N-gram Precision breakdown
# bleu.precisions is a list: [p1, p2, p3, p4]
cols = st.columns(4)
for i, p in enumerate(bleu.precisions):
cols[i].markdown(f"{i+1}-gram")
cols[i].write(f"{p:.1f}%")
# Brief explanation
with st.expander("What do these mean?"):
st.write("""
- **1-gram**: Individual word accuracy (Vocabulary).
- **2-gram**: Fluency of word pairs.
- **4-gram**: Capturing longer phrases/sentence structure.
""")
else:
st.info("Upload an image from the Flickr8k set to see BLEU metrics.")
st.header('About this Project')
st.markdown("""
This AI model generates descriptive captions for uploaded images using a **ResNet50 + LSTM** architecture.
* **Encoder:** Pre-trained ResNet50 (Frozen) extracts high-level visual features.
* **Decoder:** A Long Short-Term Memory (LSTM) network trained for 10 epochs.
* **Dataset:** Trained on the **Flickr8k dataset** (8,000 images).
⚠️ **Note:** Because the model was trained on a specific, small-scale dataset with a frozen backbone, it performs satisfactory on outdoor scenes, people, and animals. It may produce unexpected results for images significantly different from the Flickr8k distribution.
""")