Sher1988 commited on
Commit
38e36bb
·
verified ·
1 Parent(s): 5588fb5

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile CHANGED
@@ -1,24 +1,24 @@
1
- FROM python:3.13.5-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- git \
9
- && rm -rf /var/lib/apt/lists/*
10
-
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
13
- COPY resnet18_cifar10_finetuned.pth ./
14
-
15
- RUN pip3 install -r requirements.txt
16
-
17
- # Change EXPOSE to 7860
18
- EXPOSE 7860
19
-
20
- # Update HEALTHCHECK to use 7860
21
- HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
22
-
23
- # Add the XSRF and CORS disable flags to ENTRYPOINT
24
  ENTRYPOINT ["streamlit", "run", "src/image_classifier_app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.enableCORS=false", "--server.enableXsrfProtection=false"]
 
1
+ FROM python:3.13.5-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ curl \
8
+ git \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt ./
12
+ COPY src/ ./src/
13
+ COPY resnet18_cifar10_finetuned.pth ./
14
+
15
+ RUN pip3 install -r requirements.txt
16
+
17
+ # Change EXPOSE to 7860
18
+ EXPOSE 7860
19
+
20
+ # Update HEALTHCHECK to use 7860
21
+ HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
22
+
23
+ # Add the XSRF and CORS disable flags to ENTRYPOINT
24
  ENTRYPOINT ["streamlit", "run", "src/image_classifier_app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.enableCORS=false", "--server.enableXsrfProtection=false"]
LICENSE ADDED
File without changes
README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Here is a clean, professional `README.md` suitable for your HF Space:
2
+
3
+
4
+
5
+ ---
6
+
7
+
8
+
9
+ \# AI Image Caption Generator
10
+
11
+
12
+
13
+ A deep learning–based image captioning system built using a \*\*ResNet50 encoder\*\* and an \*\*LSTM decoder\*\*. The model generates natural language descriptions for uploaded images.
14
+
15
+
16
+
17
+ \## Architecture
18
+
19
+
20
+
21
+ \* \*\*Encoder:\*\* ResNet50 (frozen backbone)
22
+
23
+ \* \*\*Decoder:\*\* LSTM-based sequence generator
24
+
25
+ \* \*\*Training Dataset:\*\* Flickr8k
26
+
27
+ \* \*\*Inference Framework:\*\* Streamlit
28
+
29
+ \* \*\*Evaluation Metric:\*\* SacreBLEU
30
+
31
+
32
+
33
+ The encoder extracts high-level visual features, which are then passed to the decoder to generate captions word by word.
34
+
35
+
36
+
37
+ ---
38
+
39
+
40
+
41
+ \## How It Works
42
+
43
+
44
+
45
+ 1\. User uploads an image.
46
+
47
+ 2\. Image is preprocessed and passed through the ResNet50 encoder.
48
+
49
+ 3\. Extracted feature vector is fed into the LSTM decoder.
50
+
51
+ 4\. Caption is generated using temperature-based sampling.
52
+
53
+ 5\. If the image belongs to the Flickr8k dataset, BLEU metrics are displayed.
54
+
55
+
56
+
57
+ ---
58
+
59
+
60
+
61
+ \## Features
62
+
63
+
64
+
65
+ \* Temperature-controlled caption generation
66
+
67
+ \* SacreBLEU evaluation
68
+
69
+ \* N-gram precision breakdown (1–4 gram)
70
+
71
+ \* Clean Streamlit interface
72
+
73
+ \* Fully CPU-compatible deployment
74
+
75
+
76
+
77
+ ---
78
+
79
+
80
+
81
+ \## Project Structure
82
+
83
+
84
+
85
+ ```
86
+
87
+ app.py
88
+
89
+ models/
90
+
91
+   encoder.pth
92
+
93
+   decoder.pth
94
+
95
+ models/
96
+
97
+   encoder.py
98
+
99
+   decoder.py
100
+
101
+ utils/
102
+
103
+   transforms.py
104
+
105
+   vocab.py
106
+
107
+   helpers.py
108
+
109
+ vocabulary.json
110
+
111
+ requirements.txt
112
+
113
+ ```
114
+
115
+
116
+
117
+ ---
118
+
119
+
120
+
121
+ \## Model Details
122
+
123
+
124
+
125
+ \* Encoder weights size: ~92 MB
126
+
127
+ \* Decoder weights size: ~32 MB
128
+
129
+ \* Full encoder backbone included in state\_dict
130
+
131
+ \* Inference runs on CPU
132
+
133
+
134
+
135
+ ---
136
+
137
+
138
+
139
+ \## Limitations
140
+
141
+
142
+
143
+ \* Trained on Flickr8k (8,000 images)
144
+
145
+ \* Performs best on outdoor scenes, people, and animals
146
+
147
+ \* May generalize poorly to unseen domains
148
+
149
+ \* CPU inference can be slow (2–5 seconds per image)
150
+
151
+
152
+
153
+ ---
154
+
155
+
156
+
157
+ \## Setup (Local)
158
+
159
+
160
+
161
+ ```bash
162
+
163
+ pip install -r requirements.txt
164
+
165
+ streamlit run app.py
166
+
167
+ ```
168
+
169
+
170
+
171
+ ---
172
+
173
+
174
+
175
+ \## Deployment
176
+
177
+
178
+
179
+ This project is deployed on \*\*Hugging Face Spaces\*\* using Streamlit.
180
+
181
+
182
+
183
+ ---
184
+
185
+
186
+
187
+ \## License
188
+
189
+
190
+
191
+ MIT License
192
+
193
+
194
+
195
+ ---
196
+
197
+
198
+
199
+ If you want, I can also write a \*\*short portfolio-style README\*\* optimized specifically for recruiters.
200
+
201
+
202
+
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import streamlit as st
4
+ from PIL import Image
5
+ from models.encoder import EncoderCNN
6
+ from models.decoder import DecoderRNN
7
+ from utils.vocab import Vocabulary
8
+ from torchvision import transforms as T
9
+ from utils.helpers import ENCODER_MODEL_PATH, DECODER_MODEL_PATH, VOCAB_FILE_PATH, CAPTIONS_TKN_PATH
10
+ from inference import sample_with_temp, sample
11
+ from utils.transforms import transforms
12
+ import sacrebleu
13
+ # ... (your other imports)
14
+
15
+ @st.cache_resource
16
+ def load_models():
17
+ captions = pd.read_csv(CAPTIONS_TKN_PATH).drop('tokens', axis=1)
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+
20
+ # Load Vocabulary
21
+ vocab = Vocabulary(load_path=VOCAB_FILE_PATH)
22
+
23
+ # Initialize Models
24
+ encoder = EncoderCNN(256).to(device)
25
+ decoder = DecoderRNN(len(vocab), 256, 512).to(device)
26
+
27
+ # Load Weights
28
+ encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH, map_location=device))
29
+ decoder.load_state_dict(torch.load(DECODER_MODEL_PATH, map_location=device))
30
+
31
+ encoder.eval()
32
+ decoder.eval()
33
+
34
+ return encoder, decoder, vocab, device, captions
35
+
36
+
37
+ # --- Main App Logic ---
38
+ encoder, decoder, vocab, device, captions = load_models()
39
+ act_caps = []
40
+ caption = ''
41
+ st.title("📸 AI Image Captioner")
42
+
43
+ temp = st.slider("Sampling Temperature", min_value=0.1, max_value=2.0, value=0.8, step=0.1)
44
+ st.info("Higher temperature = more creative/random. Lower = more predictable.")
45
+
46
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
47
+ if uploaded_file is not None:
48
+ img = Image.open(uploaded_file).convert('RGB')
49
+ st.image(img, caption='Uploaded Image', width=300)
50
+
51
+ # Process
52
+ # Assuming transforms is defined or returned from load_models
53
+ img_tensor = transforms(img).unsqueeze(0).to(device)
54
+
55
+ st.subheader("Actual Captions:")
56
+ act_caps = captions[captions['image'] == uploaded_file.name]['caption'].tolist()
57
+ st.success(" \n".join(act_caps))
58
+
59
+ with torch.no_grad():
60
+ encoder_out = encoder(img_tensor)
61
+ # Pass the 'temp' variable from the slider here
62
+ caption = sample_with_temp(encoder_out, decoder, vocab, temp=temp)
63
+
64
+ st.subheader("Generated Caption:")
65
+ st.success(caption)
66
+
67
+ if act_caps:
68
+ # sacrebleu expects a list of strings for hypothesis
69
+ # and a list of lists of strings for references
70
+ refs = [act_caps]
71
+ sys = [caption]
72
+
73
+ bleu = sacrebleu.corpus_bleu(sys, refs)
74
+
75
+ st.subheader("Evaluation Metrics:")
76
+ st.metric(label="SacreBLEU Score", value=f"{bleu.score:.2f}")
77
+ st.progress(min(bleu.score / 50, 1.0))
78
+
79
+ # N-gram Precision breakdown
80
+ # bleu.precisions is a list: [p1, p2, p3, p4]
81
+ cols = st.columns(4)
82
+ for i, p in enumerate(bleu.precisions):
83
+ cols[i].markdown(f"{i+1}-gram")
84
+ cols[i].write(f"{p:.1f}%")
85
+
86
+ # Brief explanation
87
+ with st.expander("What do these mean?"):
88
+ st.write("""
89
+ - **1-gram**: Individual word accuracy (Vocabulary).
90
+ - **2-gram**: Fluency of word pairs.
91
+ - **4-gram**: Capturing longer phrases/sentence structure.
92
+ """)
93
+ else:
94
+ st.info("Upload an image from the Flickr8k set to see BLEU metrics.")
95
+
96
+ st.header('About this Project')
97
+ st.markdown("""
98
+ This AI model generates descriptive captions for uploaded images using a **ResNet50 + LSTM** architecture.
99
+
100
+ * **Encoder:** Pre-trained ResNet50 (Frozen) extracts high-level visual features.
101
+ * **Decoder:** A Long Short-Term Memory (LSTM) network trained for 10 epochs.
102
+ * **Dataset:** Trained on the **Flickr8k dataset** (8,000 images).
103
+
104
+ ⚠️ **Note:** Because the model was trained on a specific, small-scale dataset with a frozen backbone, it performs satisfactory on outdoor scenes, people, and animals. It may produce unexpected results for images significantly different from the Flickr8k distribution.
105
+ """)
inference.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from utils.transforms import transforms
5
+ from utils.vocab import Vocabulary
6
+ from utils.helpers import VOCAB_FILE_PATH, ENCODER_MODEL_PATH, DECODER_MODEL_PATH
7
+ from models.encoder import EncoderCNN
8
+ from models.decoder import DecoderRNN
9
+ import PIL.Image as Image
10
+
11
+
12
+ def sample(features, decoder, vocab, max_len=20):
13
+ device = features.device
14
+ result_caption = []
15
+ word_idx = torch.tensor([vocab['<SOS>']]).unsqueeze(0).to(device) # Shape (1, 1)
16
+ outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size)
17
+ for _ in range(max_len):
18
+ predicted = outputs.argmax(2)
19
+ word = vocab[predicted.item()] # .item() to get numerical value from tensor
20
+ if word == '<EOS>':
21
+ break
22
+ result_caption.append(word)
23
+ # Pass features=None and previous hidden state
24
+ outputs, hidden = decoder(None, predicted, hidden)
25
+ return ' '.join(result_caption)
26
+
27
+
28
+
29
+ def beam_sample(features, decoder, vocab, beam_size=5, max_len=30):
30
+ device = features.device
31
+ # (log_score, sequence, hidden_state)
32
+ start_token = torch.tensor([vocab['<SOS>']]).to(device)
33
+ beams = [(0, [start_token.item()], None)]
34
+
35
+ for _ in range(max_len):
36
+ candidates = []
37
+ for score, seq, hidden in beams:
38
+ if seq[-1] == vocab['<EOS>']:
39
+ candidates.append((score, seq, hidden))
40
+ continue
41
+
42
+ # Predict next word
43
+ curr_word = torch.tensor([seq[-1]]).unsqueeze(0).to(device)
44
+ # Use features only on first step
45
+ feat_input = features if _ == 0 else None
46
+
47
+ outputs, next_hidden = decoder(feat_input, curr_word, hidden)
48
+
49
+ # Get log probabilities
50
+ log_probs = F.log_softmax(outputs.squeeze(1), dim=1)
51
+ top_probs, top_idxs = log_probs.topk(beam_size)
52
+
53
+ for i in range(beam_size):
54
+ candidates.append((score + top_probs[0][i].item(),
55
+ seq + [top_idxs[0][i].item()],
56
+ next_hidden))
57
+
58
+ # Sort by score and keep top k
59
+ beams = sorted(candidates, key=lambda x: x[0], reverse=True)[:beam_size]
60
+
61
+ # Stop if all beams end in <EOS>
62
+ if all(s[-1] == vocab['<EOS>'] for _, s, _ in beams):
63
+ break
64
+
65
+ # Return best sequence (minus tokens)
66
+ best_seq = beams[0][1]
67
+ return ' '.join([vocab[idx] for idx in best_seq if idx not in [vocab['<SOS>'], vocab['<EOS>']]])
68
+
69
+
70
+
71
+ def sample_with_temp(features, decoder, vocab, temp=0.8, max_len=30):
72
+ device = features.device
73
+ result_caption = []
74
+ word_idx = torch.tensor([vocab['<SOS>']]).unsqueeze(0).to(device)
75
+ outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size)
76
+ for _ in range(max_len):
77
+ # Apply temperature to logits
78
+ logits = outputs.squeeze(1) / temp
79
+ probs = F.softmax(logits, dim=-1)
80
+ # Sample from the distribution instead of argmax
81
+ predicted = torch.multinomial(probs, 1)
82
+ word = vocab[predicted.item()]
83
+ if word == '<EOS>': break
84
+ result_caption.append(word)
85
+ outputs, hidden = decoder(None, predicted, hidden)
86
+ return ' '.join(result_caption)
87
+
88
+
89
+
90
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
91
+ vocabulary = Vocabulary(load_path=VOCAB_FILE_PATH)
92
+ # img = Image.open(COLAB_DATA_FOLDER + 'Images/' + '141140165_9002a04f19.jpg').convert('RGB')
93
+
94
+ encoder = EncoderCNN(256).to(device)
95
+ decoder = DecoderRNN(len(vocabulary), 256, 512).to(device)
96
+
97
+ encoder_state_dict = torch.load(ENCODER_MODEL_PATH, map_location=device)
98
+ decoder_state_dict = torch.load(DECODER_MODEL_PATH, map_location=device)
99
+
100
+ encoder.load_state_dict(encoder_state_dict)
101
+ decoder.load_state_dict(decoder_state_dict)
102
+
103
+ encoder.eval()
104
+ decoder.eval()
105
+
106
+ img = Image.open('data/flickr_data/Images/3718892835_a3e74a3417.jpg').convert('RGB')
107
+ img = transforms(img).unsqueeze(0).to(device)
108
+ encoder_out = encoder(img)
109
+ print('sample_with_temp: ', sample_with_temp(encoder_out, decoder, vocabulary))
110
+ # print('sample: ', sample(encoder_out, decoder, vocabulary))
111
+ # print('beam_sample: ', beam_sample(encoder_out, decoder, vocabulary))
models/decoder.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class DecoderRNN(nn.Module):
5
+ def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, padding_idx=0):
6
+ super(DecoderRNN, self).__init__()
7
+ self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=padding_idx)
8
+ self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
9
+ self.linear = nn.Linear(hidden_size, vocab_size)
10
+
11
+ self.init_h = nn.Linear(embed_size, hidden_size)
12
+ self.init_c = nn.Linear(embed_size, hidden_size)
13
+
14
+ def forward(self, features, captions, hidden=None):
15
+ if hidden == None:
16
+ h0 = self.init_h(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
17
+ c0 = self.init_c(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
18
+ hidden = (h0, c0)
19
+ # dataflow: (B, seqlen) -> (B, hidden_size) -> (1, B, hidden_size) -> (num_layers, B, hidden_size)
20
+
21
+ embeddings = self.embed(captions) # (B, seqlen) -> Training: (B, seqlen, embed_size) | Inference: (B, 1, embed_size)
22
+ outputs, hidden = self.lstm(embeddings, hidden) # Training: (B, seqlen, hidden_size) | Inference: (B, 1, hidden_size)
23
+ outputs = self.linear(outputs) # Training: (B, seqlen, vocab_size) | Inference: (B, 1, vocab_size)
24
+ return outputs, hidden
models/encoder.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15c2fd0e6a24e1b58ba66cd7f9754d9d9befd885e1aea762d0016b6d5f8d351c
3
+ size 96454389
models/encoder.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.models import resnet50, ResNet50_Weights
2
+ import torch.nn as nn
3
+
4
+
5
+ class EncoderCNN(nn.Module):
6
+ def __init__(self, embed_size, fine_tune=False):
7
+ super(EncoderCNN, self).__init__()
8
+ resnet = resnet50(weights=ResNet50_Weights.DEFAULT if fine_tune else None)
9
+ for param in resnet.parameters():
10
+ param.requires_grad = False
11
+ if fine_tune:
12
+ for param in resnet.layer4.parameters():
13
+ param.requires_grad = True
14
+ backbone = list(resnet.children())[:-1]
15
+
16
+ self.resnet = nn.Sequential(*backbone)
17
+ self.fc = nn.Linear(resnet.fc.in_features, embed_size)
18
+ self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.01)
19
+
20
+ def forward(self, images): # (B, C, W, H)
21
+ features = self.resnet(images) # (B, 2048, 1, 1)
22
+ features = features.reshape(features.shape[0], -1) # (B, 2048*1*1) not necessay to reshape as fc layer can take any size input
23
+ return self.bn(self.fc(features)) # (B, embed_size)
24
+
models/vocabulary.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pandas==3.0.1
2
+ Pillow==12.1.1
3
+ sacrebleu==2.6.0
4
+ streamlit==1.55.0
5
+ torch==2.10.0
6
+ torchvision==0.25.0
utils/helpers.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from enum import Enum
2
+ import os
3
+
4
+
5
+ DATA_DIR = 'data'
6
+ CAPTIONS_FILE_PATH = os.path.join(DATA_DIR, 'flickr_data/captions.txt')
7
+ CAPTIONS_TKN_PATH = os.path.join(DATA_DIR, 'captions_tokenized.csv')
8
+ IMAGES_PATH = os.path.join(DATA_DIR, 'flickr_data/Images')
9
+ TOKANIZED_CAPTIONS = os.path.join(DATA_DIR, 'captions_tokenized.csv')
10
+
11
+ MODELS_PATH = 'models'
12
+ ENCODER_MODEL_PATH = os.path.join(MODELS_PATH, 'encoder.pth')
13
+ DECODER_MODEL_PATH = os.path.join(MODELS_PATH, 'decoder.pth')
14
+ VOCAB_FILE_PATH = os.path.join(MODELS_PATH, 'vocabulary.json')
utils/transforms.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from torchvision.models import ResNet50_Weights
2
+
3
+ transforms = ResNet50_Weights.DEFAULT.transforms()
utils/vocab.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ from collections import Counter
4
+
5
+ class Vocabulary():
6
+ SPECIAL_TOKENS = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"]
7
+
8
+ def __init__(self, df=None, load_path=None, min_freq=1):
9
+ if load_path:
10
+ with open(load_path, 'r') as f:
11
+ self.stoi = json.load(f)
12
+ else:
13
+ # token_freq = df.explode('tokens').value_counts()
14
+ # More efficient than df.explode for large datasets
15
+ counts = Counter([token for tokens in df['tokens'] for token in tokens])
16
+ self.stoi = {tok: i for i, tok in enumerate(self.SPECIAL_TOKENS)}
17
+ for token, freq in counts.items():
18
+ if freq >= min_freq:
19
+ self.stoi[token] = len(self.stoi)
20
+
21
+ self.itos = {i: s for s, i in self.stoi.items()}
22
+
23
+
24
+ def __len__(self):
25
+ return len(self.stoi)
26
+
27
+
28
+ def __getitem__(self, key):
29
+ if isinstance(key, str):
30
+ return self.stoi.get(key, self.stoi['<UNK>'])
31
+ elif isinstance(key, int):
32
+ return self.itos.get(key, '<UNK>')
33
+
34
+
35
+ def numericalize(self, tokens):
36
+ return [self[token] for token in tokens]
37
+
38
+
39
+ def texualize(self, indices):
40
+ return [self[idx] for idx in indices]
41
+