Spaces:
Sleeping
Sleeping
Commit ·
a745a5e
0
Parent(s):
Deploy Streamlit app
Browse files- .gitattributes +35 -0
- .gitignore +57 -0
- .streamlit/config.toml +8 -0
- Dockerfile +20 -0
- README.md +149 -0
- app.py +469 -0
- app/streamlit_app.py +145 -0
- beam_search_experiments.py +139 -0
- create_subset_20k.py +16 -0
- dataset_384.py +46 -0
- dataset_advanced.py +97 -0
- dataset_git.py +68 -0
- dataset_vit_gpt2.py +77 -0
- evaluate.py +149 -0
- plot/beam_experiment_plot.py +27 -0
- plot/caption_length_analysis.py +58 -0
- requirements.txt +20 -0
- src/__init__.py +7 -0
- src/data/__init__.py +4 -0
- src/data/coco_384_dataset.py +52 -0
- src/data/coco_advanced_dataset.py +98 -0
- src/data/coco_vit_gpt2_dataset.py +87 -0
- src/evaluation/__init__.py +4 -0
- src/evaluation/cider_eval.py +65 -0
- src/streamlit_app.py +40 -0
- src/training/__init__.py +4 -0
- src/training/train_phase1.py +168 -0
- src/training/train_phase2.py +170 -0
- src/utils/__init__.py +4 -0
- src/utils/data_subset.py +64 -0
- train_blip_20k_384.py +103 -0
- train_data_experiments.py +225 -0
- train_git.py +127 -0
- train_phase2.py +219 -0
- train_vit_gpt2.py +194 -0
- uploadtohf.py +48 -0
.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
.mypy_cache/
|
| 5 |
+
.ruff_cache/
|
| 6 |
+
.DS_Store
|
| 7 |
+
|
| 8 |
+
# Local datasets
|
| 9 |
+
train2017/
|
| 10 |
+
val2017/
|
| 11 |
+
annotations/*.jsonl
|
| 12 |
+
|
| 13 |
+
# Local model artifacts (don't commit)
|
| 14 |
+
saved_model_phase1/
|
| 15 |
+
saved_model_phase2/
|
| 16 |
+
saved_vit_gpt2/
|
| 17 |
+
saved_git_model/
|
| 18 |
+
saved_model_phase2.bak/
|
| 19 |
+
saved_vit_gpt2.bak/
|
| 20 |
+
saved_git_model.bak/
|
| 21 |
+
|
| 22 |
+
# Hugging Face cache (optional)
|
| 23 |
+
.cache/
|
| 24 |
+
hf_cache/
|
| 25 |
+
hf_home/
|
| 26 |
+
# virtual environment
|
| 27 |
+
.venv/
|
| 28 |
+
|
| 29 |
+
# python cache
|
| 30 |
+
__pycache__/
|
| 31 |
+
*.pyc
|
| 32 |
+
|
| 33 |
+
# datasets
|
| 34 |
+
train2017/
|
| 35 |
+
val2017/
|
| 36 |
+
annotations/
|
| 37 |
+
|
| 38 |
+
# trained models
|
| 39 |
+
saved_model/
|
| 40 |
+
|
| 41 |
+
saved_model_20k_384/
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# checkpoints
|
| 46 |
+
checkpoints/
|
| 47 |
+
checkpoints_20k_384/
|
| 48 |
+
|
| 49 |
+
# mac files
|
| 50 |
+
.DS_Store
|
| 51 |
+
|
| 52 |
+
# Hugging Face Spaces: avoid binaries in git
|
| 53 |
+
*.png
|
| 54 |
+
*.jpg
|
| 55 |
+
*.jpeg
|
| 56 |
+
*.gif
|
| 57 |
+
*.webp
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[server]
|
| 2 |
+
headless = true
|
| 3 |
+
enableCORS = false
|
| 4 |
+
enableXsrfProtection = true
|
| 5 |
+
|
| 6 |
+
[browser]
|
| 7 |
+
gatherUsageStats = false
|
| 8 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.13.5-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
build-essential \
|
| 7 |
+
curl \
|
| 8 |
+
git \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY requirements.txt ./
|
| 12 |
+
COPY src/ ./src/
|
| 13 |
+
|
| 14 |
+
RUN pip3 install -r requirements.txt
|
| 15 |
+
|
| 16 |
+
EXPOSE 8501
|
| 17 |
+
|
| 18 |
+
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
+
|
| 20 |
+
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
README.md
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Image Captioning (Streamlit)
|
| 2 |
+
|
| 3 |
+
This repo hosts a Streamlit app (`app.py`) that compares multiple image-captioning models.
|
| 4 |
+
|
| 5 |
+
## Why your models should NOT be inside the app repo
|
| 6 |
+
|
| 7 |
+
Fine-tuned checkpoints are large. Public hosting (Hugging Face Spaces / Streamlit Cloud) works best when:
|
| 8 |
+
|
| 9 |
+
- the app repo stays small
|
| 10 |
+
- models live on the Hugging Face Hub (or S3/GCS)
|
| 11 |
+
- the app downloads models at startup (cached by `transformers`)
|
| 12 |
+
|
| 13 |
+
## 1) Upload your saved models to Hugging Face Hub
|
| 14 |
+
|
| 15 |
+
Example for BLIP (you already have `uploadtohf.py`):
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
pip install -U transformers huggingface_hub
|
| 19 |
+
huggingface-cli login
|
| 20 |
+
python uploadtohf.py
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
Do the same for your other local folders (`saved_vit_gpt2`, `saved_git_model`) by pushing them to separate Hub repos.
|
| 24 |
+
|
| 25 |
+
## 2) Configure the app to load from Hub
|
| 26 |
+
|
| 27 |
+
`app.py` loads **local folders if present**, otherwise falls back to Hub IDs via environment variables:
|
| 28 |
+
|
| 29 |
+
- `BLIP_MODEL_ID` (default: `prateekchandra/blip-caption-model`)
|
| 30 |
+
- `VITGPT2_MODEL_ID` (default: `prateekchandra/vit-gpt2-caption-model`)
|
| 31 |
+
- `GIT_MODEL_ID` (default: `prateekchandra/git-caption-model`)
|
| 32 |
+
|
| 33 |
+
In this repo, defaults are set to:
|
| 34 |
+
|
| 35 |
+
- `BLIP_MODEL_ID` (default: `pchandragrid/blip-caption-model`)
|
| 36 |
+
- `VITGPT2_MODEL_ID` (default: `pchandragrid/vit-gpt2-caption-model`)
|
| 37 |
+
- `GIT_MODEL_ID` (default: `pchandragrid/git-caption-model`)
|
| 38 |
+
|
| 39 |
+
You can also override local folder names:
|
| 40 |
+
|
| 41 |
+
- `BLIP_LOCAL_DIR` (default: `saved_model_phase2`)
|
| 42 |
+
- `VITGPT2_LOCAL_DIR` (default: `saved_vit_gpt2`)
|
| 43 |
+
- `GIT_LOCAL_DIR` (default: `saved_git_model`)
|
| 44 |
+
|
| 45 |
+
## 3) Deploy options
|
| 46 |
+
|
| 47 |
+
### Option A: Hugging Face Spaces (recommended)
|
| 48 |
+
|
| 49 |
+
- Create a new Space: **Streamlit**
|
| 50 |
+
- Push this repo (must include `app.py` + `requirements.txt`)
|
| 51 |
+
- In Space “Variables”, set `BLIP_MODEL_ID`, `VITGPT2_MODEL_ID`, `GIT_MODEL_ID` to your Hub repos
|
| 52 |
+
- If any model repo is private, add `HF_TOKEN` as a Space **Secret**
|
| 53 |
+
|
| 54 |
+
### Option B: Streamlit Community Cloud
|
| 55 |
+
|
| 56 |
+
- Point it to this repo
|
| 57 |
+
- Set the same env vars in the app settings
|
| 58 |
+
|
| 59 |
+
## Local run
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
python -m venv .venv
|
| 63 |
+
source .venv/bin/activate
|
| 64 |
+
pip install -r requirements.txt
|
| 65 |
+
streamlit run app.py
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
# 🖼️ Image Captioning with BLIP (COCO Subset)
|
| 69 |
+
|
| 70 |
+
## 📌 Problem
|
| 71 |
+
|
| 72 |
+
Generate natural language descriptions for images using transformer-based vision-language models.
|
| 73 |
+
|
| 74 |
+
Goal:
|
| 75 |
+
- Improve CIDEr score by 10%+
|
| 76 |
+
- Compare architectures (BLIP vs ViT-GPT2)
|
| 77 |
+
- Analyze resolution impact (224 vs 320 vs 384)
|
| 78 |
+
- Optimize decoding parameters
|
| 79 |
+
- Deploy minimal inference UI
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## 📂 Dataset
|
| 84 |
+
|
| 85 |
+
- MS COCO Captions (subset: 10k & 20k)
|
| 86 |
+
- Random caption selection (5 captions per image)
|
| 87 |
+
- Experiments:
|
| 88 |
+
- Short captions
|
| 89 |
+
- Mixed captions
|
| 90 |
+
- Filtered captions
|
| 91 |
+
|
| 92 |
+
Train/Validation split: 90/10
|
| 93 |
+
|
| 94 |
+
---
|
| 95 |
+
|
| 96 |
+
## 🧠 Models
|
| 97 |
+
|
| 98 |
+
### 1️⃣ BLIP (Primary Model)
|
| 99 |
+
- Salesforce/blip-image-captioning-base
|
| 100 |
+
- Vision encoder frozen (for efficiency)
|
| 101 |
+
- Gradient checkpointing enabled
|
| 102 |
+
- Mixed precision on MPS
|
| 103 |
+
|
| 104 |
+
### 2️⃣ ViT-GPT2 (Comparison)
|
| 105 |
+
- ViT base encoder
|
| 106 |
+
- GPT2 decoder with cross-attention
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## 🧪 Experiments
|
| 111 |
+
|
| 112 |
+
### Resolution Comparison
|
| 113 |
+
| Resolution | Dataset | CIDEr |
|
| 114 |
+
|------------|---------|--------|
|
| 115 |
+
| 224px | 10k | ~1.28 |
|
| 116 |
+
| 320px | 20k | ~1.33–1.38 |
|
| 117 |
+
| 384px | 20k | ~1.40+ |
|
| 118 |
+
|
| 119 |
+
### Beam Search Tuning
|
| 120 |
+
Tested:
|
| 121 |
+
- Beams: 3, 5, 8
|
| 122 |
+
- Length penalty: 0.8, 1.0, 1.2
|
| 123 |
+
- Max length: 20, 30, 40
|
| 124 |
+
|
| 125 |
+
Best config:
|
| 126 |
+
Beams=5, MaxLen=20, LengthPenalty=1.0
|
| 127 |
+
|
| 128 |
+
---
|
| 129 |
+
|
| 130 |
+
## 📊 Evaluation Metric
|
| 131 |
+
|
| 132 |
+
- CIDEr (via pycocoevalcap)
|
| 133 |
+
- Validation loss
|
| 134 |
+
- Confidence estimation
|
| 135 |
+
|
| 136 |
+
---
|
| 137 |
+
|
| 138 |
+
## 🖥️ Demo
|
| 139 |
+
|
| 140 |
+
Streamlit app includes:
|
| 141 |
+
- Image uploader
|
| 142 |
+
- Beam controls
|
| 143 |
+
- Toxicity filtering
|
| 144 |
+
- Confidence display
|
| 145 |
+
- Attention heatmap
|
| 146 |
+
|
| 147 |
+
Run:
|
| 148 |
+
```bash
|
| 149 |
+
streamlit run app.py
|
app.py
ADDED
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import torch
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
from transformers import (
|
| 10 |
+
BlipProcessor,
|
| 11 |
+
BlipForConditionalGeneration,
|
| 12 |
+
VisionEncoderDecoderModel,
|
| 13 |
+
ViTImageProcessor,
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
GitProcessor,
|
| 16 |
+
GitForCausalLM
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _get_device() -> torch.device:
|
| 23 |
+
if torch.cuda.is_available():
|
| 24 |
+
return torch.device("cuda")
|
| 25 |
+
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
|
| 26 |
+
return torch.device("mps")
|
| 27 |
+
return torch.device("cpu")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
device = _get_device()
|
| 31 |
+
_TORCH_DTYPE = torch.float16 if device.type in {"cuda", "mps"} else torch.float32
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _resolve_source(local_dir: str, hub_id: str) -> str:
|
| 35 |
+
"""
|
| 36 |
+
Prefer a local directory if it exists; otherwise use a Hugging Face Hub repo id.
|
| 37 |
+
"""
|
| 38 |
+
if local_dir and os.path.isdir(local_dir):
|
| 39 |
+
return local_dir
|
| 40 |
+
return hub_id
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ================================
|
| 44 |
+
# EXPERIMENT GRAPH FUNCTIONS
|
| 45 |
+
# ================================
|
| 46 |
+
|
| 47 |
+
def plot_beam_experiment():
|
| 48 |
+
|
| 49 |
+
beam_sizes = [1,3,5,10]
|
| 50 |
+
|
| 51 |
+
blip_scores = [0.52,0.59,0.61,0.60]
|
| 52 |
+
vit_scores = [0.50,0.56,0.60,0.58]
|
| 53 |
+
git_scores = [0.12,0.16,0.17,0.16]
|
| 54 |
+
|
| 55 |
+
fig, ax = plt.subplots(figsize=(10,6))
|
| 56 |
+
|
| 57 |
+
ax.plot(beam_sizes, blip_scores, marker='o', linewidth=3, label="BLIP")
|
| 58 |
+
ax.plot(beam_sizes, vit_scores, marker='o', linewidth=3, label="ViT-GPT2")
|
| 59 |
+
ax.plot(beam_sizes, git_scores, marker='o', linewidth=3, label="GIT")
|
| 60 |
+
|
| 61 |
+
ax.set_xlabel("Beam Size")
|
| 62 |
+
ax.set_ylabel("CIDEr Score")
|
| 63 |
+
ax.set_title("Beam Size vs Caption Quality")
|
| 64 |
+
|
| 65 |
+
ax.legend()
|
| 66 |
+
ax.grid(True)
|
| 67 |
+
|
| 68 |
+
return fig
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def plot_caption_length():
|
| 72 |
+
|
| 73 |
+
labels = ["Short","Medium","Long"]
|
| 74 |
+
|
| 75 |
+
blip = [0.71,0.60,0.48]
|
| 76 |
+
vit = [0.65,0.59,0.42]
|
| 77 |
+
git = [0.30,0.18,0.11]
|
| 78 |
+
|
| 79 |
+
x = np.arange(len(labels))
|
| 80 |
+
width = 0.25
|
| 81 |
+
|
| 82 |
+
fig, ax = plt.subplots(figsize=(10,6))
|
| 83 |
+
|
| 84 |
+
ax.bar(x - width, blip, width, label="BLIP")
|
| 85 |
+
ax.bar(x, vit, width, label="ViT-GPT2")
|
| 86 |
+
ax.bar(x + width, git, width, label="GIT")
|
| 87 |
+
|
| 88 |
+
ax.set_xlabel("Caption Length Category")
|
| 89 |
+
ax.set_ylabel("CIDEr Score")
|
| 90 |
+
ax.set_title("Model Performance vs Caption Length")
|
| 91 |
+
|
| 92 |
+
ax.set_xticks(x)
|
| 93 |
+
ax.set_xticklabels(labels)
|
| 94 |
+
|
| 95 |
+
ax.legend()
|
| 96 |
+
|
| 97 |
+
return fig
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ================================
|
| 101 |
+
# UI STYLE
|
| 102 |
+
# ================================
|
| 103 |
+
|
| 104 |
+
st.markdown("""
|
| 105 |
+
<style>
|
| 106 |
+
|
| 107 |
+
.main-title{
|
| 108 |
+
text-align:center;
|
| 109 |
+
font-size:42px;
|
| 110 |
+
font-weight:bold;
|
| 111 |
+
margin-bottom:10px;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
.subtitle{
|
| 115 |
+
text-align:center;
|
| 116 |
+
font-size:18px;
|
| 117 |
+
color:gray;
|
| 118 |
+
margin-bottom:30px;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
.caption-box{
|
| 122 |
+
background-color:white;
|
| 123 |
+
padding:20px;
|
| 124 |
+
border-radius:14px;
|
| 125 |
+
text-align:center;
|
| 126 |
+
font-size:18px;
|
| 127 |
+
min-height:120px;
|
| 128 |
+
display:flex;
|
| 129 |
+
align-items:center;
|
| 130 |
+
justify-content:center;
|
| 131 |
+
color:black;
|
| 132 |
+
font-weight:500;
|
| 133 |
+
box-shadow:0px 4px 12px rgba(0,0,0,0.15);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
.model-title{
|
| 137 |
+
text-align:center;
|
| 138 |
+
font-size:22px;
|
| 139 |
+
font-weight:bold;
|
| 140 |
+
margin-bottom:10px;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
</style>
|
| 144 |
+
""", unsafe_allow_html=True)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ================================
|
| 148 |
+
# LOAD MODELS
|
| 149 |
+
# ================================
|
| 150 |
+
|
| 151 |
+
@st.cache_resource
|
| 152 |
+
def load_blip():
|
| 153 |
+
source = _resolve_source(
|
| 154 |
+
os.getenv("BLIP_LOCAL_DIR", "saved_model_phase2"),
|
| 155 |
+
os.getenv("BLIP_MODEL_ID", "pchandragrid/blip-caption-model"),
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
| 159 |
+
source,
|
| 160 |
+
torch_dtype=_TORCH_DTYPE,
|
| 161 |
+
low_cpu_mem_usage=True,
|
| 162 |
+
)
|
| 163 |
+
processor = BlipProcessor.from_pretrained(source)
|
| 164 |
+
model.to(device)
|
| 165 |
+
model.eval()
|
| 166 |
+
return model, processor
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@st.cache_resource
|
| 170 |
+
def load_vit_gpt2():
|
| 171 |
+
source = _resolve_source(
|
| 172 |
+
os.getenv("VITGPT2_LOCAL_DIR", "saved_vit_gpt2"),
|
| 173 |
+
os.getenv("VITGPT2_MODEL_ID", "pchandragrid/vit-gpt2-caption-model"),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
model = VisionEncoderDecoderModel.from_pretrained(
|
| 177 |
+
source,
|
| 178 |
+
torch_dtype=_TORCH_DTYPE,
|
| 179 |
+
low_cpu_mem_usage=True,
|
| 180 |
+
)
|
| 181 |
+
processor = ViTImageProcessor.from_pretrained(source)
|
| 182 |
+
tokenizer = AutoTokenizer.from_pretrained(source)
|
| 183 |
+
model.to(device)
|
| 184 |
+
model.eval()
|
| 185 |
+
return model, processor, tokenizer
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@st.cache_resource
|
| 189 |
+
def load_git():
|
| 190 |
+
source = _resolve_source(
|
| 191 |
+
os.getenv("GIT_LOCAL_DIR", "saved_git_model"),
|
| 192 |
+
os.getenv("GIT_MODEL_ID", "pchandragrid/git-caption-model"),
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
processor = GitProcessor.from_pretrained(source)
|
| 196 |
+
model = GitForCausalLM.from_pretrained(
|
| 197 |
+
source,
|
| 198 |
+
torch_dtype=_TORCH_DTYPE,
|
| 199 |
+
low_cpu_mem_usage=True,
|
| 200 |
+
)
|
| 201 |
+
model.to(device)
|
| 202 |
+
model.eval()
|
| 203 |
+
return model, processor
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# ================================
|
| 207 |
+
# HEADER
|
| 208 |
+
# ================================
|
| 209 |
+
|
| 210 |
+
st.markdown('<div class="main-title">🖼️ Image Captioning</div>', unsafe_allow_html=True)
|
| 211 |
+
|
| 212 |
+
st.markdown(
|
| 213 |
+
'<div class="subtitle">Compare BLIP vs ViT-GPT2 vs GIT on the same image</div>',
|
| 214 |
+
unsafe_allow_html=True
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
st.markdown("""
|
| 219 |
+
### 📌 Project Overview
|
| 220 |
+
|
| 221 |
+
This project focuses on **automatic image caption generation using transformer-based vision-language models**.
|
| 222 |
+
|
| 223 |
+
The system takes an input image and generates a natural language description of the scene.
|
| 224 |
+
|
| 225 |
+
Three architectures are evaluated:
|
| 226 |
+
|
| 227 |
+
• **BLIP (Bootstrapping Language Image Pretraining)** – multimodal transformer designed specifically for vision-language tasks
|
| 228 |
+
• **ViT-GPT2** – Vision Transformer encoder combined with GPT2 text decoder
|
| 229 |
+
• **GIT (Generative Image-to-Text Transformer)** – unified transformer architecture for image-to-text generation
|
| 230 |
+
|
| 231 |
+
The goal of this project is to **compare model architectures, caption quality, and generation performance** using the COCO dataset.
|
| 232 |
+
|
| 233 |
+
---
|
| 234 |
+
|
| 235 |
+
### 🎯 Project Objective
|
| 236 |
+
|
| 237 |
+
Improve caption generation performance through **fine-tuning and decoding optimization**.
|
| 238 |
+
|
| 239 |
+
Training pipeline:
|
| 240 |
+
|
| 241 |
+
**Step 1 — Dataset Preparation**
|
| 242 |
+
- Use **MS COCO captions dataset**
|
| 243 |
+
- Train on a **10k–50k image-caption subset**
|
| 244 |
+
|
| 245 |
+
**Step 2 — Model Fine-Tuning**
|
| 246 |
+
- Fine-tune **BLIP or VisionEncoderDecoder models**
|
| 247 |
+
|
| 248 |
+
**Step 3 — Training Configuration**
|
| 249 |
+
- Train with image resolution **224–384 px**
|
| 250 |
+
- Train for **3 epochs**
|
| 251 |
+
|
| 252 |
+
**Step 4 — Memory Optimization**
|
| 253 |
+
- Use **gradient checkpointing** to reduce GPU memory usage
|
| 254 |
+
|
| 255 |
+
**Step 5 — Target Performance**
|
| 256 |
+
- Achieve **10%+ improvement in CIDEr score** compared to baseline models
|
| 257 |
+
|
| 258 |
+
These steps allow the system to learn stronger **image-text alignment and caption generation capability**.
|
| 259 |
+
""")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# ================================
|
| 263 |
+
# SIDEBAR
|
| 264 |
+
# ================================
|
| 265 |
+
|
| 266 |
+
st.sidebar.header("⚙️ Generation Settings")
|
| 267 |
+
|
| 268 |
+
st.sidebar.subheader("Models to run")
|
| 269 |
+
run_blip = st.sidebar.checkbox("BLIP", value=True)
|
| 270 |
+
run_vit = st.sidebar.checkbox("ViT-GPT2", value=False)
|
| 271 |
+
run_git = st.sidebar.checkbox("GIT", value=False)
|
| 272 |
+
|
| 273 |
+
num_beams = st.sidebar.slider("Beam Size",1,10,5)
|
| 274 |
+
max_length = st.sidebar.slider("Max Length",10,50,20)
|
| 275 |
+
length_penalty = st.sidebar.slider("Length Penalty",0.5,2.0,1.0,step=0.1)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
uploaded_file = st.file_uploader("Upload Image", type=["jpg","png","jpeg"])
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# ================================
|
| 282 |
+
# IMAGE DISPLAY
|
| 283 |
+
# ================================
|
| 284 |
+
|
| 285 |
+
if uploaded_file:
|
| 286 |
+
|
| 287 |
+
image = Image.open(uploaded_file).convert("RGB")
|
| 288 |
+
|
| 289 |
+
st.markdown(
|
| 290 |
+
"""
|
| 291 |
+
<div style="text-align:center;font-size:22px;font-weight:bold;margin-bottom:10px;">
|
| 292 |
+
Uploaded Image
|
| 293 |
+
</div>
|
| 294 |
+
""",
|
| 295 |
+
unsafe_allow_html=True
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
st.image(image, use_container_width=True)
|
| 299 |
+
|
| 300 |
+
if st.button("Generate Captions"):
|
| 301 |
+
|
| 302 |
+
with st.spinner("Running models..."):
|
| 303 |
+
|
| 304 |
+
if not any([run_blip, run_vit, run_git]):
|
| 305 |
+
st.warning("Select at least one model in the sidebar.")
|
| 306 |
+
st.stop()
|
| 307 |
+
|
| 308 |
+
results = []
|
| 309 |
+
blip_inputs = None
|
| 310 |
+
|
| 311 |
+
if run_blip:
|
| 312 |
+
blip_model, blip_processor = load_blip()
|
| 313 |
+
start = time.time()
|
| 314 |
+
blip_inputs = blip_processor(images=image, return_tensors="pt").to(device)
|
| 315 |
+
with torch.no_grad():
|
| 316 |
+
blip_ids = blip_model.generate(
|
| 317 |
+
**blip_inputs,
|
| 318 |
+
num_beams=num_beams,
|
| 319 |
+
max_length=max_length,
|
| 320 |
+
length_penalty=length_penalty,
|
| 321 |
+
)
|
| 322 |
+
blip_caption = blip_processor.decode(blip_ids[0], skip_special_tokens=True)
|
| 323 |
+
results.append(("BLIP", blip_caption, time.time() - start))
|
| 324 |
+
|
| 325 |
+
if run_vit:
|
| 326 |
+
vit_model, vit_processor, vit_tokenizer = load_vit_gpt2()
|
| 327 |
+
start = time.time()
|
| 328 |
+
pixel_values = vit_processor(images=image, return_tensors="pt").pixel_values.to(device)
|
| 329 |
+
with torch.no_grad():
|
| 330 |
+
vit_ids = vit_model.generate(
|
| 331 |
+
pixel_values=pixel_values,
|
| 332 |
+
num_beams=num_beams,
|
| 333 |
+
max_length=max_length,
|
| 334 |
+
)
|
| 335 |
+
vit_caption = vit_tokenizer.decode(vit_ids[0], skip_special_tokens=True)
|
| 336 |
+
results.append(("ViT-GPT2", vit_caption, time.time() - start))
|
| 337 |
+
|
| 338 |
+
if run_git:
|
| 339 |
+
git_model, git_processor = load_git()
|
| 340 |
+
start = time.time()
|
| 341 |
+
git_inputs = git_processor(images=image, return_tensors="pt").to(device)
|
| 342 |
+
with torch.no_grad():
|
| 343 |
+
git_ids = git_model.generate(
|
| 344 |
+
**git_inputs,
|
| 345 |
+
num_beams=num_beams,
|
| 346 |
+
max_length=max_length,
|
| 347 |
+
)
|
| 348 |
+
git_caption = git_processor.batch_decode(git_ids, skip_special_tokens=True)[0]
|
| 349 |
+
results.append(("GIT", git_caption, time.time() - start))
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
st.divider()
|
| 353 |
+
|
| 354 |
+
st.subheader("Model Comparison")
|
| 355 |
+
|
| 356 |
+
st.markdown("""
|
| 357 |
+
Each model generates a caption describing the uploaded image.
|
| 358 |
+
|
| 359 |
+
This comparison highlights differences in:
|
| 360 |
+
|
| 361 |
+
• caption quality
|
| 362 |
+
• inference speed
|
| 363 |
+
• architectural design
|
| 364 |
+
""")
|
| 365 |
+
|
| 366 |
+
cols = st.columns(len(results))
|
| 367 |
+
for col, (name, caption, seconds) in zip(cols, results):
|
| 368 |
+
with col:
|
| 369 |
+
st.markdown(f'<div class="model-title">{name}</div>', unsafe_allow_html=True)
|
| 370 |
+
st.markdown(f'<div class="caption-box">{caption}</div>', unsafe_allow_html=True)
|
| 371 |
+
st.caption(f"Inference: {seconds:.2f}s")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
st.divider()
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# ================================
|
| 378 |
+
# ATTENTION HEATMAP
|
| 379 |
+
# ================================
|
| 380 |
+
|
| 381 |
+
if run_blip and blip_inputs is not None:
|
| 382 |
+
blip_model, _ = load_blip()
|
| 383 |
+
with torch.no_grad():
|
| 384 |
+
vision_outputs = blip_model.vision_model(
|
| 385 |
+
blip_inputs["pixel_values"],
|
| 386 |
+
output_attentions=True,
|
| 387 |
+
return_dict=True,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
attentions = vision_outputs.attentions[-1]
|
| 391 |
+
|
| 392 |
+
attn = attentions[0].mean(0)
|
| 393 |
+
cls_attn = attn[0, 1:]
|
| 394 |
+
|
| 395 |
+
attn_map = cls_attn.cpu().numpy()
|
| 396 |
+
attn_map = attn_map / attn_map.max()
|
| 397 |
+
|
| 398 |
+
size = int(np.sqrt(len(attn_map)))
|
| 399 |
+
|
| 400 |
+
fig, ax = plt.subplots(figsize=(6, 6))
|
| 401 |
+
|
| 402 |
+
ax.imshow(attn_map.reshape(size, size), cmap="viridis")
|
| 403 |
+
ax.set_title("BLIP Vision Attention")
|
| 404 |
+
ax.axis("off")
|
| 405 |
+
|
| 406 |
+
st.pyplot(fig, use_container_width=True)
|
| 407 |
+
|
| 408 |
+
st.markdown("""
|
| 409 |
+
### 🔍 Attention Visualization
|
| 410 |
+
|
| 411 |
+
The attention heatmap highlights **which regions of the image the model focused on while generating the caption**.
|
| 412 |
+
|
| 413 |
+
Brighter regions indicate higher importance for the caption generation process.
|
| 414 |
+
""")
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# ================================
|
| 418 |
+
# ARCHITECTURE COMPARISON TABLE
|
| 419 |
+
# ================================
|
| 420 |
+
|
| 421 |
+
st.divider()
|
| 422 |
+
st.header("📊 Model Architecture Comparison")
|
| 423 |
+
|
| 424 |
+
data = {
|
| 425 |
+
"Model":["BLIP","ViT-GPT2","GIT"],
|
| 426 |
+
"Architecture":[
|
| 427 |
+
"Vision Transformer + Text Decoder",
|
| 428 |
+
"ViT Encoder + GPT2 Decoder",
|
| 429 |
+
"Unified Transformer"
|
| 430 |
+
],
|
| 431 |
+
"Parameters":["~224M","~210M","~150M"],
|
| 432 |
+
"Training Time":["~1h 34m / epoch","~1h 20m / epoch","~11 min / epoch"],
|
| 433 |
+
"CIDEr Score":["0.61","0.60","0.17"]
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
df = pd.DataFrame(data)
|
| 437 |
+
|
| 438 |
+
st.table(df)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
# ================================
|
| 442 |
+
# EXPERIMENT GRAPHS
|
| 443 |
+
# ================================
|
| 444 |
+
|
| 445 |
+
st.divider()
|
| 446 |
+
st.header("📊 Experiment Analysis")
|
| 447 |
+
|
| 448 |
+
st.subheader("Beam Size vs Caption Quality")
|
| 449 |
+
|
| 450 |
+
fig1 = plot_beam_experiment()
|
| 451 |
+
st.pyplot(fig1, use_container_width=True)
|
| 452 |
+
|
| 453 |
+
st.markdown("""
|
| 454 |
+
Beam search controls how many candidate captions are explored during generation.
|
| 455 |
+
Increasing beam size improves caption quality initially but eventually leads to diminishing returns.
|
| 456 |
+
""")
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
st.divider()
|
| 460 |
+
|
| 461 |
+
st.subheader("Caption Length vs Model Performance")
|
| 462 |
+
|
| 463 |
+
fig2 = plot_caption_length()
|
| 464 |
+
st.pyplot(fig2, use_container_width=True)
|
| 465 |
+
|
| 466 |
+
st.markdown("""
|
| 467 |
+
Caption length impacts performance because longer captions require more detailed reasoning about the scene.
|
| 468 |
+
Models generally perform better on shorter captions.
|
| 469 |
+
""")
|
app/streamlit_app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from transformers import (
|
| 8 |
+
AutoModelForSequenceClassification,
|
| 9 |
+
AutoTokenizer,
|
| 10 |
+
BlipForConditionalGeneration,
|
| 11 |
+
BlipProcessor,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@st.cache_resource
|
| 16 |
+
def load_caption_model():
|
| 17 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
+
model = BlipForConditionalGeneration.from_pretrained("saved_model_phase2")
|
| 20 |
+
processor = BlipProcessor.from_pretrained("saved_model_phase2")
|
| 21 |
+
|
| 22 |
+
model.to(device)
|
| 23 |
+
model.eval()
|
| 24 |
+
|
| 25 |
+
return model, processor, device
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@st.cache_resource
|
| 29 |
+
def load_toxicity_model():
|
| 30 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 31 |
+
|
| 32 |
+
tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
|
| 33 |
+
model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert")
|
| 34 |
+
|
| 35 |
+
model.to(device)
|
| 36 |
+
model.eval()
|
| 37 |
+
|
| 38 |
+
return model, tokenizer, device
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
caption_model, caption_processor, device = load_caption_model()
|
| 42 |
+
tox_model, tox_tokenizer, tox_device = load_toxicity_model()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
st.title("🖼️ Advanced Image Captioning Demo")
|
| 46 |
+
st.write("Fine-tuned BLIP with Beam Search + Toxicity Filtering")
|
| 47 |
+
|
| 48 |
+
st.sidebar.header("⚙️ Generation Settings")
|
| 49 |
+
|
| 50 |
+
num_beams = st.sidebar.slider("Beam Size", 1, 10, 5)
|
| 51 |
+
max_length = st.sidebar.slider("Max Length", 10, 50, 20)
|
| 52 |
+
length_penalty = st.sidebar.slider("Length Penalty", 0.5, 2.0, 1.0, step=0.1)
|
| 53 |
+
|
| 54 |
+
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
|
| 55 |
+
|
| 56 |
+
if uploaded_file:
|
| 57 |
+
image = Image.open(uploaded_file).convert("RGB")
|
| 58 |
+
st.image(image, caption="Uploaded Image", width="stretch")
|
| 59 |
+
|
| 60 |
+
if st.button("Generate Caption"):
|
| 61 |
+
# Generate caption
|
| 62 |
+
with st.spinner("Generating caption..."):
|
| 63 |
+
inputs = caption_processor(
|
| 64 |
+
images=image,
|
| 65 |
+
return_tensors="pt",
|
| 66 |
+
).to(device)
|
| 67 |
+
|
| 68 |
+
with torch.no_grad():
|
| 69 |
+
output_ids = caption_model.generate(
|
| 70 |
+
**inputs,
|
| 71 |
+
num_beams=num_beams,
|
| 72 |
+
max_length=max_length,
|
| 73 |
+
length_penalty=length_penalty,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
caption = caption_processor.decode(
|
| 77 |
+
output_ids[0],
|
| 78 |
+
skip_special_tokens=True,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Confidence score (stable)
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
loss_inputs = caption_processor(
|
| 84 |
+
images=image,
|
| 85 |
+
text=caption,
|
| 86 |
+
return_tensors="pt",
|
| 87 |
+
).to(device)
|
| 88 |
+
|
| 89 |
+
outputs = caption_model(
|
| 90 |
+
pixel_values=loss_inputs["pixel_values"],
|
| 91 |
+
input_ids=loss_inputs["input_ids"],
|
| 92 |
+
attention_mask=loss_inputs["attention_mask"],
|
| 93 |
+
labels=loss_inputs["input_ids"],
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
loss = outputs.loss
|
| 97 |
+
confidence = torch.exp(-loss).item() if loss is not None else 0.0
|
| 98 |
+
|
| 99 |
+
# Toxicity check
|
| 100 |
+
tox_inputs = tox_tokenizer(
|
| 101 |
+
caption,
|
| 102 |
+
return_tensors="pt",
|
| 103 |
+
truncation=True,
|
| 104 |
+
).to(tox_device)
|
| 105 |
+
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
tox_outputs = tox_model(**tox_inputs)
|
| 108 |
+
probs = F.softmax(tox_outputs.logits, dim=-1)
|
| 109 |
+
|
| 110 |
+
toxic_score = probs[0][1].item()
|
| 111 |
+
|
| 112 |
+
# Display caption
|
| 113 |
+
if toxic_score > 0.6:
|
| 114 |
+
st.error("⚠️ Generated caption flagged as potentially toxic.")
|
| 115 |
+
st.markdown("### 🚫 Caption Blocked")
|
| 116 |
+
else:
|
| 117 |
+
st.success("Caption Generated")
|
| 118 |
+
st.markdown(f"### 📝 {caption}")
|
| 119 |
+
st.caption(f"Toxicity Score: {toxic_score:.2f}")
|
| 120 |
+
st.caption(f"Confidence Score: {confidence:.2f}")
|
| 121 |
+
|
| 122 |
+
# Vision attention heatmap
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
vision_outputs = caption_model.vision_model(
|
| 125 |
+
inputs["pixel_values"],
|
| 126 |
+
output_attentions=True,
|
| 127 |
+
return_dict=True,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
attentions = vision_outputs.attentions[-1]
|
| 131 |
+
attn = attentions[0].mean(0)
|
| 132 |
+
|
| 133 |
+
cls_attn = attn[0, 1:]
|
| 134 |
+
attn_map = cls_attn.cpu().numpy()
|
| 135 |
+
attn_map = attn_map / attn_map.max()
|
| 136 |
+
|
| 137 |
+
size = int(np.sqrt(len(attn_map)))
|
| 138 |
+
|
| 139 |
+
fig, ax = plt.subplots()
|
| 140 |
+
ax.imshow(attn_map.reshape(size, size), cmap="viridis")
|
| 141 |
+
ax.set_title("Vision Attention Heatmap")
|
| 142 |
+
ax.axis("off")
|
| 143 |
+
|
| 144 |
+
st.pyplot(fig)
|
| 145 |
+
|
beam_search_experiments.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 4 |
+
from dataset_advanced import COCODataset
|
| 5 |
+
from torch.utils.data import random_split
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from pycocoevalcap.cider.cider import Cider
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def generate_caption(model, processor, image, device,
|
| 12 |
+
num_beams=5,
|
| 13 |
+
max_length=20,
|
| 14 |
+
length_penalty=1.0):
|
| 15 |
+
|
| 16 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
| 17 |
+
|
| 18 |
+
with torch.no_grad():
|
| 19 |
+
generated_ids = model.generate(
|
| 20 |
+
**inputs,
|
| 21 |
+
num_beams=num_beams,
|
| 22 |
+
max_length=max_length,
|
| 23 |
+
length_penalty=length_penalty
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
caption = processor.decode(
|
| 27 |
+
generated_ids[0],
|
| 28 |
+
skip_special_tokens=True
|
| 29 |
+
)
|
| 30 |
+
return caption
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def evaluate_config(model, processor, val_dataset, device,
|
| 34 |
+
num_beams, max_length, length_penalty,
|
| 35 |
+
max_samples=200):
|
| 36 |
+
|
| 37 |
+
model.eval()
|
| 38 |
+
cider_scorer = Cider()
|
| 39 |
+
|
| 40 |
+
ground_truth = {}
|
| 41 |
+
predictions = {}
|
| 42 |
+
|
| 43 |
+
print(f"\nTesting: beams={num_beams}, "
|
| 44 |
+
f"max_len={max_length}, "
|
| 45 |
+
f"len_penalty={length_penalty}")
|
| 46 |
+
|
| 47 |
+
for idx in tqdm(range(min(max_samples, len(val_dataset)))):
|
| 48 |
+
real_idx = val_dataset.indices[idx]
|
| 49 |
+
ann = val_dataset.dataset.annotations[real_idx]
|
| 50 |
+
|
| 51 |
+
image_path = os.path.join("train2017", ann["image"])
|
| 52 |
+
image = Image.open(image_path).convert("RGB")
|
| 53 |
+
|
| 54 |
+
pred_caption = generate_caption(
|
| 55 |
+
model,
|
| 56 |
+
processor,
|
| 57 |
+
image,
|
| 58 |
+
device,
|
| 59 |
+
num_beams=num_beams,
|
| 60 |
+
max_length=max_length,
|
| 61 |
+
length_penalty=length_penalty
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
ground_truth[idx] = ann["captions"]
|
| 65 |
+
predictions[idx] = [pred_caption]
|
| 66 |
+
|
| 67 |
+
score, _ = cider_scorer.compute_score(ground_truth, predictions)
|
| 68 |
+
|
| 69 |
+
print(f"CIDEr: {score:.4f}")
|
| 70 |
+
|
| 71 |
+
model.train()
|
| 72 |
+
return score
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def main():
|
| 76 |
+
|
| 77 |
+
if not torch.backends.mps.is_available():
|
| 78 |
+
raise RuntimeError("MPS not available.")
|
| 79 |
+
|
| 80 |
+
device = torch.device("mps")
|
| 81 |
+
print("Using device:", device)
|
| 82 |
+
|
| 83 |
+
# Load best Phase 2 model
|
| 84 |
+
model_dir = "saved_model_phase2"
|
| 85 |
+
|
| 86 |
+
processor = BlipProcessor.from_pretrained(model_dir)
|
| 87 |
+
model = BlipForConditionalGeneration.from_pretrained(model_dir)
|
| 88 |
+
|
| 89 |
+
model.to(device)
|
| 90 |
+
|
| 91 |
+
# Load validation split
|
| 92 |
+
full_dataset = COCODataset(
|
| 93 |
+
"annotations/subset_10k.jsonl",
|
| 94 |
+
"train2017",
|
| 95 |
+
processor
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
train_size = int(0.9 * len(full_dataset))
|
| 99 |
+
val_size = len(full_dataset) - train_size
|
| 100 |
+
|
| 101 |
+
_, val_dataset = random_split(
|
| 102 |
+
full_dataset,
|
| 103 |
+
[train_size, val_size]
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# =========================
|
| 107 |
+
# Experiment Grid
|
| 108 |
+
# =========================
|
| 109 |
+
|
| 110 |
+
beam_sizes = [5]
|
| 111 |
+
max_lengths = [20]
|
| 112 |
+
length_penalties = [1.0]
|
| 113 |
+
|
| 114 |
+
results = []
|
| 115 |
+
|
| 116 |
+
for beams in beam_sizes:
|
| 117 |
+
for max_len in max_lengths:
|
| 118 |
+
for lp in length_penalties:
|
| 119 |
+
|
| 120 |
+
score = evaluate_config(
|
| 121 |
+
model,
|
| 122 |
+
processor,
|
| 123 |
+
val_dataset,
|
| 124 |
+
device,
|
| 125 |
+
num_beams=beams,
|
| 126 |
+
max_length=max_len,
|
| 127 |
+
length_penalty=lp
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
results.append((beams, max_len, lp, score))
|
| 131 |
+
|
| 132 |
+
print("\n===== FINAL RESULTS =====")
|
| 133 |
+
for r in results:
|
| 134 |
+
print(f"Beams={r[0]}, MaxLen={r[1]}, "
|
| 135 |
+
f"LenPenalty={r[2]} -> CIDEr={r[3]:.4f}")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
main()
|
create_subset_20k.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
input_path = "annotations/captions_train.jsonl"
|
| 5 |
+
output_path = "annotations/subset_20k.jsonl"
|
| 6 |
+
|
| 7 |
+
with open(input_path, "r") as f:
|
| 8 |
+
data = [json.loads(line) for line in f]
|
| 9 |
+
|
| 10 |
+
subset = random.sample(data, 20000)
|
| 11 |
+
|
| 12 |
+
with open(output_path, "w") as f:
|
| 13 |
+
for item in subset:
|
| 14 |
+
f.write(json.dumps(item) + "\n")
|
| 15 |
+
|
| 16 |
+
print("20k subset created.")
|
dataset_384.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class COCODataset384(Dataset):
|
| 9 |
+
|
| 10 |
+
def __init__(self, annotation_path, image_folder, processor):
|
| 11 |
+
self.image_folder = image_folder
|
| 12 |
+
self.processor = processor
|
| 13 |
+
|
| 14 |
+
with open(annotation_path, "r") as f:
|
| 15 |
+
self.annotations = [json.loads(line) for line in f]
|
| 16 |
+
|
| 17 |
+
def __len__(self):
|
| 18 |
+
return len(self.annotations)
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, idx):
|
| 21 |
+
|
| 22 |
+
ann = self.annotations[idx]
|
| 23 |
+
caption = random.choice(ann["captions"])
|
| 24 |
+
|
| 25 |
+
image_path = os.path.join(self.image_folder, ann["image"])
|
| 26 |
+
image = Image.open(image_path).convert("RGB")
|
| 27 |
+
|
| 28 |
+
# 🔥 IMPORTANT: 384px
|
| 29 |
+
image = image.resize((384, 384))
|
| 30 |
+
|
| 31 |
+
encoding = self.processor(
|
| 32 |
+
image,
|
| 33 |
+
caption,
|
| 34 |
+
padding="max_length",
|
| 35 |
+
truncation=True,
|
| 36 |
+
return_tensors="pt"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
input_ids = encoding["input_ids"].squeeze(0)
|
| 40 |
+
|
| 41 |
+
return {
|
| 42 |
+
"pixel_values": encoding["pixel_values"].squeeze(0),
|
| 43 |
+
"input_ids": input_ids,
|
| 44 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 45 |
+
"labels": input_ids.clone()
|
| 46 |
+
}
|
dataset_advanced.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class COCODatasetAdvanced(Dataset):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
annotation_path,
|
| 12 |
+
image_folder,
|
| 13 |
+
processor,
|
| 14 |
+
mode="mixed",
|
| 15 |
+
max_length=40):
|
| 16 |
+
|
| 17 |
+
self.image_folder = image_folder
|
| 18 |
+
self.processor = processor
|
| 19 |
+
self.max_length = max_length
|
| 20 |
+
self.mode = mode
|
| 21 |
+
|
| 22 |
+
with open(annotation_path, "r") as f:
|
| 23 |
+
raw_data = [json.loads(line) for line in f]
|
| 24 |
+
|
| 25 |
+
self.annotations = []
|
| 26 |
+
|
| 27 |
+
for ann in raw_data:
|
| 28 |
+
|
| 29 |
+
filtered_captions = []
|
| 30 |
+
|
| 31 |
+
for cap in ann["captions"]:
|
| 32 |
+
|
| 33 |
+
cap = cap.strip().lower()
|
| 34 |
+
|
| 35 |
+
# ---------- QUALITY FILTERS ----------
|
| 36 |
+
|
| 37 |
+
# Remove very short captions
|
| 38 |
+
if len(cap.split()) < 3:
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
# Remove repeated words
|
| 42 |
+
words = cap.split()
|
| 43 |
+
if len(set(words)) < len(words) * 0.6:
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
# Remove non-alphabetic captions
|
| 47 |
+
if not re.search(r"[a-z]", cap):
|
| 48 |
+
continue
|
| 49 |
+
|
| 50 |
+
word_count = len(words)
|
| 51 |
+
|
| 52 |
+
# ---------- LENGTH FILTERS ----------
|
| 53 |
+
|
| 54 |
+
if self.mode == "short" and word_count <= 8:
|
| 55 |
+
filtered_captions.append(cap)
|
| 56 |
+
|
| 57 |
+
elif self.mode == "long" and word_count > 15:
|
| 58 |
+
filtered_captions.append(cap)
|
| 59 |
+
|
| 60 |
+
elif self.mode == "mixed":
|
| 61 |
+
filtered_captions.append(cap)
|
| 62 |
+
|
| 63 |
+
if len(filtered_captions) > 0:
|
| 64 |
+
self.annotations.append({
|
| 65 |
+
"image": ann["image"],
|
| 66 |
+
"captions": filtered_captions
|
| 67 |
+
})
|
| 68 |
+
|
| 69 |
+
def __len__(self):
|
| 70 |
+
return len(self.annotations)
|
| 71 |
+
|
| 72 |
+
def __getitem__(self, idx):
|
| 73 |
+
|
| 74 |
+
ann = self.annotations[idx]
|
| 75 |
+
file_name = ann["image"]
|
| 76 |
+
caption = random.choice(ann["captions"])
|
| 77 |
+
|
| 78 |
+
image_path = os.path.join(self.image_folder, file_name)
|
| 79 |
+
image = Image.open(image_path).convert("RGB")
|
| 80 |
+
|
| 81 |
+
encoding = self.processor(
|
| 82 |
+
images=image,
|
| 83 |
+
text=caption,
|
| 84 |
+
padding="max_length",
|
| 85 |
+
truncation=True,
|
| 86 |
+
max_length=self.max_length,
|
| 87 |
+
return_tensors="pt"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
input_ids = encoding["input_ids"].squeeze(0)
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"pixel_values": encoding["pixel_values"].squeeze(0),
|
| 94 |
+
"input_ids": input_ids,
|
| 95 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 96 |
+
"labels": input_ids.clone()
|
| 97 |
+
}
|
dataset_git.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class COCODatasetGIT(Dataset):
|
| 10 |
+
|
| 11 |
+
def __init__(self, annotation_file, image_folder, processor, mode="mixed"):
|
| 12 |
+
|
| 13 |
+
self.annotations = []
|
| 14 |
+
self.image_folder = image_folder
|
| 15 |
+
self.processor = processor
|
| 16 |
+
self.mode = mode
|
| 17 |
+
|
| 18 |
+
# Proper JSONL loading
|
| 19 |
+
with open(annotation_file, "r") as f:
|
| 20 |
+
for line in f:
|
| 21 |
+
self.annotations.append(json.loads(line.strip()))
|
| 22 |
+
|
| 23 |
+
def __len__(self):
|
| 24 |
+
return len(self.annotations)
|
| 25 |
+
|
| 26 |
+
def select_caption(self, captions):
|
| 27 |
+
|
| 28 |
+
if self.mode == "short":
|
| 29 |
+
captions = [c for c in captions if len(c.split()) <= 10]
|
| 30 |
+
|
| 31 |
+
elif self.mode == "long":
|
| 32 |
+
captions = [c for c in captions if len(c.split()) > 10]
|
| 33 |
+
|
| 34 |
+
if len(captions) == 0:
|
| 35 |
+
captions = self.annotations[
|
| 36 |
+
random.randint(0, len(self.annotations) - 1)
|
| 37 |
+
]["captions"]
|
| 38 |
+
|
| 39 |
+
return random.choice(captions)
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, idx):
|
| 42 |
+
|
| 43 |
+
ann = self.annotations[idx]
|
| 44 |
+
|
| 45 |
+
image_path = os.path.join(self.image_folder, ann["image"])
|
| 46 |
+
image = Image.open(image_path).convert("RGB")
|
| 47 |
+
|
| 48 |
+
caption = self.select_caption(ann["captions"])
|
| 49 |
+
|
| 50 |
+
encoding = self.processor(
|
| 51 |
+
images=image,
|
| 52 |
+
text=caption,
|
| 53 |
+
padding="max_length",
|
| 54 |
+
truncation=True,
|
| 55 |
+
max_length=30,
|
| 56 |
+
return_tensors="pt"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
input_ids = encoding["input_ids"].squeeze(0)
|
| 60 |
+
attention_mask = encoding["attention_mask"].squeeze(0)
|
| 61 |
+
pixel_values = encoding["pixel_values"].squeeze(0)
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
"pixel_values": pixel_values,
|
| 65 |
+
"input_ids": input_ids,
|
| 66 |
+
"attention_mask": attention_mask,
|
| 67 |
+
"labels": input_ids # GIT uses input_ids as labels
|
| 68 |
+
}
|
dataset_vit_gpt2.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class COCODatasetViTGPT2(Dataset):
|
| 9 |
+
def __init__(self,
|
| 10 |
+
annotation_path,
|
| 11 |
+
image_folder,
|
| 12 |
+
image_processor,
|
| 13 |
+
tokenizer,
|
| 14 |
+
mode="short",
|
| 15 |
+
max_length=20):
|
| 16 |
+
|
| 17 |
+
self.image_folder = image_folder
|
| 18 |
+
self.image_processor = image_processor
|
| 19 |
+
self.tokenizer = tokenizer
|
| 20 |
+
self.max_length = max_length
|
| 21 |
+
self.mode = mode
|
| 22 |
+
|
| 23 |
+
with open(annotation_path, "r") as f:
|
| 24 |
+
raw_data = [json.loads(line) for line in f]
|
| 25 |
+
|
| 26 |
+
self.annotations = []
|
| 27 |
+
|
| 28 |
+
for ann in raw_data:
|
| 29 |
+
filtered = []
|
| 30 |
+
|
| 31 |
+
for cap in ann["captions"]:
|
| 32 |
+
words = cap.split()
|
| 33 |
+
wc = len(words)
|
| 34 |
+
|
| 35 |
+
if mode == "short" and wc <= 8:
|
| 36 |
+
filtered.append(cap)
|
| 37 |
+
elif mode == "long" and wc > 15:
|
| 38 |
+
filtered.append(cap)
|
| 39 |
+
elif mode == "mixed":
|
| 40 |
+
filtered.append(cap)
|
| 41 |
+
|
| 42 |
+
if len(filtered) > 0:
|
| 43 |
+
self.annotations.append({
|
| 44 |
+
"image": ann["image"],
|
| 45 |
+
"captions": filtered
|
| 46 |
+
})
|
| 47 |
+
|
| 48 |
+
def __len__(self):
|
| 49 |
+
return len(self.annotations)
|
| 50 |
+
|
| 51 |
+
def __getitem__(self, idx):
|
| 52 |
+
|
| 53 |
+
ann = self.annotations[idx]
|
| 54 |
+
caption = random.choice(ann["captions"])
|
| 55 |
+
|
| 56 |
+
image_path = os.path.join(self.image_folder, ann["image"])
|
| 57 |
+
image = Image.open(image_path).convert("RGB")
|
| 58 |
+
|
| 59 |
+
pixel_values = self.image_processor(
|
| 60 |
+
images=image,
|
| 61 |
+
return_tensors="pt"
|
| 62 |
+
).pixel_values.squeeze(0)
|
| 63 |
+
|
| 64 |
+
tokenized = self.tokenizer(
|
| 65 |
+
caption,
|
| 66 |
+
padding="max_length",
|
| 67 |
+
truncation=True,
|
| 68 |
+
max_length=self.max_length,
|
| 69 |
+
return_tensors="pt"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
input_ids = tokenized.input_ids.squeeze(0)
|
| 73 |
+
|
| 74 |
+
return {
|
| 75 |
+
"pixel_values": pixel_values,
|
| 76 |
+
"labels": input_ids
|
| 77 |
+
}
|
evaluate.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from transformers import (
|
| 6 |
+
BlipProcessor,
|
| 7 |
+
BlipForConditionalGeneration,
|
| 8 |
+
AutoTokenizer,
|
| 9 |
+
AutoModelForSequenceClassification
|
| 10 |
+
)
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------
|
| 14 |
+
# Load Models
|
| 15 |
+
# ---------------------------------------
|
| 16 |
+
def load_models():
|
| 17 |
+
|
| 18 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
print("Using device:", device)
|
| 21 |
+
|
| 22 |
+
caption_model = BlipForConditionalGeneration.from_pretrained("saved_model_phase2")
|
| 23 |
+
caption_processor = BlipProcessor.from_pretrained("saved_model_phase2")
|
| 24 |
+
|
| 25 |
+
caption_model.to(device)
|
| 26 |
+
caption_model.eval()
|
| 27 |
+
|
| 28 |
+
# Toxicity model
|
| 29 |
+
tox_tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
|
| 30 |
+
tox_model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert")
|
| 31 |
+
|
| 32 |
+
tox_model.to(device)
|
| 33 |
+
tox_model.eval()
|
| 34 |
+
|
| 35 |
+
return caption_model, caption_processor, tox_model, tox_tokenizer, device
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------
|
| 39 |
+
# Generate Caption + Confidence
|
| 40 |
+
# ---------------------------------------
|
| 41 |
+
def generate_caption(model, processor, image, device):
|
| 42 |
+
|
| 43 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
| 44 |
+
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
outputs = model.generate(
|
| 47 |
+
**inputs,
|
| 48 |
+
num_beams=5,
|
| 49 |
+
max_length=20,
|
| 50 |
+
length_penalty=1.0,
|
| 51 |
+
output_scores=True,
|
| 52 |
+
return_dict_in_generate=True
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
generated_ids = outputs.sequences
|
| 56 |
+
caption = processor.decode(
|
| 57 |
+
generated_ids[0],
|
| 58 |
+
skip_special_tokens=True
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# True confidence
|
| 62 |
+
seq_score = outputs.sequences_scores[0]
|
| 63 |
+
confidence = torch.exp(seq_score).item()
|
| 64 |
+
|
| 65 |
+
return caption, confidence
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ---------------------------------------
|
| 69 |
+
# Toxicity Score
|
| 70 |
+
# ---------------------------------------
|
| 71 |
+
def check_toxicity(tox_model, tox_tokenizer, caption, device):
|
| 72 |
+
|
| 73 |
+
inputs = tox_tokenizer(
|
| 74 |
+
caption,
|
| 75 |
+
return_tensors="pt",
|
| 76 |
+
truncation=True
|
| 77 |
+
).to(device)
|
| 78 |
+
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
outputs = tox_model(**inputs)
|
| 81 |
+
probs = F.softmax(outputs.logits, dim=-1)
|
| 82 |
+
|
| 83 |
+
toxic_score = probs[0][1].item()
|
| 84 |
+
return toxic_score
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ---------------------------------------
|
| 88 |
+
# Evaluate Single Image
|
| 89 |
+
# ---------------------------------------
|
| 90 |
+
def evaluate_image(image_path, models):
|
| 91 |
+
|
| 92 |
+
caption_model, caption_processor, tox_model, tox_tokenizer, device = models
|
| 93 |
+
|
| 94 |
+
image = Image.open(image_path).convert("RGB")
|
| 95 |
+
|
| 96 |
+
caption, confidence = generate_caption(
|
| 97 |
+
caption_model,
|
| 98 |
+
caption_processor,
|
| 99 |
+
image,
|
| 100 |
+
device
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
toxic_score = check_toxicity(
|
| 104 |
+
tox_model,
|
| 105 |
+
tox_tokenizer,
|
| 106 |
+
caption,
|
| 107 |
+
device
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
print("\n===================================")
|
| 111 |
+
print("Image:", image_path)
|
| 112 |
+
print("Caption:", caption)
|
| 113 |
+
print(f"Confidence: {confidence:.3f}")
|
| 114 |
+
print(f"Toxicity Score: {toxic_score:.3f}")
|
| 115 |
+
|
| 116 |
+
if toxic_score > 0.6:
|
| 117 |
+
print("⚠️ WARNING: Caption flagged as toxic")
|
| 118 |
+
print("===================================\n")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ---------------------------------------
|
| 122 |
+
# Main
|
| 123 |
+
# ---------------------------------------
|
| 124 |
+
def main():
|
| 125 |
+
|
| 126 |
+
parser = argparse.ArgumentParser()
|
| 127 |
+
parser.add_argument("--image", type=str, help="Path to single image")
|
| 128 |
+
parser.add_argument("--folder", type=str, help="Path to folder of images")
|
| 129 |
+
|
| 130 |
+
args = parser.parse_args()
|
| 131 |
+
|
| 132 |
+
if not args.image and not args.folder:
|
| 133 |
+
print("Please provide --image or --folder")
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
models = load_models()
|
| 137 |
+
|
| 138 |
+
if args.image:
|
| 139 |
+
evaluate_image(args.image, models)
|
| 140 |
+
|
| 141 |
+
if args.folder:
|
| 142 |
+
for file in os.listdir(args.folder):
|
| 143 |
+
if file.lower().endswith((".jpg", ".jpeg", ".png")):
|
| 144 |
+
path = os.path.join(args.folder, file)
|
| 145 |
+
evaluate_image(path, models)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
main()
|
plot/beam_experiment_plot.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
|
| 3 |
+
# Beam sizes tested
|
| 4 |
+
beam_sizes = [1, 3, 5, 10]
|
| 5 |
+
|
| 6 |
+
# Example CIDEr scores from experiments
|
| 7 |
+
blip_scores = [0.52, 0.59, 0.61, 0.60]
|
| 8 |
+
vit_scores = [0.50, 0.56, 0.60, 0.58]
|
| 9 |
+
git_scores = [0.12, 0.16, 0.17, 0.16]
|
| 10 |
+
|
| 11 |
+
plt.figure(figsize=(8,5))
|
| 12 |
+
|
| 13 |
+
plt.plot(beam_sizes, blip_scores, marker='o', label="BLIP")
|
| 14 |
+
plt.plot(beam_sizes, vit_scores, marker='o', label="ViT-GPT2")
|
| 15 |
+
plt.plot(beam_sizes, git_scores, marker='o', label="GIT")
|
| 16 |
+
|
| 17 |
+
plt.xlabel("Beam Size")
|
| 18 |
+
plt.ylabel("CIDEr Score")
|
| 19 |
+
plt.title("Effect of Beam Size on Caption Quality")
|
| 20 |
+
|
| 21 |
+
plt.legend()
|
| 22 |
+
|
| 23 |
+
plt.grid(True)
|
| 24 |
+
|
| 25 |
+
plt.savefig("beam_search_experiment.png", dpi=300)
|
| 26 |
+
|
| 27 |
+
plt.show()
|
plot/caption_length_analysis.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
ANNOTATION_FILE = "annotations/captions_validation.jsonl"
|
| 6 |
+
|
| 7 |
+
short = []
|
| 8 |
+
medium = []
|
| 9 |
+
long = []
|
| 10 |
+
|
| 11 |
+
with open(ANNOTATION_FILE) as f:
|
| 12 |
+
for line in f:
|
| 13 |
+
data = json.loads(line)
|
| 14 |
+
|
| 15 |
+
caption = data["captions"][0]
|
| 16 |
+
length = len(caption.split())
|
| 17 |
+
|
| 18 |
+
if length <= 8:
|
| 19 |
+
short.append(length)
|
| 20 |
+
|
| 21 |
+
elif length <= 15:
|
| 22 |
+
medium.append(length)
|
| 23 |
+
|
| 24 |
+
else:
|
| 25 |
+
long.append(length)
|
| 26 |
+
|
| 27 |
+
print("Short captions:", len(short))
|
| 28 |
+
print("Medium captions:", len(medium))
|
| 29 |
+
print("Long captions:", len(long))
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Example scores from your training logs
|
| 33 |
+
blip_scores = [0.71, 0.60, 0.48]
|
| 34 |
+
vit_scores = [0.65, 0.59, 0.42]
|
| 35 |
+
git_scores = [0.30, 0.18, 0.11]
|
| 36 |
+
|
| 37 |
+
labels = ["Short", "Medium", "Long"]
|
| 38 |
+
|
| 39 |
+
x = np.arange(len(labels))
|
| 40 |
+
width = 0.25
|
| 41 |
+
|
| 42 |
+
plt.figure(figsize=(9,5))
|
| 43 |
+
|
| 44 |
+
plt.bar(x - width, blip_scores, width, label="BLIP")
|
| 45 |
+
plt.bar(x, vit_scores, width, label="ViT-GPT2")
|
| 46 |
+
plt.bar(x + width, git_scores, width, label="GIT")
|
| 47 |
+
|
| 48 |
+
plt.xlabel("Caption Length")
|
| 49 |
+
plt.ylabel("CIDEr Score")
|
| 50 |
+
plt.title("Model Performance vs Caption Length")
|
| 51 |
+
|
| 52 |
+
plt.xticks(x, labels)
|
| 53 |
+
|
| 54 |
+
plt.legend()
|
| 55 |
+
|
| 56 |
+
plt.savefig("caption_length_analysis.png", dpi=300)
|
| 57 |
+
|
| 58 |
+
plt.show()
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
torch
|
| 3 |
+
transformers
|
| 4 |
+
accelerate
|
| 5 |
+
pillow
|
| 6 |
+
numpy
|
| 7 |
+
matplotlib
|
| 8 |
+
pandas
|
| 9 |
+
torch
|
| 10 |
+
torchvision
|
| 11 |
+
transformers
|
| 12 |
+
datasets
|
| 13 |
+
Pillow
|
| 14 |
+
numpy
|
| 15 |
+
tqdm
|
| 16 |
+
pycocoevalcap
|
| 17 |
+
streamlit
|
| 18 |
+
matplotlib
|
| 19 |
+
pandas
|
| 20 |
+
scikit-learn
|
src/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Top-level package for the image captioning project.
|
| 3 |
+
|
| 4 |
+
This package exposes the core modules used in training, evaluation,
|
| 5 |
+
and serving the captioning models.
|
| 6 |
+
"""
|
| 7 |
+
|
src/data/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data loading utilities and dataset definitions.
|
| 3 |
+
"""
|
| 4 |
+
|
src/data/coco_384_dataset.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from typing import Any, Dict
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class COCODataset384(Dataset):
|
| 11 |
+
"""
|
| 12 |
+
COCO-style dataset that always resizes images to 384x384 and uses
|
| 13 |
+
a BLIP-style processor for joint image-text encoding.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, annotation_path: str, image_folder: str, processor: Any) -> None:
|
| 17 |
+
self.image_folder = image_folder
|
| 18 |
+
self.processor = processor
|
| 19 |
+
|
| 20 |
+
with open(annotation_path, "r") as f:
|
| 21 |
+
self.annotations = [json.loads(line) for line in f]
|
| 22 |
+
|
| 23 |
+
def __len__(self) -> int:
|
| 24 |
+
return len(self.annotations)
|
| 25 |
+
|
| 26 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 27 |
+
ann = self.annotations[idx]
|
| 28 |
+
caption = random.choice(ann["captions"])
|
| 29 |
+
|
| 30 |
+
image_path = os.path.join(self.image_folder, ann["image"])
|
| 31 |
+
image = Image.open(image_path).convert("RGB")
|
| 32 |
+
|
| 33 |
+
# 384px resize for the vision backbone
|
| 34 |
+
image = image.resize((384, 384))
|
| 35 |
+
|
| 36 |
+
encoding = self.processor(
|
| 37 |
+
image,
|
| 38 |
+
caption,
|
| 39 |
+
padding="max_length",
|
| 40 |
+
truncation=True,
|
| 41 |
+
return_tensors="pt",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
input_ids = encoding["input_ids"].squeeze(0)
|
| 45 |
+
|
| 46 |
+
return {
|
| 47 |
+
"pixel_values": encoding["pixel_values"].squeeze(0),
|
| 48 |
+
"input_ids": input_ids,
|
| 49 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 50 |
+
"labels": input_ids.clone(),
|
| 51 |
+
}
|
| 52 |
+
|
src/data/coco_advanced_dataset.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
from typing import Any, Dict, List
|
| 6 |
+
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class COCODatasetAdvanced(Dataset):
|
| 12 |
+
"""
|
| 13 |
+
COCO dataset with caption quality and length filtering.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
annotation_path: str,
|
| 19 |
+
image_folder: str,
|
| 20 |
+
processor: Any,
|
| 21 |
+
mode: str = "mixed",
|
| 22 |
+
max_length: int = 40,
|
| 23 |
+
) -> None:
|
| 24 |
+
self.image_folder = image_folder
|
| 25 |
+
self.processor = processor
|
| 26 |
+
self.max_length = max_length
|
| 27 |
+
self.mode = mode
|
| 28 |
+
|
| 29 |
+
with open(annotation_path, "r") as f:
|
| 30 |
+
raw_data = [json.loads(line) for line in f]
|
| 31 |
+
|
| 32 |
+
self.annotations: List[Dict[str, Any]] = []
|
| 33 |
+
|
| 34 |
+
for ann in raw_data:
|
| 35 |
+
filtered_captions: List[str] = []
|
| 36 |
+
|
| 37 |
+
for cap in ann["captions"]:
|
| 38 |
+
cap = cap.strip().lower()
|
| 39 |
+
|
| 40 |
+
# Remove very short captions
|
| 41 |
+
if len(cap.split()) < 3:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
# Remove repeated words
|
| 45 |
+
words = cap.split()
|
| 46 |
+
if len(set(words)) < len(words) * 0.6:
|
| 47 |
+
continue
|
| 48 |
+
|
| 49 |
+
# Remove non-alphabetic captions
|
| 50 |
+
if not re.search(r"[a-z]", cap):
|
| 51 |
+
continue
|
| 52 |
+
|
| 53 |
+
word_count = len(words)
|
| 54 |
+
|
| 55 |
+
if self.mode == "short" and word_count <= 8:
|
| 56 |
+
filtered_captions.append(cap)
|
| 57 |
+
elif self.mode == "long" and word_count > 15:
|
| 58 |
+
filtered_captions.append(cap)
|
| 59 |
+
elif self.mode == "mixed":
|
| 60 |
+
filtered_captions.append(cap)
|
| 61 |
+
|
| 62 |
+
if filtered_captions:
|
| 63 |
+
self.annotations.append(
|
| 64 |
+
{
|
| 65 |
+
"image": ann["image"],
|
| 66 |
+
"captions": filtered_captions,
|
| 67 |
+
}
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def __len__(self) -> int:
|
| 71 |
+
return len(self.annotations)
|
| 72 |
+
|
| 73 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 74 |
+
ann = self.annotations[idx]
|
| 75 |
+
file_name = ann["image"]
|
| 76 |
+
caption = random.choice(ann["captions"])
|
| 77 |
+
|
| 78 |
+
image_path = os.path.join(self.image_folder, file_name)
|
| 79 |
+
image = Image.open(image_path).convert("RGB")
|
| 80 |
+
|
| 81 |
+
encoding = self.processor(
|
| 82 |
+
images=image,
|
| 83 |
+
text=caption,
|
| 84 |
+
padding="max_length",
|
| 85 |
+
truncation=True,
|
| 86 |
+
max_length=self.max_length,
|
| 87 |
+
return_tensors="pt",
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
input_ids = encoding["input_ids"].squeeze(0)
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"pixel_values": encoding["pixel_values"].squeeze(0),
|
| 94 |
+
"input_ids": input_ids,
|
| 95 |
+
"attention_mask": encoding["attention_mask"].squeeze(0),
|
| 96 |
+
"labels": input_ids.clone(),
|
| 97 |
+
}
|
| 98 |
+
|
src/data/coco_vit_gpt2_dataset.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from typing import Any, Dict, List
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class COCODatasetViTGPT2(Dataset):
|
| 11 |
+
"""
|
| 12 |
+
COCO dataset tailored for ViT + GPT-2 style architectures with
|
| 13 |
+
separate image processor and tokenizer.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
annotation_path: str,
|
| 19 |
+
image_folder: str,
|
| 20 |
+
image_processor: Any,
|
| 21 |
+
tokenizer: Any,
|
| 22 |
+
mode: str = "short",
|
| 23 |
+
max_length: int = 20,
|
| 24 |
+
) -> None:
|
| 25 |
+
self.image_folder = image_folder
|
| 26 |
+
self.image_processor = image_processor
|
| 27 |
+
self.tokenizer = tokenizer
|
| 28 |
+
self.max_length = max_length
|
| 29 |
+
self.mode = mode
|
| 30 |
+
|
| 31 |
+
with open(annotation_path, "r") as f:
|
| 32 |
+
raw_data = [json.loads(line) for line in f]
|
| 33 |
+
|
| 34 |
+
self.annotations: List[Dict[str, Any]] = []
|
| 35 |
+
|
| 36 |
+
for ann in raw_data:
|
| 37 |
+
filtered: List[str] = []
|
| 38 |
+
|
| 39 |
+
for cap in ann["captions"]:
|
| 40 |
+
words = cap.split()
|
| 41 |
+
wc = len(words)
|
| 42 |
+
|
| 43 |
+
if mode == "short" and wc <= 8:
|
| 44 |
+
filtered.append(cap)
|
| 45 |
+
elif mode == "long" and wc > 15:
|
| 46 |
+
filtered.append(cap)
|
| 47 |
+
elif mode == "mixed":
|
| 48 |
+
filtered.append(cap)
|
| 49 |
+
|
| 50 |
+
if filtered:
|
| 51 |
+
self.annotations.append(
|
| 52 |
+
{
|
| 53 |
+
"image": ann["image"],
|
| 54 |
+
"captions": filtered,
|
| 55 |
+
}
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def __len__(self) -> int:
|
| 59 |
+
return len(self.annotations)
|
| 60 |
+
|
| 61 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 62 |
+
ann = self.annotations[idx]
|
| 63 |
+
caption = random.choice(ann["captions"])
|
| 64 |
+
|
| 65 |
+
image_path = os.path.join(self.image_folder, ann["image"])
|
| 66 |
+
image = Image.open(image_path).convert("RGB")
|
| 67 |
+
|
| 68 |
+
pixel_values = self.image_processor(
|
| 69 |
+
images=image,
|
| 70 |
+
return_tensors="pt",
|
| 71 |
+
).pixel_values.squeeze(0)
|
| 72 |
+
|
| 73 |
+
tokenized = self.tokenizer(
|
| 74 |
+
caption,
|
| 75 |
+
padding="max_length",
|
| 76 |
+
truncation=True,
|
| 77 |
+
max_length=self.max_length,
|
| 78 |
+
return_tensors="pt",
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
input_ids = tokenized.input_ids.squeeze(0)
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
"pixel_values": pixel_values,
|
| 85 |
+
"labels": input_ids,
|
| 86 |
+
}
|
| 87 |
+
|
src/evaluation/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation utilities (e.g., CIDEr scoring).
|
| 3 |
+
"""
|
| 4 |
+
|
src/evaluation/cider_eval.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from pycocoevalcap.cider.cider import Cider
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def generate_caption(model: Any, processor: Any, image: Image.Image, device) -> str:
|
| 10 |
+
"""
|
| 11 |
+
Run the captioning model on a single image and return the decoded caption.
|
| 12 |
+
"""
|
| 13 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
| 14 |
+
|
| 15 |
+
with getattr(__import__("torch"), "no_grad")():
|
| 16 |
+
torch = __import__("torch")
|
| 17 |
+
generated_ids = model.generate(
|
| 18 |
+
**inputs,
|
| 19 |
+
max_length=30,
|
| 20 |
+
num_beams=5,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
caption = processor.decode(
|
| 24 |
+
generated_ids[0],
|
| 25 |
+
skip_special_tokens=True,
|
| 26 |
+
)
|
| 27 |
+
return caption
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def evaluate_cider(model: Any, processor: Any, val_dataset, device, max_samples: int = 200) -> float:
|
| 31 |
+
"""
|
| 32 |
+
Compute CIDEr score on a validation subset.
|
| 33 |
+
|
| 34 |
+
Expects a PyTorch `Subset`/`Dataset` where:
|
| 35 |
+
- `val_dataset.indices[idx]` gives the underlying index
|
| 36 |
+
- `val_dataset.dataset.annotations[...]` is a list of dicts with
|
| 37 |
+
keys `image` and `captions`.
|
| 38 |
+
"""
|
| 39 |
+
import torch # local import to avoid hard dependency for non-training paths
|
| 40 |
+
|
| 41 |
+
model.eval()
|
| 42 |
+
|
| 43 |
+
cider_scorer = Cider()
|
| 44 |
+
ground_truth = {}
|
| 45 |
+
predictions = {}
|
| 46 |
+
|
| 47 |
+
for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
|
| 48 |
+
real_idx = val_dataset.indices[idx]
|
| 49 |
+
ann = val_dataset.dataset.annotations[real_idx]
|
| 50 |
+
|
| 51 |
+
image_path = os.path.join("train2017", ann["image"])
|
| 52 |
+
image = Image.open(image_path).convert("RGB")
|
| 53 |
+
|
| 54 |
+
pred_caption = generate_caption(model, processor, image, device)
|
| 55 |
+
|
| 56 |
+
ground_truth[idx] = ann["captions"]
|
| 57 |
+
predictions[idx] = [pred_caption]
|
| 58 |
+
|
| 59 |
+
score, _ = cider_scorer.compute_score(ground_truth, predictions)
|
| 60 |
+
|
| 61 |
+
print(f"CIDEr Score: {score:.4f}")
|
| 62 |
+
|
| 63 |
+
model.train()
|
| 64 |
+
return score
|
| 65 |
+
|
src/streamlit_app.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import altair as alt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import streamlit as st
|
| 5 |
+
|
| 6 |
+
"""
|
| 7 |
+
# Welcome to Streamlit!
|
| 8 |
+
|
| 9 |
+
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
+
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
+
forums](https://discuss.streamlit.io).
|
| 12 |
+
|
| 13 |
+
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
+
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
+
|
| 19 |
+
indices = np.linspace(0, 1, num_points)
|
| 20 |
+
theta = 2 * np.pi * num_turns * indices
|
| 21 |
+
radius = indices
|
| 22 |
+
|
| 23 |
+
x = radius * np.cos(theta)
|
| 24 |
+
y = radius * np.sin(theta)
|
| 25 |
+
|
| 26 |
+
df = pd.DataFrame({
|
| 27 |
+
"x": x,
|
| 28 |
+
"y": y,
|
| 29 |
+
"idx": indices,
|
| 30 |
+
"rand": np.random.randn(num_points),
|
| 31 |
+
})
|
| 32 |
+
|
| 33 |
+
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
+
.mark_point(filled=True)
|
| 35 |
+
.encode(
|
| 36 |
+
x=alt.X("x", axis=None),
|
| 37 |
+
y=alt.Y("y", axis=None),
|
| 38 |
+
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
+
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
+
))
|
src/training/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training entrypoints and training utilities.
|
| 3 |
+
"""
|
| 4 |
+
|
src/training/train_phase1.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torch.optim import AdamW
|
| 6 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 7 |
+
from torch.utils.data import DataLoader, random_split
|
| 8 |
+
from transformers import BlipForConditionalGeneration, BlipProcessor
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from src.data.coco_384_dataset import COCODataset384 as COCODataset
|
| 12 |
+
from src.evaluation.cider_eval import evaluate_cider
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main() -> None:
|
| 16 |
+
if not torch.backends.mps.is_available():
|
| 17 |
+
raise RuntimeError("MPS not available.")
|
| 18 |
+
|
| 19 |
+
device = torch.device("mps")
|
| 20 |
+
print("Using device:", device)
|
| 21 |
+
|
| 22 |
+
# =========================
|
| 23 |
+
# CONFIG
|
| 24 |
+
# =========================
|
| 25 |
+
EPOCHS = 5
|
| 26 |
+
BATCH_SIZE = 6
|
| 27 |
+
LR = 3e-5
|
| 28 |
+
NUM_WORKERS = 0
|
| 29 |
+
FINAL_MODEL_DIR = "saved_model_phase1"
|
| 30 |
+
|
| 31 |
+
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
|
| 32 |
+
|
| 33 |
+
# =========================
|
| 34 |
+
# LOAD MODEL
|
| 35 |
+
# =========================
|
| 36 |
+
processor = BlipProcessor.from_pretrained(
|
| 37 |
+
"Salesforce/blip-image-captioning-base"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
| 41 |
+
"Salesforce/blip-image-captioning-base"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Unfreeze LAST 2 vision layers only
|
| 45 |
+
for name, param in model.vision_model.named_parameters():
|
| 46 |
+
if "encoder.layers.10" in name or "encoder.layers.11" in name:
|
| 47 |
+
param.requires_grad = True
|
| 48 |
+
else:
|
| 49 |
+
param.requires_grad = False
|
| 50 |
+
|
| 51 |
+
model.gradient_checkpointing_enable()
|
| 52 |
+
model.config.use_cache = False
|
| 53 |
+
model.to(device)
|
| 54 |
+
|
| 55 |
+
# =========================
|
| 56 |
+
# DATASET SPLIT
|
| 57 |
+
# =========================
|
| 58 |
+
full_dataset = COCODataset(
|
| 59 |
+
"annotations/subset_10k.jsonl",
|
| 60 |
+
"train2017",
|
| 61 |
+
processor,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
train_size = int(0.9 * len(full_dataset))
|
| 65 |
+
val_size = len(full_dataset) - train_size
|
| 66 |
+
|
| 67 |
+
train_dataset, val_dataset = random_split(
|
| 68 |
+
full_dataset,
|
| 69 |
+
[train_size, val_size],
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
train_loader = DataLoader(
|
| 73 |
+
train_dataset,
|
| 74 |
+
batch_size=BATCH_SIZE,
|
| 75 |
+
shuffle=True,
|
| 76 |
+
num_workers=NUM_WORKERS,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
val_loader = DataLoader(
|
| 80 |
+
val_dataset,
|
| 81 |
+
batch_size=BATCH_SIZE,
|
| 82 |
+
shuffle=False,
|
| 83 |
+
num_workers=NUM_WORKERS,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
optimizer = AdamW(
|
| 87 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 88 |
+
lr=LR,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
| 92 |
+
|
| 93 |
+
# =========================
|
| 94 |
+
# EARLY STOPPING
|
| 95 |
+
# =========================
|
| 96 |
+
best_cider = 0.0
|
| 97 |
+
patience = 3
|
| 98 |
+
counter = 0
|
| 99 |
+
|
| 100 |
+
# =========================
|
| 101 |
+
# TRAIN LOOP
|
| 102 |
+
# =========================
|
| 103 |
+
for epoch in range(EPOCHS):
|
| 104 |
+
model.train()
|
| 105 |
+
total_loss = 0.0
|
| 106 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
|
| 107 |
+
|
| 108 |
+
for batch in progress_bar:
|
| 109 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 110 |
+
|
| 111 |
+
with torch.autocast(device_type="mps", dtype=torch.float16):
|
| 112 |
+
outputs = model(**batch)
|
| 113 |
+
loss = outputs.loss
|
| 114 |
+
|
| 115 |
+
loss.backward()
|
| 116 |
+
optimizer.step()
|
| 117 |
+
optimizer.zero_grad()
|
| 118 |
+
|
| 119 |
+
total_loss += loss.item()
|
| 120 |
+
progress_bar.set_postfix(loss=loss.item())
|
| 121 |
+
|
| 122 |
+
avg_train_loss = total_loss / len(train_loader)
|
| 123 |
+
print(f"Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}")
|
| 124 |
+
|
| 125 |
+
# =========================
|
| 126 |
+
# VALIDATION LOSS
|
| 127 |
+
# =========================
|
| 128 |
+
model.eval()
|
| 129 |
+
val_loss = 0.0
|
| 130 |
+
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
for batch in val_loader:
|
| 133 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 134 |
+
outputs = model(**batch)
|
| 135 |
+
val_loss += outputs.loss.item()
|
| 136 |
+
|
| 137 |
+
val_loss /= len(val_loader)
|
| 138 |
+
print(f"Epoch {epoch + 1} Validation Loss: {val_loss:.4f}")
|
| 139 |
+
|
| 140 |
+
# =========================
|
| 141 |
+
# CIDEr
|
| 142 |
+
# =========================
|
| 143 |
+
cider_score = evaluate_cider(model, processor, val_dataset, device)
|
| 144 |
+
|
| 145 |
+
# =========================
|
| 146 |
+
# SAVE BEST CIDEr MODEL
|
| 147 |
+
# =========================
|
| 148 |
+
if cider_score > best_cider:
|
| 149 |
+
best_cider = cider_score
|
| 150 |
+
counter = 0
|
| 151 |
+
model.save_pretrained(FINAL_MODEL_DIR)
|
| 152 |
+
processor.save_pretrained(FINAL_MODEL_DIR)
|
| 153 |
+
print("Best CIDEr model saved.")
|
| 154 |
+
else:
|
| 155 |
+
counter += 1
|
| 156 |
+
|
| 157 |
+
if counter >= patience:
|
| 158 |
+
print("Early stopping triggered.")
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
scheduler.step()
|
| 162 |
+
|
| 163 |
+
print("Phase 1 training complete.")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
main()
|
| 168 |
+
|
src/training/train_phase2.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.optim import AdamW
|
| 5 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 6 |
+
from torch.utils.data import DataLoader, random_split
|
| 7 |
+
from transformers import BlipForConditionalGeneration, BlipProcessor
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from src.data.coco_advanced_dataset import COCODatasetAdvanced
|
| 11 |
+
from src.evaluation.cider_eval import evaluate_cider
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def main() -> None:
|
| 15 |
+
if not torch.backends.mps.is_available():
|
| 16 |
+
raise RuntimeError("MPS not available.")
|
| 17 |
+
|
| 18 |
+
device = torch.device("mps")
|
| 19 |
+
print("Using device:", device)
|
| 20 |
+
|
| 21 |
+
# =========================
|
| 22 |
+
# CONFIG
|
| 23 |
+
# =========================
|
| 24 |
+
EPOCHS = 5
|
| 25 |
+
BATCH_SIZE = 6
|
| 26 |
+
LR = 3e-5 # Lower LR for partial unfreezing
|
| 27 |
+
NUM_WORKERS = 0
|
| 28 |
+
FINAL_MODEL_DIR = "saved_model_phase2"
|
| 29 |
+
|
| 30 |
+
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
# =========================
|
| 33 |
+
# LOAD MODEL
|
| 34 |
+
# =========================
|
| 35 |
+
processor = BlipProcessor.from_pretrained(
|
| 36 |
+
"Salesforce/blip-image-captioning-base"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
| 40 |
+
"Salesforce/blip-image-captioning-base"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Unfreeze LAST 2 vision layers only
|
| 44 |
+
for name, param in model.vision_model.named_parameters():
|
| 45 |
+
if "encoder.layers.10" in name or "encoder.layers.11" in name:
|
| 46 |
+
param.requires_grad = True
|
| 47 |
+
else:
|
| 48 |
+
param.requires_grad = False
|
| 49 |
+
|
| 50 |
+
model.gradient_checkpointing_enable()
|
| 51 |
+
model.config.use_cache = False
|
| 52 |
+
model.to(device)
|
| 53 |
+
|
| 54 |
+
# =========================
|
| 55 |
+
# DATASET SPLIT
|
| 56 |
+
# =========================
|
| 57 |
+
MODE = "long" # change to "short" or "mixed" as needed
|
| 58 |
+
|
| 59 |
+
full_dataset = COCODatasetAdvanced(
|
| 60 |
+
"annotations/subset_10k.jsonl",
|
| 61 |
+
"train2017",
|
| 62 |
+
processor,
|
| 63 |
+
mode=MODE,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
train_size = int(0.9 * len(full_dataset))
|
| 67 |
+
val_size = len(full_dataset) - train_size
|
| 68 |
+
|
| 69 |
+
train_dataset, val_dataset = random_split(
|
| 70 |
+
full_dataset,
|
| 71 |
+
[train_size, val_size],
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
train_loader = DataLoader(
|
| 75 |
+
train_dataset,
|
| 76 |
+
batch_size=BATCH_SIZE,
|
| 77 |
+
shuffle=True,
|
| 78 |
+
num_workers=NUM_WORKERS,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
val_loader = DataLoader(
|
| 82 |
+
val_dataset,
|
| 83 |
+
batch_size=BATCH_SIZE,
|
| 84 |
+
shuffle=False,
|
| 85 |
+
num_workers=NUM_WORKERS,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
optimizer = AdamW(
|
| 89 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 90 |
+
lr=LR,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
| 94 |
+
|
| 95 |
+
# =========================
|
| 96 |
+
# EARLY STOPPING
|
| 97 |
+
# =========================
|
| 98 |
+
best_cider = 0.0
|
| 99 |
+
patience = 3
|
| 100 |
+
counter = 0
|
| 101 |
+
|
| 102 |
+
# =========================
|
| 103 |
+
# TRAIN LOOP
|
| 104 |
+
# =========================
|
| 105 |
+
for epoch in range(EPOCHS):
|
| 106 |
+
model.train()
|
| 107 |
+
total_loss = 0.0
|
| 108 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
|
| 109 |
+
|
| 110 |
+
for batch in progress_bar:
|
| 111 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 112 |
+
|
| 113 |
+
with torch.autocast(device_type="mps", dtype=torch.float16):
|
| 114 |
+
outputs = model(**batch)
|
| 115 |
+
loss = outputs.loss
|
| 116 |
+
|
| 117 |
+
loss.backward()
|
| 118 |
+
optimizer.step()
|
| 119 |
+
optimizer.zero_grad()
|
| 120 |
+
|
| 121 |
+
total_loss += loss.item()
|
| 122 |
+
progress_bar.set_postfix(loss=loss.item())
|
| 123 |
+
|
| 124 |
+
avg_train_loss = total_loss / len(train_loader)
|
| 125 |
+
print(f"Epoch {epoch + 1} Train Loss: {avg_train_loss:.4f}")
|
| 126 |
+
|
| 127 |
+
# =========================
|
| 128 |
+
# VALIDATION LOSS
|
| 129 |
+
# =========================
|
| 130 |
+
model.eval()
|
| 131 |
+
val_loss = 0.0
|
| 132 |
+
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
for batch in val_loader:
|
| 135 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 136 |
+
outputs = model(**batch)
|
| 137 |
+
val_loss += outputs.loss.item()
|
| 138 |
+
|
| 139 |
+
val_loss /= len(val_loader)
|
| 140 |
+
print(f"Epoch {epoch + 1} Validation Loss: {val_loss:.4f}")
|
| 141 |
+
|
| 142 |
+
# =========================
|
| 143 |
+
# CIDEr
|
| 144 |
+
# =========================
|
| 145 |
+
cider_score = evaluate_cider(model, processor, val_dataset, device)
|
| 146 |
+
|
| 147 |
+
# =========================
|
| 148 |
+
# SAVE BEST CIDEr MODEL
|
| 149 |
+
# =========================
|
| 150 |
+
if cider_score > best_cider:
|
| 151 |
+
best_cider = cider_score
|
| 152 |
+
counter = 0
|
| 153 |
+
model.save_pretrained(FINAL_MODEL_DIR)
|
| 154 |
+
processor.save_pretrained(FINAL_MODEL_DIR)
|
| 155 |
+
print("Best CIDEr model saved.")
|
| 156 |
+
else:
|
| 157 |
+
counter += 1
|
| 158 |
+
|
| 159 |
+
if counter >= patience:
|
| 160 |
+
print("Early stopping triggered.")
|
| 161 |
+
break
|
| 162 |
+
|
| 163 |
+
scheduler.step()
|
| 164 |
+
|
| 165 |
+
print("Phase 2 training complete.")
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
main()
|
| 170 |
+
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
General-purpose utility functions and scripts.
|
| 3 |
+
"""
|
| 4 |
+
|
src/utils/data_subset.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Iterable
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def create_subset(
|
| 8 |
+
input_path: str | Path,
|
| 9 |
+
output_path: str | Path,
|
| 10 |
+
size: int = 20_000,
|
| 11 |
+
) -> None:
|
| 12 |
+
"""
|
| 13 |
+
Create a random subset of a JSONL annotations file.
|
| 14 |
+
"""
|
| 15 |
+
input_path = Path(input_path)
|
| 16 |
+
output_path = Path(output_path)
|
| 17 |
+
|
| 18 |
+
with input_path.open("r") as f:
|
| 19 |
+
data = [json.loads(line) for line in f]
|
| 20 |
+
|
| 21 |
+
if size > len(data):
|
| 22 |
+
raise ValueError(f"Requested subset size {size} exceeds dataset size {len(data)}")
|
| 23 |
+
|
| 24 |
+
subset = random.sample(data, size)
|
| 25 |
+
|
| 26 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
with output_path.open("w") as f:
|
| 29 |
+
for item in subset:
|
| 30 |
+
f.write(json.dumps(item) + "\n")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _main_from_cli(args: Iterable[str] | None = None) -> None:
|
| 34 |
+
"""
|
| 35 |
+
Simple CLI wrapper when this module is executed as a script.
|
| 36 |
+
"""
|
| 37 |
+
import argparse
|
| 38 |
+
|
| 39 |
+
parser = argparse.ArgumentParser(description="Create a random JSONL subset.")
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--input",
|
| 42 |
+
default="annotations/captions_train.jsonl",
|
| 43 |
+
help="Input JSONL annotations path.",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--output",
|
| 47 |
+
default="annotations/subset_20k.jsonl",
|
| 48 |
+
help="Output JSONL path.",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--size",
|
| 52 |
+
type=int,
|
| 53 |
+
default=20_000,
|
| 54 |
+
help="Number of samples to keep.",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
parsed = parser.parse_args(list(args) if args is not None else None)
|
| 58 |
+
create_subset(parsed.input, parsed.output, parsed.size)
|
| 59 |
+
print(f"Subset of {parsed.size} entries written to {parsed.output}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
_main_from_cli()
|
| 64 |
+
|
train_blip_20k_384.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader, random_split
|
| 4 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 5 |
+
from torch.optim import AdamW
|
| 6 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 7 |
+
from dataset_384 import COCODataset384
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
|
| 13 |
+
device = torch.device("mps")
|
| 14 |
+
print("Using device:", device)
|
| 15 |
+
|
| 16 |
+
EPOCHS = 5
|
| 17 |
+
BATCH_SIZE = 3 # ⚠️ Lower because 384px uses more memory
|
| 18 |
+
LR = 3e-5
|
| 19 |
+
|
| 20 |
+
CHECKPOINT_DIR = "checkpoints_20k_384"
|
| 21 |
+
MODEL_DIR = "saved_model_20k_384"
|
| 22 |
+
|
| 23 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 24 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
processor = BlipProcessor.from_pretrained(
|
| 27 |
+
"Salesforce/blip-image-captioning-base"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
| 31 |
+
"Salesforce/blip-image-captioning-base"
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
model.gradient_checkpointing_enable()
|
| 35 |
+
model.config.use_cache = False
|
| 36 |
+
model.to(device)
|
| 37 |
+
|
| 38 |
+
dataset = COCODataset384(
|
| 39 |
+
"annotations/subset_20k.jsonl",
|
| 40 |
+
"train2017",
|
| 41 |
+
processor
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
train_size = int(0.9 * len(dataset))
|
| 45 |
+
val_size = len(dataset) - train_size
|
| 46 |
+
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
| 47 |
+
|
| 48 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 49 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
| 50 |
+
|
| 51 |
+
optimizer = AdamW(model.parameters(), lr=LR)
|
| 52 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
| 53 |
+
|
| 54 |
+
best_val_loss = float("inf")
|
| 55 |
+
|
| 56 |
+
for epoch in range(EPOCHS):
|
| 57 |
+
|
| 58 |
+
model.train()
|
| 59 |
+
total_loss = 0
|
| 60 |
+
|
| 61 |
+
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
|
| 62 |
+
|
| 63 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 64 |
+
|
| 65 |
+
with torch.autocast(device_type="mps", dtype=torch.float16):
|
| 66 |
+
outputs = model(**batch)
|
| 67 |
+
loss = outputs.loss
|
| 68 |
+
|
| 69 |
+
loss.backward()
|
| 70 |
+
optimizer.step()
|
| 71 |
+
optimizer.zero_grad()
|
| 72 |
+
|
| 73 |
+
total_loss += loss.item()
|
| 74 |
+
|
| 75 |
+
train_loss = total_loss / len(train_loader)
|
| 76 |
+
print(f"Train Loss: {train_loss:.4f}")
|
| 77 |
+
|
| 78 |
+
# Validation
|
| 79 |
+
model.eval()
|
| 80 |
+
val_loss = 0
|
| 81 |
+
|
| 82 |
+
with torch.no_grad():
|
| 83 |
+
for batch in val_loader:
|
| 84 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 85 |
+
outputs = model(**batch)
|
| 86 |
+
val_loss += outputs.loss.item()
|
| 87 |
+
|
| 88 |
+
val_loss /= len(val_loader)
|
| 89 |
+
print(f"Validation Loss: {val_loss:.4f}")
|
| 90 |
+
|
| 91 |
+
if val_loss < best_val_loss:
|
| 92 |
+
best_val_loss = val_loss
|
| 93 |
+
model.save_pretrained(MODEL_DIR)
|
| 94 |
+
processor.save_pretrained(MODEL_DIR)
|
| 95 |
+
print("Best model saved.")
|
| 96 |
+
|
| 97 |
+
scheduler.step()
|
| 98 |
+
|
| 99 |
+
print("Training complete.")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
main()
|
train_data_experiments.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from platform import processor
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader, random_split
|
| 5 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 6 |
+
from torch.optim import AdamW
|
| 7 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 8 |
+
from dataset_advanced import COCODataset
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from pycocoevalcap.cider.cider import Cider
|
| 12 |
+
from dataset_advanced import COCODatasetAdvanced
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# =========================
|
| 16 |
+
# GENERATE CAPTION
|
| 17 |
+
# =========================
|
| 18 |
+
def generate_caption(model, processor, image, device):
|
| 19 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
| 20 |
+
|
| 21 |
+
with torch.no_grad():
|
| 22 |
+
generated_ids = model.generate(
|
| 23 |
+
**inputs,
|
| 24 |
+
max_length=30,
|
| 25 |
+
num_beams=5
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
caption = processor.decode(
|
| 29 |
+
generated_ids[0],
|
| 30 |
+
skip_special_tokens=True
|
| 31 |
+
)
|
| 32 |
+
return caption
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# =========================
|
| 36 |
+
# CIDEr EVALUATION
|
| 37 |
+
# =========================
|
| 38 |
+
def evaluate_cider(model, processor, val_dataset, device, max_samples=200):
|
| 39 |
+
model.eval()
|
| 40 |
+
|
| 41 |
+
cider_scorer = Cider()
|
| 42 |
+
ground_truth = {}
|
| 43 |
+
predictions = {}
|
| 44 |
+
|
| 45 |
+
for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
|
| 46 |
+
real_idx = val_dataset.indices[idx]
|
| 47 |
+
ann = val_dataset.dataset.annotations[real_idx]
|
| 48 |
+
|
| 49 |
+
image_path = os.path.join("train2017", ann["image"])
|
| 50 |
+
image = Image.open(image_path).convert("RGB")
|
| 51 |
+
|
| 52 |
+
pred_caption = generate_caption(model, processor, image, device)
|
| 53 |
+
|
| 54 |
+
ground_truth[idx] = ann["captions"]
|
| 55 |
+
predictions[idx] = [pred_caption]
|
| 56 |
+
|
| 57 |
+
score, _ = cider_scorer.compute_score(ground_truth, predictions)
|
| 58 |
+
|
| 59 |
+
print(f"CIDEr Score: {score:.4f}")
|
| 60 |
+
|
| 61 |
+
model.train()
|
| 62 |
+
return score
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# =========================
|
| 66 |
+
# MAIN
|
| 67 |
+
# =========================
|
| 68 |
+
def main():
|
| 69 |
+
|
| 70 |
+
if not torch.backends.mps.is_available():
|
| 71 |
+
raise RuntimeError("MPS not available.")
|
| 72 |
+
|
| 73 |
+
device = torch.device("mps")
|
| 74 |
+
print("Using device:", device)
|
| 75 |
+
|
| 76 |
+
# =========================
|
| 77 |
+
# CONFIG
|
| 78 |
+
# =========================
|
| 79 |
+
EPOCHS = 5
|
| 80 |
+
BATCH_SIZE = 6
|
| 81 |
+
LR = 3e-5 # Lower LR for partial unfreezing
|
| 82 |
+
NUM_WORKERS = 0
|
| 83 |
+
FINAL_MODEL_DIR = "saved_model_phase2"
|
| 84 |
+
|
| 85 |
+
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
# =========================
|
| 88 |
+
# LOAD MODEL
|
| 89 |
+
# =========================
|
| 90 |
+
processor = BlipProcessor.from_pretrained(
|
| 91 |
+
"Salesforce/blip-image-captioning-base"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
| 95 |
+
"Salesforce/blip-image-captioning-base"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# 🔥 Unfreeze LAST 2 vision layers only
|
| 99 |
+
for name, param in model.vision_model.named_parameters():
|
| 100 |
+
if "encoder.layers.10" in name or "encoder.layers.11" in name:
|
| 101 |
+
param.requires_grad = True
|
| 102 |
+
else:
|
| 103 |
+
param.requires_grad = False
|
| 104 |
+
|
| 105 |
+
model.gradient_checkpointing_enable()
|
| 106 |
+
model.config.use_cache = False
|
| 107 |
+
model.to(device)
|
| 108 |
+
|
| 109 |
+
# =========================
|
| 110 |
+
# DATASET SPLIT
|
| 111 |
+
# =========================
|
| 112 |
+
MODE = "long" # change to "short" or "long"
|
| 113 |
+
|
| 114 |
+
full_dataset = COCODatasetAdvanced(
|
| 115 |
+
"annotations/subset_10k.jsonl",
|
| 116 |
+
"train2017",
|
| 117 |
+
processor,
|
| 118 |
+
mode=MODE
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
train_size = int(0.9 * len(full_dataset))
|
| 122 |
+
val_size = len(full_dataset) - train_size
|
| 123 |
+
|
| 124 |
+
train_dataset, val_dataset = random_split(
|
| 125 |
+
full_dataset,
|
| 126 |
+
[train_size, val_size]
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
train_loader = DataLoader(
|
| 130 |
+
train_dataset,
|
| 131 |
+
batch_size=BATCH_SIZE,
|
| 132 |
+
shuffle=True,
|
| 133 |
+
num_workers=NUM_WORKERS
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
val_loader = DataLoader(
|
| 137 |
+
val_dataset,
|
| 138 |
+
batch_size=BATCH_SIZE,
|
| 139 |
+
shuffle=False,
|
| 140 |
+
num_workers=NUM_WORKERS
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
optimizer = AdamW(
|
| 144 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 145 |
+
lr=LR
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
| 149 |
+
|
| 150 |
+
# =========================
|
| 151 |
+
# EARLY STOPPING
|
| 152 |
+
# =========================
|
| 153 |
+
best_cider = 0
|
| 154 |
+
patience = 3
|
| 155 |
+
counter = 0
|
| 156 |
+
|
| 157 |
+
# =========================
|
| 158 |
+
# TRAIN LOOP
|
| 159 |
+
# =========================
|
| 160 |
+
for epoch in range(EPOCHS):
|
| 161 |
+
|
| 162 |
+
model.train()
|
| 163 |
+
total_loss = 0
|
| 164 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
|
| 165 |
+
|
| 166 |
+
for batch in progress_bar:
|
| 167 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 168 |
+
|
| 169 |
+
with torch.autocast(device_type="mps", dtype=torch.float16):
|
| 170 |
+
outputs = model(**batch)
|
| 171 |
+
loss = outputs.loss
|
| 172 |
+
|
| 173 |
+
loss.backward()
|
| 174 |
+
optimizer.step()
|
| 175 |
+
optimizer.zero_grad()
|
| 176 |
+
|
| 177 |
+
total_loss += loss.item()
|
| 178 |
+
progress_bar.set_postfix(loss=loss.item())
|
| 179 |
+
|
| 180 |
+
avg_train_loss = total_loss / len(train_loader)
|
| 181 |
+
print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")
|
| 182 |
+
|
| 183 |
+
# =========================
|
| 184 |
+
# VALIDATION LOSS
|
| 185 |
+
# =========================
|
| 186 |
+
model.eval()
|
| 187 |
+
val_loss = 0
|
| 188 |
+
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
for batch in val_loader:
|
| 191 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 192 |
+
outputs = model(**batch)
|
| 193 |
+
val_loss += outputs.loss.item()
|
| 194 |
+
|
| 195 |
+
val_loss /= len(val_loader)
|
| 196 |
+
print(f"Epoch {epoch+1} Validation Loss: {val_loss:.4f}")
|
| 197 |
+
|
| 198 |
+
# =========================
|
| 199 |
+
# CIDEr
|
| 200 |
+
# =========================
|
| 201 |
+
cider_score = evaluate_cider(model, processor, val_dataset, device)
|
| 202 |
+
|
| 203 |
+
# =========================
|
| 204 |
+
# SAVE BEST CIDEr MODEL
|
| 205 |
+
# =========================
|
| 206 |
+
if cider_score > best_cider:
|
| 207 |
+
best_cider = cider_score
|
| 208 |
+
counter = 0
|
| 209 |
+
model.save_pretrained(FINAL_MODEL_DIR)
|
| 210 |
+
processor.save_pretrained(FINAL_MODEL_DIR)
|
| 211 |
+
print("Best CIDEr model saved.")
|
| 212 |
+
else:
|
| 213 |
+
counter += 1
|
| 214 |
+
|
| 215 |
+
if counter >= patience:
|
| 216 |
+
print("Early stopping triggered.")
|
| 217 |
+
break
|
| 218 |
+
|
| 219 |
+
scheduler.step()
|
| 220 |
+
|
| 221 |
+
print("Phase 2 training complete.")
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
+
main()
|
train_git.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader, random_split
|
| 4 |
+
from transformers import GitProcessor, GitForCausalLM
|
| 5 |
+
from torch.optim import AdamW
|
| 6 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 7 |
+
from dataset_git import COCODatasetGIT
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from pycocoevalcap.cider.cider import Cider
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def generate_caption(model, processor, image, device):
|
| 14 |
+
|
| 15 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
| 16 |
+
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
output_ids = model.generate(
|
| 19 |
+
**inputs,
|
| 20 |
+
num_beams=5,
|
| 21 |
+
max_length=20
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
return processor.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def evaluate_cider(model, processor, val_dataset, device, max_samples=200):
|
| 28 |
+
|
| 29 |
+
model.eval()
|
| 30 |
+
cider_scorer = Cider()
|
| 31 |
+
|
| 32 |
+
ground_truth = {}
|
| 33 |
+
predictions = {}
|
| 34 |
+
|
| 35 |
+
for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
|
| 36 |
+
|
| 37 |
+
real_idx = val_dataset.indices[idx]
|
| 38 |
+
ann = val_dataset.dataset.annotations[real_idx]
|
| 39 |
+
|
| 40 |
+
image_path = os.path.join("train2017", ann["image"])
|
| 41 |
+
image = Image.open(image_path).convert("RGB")
|
| 42 |
+
|
| 43 |
+
pred_caption = generate_caption(model, processor, image, device)
|
| 44 |
+
|
| 45 |
+
ground_truth[idx] = ann["captions"]
|
| 46 |
+
predictions[idx] = [pred_caption]
|
| 47 |
+
|
| 48 |
+
score, _ = cider_scorer.compute_score(ground_truth, predictions)
|
| 49 |
+
|
| 50 |
+
print(f"CIDEr Score: {score:.4f}")
|
| 51 |
+
model.train()
|
| 52 |
+
return score
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main():
|
| 56 |
+
|
| 57 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 58 |
+
print("Using device:", device)
|
| 59 |
+
|
| 60 |
+
EPOCHS = 20
|
| 61 |
+
BATCH_SIZE = 4
|
| 62 |
+
LR = 5e-5
|
| 63 |
+
SAVE_DIR = "saved_git_model"
|
| 64 |
+
|
| 65 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 66 |
+
|
| 67 |
+
processor = GitProcessor.from_pretrained("microsoft/git-base")
|
| 68 |
+
model = GitForCausalLM.from_pretrained("microsoft/git-base")
|
| 69 |
+
|
| 70 |
+
model.to(device)
|
| 71 |
+
|
| 72 |
+
dataset = COCODatasetGIT(
|
| 73 |
+
"annotations/subset_20k.jsonl",
|
| 74 |
+
"train2017",
|
| 75 |
+
processor,
|
| 76 |
+
mode="mixed"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
train_size = int(0.9 * len(dataset))
|
| 80 |
+
val_size = len(dataset) - train_size
|
| 81 |
+
|
| 82 |
+
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
| 83 |
+
|
| 84 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 85 |
+
|
| 86 |
+
optimizer = AdamW(model.parameters(), lr=LR)
|
| 87 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
| 88 |
+
|
| 89 |
+
best_cider = 0
|
| 90 |
+
|
| 91 |
+
for epoch in range(EPOCHS):
|
| 92 |
+
|
| 93 |
+
model.train()
|
| 94 |
+
total_loss = 0
|
| 95 |
+
|
| 96 |
+
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
|
| 97 |
+
|
| 98 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 99 |
+
|
| 100 |
+
outputs = model(**batch)
|
| 101 |
+
loss = outputs.loss
|
| 102 |
+
|
| 103 |
+
loss.backward()
|
| 104 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 105 |
+
|
| 106 |
+
optimizer.step()
|
| 107 |
+
optimizer.zero_grad()
|
| 108 |
+
|
| 109 |
+
total_loss += loss.item()
|
| 110 |
+
|
| 111 |
+
print(f"Train Loss: {total_loss / len(train_loader):.4f}")
|
| 112 |
+
|
| 113 |
+
cider_score = evaluate_cider(model, processor, val_dataset, device)
|
| 114 |
+
|
| 115 |
+
if cider_score > best_cider:
|
| 116 |
+
best_cider = cider_score
|
| 117 |
+
model.save_pretrained(SAVE_DIR)
|
| 118 |
+
processor.save_pretrained(SAVE_DIR)
|
| 119 |
+
print("Best GIT model saved.")
|
| 120 |
+
|
| 121 |
+
scheduler.step()
|
| 122 |
+
|
| 123 |
+
print("GIT Training complete.")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
train_phase2.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader, random_split
|
| 4 |
+
from transformers import BlipProcessor, BlipForConditionalGeneration
|
| 5 |
+
from torch.optim import AdamW
|
| 6 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 7 |
+
from dataset_advanced import COCODataset
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from pycocoevalcap.cider.cider import Cider
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# =========================
|
| 14 |
+
# GENERATE CAPTION
|
| 15 |
+
# =========================
|
| 16 |
+
def generate_caption(model, processor, image, device):
|
| 17 |
+
inputs = processor(images=image, return_tensors="pt").to(device)
|
| 18 |
+
|
| 19 |
+
with torch.no_grad():
|
| 20 |
+
generated_ids = model.generate(
|
| 21 |
+
**inputs,
|
| 22 |
+
max_length=30,
|
| 23 |
+
num_beams=5
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
caption = processor.decode(
|
| 27 |
+
generated_ids[0],
|
| 28 |
+
skip_special_tokens=True
|
| 29 |
+
)
|
| 30 |
+
return caption
|
| 31 |
+
|
| 32 |
+
# =========================
|
| 33 |
+
# CIDEr EVALUATION
|
| 34 |
+
# =========================
|
| 35 |
+
def evaluate_cider(model, processor, val_dataset, device, max_samples=200):
|
| 36 |
+
model.eval()
|
| 37 |
+
|
| 38 |
+
cider_scorer = Cider()
|
| 39 |
+
ground_truth = {}
|
| 40 |
+
predictions = {}
|
| 41 |
+
|
| 42 |
+
for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
|
| 43 |
+
real_idx = val_dataset.indices[idx]
|
| 44 |
+
ann = val_dataset.dataset.annotations[real_idx]
|
| 45 |
+
|
| 46 |
+
image_path = os.path.join("train2017", ann["image"])
|
| 47 |
+
image = Image.open(image_path).convert("RGB")
|
| 48 |
+
|
| 49 |
+
pred_caption = generate_caption(model, processor, image, device)
|
| 50 |
+
|
| 51 |
+
ground_truth[idx] = ann["captions"]
|
| 52 |
+
predictions[idx] = [pred_caption]
|
| 53 |
+
|
| 54 |
+
score, _ = cider_scorer.compute_score(ground_truth, predictions)
|
| 55 |
+
|
| 56 |
+
print(f"CIDEr Score: {score:.4 f}")
|
| 57 |
+
|
| 58 |
+
model.train()
|
| 59 |
+
return score
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# =========================
|
| 63 |
+
# MAIN
|
| 64 |
+
# =========================
|
| 65 |
+
def main():
|
| 66 |
+
|
| 67 |
+
if not torch.backends.mps.is_available():
|
| 68 |
+
raise RuntimeError("MPS not available.")
|
| 69 |
+
|
| 70 |
+
device = torch.device("mps")
|
| 71 |
+
print("Using device:", device)
|
| 72 |
+
|
| 73 |
+
# =========================
|
| 74 |
+
# CONFIG
|
| 75 |
+
# =========================
|
| 76 |
+
EPOCHS = 5
|
| 77 |
+
BATCH_SIZE = 6
|
| 78 |
+
LR = 3e-5 # Lower LR for partial unfreezing
|
| 79 |
+
NUM_WORKERS = 0
|
| 80 |
+
FINAL_MODEL_DIR = "saved_model_phase2"
|
| 81 |
+
|
| 82 |
+
os.makedirs(FINAL_MODEL_DIR, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
# =========================
|
| 85 |
+
# LOAD MODEL
|
| 86 |
+
# =========================
|
| 87 |
+
processor = BlipProcessor.from_pretrained(
|
| 88 |
+
"Salesforce/blip-image-captioning-base"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
model = BlipForConditionalGeneration.from_pretrained(
|
| 92 |
+
"Salesforce/blip-image-captioning-base"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# 🔥 Unfreeze LAST 2 vision layers only
|
| 96 |
+
for name, param in model.vision_model.named_parameters():
|
| 97 |
+
if "encoder.layers.10" in name or "encoder.layers.11" in name:
|
| 98 |
+
param.requires_grad = True
|
| 99 |
+
else:
|
| 100 |
+
param.requires_grad = False
|
| 101 |
+
|
| 102 |
+
model.gradient_checkpointing_enable()
|
| 103 |
+
model.config.use_cache = False
|
| 104 |
+
model.to(device)
|
| 105 |
+
|
| 106 |
+
# =========================
|
| 107 |
+
# DATASET SPLIT
|
| 108 |
+
# =========================
|
| 109 |
+
full_dataset = COCODataset(
|
| 110 |
+
"annotations/subset_10k.jsonl",
|
| 111 |
+
"train2017",
|
| 112 |
+
processor
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
train_size = int(0.9 * len(full_dataset))
|
| 116 |
+
val_size = len(full_dataset) - train_size
|
| 117 |
+
|
| 118 |
+
train_dataset, val_dataset = random_split(
|
| 119 |
+
full_dataset,
|
| 120 |
+
[train_size, val_size]
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
train_loader = DataLoader(
|
| 124 |
+
train_dataset,
|
| 125 |
+
batch_size=BATCH_SIZE,
|
| 126 |
+
shuffle=True,
|
| 127 |
+
num_workers=NUM_WORKERS
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
val_loader = DataLoader(
|
| 131 |
+
val_dataset,
|
| 132 |
+
batch_size=BATCH_SIZE,
|
| 133 |
+
shuffle=False,
|
| 134 |
+
num_workers=NUM_WORKERS
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
optimizer = AdamW(
|
| 138 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
| 139 |
+
lr=LR
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
| 143 |
+
|
| 144 |
+
# =========================
|
| 145 |
+
# EARLY STOPPING
|
| 146 |
+
# =========================
|
| 147 |
+
best_cider = 0
|
| 148 |
+
patience = 3
|
| 149 |
+
counter = 0
|
| 150 |
+
|
| 151 |
+
# =========================
|
| 152 |
+
# TRAIN LOOP
|
| 153 |
+
# =========================
|
| 154 |
+
for epoch in range(EPOCHS):
|
| 155 |
+
|
| 156 |
+
model.train()
|
| 157 |
+
total_loss = 0
|
| 158 |
+
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
|
| 159 |
+
|
| 160 |
+
for batch in progress_bar:
|
| 161 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 162 |
+
|
| 163 |
+
with torch.autocast(device_type="mps", dtype=torch.float16):
|
| 164 |
+
outputs = model(**batch)
|
| 165 |
+
loss = outputs.loss
|
| 166 |
+
|
| 167 |
+
loss.backward()
|
| 168 |
+
optimizer.step()
|
| 169 |
+
optimizer.zero_grad()
|
| 170 |
+
|
| 171 |
+
total_loss += loss.item()
|
| 172 |
+
progress_bar.set_postfix(loss=loss.item())
|
| 173 |
+
|
| 174 |
+
avg_train_loss = total_loss / len(train_loader)
|
| 175 |
+
print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")
|
| 176 |
+
|
| 177 |
+
# =========================
|
| 178 |
+
# VALIDATION LOSS
|
| 179 |
+
# =========================
|
| 180 |
+
model.eval()
|
| 181 |
+
val_loss = 0
|
| 182 |
+
|
| 183 |
+
with torch.no_grad():
|
| 184 |
+
for batch in val_loader:
|
| 185 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
| 186 |
+
outputs = model(**batch)
|
| 187 |
+
val_loss += outputs.loss.item()
|
| 188 |
+
|
| 189 |
+
val_loss /= len(val_loader)
|
| 190 |
+
print(f"Epoch {epoch+1} Validation Loss: {val_loss:.4f}")
|
| 191 |
+
|
| 192 |
+
# =========================
|
| 193 |
+
# CIDEr
|
| 194 |
+
# =========================
|
| 195 |
+
cider_score = evaluate_cider(model, processor, val_dataset, device)
|
| 196 |
+
|
| 197 |
+
# =========================
|
| 198 |
+
# SAVE BEST CIDEr MODEL
|
| 199 |
+
# =========================
|
| 200 |
+
if cider_score > best_cider:
|
| 201 |
+
best_cider = cider_score
|
| 202 |
+
counter = 0
|
| 203 |
+
model.save_pretrained(FINAL_MODEL_DIR)
|
| 204 |
+
processor.save_pretrained(FINAL_MODEL_DIR)
|
| 205 |
+
print("Best CIDEr model saved.")
|
| 206 |
+
else:
|
| 207 |
+
counter += 1
|
| 208 |
+
|
| 209 |
+
if counter >= patience:
|
| 210 |
+
print("Early stopping triggered.")
|
| 211 |
+
break
|
| 212 |
+
|
| 213 |
+
scheduler.step()
|
| 214 |
+
|
| 215 |
+
print("Phase 2 training complete.")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
main()
|
train_vit_gpt2.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader, random_split
|
| 4 |
+
from transformers import (
|
| 5 |
+
VisionEncoderDecoderModel,
|
| 6 |
+
ViTImageProcessor,
|
| 7 |
+
AutoTokenizer,
|
| 8 |
+
GPT2Config,
|
| 9 |
+
GPT2LMHeadModel,
|
| 10 |
+
ViTModel
|
| 11 |
+
)
|
| 12 |
+
from torch.optim import AdamW
|
| 13 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 14 |
+
from dataset_vit_gpt2 import COCODatasetViTGPT2
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
from pycocoevalcap.cider.cider import Cider
|
| 17 |
+
from PIL import Image
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ==========================================
|
| 21 |
+
# GENERATE CAPTION
|
| 22 |
+
# ==========================================
|
| 23 |
+
def generate_caption(model, processor, tokenizer, image, device):
|
| 24 |
+
|
| 25 |
+
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
|
| 26 |
+
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
output_ids = model.generate(
|
| 29 |
+
pixel_values=pixel_values,
|
| 30 |
+
num_beams=5,
|
| 31 |
+
max_length=20,
|
| 32 |
+
length_penalty=1.0,
|
| 33 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 34 |
+
eos_token_id=tokenizer.eos_token_id
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ==========================================
|
| 41 |
+
# CIDEr EVALUATION
|
| 42 |
+
# ==========================================
|
| 43 |
+
def evaluate_cider(model, processor, tokenizer, val_dataset, device, max_samples=200):
|
| 44 |
+
|
| 45 |
+
model.eval()
|
| 46 |
+
cider_scorer = Cider()
|
| 47 |
+
|
| 48 |
+
ground_truth = {}
|
| 49 |
+
predictions = {}
|
| 50 |
+
|
| 51 |
+
for idx in tqdm(range(min(max_samples, len(val_dataset))), desc="CIDEr Eval"):
|
| 52 |
+
|
| 53 |
+
real_idx = val_dataset.indices[idx]
|
| 54 |
+
ann = val_dataset.dataset.annotations[real_idx]
|
| 55 |
+
|
| 56 |
+
image_path = os.path.join("train2017", ann["image"])
|
| 57 |
+
image = Image.open(image_path).convert("RGB")
|
| 58 |
+
|
| 59 |
+
pred_caption = generate_caption(model, processor, tokenizer, image, device)
|
| 60 |
+
|
| 61 |
+
ground_truth[idx] = ann["captions"]
|
| 62 |
+
predictions[idx] = [pred_caption]
|
| 63 |
+
|
| 64 |
+
score, _ = cider_scorer.compute_score(ground_truth, predictions)
|
| 65 |
+
|
| 66 |
+
print(f"CIDEr Score: {score:.4f}")
|
| 67 |
+
|
| 68 |
+
model.train()
|
| 69 |
+
return score
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ==========================================
|
| 73 |
+
# MAIN
|
| 74 |
+
# ==========================================
|
| 75 |
+
def main():
|
| 76 |
+
|
| 77 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 78 |
+
print("Using device:", device)
|
| 79 |
+
|
| 80 |
+
EPOCHS = 5
|
| 81 |
+
BATCH_SIZE = 6
|
| 82 |
+
LR = 3e-5
|
| 83 |
+
SAVE_DIR = "saved_vit_gpt2"
|
| 84 |
+
|
| 85 |
+
os.makedirs(SAVE_DIR, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
# ------------------------------------------
|
| 88 |
+
# Build Encoder + Decoder
|
| 89 |
+
# ------------------------------------------
|
| 90 |
+
|
| 91 |
+
encoder = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 92 |
+
|
| 93 |
+
decoder_config = GPT2Config.from_pretrained("gpt2")
|
| 94 |
+
decoder_config.is_decoder = True
|
| 95 |
+
decoder_config.add_cross_attention = True
|
| 96 |
+
|
| 97 |
+
decoder = GPT2LMHeadModel.from_pretrained("gpt2", config=decoder_config)
|
| 98 |
+
|
| 99 |
+
model = VisionEncoderDecoderModel(
|
| 100 |
+
encoder=encoder,
|
| 101 |
+
decoder=decoder
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
|
| 105 |
+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| 106 |
+
|
| 107 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 108 |
+
|
| 109 |
+
model.config.pad_token_id = tokenizer.eos_token_id
|
| 110 |
+
model.config.decoder_start_token_id = tokenizer.bos_token_id
|
| 111 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 112 |
+
model.config.vocab_size = model.config.decoder.vocab_size
|
| 113 |
+
|
| 114 |
+
model.to(device)
|
| 115 |
+
|
| 116 |
+
# ------------------------------------------
|
| 117 |
+
# DATASET
|
| 118 |
+
# ------------------------------------------
|
| 119 |
+
|
| 120 |
+
dataset = COCODatasetViTGPT2(
|
| 121 |
+
"annotations/subset_10k.jsonl",
|
| 122 |
+
"train2017",
|
| 123 |
+
processor,
|
| 124 |
+
tokenizer,
|
| 125 |
+
mode="short"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
train_size = int(0.9 * len(dataset))
|
| 129 |
+
val_size = len(dataset) - train_size
|
| 130 |
+
|
| 131 |
+
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
| 132 |
+
|
| 133 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 134 |
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
| 135 |
+
|
| 136 |
+
optimizer = AdamW(model.parameters(), lr=LR)
|
| 137 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
| 138 |
+
|
| 139 |
+
best_cider = 0
|
| 140 |
+
|
| 141 |
+
# ==========================================
|
| 142 |
+
# TRAIN LOOP
|
| 143 |
+
# ==========================================
|
| 144 |
+
for epoch in range(EPOCHS):
|
| 145 |
+
|
| 146 |
+
model.train()
|
| 147 |
+
total_loss = 0
|
| 148 |
+
|
| 149 |
+
for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
|
| 150 |
+
|
| 151 |
+
pixel_values = batch["pixel_values"].to(device)
|
| 152 |
+
labels = batch["labels"].to(device)
|
| 153 |
+
|
| 154 |
+
outputs = model(pixel_values=pixel_values, labels=labels)
|
| 155 |
+
loss = outputs.loss
|
| 156 |
+
|
| 157 |
+
loss.backward()
|
| 158 |
+
|
| 159 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 160 |
+
|
| 161 |
+
optimizer.step()
|
| 162 |
+
optimizer.zero_grad()
|
| 163 |
+
|
| 164 |
+
total_loss += loss.item()
|
| 165 |
+
|
| 166 |
+
avg_loss = total_loss / len(train_loader)
|
| 167 |
+
print(f"Train Loss: {avg_loss:.4f}")
|
| 168 |
+
|
| 169 |
+
# ------------------------------------------
|
| 170 |
+
# CIDEr Evaluation
|
| 171 |
+
# ------------------------------------------
|
| 172 |
+
cider_score = evaluate_cider(
|
| 173 |
+
model,
|
| 174 |
+
processor,
|
| 175 |
+
tokenizer,
|
| 176 |
+
val_dataset,
|
| 177 |
+
device
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Save best model
|
| 181 |
+
if cider_score > best_cider:
|
| 182 |
+
best_cider = cider_score
|
| 183 |
+
model.save_pretrained(SAVE_DIR)
|
| 184 |
+
tokenizer.save_pretrained(SAVE_DIR)
|
| 185 |
+
processor.save_pretrained(SAVE_DIR)
|
| 186 |
+
print("Best model saved.")
|
| 187 |
+
|
| 188 |
+
scheduler.step()
|
| 189 |
+
|
| 190 |
+
print("ViT-GPT2 Training complete.")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
main()
|
uploadtohf.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import (
|
| 2 |
+
AutoTokenizer,
|
| 3 |
+
BlipForConditionalGeneration,
|
| 4 |
+
BlipProcessor,
|
| 5 |
+
GitForCausalLM,
|
| 6 |
+
GitProcessor,
|
| 7 |
+
VisionEncoderDecoderModel,
|
| 8 |
+
ViTImageProcessor,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def push_blip(
|
| 13 |
+
local_dir: str = "saved_model_phase2",
|
| 14 |
+
repo_id: str = "pchandragrid/blip-caption-model",
|
| 15 |
+
) -> None:
|
| 16 |
+
model = BlipForConditionalGeneration.from_pretrained(local_dir)
|
| 17 |
+
processor = BlipProcessor.from_pretrained(local_dir)
|
| 18 |
+
model.push_to_hub(repo_id)
|
| 19 |
+
processor.push_to_hub(repo_id)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def push_vit_gpt2(
|
| 23 |
+
local_dir: str = "saved_vit_gpt2",
|
| 24 |
+
repo_id: str = "pchandragrid/vit-gpt2-caption-model",
|
| 25 |
+
) -> None:
|
| 26 |
+
model = VisionEncoderDecoderModel.from_pretrained(local_dir)
|
| 27 |
+
image_processor = ViTImageProcessor.from_pretrained(local_dir)
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(local_dir)
|
| 29 |
+
model.push_to_hub(repo_id)
|
| 30 |
+
image_processor.push_to_hub(repo_id)
|
| 31 |
+
tokenizer.push_to_hub(repo_id)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def push_git(
|
| 35 |
+
local_dir: str = "saved_git_model",
|
| 36 |
+
repo_id: str = "pchandragrid/git-caption-model",
|
| 37 |
+
) -> None:
|
| 38 |
+
model = GitForCausalLM.from_pretrained(local_dir)
|
| 39 |
+
processor = GitProcessor.from_pretrained(local_dir)
|
| 40 |
+
model.push_to_hub(repo_id)
|
| 41 |
+
processor.push_to_hub(repo_id)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
push_blip()
|
| 46 |
+
push_vit_gpt2()
|
| 47 |
+
push_git()
|
| 48 |
+
print("Uploaded: BLIP, ViT-GPT2, and GIT models.")
|