import torch import pandas as pd import streamlit as st from PIL import Image from encoder import EncoderCNN from decoder import DecoderRNN from utils.vocab import Vocabulary #from torchvision import transforms as T from utils.helpers import VOCAB_PATH, CAPTIONS_PATH, IMAGE_DIR from utils.transforms import transforms from inference import sample_with_temp, sample import sacrebleu import os from huggingface_hub import hf_hub_download @st.cache_resource def load_models(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load captions and vocab captions = pd.read_csv(CAPTIONS_PATH) vocab = Vocabulary(load_path=VOCAB_PATH) # Initialize Models encoder = EncoderCNN(256).to(device) decoder = DecoderRNN(len(vocab), 256, 512).to(device) # repo_id = "Sher1988/image-classifier-weights" encoder_path = hf_hub_download(repo_id=repo_id, filename="encoder.pth") decoder_path = hf_hub_download(repo_id=repo_id, filename="decoder.pth") # Load Weights encoder.load_state_dict(torch.load(encoder_path, map_location=device)) decoder.load_state_dict(torch.load(decoder_path, map_location=device)) encoder.eval() decoder.eval() return encoder, decoder, vocab, device, captions # --- Sidebar Configuration --- st.sidebar.header("Select an Example Image") if os.path.exists(IMAGE_DIR): available_images = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] selected_img_name = st.sidebar.selectbox("Choose from Flickr8k:", ["None"] + available_images) # Add the preview thumbnail here if selected_img_name != "None": img_path = os.path.join(IMAGE_DIR, selected_img_name) st.sidebar.image(Image.open(img_path), caption="Sidebar Selection Preview", use_container_width=True) else: st.sidebar.warning("Image directory not found. Please check IMAGE_DIR path.") selected_img_name = "None" # --- Main App Logic --- encoder, decoder, vocab, device, captions = load_models() act_caps = [] caption = '' st.title("📸 AI Image Captioner") temp = st.slider("Sampling Temperature", min_value=0.0, max_value=0.8, value=0.1, step=0.1) st.info("Higher temperature = more creative/random. Lower temperature = more predictable.") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) # Determine which image to process img = None img_name = None if uploaded_file is not None: img = Image.open(uploaded_file).convert('RGB') img_name = uploaded_file.name elif selected_img_name != "None": img_path = os.path.join(IMAGE_DIR, selected_img_name) img = Image.open(img_path).convert('RGB') img_name = selected_img_name # If we have an image (from either source), run the model if img is not None: st.image(img, caption=f'Selected: {img_name}', width=300) # Process # Assuming transforms is defined or returned from load_models img_tensor = transforms(img).unsqueeze(0).to(device) # Get ground truth captions for the selected image name act_caps = captions[captions['image'] == img_name]['caption'].tolist() if act_caps: st.subheader("Actual Captions:") st.success(" \n".join(act_caps)) else: st.info("No ground truth captions found for this image in the CSV.") with torch.no_grad(): encoder_out = encoder(img_tensor) # Pass the 'temp' variable from the slider here caption = sample_with_temp(encoder_out, decoder, vocab, temp=temp) st.subheader("Generated Caption:") st.success(caption) if act_caps: # sacrebleu expects a list of strings for hypothesis # and a list of strings for references refs = [act_caps] sys = [caption] bleu = sacrebleu.corpus_bleu(sys, refs) st.subheader("Evaluation Metrics:") st.metric(label="SacreBLEU Score", value=f"{bleu.score:.2f}") st.progress(min(bleu.score / 50, 1.0)) # N-gram Precision breakdown # bleu.precisions is a list: [p1, p2, p3, p4] cols = st.columns(4) for i, p in enumerate(bleu.precisions): cols[i].markdown(f"{i+1}-gram") cols[i].write(f"{p:.1f}%") # Brief explanation with st.expander("What do these mean?"): st.write(""" - **1-gram**: Individual word accuracy (Vocabulary). - **2-gram**: Fluency of word pairs. - **4-gram**: Capturing longer phrases/sentence structure. """) else: st.info("Upload an image from the Flickr8k set to see BLEU metrics.") st.header('About this Project') st.markdown(""" This AI model generates descriptive captions for uploaded images using a **ResNet50 + LSTM** architecture. * **Encoder:** Pre-trained ResNet50 (Frozen) extracts high-level visual features. * **Decoder:** A Long Short-Term Memory (LSTM) network trained for 10 epochs. * **Dataset:** Trained on the **Flickr8k dataset** (8,000 images). ⚠️ **Note:** Because the model was trained on a specific, small-scale dataset with a frozen backbone, it performs satisfactory on outdoor scenes, people, and animals. It may produce unexpected results for images significantly different from the Flickr8k distribution. """)