Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +35 -35
- Dockerfile +23 -23
- LICENSE +0 -0
- README.md +202 -0
- app.py +105 -0
- inference.py +111 -0
- models/decoder.py +24 -0
- models/encoder.pth +3 -0
- models/encoder.py +24 -0
- models/vocabulary.json +0 -0
- requirements.txt +6 -0
- utils/helpers.py +14 -0
- utils/transforms.py +3 -0
- utils/vocab.py +41 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,35 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
CHANGED
|
@@ -1,24 +1,24 @@
|
|
| 1 |
-
FROM python:3.13.5-slim
|
| 2 |
-
|
| 3 |
-
WORKDIR /app
|
| 4 |
-
|
| 5 |
-
RUN apt-get update && apt-get install -y \
|
| 6 |
-
build-essential \
|
| 7 |
-
curl \
|
| 8 |
-
git \
|
| 9 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
-
|
| 11 |
-
COPY requirements.txt ./
|
| 12 |
-
COPY src/ ./src/
|
| 13 |
-
COPY resnet18_cifar10_finetuned.pth ./
|
| 14 |
-
|
| 15 |
-
RUN pip3 install -r requirements.txt
|
| 16 |
-
|
| 17 |
-
# Change EXPOSE to 7860
|
| 18 |
-
EXPOSE 7860
|
| 19 |
-
|
| 20 |
-
# Update HEALTHCHECK to use 7860
|
| 21 |
-
HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
|
| 22 |
-
|
| 23 |
-
# Add the XSRF and CORS disable flags to ENTRYPOINT
|
| 24 |
ENTRYPOINT ["streamlit", "run", "src/image_classifier_app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.enableCORS=false", "--server.enableXsrfProtection=false"]
|
|
|
|
| 1 |
+
FROM python:3.13.5-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
build-essential \
|
| 7 |
+
curl \
|
| 8 |
+
git \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY requirements.txt ./
|
| 12 |
+
COPY src/ ./src/
|
| 13 |
+
COPY resnet18_cifar10_finetuned.pth ./
|
| 14 |
+
|
| 15 |
+
RUN pip3 install -r requirements.txt
|
| 16 |
+
|
| 17 |
+
# Change EXPOSE to 7860
|
| 18 |
+
EXPOSE 7860
|
| 19 |
+
|
| 20 |
+
# Update HEALTHCHECK to use 7860
|
| 21 |
+
HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
|
| 22 |
+
|
| 23 |
+
# Add the XSRF and CORS disable flags to ENTRYPOINT
|
| 24 |
ENTRYPOINT ["streamlit", "run", "src/image_classifier_app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.enableCORS=false", "--server.enableXsrfProtection=false"]
|
LICENSE
ADDED
|
File without changes
|
README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Here is a clean, professional `README.md` suitable for your HF Space:
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
\# AI Image Caption Generator
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
A deep learning–based image captioning system built using a \*\*ResNet50 encoder\*\* and an \*\*LSTM decoder\*\*. The model generates natural language descriptions for uploaded images.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
\## Architecture
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
\* \*\*Encoder:\*\* ResNet50 (frozen backbone)
|
| 22 |
+
|
| 23 |
+
\* \*\*Decoder:\*\* LSTM-based sequence generator
|
| 24 |
+
|
| 25 |
+
\* \*\*Training Dataset:\*\* Flickr8k
|
| 26 |
+
|
| 27 |
+
\* \*\*Inference Framework:\*\* Streamlit
|
| 28 |
+
|
| 29 |
+
\* \*\*Evaluation Metric:\*\* SacreBLEU
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
The encoder extracts high-level visual features, which are then passed to the decoder to generate captions word by word.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
\## How It Works
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
1\. User uploads an image.
|
| 46 |
+
|
| 47 |
+
2\. Image is preprocessed and passed through the ResNet50 encoder.
|
| 48 |
+
|
| 49 |
+
3\. Extracted feature vector is fed into the LSTM decoder.
|
| 50 |
+
|
| 51 |
+
4\. Caption is generated using temperature-based sampling.
|
| 52 |
+
|
| 53 |
+
5\. If the image belongs to the Flickr8k dataset, BLEU metrics are displayed.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
\## Features
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
\* Temperature-controlled caption generation
|
| 66 |
+
|
| 67 |
+
\* SacreBLEU evaluation
|
| 68 |
+
|
| 69 |
+
\* N-gram precision breakdown (1–4 gram)
|
| 70 |
+
|
| 71 |
+
\* Clean Streamlit interface
|
| 72 |
+
|
| 73 |
+
\* Fully CPU-compatible deployment
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
\## Project Structure
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
app.py
|
| 88 |
+
|
| 89 |
+
models/
|
| 90 |
+
|
| 91 |
+
encoder.pth
|
| 92 |
+
|
| 93 |
+
decoder.pth
|
| 94 |
+
|
| 95 |
+
models/
|
| 96 |
+
|
| 97 |
+
encoder.py
|
| 98 |
+
|
| 99 |
+
decoder.py
|
| 100 |
+
|
| 101 |
+
utils/
|
| 102 |
+
|
| 103 |
+
transforms.py
|
| 104 |
+
|
| 105 |
+
vocab.py
|
| 106 |
+
|
| 107 |
+
helpers.py
|
| 108 |
+
|
| 109 |
+
vocabulary.json
|
| 110 |
+
|
| 111 |
+
requirements.txt
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
\## Model Details
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
\* Encoder weights size: ~92 MB
|
| 126 |
+
|
| 127 |
+
\* Decoder weights size: ~32 MB
|
| 128 |
+
|
| 129 |
+
\* Full encoder backbone included in state\_dict
|
| 130 |
+
|
| 131 |
+
\* Inference runs on CPU
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
\## Limitations
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
\* Trained on Flickr8k (8,000 images)
|
| 144 |
+
|
| 145 |
+
\* Performs best on outdoor scenes, people, and animals
|
| 146 |
+
|
| 147 |
+
\* May generalize poorly to unseen domains
|
| 148 |
+
|
| 149 |
+
\* CPU inference can be slow (2–5 seconds per image)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
---
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
\## Setup (Local)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
```bash
|
| 162 |
+
|
| 163 |
+
pip install -r requirements.txt
|
| 164 |
+
|
| 165 |
+
streamlit run app.py
|
| 166 |
+
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
---
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
\## Deployment
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
This project is deployed on \*\*Hugging Face Spaces\*\* using Streamlit.
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
---
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
\## License
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
MIT License
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
If you want, I can also write a \*\*short portfolio-style README\*\* optimized specifically for recruiters.
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
app.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import streamlit as st
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from models.encoder import EncoderCNN
|
| 6 |
+
from models.decoder import DecoderRNN
|
| 7 |
+
from utils.vocab import Vocabulary
|
| 8 |
+
from torchvision import transforms as T
|
| 9 |
+
from utils.helpers import ENCODER_MODEL_PATH, DECODER_MODEL_PATH, VOCAB_FILE_PATH, CAPTIONS_TKN_PATH
|
| 10 |
+
from inference import sample_with_temp, sample
|
| 11 |
+
from utils.transforms import transforms
|
| 12 |
+
import sacrebleu
|
| 13 |
+
# ... (your other imports)
|
| 14 |
+
|
| 15 |
+
@st.cache_resource
|
| 16 |
+
def load_models():
|
| 17 |
+
captions = pd.read_csv(CAPTIONS_TKN_PATH).drop('tokens', axis=1)
|
| 18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 19 |
+
|
| 20 |
+
# Load Vocabulary
|
| 21 |
+
vocab = Vocabulary(load_path=VOCAB_FILE_PATH)
|
| 22 |
+
|
| 23 |
+
# Initialize Models
|
| 24 |
+
encoder = EncoderCNN(256).to(device)
|
| 25 |
+
decoder = DecoderRNN(len(vocab), 256, 512).to(device)
|
| 26 |
+
|
| 27 |
+
# Load Weights
|
| 28 |
+
encoder.load_state_dict(torch.load(ENCODER_MODEL_PATH, map_location=device))
|
| 29 |
+
decoder.load_state_dict(torch.load(DECODER_MODEL_PATH, map_location=device))
|
| 30 |
+
|
| 31 |
+
encoder.eval()
|
| 32 |
+
decoder.eval()
|
| 33 |
+
|
| 34 |
+
return encoder, decoder, vocab, device, captions
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# --- Main App Logic ---
|
| 38 |
+
encoder, decoder, vocab, device, captions = load_models()
|
| 39 |
+
act_caps = []
|
| 40 |
+
caption = ''
|
| 41 |
+
st.title("📸 AI Image Captioner")
|
| 42 |
+
|
| 43 |
+
temp = st.slider("Sampling Temperature", min_value=0.1, max_value=2.0, value=0.8, step=0.1)
|
| 44 |
+
st.info("Higher temperature = more creative/random. Lower = more predictable.")
|
| 45 |
+
|
| 46 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
|
| 47 |
+
if uploaded_file is not None:
|
| 48 |
+
img = Image.open(uploaded_file).convert('RGB')
|
| 49 |
+
st.image(img, caption='Uploaded Image', width=300)
|
| 50 |
+
|
| 51 |
+
# Process
|
| 52 |
+
# Assuming transforms is defined or returned from load_models
|
| 53 |
+
img_tensor = transforms(img).unsqueeze(0).to(device)
|
| 54 |
+
|
| 55 |
+
st.subheader("Actual Captions:")
|
| 56 |
+
act_caps = captions[captions['image'] == uploaded_file.name]['caption'].tolist()
|
| 57 |
+
st.success(" \n".join(act_caps))
|
| 58 |
+
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
encoder_out = encoder(img_tensor)
|
| 61 |
+
# Pass the 'temp' variable from the slider here
|
| 62 |
+
caption = sample_with_temp(encoder_out, decoder, vocab, temp=temp)
|
| 63 |
+
|
| 64 |
+
st.subheader("Generated Caption:")
|
| 65 |
+
st.success(caption)
|
| 66 |
+
|
| 67 |
+
if act_caps:
|
| 68 |
+
# sacrebleu expects a list of strings for hypothesis
|
| 69 |
+
# and a list of lists of strings for references
|
| 70 |
+
refs = [act_caps]
|
| 71 |
+
sys = [caption]
|
| 72 |
+
|
| 73 |
+
bleu = sacrebleu.corpus_bleu(sys, refs)
|
| 74 |
+
|
| 75 |
+
st.subheader("Evaluation Metrics:")
|
| 76 |
+
st.metric(label="SacreBLEU Score", value=f"{bleu.score:.2f}")
|
| 77 |
+
st.progress(min(bleu.score / 50, 1.0))
|
| 78 |
+
|
| 79 |
+
# N-gram Precision breakdown
|
| 80 |
+
# bleu.precisions is a list: [p1, p2, p3, p4]
|
| 81 |
+
cols = st.columns(4)
|
| 82 |
+
for i, p in enumerate(bleu.precisions):
|
| 83 |
+
cols[i].markdown(f"{i+1}-gram")
|
| 84 |
+
cols[i].write(f"{p:.1f}%")
|
| 85 |
+
|
| 86 |
+
# Brief explanation
|
| 87 |
+
with st.expander("What do these mean?"):
|
| 88 |
+
st.write("""
|
| 89 |
+
- **1-gram**: Individual word accuracy (Vocabulary).
|
| 90 |
+
- **2-gram**: Fluency of word pairs.
|
| 91 |
+
- **4-gram**: Capturing longer phrases/sentence structure.
|
| 92 |
+
""")
|
| 93 |
+
else:
|
| 94 |
+
st.info("Upload an image from the Flickr8k set to see BLEU metrics.")
|
| 95 |
+
|
| 96 |
+
st.header('About this Project')
|
| 97 |
+
st.markdown("""
|
| 98 |
+
This AI model generates descriptive captions for uploaded images using a **ResNet50 + LSTM** architecture.
|
| 99 |
+
|
| 100 |
+
* **Encoder:** Pre-trained ResNet50 (Frozen) extracts high-level visual features.
|
| 101 |
+
* **Decoder:** A Long Short-Term Memory (LSTM) network trained for 10 epochs.
|
| 102 |
+
* **Dataset:** Trained on the **Flickr8k dataset** (8,000 images).
|
| 103 |
+
|
| 104 |
+
⚠️ **Note:** Because the model was trained on a specific, small-scale dataset with a frozen backbone, it performs satisfactory on outdoor scenes, people, and animals. It may produce unexpected results for images significantly different from the Flickr8k distribution.
|
| 105 |
+
""")
|
inference.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from utils.transforms import transforms
|
| 5 |
+
from utils.vocab import Vocabulary
|
| 6 |
+
from utils.helpers import VOCAB_FILE_PATH, ENCODER_MODEL_PATH, DECODER_MODEL_PATH
|
| 7 |
+
from models.encoder import EncoderCNN
|
| 8 |
+
from models.decoder import DecoderRNN
|
| 9 |
+
import PIL.Image as Image
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def sample(features, decoder, vocab, max_len=20):
|
| 13 |
+
device = features.device
|
| 14 |
+
result_caption = []
|
| 15 |
+
word_idx = torch.tensor([vocab['<SOS>']]).unsqueeze(0).to(device) # Shape (1, 1)
|
| 16 |
+
outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size)
|
| 17 |
+
for _ in range(max_len):
|
| 18 |
+
predicted = outputs.argmax(2)
|
| 19 |
+
word = vocab[predicted.item()] # .item() to get numerical value from tensor
|
| 20 |
+
if word == '<EOS>':
|
| 21 |
+
break
|
| 22 |
+
result_caption.append(word)
|
| 23 |
+
# Pass features=None and previous hidden state
|
| 24 |
+
outputs, hidden = decoder(None, predicted, hidden)
|
| 25 |
+
return ' '.join(result_caption)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def beam_sample(features, decoder, vocab, beam_size=5, max_len=30):
|
| 30 |
+
device = features.device
|
| 31 |
+
# (log_score, sequence, hidden_state)
|
| 32 |
+
start_token = torch.tensor([vocab['<SOS>']]).to(device)
|
| 33 |
+
beams = [(0, [start_token.item()], None)]
|
| 34 |
+
|
| 35 |
+
for _ in range(max_len):
|
| 36 |
+
candidates = []
|
| 37 |
+
for score, seq, hidden in beams:
|
| 38 |
+
if seq[-1] == vocab['<EOS>']:
|
| 39 |
+
candidates.append((score, seq, hidden))
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
# Predict next word
|
| 43 |
+
curr_word = torch.tensor([seq[-1]]).unsqueeze(0).to(device)
|
| 44 |
+
# Use features only on first step
|
| 45 |
+
feat_input = features if _ == 0 else None
|
| 46 |
+
|
| 47 |
+
outputs, next_hidden = decoder(feat_input, curr_word, hidden)
|
| 48 |
+
|
| 49 |
+
# Get log probabilities
|
| 50 |
+
log_probs = F.log_softmax(outputs.squeeze(1), dim=1)
|
| 51 |
+
top_probs, top_idxs = log_probs.topk(beam_size)
|
| 52 |
+
|
| 53 |
+
for i in range(beam_size):
|
| 54 |
+
candidates.append((score + top_probs[0][i].item(),
|
| 55 |
+
seq + [top_idxs[0][i].item()],
|
| 56 |
+
next_hidden))
|
| 57 |
+
|
| 58 |
+
# Sort by score and keep top k
|
| 59 |
+
beams = sorted(candidates, key=lambda x: x[0], reverse=True)[:beam_size]
|
| 60 |
+
|
| 61 |
+
# Stop if all beams end in <EOS>
|
| 62 |
+
if all(s[-1] == vocab['<EOS>'] for _, s, _ in beams):
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
# Return best sequence (minus tokens)
|
| 66 |
+
best_seq = beams[0][1]
|
| 67 |
+
return ' '.join([vocab[idx] for idx in best_seq if idx not in [vocab['<SOS>'], vocab['<EOS>']]])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def sample_with_temp(features, decoder, vocab, temp=0.8, max_len=30):
|
| 72 |
+
device = features.device
|
| 73 |
+
result_caption = []
|
| 74 |
+
word_idx = torch.tensor([vocab['<SOS>']]).unsqueeze(0).to(device)
|
| 75 |
+
outputs, hidden = decoder(features, word_idx) # outputs (1, 1, vocab_size)
|
| 76 |
+
for _ in range(max_len):
|
| 77 |
+
# Apply temperature to logits
|
| 78 |
+
logits = outputs.squeeze(1) / temp
|
| 79 |
+
probs = F.softmax(logits, dim=-1)
|
| 80 |
+
# Sample from the distribution instead of argmax
|
| 81 |
+
predicted = torch.multinomial(probs, 1)
|
| 82 |
+
word = vocab[predicted.item()]
|
| 83 |
+
if word == '<EOS>': break
|
| 84 |
+
result_caption.append(word)
|
| 85 |
+
outputs, hidden = decoder(None, predicted, hidden)
|
| 86 |
+
return ' '.join(result_caption)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 91 |
+
vocabulary = Vocabulary(load_path=VOCAB_FILE_PATH)
|
| 92 |
+
# img = Image.open(COLAB_DATA_FOLDER + 'Images/' + '141140165_9002a04f19.jpg').convert('RGB')
|
| 93 |
+
|
| 94 |
+
encoder = EncoderCNN(256).to(device)
|
| 95 |
+
decoder = DecoderRNN(len(vocabulary), 256, 512).to(device)
|
| 96 |
+
|
| 97 |
+
encoder_state_dict = torch.load(ENCODER_MODEL_PATH, map_location=device)
|
| 98 |
+
decoder_state_dict = torch.load(DECODER_MODEL_PATH, map_location=device)
|
| 99 |
+
|
| 100 |
+
encoder.load_state_dict(encoder_state_dict)
|
| 101 |
+
decoder.load_state_dict(decoder_state_dict)
|
| 102 |
+
|
| 103 |
+
encoder.eval()
|
| 104 |
+
decoder.eval()
|
| 105 |
+
|
| 106 |
+
img = Image.open('data/flickr_data/Images/3718892835_a3e74a3417.jpg').convert('RGB')
|
| 107 |
+
img = transforms(img).unsqueeze(0).to(device)
|
| 108 |
+
encoder_out = encoder(img)
|
| 109 |
+
print('sample_with_temp: ', sample_with_temp(encoder_out, decoder, vocabulary))
|
| 110 |
+
# print('sample: ', sample(encoder_out, decoder, vocabulary))
|
| 111 |
+
# print('beam_sample: ', beam_sample(encoder_out, decoder, vocabulary))
|
models/decoder.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class DecoderRNN(nn.Module):
|
| 5 |
+
def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1, padding_idx=0):
|
| 6 |
+
super(DecoderRNN, self).__init__()
|
| 7 |
+
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_size, padding_idx=padding_idx)
|
| 8 |
+
self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
|
| 9 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
| 10 |
+
|
| 11 |
+
self.init_h = nn.Linear(embed_size, hidden_size)
|
| 12 |
+
self.init_c = nn.Linear(embed_size, hidden_size)
|
| 13 |
+
|
| 14 |
+
def forward(self, features, captions, hidden=None):
|
| 15 |
+
if hidden == None:
|
| 16 |
+
h0 = self.init_h(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
|
| 17 |
+
c0 = self.init_c(features).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
|
| 18 |
+
hidden = (h0, c0)
|
| 19 |
+
# dataflow: (B, seqlen) -> (B, hidden_size) -> (1, B, hidden_size) -> (num_layers, B, hidden_size)
|
| 20 |
+
|
| 21 |
+
embeddings = self.embed(captions) # (B, seqlen) -> Training: (B, seqlen, embed_size) | Inference: (B, 1, embed_size)
|
| 22 |
+
outputs, hidden = self.lstm(embeddings, hidden) # Training: (B, seqlen, hidden_size) | Inference: (B, 1, hidden_size)
|
| 23 |
+
outputs = self.linear(outputs) # Training: (B, seqlen, vocab_size) | Inference: (B, 1, vocab_size)
|
| 24 |
+
return outputs, hidden
|
models/encoder.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:15c2fd0e6a24e1b58ba66cd7f9754d9d9befd885e1aea762d0016b6d5f8d351c
|
| 3 |
+
size 96454389
|
models/encoder.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision.models import resnet50, ResNet50_Weights
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class EncoderCNN(nn.Module):
|
| 6 |
+
def __init__(self, embed_size, fine_tune=False):
|
| 7 |
+
super(EncoderCNN, self).__init__()
|
| 8 |
+
resnet = resnet50(weights=ResNet50_Weights.DEFAULT if fine_tune else None)
|
| 9 |
+
for param in resnet.parameters():
|
| 10 |
+
param.requires_grad = False
|
| 11 |
+
if fine_tune:
|
| 12 |
+
for param in resnet.layer4.parameters():
|
| 13 |
+
param.requires_grad = True
|
| 14 |
+
backbone = list(resnet.children())[:-1]
|
| 15 |
+
|
| 16 |
+
self.resnet = nn.Sequential(*backbone)
|
| 17 |
+
self.fc = nn.Linear(resnet.fc.in_features, embed_size)
|
| 18 |
+
self.bn = nn.BatchNorm1d(num_features=embed_size, momentum=0.01)
|
| 19 |
+
|
| 20 |
+
def forward(self, images): # (B, C, W, H)
|
| 21 |
+
features = self.resnet(images) # (B, 2048, 1, 1)
|
| 22 |
+
features = features.reshape(features.shape[0], -1) # (B, 2048*1*1) not necessay to reshape as fc layer can take any size input
|
| 23 |
+
return self.bn(self.fc(features)) # (B, embed_size)
|
| 24 |
+
|
models/vocabulary.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas==3.0.1
|
| 2 |
+
Pillow==12.1.1
|
| 3 |
+
sacrebleu==2.6.0
|
| 4 |
+
streamlit==1.55.0
|
| 5 |
+
torch==2.10.0
|
| 6 |
+
torchvision==0.25.0
|
utils/helpers.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from enum import Enum
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
DATA_DIR = 'data'
|
| 6 |
+
CAPTIONS_FILE_PATH = os.path.join(DATA_DIR, 'flickr_data/captions.txt')
|
| 7 |
+
CAPTIONS_TKN_PATH = os.path.join(DATA_DIR, 'captions_tokenized.csv')
|
| 8 |
+
IMAGES_PATH = os.path.join(DATA_DIR, 'flickr_data/Images')
|
| 9 |
+
TOKANIZED_CAPTIONS = os.path.join(DATA_DIR, 'captions_tokenized.csv')
|
| 10 |
+
|
| 11 |
+
MODELS_PATH = 'models'
|
| 12 |
+
ENCODER_MODEL_PATH = os.path.join(MODELS_PATH, 'encoder.pth')
|
| 13 |
+
DECODER_MODEL_PATH = os.path.join(MODELS_PATH, 'decoder.pth')
|
| 14 |
+
VOCAB_FILE_PATH = os.path.join(MODELS_PATH, 'vocabulary.json')
|
utils/transforms.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision.models import ResNet50_Weights
|
| 2 |
+
|
| 3 |
+
transforms = ResNet50_Weights.DEFAULT.transforms()
|
utils/vocab.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
from collections import Counter
|
| 4 |
+
|
| 5 |
+
class Vocabulary():
|
| 6 |
+
SPECIAL_TOKENS = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"]
|
| 7 |
+
|
| 8 |
+
def __init__(self, df=None, load_path=None, min_freq=1):
|
| 9 |
+
if load_path:
|
| 10 |
+
with open(load_path, 'r') as f:
|
| 11 |
+
self.stoi = json.load(f)
|
| 12 |
+
else:
|
| 13 |
+
# token_freq = df.explode('tokens').value_counts()
|
| 14 |
+
# More efficient than df.explode for large datasets
|
| 15 |
+
counts = Counter([token for tokens in df['tokens'] for token in tokens])
|
| 16 |
+
self.stoi = {tok: i for i, tok in enumerate(self.SPECIAL_TOKENS)}
|
| 17 |
+
for token, freq in counts.items():
|
| 18 |
+
if freq >= min_freq:
|
| 19 |
+
self.stoi[token] = len(self.stoi)
|
| 20 |
+
|
| 21 |
+
self.itos = {i: s for s, i in self.stoi.items()}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def __len__(self):
|
| 25 |
+
return len(self.stoi)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def __getitem__(self, key):
|
| 29 |
+
if isinstance(key, str):
|
| 30 |
+
return self.stoi.get(key, self.stoi['<UNK>'])
|
| 31 |
+
elif isinstance(key, int):
|
| 32 |
+
return self.itos.get(key, '<UNK>')
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def numericalize(self, tokens):
|
| 36 |
+
return [self[token] for token in tokens]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def texualize(self, indices):
|
| 40 |
+
return [self[idx] for idx in indices]
|
| 41 |
+
|