Sher1988 commited on
Commit
c00940f
·
verified ·
1 Parent(s): ab1e6d3

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -146
app.py DELETED
@@ -1,146 +0,0 @@
1
- import torch
2
- import pandas as pd
3
- import streamlit as st
4
- from PIL import Image
5
-
6
- from encoder import EncoderCNN
7
- from decoder import DecoderRNN
8
- from utils.vocab import Vocabulary
9
- #from torchvision import transforms as T
10
- from utils.helpers import VOCAB_PATH, CAPTIONS_PATH, IMAGE_DIR
11
- from utils.transforms import transforms
12
- from inference import sample_with_temp, sample
13
- import sacrebleu
14
- import os
15
- from huggingface_hub import hf_hub_download
16
-
17
- @st.cache_resource
18
- def load_models():
19
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
-
21
-
22
- # Load captions and vocab
23
- captions = pd.read_csv(CAPTIONS_PATH)
24
- vocab = Vocabulary(load_path=VOCAB_PATH)
25
-
26
- # Initialize Models
27
- encoder = EncoderCNN(256).to(device)
28
- decoder = DecoderRNN(len(vocab), 256, 512).to(device)
29
-
30
- #
31
- repo_id = "Sher1988/image-classifier-weights"
32
- encoder_path = hf_hub_download(repo_id=repo_id, filename="encoder.pth")
33
- decoder_path = hf_hub_download(repo_id=repo_id, filename="decoder.pth")
34
-
35
- # Load Weights
36
- encoder.load_state_dict(torch.load(encoder_path, map_location=device))
37
- decoder.load_state_dict(torch.load(decoder_path, map_location=device))
38
-
39
- encoder.eval()
40
- decoder.eval()
41
-
42
- return encoder, decoder, vocab, device, captions
43
-
44
-
45
- # --- Sidebar Configuration ---
46
- st.sidebar.header("Select an Example Image")
47
-
48
- if os.path.exists(IMAGE_DIR):
49
- available_images = [f for f in os.listdir(IMAGE_DIR) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
50
- selected_img_name = st.sidebar.selectbox("Choose from Flickr8k:", ["None"] + available_images)
51
-
52
- # Add the preview thumbnail here
53
- if selected_img_name != "None":
54
- img_path = os.path.join(IMAGE_DIR, selected_img_name)
55
- st.sidebar.image(Image.open(img_path), caption="Sidebar Selection Preview", use_container_width=True)
56
- else:
57
- st.sidebar.warning("Image directory not found. Please check IMAGE_DIR path.")
58
- selected_img_name = "None"
59
-
60
- # --- Main App Logic ---
61
- encoder, decoder, vocab, device, captions = load_models()
62
- act_caps = []
63
- caption = ''
64
- st.title("📸 AI Image Captioner")
65
-
66
- temp = st.slider("Sampling Temperature", min_value=0.0, max_value=0.8, value=0.1, step=0.1)
67
- st.info("Higher temperature = more creative/random. Lower temperature = more predictable.")
68
-
69
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
70
-
71
- # Determine which image to process
72
- img = None
73
- img_name = None
74
-
75
- if uploaded_file is not None:
76
- img = Image.open(uploaded_file).convert('RGB')
77
- img_name = uploaded_file.name
78
- elif selected_img_name != "None":
79
- img_path = os.path.join(IMAGE_DIR, selected_img_name)
80
- img = Image.open(img_path).convert('RGB')
81
- img_name = selected_img_name
82
-
83
- # If we have an image (from either source), run the model
84
- if img is not None:
85
- st.image(img, caption=f'Selected: {img_name}', width=300)
86
-
87
- # Process
88
- # Assuming transforms is defined or returned from load_models
89
- img_tensor = transforms(img).unsqueeze(0).to(device)
90
-
91
- # Get ground truth captions for the selected image name
92
- act_caps = captions[captions['image'] == img_name]['caption'].tolist()
93
-
94
- if act_caps:
95
- st.subheader("Actual Captions:")
96
- st.success(" \n".join(act_caps))
97
- else:
98
- st.info("No ground truth captions found for this image in the CSV.")
99
-
100
- with torch.no_grad():
101
- encoder_out = encoder(img_tensor)
102
- # Pass the 'temp' variable from the slider here
103
- caption = sample_with_temp(encoder_out, decoder, vocab, temp=temp)
104
-
105
- st.subheader("Generated Caption:")
106
- st.success(caption)
107
-
108
- if act_caps:
109
- # sacrebleu expects a list of strings for hypothesis
110
- # and a list of strings for references
111
- refs = [act_caps]
112
- sys = [caption]
113
-
114
- bleu = sacrebleu.corpus_bleu(sys, refs)
115
-
116
- st.subheader("Evaluation Metrics:")
117
- st.metric(label="SacreBLEU Score", value=f"{bleu.score:.2f}")
118
- st.progress(min(bleu.score / 50, 1.0))
119
-
120
- # N-gram Precision breakdown
121
- # bleu.precisions is a list: [p1, p2, p3, p4]
122
- cols = st.columns(4)
123
- for i, p in enumerate(bleu.precisions):
124
- cols[i].markdown(f"{i+1}-gram")
125
- cols[i].write(f"{p:.1f}%")
126
-
127
- # Brief explanation
128
- with st.expander("What do these mean?"):
129
- st.write("""
130
- - **1-gram**: Individual word accuracy (Vocabulary).
131
- - **2-gram**: Fluency of word pairs.
132
- - **4-gram**: Capturing longer phrases/sentence structure.
133
- """)
134
- else:
135
- st.info("Upload an image from the Flickr8k set to see BLEU metrics.")
136
-
137
- st.header('About this Project')
138
- st.markdown("""
139
- This AI model generates descriptive captions for uploaded images using a **ResNet50 + LSTM** architecture.
140
-
141
- * **Encoder:** Pre-trained ResNet50 (Frozen) extracts high-level visual features.
142
- * **Decoder:** A Long Short-Term Memory (LSTM) network trained for 10 epochs.
143
- * **Dataset:** Trained on the **Flickr8k dataset** (8,000 images).
144
-
145
- ⚠��� **Note:** Because the model was trained on a specific, small-scale dataset with a frozen backbone, it performs satisfactory on outdoor scenes, people, and animals. It may produce unexpected results for images significantly different from the Flickr8k distribution.
146
- """)