Spaces:
Runtime error
Runtime error
| import torch | |
| import pandas as pd | |
| import streamlit as st | |
| from PIL import Image | |
| from encoder import EncoderCNN | |
| from decoder import DecoderRNN | |
| from utils.vocab import Vocabulary | |
| #from torchvision import transforms as T | |
| from utils.helpers import VOCAB_PATH, CAPTIONS_PATH, IMAGE_DIR | |
| from utils.transforms import transforms | |
| from inference import sample_with_temp, sample | |
| import sacrebleu | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| def load_models(): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load captions and vocab | |
| captions = pd.read_csv(CAPTIONS_PATH) | |
| vocab = Vocabulary(load_path=VOCAB_PATH) | |
| # Initialize Models | |
| encoder = EncoderCNN(256).to(device) | |
| decoder = DecoderRNN(len(vocab), 256, 512).to(device) | |
| # | |
| repo_id = "Sher1988/image-classifier-weights" | |
| encoder_path = hf_hub_download(repo_id=repo_id, filename="encoder.pth") | |
| decoder_path = hf_hub_download(repo_id=repo_id, filename="decoder.pth") | |
| # Load Weights | |
| encoder.load_state_dict(torch.load(encoder_path, map_location=device)) | |
| decoder.load_state_dict(torch.load(decoder_path, map_location=device)) | |
| encoder.eval() | |
| decoder.eval() | |
| return encoder, decoder, vocab, device, captions | |
| # --- Sidebar Configuration --- | |
| st.sidebar.header("Select an Example Image") | |
| if os.path.exists(IMAGE_DIR): | |
| available_images = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] | |
| selected_img_name = st.sidebar.selectbox("Choose from Flickr8k:", ["None"] + available_images) | |
| # Add the preview thumbnail here | |
| if selected_img_name != "None": | |
| img_path = os.path.join(IMAGE_DIR, selected_img_name) | |
| st.sidebar.image(Image.open(img_path), caption="Sidebar Selection Preview", use_container_width=True) | |
| else: | |
| st.sidebar.warning("Image directory not found. Please check IMAGE_DIR path.") | |
| selected_img_name = "None" | |
| # --- Main App Logic --- | |
| encoder, decoder, vocab, device, captions = load_models() | |
| act_caps = [] | |
| caption = '' | |
| st.title("📸 AI Image Captioner") | |
| temp = st.slider("Sampling Temperature", min_value=0.0, max_value=0.8, value=0.1, step=0.1) | |
| st.info("Higher temperature = more creative/random. Lower temperature = more predictable.") | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) | |
| # Determine which image to process | |
| img = None | |
| img_name = None | |
| if uploaded_file is not None: | |
| img = Image.open(uploaded_file).convert('RGB') | |
| img_name = uploaded_file.name | |
| elif selected_img_name != "None": | |
| img_path = os.path.join(IMAGE_DIR, selected_img_name) | |
| img = Image.open(img_path).convert('RGB') | |
| img_name = selected_img_name | |
| # If we have an image (from either source), run the model | |
| if img is not None: | |
| st.image(img, caption=f'Selected: {img_name}', width=300) | |
| # Process | |
| # Assuming transforms is defined or returned from load_models | |
| img_tensor = transforms(img).unsqueeze(0).to(device) | |
| # Get ground truth captions for the selected image name | |
| act_caps = captions[captions['image'] == img_name]['caption'].tolist() | |
| if act_caps: | |
| st.subheader("Actual Captions:") | |
| st.success(" \n".join(act_caps)) | |
| else: | |
| st.info("No ground truth captions found for this image in the CSV.") | |
| with torch.no_grad(): | |
| encoder_out = encoder(img_tensor) | |
| # Pass the 'temp' variable from the slider here | |
| caption = sample_with_temp(encoder_out, decoder, vocab, temp=temp) | |
| st.subheader("Generated Caption:") | |
| st.success(caption) | |
| if act_caps: | |
| # sacrebleu expects a list of strings for hypothesis | |
| # and a list of strings for references | |
| refs = [act_caps] | |
| sys = [caption] | |
| bleu = sacrebleu.corpus_bleu(sys, refs) | |
| st.subheader("Evaluation Metrics:") | |
| st.metric(label="SacreBLEU Score", value=f"{bleu.score:.2f}") | |
| st.progress(min(bleu.score / 50, 1.0)) | |
| # N-gram Precision breakdown | |
| # bleu.precisions is a list: [p1, p2, p3, p4] | |
| cols = st.columns(4) | |
| for i, p in enumerate(bleu.precisions): | |
| cols[i].markdown(f"{i+1}-gram") | |
| cols[i].write(f"{p:.1f}%") | |
| # Brief explanation | |
| with st.expander("What do these mean?"): | |
| st.write(""" | |
| - **1-gram**: Individual word accuracy (Vocabulary). | |
| - **2-gram**: Fluency of word pairs. | |
| - **4-gram**: Capturing longer phrases/sentence structure. | |
| """) | |
| else: | |
| st.info("Upload an image from the Flickr8k set to see BLEU metrics.") | |
| st.header('About this Project') | |
| st.markdown(""" | |
| This AI model generates descriptive captions for uploaded images using a **ResNet50 + LSTM** architecture. | |
| * **Encoder:** Pre-trained ResNet50 (Frozen) extracts high-level visual features. | |
| * **Decoder:** A Long Short-Term Memory (LSTM) network trained for 10 epochs. | |
| * **Dataset:** Trained on the **Flickr8k dataset** (8,000 images). | |
| ⚠️ **Note:** Because the model was trained on a specific, small-scale dataset with a frozen backbone, it performs satisfactory on outdoor scenes, people, and animals. It may produce unexpected results for images significantly different from the Flickr8k distribution. | |
| """) | |