Spaces:
Running
Running
Deploy VQA Space with model downloader
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .env.example +6 -0
- .gitattributes +12 -34
- DATASET_CARD.md +250 -0
- Dockerfile +23 -0
- HOW_TO_RUN.md +255 -0
- PATTERN_MATCHING_FIX.md +86 -0
- QUICK_START.md +196 -0
- README.md +203 -7
- README_COMPLETE.md +530 -0
- SETUP_GUIDE.md +118 -0
- VQA_ENHANCEMENTS.md +298 -0
- __pycache__/backend_api.cpython-312.pyc +0 -0
- __pycache__/conversation_manager.cpython-312.pyc +0 -0
- __pycache__/ensemble_vqa_app.cpython-312.pyc +0 -0
- __pycache__/groq_service.cpython-312.pyc +0 -0
- __pycache__/knowledge_graph_service.cpython-312.pyc +0 -0
- __pycache__/llm_reasoning_service.cpython-312.pyc +0 -0
- __pycache__/model_spatial.cpython-312.pyc +0 -0
- __pycache__/semantic_neurosymbolic_vqa.cpython-312.pyc +0 -0
- architecture_draft.html +89 -0
- architecture_draft.mmd +69 -0
- backend_api.py +341 -0
- continue.py +344 -0
- continued_training_metric.csv +21 -0
- conversation_manager.py +312 -0
- download_models.py +27 -0
- draft_generator.py +112 -0
- ensemble_vqa_app.py +458 -0
- enterprise_architecture.drawio +341 -0
- exp_results/feature_extraction_metric.csv +31 -0
- experiments/__pycache__/train.cpython-312.pyc +0 -0
- experiments/test.py +73 -0
- experiments/train.py +349 -0
- experiments/utils/preprocess.py +164 -0
- experiments/utils/vocab.py +65 -0
- finetune.py +220 -0
- finetune2.py +395 -0
- genvqa-dataset.py +78 -0
- groq_service.py +118 -0
- knowledge_graph_service.py +291 -0
- llm_reasoning_service.py +292 -0
- model.py +224 -0
- model_spatial.py +309 -0
- models/__pycache__/model.cpython-312.pyc +0 -0
- models/model.py +224 -0
- quick_start.bat +71 -0
- requirements_api.txt +14 -0
- scores/feature.txt +77 -0
- scores/score.py +300 -0
- scores/vqa_evaluation_feature.csv +0 -0
.env.example
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Groq API Configuration
|
| 2 |
+
# Get your API key from: https://console.groq.com/keys
|
| 3 |
+
GROQ_API_KEY=your_groq_api_key_here
|
| 4 |
+
|
| 5 |
+
# Optional: Model selection (default: llama-3.3-70b-versatile)
|
| 6 |
+
# GROQ_MODEL=llama-3.3-70b-versatile
|
.gitattributes
CHANGED
|
@@ -1,35 +1,13 @@
|
|
| 1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.
|
| 24 |
-
*.
|
| 25 |
-
*.
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
*.
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lhs merge=lfs -text
|
| 3 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.json
|
| 5 |
+
filter=lfs
|
| 6 |
+
diff=lfs
|
| 7 |
+
merge=lfs
|
| 8 |
+
-text
|
| 9 |
+
*.csv
|
| 10 |
+
filter=lfs
|
| 11 |
+
diff=lfs
|
| 12 |
+
merge=lfs
|
| 13 |
+
-text
|
|
|
DATASET_CARD.md
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VQA v2 Curated Dataset for Spatial Reasoning
|
| 2 |
+
|
| 3 |
+
## Dataset Description
|
| 4 |
+
|
| 5 |
+
This is a **curated and balanced subset** of the VQA v2 (Visual Question Answering v2.0) dataset, specifically preprocessed for training visual question answering models with enhanced spatial reasoning capabilities.
|
| 6 |
+
|
| 7 |
+
### Dataset Summary
|
| 8 |
+
|
| 9 |
+
- **Source**: VQA v2 (MSCOCO train2014 split)
|
| 10 |
+
- **Task**: Visual Question Answering
|
| 11 |
+
- **Language**: English
|
| 12 |
+
- **License**: CC BY 4.0 (inherited from VQA v2)
|
| 13 |
+
|
| 14 |
+
### Key Features
|
| 15 |
+
|
| 16 |
+
✨ **Quality-Focused Curation**:
|
| 17 |
+
- Filtered out ambiguous yes/no questions
|
| 18 |
+
- Removed vague questions ("what is in the image", etc.)
|
| 19 |
+
- Answer length limited to 5 words / 30 characters
|
| 20 |
+
- Minimum answer frequency threshold (20 occurrences)
|
| 21 |
+
|
| 22 |
+
🎯 **Balanced Distribution**:
|
| 23 |
+
- Maximum 600 samples per answer class
|
| 24 |
+
- Prevents model bias toward common answers
|
| 25 |
+
- Ensures diverse question-answer coverage
|
| 26 |
+
|
| 27 |
+
📊 **Dataset Statistics**:
|
| 28 |
+
- **Total Q-A pairs**: ~[Your final count from running the script]
|
| 29 |
+
- **Unique answers**: ~[Number of unique answer classes]
|
| 30 |
+
- **Images**: MSCOCO train2014 subset
|
| 31 |
+
- **Format**: JSON + CSV metadata
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## Dataset Structure
|
| 36 |
+
|
| 37 |
+
### Data Fields
|
| 38 |
+
|
| 39 |
+
Each sample contains:
|
| 40 |
+
|
| 41 |
+
```json
|
| 42 |
+
{
|
| 43 |
+
"image_id": 123456, // MSCOCO image ID
|
| 44 |
+
"question_id": 789012, // VQA v2 question ID
|
| 45 |
+
"question": "What color is the car?",
|
| 46 |
+
"answer": "red", // Most frequent answer from annotators
|
| 47 |
+
"image_path": "images/COCO_train2014_000000123456.jpg"
|
| 48 |
+
}
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Data Splits
|
| 52 |
+
|
| 53 |
+
- **Training**: Main dataset (recommend 80-90% for training)
|
| 54 |
+
- **Validation**: User-defined split (recommend 10-20% for validation)
|
| 55 |
+
|
| 56 |
+
### File Structure
|
| 57 |
+
|
| 58 |
+
```
|
| 59 |
+
gen_vqa_v2/
|
| 60 |
+
├── images/ # MSCOCO train2014 images
|
| 61 |
+
│ └── COCO_train2014_*.jpg
|
| 62 |
+
├── qa_pairs.json # Question-answer pairs (JSON)
|
| 63 |
+
└── metadata.csv # Same data in CSV format
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Data Preprocessing
|
| 69 |
+
|
| 70 |
+
### Filtering Criteria
|
| 71 |
+
|
| 72 |
+
**Excluded Answers**:
|
| 73 |
+
- Generic responses: `yes`, `no`, `unknown`, `none`, `n/a`, `cant tell`, `not sure`
|
| 74 |
+
|
| 75 |
+
**Excluded Questions**:
|
| 76 |
+
- Ambiguous queries: "what is in the image", "what is this", "what is that", "what do you see"
|
| 77 |
+
|
| 78 |
+
**Answer Constraints**:
|
| 79 |
+
- Maximum 5 words per answer
|
| 80 |
+
- Maximum 30 characters per answer
|
| 81 |
+
- Minimum frequency: 20 occurrences across dataset
|
| 82 |
+
|
| 83 |
+
**Balancing Strategy**:
|
| 84 |
+
- Maximum 600 samples per answer class
|
| 85 |
+
- Prevents over-representation of common answers (e.g., "white", "2")
|
| 86 |
+
|
| 87 |
+
### Preprocessing Script
|
| 88 |
+
|
| 89 |
+
The dataset was generated using `genvqa-dataset.py`:
|
| 90 |
+
|
| 91 |
+
```python
|
| 92 |
+
# Key parameters
|
| 93 |
+
MIN_ANSWER_FREQ = 20 # Minimum answer occurrences
|
| 94 |
+
MAX_SAMPLES_PER_ANSWER = 600 # Class balancing limit
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
---
|
| 98 |
+
|
| 99 |
+
## Intended Use
|
| 100 |
+
|
| 101 |
+
### Primary Use Cases
|
| 102 |
+
|
| 103 |
+
✅ **Training VQA Models**:
|
| 104 |
+
- Visual question answering systems
|
| 105 |
+
- Multimodal vision-language models
|
| 106 |
+
- Spatial reasoning research
|
| 107 |
+
|
| 108 |
+
✅ **Research Applications**:
|
| 109 |
+
- Evaluating spatial understanding in VQA
|
| 110 |
+
- Studying answer distribution bias
|
| 111 |
+
- Benchmarking ensemble architectures
|
| 112 |
+
|
| 113 |
+
### Out-of-Scope Use
|
| 114 |
+
|
| 115 |
+
❌ Medical diagnosis or safety-critical applications
|
| 116 |
+
❌ Surveillance or privacy-invasive systems
|
| 117 |
+
❌ Generating misleading or harmful content
|
| 118 |
+
|
| 119 |
+
---
|
| 120 |
+
|
| 121 |
+
## Dataset Creation
|
| 122 |
+
|
| 123 |
+
### Source Data
|
| 124 |
+
|
| 125 |
+
**VQA v2 Dataset**:
|
| 126 |
+
- **Paper**: [Making the V in VQA Matter](https://arxiv.org/abs/1612.00837)
|
| 127 |
+
- **Authors**: Goyal et al. (2017)
|
| 128 |
+
- **Images**: MSCOCO train2014
|
| 129 |
+
- **Original Size**: 443,757 question-answer pairs (train split)
|
| 130 |
+
|
| 131 |
+
### Curation Rationale
|
| 132 |
+
|
| 133 |
+
This curated subset addresses common VQA training challenges:
|
| 134 |
+
|
| 135 |
+
1. **Bias Reduction**: Limits over-represented answers
|
| 136 |
+
2. **Quality Control**: Removes ambiguous/uninformative samples
|
| 137 |
+
3. **Spatial Focus**: Retains questions requiring spatial reasoning
|
| 138 |
+
4. **Practical Constraints**: Focuses on concise, specific answers
|
| 139 |
+
|
| 140 |
+
### Annotations
|
| 141 |
+
|
| 142 |
+
Annotations are inherited from VQA v2:
|
| 143 |
+
- 10 answers per question from human annotators
|
| 144 |
+
- **Answer selection**: Most frequent answer among annotators
|
| 145 |
+
- **Consensus**: Majority voting for ground truth
|
| 146 |
+
|
| 147 |
+
---
|
| 148 |
+
|
| 149 |
+
## Considerations for Using the Data
|
| 150 |
+
|
| 151 |
+
### Social Impact
|
| 152 |
+
|
| 153 |
+
This dataset inherits biases from MSCOCO and VQA v2:
|
| 154 |
+
- **Geographic bias**: Primarily Western/North American scenes
|
| 155 |
+
- **Cultural bias**: Limited representation of global diversity
|
| 156 |
+
- **Object bias**: Common objects over-represented
|
| 157 |
+
|
| 158 |
+
### Limitations
|
| 159 |
+
|
| 160 |
+
⚠️ **Known Issues**:
|
| 161 |
+
- Answer distribution still skewed toward common objects (e.g., "white", "2", "yes")
|
| 162 |
+
- Spatial reasoning questions may be underrepresented
|
| 163 |
+
- Some questions may have multiple valid answers
|
| 164 |
+
|
| 165 |
+
⚠️ **Not Suitable For**:
|
| 166 |
+
- Fine-grained visual reasoning (e.g., "How many stripes on the 3rd zebra?")
|
| 167 |
+
- Rare object recognition
|
| 168 |
+
- Non-English languages
|
| 169 |
+
|
| 170 |
+
---
|
| 171 |
+
|
| 172 |
+
## Citation
|
| 173 |
+
|
| 174 |
+
### BibTeX
|
| 175 |
+
|
| 176 |
+
```bibtex
|
| 177 |
+
@inproceedings{goyal2017making,
|
| 178 |
+
title={Making the V in VQA Matter: Elevating the Role of Image Understanding in Visual Question Answering},
|
| 179 |
+
author={Goyal, Yash and Khot, Tejas and Summers-Stay, Douglas and Batra, Dhruv and Parikh, Devi},
|
| 180 |
+
booktitle={CVPR},
|
| 181 |
+
year={2017}
|
| 182 |
+
}
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
### Original VQA v2 Dataset
|
| 186 |
+
|
| 187 |
+
- **Homepage**: https://visualqa.org/
|
| 188 |
+
- **Paper**: https://arxiv.org/abs/1612.00837
|
| 189 |
+
- **License**: CC BY 4.0
|
| 190 |
+
|
| 191 |
+
---
|
| 192 |
+
|
| 193 |
+
## Additional Information
|
| 194 |
+
|
| 195 |
+
### Dataset Curators
|
| 196 |
+
|
| 197 |
+
Curated from VQA v2 by [Your Name/Organization]
|
| 198 |
+
|
| 199 |
+
### Licensing
|
| 200 |
+
|
| 201 |
+
This dataset is released under **CC BY 4.0**, consistent with the original VQA v2 license.
|
| 202 |
+
|
| 203 |
+
### Contact
|
| 204 |
+
|
| 205 |
+
For questions or issues, please contact [your email/GitHub].
|
| 206 |
+
|
| 207 |
+
---
|
| 208 |
+
|
| 209 |
+
## Usage Example
|
| 210 |
+
|
| 211 |
+
### Loading the Dataset
|
| 212 |
+
|
| 213 |
+
```python
|
| 214 |
+
import json
|
| 215 |
+
import pandas as pd
|
| 216 |
+
from PIL import Image
|
| 217 |
+
|
| 218 |
+
# Load metadata
|
| 219 |
+
with open("gen_vqa_v2/qa_pairs.json", "r") as f:
|
| 220 |
+
data = json.load(f)
|
| 221 |
+
|
| 222 |
+
# Or use CSV
|
| 223 |
+
df = pd.read_csv("gen_vqa_v2/metadata.csv")
|
| 224 |
+
|
| 225 |
+
# Access a sample
|
| 226 |
+
sample = data[0]
|
| 227 |
+
image = Image.open(f"gen_vqa_v2/{sample['image_path']}")
|
| 228 |
+
question = sample['question']
|
| 229 |
+
answer = sample['answer']
|
| 230 |
+
|
| 231 |
+
print(f"Q: {question}")
|
| 232 |
+
print(f"A: {answer}")
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
### Training Split
|
| 236 |
+
|
| 237 |
+
```python
|
| 238 |
+
from sklearn.model_selection import train_test_split
|
| 239 |
+
|
| 240 |
+
# 80-20 train-val split
|
| 241 |
+
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
---
|
| 245 |
+
|
| 246 |
+
## Acknowledgments
|
| 247 |
+
|
| 248 |
+
- **VQA v2 Team**: Goyal et al. for the original dataset
|
| 249 |
+
- **MSCOCO Team**: Lin et al. for the image dataset
|
| 250 |
+
- **Community**: Open-source VQA research community
|
Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# System deps
|
| 6 |
+
RUN apt-get update && apt-get install -y \
|
| 7 |
+
git \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
libsm6 \
|
| 10 |
+
libxrender1 \
|
| 11 |
+
libxext6 \
|
| 12 |
+
libgl1-mesa-glx \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Install Python deps
|
| 16 |
+
COPY requirements_api.txt .
|
| 17 |
+
RUN pip install --no-cache-dir -r requirements_api.txt
|
| 18 |
+
|
| 19 |
+
# Copy all project files
|
| 20 |
+
COPY . .
|
| 21 |
+
|
| 22 |
+
# Download models before starting server
|
| 23 |
+
CMD python download_models.py && uvicorn backend_api:app --host 0.0.0.0 --port 7860
|
HOW_TO_RUN.md
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🚀 How to Run the VQA Mobile App
|
| 2 |
+
|
| 3 |
+
## Quick Overview
|
| 4 |
+
|
| 5 |
+
You now have a complete React Native mobile app for Visual Question Answering! Here's what was created:
|
| 6 |
+
|
| 7 |
+
### ✅ What's Built
|
| 8 |
+
|
| 9 |
+
1. **Backend API** (`backend_api.py`)
|
| 10 |
+
- FastAPI server wrapping your ensemble VQA models
|
| 11 |
+
- Automatic routing between base and spatial models
|
| 12 |
+
- Image upload and question answering endpoints
|
| 13 |
+
|
| 14 |
+
2. **Mobile App** (`ui/` folder)
|
| 15 |
+
- Beautiful React Native app with Expo
|
| 16 |
+
- Google OAuth authentication
|
| 17 |
+
- Camera and gallery image picker
|
| 18 |
+
- Question input and answer display
|
| 19 |
+
- Model routing visualization
|
| 20 |
+
|
| 21 |
+
## 🎯 Running the App (3 Steps)
|
| 22 |
+
|
| 23 |
+
### Step 1: Start the Backend Server
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
# Open PowerShell/Terminal
|
| 27 |
+
cd c:\Users\rdeva\Downloads\vqa_coes
|
| 28 |
+
|
| 29 |
+
# Install API dependencies (FIRST TIME ONLY)
|
| 30 |
+
# If you get import errors, run this:
|
| 31 |
+
pip install fastapi uvicorn python-multipart
|
| 32 |
+
|
| 33 |
+
# Start the server
|
| 34 |
+
python start_backend.py
|
| 35 |
+
# Or: python backend_api.py
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
> **Note**: If you get "ModuleNotFoundError", see [IMPORT_ERRORS_FIX.md](file:///c:/Users/rdeva/Downloads/vqa_coes/IMPORT_ERRORS_FIX.md) for solutions.
|
| 39 |
+
|
| 40 |
+
✅ **Keep this window open!** The server must stay running.
|
| 41 |
+
|
| 42 |
+
You should see:
|
| 43 |
+
```
|
| 44 |
+
🚀 INITIALIZING ENSEMBLE VQA SYSTEM
|
| 45 |
+
✅ Ensemble ready!
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### Step 2: Configure the Mobile App
|
| 49 |
+
|
| 50 |
+
1. **Find your local IP address:**
|
| 51 |
+
```bash
|
| 52 |
+
ipconfig
|
| 53 |
+
```
|
| 54 |
+
Look for "IPv4 Address" (e.g., `192.168.1.100`)
|
| 55 |
+
|
| 56 |
+
2. **Update the API URL:**
|
| 57 |
+
- Open: `ui\src\config\api.js`
|
| 58 |
+
- Change line 8:
|
| 59 |
+
```javascript
|
| 60 |
+
export const API_BASE_URL = 'http://YOUR_IP_HERE:8000';
|
| 61 |
+
```
|
| 62 |
+
- Example:
|
| 63 |
+
```javascript
|
| 64 |
+
export const API_BASE_URL = 'http://192.168.1.100:8000';
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
### Step 3: Start the Mobile App
|
| 68 |
+
|
| 69 |
+
```bash
|
| 70 |
+
# Open a NEW PowerShell/Terminal window
|
| 71 |
+
cd c:\Users\rdeva\Downloads\vqa_coes\ui
|
| 72 |
+
|
| 73 |
+
# Start Expo
|
| 74 |
+
npm start
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
You'll see a QR code in the terminal.
|
| 78 |
+
|
| 79 |
+
### Step 4: Run on Your Phone
|
| 80 |
+
|
| 81 |
+
1. **Install Expo Go** on your smartphone:
|
| 82 |
+
- [Android - Play Store](https://play.google.com/store/apps/details?id=host.exp.exponent)
|
| 83 |
+
- [iOS - App Store](https://apps.apple.com/app/expo-go/id982107779)
|
| 84 |
+
|
| 85 |
+
2. **Scan the QR code:**
|
| 86 |
+
- Android: Open Expo Go → Scan QR
|
| 87 |
+
- iOS: Open Camera → Scan QR → Tap notification
|
| 88 |
+
|
| 89 |
+
3. **Wait for the app to load** (first time takes ~1-2 minutes)
|
| 90 |
+
|
| 91 |
+
## 📱 Using the App
|
| 92 |
+
|
| 93 |
+
### Option A: Test Without Google Login
|
| 94 |
+
|
| 95 |
+
For quick testing, you can bypass Google authentication:
|
| 96 |
+
|
| 97 |
+
1. Open `ui\App.js`
|
| 98 |
+
2. Find line 23-27 and replace with:
|
| 99 |
+
```javascript
|
| 100 |
+
<Stack.Screen name="Home" component={HomeScreen} />
|
| 101 |
+
```
|
| 102 |
+
3. Save and reload the app (shake phone → Reload)
|
| 103 |
+
|
| 104 |
+
### Option B: Set Up Google Login
|
| 105 |
+
|
| 106 |
+
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
| 107 |
+
2. Create a new project
|
| 108 |
+
3. Enable Google+ API
|
| 109 |
+
4. Create OAuth 2.0 credentials
|
| 110 |
+
5. Update `ui\src\config\google.js` with your client IDs
|
| 111 |
+
|
| 112 |
+
### Testing VQA Functionality
|
| 113 |
+
|
| 114 |
+
1. **Select an image:**
|
| 115 |
+
- Tap "Camera" to take a photo
|
| 116 |
+
- Tap "Gallery" to choose existing image
|
| 117 |
+
|
| 118 |
+
2. **Ask a question:**
|
| 119 |
+
- Type your question (e.g., "What color is the car?")
|
| 120 |
+
- Tap "Ask Question"
|
| 121 |
+
|
| 122 |
+
3. **View the answer:**
|
| 123 |
+
- See the AI-generated answer
|
| 124 |
+
- Check which model was used:
|
| 125 |
+
- 🔍 **Base Model** - General questions
|
| 126 |
+
- 📍 **Spatial Model** - Spatial questions (left, right, above, etc.)
|
| 127 |
+
|
| 128 |
+
## 🧪 Example Questions to Try
|
| 129 |
+
|
| 130 |
+
### General Questions (Base Model 🔍)
|
| 131 |
+
- "What color is the car?"
|
| 132 |
+
- "How many people are in the image?"
|
| 133 |
+
- "What room is this?"
|
| 134 |
+
- "Is there a dog?"
|
| 135 |
+
|
| 136 |
+
### Spatial Questions (Spatial Model 📍)
|
| 137 |
+
- "What is to the right of the table?"
|
| 138 |
+
- "What is above the chair?"
|
| 139 |
+
- "What is next to the door?"
|
| 140 |
+
- "What is on the left side?"
|
| 141 |
+
|
| 142 |
+
## 🔧 Troubleshooting
|
| 143 |
+
|
| 144 |
+
### "Cannot connect to server"
|
| 145 |
+
- ✅ Check backend is running (`python backend_api.py`)
|
| 146 |
+
- ✅ Verify IP address in `api.js` matches your computer's IP
|
| 147 |
+
- ✅ Ensure phone and computer are on the **same WiFi network**
|
| 148 |
+
- ✅ Check Windows Firewall isn't blocking port 8000
|
| 149 |
+
|
| 150 |
+
### "Model not loaded"
|
| 151 |
+
- ✅ Ensure these files exist in `c:\Users\rdeva\Downloads\vqa_coes\`:
|
| 152 |
+
- `vqa_checkpoint.pt`
|
| 153 |
+
- `vqa_spatial_checkpoint.pt`
|
| 154 |
+
- ✅ Check backend terminal for error messages
|
| 155 |
+
|
| 156 |
+
### App won't load on phone
|
| 157 |
+
- ✅ Verify Expo Go is installed
|
| 158 |
+
- ✅ Both devices on same WiFi
|
| 159 |
+
- ✅ Try restarting Expo: Press `Ctrl+C`, then `npm start`
|
| 160 |
+
- ✅ Clear cache: `npm start -- --clear`
|
| 161 |
+
|
| 162 |
+
### Camera/Gallery not working
|
| 163 |
+
- ✅ Grant permissions when prompted
|
| 164 |
+
- ✅ Check phone Settings → App Permissions
|
| 165 |
+
|
| 166 |
+
## 📁 Project Structure
|
| 167 |
+
|
| 168 |
+
```
|
| 169 |
+
vqa_coes/
|
| 170 |
+
├── backend_api.py # FastAPI backend server
|
| 171 |
+
├── ensemble_vqa_app.py # Your existing ensemble system
|
| 172 |
+
├── model_spatial.py # Spatial model
|
| 173 |
+
├── models/model.py # Base model
|
| 174 |
+
├── vqa_checkpoint.pt # Base model weights
|
| 175 |
+
├── vqa_spatial_checkpoint.pt # Spatial model weights
|
| 176 |
+
├── requirements_api.txt # Backend dependencies
|
| 177 |
+
��── QUICK_START.md # This guide
|
| 178 |
+
└── ui/ # Mobile app
|
| 179 |
+
├── App.js # Main app component
|
| 180 |
+
├── app.json # Expo configuration
|
| 181 |
+
├── package.json # Dependencies
|
| 182 |
+
└── src/
|
| 183 |
+
├── config/
|
| 184 |
+
│ ├── api.js # ⚠️ UPDATE YOUR IP HERE
|
| 185 |
+
│ └── google.js # Google OAuth config
|
| 186 |
+
├── contexts/
|
| 187 |
+
│ └── AuthContext.js # Authentication
|
| 188 |
+
├── screens/
|
| 189 |
+
│ ├── LoginScreen.js # Login UI
|
| 190 |
+
│ └── HomeScreen.js # Main VQA UI
|
| 191 |
+
├── services/
|
| 192 |
+
│ └── api.js # API client
|
| 193 |
+
└── styles/
|
| 194 |
+
├── theme.js # Design system
|
| 195 |
+
└── globalStyles.js
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
## 📚 Documentation
|
| 199 |
+
|
| 200 |
+
- **Quick Start**: `QUICK_START.md` (this file)
|
| 201 |
+
- **Full README**: `ui/README.md`
|
| 202 |
+
- **Implementation Details**: See walkthrough artifact
|
| 203 |
+
|
| 204 |
+
## 🎨 Customization
|
| 205 |
+
|
| 206 |
+
### Change Colors
|
| 207 |
+
Edit `ui/src/styles/theme.js`:
|
| 208 |
+
```javascript
|
| 209 |
+
colors: {
|
| 210 |
+
primary: '#6366F1', // Change to your color
|
| 211 |
+
secondary: '#EC4899', // Change to your color
|
| 212 |
+
// ...
|
| 213 |
+
}
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
### Change App Name
|
| 217 |
+
Edit `ui/app.json`:
|
| 218 |
+
```json
|
| 219 |
+
{
|
| 220 |
+
"expo": {
|
| 221 |
+
"name": "Your App Name",
|
| 222 |
+
"slug": "your-app-slug"
|
| 223 |
+
}
|
| 224 |
+
}
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
## 🚢 Next Steps
|
| 228 |
+
|
| 229 |
+
Once everything works:
|
| 230 |
+
|
| 231 |
+
1. **Add Google OAuth** for production
|
| 232 |
+
2. **Create custom icons** (see `ui/assets/ICONS_README.md`)
|
| 233 |
+
3. **Build standalone app**:
|
| 234 |
+
```bash
|
| 235 |
+
npx eas-cli build --platform android
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## 💡 Tips
|
| 239 |
+
|
| 240 |
+
- **Backend must run first** before starting the mobile app
|
| 241 |
+
- **Same WiFi network** is required for phone and computer
|
| 242 |
+
- **First load is slow** - subsequent loads are faster
|
| 243 |
+
- **Shake phone** to access Expo developer menu
|
| 244 |
+
- **Check logs** in both terminals for debugging
|
| 245 |
+
|
| 246 |
+
## 🆘 Need Help?
|
| 247 |
+
|
| 248 |
+
1. Check the troubleshooting section above
|
| 249 |
+
2. Review backend terminal for errors
|
| 250 |
+
3. Check Expo console in terminal
|
| 251 |
+
4. Verify all configuration steps
|
| 252 |
+
|
| 253 |
+
---
|
| 254 |
+
|
| 255 |
+
**Ready to test?** Follow the 4 steps above and start asking questions about images! 🎉
|
PATTERN_MATCHING_FIX.md
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fix: Removed Hardcoded Patterns from Neuro-Symbolic VQA
|
| 2 |
+
|
| 3 |
+
## Problem Identified
|
| 4 |
+
The `_detect_objects_with_clip()` method in `semantic_neurosymbolic_vqa.py` contained a **predefined list of object categories**, which is essentially pattern matching and defeats the purpose of a truly neuro-symbolic approach.
|
| 5 |
+
|
| 6 |
+
```python
|
| 7 |
+
# ❌ OLD CODE - Hardcoded categories (pattern matching!)
|
| 8 |
+
object_categories = [
|
| 9 |
+
"food", "soup", "noodles", "rice", "meat", "vegetable", "fruit",
|
| 10 |
+
"bowl", "plate", "cup", "glass", "spoon", "fork", "knife", ...
|
| 11 |
+
]
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
This is **not acceptable** because:
|
| 15 |
+
- It limits detection to predefined categories only
|
| 16 |
+
- It's essentially pattern matching, not true neural understanding
|
| 17 |
+
- It violates the neuro-symbolic principle of learning from data
|
| 18 |
+
|
| 19 |
+
## Solution Applied
|
| 20 |
+
|
| 21 |
+
### 1. Deprecated `_detect_objects_with_clip()`
|
| 22 |
+
The method now returns an empty list and warns that it's deprecated:
|
| 23 |
+
|
| 24 |
+
```python
|
| 25 |
+
# ✅ NEW CODE - No predefined lists!
|
| 26 |
+
def _detect_objects_with_clip(self, image_features, image_path=None):
|
| 27 |
+
"""
|
| 28 |
+
NOTE: This method is deprecated in favor of using the VQA model
|
| 29 |
+
directly from ensemble_vqa_app.py.
|
| 30 |
+
"""
|
| 31 |
+
print("⚠️ _detect_objects_with_clip is deprecated")
|
| 32 |
+
print("→ Use VQA model's _detect_multiple_objects() instead")
|
| 33 |
+
return []
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### 2. Updated `answer_with_clip_features()`
|
| 37 |
+
Now **requires** objects to be provided by the VQA model:
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
# ✅ Objects must come from VQA model, not predefined lists
|
| 41 |
+
def answer_with_clip_features(
|
| 42 |
+
self,
|
| 43 |
+
image_features,
|
| 44 |
+
question,
|
| 45 |
+
image_path=None,
|
| 46 |
+
detected_objects: List[str] = None # REQUIRED!
|
| 47 |
+
):
|
| 48 |
+
if not detected_objects:
|
| 49 |
+
print("⚠️ No objects provided - neuro-symbolic reasoning requires VQA-detected objects")
|
| 50 |
+
return None
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### 3. Ensemble VQA Uses True VQA Detection
|
| 54 |
+
The `ensemble_vqa_app.py` already uses `_detect_multiple_objects()` which:
|
| 55 |
+
- Asks the VQA model **open-ended questions** like "What is this?"
|
| 56 |
+
- Uses the model's learned knowledge, not predefined categories
|
| 57 |
+
- Generates objects dynamically based on visual understanding
|
| 58 |
+
|
| 59 |
+
```python
|
| 60 |
+
# ✅ TRUE NEURO-SYMBOLIC APPROACH
|
| 61 |
+
detected_objects = self._detect_multiple_objects(image, model, top_k=5)
|
| 62 |
+
# This asks VQA model: "What is this?", "What food is this?", etc.
|
| 63 |
+
# NO predefined categories!
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Result
|
| 67 |
+
|
| 68 |
+
✅ **Pure Neuro-Symbolic Pipeline**:
|
| 69 |
+
1. **VQA Model** detects objects using learned visual understanding (no predefined lists)
|
| 70 |
+
2. **Wikidata** provides factual knowledge about detected objects
|
| 71 |
+
3. **LLM** performs Chain-of-Thought reasoning on the facts
|
| 72 |
+
4. **No pattern matching** anywhere in the pipeline
|
| 73 |
+
|
| 74 |
+
## Files Modified
|
| 75 |
+
- `semantic_neurosymbolic_vqa.py`:
|
| 76 |
+
- Deprecated `_detect_objects_with_clip()`
|
| 77 |
+
- Updated `answer_with_clip_features()` to require VQA-detected objects
|
| 78 |
+
- Changed knowledge source from "CLIP + Wikidata" to "VQA + Wikidata"
|
| 79 |
+
|
| 80 |
+
## Verification
|
| 81 |
+
The system now uses a **truly neuro-symbolic approach**:
|
| 82 |
+
- ✅ No hardcoded object categories
|
| 83 |
+
- ✅ No predefined patterns
|
| 84 |
+
- ✅ Pure learned visual understanding from VQA model
|
| 85 |
+
- ✅ Symbolic reasoning from Wikidata + LLM
|
| 86 |
+
- ✅ Chain-of-Thought transparency
|
QUICK_START.md
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quick Start Guide - VQA Mobile App
|
| 2 |
+
|
| 3 |
+
This guide will help you get the VQA mobile app running quickly.
|
| 4 |
+
|
| 5 |
+
## Prerequisites Checklist
|
| 6 |
+
|
| 7 |
+
- [ ] Python 3.8+ installed
|
| 8 |
+
- [ ] Node.js 16+ installed
|
| 9 |
+
- [ ] VQA model checkpoints available
|
| 10 |
+
- [ ] Smartphone with Expo Go app installed
|
| 11 |
+
- [ ] Computer and phone on same WiFi network
|
| 12 |
+
|
| 13 |
+
## Step-by-Step Setup
|
| 14 |
+
|
| 15 |
+
### Step 1: Start the Backend Server
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
# Open terminal/PowerShell
|
| 19 |
+
cd c:\Users\rdeva\Downloads\vqa_coes
|
| 20 |
+
|
| 21 |
+
# Install backend dependencies (first time only)
|
| 22 |
+
pip install -r requirements_api.txt
|
| 23 |
+
|
| 24 |
+
# Start the server
|
| 25 |
+
python backend_api.py
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
**Expected output:**
|
| 29 |
+
```
|
| 30 |
+
🚀 INITIALIZING ENSEMBLE VQA SYSTEM
|
| 31 |
+
⚙️ Device: cuda
|
| 32 |
+
📥 Loading models...
|
| 33 |
+
✅ Ensemble ready!
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
**Important:** Keep this terminal window open! The server must keep running.
|
| 37 |
+
|
| 38 |
+
### Step 2: Find Your Local IP Address
|
| 39 |
+
|
| 40 |
+
**Windows:**
|
| 41 |
+
```bash
|
| 42 |
+
ipconfig
|
| 43 |
+
```
|
| 44 |
+
Look for "IPv4 Address" under your WiFi adapter (e.g., `192.168.1.100`)
|
| 45 |
+
|
| 46 |
+
**Mac/Linux:**
|
| 47 |
+
```bash
|
| 48 |
+
ifconfig
|
| 49 |
+
# or
|
| 50 |
+
ip addr
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### Step 3: Configure the Mobile App
|
| 54 |
+
|
| 55 |
+
1. Open `ui/src/config/api.js`
|
| 56 |
+
2. Replace the IP address:
|
| 57 |
+
```javascript
|
| 58 |
+
export const API_BASE_URL = 'http://YOUR_IP_HERE:8000';
|
| 59 |
+
// Example: export const API_BASE_URL = 'http://192.168.1.100:8000';
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
### Step 4: Configure Google OAuth (Optional for Testing)
|
| 63 |
+
|
| 64 |
+
**For testing without Google login**, you can skip this and modify the app to bypass authentication.
|
| 65 |
+
|
| 66 |
+
**For full Google login:**
|
| 67 |
+
|
| 68 |
+
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
| 69 |
+
2. Create a project
|
| 70 |
+
3. Enable Google+ API
|
| 71 |
+
4. Create OAuth 2.0 credentials
|
| 72 |
+
5. Update `ui/src/config/google.js` with your client IDs
|
| 73 |
+
|
| 74 |
+
### Step 5: Start the Mobile App
|
| 75 |
+
|
| 76 |
+
```bash
|
| 77 |
+
# Open a NEW terminal/PowerShell
|
| 78 |
+
cd c:\Users\rdeva\Downloads\vqa_coes\ui
|
| 79 |
+
|
| 80 |
+
# Start Expo
|
| 81 |
+
npm start
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
**Expected output:**
|
| 85 |
+
```
|
| 86 |
+
Metro waiting on exp://192.168.1.100:8081
|
| 87 |
+
› Scan the QR code above with Expo Go (Android) or the Camera app (iOS)
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
### Step 6: Run on Your Phone
|
| 91 |
+
|
| 92 |
+
1. **Install Expo Go** on your phone:
|
| 93 |
+
- [Android - Play Store](https://play.google.com/store/apps/details?id=host.exp.exponent)
|
| 94 |
+
- [iOS - App Store](https://apps.apple.com/app/expo-go/id982107779)
|
| 95 |
+
|
| 96 |
+
2. **Scan the QR code**:
|
| 97 |
+
- Android: Open Expo Go app → Scan QR code
|
| 98 |
+
- iOS: Open Camera app → Scan QR code → Tap notification
|
| 99 |
+
|
| 100 |
+
3. **Wait for app to load** (first time may take 1-2 minutes)
|
| 101 |
+
|
| 102 |
+
## Testing Without Google Login
|
| 103 |
+
|
| 104 |
+
If you want to test the VQA functionality without setting up Google OAuth:
|
| 105 |
+
|
| 106 |
+
1. Open `ui/App.js`
|
| 107 |
+
2. Temporarily modify the navigation to always show HomeScreen:
|
| 108 |
+
|
| 109 |
+
```javascript
|
| 110 |
+
// Replace this:
|
| 111 |
+
{user ? (
|
| 112 |
+
<Stack.Screen name="Home" component={HomeScreen} />
|
| 113 |
+
) : (
|
| 114 |
+
<Stack.Screen name="Login" component={LoginScreen} />
|
| 115 |
+
)}
|
| 116 |
+
|
| 117 |
+
// With this:
|
| 118 |
+
<Stack.Screen name="Home" component={HomeScreen} />
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
3. Restart the Expo server
|
| 122 |
+
|
| 123 |
+
## Testing the App
|
| 124 |
+
|
| 125 |
+
### Test 1: General Question (Base Model)
|
| 126 |
+
1. Tap "Gallery" and select an image
|
| 127 |
+
2. Enter question: "What color is the car?"
|
| 128 |
+
3. Tap "Ask Question"
|
| 129 |
+
4. Should show: 🔍 Base Model
|
| 130 |
+
|
| 131 |
+
### Test 2: Spatial Question (Spatial Model)
|
| 132 |
+
1. Select an image with multiple objects
|
| 133 |
+
2. Enter question: "What is to the right of the table?"
|
| 134 |
+
3. Tap "Ask Question"
|
| 135 |
+
4. Should show: 📍 Spatial Model
|
| 136 |
+
|
| 137 |
+
## Troubleshooting
|
| 138 |
+
|
| 139 |
+
### "Cannot connect to server"
|
| 140 |
+
- ✅ Check backend is running
|
| 141 |
+
- ✅ Verify IP address in `api.js` is correct
|
| 142 |
+
- ✅ Ensure phone and computer on same WiFi
|
| 143 |
+
- ✅ Check firewall isn't blocking port 8000
|
| 144 |
+
|
| 145 |
+
### "Model not loaded"
|
| 146 |
+
- ✅ Check checkpoint files are in project root
|
| 147 |
+
- ✅ Verify file names: `vqa_checkpoint.pt` and `vqa_spatial_checkpoint.pt`
|
| 148 |
+
- ✅ Check backend terminal for error messages
|
| 149 |
+
|
| 150 |
+
### App won't load on phone
|
| 151 |
+
- ✅ Ensure Expo Go is installed
|
| 152 |
+
- ✅ Check both devices on same network
|
| 153 |
+
- ✅ Try restarting Expo server (Ctrl+C, then `npm start`)
|
| 154 |
+
- ✅ Clear Expo cache: `npm start -- --clear`
|
| 155 |
+
|
| 156 |
+
### "Permission denied" for camera/gallery
|
| 157 |
+
- ✅ Grant permissions when prompted
|
| 158 |
+
- ✅ Check phone settings → App permissions
|
| 159 |
+
|
| 160 |
+
## Next Steps
|
| 161 |
+
|
| 162 |
+
Once everything works:
|
| 163 |
+
|
| 164 |
+
1. **Set up Google OAuth** for production use
|
| 165 |
+
2. **Customize the UI** in `src/styles/theme.js`
|
| 166 |
+
3. **Add custom icons** in `assets/` folder
|
| 167 |
+
4. **Build standalone app** with `eas build`
|
| 168 |
+
|
| 169 |
+
## Quick Commands Reference
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
# Start backend
|
| 173 |
+
cd c:\Users\rdeva\Downloads\vqa_coes
|
| 174 |
+
python backend_api.py
|
| 175 |
+
|
| 176 |
+
# Start mobile app
|
| 177 |
+
cd c:\Users\rdeva\Downloads\vqa_coes\ui
|
| 178 |
+
npm start
|
| 179 |
+
|
| 180 |
+
# Clear Expo cache
|
| 181 |
+
npm start -- --clear
|
| 182 |
+
|
| 183 |
+
# Install new package
|
| 184 |
+
npm install package-name
|
| 185 |
+
|
| 186 |
+
# Check backend health
|
| 187 |
+
curl http://localhost:8000/health
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
## Support
|
| 191 |
+
|
| 192 |
+
If you encounter issues:
|
| 193 |
+
1. Check the main README.md
|
| 194 |
+
2. Review backend terminal logs
|
| 195 |
+
3. Check Expo console for errors
|
| 196 |
+
4. Verify all prerequisites are met
|
README.md
CHANGED
|
@@ -1,10 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# GenVQA — Generative Visual Question Answering
|
| 4 |
+
|
| 5 |
+
**A neuro-symbolic VQA system that detects objects with a neural model, retrieves structured facts from Wikidata, and generates grounded answers with Groq.**
|
| 6 |
+
|
| 7 |
+
[](https://github.com/DevaRajan8/Generative-vqa/actions/workflows/backend-ci.yml)
|
| 8 |
+
[](https://github.com/DevaRajan8/Generative-vqa/actions/workflows/ui-ci.yml)
|
| 9 |
+

|
| 10 |
+

|
| 11 |
+
|
| 12 |
+
</div>
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## Architecture
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 20 |
+
│ CLIENT LAYER │
|
| 21 |
+
│ 📱 Expo Mobile App (React Native) │
|
| 22 |
+
│ • Image upload + question input │
|
| 23 |
+
│ • Displays answer + accessibility description │
|
| 24 |
+
└────────────────────────┬────────────────────────────────────┘
|
| 25 |
+
│ HTTP POST /api/answer
|
| 26 |
+
▼
|
| 27 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 28 |
+
│ BACKEND LAYER (FastAPI) │
|
| 29 |
+
│ backend_api.py │
|
| 30 |
+
│ • Request handling, session management │
|
| 31 |
+
│ • Conversation Manager → multi-turn context tracking │
|
| 32 |
+
└────────────────────────┬────────────────────────────────────┘
|
| 33 |
+
│
|
| 34 |
+
▼
|
| 35 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 36 |
+
│ ROUTING LAYER (ensemble_vqa_app.py) │
|
| 37 |
+
│ │
|
| 38 |
+
│ CLIP encodes question → compares against: │
|
| 39 |
+
│ "reasoning question" vs "visual/perceptual question" │
|
| 40 |
+
│ │
|
| 41 |
+
│ Reasoning? Visual? │
|
| 42 |
+
│ │ │ │
|
| 43 |
+
│ ▼ ▼ │
|
| 44 |
+
│ ┌─────────────────┐ ┌─────────────────────┐ │
|
| 45 |
+
│ │ NEURO-SYMBOLIC │ │ NEURAL VQA PATH │ │
|
| 46 |
+
│ │ │ │ │ │
|
| 47 |
+
│ │ 1. VQA model │ │ VQA model (GRU + │ │
|
| 48 |
+
│ │ detects obj │ │ Attention) predicts │ │
|
| 49 |
+
│ │ │ │ answer directly │ │
|
| 50 |
+
│ │ 2. Wikidata API │ └──────────┬──────────┘ │
|
| 51 |
+
│ │ fetches facts│ │ │
|
| 52 |
+
│ │ (P31, P2101, │ │ │
|
| 53 |
+
│ │ P2054, P186,│ │ │
|
| 54 |
+
│ │ P366 ...) │ │ │
|
| 55 |
+
│ │ │ │ │
|
| 56 |
+
│ │ 3. Groq LLM │ │ │
|
| 57 |
+
│ │ verbalizes │ │ │
|
| 58 |
+
│ │ from facts │ │ │
|
| 59 |
+
│ └─────────┬───────┘ │ │
|
| 60 |
+
│ └──────────────┬──────────┘ │
|
| 61 |
+
└────────────────────────── │ ─────────────────────────────┘
|
| 62 |
+
│
|
| 63 |
+
▼
|
| 64 |
+
┌─────────────────┐
|
| 65 |
+
│ GROQ SERVICE │
|
| 66 |
+
│ Accessibility │
|
| 67 |
+
│ description │
|
| 68 |
+
│ (2 sentences, │
|
| 69 |
+
│ screen-reader │
|
| 70 |
+
│ friendly) │
|
| 71 |
+
└────��───┬────────┘
|
| 72 |
+
│
|
| 73 |
+
▼
|
| 74 |
+
JSON response
|
| 75 |
+
{ answer, model_used,
|
| 76 |
+
kg_enhancement,
|
| 77 |
+
wikidata_entity,
|
| 78 |
+
description }
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
| Layer | Component | Role |
|
| 82 |
+
|---|---|---|
|
| 83 |
+
| **Client** | Expo React Native | Image upload, question input, answer display |
|
| 84 |
+
| **API** | FastAPI (`backend_api.py`) | Routing, sessions, conversation state |
|
| 85 |
+
| **Conversation** | `conversation_manager.py` | Multi-turn context, history tracking |
|
| 86 |
+
| **Router** | CLIP (in `ensemble_vqa_app.py`) | Classifies question as reasoning vs visual |
|
| 87 |
+
| **Neural VQA** | GRU + Attention (`model.py`) | Answers visual questions directly from image |
|
| 88 |
+
| **Neuro-Symbolic** | `semantic_neurosymbolic_vqa.py` | VQA detects objects → Wikidata fetches facts → Groq verbalizes |
|
| 89 |
+
| **Accessibility** | `groq_service.py` | Generates spoken-friendly 2-sentence description for every answer |
|
| 90 |
+
|
| 91 |
+
---
|
| 92 |
+
|
| 93 |
+
## Features
|
| 94 |
+
|
| 95 |
+
- 🔍 **Visual Question Answering** — trained on VQAv2, fine-tuned on custom data
|
| 96 |
+
- 🧠 **Neuro-Symbolic Routing** — CLIP semantically classifies questions as _reasoning_ vs _visual_, routes accordingly
|
| 97 |
+
- 🌐 **Live Wikidata Facts** — queries physical properties, categories, materials, uses in real time
|
| 98 |
+
- 🤖 **Groq Verbalization** — Llama 3.3 70B answers from structured facts, not hallucination
|
| 99 |
+
- 💬 **Conversational Support** — multi-turn conversation manager with context tracking
|
| 100 |
+
- 📱 **Expo Mobile UI** — React Native app for iOS/Android/Web
|
| 101 |
+
- ♿ **Accessibility** — Groq generates spoken-friendly descriptions for every answer
|
| 102 |
+
|
| 103 |
+
---
|
| 104 |
+
|
| 105 |
+
## Quick Start
|
| 106 |
+
|
| 107 |
+
### 1 — Backend
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
# Clone and install
|
| 111 |
+
git clone https://github.com/DevaRajan8/Generative-vqa.git
|
| 112 |
+
cd Generative-vqa
|
| 113 |
+
pip install -r requirements_api.txt
|
| 114 |
+
|
| 115 |
+
# Set your Groq API key
|
| 116 |
+
cp .env.example .env
|
| 117 |
+
# Edit .env → GROQ_API_KEY=your_key_here
|
| 118 |
+
|
| 119 |
+
# Start API
|
| 120 |
+
python backend_api.py
|
| 121 |
+
# → http://localhost:8000
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### 2 — Mobile UI
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
cd ui
|
| 128 |
+
npm install
|
| 129 |
+
npx expo start --clear
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
> Scan the QR code with Expo Go, or press `w` for browser.
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
## API
|
| 137 |
+
|
| 138 |
+
| Endpoint | Method | Description |
|
| 139 |
+
|---|---|---|
|
| 140 |
+
| `/api/answer` | POST | Answer a question about an uploaded image |
|
| 141 |
+
| `/api/health` | GET | Health check |
|
| 142 |
+
| `/api/conversation/new` | POST | Start a new conversation session |
|
| 143 |
+
|
| 144 |
+
**Example:**
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
curl -X POST http://localhost:8000/api/answer \
|
| 148 |
+
-F "image=@photo.jpg" \
|
| 149 |
+
-F "question=Can this melt?"
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
**Response:**
|
| 153 |
+
|
| 154 |
+
```json
|
| 155 |
+
{
|
| 156 |
+
"answer": "ice",
|
| 157 |
+
"model_used": "neuro-symbolic",
|
| 158 |
+
"kg_enhancement": "Yes — ice can melt. [Wikidata P2101: melting point = 0.0 °C]",
|
| 159 |
+
"knowledge_source": "VQA (neural) + Wikidata (symbolic) + Groq (verbalize)",
|
| 160 |
+
"wikidata_entity": "Q86"
|
| 161 |
+
}
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
---
|
| 165 |
+
|
| 166 |
+
## Project Structure
|
| 167 |
+
|
| 168 |
+
```
|
| 169 |
+
├── backend_api.py # FastAPI server
|
| 170 |
+
├── ensemble_vqa_app.py # VQA orchestrator (routing + inference)
|
| 171 |
+
├── semantic_neurosymbolic_vqa.py # Wikidata KB + Groq verbalizer
|
| 172 |
+
├── groq_service.py # Groq accessibility descriptions
|
| 173 |
+
├── conversation_manager.py # Multi-turn conversation tracking
|
| 174 |
+
├── model.py # VQA model definition
|
| 175 |
+
├── train.py # Training pipeline
|
| 176 |
+
├── ui/ # Expo React Native app
|
| 177 |
+
│ └── src/screens/HomeScreen.js
|
| 178 |
+
└── .github/
|
| 179 |
+
├── workflows/ # CI — backend lint + UI build
|
| 180 |
+
└── ISSUE_TEMPLATE/
|
| 181 |
+
```
|
| 182 |
+
|
| 183 |
---
|
| 184 |
+
|
| 185 |
+
## Environment Variables
|
| 186 |
+
|
| 187 |
+
| Variable | Required | Description |
|
| 188 |
+
|---|---|---|
|
| 189 |
+
| `GROQ_API_KEY` | ✅ | Groq API key — [get one free](https://console.groq.com) |
|
| 190 |
+
| `MODEL_PATH` | optional | Path to VQA checkpoint (default: `vqa_checkpoint.pt`) |
|
| 191 |
+
| `PORT` | optional | API server port (default: `8000`) |
|
| 192 |
+
|
| 193 |
---
|
| 194 |
|
| 195 |
+
## Requirements
|
| 196 |
+
|
| 197 |
+
- Python 3.10+
|
| 198 |
+
- CUDA GPU recommended (CPU works but is slow)
|
| 199 |
+
- Node.js 20+ (for UI)
|
| 200 |
+
- Groq API key (free tier available)
|
| 201 |
+
|
| 202 |
+
---
|
| 203 |
+
|
| 204 |
+
## License
|
| 205 |
+
|
| 206 |
+
MIT © [DevaRajan8](https://github.com/DevaRajan8)
|
README_COMPLETE.md
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
|
| 3 |
+
# 🧠 GenVQA — Generative Visual Question Answering
|
| 4 |
+
|
| 5 |
+
**A hybrid neuro-symbolic VQA system that intelligently routes between pure neural networks and knowledge-grounded reasoning**
|
| 6 |
+
|
| 7 |
+
</div>
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## Overview
|
| 12 |
+
|
| 13 |
+
GenVQA is an advanced Visual Question Answering system that combines the best of both worlds:
|
| 14 |
+
|
| 15 |
+
- **Neural networks** for perception-based visual questions
|
| 16 |
+
- **Symbolic reasoning** for knowledge-intensive reasoning questions
|
| 17 |
+
|
| 18 |
+
The system automatically classifies incoming questions and routes them to the optimal processing pipeline, ensuring accurate and grounded answers.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## System Architecture
|
| 23 |
+
|
| 24 |
+
```
|
| 25 |
+
┌──────────────────────────────────────────────────────────────────┐
|
| 26 |
+
│ CLIENT │
|
| 27 |
+
│ Expo React Native App (iOS/Android/Web) │
|
| 28 |
+
│ • Image upload via camera/gallery │
|
| 29 |
+
│ • Question input with suggested prompts │
|
| 30 |
+
│ • Multi-turn conversational interface │
|
| 31 |
+
│ • Google OAuth authentication │
|
| 32 |
+
└───────────────────────────┬──────────────────────────────────────┘
|
| 33 |
+
│ HTTP POST /api/answer
|
| 34 |
+
▼
|
| 35 |
+
┌──────────────────────────────────────────────────────────────────┐
|
| 36 |
+
│ BACKEND API LAYER │
|
| 37 |
+
│ FastAPI (backend_api.py) │
|
| 38 |
+
│ • Request handling & validation │
|
| 39 |
+
│ • Session management & authentication │
|
| 40 |
+
│ • Multi-turn conversation tracking │
|
| 41 |
+
└───────────────────────────┬──────────────────────────────────────┘
|
| 42 |
+
│
|
| 43 |
+
▼
|
| 44 |
+
┌──────────────────────────────────────────────────────────────────┐
|
| 45 |
+
│ INTELLIGENT ROUTING LAYER │
|
| 46 |
+
│ (ensemble_vqa_app.py) │
|
| 47 |
+
│ │
|
| 48 |
+
│ CLIP Semantic Classifier: │
|
| 49 |
+
│ Encodes question → Compares similarity: │
|
| 50 |
+
│ "This is a reasoning question about facts" │
|
| 51 |
+
│ vs │
|
| 52 |
+
│ "This is a visual perception question" │
|
| 53 |
+
│ │
|
| 54 |
+
│ Similarity > threshold?
|
| 55 |
+
│
|
| 56 |
+
│ ├─────────┬────────┐ │
|
| 57 |
+
│ │ │ │ │
|
| 58 |
+
│ REASONING VISUAL SPATIAL │
|
| 59 |
+
│ │ │ │ │
|
| 60 |
+
└─────────────────────┼─────────┼────────┼─────────────────────────┘
|
| 61 |
+
│ │ │
|
| 62 |
+
┌─────────────┘ │ └─────────────┐
|
| 63 |
+
▼ ▼ ▼
|
| 64 |
+
┌──────────────────┐ ┌───────────────────┐ ┌─────────────────┐
|
| 65 |
+
│ NEURO-SYMBOLIC │ │ NEURAL VQA PATH │ │ SPATIAL ADAPTER │
|
| 66 |
+
│ PIPELINE │ │ │ │ PATH │
|
| 67 |
+
│ │ │ CLIP + GRU + │ │ │
|
| 68 |
+
│ ① VQA Model │ │ Attention │ │ Enhanced with │
|
| 69 |
+
│ Detects │ │ │ │ spatial │
|
| 70 |
+
│ Objects │ │ Direct answer │ │ self-attention │
|
| 71 |
+
│ (e.g. "soup") │ │ prediction from │ │ for left/right │
|
| 72 |
+
│ │ │ image features │ │ above/below │
|
| 73 |
+
│ ② Wikidata API │ │ │ │ questions │
|
| 74 |
+
│ Fetches Facts │ │ Outputs: │ │ │
|
| 75 |
+
│ P31: category │ │ "red" │ │ Outputs: │
|
| 76 |
+
│ P186: material│ └───────┬───────────┘ │ "on the left" │
|
| 77 |
+
│ P2101: melting│ │ └────────┬────────┘
|
| 78 |
+
│ P366: use │ │ │
|
| 79 |
+
│ P2054: density│ │ │
|
| 80 |
+
│ │ │ │
|
| 81 |
+
│ ③ Groq LLM │ │ │
|
| 82 |
+
│ Verbalizes │ │ │
|
| 83 |
+
│ from facts │ │ │
|
| 84 |
+
│ (instead
|
| 85 |
+
of free │ │ │
|
| 86 |
+
│ reasoning) │ │ │
|
| 87 |
+
│ │ │ │
|
| 88 |
+
│ Outputs: │ │ │
|
| 89 |
+
│ "Soup is made of │ │ │
|
| 90 |
+
│ water and │ │ │
|
| 91 |
+
│ vegetables, │ │ │
|
| 92 |
+
│ used for eating"│ │ │
|
| 93 |
+
└────────┬─────────┘ │ │
|
| 94 |
+
│ │ │
|
| 95 |
+
└──────────┬──────────┴────────────────────────┘
|
| 96 |
+
▼
|
| 97 |
+
┌──────────────────────┐
|
| 98 |
+
│ GROQ ACCESSIBILITY │
|
| 99 |
+
│ SERVICE │
|
| 100 |
+
│ │
|
| 101 |
+
│ Generates 2-sentence│
|
| 102 |
+
│ screen-reader │
|
| 103 |
+
│ friendly description│
|
| 104 |
+
│ for every answer │
|
| 105 |
+
└──────────┬───────────┘
|
| 106 |
+
│
|
| 107 |
+
▼
|
| 108 |
+
JSON Response
|
| 109 |
+
{
|
| 110 |
+
"answer": "...",
|
| 111 |
+
"model_used": "neuro_symbolic|base|spatial",
|
| 112 |
+
"confidence": 0.85,
|
| 113 |
+
"kg_enhancement": true/false,
|
| 114 |
+
"wikidata_entity": "Q123456",
|
| 115 |
+
"description": "...",
|
| 116 |
+
"session_id": "..."
|
| 117 |
+
}
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## Neural vs Neuro-Symbolic: Deep Dive
|
| 123 |
+
|
| 124 |
+
### Neural Pathway
|
| 125 |
+
|
| 126 |
+
**When Used**: Perceptual questions about what's directly visible
|
| 127 |
+
|
| 128 |
+
- _"What color is the car?"_
|
| 129 |
+
- _"How many people are in the image?"_
|
| 130 |
+
- _"Is the dog sitting or standing?"_
|
| 131 |
+
|
| 132 |
+
**Architecture**:
|
| 133 |
+
|
| 134 |
+
```
|
| 135 |
+
Image Input
|
| 136 |
+
│
|
| 137 |
+
▼
|
| 138 |
+
┌─────────────────────────────┐
|
| 139 |
+
│ CLIP Vision Encoder │
|
| 140 |
+
│ (ViT-B/16) │
|
| 141 |
+
│ • Pre-trained on 400M │
|
| 142 |
+
│ image-text pairs │
|
| 143 |
+
│ • 512-dim embeddings │
|
| 144 |
+
└──────────┬──────────────────┘
|
| 145 |
+
│
|
| 146 |
+
▼
|
| 147 |
+
[512-dim vector] ────────────┐
|
| 148 |
+
│
|
| 149 |
+
Question Input │
|
| 150 |
+
│ │
|
| 151 |
+
▼ │
|
| 152 |
+
┌─────────────────────────────┐ │
|
| 153 |
+
│ GPT-2 Text Encoder │ │
|
| 154 |
+
│ (distilgpt2) │ │
|
| 155 |
+
│ • Contextual embeddings │ │
|
| 156 |
+
│ • 768-dim output │ │
|
| 157 |
+
└──────────┬──────────────────┘ │
|
| 158 |
+
│ │
|
| 159 |
+
▼ │
|
| 160 |
+
[768-dim vector] │
|
| 161 |
+
│ │
|
| 162 |
+
▼ │
|
| 163 |
+
┌──────────────┐ │
|
| 164 |
+
│ Linear Proj │ │
|
| 165 |
+
│ 768 → 512 │ │
|
| 166 |
+
└──────┬───────┘ │
|
| 167 |
+
│ │
|
| 168 |
+
└───────────┬───────────┘
|
| 169 |
+
│
|
| 170 |
+
▼
|
| 171 |
+
┌──────────────────────┐
|
| 172 |
+
│ Multimodal Fusion │
|
| 173 |
+
│ • Gated combination │
|
| 174 |
+
│ • 3-layer MLP │
|
| 175 |
+
│ • ReLU + Dropout │
|
| 176 |
+
└──────────┬───────────┘
|
| 177 |
+
│
|
| 178 |
+
▼
|
| 179 |
+
┌──────────────────────┐
|
| 180 |
+
│ GRU Decoder with │
|
| 181 |
+
│ Attention Mechanism │
|
| 182 |
+
│ │
|
| 183 |
+
│ • Hidden: 512-dim │
|
| 184 |
+
│ • 2 layers │
|
| 185 |
+
│ • Seq2seq decoding │
|
| 186 |
+
│ • Attention over │
|
| 187 |
+
│ fused features │
|
| 188 |
+
└──────────┬───────────┘
|
| 189 |
+
│
|
| 190 |
+
▼
|
| 191 |
+
Answer Tokens
|
| 192 |
+
"red car"
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
**Key Components**:
|
| 196 |
+
|
| 197 |
+
- **CLIP**: Zero-shot image understanding, robust to domain shift
|
| 198 |
+
- **GPT-2**: Contextual question encoding
|
| 199 |
+
- **Attention**: Decoder focuses on relevant image regions per word
|
| 200 |
+
- **GRU**: Sequential answer generation with memory
|
| 201 |
+
|
| 202 |
+
**Training**:
|
| 203 |
+
|
| 204 |
+
- Dataset: VQA v2 (curated, balanced subset)
|
| 205 |
+
- Loss: Cross-entropy over answer vocabulary
|
| 206 |
+
- Fine-tuning: Last 2 CLIP layers + full decoder
|
| 207 |
+
- Accuracy: ~39% on general VQA, ~28% on spatial questions
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
### Neuro-Symbolic Pathway (Knowledge-Grounded Reasoning)
|
| 212 |
+
|
| 213 |
+
**When Used**: Questions requiring external knowledge or reasoning
|
| 214 |
+
|
| 215 |
+
- _"Can soup melt?"_
|
| 216 |
+
- _"What is ice cream made of?"_
|
| 217 |
+
- _"Does this float in water?"_
|
| 218 |
+
|
| 219 |
+
**Architecture**:
|
| 220 |
+
|
| 221 |
+
```
|
| 222 |
+
Step 1: NEURAL DETECTION
|
| 223 |
+
─────────────────────────
|
| 224 |
+
Image + Question
|
| 225 |
+
│
|
| 226 |
+
▼
|
| 227 |
+
┌──────────────────────┐
|
| 228 |
+
│ VQA Model │
|
| 229 |
+
│ (same as above) │
|
| 230 |
+
│ │
|
| 231 |
+
│ Predicts: "soup" │
|
| 232 |
+
└──────────┬───────────┘
|
| 233 |
+
│
|
| 234 |
+
▼
|
| 235 |
+
Detected Object
|
| 236 |
+
"soup"
|
| 237 |
+
|
| 238 |
+
Step 2: SYMBOLIC FACT RETRIEVAL
|
| 239 |
+
────────────────────────────────
|
| 240 |
+
"soup"
|
| 241 |
+
│
|
| 242 |
+
▼
|
| 243 |
+
┌──────────────────────────────────┐
|
| 244 |
+
│ Wikidata SPARQL Queries │
|
| 245 |
+
│ │
|
| 246 |
+
│ ① Entity Resolution: │
|
| 247 |
+
│ "soup" → Q41415 (Wikidata ID) │
|
| 248 |
+
│ │
|
| 249 |
+
│ ② Fetch ALL Relevant Properties: │
|
| 250 |
+
│ │
|
| 251 |
+
│ P31 (instance of): │
|
| 252 |
+
│ → "food" │
|
| 253 |
+
│ → "liquid food" │
|
| 254 |
+
│ → "dish" │
|
| 255 |
+
│ │
|
| 256 |
+
│ P186 (made of): │
|
| 257 |
+
│ → "water" │
|
| 258 |
+
│ → "vegetables" │
|
| 259 |
+
│ → "broth" │
|
| 260 |
+
│ │
|
| 261 |
+
│ P366 (used for): │
|
| 262 |
+
│ → "consumption" │
|
| 263 |
+
│ → "nutrition" │
|
| 264 |
+
│ │
|
| 265 |
+
│ P2101 (melting point): │
|
| 266 |
+
│ → (not found) │
|
| 267 |
+
│ │
|
| 268 |
+
│ P2054 (density): │
|
| 269 |
+
│ → ~1000 kg/m³ │
|
| 270 |
+
│ → (floats/sinks calc) │
|
| 271 |
+
│ │
|
| 272 |
+
│ P2777 (flash point): │
|
| 273 |
+
│ → (not found) │
|
| 274 |
+
└──────────────┬───────────────────┘
|
| 275 |
+
│
|
| 276 |
+
▼
|
| 277 |
+
Structured Knowledge Graph
|
| 278 |
+
{
|
| 279 |
+
"entity": "soup (Q41415)",
|
| 280 |
+
"categories": ["food", "liquid"],
|
| 281 |
+
"materials": ["water", "vegetables"],
|
| 282 |
+
"uses": ["consumption"],
|
| 283 |
+
"density": 1000,
|
| 284 |
+
"melting_point": null
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
Step 3: LLM VERBALIZATION (NOT REASONING!)
|
| 288 |
+
───────────────────────────────────────────
|
| 289 |
+
Knowledge Graph
|
| 290 |
+
│
|
| 291 |
+
▼
|
| 292 |
+
┌────────────────────────────────────┐
|
| 293 |
+
│ Groq API │
|
| 294 |
+
│ (Llama 3.3 70B) │
|
| 295 |
+
│ │
|
| 296 |
+
│ System Prompt: │
|
| 297 |
+
│ "You are a fact verbalizer. │
|
| 298 |
+
│ Answer ONLY from provided │
|
| 299 |
+
│ Wikidata facts. Do NOT use │
|
| 300 |
+
│ your training knowledge. │
|
| 301 |
+
│ If facts don't contain the │
|
| 302 |
+
│ answer, say 'unknown from │
|
| 303 |
+
│ available data'." │
|
| 304 |
+
│ │
|
| 305 |
+
│ User Input: │
|
| 306 |
+
│ Question: "Can soup melt?" │
|
| 307 |
+
│ Facts: {structured data above} │
|
| 308 |
+
└────────────┬───────────────────────┘
|
| 309 |
+
│
|
| 310 |
+
▼
|
| 311 |
+
Natural Language Answer
|
| 312 |
+
"According to Wikidata, soup is
|
| 313 |
+
a liquid food made of water and
|
| 314 |
+
vegetables. Since it's already
|
| 315 |
+
liquid, it doesn't have a melting
|
| 316 |
+
point like solids do. It can
|
| 317 |
+
freeze, but not melt."
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
**Critical Design Principle**:
|
| 321 |
+
|
| 322 |
+
> Groq is a **verbalizer**, NOT a reasoner. All reasoning happens in the symbolic layer (Wikidata facts). Groq only translates structured facts into natural language.
|
| 323 |
+
|
| 324 |
+
**Why This Matters**:
|
| 325 |
+
|
| 326 |
+
- **Without facts**: Groq hallucinates from training data
|
| 327 |
+
- **With facts**: Groq grounds answers in real-time data
|
| 328 |
+
- **Result**: Factual accuracy, no made-up information
|
| 329 |
+
|
| 330 |
+
**Knowledge Base Properties Fetched**:
|
| 331 |
+
| Property | Wikidata Code | Example Value |
|
| 332 |
+
|----------|---------------|---------------|
|
| 333 |
+
| Category | P31 | "food", "tool", "animal" |
|
| 334 |
+
| Material | P186 | "metal", "wood", "plastic" |
|
| 335 |
+
| Melting Point | P2101 | 273.15 K (0°C) |
|
| 336 |
+
| Density | P2054 | 917 kg/m³ (floats/sinks) |
|
| 337 |
+
| Use | P366 | "eating", "transportation" |
|
| 338 |
+
| Flash Point | P2777 | 310 K (flammable) |
|
| 339 |
+
| Location | P276 | "ocean", "forest" |
|
| 340 |
+
|
| 341 |
+
---
|
| 342 |
+
|
| 343 |
+
### Spatial Reasoning Pathway
|
| 344 |
+
|
| 345 |
+
**When Used**: Questions about relative positions
|
| 346 |
+
|
| 347 |
+
- _"What is to the left of the car?"_
|
| 348 |
+
- _"Is the cat above or below the table?"_
|
| 349 |
+
|
| 350 |
+
**Architecture Enhancement**:
|
| 351 |
+
|
| 352 |
+
```
|
| 353 |
+
Base VQA Model
|
| 354 |
+
│
|
| 355 |
+
▼
|
| 356 |
+
┌──────────────────────────────┐
|
| 357 |
+
│ Spatial Self-Attention │
|
| 358 |
+
│ • Multi-head attention (8) │
|
| 359 |
+
│ • Learns spatial relations │
|
| 360 |
+
│ • Position-aware weighting │
|
| 361 |
+
└──────────┬───────────────────┘
|
| 362 |
+
│
|
| 363 |
+
▼
|
| 364 |
+
Spatial-aware answer
|
| 365 |
+
"on the left side"
|
| 366 |
+
```
|
| 367 |
+
|
| 368 |
+
**Keyword Triggering**:
|
| 369 |
+
|
| 370 |
+
- Detects: `left`, `right`, `above`, `below`, `top`, `bottom`, `next to`, `behind`, `between`, etc.
|
| 371 |
+
- Routes to spatial adapter model
|
| 372 |
+
- Enhanced accuracy on positional questions
|
| 373 |
+
|
| 374 |
+
---
|
| 375 |
+
|
| 376 |
+
## Intelligent Routing System
|
| 377 |
+
|
| 378 |
+
**CLIP-Based Semantic Routing**:
|
| 379 |
+
|
| 380 |
+
```python
|
| 381 |
+
# Encode question with CLIP
|
| 382 |
+
question_embedding = clip.encode_text(question)
|
| 383 |
+
|
| 384 |
+
# Compare against two templates
|
| 385 |
+
reasoning_prompt = "This is a reasoning question about facts and knowledge"
|
| 386 |
+
visual_prompt = "This is a visual perception question about what you see"
|
| 387 |
+
|
| 388 |
+
reasoning_similarity = cosine_similarity(question_embedding,
|
| 389 |
+
clip.encode_text(reasoning_prompt))
|
| 390 |
+
visual_similarity = cosine_similarity(question_embedding,
|
| 391 |
+
clip.encode_text(visual_prompt))
|
| 392 |
+
|
| 393 |
+
# Route decision
|
| 394 |
+
if reasoning_similarity > visual_similarity + THRESHOLD:
|
| 395 |
+
route_to_neuro_symbolic()
|
| 396 |
+
elif contains_spatial_keywords(question):
|
| 397 |
+
route_to_spatial_adapter()
|
| 398 |
+
else:
|
| 399 |
+
route_to_base_neural()
|
| 400 |
+
```
|
| 401 |
+
|
| 402 |
+
**Routing Logic**:
|
| 403 |
+
|
| 404 |
+
1. **Neuro-Symbolic** if CLIP classifies as reasoning (>0.6 similarity)
|
| 405 |
+
2. **Spatial** if contains spatial keywords (`left`, `right`, `above`, etc.)
|
| 406 |
+
3. **Base Neural** for all other visual perception questions
|
| 407 |
+
|
| 408 |
+
---
|
| 409 |
+
|
| 410 |
+
## Multi-Turn Conversation Support
|
| 411 |
+
|
| 412 |
+
**Conversation Manager Features**:
|
| 413 |
+
|
| 414 |
+
- Session tracking with UUID
|
| 415 |
+
- Context retention across turns
|
| 416 |
+
- Pronoun resolution (`it`, `this`, `that` → previous object)
|
| 417 |
+
- Automatic session expiry (30 min timeout)
|
| 418 |
+
|
| 419 |
+
**Example Conversation**:
|
| 420 |
+
|
| 421 |
+
```
|
| 422 |
+
Turn 1:
|
| 423 |
+
User: "What is this?"
|
| 424 |
+
VQA: "A red car"
|
| 425 |
+
Objects: ["car"]
|
| 426 |
+
|
| 427 |
+
Turn 2:
|
| 428 |
+
User: "Can it float?" # "it" = "car"
|
| 429 |
+
System: Resolves "it" → "car"
|
| 430 |
+
VQA: [Neuro-Symbolic] "According to Wikidata, cars are made
|
| 431 |
+
of metal and plastic with density around 800-1000 kg/m³,
|
| 432 |
+
which is close to water. Most cars would sink."
|
| 433 |
+
|
| 434 |
+
Turn 3:
|
| 435 |
+
User: "What color is it again?" # Still referring to car
|
| 436 |
+
VQA: [Neural] "red" # From Turn 1 context
|
| 437 |
+
```
|
| 438 |
+
|
| 439 |
+
---
|
| 440 |
+
|
| 441 |
+
## Quick Start
|
| 442 |
+
|
| 443 |
+
### Prerequisites
|
| 444 |
+
|
| 445 |
+
- Python 3.10+
|
| 446 |
+
- CUDA GPU (recommended, 4GB+ VRAM)
|
| 447 |
+
- Node.js 16+ (for mobile UI)
|
| 448 |
+
- Groq API key ([get one free](https://console.groq.com))
|
| 449 |
+
|
| 450 |
+
### Backend Setup
|
| 451 |
+
|
| 452 |
+
```bash
|
| 453 |
+
# 1. Clone repository
|
| 454 |
+
git clone https://github.com/YourUsername/vqa_coes.git
|
| 455 |
+
cd vqa_coes
|
| 456 |
+
|
| 457 |
+
# 2. Install dependencies
|
| 458 |
+
pip install -r requirements_api.txt
|
| 459 |
+
|
| 460 |
+
# 3. Set environment variables
|
| 461 |
+
echo "GROQ_API_KEY=your_groq_api_key_here" > .env
|
| 462 |
+
|
| 463 |
+
# 4. Download model checkpoints (if not included)
|
| 464 |
+
# Ensure these files exist in project root:
|
| 465 |
+
# - vqa_checkpoint.pt (base model)
|
| 466 |
+
# - vqa_spatial_checkpoint.pt (spatial model)
|
| 467 |
+
|
| 468 |
+
# 5. Start API server
|
| 469 |
+
python backend_api.py
|
| 470 |
+
|
| 471 |
+
# Server will start at http://localhost:8000
|
| 472 |
+
```
|
| 473 |
+
|
| 474 |
+
### Mobile UI Setup
|
| 475 |
+
|
| 476 |
+
```bash
|
| 477 |
+
# 1. Navigate to UI folder
|
| 478 |
+
cd ui
|
| 479 |
+
|
| 480 |
+
# 2. Install dependencies
|
| 481 |
+
npm install
|
| 482 |
+
|
| 483 |
+
# 3. Configure API endpoint
|
| 484 |
+
# Edit ui/src/config/api.js
|
| 485 |
+
# Change: export const API_BASE_URL = 'http://YOUR_LOCAL_IP:8000';
|
| 486 |
+
|
| 487 |
+
# 4. Start Expo
|
| 488 |
+
npx expo start --clear
|
| 489 |
+
|
| 490 |
+
# Scan QR code with Expo Go app, or press 'w' for web
|
| 491 |
+
```
|
| 492 |
+
|
| 493 |
+
---
|
| 494 |
+
|
| 495 |
+
## 🔧 API Reference
|
| 496 |
+
|
| 497 |
+
### POST `/api/answer`
|
| 498 |
+
|
| 499 |
+
Answer a visual question with optional conversation context.
|
| 500 |
+
|
| 501 |
+
**Request**:
|
| 502 |
+
|
| 503 |
+
```bash
|
| 504 |
+
curl -X POST http://localhost:8000/api/answer \
|
| 505 |
+
-F "image=@photo.jpg" \
|
| 506 |
+
-F "question=Can this float in water?" \
|
| 507 |
+
-F "session_id=optional-uuid-here"
|
| 508 |
+
```
|
| 509 |
+
|
| 510 |
+
**Response**:
|
| 511 |
+
|
| 512 |
+
```json
|
| 513 |
+
{
|
| 514 |
+
"answer": "According to Wikidata, this object has a density of 917 kg/m³, which is less than water (1000 kg/m³), so it would float.",
|
| 515 |
+
"model_used": "neuro_symbolic",
|
| 516 |
+
"confidence": 0.87,
|
| 517 |
+
"kg_enhancement": true,
|
| 518 |
+
"wikidata_entity": "Q41576",
|
| 519 |
+
"description": "The object appears to be made of ice. Based on its physical properties from scientific data, it would float on water due to lower density.",
|
| 520 |
+
"session_id": "550e8400-e29b-41d4-a716-446655440000",
|
| 521 |
+
"conversation_turn": 2
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
## 📄 License
|
| 526 |
+
|
| 527 |
+
MIT License - see LICENSE file for details
|
| 528 |
+
|
| 529 |
+
---
|
| 530 |
+
```
|
SETUP_GUIDE.md
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VQA Accessibility Enhancement - Setup Guide
|
| 2 |
+
|
| 3 |
+
## Backend Setup
|
| 4 |
+
|
| 5 |
+
### 1. Install Python Dependencies
|
| 6 |
+
```bash
|
| 7 |
+
cd c:\Users\rdeva\Downloads\vqa_coes
|
| 8 |
+
pip install -r requirements_api.txt
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
### 2. Configure Groq API Key
|
| 12 |
+
|
| 13 |
+
1. Get your Groq API key from: https://console.groq.com/keys
|
| 14 |
+
2. Create a `.env` file in the project root:
|
| 15 |
+
```bash
|
| 16 |
+
copy .env.example .env
|
| 17 |
+
```
|
| 18 |
+
3. Edit `.env` and add your API key:
|
| 19 |
+
```
|
| 20 |
+
GROQ_API_KEY=your_actual_groq_api_key_here
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
### 3. Start Backend Server
|
| 24 |
+
```bash
|
| 25 |
+
python backend_api.py
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
The server will start on `http://localhost:8000`
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## Frontend Setup
|
| 33 |
+
|
| 34 |
+
### 1. Install Node Dependencies
|
| 35 |
+
```bash
|
| 36 |
+
cd ui
|
| 37 |
+
npm install
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
This will install the new `expo-speech` package for text-to-speech functionality.
|
| 41 |
+
|
| 42 |
+
### 2. Start Expo App
|
| 43 |
+
```bash
|
| 44 |
+
npm start
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
Then:
|
| 48 |
+
- Press `a` for Android emulator
|
| 49 |
+
- Press `i` for iOS simulator
|
| 50 |
+
- Scan QR code with Expo Go app for physical device
|
| 51 |
+
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
## Testing the Features
|
| 55 |
+
|
| 56 |
+
### Image Display Fix
|
| 57 |
+
1. Open the app
|
| 58 |
+
2. Tap "Camera" or "Gallery" to select an image
|
| 59 |
+
3. **Expected**: Image should display correctly (no blank screen)
|
| 60 |
+
|
| 61 |
+
### LLM Description Feature
|
| 62 |
+
1. Upload an image
|
| 63 |
+
2. Enter a question (e.g., "What color is the car?")
|
| 64 |
+
3. Tap "Ask Question"
|
| 65 |
+
4. **Expected**:
|
| 66 |
+
- Original answer appears in the "Answer" card
|
| 67 |
+
- "Accessible Description" card appears below with 2-sentence description
|
| 68 |
+
- Speaker icon button is visible
|
| 69 |
+
|
| 70 |
+
### Text-to-Speech
|
| 71 |
+
1. After getting an answer with description
|
| 72 |
+
2. Tap the speaker icon (🔊) in the "Accessible Description" card
|
| 73 |
+
3. **Expected**: The description is read aloud
|
| 74 |
+
4. Tap the stop icon (⏹️) to stop playback
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
## Troubleshooting
|
| 79 |
+
|
| 80 |
+
### Backend Issues
|
| 81 |
+
|
| 82 |
+
**Groq API Key Error**
|
| 83 |
+
```
|
| 84 |
+
ValueError: Groq API key not found
|
| 85 |
+
```
|
| 86 |
+
**Solution**: Make sure `.env` file exists with `GROQ_API_KEY=your_key`
|
| 87 |
+
|
| 88 |
+
**Models Not Loading**
|
| 89 |
+
```
|
| 90 |
+
❌ Base checkpoint not found
|
| 91 |
+
```
|
| 92 |
+
**Solution**: Ensure `vqa_checkpoint.pt` and `vqa_spatial_checkpoint.pt` are in the project root
|
| 93 |
+
|
| 94 |
+
### Frontend Issues
|
| 95 |
+
|
| 96 |
+
**Image Not Displaying**
|
| 97 |
+
- Make sure you've run `npm install` to get the latest `expo-image` package
|
| 98 |
+
- Check console logs for image URI format issues
|
| 99 |
+
|
| 100 |
+
**Text-to-Speech Not Working**
|
| 101 |
+
- Ensure device volume is turned up
|
| 102 |
+
- Check that `expo-speech` package is installed
|
| 103 |
+
- On iOS simulator, speech may not work (test on physical device)
|
| 104 |
+
|
| 105 |
+
**Cannot Connect to Backend**
|
| 106 |
+
- Verify backend is running on port 8000
|
| 107 |
+
- Update `ui/src/config/api.js` with correct backend URL
|
| 108 |
+
- For physical devices, use ngrok or your computer's local IP
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## Features Summary
|
| 113 |
+
|
| 114 |
+
✅ **Fixed**: Image display issue (using expo-image instead of react-native Image)
|
| 115 |
+
✅ **Added**: Groq LLM integration for 2-sentence descriptions
|
| 116 |
+
✅ **Added**: Text-to-speech accessibility feature
|
| 117 |
+
✅ **Added**: Visual distinction between raw answer and description
|
| 118 |
+
✅ **Added**: Fallback mode when Groq API is unavailable
|
VQA_ENHANCEMENTS.md
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VQA Enhancements: LLM Reasoning & Conversational VQA
|
| 2 |
+
|
| 3 |
+
This document describes the two major enhancements added to the VQA system.
|
| 4 |
+
|
| 5 |
+
## 🧠 Feature 1: LLM-Driven Reasoning Engine
|
| 6 |
+
|
| 7 |
+
### Overview
|
| 8 |
+
Replaced hardcoded if/else rules with **Groq LLM Chain-of-Thought reasoning** for intelligent deductive reasoning from Wikidata facts.
|
| 9 |
+
|
| 10 |
+
### What Changed
|
| 11 |
+
**Before**: Hardcoded rules in `semantic_neurosymbolic_vqa.py`
|
| 12 |
+
```python
|
| 13 |
+
if 'melt' in question:
|
| 14 |
+
check material properties...
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
**After**: LLM-driven reasoning
|
| 18 |
+
```python
|
| 19 |
+
reasoning_result = llm_service.reason_with_facts(
|
| 20 |
+
object_name="candle",
|
| 21 |
+
facts={"materials": ["wax"], "categories": ["light source"]},
|
| 22 |
+
question="Can this melt?"
|
| 23 |
+
)
|
| 24 |
+
# Returns: Chain-of-Thought reasoning + answer
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### Benefits
|
| 28 |
+
- ✅ Handles complex questions like "Would this survive a fire?"
|
| 29 |
+
- ✅ Provides transparent reasoning chains
|
| 30 |
+
- ✅ More flexible and generalizable
|
| 31 |
+
- ✅ Automatic fallback to rule-based reasoning if LLM fails
|
| 32 |
+
|
| 33 |
+
### Example
|
| 34 |
+
**Question**: "Can this melt?"
|
| 35 |
+
**Object**: Candle
|
| 36 |
+
**Facts**: Material: wax, Category: light source
|
| 37 |
+
|
| 38 |
+
**LLM Reasoning Chain**:
|
| 39 |
+
1. The object is a candle
|
| 40 |
+
2. It is made of wax
|
| 41 |
+
3. Wax has a low melting point (~60°C)
|
| 42 |
+
4. Therefore, yes, it can melt at moderate temperatures
|
| 43 |
+
|
| 44 |
+
**Answer**: "Yes, the candle can melt because it's made of wax, which has a low melting point."
|
| 45 |
+
|
| 46 |
+
### Files Added/Modified
|
| 47 |
+
- **NEW**: `llm_reasoning_service.py` - LLM reasoning with Chain-of-Thought
|
| 48 |
+
- **MODIFIED**: `semantic_neurosymbolic_vqa.py` - Integrated LLM reasoning
|
| 49 |
+
- **MODIFIED**: `backend_api.py` - Added reasoning_chain to API responses
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## 💬 Feature 2: Conversational VQA
|
| 54 |
+
|
| 55 |
+
### Overview
|
| 56 |
+
Added **multi-turn conversation support** with context management and pronoun resolution.
|
| 57 |
+
|
| 58 |
+
### What Changed
|
| 59 |
+
**Before**: Single-shot Q&A with no context
|
| 60 |
+
```
|
| 61 |
+
User: "What is this?" → System: "A red apple."
|
| 62 |
+
User: "Is it healthy?" → System: "What is 'it'?" ❌
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
**After**: Multi-turn conversations
|
| 66 |
+
```
|
| 67 |
+
User: "What is this?" → System: "A red apple."
|
| 68 |
+
User: "Is it healthy?" → System: "Yes, apples are rich in fiber..." ✅
|
| 69 |
+
(System knows "it" = apple)
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### Benefits
|
| 73 |
+
- ✅ Natural follow-up questions
|
| 74 |
+
- ✅ Context-aware pronoun resolution
|
| 75 |
+
- ✅ Session management with auto-expiration
|
| 76 |
+
- ✅ Conversation history tracking
|
| 77 |
+
|
| 78 |
+
### Example Conversation
|
| 79 |
+
```
|
| 80 |
+
Turn 1:
|
| 81 |
+
Q: "What is this?"
|
| 82 |
+
A: "A red apple"
|
| 83 |
+
Objects: ["apple"]
|
| 84 |
+
|
| 85 |
+
Turn 2:
|
| 86 |
+
Q: "Is it healthy?"
|
| 87 |
+
Resolved: "Is apple healthy?"
|
| 88 |
+
A: "Yes, apples are rich in fiber and vitamins"
|
| 89 |
+
|
| 90 |
+
Turn 3:
|
| 91 |
+
Q: "What color is it?"
|
| 92 |
+
Resolved: "What color is apple?"
|
| 93 |
+
A: "Red"
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### Files Added/Modified
|
| 97 |
+
- **NEW**: `conversation_manager.py` - Multi-turn conversation management
|
| 98 |
+
- **MODIFIED**: `ensemble_vqa_app.py` - Added `answer_conversational()` method
|
| 99 |
+
- **MODIFIED**: `backend_api.py` - Added conversation endpoints
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## 🚀 API Endpoints
|
| 104 |
+
|
| 105 |
+
### Existing Endpoint (Enhanced)
|
| 106 |
+
**POST** `/api/answer`
|
| 107 |
+
- Now includes `reasoning_chain` in response
|
| 108 |
+
- Backward compatible
|
| 109 |
+
|
| 110 |
+
### New Conversation Endpoints
|
| 111 |
+
|
| 112 |
+
**POST** `/api/conversation/answer`
|
| 113 |
+
- Multi-turn conversation support
|
| 114 |
+
- Request: `image`, `question`, `session_id` (optional)
|
| 115 |
+
- Response includes:
|
| 116 |
+
- `session_id` - For continuing conversation
|
| 117 |
+
- `resolved_question` - Question with pronouns resolved
|
| 118 |
+
- `conversation_context` - Previous turns, objects, etc.
|
| 119 |
+
- `reasoning_chain` - LLM reasoning steps (if applicable)
|
| 120 |
+
|
| 121 |
+
**GET** `/api/conversation/{session_id}/history`
|
| 122 |
+
- Get full conversation history
|
| 123 |
+
- Returns all turns with timestamps
|
| 124 |
+
|
| 125 |
+
**DELETE** `/api/conversation/{session_id}`
|
| 126 |
+
- Clear conversation session
|
| 127 |
+
- Useful for starting fresh
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## 📋 Usage Examples
|
| 132 |
+
|
| 133 |
+
### Example 1: LLM Reasoning (Python)
|
| 134 |
+
```python
|
| 135 |
+
from llm_reasoning_service import get_llm_reasoning_service
|
| 136 |
+
|
| 137 |
+
service = get_llm_reasoning_service()
|
| 138 |
+
|
| 139 |
+
result = service.reason_with_facts(
|
| 140 |
+
object_name="ice cream",
|
| 141 |
+
facts={
|
| 142 |
+
"materials": ["milk", "sugar", "cream"],
|
| 143 |
+
"categories": ["frozen dessert"]
|
| 144 |
+
},
|
| 145 |
+
question="Would this survive in the desert?"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
print(result['answer'])
|
| 149 |
+
# "No, ice cream would not survive in the desert because..."
|
| 150 |
+
|
| 151 |
+
print(result['reasoning_chain'])
|
| 152 |
+
# ["Ice cream is a frozen dessert", "Deserts are hot...", ...]
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
### Example 2: Conversational VQA (API)
|
| 156 |
+
```bash
|
| 157 |
+
# Turn 1: Ask what it is
|
| 158 |
+
curl -X POST http://localhost:8000/api/conversation/answer \
|
| 159 |
+
-F "image=@apple.jpg" \
|
| 160 |
+
-F "question=What is this?"
|
| 161 |
+
|
| 162 |
+
# Response: {"session_id": "abc123", "answer": "apple", ...}
|
| 163 |
+
|
| 164 |
+
# Turn 2: Follow-up question with pronoun
|
| 165 |
+
curl -X POST http://localhost:8000/api/conversation/answer \
|
| 166 |
+
-F "image=@apple.jpg" \
|
| 167 |
+
-F "question=Is it healthy?" \
|
| 168 |
+
-F "session_id=abc123"
|
| 169 |
+
|
| 170 |
+
# Response: {
|
| 171 |
+
# "resolved_question": "Is apple healthy?",
|
| 172 |
+
# "answer": "Yes, apples are healthy",
|
| 173 |
+
# "conversation_context": {"turn_number": 2, ...}
|
| 174 |
+
# }
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### Example 3: Conversational VQA (Python)
|
| 178 |
+
```python
|
| 179 |
+
from ensemble_vqa_app import ProductionEnsembleVQA
|
| 180 |
+
|
| 181 |
+
ensemble = ProductionEnsembleVQA(
|
| 182 |
+
base_checkpoint="vqa_checkpoint.pt",
|
| 183 |
+
spatial_checkpoint="vqa_spatial_checkpoint.pt"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Turn 1
|
| 187 |
+
result1 = ensemble.answer_conversational(
|
| 188 |
+
image_path="apple.jpg",
|
| 189 |
+
question="What is this?",
|
| 190 |
+
verbose=True
|
| 191 |
+
)
|
| 192 |
+
session_id = result1['session_id']
|
| 193 |
+
print(f"Answer: {result1['answer']}") # "apple"
|
| 194 |
+
|
| 195 |
+
# Turn 2 - pronoun resolution
|
| 196 |
+
result2 = ensemble.answer_conversational(
|
| 197 |
+
image_path="apple.jpg",
|
| 198 |
+
question="Is it healthy?",
|
| 199 |
+
session_id=session_id,
|
| 200 |
+
verbose=True
|
| 201 |
+
)
|
| 202 |
+
print(f"Resolved: {result2['resolved_question']}") # "Is apple healthy?"
|
| 203 |
+
print(f"Answer: {result2['answer']}") # "Yes, apples are healthy"
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
---
|
| 207 |
+
|
| 208 |
+
## ⚙️ Configuration
|
| 209 |
+
|
| 210 |
+
### Environment Variables
|
| 211 |
+
```bash
|
| 212 |
+
# Required for LLM reasoning
|
| 213 |
+
GROQ_API_KEY=your_groq_api_key_here
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
### Session Timeout
|
| 217 |
+
Conversations expire after **30 minutes** of inactivity (configurable in `ConversationManager`).
|
| 218 |
+
|
| 219 |
+
---
|
| 220 |
+
|
| 221 |
+
## 🧪 Testing
|
| 222 |
+
|
| 223 |
+
Run the test suite:
|
| 224 |
+
```bash
|
| 225 |
+
python test_vqa_enhancements.py
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
Tests include:
|
| 229 |
+
- ✅ LLM reasoning with various question types
|
| 230 |
+
- ✅ Conversation manager pronoun resolution
|
| 231 |
+
- ✅ Session management and expiration
|
| 232 |
+
- ✅ Integration with existing VQA system
|
| 233 |
+
|
| 234 |
+
---
|
| 235 |
+
|
| 236 |
+
## 🔄 Backward Compatibility
|
| 237 |
+
|
| 238 |
+
**All existing functionality remains intact:**
|
| 239 |
+
- ✅ Original `/api/answer` endpoint works unchanged
|
| 240 |
+
- ✅ Single-shot Q&A still supported
|
| 241 |
+
- ✅ Spatial routing unchanged
|
| 242 |
+
- ✅ Neuro-symbolic fallback preserved
|
| 243 |
+
|
| 244 |
+
**New features are opt-in:**
|
| 245 |
+
- Use `/api/conversation/answer` for multi-turn
|
| 246 |
+
- LLM reasoning activates automatically for reasoning questions
|
| 247 |
+
- Fallback to rule-based if LLM unavailable
|
| 248 |
+
|
| 249 |
+
---
|
| 250 |
+
|
| 251 |
+
## 📊 Architecture
|
| 252 |
+
|
| 253 |
+
```
|
| 254 |
+
User Question
|
| 255 |
+
↓
|
| 256 |
+
Ensemble VQA
|
| 257 |
+
↓
|
| 258 |
+
┌─────────────────────────────────┐
|
| 259 |
+
│ Conversation Manager │
|
| 260 |
+
│ - Resolve pronouns │
|
| 261 |
+
│ - Track context │
|
| 262 |
+
└─────────────────────────────────┘
|
| 263 |
+
↓
|
| 264 |
+
┌─────────────────────────────────┐
|
| 265 |
+
│ Semantic Neuro-Symbolic VQA │
|
| 266 |
+
│ - Detect objects (VQA) │
|
| 267 |
+
│ - Query Wikidata │
|
| 268 |
+
└─────────────────────────────────┘
|
| 269 |
+
↓
|
| 270 |
+
┌─────────────────────────────────┐
|
| 271 |
+
│ LLM Reasoning Service │
|
| 272 |
+
│ - Chain-of-Thought reasoning │
|
| 273 |
+
│ - Fallback to rules │
|
| 274 |
+
└─────────────────────────────────┘
|
| 275 |
+
↓
|
| 276 |
+
Answer + Reasoning Chain
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
---
|
| 280 |
+
|
| 281 |
+
## 🎯 Key Improvements
|
| 282 |
+
|
| 283 |
+
| Feature | Before | After |
|
| 284 |
+
|---------|--------|-------|
|
| 285 |
+
| **Reasoning** | Hardcoded if/else rules | LLM Chain-of-Thought |
|
| 286 |
+
| **Conversations** | Single-shot only | Multi-turn with context |
|
| 287 |
+
| **Pronouns** | Not handled | Automatic resolution |
|
| 288 |
+
| **Transparency** | Black box | Reasoning chains visible |
|
| 289 |
+
| **Flexibility** | Rigid rules | Adaptive LLM reasoning |
|
| 290 |
+
|
| 291 |
+
---
|
| 292 |
+
|
| 293 |
+
## 📝 Notes
|
| 294 |
+
|
| 295 |
+
- LLM reasoning requires `GROQ_API_KEY` environment variable
|
| 296 |
+
- Conversation sessions auto-expire after 30 minutes
|
| 297 |
+
- All features have fallback mechanisms for robustness
|
| 298 |
+
- Zero breaking changes to existing code
|
__pycache__/backend_api.cpython-312.pyc
ADDED
|
Binary file (14.8 kB). View file
|
|
|
__pycache__/conversation_manager.cpython-312.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
__pycache__/ensemble_vqa_app.cpython-312.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
__pycache__/groq_service.cpython-312.pyc
ADDED
|
Binary file (5.32 kB). View file
|
|
|
__pycache__/knowledge_graph_service.cpython-312.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
__pycache__/llm_reasoning_service.cpython-312.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
__pycache__/model_spatial.cpython-312.pyc
ADDED
|
Binary file (25.5 kB). View file
|
|
|
__pycache__/semantic_neurosymbolic_vqa.cpython-312.pyc
ADDED
|
Binary file (32 kB). View file
|
|
|
architecture_draft.html
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
<!DOCTYPE html>
|
| 3 |
+
<html>
|
| 4 |
+
<head>
|
| 5 |
+
<title>VQA Architecture Draft</title>
|
| 6 |
+
<script type="module">
|
| 7 |
+
import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
|
| 8 |
+
mermaid.initialize({ startOnLoad: true, theme: 'dark', flowchart: { curve: 'basis' } });
|
| 9 |
+
</script>
|
| 10 |
+
<style>
|
| 11 |
+
body { background-color: #0D1117; color: white; font-family: sans-serif; display: flex; justify-content: center; padding: 20px; }
|
| 12 |
+
.mermaid { background-color: #161B22; padding: 20px; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.5); }
|
| 13 |
+
</style>
|
| 14 |
+
</head>
|
| 15 |
+
<body>
|
| 16 |
+
<div class="mermaid">
|
| 17 |
+
|
| 18 |
+
graph TD
|
| 19 |
+
%% Styling
|
| 20 |
+
classDef default fill:#1A1A1A,stroke:#444,stroke-width:2px,color:#FFF,rx:8px,ry:8px,font-family:arial;
|
| 21 |
+
classDef mobile fill:#003366,stroke:#0055AA,stroke-width:2px,color:#FFF;
|
| 22 |
+
classDef preproc fill:#333333,stroke:#555,stroke-width:2px,color:#FFF;
|
| 23 |
+
classDef model fill:#4B0082,stroke:#8A2BE2,stroke-width:2px,color:#FFF;
|
| 24 |
+
classDef condition fill:#2B2B2B,stroke:#F4A460,stroke-width:2px,color:#FFF,shape:rhombus;
|
| 25 |
+
classDef external fill:#004d00,stroke:#009900,stroke-width:2px,color:#FFF;
|
| 26 |
+
classDef final fill:#660000,stroke:#CC0000,stroke-width:2px,color:#FFF;
|
| 27 |
+
|
| 28 |
+
%% Nodes
|
| 29 |
+
UserApp[📱 Mobile App]:::mobile
|
| 30 |
+
|
| 31 |
+
ImgUpload[🖼️ Image]:::preproc
|
| 32 |
+
Question[⌨️ Question Text]:::preproc
|
| 33 |
+
|
| 34 |
+
PIL[🐍 PIL Preprocessing<br/>RGB conversion]:::preproc
|
| 35 |
+
|
| 36 |
+
CLIP[👁️ OpenAI CLIP ViT-B/32<br/>Image Features 512-dim]:::model
|
| 37 |
+
GPT2[🤗 DistilGPT-2<br/>Tokenized Question]:::model
|
| 38 |
+
|
| 39 |
+
Route1{Question<br/>spatial?}:::condition
|
| 40 |
+
|
| 41 |
+
Spatial[📐 Spatial VQA Model<br/>8-head attention]:::model
|
| 42 |
+
Base[🧠 Base VQA Model<br/>General VQA]:::model
|
| 43 |
+
|
| 44 |
+
Decoder[🤗 GPT-2 Decoder<br/>vocab decode]:::model
|
| 45 |
+
NeuralAns[💬 Neural Answer]:::final
|
| 46 |
+
|
| 47 |
+
Route2{Knowledge<br/>question?}:::condition
|
| 48 |
+
|
| 49 |
+
ObjDet[👁️ CLIP Object Detector<br/>Top-3 objects]:::model
|
| 50 |
+
Wikidata[🌍 Wikidata SPARQL<br/>P31, P186, P366]:::external
|
| 51 |
+
GroqV[⚡ Groq Llama-3.3<br/>Verbalizer]:::external
|
| 52 |
+
KGAns[🧩 KG Enhancement]:::final
|
| 53 |
+
|
| 54 |
+
FastAPI[🚀 FastAPI]:::preproc
|
| 55 |
+
GroqA[⚡ Groq Llama-3.3<br/>Accessibility]:::external
|
| 56 |
+
Audio[🔊 2-sentence description]:::final
|
| 57 |
+
|
| 58 |
+
%% Edges
|
| 59 |
+
UserApp -- "Image uploaded" --> ImgUpload
|
| 60 |
+
UserApp -- "Question typed" --> Question
|
| 61 |
+
|
| 62 |
+
ImgUpload --> PIL
|
| 63 |
+
PIL --> CLIP
|
| 64 |
+
Question --> GPT2
|
| 65 |
+
|
| 66 |
+
CLIP & GPT2 --> Route1
|
| 67 |
+
|
| 68 |
+
Route1 -- "YES" --> Spatial
|
| 69 |
+
Route1 -- "NO" --> Base
|
| 70 |
+
|
| 71 |
+
Spatial & Base -- "Beam search (width=5)" --> Decoder
|
| 72 |
+
Decoder --> NeuralAns
|
| 73 |
+
|
| 74 |
+
CLIP -- "Anchor similarity" --> Route2
|
| 75 |
+
|
| 76 |
+
Route2 -- "YES" --> ObjDet
|
| 77 |
+
ObjDet -- "Detected objects" --> Wikidata
|
| 78 |
+
Wikidata -- "Structured facts" --> GroqV
|
| 79 |
+
GroqV --> KGAns
|
| 80 |
+
|
| 81 |
+
FastAPI -- "Narration request" --> GroqA
|
| 82 |
+
GroqA --> Audio
|
| 83 |
+
|
| 84 |
+
NeuralAns & KGAns & Audio -- "JSON output" --> FastAPI
|
| 85 |
+
FastAPI --> UserApp
|
| 86 |
+
|
| 87 |
+
</div>
|
| 88 |
+
</body>
|
| 89 |
+
</html>
|
architecture_draft.mmd
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
graph TD
|
| 3 |
+
%% Styling
|
| 4 |
+
classDef default fill:#1A1A1A,stroke:#444,stroke-width:2px,color:#FFF,rx:8px,ry:8px,font-family:arial;
|
| 5 |
+
classDef mobile fill:#003366,stroke:#0055AA,stroke-width:2px,color:#FFF;
|
| 6 |
+
classDef preproc fill:#333333,stroke:#555,stroke-width:2px,color:#FFF;
|
| 7 |
+
classDef model fill:#4B0082,stroke:#8A2BE2,stroke-width:2px,color:#FFF;
|
| 8 |
+
classDef condition fill:#2B2B2B,stroke:#F4A460,stroke-width:2px,color:#FFF,shape:rhombus;
|
| 9 |
+
classDef external fill:#004d00,stroke:#009900,stroke-width:2px,color:#FFF;
|
| 10 |
+
classDef final fill:#660000,stroke:#CC0000,stroke-width:2px,color:#FFF;
|
| 11 |
+
|
| 12 |
+
%% Nodes
|
| 13 |
+
UserApp[📱 Mobile App]:::mobile
|
| 14 |
+
|
| 15 |
+
ImgUpload[🖼️ Image]:::preproc
|
| 16 |
+
Question[⌨️ Question Text]:::preproc
|
| 17 |
+
|
| 18 |
+
PIL[🐍 PIL Preprocessing<br/>RGB conversion]:::preproc
|
| 19 |
+
|
| 20 |
+
CLIP[👁️ OpenAI CLIP ViT-B/32<br/>Image Features 512-dim]:::model
|
| 21 |
+
GPT2[🤗 DistilGPT-2<br/>Tokenized Question]:::model
|
| 22 |
+
|
| 23 |
+
Route1{Question<br/>spatial?}:::condition
|
| 24 |
+
|
| 25 |
+
Spatial[📐 Spatial VQA Model<br/>8-head attention]:::model
|
| 26 |
+
Base[🧠 Base VQA Model<br/>General VQA]:::model
|
| 27 |
+
|
| 28 |
+
Decoder[🤗 GPT-2 Decoder<br/>vocab decode]:::model
|
| 29 |
+
NeuralAns[💬 Neural Answer]:::final
|
| 30 |
+
|
| 31 |
+
Route2{Knowledge<br/>question?}:::condition
|
| 32 |
+
|
| 33 |
+
ObjDet[👁️ CLIP Object Detector<br/>Top-3 objects]:::model
|
| 34 |
+
Wikidata[🌍 Wikidata SPARQL<br/>P31, P186, P366]:::external
|
| 35 |
+
GroqV[⚡ Groq Llama-3.3<br/>Verbalizer]:::external
|
| 36 |
+
KGAns[🧩 KG Enhancement]:::final
|
| 37 |
+
|
| 38 |
+
FastAPI[🚀 FastAPI]:::preproc
|
| 39 |
+
GroqA[⚡ Groq Llama-3.3<br/>Accessibility]:::external
|
| 40 |
+
Audio[🔊 2-sentence description]:::final
|
| 41 |
+
|
| 42 |
+
%% Edges
|
| 43 |
+
UserApp -- "Image uploaded" --> ImgUpload
|
| 44 |
+
UserApp -- "Question typed" --> Question
|
| 45 |
+
|
| 46 |
+
ImgUpload --> PIL
|
| 47 |
+
PIL --> CLIP
|
| 48 |
+
Question --> GPT2
|
| 49 |
+
|
| 50 |
+
CLIP & GPT2 --> Route1
|
| 51 |
+
|
| 52 |
+
Route1 -- "YES" --> Spatial
|
| 53 |
+
Route1 -- "NO" --> Base
|
| 54 |
+
|
| 55 |
+
Spatial & Base -- "Beam search (width=5)" --> Decoder
|
| 56 |
+
Decoder --> NeuralAns
|
| 57 |
+
|
| 58 |
+
CLIP -- "Anchor similarity" --> Route2
|
| 59 |
+
|
| 60 |
+
Route2 -- "YES" --> ObjDet
|
| 61 |
+
ObjDet -- "Detected objects" --> Wikidata
|
| 62 |
+
Wikidata -- "Structured facts" --> GroqV
|
| 63 |
+
GroqV --> KGAns
|
| 64 |
+
|
| 65 |
+
FastAPI -- "Narration request" --> GroqA
|
| 66 |
+
GroqA --> Audio
|
| 67 |
+
|
| 68 |
+
NeuralAns & KGAns & Audio -- "JSON output" --> FastAPI
|
| 69 |
+
FastAPI --> UserApp
|
backend_api.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI Backend for Ensemble VQA Mobile App
|
| 3 |
+
Provides REST API endpoints for the React Native mobile application
|
| 4 |
+
"""
|
| 5 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
| 6 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 7 |
+
from fastapi.responses import JSONResponse
|
| 8 |
+
import uvicorn
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import io
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
load_dotenv()
|
| 16 |
+
from ensemble_vqa_app import ProductionEnsembleVQA
|
| 17 |
+
from groq_service import get_groq_service
|
| 18 |
+
app = FastAPI(
|
| 19 |
+
title="Ensemble VQA API",
|
| 20 |
+
description="Visual Question Answering API with ensemble model routing",
|
| 21 |
+
version="1.0.0"
|
| 22 |
+
)
|
| 23 |
+
app.add_middleware(
|
| 24 |
+
CORSMiddleware,
|
| 25 |
+
allow_origins=["*"],
|
| 26 |
+
allow_credentials=True,
|
| 27 |
+
allow_methods=["*"],
|
| 28 |
+
allow_headers=["*"],
|
| 29 |
+
)
|
| 30 |
+
ensemble_model = None
|
| 31 |
+
groq_service = None
|
| 32 |
+
@app.on_event("startup")
|
| 33 |
+
async def startup_event():
|
| 34 |
+
"""Initialize the ensemble VQA model on server startup"""
|
| 35 |
+
global ensemble_model, groq_service
|
| 36 |
+
print("=" * 80)
|
| 37 |
+
print("🚀 STARTING VQA API SERVER")
|
| 38 |
+
print("=" * 80)
|
| 39 |
+
BASE_CHECKPOINT = "./vqa_checkpoint.pt"
|
| 40 |
+
SPATIAL_CHECKPOINT = "./vqa_spatial_checkpoint.pt"
|
| 41 |
+
if not os.path.exists(BASE_CHECKPOINT):
|
| 42 |
+
print(f"❌ Base checkpoint not found: {BASE_CHECKPOINT}")
|
| 43 |
+
print("Please ensure vqa_checkpoint.pt is in the project root")
|
| 44 |
+
sys.exit(1)
|
| 45 |
+
if not os.path.exists(SPATIAL_CHECKPOINT):
|
| 46 |
+
print(f"❌ Spatial checkpoint not found: {SPATIAL_CHECKPOINT}")
|
| 47 |
+
print("Please ensure vqa_spatial_checkpoint.pt is in the project root")
|
| 48 |
+
sys.exit(1)
|
| 49 |
+
try:
|
| 50 |
+
ensemble_model = ProductionEnsembleVQA(
|
| 51 |
+
base_checkpoint=BASE_CHECKPOINT,
|
| 52 |
+
spatial_checkpoint=SPATIAL_CHECKPOINT,
|
| 53 |
+
device='cuda'
|
| 54 |
+
)
|
| 55 |
+
print("\n✅ VQA models loaded successfully!")
|
| 56 |
+
try:
|
| 57 |
+
groq_service = get_groq_service()
|
| 58 |
+
print("✅ Groq LLM service initialized for accessibility features")
|
| 59 |
+
except ValueError as e:
|
| 60 |
+
print(f"⚠️ Groq service not available: {e}")
|
| 61 |
+
print(" Accessibility descriptions will use fallback mode")
|
| 62 |
+
groq_service = None
|
| 63 |
+
print("📱 Mobile app can now connect")
|
| 64 |
+
print("=" * 80)
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"\n❌ Failed to load models: {e}")
|
| 67 |
+
sys.exit(1)
|
| 68 |
+
@app.get("/")
|
| 69 |
+
async def root():
|
| 70 |
+
"""Root endpoint"""
|
| 71 |
+
return {
|
| 72 |
+
"message": "Ensemble VQA API",
|
| 73 |
+
"version": "1.0.0",
|
| 74 |
+
"status": "running",
|
| 75 |
+
"endpoints": {
|
| 76 |
+
"health": "/health",
|
| 77 |
+
"answer": "/api/answer (POST)"
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
@app.get("/health")
|
| 81 |
+
async def health_check():
|
| 82 |
+
"""Health check endpoint"""
|
| 83 |
+
return {
|
| 84 |
+
"status": "healthy",
|
| 85 |
+
"model_loaded": ensemble_model is not None,
|
| 86 |
+
"models": {
|
| 87 |
+
"base": "loaded" if ensemble_model else "not loaded",
|
| 88 |
+
"spatial": "loaded" if ensemble_model else "not loaded"
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
@app.post("/api/answer")
|
| 92 |
+
async def answer_question(
|
| 93 |
+
image: UploadFile = File(...),
|
| 94 |
+
question: str = Form(...)
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
Answer a visual question using the ensemble VQA system
|
| 98 |
+
Args:
|
| 99 |
+
image: Image file (JPEG, PNG)
|
| 100 |
+
question: Question text
|
| 101 |
+
Returns:
|
| 102 |
+
JSON response with answer, model used, accessibility description, and metadata
|
| 103 |
+
"""
|
| 104 |
+
if ensemble_model is None:
|
| 105 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 106 |
+
if not question or question.strip() == "":
|
| 107 |
+
raise HTTPException(status_code=400, detail="Question cannot be empty")
|
| 108 |
+
try:
|
| 109 |
+
image_bytes = await image.read()
|
| 110 |
+
try:
|
| 111 |
+
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
| 112 |
+
except Exception as e:
|
| 113 |
+
raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}")
|
| 114 |
+
temp_image_path = "temp_upload.jpg"
|
| 115 |
+
pil_image.save(temp_image_path)
|
| 116 |
+
result = ensemble_model.answer(
|
| 117 |
+
image_path=temp_image_path,
|
| 118 |
+
question=question,
|
| 119 |
+
use_beam_search=True,
|
| 120 |
+
beam_width=5,
|
| 121 |
+
verbose=True
|
| 122 |
+
)
|
| 123 |
+
if os.path.exists(temp_image_path):
|
| 124 |
+
os.remove(temp_image_path)
|
| 125 |
+
is_spatial = ensemble_model.is_spatial_question(question)
|
| 126 |
+
description = None
|
| 127 |
+
description_status = "not_available"
|
| 128 |
+
if groq_service is not None:
|
| 129 |
+
try:
|
| 130 |
+
desc_result = groq_service.generate_description(
|
| 131 |
+
question=question,
|
| 132 |
+
answer=result['answer']
|
| 133 |
+
)
|
| 134 |
+
description = desc_result.get('description')
|
| 135 |
+
description_status = desc_result.get('status', 'success')
|
| 136 |
+
except Exception as e:
|
| 137 |
+
print(f"⚠️ Groq description generation failed: {e}")
|
| 138 |
+
description = f"Question: {question}. Answer: {result['answer']}."
|
| 139 |
+
description_status = "fallback"
|
| 140 |
+
else:
|
| 141 |
+
description = f"Question: {question}. Answer: {result['answer']}."
|
| 142 |
+
description_status = "fallback"
|
| 143 |
+
reasoning_chain = None
|
| 144 |
+
if result.get('kg_enhancement'):
|
| 145 |
+
reasoning_chain = result.get('reasoning_chain', [])
|
| 146 |
+
return JSONResponse(content={
|
| 147 |
+
"success": True,
|
| 148 |
+
"answer": result['answer'],
|
| 149 |
+
"description": description,
|
| 150 |
+
"description_status": description_status,
|
| 151 |
+
"model_used": result['model_used'],
|
| 152 |
+
"confidence": result['confidence'],
|
| 153 |
+
"question_type": "spatial" if is_spatial else "general",
|
| 154 |
+
"question": question,
|
| 155 |
+
"kg_enhancement": result.get('kg_enhancement'),
|
| 156 |
+
"reasoning_type": result.get('reasoning_type', 'neural'),
|
| 157 |
+
"reasoning_chain": reasoning_chain,
|
| 158 |
+
"metadata": {
|
| 159 |
+
"beam_search": True,
|
| 160 |
+
"beam_width": 5
|
| 161 |
+
}
|
| 162 |
+
})
|
| 163 |
+
except HTTPException:
|
| 164 |
+
raise
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"❌ Error processing request: {e}")
|
| 167 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 168 |
+
@app.get("/api/models/info")
|
| 169 |
+
async def models_info():
|
| 170 |
+
"""Get information about loaded models"""
|
| 171 |
+
if ensemble_model is None:
|
| 172 |
+
raise HTTPException(status_code=503, detail="Models not loaded")
|
| 173 |
+
return {
|
| 174 |
+
"base_model": {
|
| 175 |
+
"name": "Base VQA Model",
|
| 176 |
+
"description": "General visual question answering",
|
| 177 |
+
"accuracy": "50%",
|
| 178 |
+
"use_case": "General questions about objects, colors, counts, etc."
|
| 179 |
+
},
|
| 180 |
+
"spatial_model": {
|
| 181 |
+
"name": "Spatial Adapter Model",
|
| 182 |
+
"description": "Spatial reasoning and positional questions",
|
| 183 |
+
"accuracy": "40%",
|
| 184 |
+
"use_case": "Spatial questions (left, right, above, below, etc.)"
|
| 185 |
+
},
|
| 186 |
+
"routing": {
|
| 187 |
+
"method": "Keyword-based classification",
|
| 188 |
+
"spatial_keywords": ensemble_model.SPATIAL_KEYWORDS
|
| 189 |
+
},
|
| 190 |
+
"conversation": {
|
| 191 |
+
"enabled": ensemble_model.conversation_enabled if ensemble_model else False,
|
| 192 |
+
"timeout_minutes": 30
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
@app.post("/api/conversation/answer")
|
| 196 |
+
async def answer_conversational(
|
| 197 |
+
image: UploadFile = File(...),
|
| 198 |
+
question: str = Form(...),
|
| 199 |
+
session_id: str = Form(None)
|
| 200 |
+
):
|
| 201 |
+
"""
|
| 202 |
+
Answer a visual question with multi-turn conversation support.
|
| 203 |
+
Handles pronoun resolution and maintains conversation context.
|
| 204 |
+
Args:
|
| 205 |
+
image: Image file (JPEG, PNG)
|
| 206 |
+
question: Question text (may contain pronouns like "it", "this")
|
| 207 |
+
session_id: Optional session ID to continue conversation
|
| 208 |
+
Returns:
|
| 209 |
+
JSON response with answer, session_id, resolved question, and context
|
| 210 |
+
"""
|
| 211 |
+
if ensemble_model is None:
|
| 212 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
| 213 |
+
if not ensemble_model.conversation_enabled:
|
| 214 |
+
raise HTTPException(
|
| 215 |
+
status_code=501,
|
| 216 |
+
detail="Conversational VQA not available. Use /api/answer instead."
|
| 217 |
+
)
|
| 218 |
+
if not question or question.strip() == "":
|
| 219 |
+
raise HTTPException(status_code=400, detail="Question cannot be empty")
|
| 220 |
+
try:
|
| 221 |
+
image_bytes = await image.read()
|
| 222 |
+
try:
|
| 223 |
+
pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
|
| 224 |
+
except Exception as e:
|
| 225 |
+
raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}")
|
| 226 |
+
temp_image_path = "temp_upload.jpg"
|
| 227 |
+
pil_image.save(temp_image_path)
|
| 228 |
+
result = ensemble_model.answer_conversational(
|
| 229 |
+
image_path=temp_image_path,
|
| 230 |
+
question=question,
|
| 231 |
+
session_id=session_id,
|
| 232 |
+
use_beam_search=True,
|
| 233 |
+
beam_width=5,
|
| 234 |
+
verbose=True
|
| 235 |
+
)
|
| 236 |
+
if os.path.exists(temp_image_path):
|
| 237 |
+
os.remove(temp_image_path)
|
| 238 |
+
description = None
|
| 239 |
+
if groq_service is not None:
|
| 240 |
+
try:
|
| 241 |
+
desc_result = groq_service.generate_description(
|
| 242 |
+
question=result['resolved_question'],
|
| 243 |
+
answer=result['answer']
|
| 244 |
+
)
|
| 245 |
+
description = desc_result.get('description')
|
| 246 |
+
except:
|
| 247 |
+
description = f"Question: {question}. Answer: {result['answer']}."
|
| 248 |
+
else:
|
| 249 |
+
description = f"Question: {question}. Answer: {result['answer']}."
|
| 250 |
+
return JSONResponse(content={
|
| 251 |
+
"success": True,
|
| 252 |
+
"answer": result['answer'],
|
| 253 |
+
"description": description,
|
| 254 |
+
"session_id": result['session_id'],
|
| 255 |
+
"resolved_question": result['resolved_question'],
|
| 256 |
+
"original_question": question,
|
| 257 |
+
"conversation_context": result['conversation_context'],
|
| 258 |
+
"model_used": result['model_used'],
|
| 259 |
+
"confidence": result['confidence'],
|
| 260 |
+
"kg_enhancement": result.get('kg_enhancement'),
|
| 261 |
+
"reasoning_type": result.get('reasoning_type', 'neural'),
|
| 262 |
+
"reasoning_chain": result.get('reasoning_chain'),
|
| 263 |
+
"metadata": {
|
| 264 |
+
"beam_search": True,
|
| 265 |
+
"beam_width": 5,
|
| 266 |
+
"conversation_enabled": True
|
| 267 |
+
}
|
| 268 |
+
})
|
| 269 |
+
except HTTPException:
|
| 270 |
+
raise
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"❌ Error processing conversational request: {e}")
|
| 273 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 274 |
+
@app.get("/api/conversation/{session_id}/history")
|
| 275 |
+
async def get_conversation_history(session_id: str):
|
| 276 |
+
"""
|
| 277 |
+
Get conversation history for a session.
|
| 278 |
+
Args:
|
| 279 |
+
session_id: Session ID
|
| 280 |
+
Returns:
|
| 281 |
+
JSON with conversation history
|
| 282 |
+
"""
|
| 283 |
+
if ensemble_model is None or not ensemble_model.conversation_enabled:
|
| 284 |
+
raise HTTPException(status_code=503, detail="Conversation service not available")
|
| 285 |
+
history = ensemble_model.conversation_manager.get_history(session_id)
|
| 286 |
+
if history is None:
|
| 287 |
+
raise HTTPException(
|
| 288 |
+
status_code=404,
|
| 289 |
+
detail=f"Session {session_id} not found or expired"
|
| 290 |
+
)
|
| 291 |
+
return JSONResponse(content={
|
| 292 |
+
"success": True,
|
| 293 |
+
"session_id": session_id,
|
| 294 |
+
"history": history,
|
| 295 |
+
"turn_count": len(history)
|
| 296 |
+
})
|
| 297 |
+
@app.delete("/api/conversation/{session_id}")
|
| 298 |
+
async def delete_conversation(session_id: str):
|
| 299 |
+
"""
|
| 300 |
+
Delete a conversation session.
|
| 301 |
+
Args:
|
| 302 |
+
session_id: Session ID to delete
|
| 303 |
+
Returns:
|
| 304 |
+
JSON with success status
|
| 305 |
+
"""
|
| 306 |
+
if ensemble_model is None or not ensemble_model.conversation_enabled:
|
| 307 |
+
raise HTTPException(status_code=503, detail="Conversation service not available")
|
| 308 |
+
deleted = ensemble_model.conversation_manager.delete_session(session_id)
|
| 309 |
+
if not deleted:
|
| 310 |
+
raise HTTPException(
|
| 311 |
+
status_code=404,
|
| 312 |
+
detail=f"Session {session_id} not found"
|
| 313 |
+
)
|
| 314 |
+
return JSONResponse(content={
|
| 315 |
+
"success": True,
|
| 316 |
+
"message": f"Session {session_id} deleted"
|
| 317 |
+
})
|
| 318 |
+
if __name__ == "__main__":
|
| 319 |
+
print("\n" + "=" * 80)
|
| 320 |
+
print("🚀 ENSEMBLE VQA API SERVER")
|
| 321 |
+
print("=" * 80)
|
| 322 |
+
print("\n📋 Configuration:")
|
| 323 |
+
print(" - Host: 0.0.0.0 (accessible from network)")
|
| 324 |
+
print(" - Port: 8000")
|
| 325 |
+
print(" - Reload: Enabled (development mode)")
|
| 326 |
+
print("\n🔗 Access URLs:")
|
| 327 |
+
print(" - Local: http://localhost:8000")
|
| 328 |
+
print(" - Network: http://<your-ip>:8000")
|
| 329 |
+
print(" - Docs: http://localhost:8000/docs")
|
| 330 |
+
print("\n💡 For mobile testing:")
|
| 331 |
+
print(" 1. Find your local IP: ipconfig (Windows) or ifconfig (Mac/Linux)")
|
| 332 |
+
print(" 2. Update API_URL in mobile app to http://<your-ip>:8000")
|
| 333 |
+
print(" 3. Ensure phone and computer are on same network")
|
| 334 |
+
print("=" * 80 + "\n")
|
| 335 |
+
uvicorn.run(
|
| 336 |
+
"backend_api:app",
|
| 337 |
+
host="0.0.0.0",
|
| 338 |
+
port=7860, # HuggingFace Spaces requires port 7860
|
| 339 |
+
reload=True,
|
| 340 |
+
log_level="info"
|
| 341 |
+
)
|
continue.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from transformers import GPT2Tokenizer
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from collections import Counter
|
| 12 |
+
from nltk.tokenize import word_tokenize
|
| 13 |
+
from sklearn.model_selection import train_test_split
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from model import VQAModel
|
| 16 |
+
device = 'cuda'
|
| 17 |
+
class Vocab:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.vocab = None
|
| 20 |
+
self.vocab_size = None
|
| 21 |
+
self.word2idx = None
|
| 22 |
+
self.idx2word = None
|
| 23 |
+
self.pad = '<pad>'
|
| 24 |
+
self.bos = '<bos>'
|
| 25 |
+
self.eos = '<eos>'
|
| 26 |
+
self.unk = '<unk>'
|
| 27 |
+
def build_vocab(self, df, min_freq=1):
|
| 28 |
+
counter = Counter()
|
| 29 |
+
for ans in df['answer']:
|
| 30 |
+
tokens = word_tokenize(ans.lower())
|
| 31 |
+
counter.update(tokens)
|
| 32 |
+
vocab = sorted([word for word, freq in counter.items() if freq >= min_freq])
|
| 33 |
+
vocab = [self.pad, self.bos, self.eos, self.unk] + vocab
|
| 34 |
+
word2idx = {word: idx for idx, word in enumerate(vocab)}
|
| 35 |
+
idx2word = {idx: word for word, idx in word2idx.items()}
|
| 36 |
+
self.vocab = vocab
|
| 37 |
+
self.word2idx = word2idx
|
| 38 |
+
self.idx2word = idx2word
|
| 39 |
+
self.vocab_size = len(vocab)
|
| 40 |
+
self.pad_token_id = self.word2idx["<pad>"]
|
| 41 |
+
self.bos_token_id = self.word2idx["<bos>"]
|
| 42 |
+
self.eos_token_id = self.word2idx["<eos>"]
|
| 43 |
+
self.unk_token_id = self.word2idx["<unk>"]
|
| 44 |
+
def encoder(self, text, max_len):
|
| 45 |
+
tokens = word_tokenize(text.lower())
|
| 46 |
+
token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]
|
| 47 |
+
token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
|
| 48 |
+
if len(token_ids) < max_len:
|
| 49 |
+
token_ids += [self.pad_token_id] * (max_len - len(token_ids))
|
| 50 |
+
else:
|
| 51 |
+
token_ids = token_ids[:max_len]
|
| 52 |
+
return token_ids
|
| 53 |
+
def decoder(self, token_ids):
|
| 54 |
+
tokens = []
|
| 55 |
+
for idx in token_ids:
|
| 56 |
+
if idx == self.eos_token_id:
|
| 57 |
+
break
|
| 58 |
+
if idx in (self.pad_token_id, self.bos_token_id):
|
| 59 |
+
continue
|
| 60 |
+
tokens.append(self.idx2word.get(idx, "<unk>"))
|
| 61 |
+
return ' '.join(tokens).strip()
|
| 62 |
+
class AugmentedVQADataset(Dataset):
|
| 63 |
+
def __init__(self, df, img_dir, question_tokenizer, text_processor, clip_processor,
|
| 64 |
+
question_max_len=32, answer_max_len=16, augment=True):
|
| 65 |
+
self.df = df
|
| 66 |
+
self.img_dir = img_dir
|
| 67 |
+
self.question_tokenizer = question_tokenizer
|
| 68 |
+
self.text_processor = text_processor
|
| 69 |
+
self.clip_processor = clip_processor
|
| 70 |
+
self.question_max_len = question_max_len
|
| 71 |
+
self.answer_max_len = answer_max_len
|
| 72 |
+
self.augment = augment
|
| 73 |
+
if augment:
|
| 74 |
+
self.transform = transforms.Compose([
|
| 75 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 76 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
| 77 |
+
transforms.RandomRotation(10),
|
| 78 |
+
])
|
| 79 |
+
else:
|
| 80 |
+
self.transform = None
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return len(self.df)
|
| 83 |
+
def __getitem__(self, idx):
|
| 84 |
+
row = self.df.iloc[idx]
|
| 85 |
+
img_path = os.path.join(self.img_dir, row['image_path'])
|
| 86 |
+
image = Image.open(img_path).convert('RGB')
|
| 87 |
+
question = row['question']
|
| 88 |
+
answer = row['answer']
|
| 89 |
+
if self.augment and self.transform:
|
| 90 |
+
image = self.transform(image)
|
| 91 |
+
question_tokenized = self.question_tokenizer(
|
| 92 |
+
question,
|
| 93 |
+
padding='max_length',
|
| 94 |
+
truncation=True,
|
| 95 |
+
max_length=self.question_max_len,
|
| 96 |
+
return_tensors='pt'
|
| 97 |
+
)
|
| 98 |
+
answer_ids = self.text_processor.encoder(answer, max_len=self.answer_max_len)
|
| 99 |
+
image = self.clip_processor(image)
|
| 100 |
+
return {
|
| 101 |
+
'image_path': img_path,
|
| 102 |
+
'image': image,
|
| 103 |
+
'question_ids': question_tokenized['input_ids'].squeeze(0),
|
| 104 |
+
'question_mask': question_tokenized['attention_mask'].squeeze(0),
|
| 105 |
+
'answer_ids': torch.tensor(answer_ids, dtype=torch.long)
|
| 106 |
+
}
|
| 107 |
+
def save_checkpoint(model, optimizer, epoch, vocab, path):
|
| 108 |
+
torch.save({
|
| 109 |
+
'epoch': epoch,
|
| 110 |
+
'model_state_dict': model.state_dict(),
|
| 111 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 112 |
+
'vocab': vocab.vocab,
|
| 113 |
+
'word2idx': vocab.word2idx,
|
| 114 |
+
'idx2word': vocab.idx2word,
|
| 115 |
+
'pad_token_id': vocab.pad_token_id,
|
| 116 |
+
'bos_token_id': vocab.bos_token_id,
|
| 117 |
+
'eos_token_id': vocab.eos_token_id,
|
| 118 |
+
'unk_token_id': vocab.unk_token_id,
|
| 119 |
+
'question_max_len': model.question_max_len,
|
| 120 |
+
'answer_max_len': model.answer_max_len
|
| 121 |
+
}, path)
|
| 122 |
+
def plot_losses(train_losses, val_losses, save_path="loss_plot.png"):
|
| 123 |
+
plt.figure(figsize=(8,6))
|
| 124 |
+
plt.plot(train_losses, label="Train Loss")
|
| 125 |
+
plt.plot(val_losses, label="Validation Loss")
|
| 126 |
+
plt.xlabel("Epoch")
|
| 127 |
+
plt.ylabel("Loss")
|
| 128 |
+
plt.title("Train vs Validation Loss")
|
| 129 |
+
plt.legend()
|
| 130 |
+
plt.savefig(save_path)
|
| 131 |
+
plt.close()
|
| 132 |
+
def train_one_epoch(model, dataloader, optimizer, device, scaler, vocab):
|
| 133 |
+
model.train()
|
| 134 |
+
total_loss = 0
|
| 135 |
+
total_token_acc = 0
|
| 136 |
+
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id, label_smoothing=0.1)
|
| 137 |
+
for batch in tqdm(dataloader):
|
| 138 |
+
optimizer.zero_grad()
|
| 139 |
+
images = batch['image'].to(device)
|
| 140 |
+
questions = {
|
| 141 |
+
'input_ids': batch['question_ids'].to(device),
|
| 142 |
+
'attention_mask': batch['question_mask'].to(device)
|
| 143 |
+
}
|
| 144 |
+
answers = batch['answer_ids'].to(device)
|
| 145 |
+
with torch.amp.autocast(device):
|
| 146 |
+
logits = model(images, questions, answer_input_ids=answers)
|
| 147 |
+
shifted_logits = logits[:, :-1, :]
|
| 148 |
+
shifted_answers = answers[:, 1:]
|
| 149 |
+
loss = criterion(
|
| 150 |
+
shifted_logits.reshape(-1, shifted_logits.size(-1)),
|
| 151 |
+
shifted_answers.reshape(-1)
|
| 152 |
+
)
|
| 153 |
+
predicted_tokens = shifted_logits.argmax(dim=-1)
|
| 154 |
+
correct = (predicted_tokens == shifted_answers).float()
|
| 155 |
+
mask = (shifted_answers != vocab.pad_token_id).float()
|
| 156 |
+
token_acc = (correct * mask).sum() / mask.sum()
|
| 157 |
+
total_token_acc += token_acc.item()
|
| 158 |
+
scaler.scale(loss).backward()
|
| 159 |
+
scaler.unscale_(optimizer)
|
| 160 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 161 |
+
scaler.step(optimizer)
|
| 162 |
+
scaler.update()
|
| 163 |
+
total_loss += loss.item()
|
| 164 |
+
avg_loss = total_loss / len(dataloader)
|
| 165 |
+
avg_token_acc = total_token_acc / len(dataloader)
|
| 166 |
+
return avg_loss, avg_token_acc
|
| 167 |
+
def validate_one_epoch(model, dataloader, device, vocab):
|
| 168 |
+
model.eval()
|
| 169 |
+
total_loss = 0
|
| 170 |
+
total_token_acc = 0
|
| 171 |
+
exact_matches = 0
|
| 172 |
+
total_samples = 0
|
| 173 |
+
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id)
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
for batch in tqdm(dataloader):
|
| 176 |
+
images = batch['image'].to(device)
|
| 177 |
+
questions = {
|
| 178 |
+
'input_ids': batch['question_ids'].to(device),
|
| 179 |
+
'attention_mask': batch['question_mask'].to(device)
|
| 180 |
+
}
|
| 181 |
+
answers = batch['answer_ids'].to(device)
|
| 182 |
+
logits = model(images, questions, answer_input_ids=answers)
|
| 183 |
+
shifted_logits = logits[:, :-1, :]
|
| 184 |
+
shifted_answers = answers[:, 1:]
|
| 185 |
+
loss = criterion(
|
| 186 |
+
shifted_logits.reshape(-1, shifted_logits.size(-1)),
|
| 187 |
+
shifted_answers.reshape(-1)
|
| 188 |
+
)
|
| 189 |
+
total_loss += loss.item()
|
| 190 |
+
predicted_tokens = shifted_logits.argmax(dim=-1)
|
| 191 |
+
correct = (predicted_tokens == shifted_answers).float()
|
| 192 |
+
mask = (shifted_answers != vocab.pad_token_id).float()
|
| 193 |
+
token_acc = (correct * mask).sum() / mask.sum()
|
| 194 |
+
total_token_acc += token_acc.item()
|
| 195 |
+
generated = model(images, questions)
|
| 196 |
+
for pred, true in zip(generated, answers):
|
| 197 |
+
pred_text = vocab.decoder(pred.cpu().numpy())
|
| 198 |
+
true_text = vocab.decoder(true.cpu().numpy())
|
| 199 |
+
if pred_text.strip() == true_text.strip():
|
| 200 |
+
exact_matches += 1
|
| 201 |
+
total_samples += 1
|
| 202 |
+
avg_loss = total_loss / len(dataloader)
|
| 203 |
+
avg_token_acc = total_token_acc / len(dataloader)
|
| 204 |
+
exact_match_acc = exact_matches / total_samples
|
| 205 |
+
return avg_loss, avg_token_acc, exact_match_acc
|
| 206 |
+
def main():
|
| 207 |
+
print()
|
| 208 |
+
print("# VQA: Continue Training (Same Settings)")
|
| 209 |
+
print()
|
| 210 |
+
import random
|
| 211 |
+
import numpy as np
|
| 212 |
+
torch.manual_seed(42)
|
| 213 |
+
random.seed(42)
|
| 214 |
+
np.random.seed(42)
|
| 215 |
+
if torch.cuda.is_available(): torch.cuda.manual_seed_all(42)
|
| 216 |
+
DATA_DIR = r"./gen_vqa_v2"
|
| 217 |
+
CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
|
| 218 |
+
RESUME_CHECKPOINT = r"./output2/continued_training/vqa_checkpoint.pt"
|
| 219 |
+
OUTPUT_DIR = r"./output2/continued_training_2"
|
| 220 |
+
CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "vqa_checkpoint.pt")
|
| 221 |
+
LOG_CSV = os.path.join(OUTPUT_DIR, "train_log.csv")
|
| 222 |
+
LOSS_GRAPH_PATH = os.path.join(OUTPUT_DIR, "loss_plot.png")
|
| 223 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 224 |
+
batch_size = 64
|
| 225 |
+
additional_epochs = 50
|
| 226 |
+
patience = 8
|
| 227 |
+
question_max_len = 20
|
| 228 |
+
answer_max_len = 12
|
| 229 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 230 |
+
print(device)
|
| 231 |
+
print(f"Loading checkpoint from: {RESUME_CHECKPOINT}")
|
| 232 |
+
checkpoint = torch.load(RESUME_CHECKPOINT, map_location=device)
|
| 233 |
+
start_epoch = checkpoint['epoch'] + 1
|
| 234 |
+
metadata = pd.read_csv(CSV_PATH)
|
| 235 |
+
vocab = Vocab()
|
| 236 |
+
vocab.vocab = checkpoint['vocab']
|
| 237 |
+
vocab.vocab_size = len(checkpoint['vocab'])
|
| 238 |
+
vocab.word2idx = checkpoint['word2idx']
|
| 239 |
+
vocab.idx2word = checkpoint['idx2word']
|
| 240 |
+
vocab.pad_token_id = checkpoint['pad_token_id']
|
| 241 |
+
vocab.bos_token_id = checkpoint['bos_token_id']
|
| 242 |
+
vocab.eos_token_id = checkpoint['eos_token_id']
|
| 243 |
+
vocab.unk_token_id = checkpoint['unk_token_id']
|
| 244 |
+
print(f"Answer Vocab Size: {len(vocab.vocab)}")
|
| 245 |
+
print(f"Resuming from epoch: {start_epoch}")
|
| 246 |
+
train_df, test_df = train_test_split(metadata, test_size=0.2, random_state=42)
|
| 247 |
+
val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
|
| 248 |
+
print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")
|
| 249 |
+
print()
|
| 250 |
+
model = VQAModel(
|
| 251 |
+
vocab_size=len(vocab.vocab),
|
| 252 |
+
device=device,
|
| 253 |
+
question_max_len=question_max_len,
|
| 254 |
+
answer_max_len=answer_max_len,
|
| 255 |
+
pad_token_id=vocab.pad_token_id,
|
| 256 |
+
bos_token_id=vocab.bos_token_id,
|
| 257 |
+
eos_token_id=vocab.eos_token_id,
|
| 258 |
+
unk_token_id=vocab.unk_token_id,
|
| 259 |
+
hidden_size=512,
|
| 260 |
+
num_layers=2
|
| 261 |
+
).to(device)
|
| 262 |
+
clip_processor = model.clip_preprocess
|
| 263 |
+
question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 264 |
+
if question_tokenizer.pad_token is None:
|
| 265 |
+
question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 266 |
+
model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
|
| 267 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 268 |
+
print("Model loaded from checkpoint!")
|
| 269 |
+
if model.fine_tuning_mode:
|
| 270 |
+
print("Model already in fine-tuning mode (encoders unfrozen)")
|
| 271 |
+
else:
|
| 272 |
+
print("Continuing with same training configuration")
|
| 273 |
+
print()
|
| 274 |
+
train_dataset = AugmentedVQADataset(
|
| 275 |
+
train_df, DATA_DIR, question_tokenizer, vocab,
|
| 276 |
+
clip_processor=clip_processor,
|
| 277 |
+
question_max_len=question_max_len,
|
| 278 |
+
answer_max_len=answer_max_len,
|
| 279 |
+
augment=True
|
| 280 |
+
)
|
| 281 |
+
val_dataset = AugmentedVQADataset(
|
| 282 |
+
val_df, DATA_DIR, question_tokenizer, vocab,
|
| 283 |
+
clip_processor=clip_processor,
|
| 284 |
+
question_max_len=question_max_len,
|
| 285 |
+
answer_max_len=answer_max_len,
|
| 286 |
+
augment=False
|
| 287 |
+
)
|
| 288 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
| 289 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
| 290 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 291 |
+
optimizer = torch.optim.AdamW(trainable_params, lr=1e-6, weight_decay=1e-4)
|
| 292 |
+
print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
|
| 293 |
+
if 'optimizer_state_dict' in checkpoint:
|
| 294 |
+
try:
|
| 295 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 296 |
+
print("Optimizer state loaded from checkpoint!")
|
| 297 |
+
for param_group in optimizer.param_groups:
|
| 298 |
+
print(f" Loaded LR: {param_group['lr']}")
|
| 299 |
+
except Exception as e:
|
| 300 |
+
print(f"Could not load optimizer state: {e}")
|
| 301 |
+
print("Using fresh optimizer")
|
| 302 |
+
else:
|
| 303 |
+
print("No optimizer state in checkpoint, using fresh optimizer")
|
| 304 |
+
print()
|
| 305 |
+
scaler = torch.amp.GradScaler(device)
|
| 306 |
+
best_val_exact_match = 0.0
|
| 307 |
+
counter = 0
|
| 308 |
+
logs = []
|
| 309 |
+
if os.path.exists(LOG_CSV):
|
| 310 |
+
old_logs = pd.read_csv(LOG_CSV)
|
| 311 |
+
logs = old_logs.values.tolist()
|
| 312 |
+
best_val_exact_match = old_logs['val_exact_match'].max()
|
| 313 |
+
print(f"Previous best exact match: {best_val_exact_match:.4f}")
|
| 314 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 315 |
+
optimizer, mode='max', factor=0.5, patience=4, verbose=True
|
| 316 |
+
)
|
| 317 |
+
total_epochs = start_epoch + additional_epochs
|
| 318 |
+
for epoch in range(start_epoch, total_epochs):
|
| 319 |
+
print(f"\nEpoch {epoch+1}/{total_epochs}")
|
| 320 |
+
train_loss, train_token_acc = train_one_epoch(model, train_loader, optimizer, device, scaler, vocab)
|
| 321 |
+
val_loss, val_token_acc, val_exact_match = validate_one_epoch(model, val_loader, device, vocab)
|
| 322 |
+
print(f"Train Loss: {train_loss:.4f} | Train Token Acc: {train_token_acc:.4f}")
|
| 323 |
+
print(f"Val Loss: {val_loss:.4f} | Val Token Acc: {val_token_acc:.4f} | Val Exact Match: {val_exact_match:.4f}")
|
| 324 |
+
print(f"LR: {optimizer.param_groups[0]['lr']}")
|
| 325 |
+
scheduler.step(val_exact_match)
|
| 326 |
+
if val_exact_match > best_val_exact_match:
|
| 327 |
+
best_val_exact_match = val_exact_match
|
| 328 |
+
save_checkpoint(model, optimizer, epoch, vocab, CHECKPOINT_PATH)
|
| 329 |
+
print("Checkpoint saved!")
|
| 330 |
+
counter = 0
|
| 331 |
+
else:
|
| 332 |
+
counter += 1
|
| 333 |
+
print(f"No improvement in exact match for {counter} epochs.")
|
| 334 |
+
if counter >= patience:
|
| 335 |
+
print(f"\nEarly stopping after {patience} epochs without improvement")
|
| 336 |
+
break
|
| 337 |
+
logs.append([epoch+1, train_loss, train_token_acc, val_loss, val_token_acc, val_exact_match, optimizer.param_groups[0]['lr']])
|
| 338 |
+
log_df = pd.DataFrame(logs, columns=["epoch","train_loss","train_token_acc","val_loss","val_token_acc","val_exact_match","lr"])
|
| 339 |
+
log_df.to_csv(LOG_CSV, index=False)
|
| 340 |
+
plot_losses([x[1] for x in logs], [x[3] for x in logs], save_path=LOSS_GRAPH_PATH)
|
| 341 |
+
print("\nContinued training complete!")
|
| 342 |
+
print(f"Best exact match accuracy: {best_val_exact_match:.4f}")
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
main()
|
continued_training_metric.csv
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
epoch,train_loss,train_token_acc,val_loss,val_token_acc,val_exact_match,lr
|
| 2 |
+
30,1.9502590653601657,0.7322589969014642,1.3020859152640936,0.6990427433882119,0.38998964956380305,1e-06
|
| 3 |
+
31,1.9464403521302605,0.7328945476229131,1.3008682691263702,0.7001620300535886,0.3919858051160727,1e-06
|
| 4 |
+
32,1.9446046694293662,0.733795435205851,1.2995548267971795,0.7003354483617926,0.39220760017743606,1e-06
|
| 5 |
+
33,1.9418390615673053,0.7339540544097625,1.2990998206835873,0.7004338480391592,0.3923554635516783,1e-06
|
| 6 |
+
34,1.9405346881137806,0.733893274767451,1.299637350552487,0.7005681339299904,0.39257725861304155,1e-06
|
| 7 |
+
35,1.9380957318931413,0.7351758044757201,1.2987835997680448,0.7006050677232023,0.39265119030016266,1e-06
|
| 8 |
+
36,1.9369506880350187,0.7359647384978554,1.2979233675407913,0.7013796053405078,0.39405589235546357,1e-06
|
| 9 |
+
37,1.9360789391220428,0.7364758676075498,1.2977605515493538,0.7014409610123005,0.39398196066834246,1e-06
|
| 10 |
+
38,1.9357275886693557,0.7362391176412685,1.297402927054549,0.7011285817848062,0.39353837054561586,1e-06
|
| 11 |
+
39,1.932767997896227,0.736806065813456,1.2974532218474262,0.7004903276573937,0.39220760017743606,1e-06
|
| 12 |
+
40,1.9330583925010325,0.7374090065552876,1.2972412691363748,0.7010474972567469,0.39316871211001037,1e-06
|
| 13 |
+
41,1.9306796564990991,0.7378562083616544,1.2969766115804888,0.7015751037957534,0.39427768741682684,1e-06
|
| 14 |
+
42,1.9282727334571266,0.7377051650808099,1.2973702516195909,0.7011518692070583,0.39331657548425253,1e-06
|
| 15 |
+
43,1.9271106582502864,0.7386680361718415,1.2968672679842643,0.7010392338599799,0.39316871211001037,1e-06
|
| 16 |
+
44,1.9269962475047457,0.7397106953586509,1.296902930680311,0.7012923545432541,0.3936862339198581,1e-06
|
| 17 |
+
45,1.9244166012701376,0.7400805678048972,1.2962118839880206,0.7011814176473977,0.39353837054561586,1e-06
|
| 18 |
+
46,1.9289601324296857,0.7377478108470924,1.296351783118158,0.7014656890675707,0.3941298240425847,5e-07
|
| 19 |
+
47,1.9269490434459369,0.7386778470752227,1.2962728831565604,0.7015336796922503,0.39420375572970573,5e-07
|
| 20 |
+
48,1.9252020313075702,0.7394137923214155,1.2964817043745294,0.7014302642277952,0.39420375572970573,5e-07
|
| 21 |
+
49,1.9241666916486853,0.7392096879001484,1.296351099070513,0.7016751350096937,0.39449948247819017,5e-07
|
conversation_manager.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Conversation Manager for Multi-turn VQA
|
| 3 |
+
Manages conversation state, context, and pronoun resolution
|
| 4 |
+
"""
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Dict, List, Optional, Any
|
| 7 |
+
from datetime import datetime, timedelta
|
| 8 |
+
import uuid
|
| 9 |
+
import re
|
| 10 |
+
@dataclass
|
| 11 |
+
class ConversationTurn:
|
| 12 |
+
"""Represents a single turn in a conversation"""
|
| 13 |
+
question: str
|
| 14 |
+
answer: str
|
| 15 |
+
objects_detected: List[str]
|
| 16 |
+
timestamp: datetime
|
| 17 |
+
reasoning_chain: Optional[List[str]] = None
|
| 18 |
+
model_used: Optional[str] = None
|
| 19 |
+
@dataclass
|
| 20 |
+
class ConversationSession:
|
| 21 |
+
"""Represents a complete conversation session"""
|
| 22 |
+
session_id: str
|
| 23 |
+
image_path: str
|
| 24 |
+
history: List[ConversationTurn] = field(default_factory=list)
|
| 25 |
+
current_objects: List[str] = field(default_factory=list)
|
| 26 |
+
created_at: datetime = field(default_factory=datetime.now)
|
| 27 |
+
last_activity: datetime = field(default_factory=datetime.now)
|
| 28 |
+
def add_turn(
|
| 29 |
+
self,
|
| 30 |
+
question: str,
|
| 31 |
+
answer: str,
|
| 32 |
+
objects_detected: List[str],
|
| 33 |
+
reasoning_chain: Optional[List[str]] = None,
|
| 34 |
+
model_used: Optional[str] = None
|
| 35 |
+
):
|
| 36 |
+
"""Add a new turn to the conversation"""
|
| 37 |
+
turn = ConversationTurn(
|
| 38 |
+
question=question,
|
| 39 |
+
answer=answer,
|
| 40 |
+
objects_detected=objects_detected,
|
| 41 |
+
timestamp=datetime.now(),
|
| 42 |
+
reasoning_chain=reasoning_chain,
|
| 43 |
+
model_used=model_used
|
| 44 |
+
)
|
| 45 |
+
self.history.append(turn)
|
| 46 |
+
if objects_detected:
|
| 47 |
+
self.current_objects = objects_detected
|
| 48 |
+
self.last_activity = datetime.now()
|
| 49 |
+
def get_context_summary(self) -> str:
|
| 50 |
+
"""Get a summary of the conversation context"""
|
| 51 |
+
if not self.history:
|
| 52 |
+
return "No previous conversation"
|
| 53 |
+
summary_parts = []
|
| 54 |
+
for i, turn in enumerate(self.history[-3:], 1):
|
| 55 |
+
summary_parts.append(f"Turn {i}: Q: {turn.question} A: {turn.answer}")
|
| 56 |
+
return " | ".join(summary_parts)
|
| 57 |
+
def is_expired(self, timeout_minutes: int = 30) -> bool:
|
| 58 |
+
"""Check if session has expired"""
|
| 59 |
+
expiry_time = self.last_activity + timedelta(minutes=timeout_minutes)
|
| 60 |
+
return datetime.now() > expiry_time
|
| 61 |
+
class ConversationManager:
|
| 62 |
+
"""
|
| 63 |
+
Manages multi-turn conversation sessions for VQA.
|
| 64 |
+
Handles context retention, pronoun resolution, and session lifecycle.
|
| 65 |
+
"""
|
| 66 |
+
PRONOUNS = ['it', 'this', 'that', 'these', 'those', 'they', 'them']
|
| 67 |
+
def __init__(self, session_timeout_minutes: int = 30):
|
| 68 |
+
"""
|
| 69 |
+
Initialize conversation manager
|
| 70 |
+
Args:
|
| 71 |
+
session_timeout_minutes: Minutes before a session expires
|
| 72 |
+
"""
|
| 73 |
+
self.sessions: Dict[str, ConversationSession] = {}
|
| 74 |
+
self.session_timeout = session_timeout_minutes
|
| 75 |
+
print(f"✅ Conversation Manager initialized (timeout: {session_timeout_minutes}min)")
|
| 76 |
+
def create_session(self, image_path: str, session_id: Optional[str] = None) -> str:
|
| 77 |
+
"""
|
| 78 |
+
Create a new conversation session
|
| 79 |
+
Args:
|
| 80 |
+
image_path: Path to the image for this conversation
|
| 81 |
+
session_id: Optional custom session ID (generates UUID if not provided)
|
| 82 |
+
Returns:
|
| 83 |
+
Session ID
|
| 84 |
+
"""
|
| 85 |
+
if session_id is None:
|
| 86 |
+
session_id = str(uuid.uuid4())
|
| 87 |
+
session = ConversationSession(
|
| 88 |
+
session_id=session_id,
|
| 89 |
+
image_path=image_path
|
| 90 |
+
)
|
| 91 |
+
self.sessions[session_id] = session
|
| 92 |
+
return session_id
|
| 93 |
+
def get_session(self, session_id: str) -> Optional[ConversationSession]:
|
| 94 |
+
"""
|
| 95 |
+
Get an existing session
|
| 96 |
+
Args:
|
| 97 |
+
session_id: Session ID to retrieve
|
| 98 |
+
Returns:
|
| 99 |
+
ConversationSession or None if not found/expired
|
| 100 |
+
"""
|
| 101 |
+
session = self.sessions.get(session_id)
|
| 102 |
+
if session is None:
|
| 103 |
+
return None
|
| 104 |
+
if session.is_expired(self.session_timeout):
|
| 105 |
+
self.delete_session(session_id)
|
| 106 |
+
return None
|
| 107 |
+
return session
|
| 108 |
+
def get_or_create_session(
|
| 109 |
+
self,
|
| 110 |
+
session_id: Optional[str],
|
| 111 |
+
image_path: str
|
| 112 |
+
) -> ConversationSession:
|
| 113 |
+
"""
|
| 114 |
+
Get existing session or create new one
|
| 115 |
+
Args:
|
| 116 |
+
session_id: Optional session ID
|
| 117 |
+
image_path: Image path for new session
|
| 118 |
+
Returns:
|
| 119 |
+
ConversationSession
|
| 120 |
+
"""
|
| 121 |
+
if session_id:
|
| 122 |
+
session = self.get_session(session_id)
|
| 123 |
+
if session:
|
| 124 |
+
return session
|
| 125 |
+
new_id = self.create_session(image_path, session_id)
|
| 126 |
+
return self.sessions[new_id]
|
| 127 |
+
def add_turn(
|
| 128 |
+
self,
|
| 129 |
+
session_id: str,
|
| 130 |
+
question: str,
|
| 131 |
+
answer: str,
|
| 132 |
+
objects_detected: List[str],
|
| 133 |
+
reasoning_chain: Optional[List[str]] = None,
|
| 134 |
+
model_used: Optional[str] = None
|
| 135 |
+
) -> bool:
|
| 136 |
+
"""
|
| 137 |
+
Add a turn to a conversation session
|
| 138 |
+
Args:
|
| 139 |
+
session_id: Session ID
|
| 140 |
+
question: User's question
|
| 141 |
+
answer: VQA answer
|
| 142 |
+
objects_detected: List of detected objects
|
| 143 |
+
reasoning_chain: Optional reasoning steps
|
| 144 |
+
model_used: Optional model identifier
|
| 145 |
+
Returns:
|
| 146 |
+
True if successful, False if session not found
|
| 147 |
+
"""
|
| 148 |
+
session = self.get_session(session_id)
|
| 149 |
+
if session is None:
|
| 150 |
+
return False
|
| 151 |
+
session.add_turn(
|
| 152 |
+
question=question,
|
| 153 |
+
answer=answer,
|
| 154 |
+
objects_detected=objects_detected,
|
| 155 |
+
reasoning_chain=reasoning_chain,
|
| 156 |
+
model_used=model_used
|
| 157 |
+
)
|
| 158 |
+
return True
|
| 159 |
+
def resolve_references(
|
| 160 |
+
self,
|
| 161 |
+
question: str,
|
| 162 |
+
session: ConversationSession
|
| 163 |
+
) -> str:
|
| 164 |
+
"""
|
| 165 |
+
Resolve pronouns and references in a question using conversation context.
|
| 166 |
+
Args:
|
| 167 |
+
question: User's question (may contain pronouns)
|
| 168 |
+
session: Conversation session with context
|
| 169 |
+
Returns:
|
| 170 |
+
Question with pronouns resolved
|
| 171 |
+
Example:
|
| 172 |
+
Input: "Is it healthy?"
|
| 173 |
+
Context: Previous object was "apple"
|
| 174 |
+
Output: "Is apple healthy?"
|
| 175 |
+
"""
|
| 176 |
+
if not session.history:
|
| 177 |
+
return question
|
| 178 |
+
q_lower = question.lower()
|
| 179 |
+
has_pronoun = any(pronoun in q_lower.split() for pronoun in self.PRONOUNS)
|
| 180 |
+
if not has_pronoun:
|
| 181 |
+
return question
|
| 182 |
+
recent_objects = session.current_objects
|
| 183 |
+
if not recent_objects:
|
| 184 |
+
return question
|
| 185 |
+
resolved = question
|
| 186 |
+
if any(pronoun in q_lower.split() for pronoun in ['it', 'this', 'that']):
|
| 187 |
+
primary_object = recent_objects[0]
|
| 188 |
+
resolved = re.sub(r'\bit\b', primary_object, resolved, flags=re.IGNORECASE)
|
| 189 |
+
resolved = re.sub(r'\bthis\b', primary_object, resolved, flags=re.IGNORECASE)
|
| 190 |
+
resolved = re.sub(r'\bthat\b', primary_object, resolved, flags=re.IGNORECASE)
|
| 191 |
+
if any(pronoun in q_lower.split() for pronoun in ['these', 'those', 'they', 'them']):
|
| 192 |
+
objects_phrase = ', '.join(recent_objects)
|
| 193 |
+
resolved = re.sub(r'\bthese\b', objects_phrase, resolved, flags=re.IGNORECASE)
|
| 194 |
+
resolved = re.sub(r'\bthose\b', objects_phrase, resolved, flags=re.IGNORECASE)
|
| 195 |
+
resolved = re.sub(r'\bthey\b', objects_phrase, resolved, flags=re.IGNORECASE)
|
| 196 |
+
resolved = re.sub(r'\bthem\b', objects_phrase, resolved, flags=re.IGNORECASE)
|
| 197 |
+
return resolved
|
| 198 |
+
def get_context_for_question(
|
| 199 |
+
self,
|
| 200 |
+
session_id: str,
|
| 201 |
+
question: str
|
| 202 |
+
) -> Dict[str, Any]:
|
| 203 |
+
"""
|
| 204 |
+
Get relevant context for answering a question
|
| 205 |
+
Args:
|
| 206 |
+
session_id: Session ID
|
| 207 |
+
question: Current question
|
| 208 |
+
Returns:
|
| 209 |
+
Dict with context information
|
| 210 |
+
"""
|
| 211 |
+
session = self.get_session(session_id)
|
| 212 |
+
if session is None:
|
| 213 |
+
return {
|
| 214 |
+
'has_context': False,
|
| 215 |
+
'turn_number': 0,
|
| 216 |
+
'previous_objects': [],
|
| 217 |
+
'previous_questions': []
|
| 218 |
+
}
|
| 219 |
+
return {
|
| 220 |
+
'has_context': len(session.history) > 0,
|
| 221 |
+
'turn_number': len(session.history) + 1,
|
| 222 |
+
'previous_objects': session.current_objects,
|
| 223 |
+
'previous_questions': [turn.question for turn in session.history[-3:]],
|
| 224 |
+
'previous_answers': [turn.answer for turn in session.history[-3:]],
|
| 225 |
+
'context_summary': session.get_context_summary()
|
| 226 |
+
}
|
| 227 |
+
def get_history(self, session_id: str) -> Optional[List[Dict[str, Any]]]:
|
| 228 |
+
"""
|
| 229 |
+
Get conversation history for a session
|
| 230 |
+
Args:
|
| 231 |
+
session_id: Session ID
|
| 232 |
+
Returns:
|
| 233 |
+
List of turn dictionaries or None if session not found
|
| 234 |
+
"""
|
| 235 |
+
session = self.get_session(session_id)
|
| 236 |
+
if session is None:
|
| 237 |
+
return None
|
| 238 |
+
history = []
|
| 239 |
+
for turn in session.history:
|
| 240 |
+
history.append({
|
| 241 |
+
'question': turn.question,
|
| 242 |
+
'answer': turn.answer,
|
| 243 |
+
'objects_detected': turn.objects_detected,
|
| 244 |
+
'timestamp': turn.timestamp.isoformat(),
|
| 245 |
+
'reasoning_chain': turn.reasoning_chain,
|
| 246 |
+
'model_used': turn.model_used
|
| 247 |
+
})
|
| 248 |
+
return history
|
| 249 |
+
def delete_session(self, session_id: str) -> bool:
|
| 250 |
+
"""
|
| 251 |
+
Delete a conversation session
|
| 252 |
+
Args:
|
| 253 |
+
session_id: Session ID to delete
|
| 254 |
+
Returns:
|
| 255 |
+
True if deleted, False if not found
|
| 256 |
+
"""
|
| 257 |
+
if session_id in self.sessions:
|
| 258 |
+
del self.sessions[session_id]
|
| 259 |
+
return True
|
| 260 |
+
return False
|
| 261 |
+
def cleanup_expired_sessions(self):
|
| 262 |
+
"""Remove all expired sessions"""
|
| 263 |
+
expired_ids = [
|
| 264 |
+
sid for sid, session in self.sessions.items()
|
| 265 |
+
if session.is_expired(self.session_timeout)
|
| 266 |
+
]
|
| 267 |
+
for sid in expired_ids:
|
| 268 |
+
self.delete_session(sid)
|
| 269 |
+
return len(expired_ids)
|
| 270 |
+
def get_active_sessions_count(self) -> int:
|
| 271 |
+
"""Get count of active (non-expired) sessions"""
|
| 272 |
+
self.cleanup_expired_sessions()
|
| 273 |
+
return len(self.sessions)
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
print("=" * 80)
|
| 276 |
+
print("🧪 Testing Conversation Manager")
|
| 277 |
+
print("=" * 80)
|
| 278 |
+
manager = ConversationManager(session_timeout_minutes=30)
|
| 279 |
+
print("\n📝 Test 1: Multi-turn conversation")
|
| 280 |
+
session_id = manager.create_session("test_image.jpg")
|
| 281 |
+
print(f"Created session: {session_id}")
|
| 282 |
+
manager.add_turn(
|
| 283 |
+
session_id=session_id,
|
| 284 |
+
question="What is this?",
|
| 285 |
+
answer="apple",
|
| 286 |
+
objects_detected=["apple"]
|
| 287 |
+
)
|
| 288 |
+
print("Turn 1: 'What is this?' → 'apple'")
|
| 289 |
+
session = manager.get_session(session_id)
|
| 290 |
+
question_2 = "Is it healthy?"
|
| 291 |
+
resolved_2 = manager.resolve_references(question_2, session)
|
| 292 |
+
print(f"Turn 2: '{question_2}' → Resolved: '{resolved_2}'")
|
| 293 |
+
manager.add_turn(
|
| 294 |
+
session_id=session_id,
|
| 295 |
+
question=question_2,
|
| 296 |
+
answer="Yes, apples are healthy",
|
| 297 |
+
objects_detected=["apple"]
|
| 298 |
+
)
|
| 299 |
+
question_3 = "What color is it?"
|
| 300 |
+
resolved_3 = manager.resolve_references(question_3, session)
|
| 301 |
+
print(f"Turn 3: '{question_3}' → Resolved: '{resolved_3}'")
|
| 302 |
+
print("\n📝 Test 2: Context retrieval")
|
| 303 |
+
context = manager.get_context_for_question(session_id, "Another question")
|
| 304 |
+
print(f"Turn number: {context['turn_number']}")
|
| 305 |
+
print(f"Previous objects: {context['previous_objects']}")
|
| 306 |
+
print(f"Context summary: {context['context_summary']}")
|
| 307 |
+
print("\n📝 Test 3: Conversation history")
|
| 308 |
+
history = manager.get_history(session_id)
|
| 309 |
+
for i, turn in enumerate(history, 1):
|
| 310 |
+
print(f" Turn {i}: Q: {turn['question']} | A: {turn['answer']}")
|
| 311 |
+
print("\n" + "=" * 80)
|
| 312 |
+
print("✅ Tests completed!")
|
download_models.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from huggingface_hub import hf_hub_download
|
| 3 |
+
|
| 4 |
+
REPO_ID = "Deva8/GENvqa-model"
|
| 5 |
+
|
| 6 |
+
# We use the token from the environment variable (which the user must set in Settings -> Secrets)
|
| 7 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 8 |
+
|
| 9 |
+
print("Downloading models from HuggingFace Hub...")
|
| 10 |
+
|
| 11 |
+
# Download base checkpoint
|
| 12 |
+
hf_hub_download(
|
| 13 |
+
repo_id=REPO_ID,
|
| 14 |
+
filename="vqa_checkpoint.pt",
|
| 15 |
+
local_dir=".",
|
| 16 |
+
token=HF_TOKEN
|
| 17 |
+
)
|
| 18 |
+
print("Base checkpoint downloaded successfully.")
|
| 19 |
+
|
| 20 |
+
# Download spatial checkpoint
|
| 21 |
+
hf_hub_download(
|
| 22 |
+
repo_id=REPO_ID,
|
| 23 |
+
filename="vqa_spatial_checkpoint.pt",
|
| 24 |
+
local_dir=".",
|
| 25 |
+
token=HF_TOKEN
|
| 26 |
+
)
|
| 27 |
+
print("Spatial checkpoint downloaded successfully.")
|
draft_generator.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
mermaid_code = """
|
| 5 |
+
graph TD
|
| 6 |
+
%% Styling
|
| 7 |
+
classDef default fill:#1A1A1A,stroke:#444,stroke-width:2px,color:#FFF,rx:8px,ry:8px,font-family:arial;
|
| 8 |
+
classDef mobile fill:#003366,stroke:#0055AA,stroke-width:2px,color:#FFF;
|
| 9 |
+
classDef preproc fill:#333333,stroke:#555,stroke-width:2px,color:#FFF;
|
| 10 |
+
classDef model fill:#4B0082,stroke:#8A2BE2,stroke-width:2px,color:#FFF;
|
| 11 |
+
classDef condition fill:#2B2B2B,stroke:#F4A460,stroke-width:2px,color:#FFF,shape:rhombus;
|
| 12 |
+
classDef external fill:#004d00,stroke:#009900,stroke-width:2px,color:#FFF;
|
| 13 |
+
classDef final fill:#660000,stroke:#CC0000,stroke-width:2px,color:#FFF;
|
| 14 |
+
|
| 15 |
+
%% Nodes
|
| 16 |
+
UserApp[📱 Mobile App]:::mobile
|
| 17 |
+
|
| 18 |
+
ImgUpload[🖼️ Image]:::preproc
|
| 19 |
+
Question[⌨️ Question Text]:::preproc
|
| 20 |
+
|
| 21 |
+
PIL[🐍 PIL Preprocessing<br/>RGB conversion]:::preproc
|
| 22 |
+
|
| 23 |
+
CLIP[👁️ OpenAI CLIP ViT-B/32<br/>Image Features 512-dim]:::model
|
| 24 |
+
GPT2[🤗 DistilGPT-2<br/>Tokenized Question]:::model
|
| 25 |
+
|
| 26 |
+
Route1{Question<br/>spatial?}:::condition
|
| 27 |
+
|
| 28 |
+
Spatial[📐 Spatial VQA Model<br/>8-head attention]:::model
|
| 29 |
+
Base[🧠 Base VQA Model<br/>General VQA]:::model
|
| 30 |
+
|
| 31 |
+
Decoder[🤗 GPT-2 Decoder<br/>vocab decode]:::model
|
| 32 |
+
NeuralAns[💬 Neural Answer]:::final
|
| 33 |
+
|
| 34 |
+
Route2{Knowledge<br/>question?}:::condition
|
| 35 |
+
|
| 36 |
+
ObjDet[👁️ CLIP Object Detector<br/>Top-3 objects]:::model
|
| 37 |
+
Wikidata[🌍 Wikidata SPARQL<br/>P31, P186, P366]:::external
|
| 38 |
+
GroqV[⚡ Groq Llama-3.3<br/>Verbalizer]:::external
|
| 39 |
+
KGAns[🧩 KG Enhancement]:::final
|
| 40 |
+
|
| 41 |
+
FastAPI[🚀 FastAPI]:::preproc
|
| 42 |
+
GroqA[⚡ Groq Llama-3.3<br/>Accessibility]:::external
|
| 43 |
+
Audio[🔊 2-sentence description]:::final
|
| 44 |
+
|
| 45 |
+
%% Edges
|
| 46 |
+
UserApp -- "Image uploaded" --> ImgUpload
|
| 47 |
+
UserApp -- "Question typed" --> Question
|
| 48 |
+
|
| 49 |
+
ImgUpload --> PIL
|
| 50 |
+
PIL --> CLIP
|
| 51 |
+
Question --> GPT2
|
| 52 |
+
|
| 53 |
+
CLIP & GPT2 --> Route1
|
| 54 |
+
|
| 55 |
+
Route1 -- "YES" --> Spatial
|
| 56 |
+
Route1 -- "NO" --> Base
|
| 57 |
+
|
| 58 |
+
Spatial & Base -- "Beam search (width=5)" --> Decoder
|
| 59 |
+
Decoder --> NeuralAns
|
| 60 |
+
|
| 61 |
+
CLIP -- "Anchor similarity" --> Route2
|
| 62 |
+
|
| 63 |
+
Route2 -- "YES" --> ObjDet
|
| 64 |
+
ObjDet -- "Detected objects" --> Wikidata
|
| 65 |
+
Wikidata -- "Structured facts" --> GroqV
|
| 66 |
+
GroqV --> KGAns
|
| 67 |
+
|
| 68 |
+
FastAPI -- "Narration request" --> GroqA
|
| 69 |
+
GroqA --> Audio
|
| 70 |
+
|
| 71 |
+
NeuralAns & KGAns & Audio -- "JSON output" --> FastAPI
|
| 72 |
+
FastAPI --> UserApp
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
file_path = r"C:\Users\rdeva\Downloads\vqa_coes\architecture_draft.mmd"
|
| 76 |
+
|
| 77 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 78 |
+
f.write(mermaid_code)
|
| 79 |
+
|
| 80 |
+
print(f"Mermaid file saved to {file_path}")
|
| 81 |
+
|
| 82 |
+
# Note: In a real environment, we would use mermaid-cli (mmdc) to convert this to SVG/PNG.
|
| 83 |
+
# Since it might not be installed globally, we will just provide the mermaid file and
|
| 84 |
+
# instructions, or generate an HTML wrapper that renders it in browser.
|
| 85 |
+
|
| 86 |
+
html_path = r"C:\Users\rdeva\Downloads\vqa_coes\architecture_draft.html"
|
| 87 |
+
html_content = f"""
|
| 88 |
+
<!DOCTYPE html>
|
| 89 |
+
<html>
|
| 90 |
+
<head>
|
| 91 |
+
<title>VQA Architecture Draft</title>
|
| 92 |
+
<script type="module">
|
| 93 |
+
import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
|
| 94 |
+
mermaid.initialize({{ startOnLoad: true, theme: 'dark', flowchart: {{ curve: 'basis' }} }});
|
| 95 |
+
</script>
|
| 96 |
+
<style>
|
| 97 |
+
body {{ background-color: #0D1117; color: white; font-family: sans-serif; display: flex; justify-content: center; padding: 20px; }}
|
| 98 |
+
.mermaid {{ background-color: #161B22; padding: 20px; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.5); }}
|
| 99 |
+
</style>
|
| 100 |
+
</head>
|
| 101 |
+
<body>
|
| 102 |
+
<div class="mermaid">
|
| 103 |
+
{mermaid_code}
|
| 104 |
+
</div>
|
| 105 |
+
</body>
|
| 106 |
+
</html>
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
with open(html_path, "w", encoding="utf-8") as f:
|
| 110 |
+
f.write(html_content)
|
| 111 |
+
|
| 112 |
+
print(f"HTML viewer saved to {html_path}")
|
ensemble_vqa_app.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Production Ensemble VQA Application
|
| 3 |
+
Combines base model (general VQA) and spatial adapter (spatial reasoning)
|
| 4 |
+
for optimal performance on all question types.
|
| 5 |
+
NEW: Neuro-Symbolic VQA with Knowledge Graph integration
|
| 6 |
+
NEW: Multi-turn Conversational VQA with context management
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from transformers import GPT2Tokenizer
|
| 12 |
+
from models.model import VQAModel
|
| 13 |
+
from model_spatial import VQAModelWithSpatialAdapter
|
| 14 |
+
from experiments.train import Vocab
|
| 15 |
+
from knowledge_graph_service import KnowledgeGraphService
|
| 16 |
+
from typing import Optional
|
| 17 |
+
import time
|
| 18 |
+
class ProductionEnsembleVQA:
|
| 19 |
+
|
| 20 |
+
SPATIAL_KEYWORDS = [
|
| 21 |
+
'right', 'left', 'above', 'below', 'top', 'bottom',
|
| 22 |
+
'up', 'down', 'upward', 'downward',
|
| 23 |
+
'front', 'behind', 'back', 'next to', 'beside', 'near', 'between',
|
| 24 |
+
'in front', 'in back', 'across from', 'opposite', 'adjacent',
|
| 25 |
+
'closest', 'farthest', 'nearest', 'furthest', 'closer', 'farther',
|
| 26 |
+
'where is', 'where are', 'which side', 'what side', 'what direction',
|
| 27 |
+
'on the left', 'on the right', 'at the top', 'at the bottom',
|
| 28 |
+
'to the left', 'to the right', 'in the middle', 'in the center',
|
| 29 |
+
'under', 'over', 'underneath', 'on top of', 'inside', 'outside'
|
| 30 |
+
]
|
| 31 |
+
def __init__(self, base_checkpoint, spatial_checkpoint, device='cuda'):
|
| 32 |
+
|
| 33 |
+
self.device = device if torch.cuda.is_available() else 'cpu'
|
| 34 |
+
print("="*80)
|
| 35 |
+
print("🚀 INITIALIZING ENSEMBLE VQA SYSTEM")
|
| 36 |
+
print("="*80)
|
| 37 |
+
print(f"\n⚙️ Device: {self.device}")
|
| 38 |
+
print("\n📥 Loading models...")
|
| 39 |
+
start_time = time.time()
|
| 40 |
+
print(" [1/2] Loading base model (general VQA)...")
|
| 41 |
+
self.base_model, self.vocab, self.tokenizer = self._load_base_model(base_checkpoint)
|
| 42 |
+
print(" ✓ Base model loaded")
|
| 43 |
+
print(" [2/2] Loading spatial model (spatial reasoning)...")
|
| 44 |
+
self.spatial_model, _, _ = self._load_spatial_model(spatial_checkpoint)
|
| 45 |
+
print(" ✓ Spatial model loaded")
|
| 46 |
+
load_time = time.time() - start_time
|
| 47 |
+
print(" [3/3] Initializing Semantic Neuro-Symbolic VQA...")
|
| 48 |
+
try:
|
| 49 |
+
from semantic_neurosymbolic_vqa import SemanticNeurosymbolicVQA
|
| 50 |
+
self.kg_service = SemanticNeurosymbolicVQA(device=self.device)
|
| 51 |
+
print(" ✓ Semantic Neuro-Symbolic VQA ready (CLIP + Wikidata, no pattern matching)")
|
| 52 |
+
self.kg_enabled = True
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f" ⚠️ Semantic Neuro-Symbolic VQA unavailable: {e}")
|
| 55 |
+
print(" → Falling back to neural-only mode")
|
| 56 |
+
self.kg_service = None
|
| 57 |
+
self.kg_enabled = False
|
| 58 |
+
print(f"\n✅ Ensemble ready! (loaded in {load_time:.1f}s)")
|
| 59 |
+
print(f"📊 Memory: ~2x single model (~4GB GPU)")
|
| 60 |
+
print(f"🎯 Routing: Automatic based on question type")
|
| 61 |
+
print(f"🧠 Neuro-Symbolic: {'Enabled' if self.kg_enabled else 'Disabled (neural-only)'}")
|
| 62 |
+
print(f"💬 Conversation: Initializing multi-turn support...")
|
| 63 |
+
try:
|
| 64 |
+
from conversation_manager import ConversationManager
|
| 65 |
+
self.conversation_manager = ConversationManager(session_timeout_minutes=30)
|
| 66 |
+
self.conversation_enabled = True
|
| 67 |
+
print(f" ✓ Conversational VQA ready (multi-turn with context)")
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f" ⚠️ Conversation manager unavailable: {e}")
|
| 70 |
+
print(f" → Single-shot Q&A only")
|
| 71 |
+
self.conversation_manager = None
|
| 72 |
+
self.conversation_enabled = False
|
| 73 |
+
print("="*80)
|
| 74 |
+
def _load_base_model(self, checkpoint_path):
|
| 75 |
+
"""Load base VQA model."""
|
| 76 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 77 |
+
vocab = Vocab()
|
| 78 |
+
vocab.vocab = checkpoint['vocab']
|
| 79 |
+
vocab.vocab_size = len(checkpoint['vocab'])
|
| 80 |
+
vocab.word2idx = checkpoint['word2idx']
|
| 81 |
+
vocab.idx2word = checkpoint['idx2word']
|
| 82 |
+
vocab.pad_token_id = checkpoint['pad_token_id']
|
| 83 |
+
vocab.bos_token_id = checkpoint['bos_token_id']
|
| 84 |
+
vocab.eos_token_id = checkpoint['eos_token_id']
|
| 85 |
+
vocab.unk_token_id = checkpoint['unk_token_id']
|
| 86 |
+
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 87 |
+
if tokenizer.pad_token is None:
|
| 88 |
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 89 |
+
model = VQAModel(
|
| 90 |
+
vocab_size=len(checkpoint['vocab']),
|
| 91 |
+
device=self.device,
|
| 92 |
+
question_max_len=checkpoint.get('question_max_len', 20),
|
| 93 |
+
answer_max_len=checkpoint.get('answer_max_len', 12),
|
| 94 |
+
pad_token_id=checkpoint['pad_token_id'],
|
| 95 |
+
bos_token_id=checkpoint['bos_token_id'],
|
| 96 |
+
eos_token_id=checkpoint['eos_token_id'],
|
| 97 |
+
unk_token_id=checkpoint['unk_token_id'],
|
| 98 |
+
hidden_size=512,
|
| 99 |
+
num_layers=2
|
| 100 |
+
).to(self.device)
|
| 101 |
+
model.gpt2_model.resize_token_embeddings(len(tokenizer))
|
| 102 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 103 |
+
model.eval()
|
| 104 |
+
return model, vocab, tokenizer
|
| 105 |
+
def _load_spatial_model(self, checkpoint_path):
|
| 106 |
+
"""Load spatial adapter model."""
|
| 107 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 108 |
+
vocab = Vocab()
|
| 109 |
+
vocab.vocab = checkpoint['vocab']
|
| 110 |
+
vocab.vocab_size = len(checkpoint['vocab'])
|
| 111 |
+
vocab.word2idx = checkpoint['word2idx']
|
| 112 |
+
vocab.idx2word = checkpoint['idx2word']
|
| 113 |
+
vocab.pad_token_id = checkpoint['pad_token_id']
|
| 114 |
+
vocab.bos_token_id = checkpoint['bos_token_id']
|
| 115 |
+
vocab.eos_token_id = checkpoint['eos_token_id']
|
| 116 |
+
vocab.unk_token_id = checkpoint['unk_token_id']
|
| 117 |
+
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 118 |
+
if tokenizer.pad_token is None:
|
| 119 |
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 120 |
+
base_model = VQAModel(
|
| 121 |
+
vocab_size=len(checkpoint['vocab']),
|
| 122 |
+
device=self.device,
|
| 123 |
+
question_max_len=checkpoint.get('question_max_len', 20),
|
| 124 |
+
answer_max_len=checkpoint.get('answer_max_len', 12),
|
| 125 |
+
pad_token_id=checkpoint['pad_token_id'],
|
| 126 |
+
bos_token_id=checkpoint['bos_token_id'],
|
| 127 |
+
eos_token_id=checkpoint['eos_token_id'],
|
| 128 |
+
unk_token_id=checkpoint['unk_token_id'],
|
| 129 |
+
hidden_size=512,
|
| 130 |
+
num_layers=2
|
| 131 |
+
).to(self.device)
|
| 132 |
+
base_model.gpt2_model.resize_token_embeddings(len(tokenizer))
|
| 133 |
+
model = VQAModelWithSpatialAdapter(
|
| 134 |
+
base_model=base_model,
|
| 135 |
+
hidden_size=512,
|
| 136 |
+
num_heads=8,
|
| 137 |
+
dropout=0.3
|
| 138 |
+
).to(self.device)
|
| 139 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 140 |
+
model.eval()
|
| 141 |
+
return model, vocab, tokenizer
|
| 142 |
+
def is_spatial_question(self, question):
|
| 143 |
+
"""
|
| 144 |
+
Classify if a question is spatial using keyword matching.
|
| 145 |
+
Args:
|
| 146 |
+
question: Question string
|
| 147 |
+
Returns:
|
| 148 |
+
bool: True if spatial, False otherwise
|
| 149 |
+
"""
|
| 150 |
+
q_lower = question.lower()
|
| 151 |
+
return any(keyword in q_lower for keyword in self.SPATIAL_KEYWORDS)
|
| 152 |
+
def answer(self, image_path, question, use_beam_search=True, beam_width=5, verbose=False):
|
| 153 |
+
"""
|
| 154 |
+
Answer a question by routing to appropriate model.
|
| 155 |
+
Now with Neuro-Symbolic reasoning for common-sense questions!
|
| 156 |
+
Args:
|
| 157 |
+
image_path: Path to image file
|
| 158 |
+
question: Question string
|
| 159 |
+
use_beam_search: Whether to use beam search (better quality)
|
| 160 |
+
beam_width: Beam width for beam search
|
| 161 |
+
verbose: Print routing information
|
| 162 |
+
Returns:
|
| 163 |
+
dict: {
|
| 164 |
+
'answer': str,
|
| 165 |
+
'model_used': 'spatial' or 'base',
|
| 166 |
+
'confidence': float,
|
| 167 |
+
'kg_enhancement': str (optional),
|
| 168 |
+
'reasoning_type': 'neural' or 'neuro-symbolic'
|
| 169 |
+
}
|
| 170 |
+
"""
|
| 171 |
+
is_spatial = self.is_spatial_question(question)
|
| 172 |
+
model_used = 'spatial' if is_spatial else 'base'
|
| 173 |
+
if verbose:
|
| 174 |
+
print(f"🔍 Question type: {'Spatial' if is_spatial else 'General'}")
|
| 175 |
+
print(f"🤖 Using: {model_used} model")
|
| 176 |
+
model = self.spatial_model if is_spatial else self.base_model
|
| 177 |
+
image = Image.open(image_path).convert('RGB')
|
| 178 |
+
image = model.clip_preprocess(image).unsqueeze(0).to(self.device)
|
| 179 |
+
question_tokens = self.tokenizer(
|
| 180 |
+
question,
|
| 181 |
+
padding='max_length',
|
| 182 |
+
truncation=True,
|
| 183 |
+
max_length=model.question_max_len,
|
| 184 |
+
return_tensors='pt'
|
| 185 |
+
)
|
| 186 |
+
questions = {
|
| 187 |
+
'input_ids': question_tokens['input_ids'].to(self.device),
|
| 188 |
+
'attention_mask': question_tokens['attention_mask'].to(self.device)
|
| 189 |
+
}
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
if use_beam_search and hasattr(model, 'generate_with_beam_search'):
|
| 192 |
+
generated = model.generate_with_beam_search(
|
| 193 |
+
image, questions, beam_width=beam_width
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
generated = model(image, questions)
|
| 197 |
+
# Always get the neural answer first — it is ALWAYS the primary answer
|
| 198 |
+
if verbose:
|
| 199 |
+
print("📝 Using neural VQA...")
|
| 200 |
+
neural_answer = self.vocab.decoder(generated[0].cpu().numpy())
|
| 201 |
+
|
| 202 |
+
# Neuro-symbolic is a *supplement* only — its result goes into kg_enhancement,
|
| 203 |
+
# never replacing the neural answer.
|
| 204 |
+
kg_enhancement = None
|
| 205 |
+
reasoning_type = 'neural'
|
| 206 |
+
objects_detected = []
|
| 207 |
+
question_intent = None
|
| 208 |
+
wikidata_entity = None
|
| 209 |
+
knowledge_source = None
|
| 210 |
+
|
| 211 |
+
if self.kg_enabled and self.kg_service:
|
| 212 |
+
if verbose:
|
| 213 |
+
print("🔍 Analyzing question semantics...")
|
| 214 |
+
should_use_ns = self.kg_service.should_use_neurosymbolic(
|
| 215 |
+
image_features=None,
|
| 216 |
+
question=question,
|
| 217 |
+
vqa_confidence=0.0,
|
| 218 |
+
image_path=image_path
|
| 219 |
+
)
|
| 220 |
+
if should_use_ns:
|
| 221 |
+
if verbose:
|
| 222 |
+
print("🧠 Neuro-Symbolic supplement: detecting subject via CLIP...")
|
| 223 |
+
|
| 224 |
+
# CLIP zero-shot: compare image against 80+ concrete noun labels
|
| 225 |
+
# This is much more accurate than asking the VQA model
|
| 226 |
+
detected_objects = self.kg_service.detect_objects_with_clip(
|
| 227 |
+
image_path=image_path, top_k=3
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
if verbose:
|
| 231 |
+
print(f" → CLIP detected: {detected_objects}")
|
| 232 |
+
print(" → Fetching Wikidata facts + Groq verbalization...")
|
| 233 |
+
|
| 234 |
+
if detected_objects:
|
| 235 |
+
ns_result = self.kg_service.answer_with_clip_features(
|
| 236 |
+
image_features=None,
|
| 237 |
+
question=question,
|
| 238 |
+
image_path=image_path,
|
| 239 |
+
detected_objects=tuple(detected_objects)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if ns_result:
|
| 243 |
+
kg_enhancement = ns_result['kg_enhancement']
|
| 244 |
+
reasoning_type = 'neuro-symbolic'
|
| 245 |
+
objects_detected = detected_objects # expose to return dict
|
| 246 |
+
question_intent = ns_result.get('question_intent')
|
| 247 |
+
wikidata_entity = ns_result.get('wikidata_entity')
|
| 248 |
+
knowledge_source = ns_result.get('knowledge_source')
|
| 249 |
+
if verbose:
|
| 250 |
+
print(f"✨ Neuro-Symbolic supplement: {kg_enhancement}")
|
| 251 |
+
print(f" → Wikidata entity: {wikidata_entity}")
|
| 252 |
+
else:
|
| 253 |
+
if verbose:
|
| 254 |
+
print(" → CLIP could not identify subject, skipping Wikidata lookup")
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
'answer': neural_answer,
|
| 258 |
+
'model_used': model_used,
|
| 259 |
+
'confidence': 1.0,
|
| 260 |
+
'kg_enhancement': kg_enhancement,
|
| 261 |
+
'reasoning_type': reasoning_type,
|
| 262 |
+
'objects_detected': objects_detected,
|
| 263 |
+
'question_intent': question_intent,
|
| 264 |
+
'wikidata_entity': wikidata_entity,
|
| 265 |
+
'knowledge_source': knowledge_source,
|
| 266 |
+
}
|
| 267 |
+
def answer_conversational(
|
| 268 |
+
self,
|
| 269 |
+
image_path: str,
|
| 270 |
+
question: str,
|
| 271 |
+
session_id: Optional[str] = None,
|
| 272 |
+
use_beam_search: bool = True,
|
| 273 |
+
beam_width: int = 5,
|
| 274 |
+
verbose: bool = False
|
| 275 |
+
) -> dict:
|
| 276 |
+
"""
|
| 277 |
+
Answer a question with multi-turn conversation support.
|
| 278 |
+
Handles pronoun resolution and context management.
|
| 279 |
+
Args:
|
| 280 |
+
image_path: Path to image file
|
| 281 |
+
question: Question string (may contain pronouns like "it", "this")
|
| 282 |
+
session_id: Optional session ID for continuing conversation
|
| 283 |
+
use_beam_search: Whether to use beam search
|
| 284 |
+
beam_width: Beam width for beam search
|
| 285 |
+
verbose: Print routing information
|
| 286 |
+
Returns:
|
| 287 |
+
dict: {
|
| 288 |
+
'answer': str,
|
| 289 |
+
'session_id': str,
|
| 290 |
+
'resolved_question': str,
|
| 291 |
+
'conversation_context': dict,
|
| 292 |
+
... (other fields from answer())
|
| 293 |
+
}
|
| 294 |
+
"""
|
| 295 |
+
if not self.conversation_enabled or not self.conversation_manager:
|
| 296 |
+
result = self.answer(image_path, question, use_beam_search, beam_width, verbose)
|
| 297 |
+
result['session_id'] = None
|
| 298 |
+
result['resolved_question'] = question
|
| 299 |
+
result['conversation_context'] = {'has_context': False}
|
| 300 |
+
return result
|
| 301 |
+
session = self.conversation_manager.get_or_create_session(session_id, image_path)
|
| 302 |
+
actual_session_id = session.session_id
|
| 303 |
+
if verbose:
|
| 304 |
+
print(f"💬 Session: {actual_session_id}")
|
| 305 |
+
print(f" Turn number: {len(session.history) + 1}")
|
| 306 |
+
resolved_question = self.conversation_manager.resolve_references(question, session)
|
| 307 |
+
if verbose and resolved_question != question:
|
| 308 |
+
print(f"🔄 Pronoun resolution:")
|
| 309 |
+
print(f" Original: {question}")
|
| 310 |
+
print(f" Resolved: {resolved_question}")
|
| 311 |
+
result = self.answer(
|
| 312 |
+
image_path=image_path,
|
| 313 |
+
question=resolved_question,
|
| 314 |
+
use_beam_search=use_beam_search,
|
| 315 |
+
beam_width=beam_width,
|
| 316 |
+
verbose=verbose
|
| 317 |
+
)
|
| 318 |
+
self.conversation_manager.add_turn(
|
| 319 |
+
session_id=actual_session_id,
|
| 320 |
+
question=question,
|
| 321 |
+
answer=result['answer'],
|
| 322 |
+
objects_detected=result.get('objects_detected', []),
|
| 323 |
+
reasoning_chain=result.get('reasoning_chain'),
|
| 324 |
+
model_used=result.get('model_used')
|
| 325 |
+
)
|
| 326 |
+
context = self.conversation_manager.get_context_for_question(
|
| 327 |
+
actual_session_id,
|
| 328 |
+
question
|
| 329 |
+
)
|
| 330 |
+
result['session_id'] = actual_session_id
|
| 331 |
+
result['resolved_question'] = resolved_question
|
| 332 |
+
result['conversation_context'] = context
|
| 333 |
+
return result
|
| 334 |
+
def _detect_multiple_objects(self, image, vqa_model, top_k=3):
|
| 335 |
+
"""
|
| 336 |
+
Detect the primary subject of the image using neutral, unbiased questions.
|
| 337 |
+
We ask the same question several ways so the VQA model has the best chance
|
| 338 |
+
of identifying the actual subject — never biasing toward food or objects.
|
| 339 |
+
Returns at most top_k unique answers.
|
| 340 |
+
"""
|
| 341 |
+
# Neutral questions — no food bias, no category bias
|
| 342 |
+
detection_questions = [
|
| 343 |
+
"What is the main subject of this image?",
|
| 344 |
+
"What is in this image?",
|
| 345 |
+
"What is shown in this picture?",
|
| 346 |
+
]
|
| 347 |
+
# Tokens we treat as non-answers
|
| 348 |
+
stop_words = {'a', 'an', 'the', 'this', 'that', 'it', 'yes', 'no',
|
| 349 |
+
'some', 'there', 'here', 'image', 'picture', 'photo'}
|
| 350 |
+
detected = []
|
| 351 |
+
for question in detection_questions:
|
| 352 |
+
try:
|
| 353 |
+
question_tokens = self.tokenizer(
|
| 354 |
+
question,
|
| 355 |
+
padding='max_length',
|
| 356 |
+
truncation=True,
|
| 357 |
+
max_length=vqa_model.question_max_len,
|
| 358 |
+
return_tensors='pt'
|
| 359 |
+
)
|
| 360 |
+
questions = {
|
| 361 |
+
'input_ids': question_tokens['input_ids'].to(self.device),
|
| 362 |
+
'attention_mask': question_tokens['attention_mask'].to(self.device)
|
| 363 |
+
}
|
| 364 |
+
with torch.no_grad():
|
| 365 |
+
generated = vqa_model(image, questions)
|
| 366 |
+
answer = self.vocab.decoder(generated[0].cpu().numpy()).strip()
|
| 367 |
+
if (answer
|
| 368 |
+
and answer.lower() not in stop_words
|
| 369 |
+
and answer not in detected):
|
| 370 |
+
detected.append(answer)
|
| 371 |
+
if len(detected) >= top_k:
|
| 372 |
+
break
|
| 373 |
+
except Exception as e:
|
| 374 |
+
print(f" ⚠️ Error detecting objects: {e}")
|
| 375 |
+
continue
|
| 376 |
+
return detected if detected else []
|
| 377 |
+
def batch_answer(self, image_question_pairs, use_beam_search=True, verbose=False):
|
| 378 |
+
"""
|
| 379 |
+
Answer multiple questions efficiently.
|
| 380 |
+
Args:
|
| 381 |
+
image_question_pairs: List of (image_path, question) tuples
|
| 382 |
+
use_beam_search: Whether to use beam search
|
| 383 |
+
verbose: Print progress
|
| 384 |
+
Returns:
|
| 385 |
+
List of result dicts
|
| 386 |
+
"""
|
| 387 |
+
results = []
|
| 388 |
+
total = len(image_question_pairs)
|
| 389 |
+
for i, (image_path, question) in enumerate(image_question_pairs):
|
| 390 |
+
if verbose:
|
| 391 |
+
print(f"\n[{i+1}/{total}] Processing...")
|
| 392 |
+
result = self.answer(image_path, question, use_beam_search, verbose=verbose)
|
| 393 |
+
results.append(result)
|
| 394 |
+
return results
|
| 395 |
+
def demo():
|
| 396 |
+
"""Demo usage of production ensemble VQA."""
|
| 397 |
+
BASE_CHECKPOINT = "./output2/continued_training/vqa_checkpoint.pt"
|
| 398 |
+
SPATIAL_CHECKPOINT = "./output2/spatial_adapter_v2_2/vqa_spatial_checkpoint.pt"
|
| 399 |
+
IMAGE = "./im2.jpg"
|
| 400 |
+
ensemble = ProductionEnsembleVQA(BASE_CHECKPOINT, SPATIAL_CHECKPOINT)
|
| 401 |
+
test_cases = [
|
| 402 |
+
("what is to the right of the soup?", True),
|
| 403 |
+
("what is on the left side?", True),
|
| 404 |
+
("what is above the table?", True),
|
| 405 |
+
("what is next to the bowl?", True),
|
| 406 |
+
("what color is the bowl?", False),
|
| 407 |
+
("how many items are there?", False),
|
| 408 |
+
("what room is this?", False),
|
| 409 |
+
("is there a spoon?", False),
|
| 410 |
+
]
|
| 411 |
+
print("\n" + "="*80)
|
| 412 |
+
print("🧪 TESTING ENSEMBLE VQA SYSTEM")
|
| 413 |
+
print("="*80)
|
| 414 |
+
print(f"\n📷 Image: {IMAGE}\n")
|
| 415 |
+
for question, expected_spatial in test_cases:
|
| 416 |
+
result = ensemble.answer(IMAGE, question, verbose=False)
|
| 417 |
+
is_spatial = result['model_used'] == 'spatial'
|
| 418 |
+
routing_correct = "✓" if is_spatial == expected_spatial else "✗"
|
| 419 |
+
print(f"Q: {question}")
|
| 420 |
+
print(f"A: {result['answer']}")
|
| 421 |
+
print(f"Model: {result['model_used']} {routing_correct}")
|
| 422 |
+
print()
|
| 423 |
+
print("="*80)
|
| 424 |
+
print("✅ Demo complete!")
|
| 425 |
+
def interactive_mode():
|
| 426 |
+
"""Interactive mode for testing."""
|
| 427 |
+
BASE_CHECKPOINT = "./output2/continued_training/vqa_checkpoint.pt"
|
| 428 |
+
SPATIAL_CHECKPOINT = "./output2/spatial_adapter_v2_2/vqa_spatial_checkpoint.pt"
|
| 429 |
+
ensemble = ProductionEnsembleVQA(BASE_CHECKPOINT, SPATIAL_CHECKPOINT)
|
| 430 |
+
print("\n" + "="*80)
|
| 431 |
+
print("🎮 INTERACTIVE MODE")
|
| 432 |
+
print("="*80)
|
| 433 |
+
print("\nCommands:")
|
| 434 |
+
print(" - Enter image path and question")
|
| 435 |
+
print(" - Type 'quit' to exit")
|
| 436 |
+
print("="*80 + "\n")
|
| 437 |
+
while True:
|
| 438 |
+
try:
|
| 439 |
+
image_path = input("📷 Image path: ").strip()
|
| 440 |
+
if image_path.lower() == 'quit':
|
| 441 |
+
break
|
| 442 |
+
question = input("❓ Question: ").strip()
|
| 443 |
+
if question.lower() == 'quit':
|
| 444 |
+
break
|
| 445 |
+
result = ensemble.answer(image_path, question, verbose=True)
|
| 446 |
+
print(f"\n💬 Answer: {result['answer']}\n")
|
| 447 |
+
print("-"*80 + "\n")
|
| 448 |
+
except KeyboardInterrupt:
|
| 449 |
+
print("\n\n👋 Goodbye!")
|
| 450 |
+
break
|
| 451 |
+
except Exception as e:
|
| 452 |
+
print(f"\n❌ Error: {e}\n")
|
| 453 |
+
if __name__ == "__main__":
|
| 454 |
+
import sys
|
| 455 |
+
if len(sys.argv) > 1 and sys.argv[1] == "interactive":
|
| 456 |
+
interactive_mode()
|
| 457 |
+
else:
|
| 458 |
+
demo()
|
enterprise_architecture.drawio
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<mxGraphModel dx="1800" dy="1100" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="1920" pageHeight="1080" math="0" shadow="1">
|
| 3 |
+
<root>
|
| 4 |
+
<mxCell id="0" />
|
| 5 |
+
<mxCell id="1" parent="0" />
|
| 6 |
+
|
| 7 |
+
<mxCell id="bg" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=none;" vertex="1" parent="1">
|
| 8 |
+
<mxGeometry x="-20" y="-20" width="1960" height="1120" as="geometry" />
|
| 9 |
+
</mxCell>
|
| 10 |
+
|
| 11 |
+
<mxCell id="title_bg" value="" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#161B22;strokeColor=#30363D;" vertex="1" parent="1">
|
| 12 |
+
<mxGeometry x="20" y="20" width="1880" height="70" as="geometry" />
|
| 13 |
+
</mxCell>
|
| 14 |
+
|
| 15 |
+
<mxCell id="title" value="<font style="font-size:24px;font-weight:bold;" color="#58A6FF">Semantic Neuro-Symbolic VQA -- Enterprise Architecture</font><br><font style="font-size:11px;" color="#8B949E">React Native Mobile UI | FastAPI (Uvicorn) | PyTorch | OpenAI CLIP | Wikidata SPARQL | Groq LLM (Llama-3.3-70B-Versatile)</font>" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;" vertex="1" parent="1">
|
| 16 |
+
<mxGeometry x="20" y="20" width="1880" height="70" as="geometry" />
|
| 17 |
+
</mxCell>
|
| 18 |
+
|
| 19 |
+
<!-- ===================== CLIENT LAYER ===================== -->
|
| 20 |
+
<mxCell id="client_layer" value="<font style="font-size:14px;font-weight:bold;" color="#79C0FF">[1] CLIENT LAYER</font>" style="swimlane;startSize=30;fillColor=#161B22;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontStyle=1;fontSize=13;rounded=10;" vertex="1" parent="1">
|
| 21 |
+
<mxGeometry x="20" y="110" width="350" height="870" as="geometry" />
|
| 22 |
+
</mxCell>
|
| 23 |
+
|
| 24 |
+
<mxCell id="mobile_label" value="[React Native / Expo]" style="text;html=1;fontSize=20;align=center;fillColor=none;strokeColor=none;fontColor=#58A6FF;" vertex="1" parent="client_layer">
|
| 25 |
+
<mxGeometry x="80" y="38" width="190" height="35" as="geometry" />
|
| 26 |
+
</mxCell>
|
| 27 |
+
|
| 28 |
+
<mxCell id="mobile_app" value="<b>React Native Mobile App</b><br><font color="#8B949E">Expo Framework | iOS and Android</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=12;" vertex="1" parent="client_layer">
|
| 29 |
+
<mxGeometry x="30" y="85" width="290" height="60" as="geometry" />
|
| 30 |
+
</mxCell>
|
| 31 |
+
|
| 32 |
+
<mxCell id="screen_login" value="<b>LoginScreen.js</b><br><font color="#8B949E">Auth | Session Management</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
|
| 33 |
+
<mxGeometry x="30" y="165" width="290" height="50" as="geometry" />
|
| 34 |
+
</mxCell>
|
| 35 |
+
|
| 36 |
+
<mxCell id="screen_camera" value="<b>CameraScreen.js</b><br><font color="#8B949E">Image Capture | Upload</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
|
| 37 |
+
<mxGeometry x="30" y="225" width="290" height="50" as="geometry" />
|
| 38 |
+
</mxCell>
|
| 39 |
+
|
| 40 |
+
<mxCell id="screen_home" value="<b>HomeScreen.js</b><br><font color="#8B949E">Main Dashboard | History</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
|
| 41 |
+
<mxGeometry x="30" y="285" width="290" height="50" as="geometry" />
|
| 42 |
+
</mxCell>
|
| 43 |
+
|
| 44 |
+
<mxCell id="screen_qa" value="<b>QuestionScreen.js</b><br><font color="#8B949E">Q and A Interface | Conversation</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
|
| 45 |
+
<mxGeometry x="30" y="345" width="290" height="50" as="geometry" />
|
| 46 |
+
</mxCell>
|
| 47 |
+
|
| 48 |
+
<mxCell id="screen_result" value="<b>ResultScreen.js</b><br><font color="#8B949E">Answer Display | KG Enhancement</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
|
| 49 |
+
<mxGeometry x="30" y="405" width="290" height="50" as="geometry" />
|
| 50 |
+
</mxCell>
|
| 51 |
+
|
| 52 |
+
<mxCell id="api_js" value="<b>api.js (API Service)</b><br><font color="#8B949E">Axios | FormData | Session Tokens<br>REST calls to FastAPI backend</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A2820;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
|
| 53 |
+
<mxGeometry x="30" y="478" width="290" height="70" as="geometry" />
|
| 54 |
+
</mxCell>
|
| 55 |
+
|
| 56 |
+
<mxCell id="ep1" value="POST /api/answer" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#3FB950;fontColor=#3FB950;fontSize=10;" vertex="1" parent="client_layer">
|
| 57 |
+
<mxGeometry x="30" y="565" width="135" height="30" as="geometry" />
|
| 58 |
+
</mxCell>
|
| 59 |
+
<mxCell id="ep2" value="POST /api/conversation/answer" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#3FB950;fontColor=#3FB950;fontSize=10;" vertex="1" parent="client_layer">
|
| 60 |
+
<mxGeometry x="177" y="565" width="143" height="30" as="geometry" />
|
| 61 |
+
</mxCell>
|
| 62 |
+
<mxCell id="ep3" value="GET /api/models/info" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#3FB950;fontColor=#3FB950;fontSize=10;" vertex="1" parent="client_layer">
|
| 63 |
+
<mxGeometry x="30" y="605" width="135" height="30" as="geometry" />
|
| 64 |
+
</mxCell>
|
| 65 |
+
<mxCell id="ep4" value="GET/DELETE /api/conversation/{id}" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#3FB950;fontColor=#3FB950;fontSize=10;" vertex="1" parent="client_layer">
|
| 66 |
+
<mxGeometry x="177" y="605" width="143" height="30" as="geometry" />
|
| 67 |
+
</mxCell>
|
| 68 |
+
|
| 69 |
+
<mxCell id="client_tech" value="<b>Tech:</b> Expo | React Navigation | Axios | FormData<br><b>Auth:</b> Session tokens | Context API" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#161B22;strokeColor=#21262D;fontColor=#8B949E;fontSize=10;" vertex="1" parent="client_layer">
|
| 70 |
+
<mxGeometry x="30" y="660" width="290" height="55" as="geometry" />
|
| 71 |
+
</mxCell>
|
| 72 |
+
|
| 73 |
+
<!-- ===================== API GATEWAY LAYER ===================== -->
|
| 74 |
+
<mxCell id="api_layer" value="<font style="font-size:14px;font-weight:bold;" color="#56D364">[2] API GATEWAY LAYER</font>" style="swimlane;startSize=30;fillColor=#161B22;strokeColor=#3FB950;fontColor=#FFFFFF;fontStyle=1;fontSize=13;rounded=10;" vertex="1" parent="1">
|
| 75 |
+
<mxGeometry x="400" y="110" width="360" height="870" as="geometry" />
|
| 76 |
+
</mxCell>
|
| 77 |
+
|
| 78 |
+
<mxCell id="apigw_label" value="[FastAPI + Uvicorn]" style="text;html=1;fontSize=20;align=center;fillColor=none;strokeColor=none;fontColor=#3FB950;" vertex="1" parent="api_layer">
|
| 79 |
+
<mxGeometry x="85" y="38" width="190" height="35" as="geometry" />
|
| 80 |
+
</mxCell>
|
| 81 |
+
|
| 82 |
+
<mxCell id="fastapi_main" value="<b>FastAPI Backend (Uvicorn)</b><br><font color="#8B949E">backend_api.py<br>Host: 0.0.0.0 | Port: 8000<br>CORS enabled | Auto-reload dev mode</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#162415;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 83 |
+
<mxGeometry x="20" y="88" width="320" height="80" as="geometry" />
|
| 84 |
+
</mxCell>
|
| 85 |
+
|
| 86 |
+
<mxCell id="startup" value="<b>Startup Event</b><br><font color="#8B949E">Load checkpoints | Init models<br>Init Groq service | Health check</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 87 |
+
<mxGeometry x="20" y="188" width="320" height="60" as="geometry" />
|
| 88 |
+
</mxCell>
|
| 89 |
+
|
| 90 |
+
<mxCell id="ep_health" value="GET /health<br><font color="#8B949E">Model status check</font>" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 91 |
+
<mxGeometry x="20" y="268" width="145" height="50" as="geometry" />
|
| 92 |
+
</mxCell>
|
| 93 |
+
<mxCell id="ep_root" value="GET /<br><font color="#8B949E">API info and docs</font>" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 94 |
+
<mxGeometry x="175" y="268" width="145" height="50" as="geometry" />
|
| 95 |
+
</mxCell>
|
| 96 |
+
<mxCell id="ep_answer" value="POST /api/answer<br><font color="#8B949E">image + question -> JSON answer</font>" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#132D0E;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 97 |
+
<mxGeometry x="20" y="328" width="300" height="50" as="geometry" />
|
| 98 |
+
</mxCell>
|
| 99 |
+
<mxCell id="ep_conv" value="POST /api/conversation/answer<br><font color="#8B949E">Multi-turn | session_id | pronouns</font>" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#132D0E;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 100 |
+
<mxGeometry x="20" y="388" width="300" height="50" as="geometry" />
|
| 101 |
+
</mxCell>
|
| 102 |
+
<mxCell id="ep_hist" value="GET /api/conversation/{id}/history" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 103 |
+
<mxGeometry x="20" y="448" width="300" height="38" as="geometry" />
|
| 104 |
+
</mxCell>
|
| 105 |
+
<mxCell id="ep_del" value="DELETE /api/conversation/{id}" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 106 |
+
<mxGeometry x="20" y="496" width="300" height="38" as="geometry" />
|
| 107 |
+
</mxCell>
|
| 108 |
+
<mxCell id="ep_models" value="GET /api/models/info" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 109 |
+
<mxGeometry x="20" y="544" width="300" height="38" as="geometry" />
|
| 110 |
+
</mxCell>
|
| 111 |
+
|
| 112 |
+
<mxCell id="middleware" value="<b>Middleware</b><br><font color="#8B949E">CORS | Error handling | HTTP 400/503/500</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 113 |
+
<mxGeometry x="20" y="600" width="320" height="50" as="geometry" />
|
| 114 |
+
</mxCell>
|
| 115 |
+
|
| 116 |
+
<mxCell id="conv_manager" value="<b>ConversationManager</b><br><font color="#8B949E">conversation_manager.py<br>Session 30min timeout | Pronoun resolution<br>History storage | Context retrieval</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A1A2E;strokeColor=#7B2FBE;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
|
| 117 |
+
<mxGeometry x="20" y="670" width="320" height="80" as="geometry" />
|
| 118 |
+
</mxCell>
|
| 119 |
+
|
| 120 |
+
<!-- ===================== ML INFERENCE ENGINE ===================== -->
|
| 121 |
+
<mxCell id="ml_layer" value="<font style="font-size:14px;font-weight:bold;" color="#FFA657">[3] ML INFERENCE ENGINE</font>" style="swimlane;startSize=30;fillColor=#161B22;strokeColor=#D29922;fontColor=#FFFFFF;fontStyle=1;fontSize=13;rounded=10;" vertex="1" parent="1">
|
| 122 |
+
<mxGeometry x="800" y="110" width="380" height="870" as="geometry" />
|
| 123 |
+
</mxCell>
|
| 124 |
+
|
| 125 |
+
<mxCell id="ml_label" value="[PyTorch + CLIP + DistilGPT-2]" style="text;html=1;fontSize=16;align=center;fillColor=none;strokeColor=none;fontColor=#D29922;" vertex="1" parent="ml_layer">
|
| 126 |
+
<mxGeometry x="40" y="38" width="300" height="35" as="geometry" />
|
| 127 |
+
</mxCell>
|
| 128 |
+
|
| 129 |
+
<mxCell id="ensemble_vqa" value="<b>ProductionEnsembleVQA</b><br><font color="#8B949E">ensemble_vqa_app.py<br>Device: CUDA / CPU auto-detect<br>Beam Search width=5 | Top-K Decoding</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#2D2000;strokeColor=#D29922;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
|
| 130 |
+
<mxGeometry x="20" y="88" width="340" height="80" as="geometry" />
|
| 131 |
+
</mxCell>
|
| 132 |
+
|
| 133 |
+
<mxCell id="router" value="<b>Question Router (Keyword Classifier)</b><br><font color="#8B949E">is_spatial_question()<br>Spatial keywords: left, right, above, below, next to...<br>Routes to Base or Spatial model</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1E1E00;strokeColor=#D29922;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
|
| 134 |
+
<mxGeometry x="20" y="188" width="340" height="75" as="geometry" />
|
| 135 |
+
</mxCell>
|
| 136 |
+
|
| 137 |
+
<mxCell id="base_model_box" value="<b>Base VQA Model</b><br><font color="#8B949E">model.py | VQAModel<br>CLIP ViT-B/32 + GPT-2<br>vqa_checkpoint.pt (731 MB)<br>hidden=512 | layers=2 | acc~50%</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#162415;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
|
| 138 |
+
<mxGeometry x="20" y="285" width="158" height="120" as="geometry" />
|
| 139 |
+
</mxCell>
|
| 140 |
+
|
| 141 |
+
<mxCell id="spatial_model_box" value="<b>Spatial VQA Model</b><br><font color="#8B949E">model_spatial.py<br>SpatialAdapter + 8-head attn<br>vqa_spatial_checkpoint.pt (739 MB)<br>dropout=0.3 | acc~40%</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
|
| 142 |
+
<mxGeometry x="192" y="285" width="168" height="120" as="geometry" />
|
| 143 |
+
</mxCell>
|
| 144 |
+
|
| 145 |
+
<mxCell id="gpt2" value="<b>DistilGPT-2 Tokenizer</b><br><font color="#8B949E">Text tokenization | Vocab<br>BOS / EOS / PAD tokens | Beam search decoding</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
|
| 146 |
+
<mxGeometry x="20" y="425" width="340" height="65" as="geometry" />
|
| 147 |
+
</mxCell>
|
| 148 |
+
|
| 149 |
+
<mxCell id="clip_box" value="<b>OpenAI CLIP (ViT-B/32)</b><br><font color="#8B949E">Image encoder + Text encoder<br>Zero-shot object detection (80+ nouns)<br>Question routing: visual vs knowledge<br>Anchor similarity | Softmax x10</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A1A0D;strokeColor=#E3B341;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
|
| 150 |
+
<mxGeometry x="20" y="508" width="340" height="90" as="geometry" />
|
| 151 |
+
</mxCell>
|
| 152 |
+
|
| 153 |
+
<mxCell id="img_proc" value="<b>Image Preprocessor (PIL)</b><br><font color="#8B949E">JPEG/PNG -> RGB | CLIP preprocess | Tensor</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
|
| 154 |
+
<mxGeometry x="20" y="615" width="340" height="55" as="geometry" />
|
| 155 |
+
</mxCell>
|
| 156 |
+
|
| 157 |
+
<mxCell id="pt_files" value="<b>PyTorch Checkpoints (Local Disk)</b><br><font color="#8B949E">vqa_checkpoint.pt (731 MB)<br>vqa_spatial_checkpoint.pt (739 MB)<br>state_dict | vocab | tokenizer config</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#251A00;strokeColor=#D29922;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
|
| 158 |
+
<mxGeometry x="20" y="688" width="340" height="80" as="geometry" />
|
| 159 |
+
</mxCell>
|
| 160 |
+
|
| 161 |
+
<mxCell id="gpu_badge" value="GPU: CUDA | ~4 GB VRAM | 2x Model Parallel loading" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#D29922;fontColor=#E3B341;fontSize=10;" vertex="1" parent="ml_layer">
|
| 162 |
+
<mxGeometry x="20" y="785" width="340" height="28" as="geometry" />
|
| 163 |
+
</mxCell>
|
| 164 |
+
|
| 165 |
+
<!-- ===================== NEURO-SYMBOLIC PIPELINE ===================== -->
|
| 166 |
+
<mxCell id="ns_layer" value="<font style="font-size:14px;font-weight:bold;" color="#BC8CFF">[4] NEURO-SYMBOLIC PIPELINE</font>" style="swimlane;startSize=30;fillColor=#161B22;strokeColor=#8957E5;fontColor=#FFFFFF;fontStyle=1;fontSize=13;rounded=10;" vertex="1" parent="1">
|
| 167 |
+
<mxGeometry x="1220" y="110" width="370" height="870" as="geometry" />
|
| 168 |
+
</mxCell>
|
| 169 |
+
|
| 170 |
+
<mxCell id="ns_label" value="[CLIP + Wikidata SPARQL + Groq LLM]" style="text;html=1;fontSize=14;align=center;fillColor=none;strokeColor=none;fontColor=#8957E5;" vertex="1" parent="ns_layer">
|
| 171 |
+
<mxGeometry x="15" y="38" width="340" height="35" as="geometry" />
|
| 172 |
+
</mxCell>
|
| 173 |
+
|
| 174 |
+
<mxCell id="ns_main" value="<b>SemanticNeurosymbolicVQA</b><br><font color="#8B949E">semantic_neurosymbolic_vqa.py<br>Neural -> Symbolic -> Verbalize pipeline</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A0D2E;strokeColor=#8957E5;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
|
| 175 |
+
<mxGeometry x="20" y="88" width="330" height="65" as="geometry" />
|
| 176 |
+
</mxCell>
|
| 177 |
+
|
| 178 |
+
<mxCell id="ns_step1" value="<b>Step 1: CLIP Routing</b><br><font color="#8B949E">should_use_neurosymbolic()<br>VISUAL anchor vs KNOWLEDGE anchor<br>Temperature softmax x10</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D1A30;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
|
| 179 |
+
<mxGeometry x="20" y="173" width="330" height="78" as="geometry" />
|
| 180 |
+
</mxCell>
|
| 181 |
+
|
| 182 |
+
<mxCell id="route_decision" value="VISUAL question?<br>-> Neural VQA only<br>KNOWLEDGE question?<br>-> Neuro-Symbolic" style="rhombus;whiteSpace=wrap;html=1;fillColor=#21262D;strokeColor=#8957E5;fontColor=#FFFFFF;fontSize=10;" vertex="1" parent="ns_layer">
|
| 183 |
+
<mxGeometry x="75" y="268" width="220" height="88" as="geometry" />
|
| 184 |
+
</mxCell>
|
| 185 |
+
|
| 186 |
+
<mxCell id="ns_step2" value="<b>Step 2: CLIP Object Detection</b><br><font color="#8B949E">detect_objects_with_clip()<br>80+ noun vocabulary | Top-3 objects<br>Cosine similarity | prompt: 'a photo of a {label}'</font>" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D1A30;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
|
| 187 |
+
<mxGeometry x="20" y="375" width="330" height="80" as="geometry" />
|
| 188 |
+
</mxCell>
|
| 189 |
+
|
| 190 |
+
<mxCell id="wikidata_box" value="<b>Step 3: WikidataKnowledgeBase</b><br><font color="#8B949E">SPARQL: query.wikidata.org<br>P31 (category) | P186 (material) | P366 (uses)<br>P2101 (melting pt) | P2054 (density)<br>lru_cache(500) | timeout=10s</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#0D2E2E;strokeColor=#2EA8A8;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
|
| 191 |
+
<mxGeometry x="20" y="473" width="330" height="100" as="geometry" />
|
| 192 |
+
</mxCell>
|
| 193 |
+
|
| 194 |
+
<mxCell id="groq_box" value="<b>Step 4: Groq LLM Verbalizer</b><br><font color="#8B949E">WikidataGroqAnswerer<br>Model: llama-3.3-70b-versatile<br>Temp=0.1 | max_tokens=180 | top_p=0.9<br>Answers ONLY from Wikidata facts</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A2B1A;strokeColor=#F85149;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
|
| 195 |
+
<mxGeometry x="20" y="592" width="330" height="95" as="geometry" />
|
| 196 |
+
</mxCell>
|
| 197 |
+
|
| 198 |
+
<mxCell id="groq_access" value="<b>Groq Accessibility Service</b><br><font color="#8B949E">groq_service.py | GroqDescriptionService<br>2-sentence narrations for blind users<br>Temp=0.7 | max_tokens=150</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A2B1A;strokeColor=#F85149;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
|
| 199 |
+
<mxGeometry x="20" y="706" width="330" height="85" as="geometry" />
|
| 200 |
+
</mxCell>
|
| 201 |
+
|
| 202 |
+
<mxCell id="groq_badge" value="Groq API | Llama-3.3-70B-Versatile | GROQ_API_KEY env var" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#F85149;fontColor=#F85149;fontSize=10;" vertex="1" parent="ns_layer">
|
| 203 |
+
<mxGeometry x="20" y="808" width="330" height="28" as="geometry" />
|
| 204 |
+
</mxCell>
|
| 205 |
+
|
| 206 |
+
<!-- ===================== EXTERNAL SERVICES ===================== -->
|
| 207 |
+
<mxCell id="wikidata_ext" value="<b>Wikidata SPARQL API</b><br><font color="#8B949E">query.wikidata.org/sparql<br>wikidata.org/w/api.php<br>Entity lookup | Property values<br>Free and Open Knowledge Base</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#0A2525;strokeColor=#2EA8A8;fontColor=#FFFFFF;fontSize=12;" vertex="1" parent="1">
|
| 208 |
+
<mxGeometry x="1640" y="200" width="250" height="130" as="geometry" />
|
| 209 |
+
</mxCell>
|
| 210 |
+
|
| 211 |
+
<mxCell id="groq_cloud" value="<b>Groq Cloud API</b><br><font color="#8B949E">api.groq.com<br>Llama-3.3-70B-Versatile<br>Ultra-low latency inference<br>chat.completions endpoint</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A0A0A;strokeColor=#F85149;fontColor=#FFFFFF;fontSize=12;" vertex="1" parent="1">
|
| 212 |
+
<mxGeometry x="1640" y="385" width="250" height="130" as="geometry" />
|
| 213 |
+
</mxCell>
|
| 214 |
+
|
| 215 |
+
<mxCell id="hf_clip" value="<b>OpenAI / HuggingFace Hub</b><br><font color="#8B949E">CLIP ViT-B/32 weights<br>GPT-2 / DistilGPT-2 tokenizer<br>Cached locally after first download</font>" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A1000;strokeColor=#E3B341;fontColor=#FFFFFF;fontSize=12;" vertex="1" parent="1">
|
| 216 |
+
<mxGeometry x="1640" y="565" width="250" height="105" as="geometry" />
|
| 217 |
+
</mxCell>
|
| 218 |
+
|
| 219 |
+
<!-- ===================== LEGEND ===================== -->
|
| 220 |
+
<mxCell id="legend" value="<b>LEGEND</b><br>[1] Blue = Client Layer (React Native)<br>[2] Green = API Gateway (FastAPI)<br>[3] Orange = ML Inference (PyTorch)<br>[4] Purple = Neuro-Symbolic Pipeline<br>Solid arrow = Primary data flow<br>Dashed arrow = Conditional / supplement<br>Animated = Live request flow" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#161B22;strokeColor=#30363D;fontColor=#8B949E;fontSize=11;align=left;" vertex="1" parent="1">
|
| 221 |
+
<mxGeometry x="1640" y="710" width="250" height="155" as="geometry" />
|
| 222 |
+
</mxCell>
|
| 223 |
+
|
| 224 |
+
<!-- ===================== EDGES / ANIMATED FLOWS ===================== -->
|
| 225 |
+
|
| 226 |
+
<!-- 1. api.js -> FastAPI (HTTP REST) -->
|
| 227 |
+
<mxCell id="flow_1" value="<font color="#3FB950">HTTP REST (JSON/FormData)</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;orthogonalLoop=1;jettySize=auto;strokeColor=#3FB950;strokeWidth=3;fontSize=10;fontColor=#3FB950;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="api_js" target="fastapi_main">
|
| 228 |
+
<mxGeometry relative="1" as="geometry" />
|
| 229 |
+
</mxCell>
|
| 230 |
+
|
| 231 |
+
<!-- 2. FastAPI -> Ensemble VQA -->
|
| 232 |
+
<mxCell id="flow_2" value="<font color="#FFA657">answer()</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;orthogonalLoop=1;jettySize=auto;strokeColor=#D29922;strokeWidth=3;fontSize=10;fontColor=#FFA657;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="fastapi_main" target="ensemble_vqa">
|
| 233 |
+
<mxGeometry relative="1" as="geometry" />
|
| 234 |
+
</mxCell>
|
| 235 |
+
|
| 236 |
+
<!-- 3. Ensemble -> Router -->
|
| 237 |
+
<mxCell id="flow_3" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#D29922;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ensemble_vqa" target="router">
|
| 238 |
+
<mxGeometry relative="1" as="geometry" />
|
| 239 |
+
</mxCell>
|
| 240 |
+
|
| 241 |
+
<!-- 4a. Router -> Base Model -->
|
| 242 |
+
<mxCell id="flow_4a" value="<font color="#3FB950">General Q</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#3FB950;strokeWidth=2;animation=1;endArrow=block;endFill=1;fontSize=10;fontColor=#3FB950;" edge="1" parent="1" source="router" target="base_model_box">
|
| 243 |
+
<mxGeometry relative="1" as="geometry" />
|
| 244 |
+
</mxCell>
|
| 245 |
+
|
| 246 |
+
<!-- 4b. Router -> Spatial Model -->
|
| 247 |
+
<mxCell id="flow_4b" value="<font color="#58A6FF">Spatial Q</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#1F6FEB;strokeWidth=2;animation=1;endArrow=block;endFill=1;fontSize=10;fontColor=#58A6FF;" edge="1" parent="1" source="router" target="spatial_model_box">
|
| 248 |
+
<mxGeometry relative="1" as="geometry" />
|
| 249 |
+
</mxCell>
|
| 250 |
+
|
| 251 |
+
<!-- 5. Ensemble -> NS Pipeline (supplement) -->
|
| 252 |
+
<mxCell id="flow_5" value="<font color="#BC8CFF">NS supplement</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;orthogonalLoop=1;jettySize=auto;strokeColor=#8957E5;strokeWidth=3;fontSize=10;fontColor=#BC8CFF;animation=1;dashed=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ensemble_vqa" target="ns_main">
|
| 253 |
+
<mxGeometry relative="1" as="geometry" />
|
| 254 |
+
</mxCell>
|
| 255 |
+
|
| 256 |
+
<!-- 6. NS main -> CLIP Routing -->
|
| 257 |
+
<mxCell id="flow_6" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#8957E5;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ns_main" target="ns_step1">
|
| 258 |
+
<mxGeometry relative="1" as="geometry" />
|
| 259 |
+
</mxCell>
|
| 260 |
+
|
| 261 |
+
<!-- 7. CLIP Routing -> Decision diamond -->
|
| 262 |
+
<mxCell id="flow_7" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#8957E5;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ns_step1" target="route_decision">
|
| 263 |
+
<mxGeometry relative="1" as="geometry" />
|
| 264 |
+
</mxCell>
|
| 265 |
+
|
| 266 |
+
<!-- 8. Decision -> Object Detection -->
|
| 267 |
+
<mxCell id="flow_8" value="<font color="#BC8CFF">Knowledge Q</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#8957E5;strokeWidth=2;animation=1;dashed=1;endArrow=block;endFill=1;fontSize=10;fontColor=#BC8CFF;" edge="1" parent="1" source="route_decision" target="ns_step2">
|
| 268 |
+
<mxGeometry relative="1" as="geometry" />
|
| 269 |
+
</mxCell>
|
| 270 |
+
|
| 271 |
+
<!-- 9. Object Detection -> Wikidata box -->
|
| 272 |
+
<mxCell id="flow_9" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#2EA8A8;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ns_step2" target="wikidata_box">
|
| 273 |
+
<mxGeometry relative="1" as="geometry" />
|
| 274 |
+
</mxCell>
|
| 275 |
+
|
| 276 |
+
<!-- 10. Wikidata box -> Wikidata external API -->
|
| 277 |
+
<mxCell id="flow_10" value="<font color="#2EA8A8">SPARQL queries</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#2EA8A8;strokeWidth=3;fontSize=10;fontColor=#2EA8A8;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="wikidata_box" target="wikidata_ext">
|
| 278 |
+
<mxGeometry relative="1" as="geometry" />
|
| 279 |
+
</mxCell>
|
| 280 |
+
|
| 281 |
+
<!-- 11. Wikidata facts -> Groq verbalizer -->
|
| 282 |
+
<mxCell id="flow_11" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#F85149;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="wikidata_box" target="groq_box">
|
| 283 |
+
<mxGeometry relative="1" as="geometry" />
|
| 284 |
+
</mxCell>
|
| 285 |
+
|
| 286 |
+
<!-- 12. Groq box -> Groq Cloud -->
|
| 287 |
+
<mxCell id="flow_12" value="<font color="#F85149">API call | Llama-3.3-70B</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#F85149;strokeWidth=3;fontSize=10;fontColor=#F85149;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="groq_box" target="groq_cloud">
|
| 288 |
+
<mxGeometry relative="1" as="geometry" />
|
| 289 |
+
</mxCell>
|
| 290 |
+
|
| 291 |
+
<!-- 13. Groq accessibility -> Groq Cloud -->
|
| 292 |
+
<mxCell id="flow_13" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#F85149;strokeWidth=2;animation=1;dashed=1;endArrow=block;endFill=1;" edge="1" parent="1" source="groq_access" target="groq_cloud">
|
| 293 |
+
<mxGeometry relative="1" as="geometry" />
|
| 294 |
+
</mxCell>
|
| 295 |
+
|
| 296 |
+
<!-- 14. FastAPI -> Groq Accessibility (top arc) -->
|
| 297 |
+
<mxCell id="flow_14" value="<font color="#F85149">accessibility narration</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#F85149;strokeWidth=2;fontSize=10;fontColor=#F85149;animation=1;dashed=1;endArrow=block;endFill=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="fastapi_main" target="groq_access">
|
| 298 |
+
<mxGeometry relative="1" as="geometry">
|
| 299 |
+
<Array as="points">
|
| 300 |
+
<mxPoint x="580" y="140" />
|
| 301 |
+
<mxPoint x="1385" y="140" />
|
| 302 |
+
</Array>
|
| 303 |
+
</mxGeometry>
|
| 304 |
+
</mxCell>
|
| 305 |
+
|
| 306 |
+
<!-- 15. CLIP box -> HuggingFace (model weights) -->
|
| 307 |
+
<mxCell id="flow_15" value="<font color="#E3B341">model weights (cached)</font>" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#E3B341;strokeWidth=2;fontSize=10;fontColor=#E3B341;dashed=1;endArrow=block;endFill=1;" edge="1" parent="1" source="clip_box" target="hf_clip">
|
| 308 |
+
<mxGeometry relative="1" as="geometry" />
|
| 309 |
+
</mxCell>
|
| 310 |
+
|
| 311 |
+
<!-- 16a. Base model -> GPT2 Tokenizer -->
|
| 312 |
+
<mxCell id="flow_16a" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#30363D;strokeWidth=1;endArrow=block;endFill=1;" edge="1" parent="1" source="base_model_box" target="gpt2">
|
| 313 |
+
<mxGeometry relative="1" as="geometry" />
|
| 314 |
+
</mxCell>
|
| 315 |
+
|
| 316 |
+
<!-- 16b. Spatial model -> GPT2 Tokenizer -->
|
| 317 |
+
<mxCell id="flow_16b" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#30363D;strokeWidth=1;endArrow=block;endFill=1;" edge="1" parent="1" source="spatial_model_box" target="gpt2">
|
| 318 |
+
<mxGeometry relative="1" as="geometry" />
|
| 319 |
+
</mxCell>
|
| 320 |
+
|
| 321 |
+
<!-- 17. Conv Manager <-> Ensemble VQA -->
|
| 322 |
+
<mxCell id="flow_17" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#7B2FBE;strokeWidth=2;animation=1;dashed=1;endArrow=block;endFill=1;startArrow=block;startFill=1;" edge="1" parent="1" source="conv_manager" target="ensemble_vqa">
|
| 323 |
+
<mxGeometry relative="1" as="geometry" />
|
| 324 |
+
</mxCell>
|
| 325 |
+
|
| 326 |
+
<!-- ===================== PHASE ANNOTATIONS ===================== -->
|
| 327 |
+
<mxCell id="ann1" value="(1) User uploads image + question" style="text;html=1;strokeColor=none;fillColor=#0D1117;fontColor=#58A6FF;fontSize=11;fontStyle=1;align=center;" vertex="1" parent="1">
|
| 328 |
+
<mxGeometry x="100" y="988" width="250" height="28" as="geometry" />
|
| 329 |
+
</mxCell>
|
| 330 |
+
<mxCell id="ann2" value="(2) REST API routes to ensemble" style="text;html=1;strokeColor=none;fillColor=#0D1117;fontColor=#3FB950;fontSize=11;fontStyle=1;align=center;" vertex="1" parent="1">
|
| 331 |
+
<mxGeometry x="460" y="988" width="240" height="28" as="geometry" />
|
| 332 |
+
</mxCell>
|
| 333 |
+
<mxCell id="ann3" value="(3) Neural model answers question" style="text;html=1;strokeColor=none;fillColor=#0D1117;fontColor=#FFA657;fontSize=11;fontStyle=1;align=center;" vertex="1" parent="1">
|
| 334 |
+
<mxGeometry x="860" y="988" width="250" height="28" as="geometry" />
|
| 335 |
+
</mxCell>
|
| 336 |
+
<mxCell id="ann4" value="(4) Symbolic + Groq enriches answer" style="text;html=1;strokeColor=none;fillColor=#0D1117;fontColor=#BC8CFF;fontSize=11;fontStyle=1;align=center;" vertex="1" parent="1">
|
| 337 |
+
<mxGeometry x="1270" y="988" width="260" height="28" as="geometry" />
|
| 338 |
+
</mxCell>
|
| 339 |
+
|
| 340 |
+
</root>
|
| 341 |
+
</mxGraphModel>
|
exp_results/feature_extraction_metric.csv
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
epoch,train_loss,train_token_acc,val_loss,val_token_acc,val_exact_match,lr
|
| 2 |
+
1,3.687392619148223,0.5010925703618669,2.6377785576964325,0.531001718679689,0.0625462073044507,0.0001
|
| 3 |
+
2,3.0861334211370917,0.5492582896593264,2.1873294205035805,0.5735707690693298,0.1437971314505397,0.0001
|
| 4 |
+
3,2.8613873058208554,0.5773015727241105,2.0188139058508963,0.5919563278274717,0.18172408694366404,0.0001
|
| 5 |
+
4,2.737266832805117,0.5940482014385925,1.8989913845961948,0.6057449292461827,0.2079698358716546,0.0001
|
| 6 |
+
5,2.64607786060719,0.6068536389304081,1.8126546847370435,0.6131761748835726,0.22467839716102322,0.0001
|
| 7 |
+
6,2.5737654996439945,0.6159161500927967,1.745610311908542,0.6227055006432083,0.23806003252994234,0.0001
|
| 8 |
+
7,2.514629547727101,0.6238974923921153,1.6846065549355633,0.6310539678582605,0.25521218394203754,0.0001
|
| 9 |
+
8,2.467853448716654,0.630066124487741,1.6530387682734795,0.6351331795723933,0.2616442407215733,0.0001
|
| 10 |
+
9,2.430272235876001,0.6363310434633568,1.6044414886888467,0.6438829395568596,0.2796096406920006,0.0001
|
| 11 |
+
10,2.3940254725485,0.6410929495099732,1.5768477393771119,0.6476609546620891,0.2876681945882005,0.0001
|
| 12 |
+
11,2.3626844231579023,0.6466396824626934,1.553934060740021,0.6507747072093891,0.2935087978707674,0.0001
|
| 13 |
+
12,2.3347287295768417,0.6508579807194079,1.5344560882955227,0.6529503009229336,0.29957119621469763,0.0001
|
| 14 |
+
13,2.309176077580466,0.6551987208042674,1.5069528773145855,0.6592958943461472,0.3086647937305929,0.0001
|
| 15 |
+
14,2.2852324938224235,0.6583507632729854,1.4877223473674845,0.6627878375210852,0.31820198136921485,0.0001
|
| 16 |
+
15,2.265477722738707,0.6621552250710977,1.4731922914397042,0.6635274037999926,0.3206417270442111,0.0001
|
| 17 |
+
16,2.245406344189297,0.6660276569959188,1.454425812892194,0.6657813076140746,0.3254472867070827,1e-06
|
| 18 |
+
17,2.2047869251156476,0.6741207528932076,1.4267255866302635,0.6736559963451242,0.3408990093153926,1e-06
|
| 19 |
+
18,2.173899897451869,0.6801777819710184,1.4036545191171035,0.6780021879470574,0.34703533934644387,1e-06
|
| 20 |
+
19,2.15051551812644,0.6852958937991237,1.3850691127327253,0.6806749330376679,0.3535413278131007,1e-06
|
| 21 |
+
20,2.130151925532512,0.6903713528113137,1.3759601954019294,0.682907020145992,0.3590862043471832,1e-06
|
| 22 |
+
21,2.111327923803482,0.6937075932303665,1.3607378039719924,0.6867363317957464,0.3650746710039923,1e-06
|
| 23 |
+
22,2.092705831874552,0.6989087903379759,1.3529389587775715,0.6871686296642951,0.3676622800532308,1e-06
|
| 24 |
+
23,2.0762000757163266,0.7018636832358497,1.3471845992893543,0.6889090611124938,0.3711370693479225,1e-06
|
| 25 |
+
24,2.0588077032516723,0.7061800249295429,1.3332587570514318,0.6925943864966339,0.37853023806003255,1e-06
|
| 26 |
+
25,2.043530640342685,0.7086816234112068,1.323614944545728,0.6927403596774587,0.3790477598698802,1e-06
|
| 27 |
+
26,2.028976038177644,0.7119645012827895,1.321273627989697,0.6960837739818501,0.38511015821381045,1e-06
|
| 28 |
+
27,2.0125017191516372,0.7166598519934908,1.3151825143481202,0.6966083350608934,0.38651486026911136,1e-06
|
| 29 |
+
28,1.998029633995205,0.7198163744333156,1.3046240308937036,0.6980289071798325,0.38836315244713887,1e-06
|
| 30 |
+
29,1.9832194559959038,0.7228894007410402,1.3061683574375116,0.6981341627971182,0.3905811030607719,1e-06
|
| 31 |
+
30,1.96923152904127,0.7272438684699805,1.3041821732273642,0.6986926667532831,0.3902114446251663,1e-06
|
experiments/__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (21.6 kB). View file
|
|
|
experiments/test.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from transformers import GPT2Tokenizer
|
| 5 |
+
from model import VQAModel
|
| 6 |
+
from train import Vocab
|
| 7 |
+
def load_model(checkpoint_path, device='cuda'):
|
| 8 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 9 |
+
vocab = Vocab()
|
| 10 |
+
vocab.vocab = checkpoint['vocab']
|
| 11 |
+
vocab.vocab_size = len(checkpoint['vocab'])
|
| 12 |
+
vocab.word2idx = checkpoint['word2idx']
|
| 13 |
+
vocab.idx2word = checkpoint['idx2word']
|
| 14 |
+
vocab.pad_token_id = checkpoint['pad_token_id']
|
| 15 |
+
vocab.bos_token_id = checkpoint['bos_token_id']
|
| 16 |
+
vocab.eos_token_id = checkpoint['eos_token_id']
|
| 17 |
+
vocab.unk_token_id = checkpoint['unk_token_id']
|
| 18 |
+
model = VQAModel(
|
| 19 |
+
vocab_size=len(checkpoint['vocab']),
|
| 20 |
+
device=device,
|
| 21 |
+
question_max_len=checkpoint.get('question_max_len', 20),
|
| 22 |
+
answer_max_len=checkpoint.get('answer_max_len', 12),
|
| 23 |
+
pad_token_id=checkpoint['pad_token_id'],
|
| 24 |
+
bos_token_id=checkpoint['bos_token_id'],
|
| 25 |
+
eos_token_id=checkpoint['eos_token_id'],
|
| 26 |
+
unk_token_id=checkpoint['unk_token_id'],
|
| 27 |
+
hidden_size=512,
|
| 28 |
+
num_layers=2
|
| 29 |
+
).to(device)
|
| 30 |
+
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 31 |
+
if tokenizer.pad_token is None:
|
| 32 |
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 33 |
+
model.gpt2_model.resize_token_embeddings(len(tokenizer))
|
| 34 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 35 |
+
model.eval()
|
| 36 |
+
return model, vocab, tokenizer
|
| 37 |
+
def answer_question(model, vocab, tokenizer, image_path, question, device='cuda', use_beam_search=True, beam_width=5, temperature=0.8):
|
| 38 |
+
image = Image.open(image_path).convert('RGB')
|
| 39 |
+
image = model.clip_preprocess(image).unsqueeze(0).to(device)
|
| 40 |
+
question_tokens = tokenizer(
|
| 41 |
+
question,
|
| 42 |
+
padding='max_length',
|
| 43 |
+
truncation=True,
|
| 44 |
+
max_length=model.question_max_len,
|
| 45 |
+
return_tensors='pt'
|
| 46 |
+
)
|
| 47 |
+
questions = {
|
| 48 |
+
'input_ids': question_tokens['input_ids'].to(device),
|
| 49 |
+
'attention_mask': question_tokens['attention_mask'].to(device)
|
| 50 |
+
}
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
if use_beam_search and hasattr(model, 'generate_with_beam_search'):
|
| 53 |
+
generated = model.generate_with_beam_search(image, questions, beam_width=beam_width)
|
| 54 |
+
else:
|
| 55 |
+
generated = model(image, questions)
|
| 56 |
+
answer = vocab.decoder(generated[0].cpu().numpy())
|
| 57 |
+
return answer
|
| 58 |
+
CHECKPOINT = "./output2/spatial_adapter_v2_2/vqa_spatial_checkpoint.pt"
|
| 59 |
+
IMAGE_PATH = r"./im2.jpg"
|
| 60 |
+
QUESTION = ""
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 63 |
+
print("Loading model...")
|
| 64 |
+
model, vocab, tokenizer = load_model(CHECKPOINT, device)
|
| 65 |
+
print("Model loaded!\n")
|
| 66 |
+
test_questions = [
|
| 67 |
+
"What is to the right of the soup?"
|
| 68 |
+
]
|
| 69 |
+
print(f"Image: {IMAGE_PATH}\n")
|
| 70 |
+
for question in test_questions:
|
| 71 |
+
print(f"Question: {question}")
|
| 72 |
+
answer = answer_question(model, vocab, tokenizer, IMAGE_PATH, question, device, use_beam_search=True, beam_width=5)
|
| 73 |
+
print(f"Answer: {answer}\n")
|
experiments/train.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from transformers import GPT2Tokenizer
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from collections import Counter
|
| 12 |
+
from nltk.tokenize import word_tokenize
|
| 13 |
+
from sklearn.model_selection import train_test_split
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
from models.model import VQAModel
|
| 16 |
+
device = 'cuda'
|
| 17 |
+
class Vocab:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.vocab = None
|
| 20 |
+
self.vocab_size = None
|
| 21 |
+
self.word2idx = None
|
| 22 |
+
self.idx2word = None
|
| 23 |
+
self.pad = '<pad>'
|
| 24 |
+
self.bos = '<bos>'
|
| 25 |
+
self.eos = '<eos>'
|
| 26 |
+
self.unk = '<unk>'
|
| 27 |
+
def build_vocab(self, df, min_freq=1):
|
| 28 |
+
counter = Counter()
|
| 29 |
+
for ans in df['answer']:
|
| 30 |
+
tokens = word_tokenize(ans.lower())
|
| 31 |
+
counter.update(tokens)
|
| 32 |
+
vocab = sorted([word for word, freq in counter.items() if freq >= min_freq])
|
| 33 |
+
vocab = [self.pad, self.bos, self.eos, self.unk] + vocab
|
| 34 |
+
word2idx = {word: idx for idx, word in enumerate(vocab)}
|
| 35 |
+
idx2word = {idx: word for word, idx in word2idx.items()}
|
| 36 |
+
self.vocab = vocab
|
| 37 |
+
self.word2idx = word2idx
|
| 38 |
+
self.idx2word = idx2word
|
| 39 |
+
self.vocab_size = len(vocab)
|
| 40 |
+
self.pad_token_id = self.word2idx["<pad>"]
|
| 41 |
+
self.bos_token_id = self.word2idx["<bos>"]
|
| 42 |
+
self.eos_token_id = self.word2idx["<eos>"]
|
| 43 |
+
self.unk_token_id = self.word2idx["<unk>"]
|
| 44 |
+
def encoder(self, text, max_len):
|
| 45 |
+
tokens = word_tokenize(text.lower())
|
| 46 |
+
token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]
|
| 47 |
+
token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
|
| 48 |
+
if len(token_ids) < max_len:
|
| 49 |
+
token_ids += [self.pad_token_id] * (max_len - len(token_ids))
|
| 50 |
+
else:
|
| 51 |
+
token_ids = token_ids[:max_len]
|
| 52 |
+
return token_ids
|
| 53 |
+
def decoder(self, token_ids):
|
| 54 |
+
tokens = []
|
| 55 |
+
for idx in token_ids:
|
| 56 |
+
if idx == self.eos_token_id:
|
| 57 |
+
break
|
| 58 |
+
if idx in (self.pad_token_id, self.bos_token_id):
|
| 59 |
+
continue
|
| 60 |
+
tokens.append(self.idx2word.get(idx, "<unk>"))
|
| 61 |
+
return ' '.join(tokens).strip()
|
| 62 |
+
class AugmentedVQADataset(Dataset):
|
| 63 |
+
def __init__(self, df, img_dir, question_tokenizer, text_processor, clip_processor,
|
| 64 |
+
question_max_len=32, answer_max_len=16, augment=True):
|
| 65 |
+
self.df = df
|
| 66 |
+
self.img_dir = img_dir
|
| 67 |
+
self.question_tokenizer = question_tokenizer
|
| 68 |
+
self.text_processor = text_processor
|
| 69 |
+
self.clip_processor = clip_processor
|
| 70 |
+
self.question_max_len = question_max_len
|
| 71 |
+
self.answer_max_len = answer_max_len
|
| 72 |
+
self.augment = augment
|
| 73 |
+
if augment:
|
| 74 |
+
self.transform = transforms.Compose([
|
| 75 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 76 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
| 77 |
+
transforms.RandomRotation(10),
|
| 78 |
+
])
|
| 79 |
+
else:
|
| 80 |
+
self.transform = None
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return len(self.df)
|
| 83 |
+
def __getitem__(self, idx):
|
| 84 |
+
row = self.df.iloc[idx]
|
| 85 |
+
img_path = os.path.join(self.img_dir, row['image_path'])
|
| 86 |
+
image = Image.open(img_path).convert('RGB')
|
| 87 |
+
question = row['question']
|
| 88 |
+
answer = row['answer']
|
| 89 |
+
if self.augment and self.transform:
|
| 90 |
+
image = self.transform(image)
|
| 91 |
+
question_tokenized = self.question_tokenizer(
|
| 92 |
+
question,
|
| 93 |
+
padding='max_length',
|
| 94 |
+
truncation=True,
|
| 95 |
+
max_length=self.question_max_len,
|
| 96 |
+
return_tensors='pt'
|
| 97 |
+
)
|
| 98 |
+
answer_ids = self.text_processor.encoder(answer, max_len=self.answer_max_len)
|
| 99 |
+
image = self.clip_processor(image)
|
| 100 |
+
return {
|
| 101 |
+
'image_path': img_path,
|
| 102 |
+
'image': image,
|
| 103 |
+
'question_ids': question_tokenized['input_ids'].squeeze(0),
|
| 104 |
+
'question_mask': question_tokenized['attention_mask'].squeeze(0),
|
| 105 |
+
'answer_ids': torch.tensor(answer_ids, dtype=torch.long)
|
| 106 |
+
}
|
| 107 |
+
def save_checkpoint(model, optimizer, epoch, vocab, path):
|
| 108 |
+
torch.save({
|
| 109 |
+
'epoch': epoch,
|
| 110 |
+
'model_state_dict': model.state_dict(),
|
| 111 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 112 |
+
'vocab': vocab.vocab,
|
| 113 |
+
'word2idx': vocab.word2idx,
|
| 114 |
+
'idx2word': vocab.idx2word,
|
| 115 |
+
'pad_token_id': vocab.pad_token_id,
|
| 116 |
+
'bos_token_id': vocab.bos_token_id,
|
| 117 |
+
'eos_token_id': vocab.eos_token_id,
|
| 118 |
+
'unk_token_id': vocab.unk_token_id,
|
| 119 |
+
'question_max_len': model.question_max_len,
|
| 120 |
+
'answer_max_len': model.answer_max_len
|
| 121 |
+
}, path)
|
| 122 |
+
def plot_losses(train_losses, val_losses, save_path="loss_plot.png"):
|
| 123 |
+
plt.figure(figsize=(8,6))
|
| 124 |
+
plt.plot(train_losses, label="Train Loss")
|
| 125 |
+
plt.plot(val_losses, label="Validation Loss")
|
| 126 |
+
plt.xlabel("Epoch")
|
| 127 |
+
plt.ylabel("Loss")
|
| 128 |
+
plt.title("Train vs Validation Loss")
|
| 129 |
+
plt.legend()
|
| 130 |
+
plt.savefig(save_path)
|
| 131 |
+
plt.close()
|
| 132 |
+
def train_one_epoch(model, dataloader, optimizer, device, scaler, vocab):
|
| 133 |
+
model.train()
|
| 134 |
+
total_loss = 0
|
| 135 |
+
total_token_acc = 0
|
| 136 |
+
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id, label_smoothing=0.1)
|
| 137 |
+
for batch in tqdm(dataloader):
|
| 138 |
+
optimizer.zero_grad()
|
| 139 |
+
images = batch['image'].to(device)
|
| 140 |
+
questions = {
|
| 141 |
+
'input_ids': batch['question_ids'].to(device),
|
| 142 |
+
'attention_mask': batch['question_mask'].to(device)
|
| 143 |
+
}
|
| 144 |
+
answers = batch['answer_ids'].to(device)
|
| 145 |
+
with torch.amp.autocast(device):
|
| 146 |
+
logits = model(images, questions, answer_input_ids=answers)
|
| 147 |
+
shifted_logits = logits[:, :-1, :]
|
| 148 |
+
shifted_answers = answers[:, 1:]
|
| 149 |
+
loss = criterion(
|
| 150 |
+
shifted_logits.reshape(-1, shifted_logits.size(-1)),
|
| 151 |
+
shifted_answers.reshape(-1)
|
| 152 |
+
)
|
| 153 |
+
predicted_tokens = shifted_logits.argmax(dim=-1)
|
| 154 |
+
correct = (predicted_tokens == shifted_answers).float()
|
| 155 |
+
mask = (shifted_answers != vocab.pad_token_id).float()
|
| 156 |
+
token_acc = (correct * mask).sum() / mask.sum()
|
| 157 |
+
total_token_acc += token_acc.item()
|
| 158 |
+
scaler.scale(loss).backward()
|
| 159 |
+
scaler.unscale_(optimizer)
|
| 160 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 161 |
+
scaler.step(optimizer)
|
| 162 |
+
scaler.update()
|
| 163 |
+
total_loss += loss.item()
|
| 164 |
+
avg_loss = total_loss / len(dataloader)
|
| 165 |
+
avg_token_acc = total_token_acc / len(dataloader)
|
| 166 |
+
return avg_loss, avg_token_acc
|
| 167 |
+
def validate_one_epoch(model, dataloader, device, vocab):
|
| 168 |
+
model.eval()
|
| 169 |
+
total_loss = 0
|
| 170 |
+
total_token_acc = 0
|
| 171 |
+
exact_matches = 0
|
| 172 |
+
total_samples = 0
|
| 173 |
+
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id)
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
for batch in tqdm(dataloader):
|
| 176 |
+
images = batch['image'].to(device)
|
| 177 |
+
questions = {
|
| 178 |
+
'input_ids': batch['question_ids'].to(device),
|
| 179 |
+
'attention_mask': batch['question_mask'].to(device)
|
| 180 |
+
}
|
| 181 |
+
answers = batch['answer_ids'].to(device)
|
| 182 |
+
logits = model(images, questions, answer_input_ids=answers)
|
| 183 |
+
shifted_logits = logits[:, :-1, :]
|
| 184 |
+
shifted_answers = answers[:, 1:]
|
| 185 |
+
loss = criterion(
|
| 186 |
+
shifted_logits.reshape(-1, shifted_logits.size(-1)),
|
| 187 |
+
shifted_answers.reshape(-1)
|
| 188 |
+
)
|
| 189 |
+
total_loss += loss.item()
|
| 190 |
+
predicted_tokens = shifted_logits.argmax(dim=-1)
|
| 191 |
+
correct = (predicted_tokens == shifted_answers).float()
|
| 192 |
+
mask = (shifted_answers != vocab.pad_token_id).float()
|
| 193 |
+
token_acc = (correct * mask).sum() / mask.sum()
|
| 194 |
+
total_token_acc += token_acc.item()
|
| 195 |
+
if hasattr(model, 'generate_with_beam_search'):
|
| 196 |
+
generated = model.generate_with_beam_search(images, questions, beam_width=3)
|
| 197 |
+
else:
|
| 198 |
+
generated = model(images, questions)
|
| 199 |
+
for pred, true in zip(generated, answers):
|
| 200 |
+
pred_text = vocab.decoder(pred.cpu().numpy())
|
| 201 |
+
true_text = vocab.decoder(true.cpu().numpy())
|
| 202 |
+
if pred_text.strip() == true_text.strip():
|
| 203 |
+
exact_matches += 1
|
| 204 |
+
total_samples += 1
|
| 205 |
+
avg_loss = total_loss / len(dataloader)
|
| 206 |
+
avg_token_acc = total_token_acc / len(dataloader)
|
| 207 |
+
exact_match_acc = exact_matches / total_samples
|
| 208 |
+
return avg_loss, avg_token_acc, exact_match_acc
|
| 209 |
+
def main():
|
| 210 |
+
print()
|
| 211 |
+
print("# VQA: Training with Staged Unfreezing")
|
| 212 |
+
print()
|
| 213 |
+
import random
|
| 214 |
+
import numpy as np
|
| 215 |
+
torch.manual_seed(42)
|
| 216 |
+
random.seed(42)
|
| 217 |
+
np.random.seed(42)
|
| 218 |
+
if torch.cuda.is_available(): torch.cuda.manual_seed_all(42)
|
| 219 |
+
DATA_DIR = r"./gen_vqa_v2"
|
| 220 |
+
CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
|
| 221 |
+
OUTPUT_DIR = r"./output2/feature_extraction"
|
| 222 |
+
CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "vqa_checkpoint.pt")
|
| 223 |
+
LOG_CSV = os.path.join(OUTPUT_DIR, "train_log.csv")
|
| 224 |
+
LOSS_GRAPH_PATH = os.path.join(OUTPUT_DIR, "loss_plot.png")
|
| 225 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 226 |
+
batch_size = 64
|
| 227 |
+
learning_rate = 1e-4
|
| 228 |
+
num_epochs = 30
|
| 229 |
+
patience = 8
|
| 230 |
+
question_max_len = 20
|
| 231 |
+
answer_max_len = 12
|
| 232 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 233 |
+
print(device)
|
| 234 |
+
metadata = pd.read_csv(CSV_PATH)
|
| 235 |
+
print(f"Using: question_max_len={question_max_len}, answer_max_len={answer_max_len}")
|
| 236 |
+
vocab = Vocab()
|
| 237 |
+
vocab.build_vocab(metadata, min_freq=3)
|
| 238 |
+
answer_vocab_size = len(vocab.vocab)
|
| 239 |
+
print(f"Answer Vocab Size: {answer_vocab_size}")
|
| 240 |
+
word_freq = Counter()
|
| 241 |
+
for ans in metadata['answer']:
|
| 242 |
+
tokens = word_tokenize(ans.lower())
|
| 243 |
+
word_freq.update(tokens)
|
| 244 |
+
print("\nTop 20 most common answer words:")
|
| 245 |
+
for word, freq in word_freq.most_common(20):
|
| 246 |
+
print(f" {word}: {freq}")
|
| 247 |
+
train_df, test_df = train_test_split(metadata, test_size=0.2, random_state=42)
|
| 248 |
+
val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
|
| 249 |
+
print(f"\nTrain size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")
|
| 250 |
+
print()
|
| 251 |
+
model = VQAModel(
|
| 252 |
+
vocab_size=answer_vocab_size,
|
| 253 |
+
device=device,
|
| 254 |
+
question_max_len=question_max_len,
|
| 255 |
+
answer_max_len=answer_max_len,
|
| 256 |
+
pad_token_id=vocab.pad_token_id,
|
| 257 |
+
bos_token_id=vocab.bos_token_id,
|
| 258 |
+
eos_token_id=vocab.eos_token_id,
|
| 259 |
+
unk_token_id=vocab.unk_token_id,
|
| 260 |
+
hidden_size=512,
|
| 261 |
+
num_layers=2
|
| 262 |
+
).to(device)
|
| 263 |
+
print("STAGE 1: Training decoder with frozen encoders")
|
| 264 |
+
print()
|
| 265 |
+
clip_processor = model.clip_preprocess
|
| 266 |
+
question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 267 |
+
if question_tokenizer.pad_token is None:
|
| 268 |
+
question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 269 |
+
model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
|
| 270 |
+
train_dataset = AugmentedVQADataset(
|
| 271 |
+
train_df, DATA_DIR, question_tokenizer, vocab,
|
| 272 |
+
clip_processor=clip_processor,
|
| 273 |
+
question_max_len=question_max_len,
|
| 274 |
+
answer_max_len=answer_max_len,
|
| 275 |
+
augment=True
|
| 276 |
+
)
|
| 277 |
+
val_dataset = AugmentedVQADataset(
|
| 278 |
+
val_df, DATA_DIR, question_tokenizer, vocab,
|
| 279 |
+
clip_processor=clip_processor,
|
| 280 |
+
question_max_len=question_max_len,
|
| 281 |
+
answer_max_len=answer_max_len,
|
| 282 |
+
augment=False
|
| 283 |
+
)
|
| 284 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
| 285 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
| 286 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 287 |
+
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=1e-4)
|
| 288 |
+
print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
|
| 289 |
+
print()
|
| 290 |
+
scaler = torch.amp.GradScaler(device)
|
| 291 |
+
best_val_loss = np.inf
|
| 292 |
+
best_val_exact_match = 0.0
|
| 293 |
+
counter = 0
|
| 294 |
+
logs = []
|
| 295 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 296 |
+
optimizer, mode='max', factor=0.5, patience=4, verbose=True
|
| 297 |
+
)
|
| 298 |
+
for epoch in range(num_epochs):
|
| 299 |
+
print(f"\nEpoch {epoch+1}/{num_epochs}")
|
| 300 |
+
train_loss, train_token_acc = train_one_epoch(model, train_loader, optimizer, device, scaler, vocab)
|
| 301 |
+
val_loss, val_token_acc, val_exact_match = validate_one_epoch(model, val_loader, device, vocab)
|
| 302 |
+
print(f"Train Loss: {train_loss:.4f} | Train Token Acc: {train_token_acc:.4f}")
|
| 303 |
+
print(f"Val Loss: {val_loss:.4f} | Val Token Acc: {val_token_acc:.4f} | Val Exact Match: {val_exact_match:.4f}")
|
| 304 |
+
print(f"LR: {optimizer.param_groups[0]['lr']}")
|
| 305 |
+
scheduler.step(val_exact_match)
|
| 306 |
+
if val_exact_match > best_val_exact_match:
|
| 307 |
+
best_val_exact_match = val_exact_match
|
| 308 |
+
save_checkpoint(model, optimizer, epoch, vocab, CHECKPOINT_PATH)
|
| 309 |
+
print("Checkpoint saved!")
|
| 310 |
+
counter = 0
|
| 311 |
+
else:
|
| 312 |
+
counter += 1
|
| 313 |
+
print(f"No improvement in exact match for {counter} epochs.")
|
| 314 |
+
if epoch == 15 and not model.fine_tuning_mode:
|
| 315 |
+
print("\n" + "="*50)
|
| 316 |
+
print("STAGE 2: Unfreezing encoders for fine-tuning")
|
| 317 |
+
print("="*50)
|
| 318 |
+
model.unfreeze_clip_layers(num_layers=3)
|
| 319 |
+
model.unfreeze_gpt2_layers(num_layers=3)
|
| 320 |
+
clip_params = []
|
| 321 |
+
gpt2_params = []
|
| 322 |
+
other_params = []
|
| 323 |
+
for name, param in model.named_parameters():
|
| 324 |
+
if param.requires_grad:
|
| 325 |
+
if 'clip_model' in name:
|
| 326 |
+
clip_params.append(param)
|
| 327 |
+
elif 'gpt2_model' in name:
|
| 328 |
+
gpt2_params.append(param)
|
| 329 |
+
else:
|
| 330 |
+
other_params.append(param)
|
| 331 |
+
optimizer = torch.optim.AdamW([
|
| 332 |
+
{'params': clip_params, 'lr': 1e-6},
|
| 333 |
+
{'params': gpt2_params, 'lr': 1e-6},
|
| 334 |
+
{'params': other_params, 'lr': 5e-5}
|
| 335 |
+
], weight_decay=1e-4)
|
| 336 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 337 |
+
optimizer, mode='max', factor=0.5, patience=4, verbose=True
|
| 338 |
+
)
|
| 339 |
+
print()
|
| 340 |
+
if counter >= patience:
|
| 341 |
+
print(f"\nEarly stopping after {patience} epochs without improvement")
|
| 342 |
+
logs.append([epoch+1, train_loss, train_token_acc, val_loss, val_token_acc, val_exact_match, optimizer.param_groups[0]['lr']])
|
| 343 |
+
log_df = pd.DataFrame(logs, columns=["epoch","train_loss","train_token_acc","val_loss","val_token_acc","val_exact_match","lr"])
|
| 344 |
+
log_df.to_csv(LOG_CSV, index=False)
|
| 345 |
+
plot_losses([x[1] for x in logs], [x[3] for x in logs], save_path=LOSS_GRAPH_PATH)
|
| 346 |
+
print("Training complete!")
|
| 347 |
+
print(f"Best exact match accuracy: {best_val_exact_match:.4f}")
|
| 348 |
+
if __name__ == "__main__":
|
| 349 |
+
main()
|
experiments/utils/preprocess.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from transformers import GPT2Tokenizer
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from nltk.tokenize import word_tokenize
|
| 12 |
+
from sklearn.model_selection import train_test_split
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
from model import VQAModel
|
| 15 |
+
class Vocab:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.vocab = None
|
| 18 |
+
self.vocab_size = None
|
| 19 |
+
self.word2idx = None
|
| 20 |
+
self.idx2word = None
|
| 21 |
+
self.pad = '<pad>'
|
| 22 |
+
self.bos = '<bos>'
|
| 23 |
+
self.eos = '<eos>'
|
| 24 |
+
self.unk = '<unk>'
|
| 25 |
+
def build_vocab(self, df, min_freq=1):
|
| 26 |
+
counter = Counter()
|
| 27 |
+
for ans in df['answer']:
|
| 28 |
+
tokens = word_tokenize(ans.lower())
|
| 29 |
+
counter.update(tokens)
|
| 30 |
+
vocab = sorted([word for word, freq in counter.items() if freq >= min_freq])
|
| 31 |
+
vocab = [self.pad, self.bos, self.eos, self.unk] + vocab
|
| 32 |
+
word2idx = {word: idx for idx, word in enumerate(vocab)}
|
| 33 |
+
idx2word = {idx: word for word, idx in word2idx.items()}
|
| 34 |
+
self.vocab = vocab
|
| 35 |
+
self.word2idx = word2idx
|
| 36 |
+
self.idx2word = idx2word
|
| 37 |
+
self.vocab_size = len(vocab)
|
| 38 |
+
self.pad_token_id = self.word2idx["<pad>"]
|
| 39 |
+
self.bos_token_id = self.word2idx["<bos>"]
|
| 40 |
+
self.eos_token_id = self.word2idx["<eos>"]
|
| 41 |
+
self.unk_token_id = self.word2idx["<unk>"]
|
| 42 |
+
def encoder(self, text, max_len):
|
| 43 |
+
tokens = word_tokenize(text.lower())
|
| 44 |
+
token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]
|
| 45 |
+
token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
|
| 46 |
+
if len(token_ids) < max_len:
|
| 47 |
+
token_ids += [self.pad_token_id] * (max_len - len(token_ids))
|
| 48 |
+
else:
|
| 49 |
+
token_ids = token_ids[:max_len]
|
| 50 |
+
return token_ids
|
| 51 |
+
def decoder(self, token_ids):
|
| 52 |
+
tokens = []
|
| 53 |
+
for idx in token_ids:
|
| 54 |
+
if idx == self.eos_token_id:
|
| 55 |
+
break
|
| 56 |
+
if idx in (self.pad_token_id, self.bos_token_id):
|
| 57 |
+
continue
|
| 58 |
+
tokens.append(self.idx2word.get(idx, "<unk>"))
|
| 59 |
+
return ' '.join(tokens).strip()
|
| 60 |
+
class AugmentedVQADataset(Dataset):
|
| 61 |
+
def __init__(self, df, img_dir, question_tokenizer, text_processor, clip_processor,
|
| 62 |
+
question_max_len=32, answer_max_len=16, augment=True):
|
| 63 |
+
self.df = df
|
| 64 |
+
self.img_dir = img_dir
|
| 65 |
+
self.question_tokenizer = question_tokenizer
|
| 66 |
+
self.text_processor = text_processor
|
| 67 |
+
self.clip_processor = clip_processor
|
| 68 |
+
self.question_max_len = question_max_len
|
| 69 |
+
self.answer_max_len = answer_max_len
|
| 70 |
+
self.augment = augment
|
| 71 |
+
if augment:
|
| 72 |
+
self.transform = transforms.Compose([
|
| 73 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 74 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
| 75 |
+
transforms.RandomRotation(10),
|
| 76 |
+
])
|
| 77 |
+
else:
|
| 78 |
+
self.transform = None
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return len(self.df)
|
| 81 |
+
def __getitem__(self, idx):
|
| 82 |
+
row = self.df.iloc[idx]
|
| 83 |
+
img_path = os.path.join(self.img_dir, row['image_path'])
|
| 84 |
+
image = Image.open(img_path).convert('RGB')
|
| 85 |
+
question = row['question']
|
| 86 |
+
answer = row['answer']
|
| 87 |
+
if self.augment and self.transform:
|
| 88 |
+
image = self.transform(image)
|
| 89 |
+
question_tokenized = self.question_tokenizer(
|
| 90 |
+
question,
|
| 91 |
+
padding='max_length',
|
| 92 |
+
truncation=True,
|
| 93 |
+
max_length=self.question_max_len,
|
| 94 |
+
return_tensors='pt'
|
| 95 |
+
)
|
| 96 |
+
answer_ids = self.text_processor.encoder(answer, max_len=self.answer_max_len)
|
| 97 |
+
image = self.clip_processor(image)
|
| 98 |
+
return {
|
| 99 |
+
'image_path': img_path,
|
| 100 |
+
'image': image,
|
| 101 |
+
'question_ids': question_tokenized['input_ids'].squeeze(0),
|
| 102 |
+
'question_mask': question_tokenized['attention_mask'].squeeze(0),
|
| 103 |
+
'answer_ids': torch.tensor(answer_ids, dtype=torch.long)
|
| 104 |
+
}
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
DATA_DIR = r"/home/devarajan8/Documents/vqa/gen_vqa_v2"
|
| 107 |
+
CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
|
| 108 |
+
batch_size = 16
|
| 109 |
+
question_max_len = 16
|
| 110 |
+
answer_max_len = 10
|
| 111 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 112 |
+
metadata = pd.read_csv(CSV_PATH)
|
| 113 |
+
vocab = Vocab()
|
| 114 |
+
vocab.build_vocab(metadata, min_freq=5)
|
| 115 |
+
answer_vocab_size = len(vocab.vocab)
|
| 116 |
+
print(f"Answer Vocab Size: {answer_vocab_size}")
|
| 117 |
+
train_df, test_df = train_test_split(metadata, test_size=0.2, random_state=42)
|
| 118 |
+
val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
|
| 119 |
+
print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")
|
| 120 |
+
print()
|
| 121 |
+
model = VQAModel(
|
| 122 |
+
vocab_size=answer_vocab_size,
|
| 123 |
+
device=device,
|
| 124 |
+
question_max_len=question_max_len,
|
| 125 |
+
answer_max_len=answer_max_len,
|
| 126 |
+
pad_token_id=vocab.pad_token_id,
|
| 127 |
+
bos_token_id=vocab.bos_token_id,
|
| 128 |
+
eos_token_id=vocab.eos_token_id,
|
| 129 |
+
unk_token_id=vocab.unk_token_id,
|
| 130 |
+
hidden_size=512,
|
| 131 |
+
num_layers=2
|
| 132 |
+
).to(device)
|
| 133 |
+
clip_processor = model.clip_preprocess
|
| 134 |
+
question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 135 |
+
if question_tokenizer.pad_token is None:
|
| 136 |
+
question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 137 |
+
model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
|
| 138 |
+
train_dataset = AugmentedVQADataset(
|
| 139 |
+
train_df, DATA_DIR, question_tokenizer, vocab,
|
| 140 |
+
clip_processor=clip_processor,
|
| 141 |
+
question_max_len=question_max_len,
|
| 142 |
+
answer_max_len=answer_max_len,
|
| 143 |
+
augment=True
|
| 144 |
+
)
|
| 145 |
+
val_dataset = AugmentedVQADataset(
|
| 146 |
+
val_df, DATA_DIR, question_tokenizer, vocab,
|
| 147 |
+
clip_processor=clip_processor,
|
| 148 |
+
question_max_len=question_max_len,
|
| 149 |
+
answer_max_len=answer_max_len,
|
| 150 |
+
augment=False
|
| 151 |
+
)
|
| 152 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 153 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 154 |
+
for batch in train_loader:
|
| 155 |
+
images = batch['image']
|
| 156 |
+
ques_ids = batch['question_ids']
|
| 157 |
+
attn_mask = batch['question_mask']
|
| 158 |
+
answers = batch['answer_ids']
|
| 159 |
+
print(f"Image: {images.shape}")
|
| 160 |
+
print(f"Question Ids: {ques_ids.shape}")
|
| 161 |
+
print(f"Attention Mask: {attn_mask.shape}")
|
| 162 |
+
print(f"Answer Ids: {answers.shape}")
|
| 163 |
+
print(answers[0])
|
| 164 |
+
break
|
experiments/utils/vocab.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from nltk.tokenize import word_tokenize
|
| 5 |
+
class Vocab:
|
| 6 |
+
def __init__(self):
|
| 7 |
+
self.vocab = None
|
| 8 |
+
self.vocab_size = None
|
| 9 |
+
self.word2idx = None
|
| 10 |
+
self.idx2word = None
|
| 11 |
+
self.pad = '<pad>'
|
| 12 |
+
self.bos = '<bos>'
|
| 13 |
+
self.eos = '<eos>'
|
| 14 |
+
self.unk = '<unk>'
|
| 15 |
+
def build_vocab(self, df, min_freq=1):
|
| 16 |
+
counter = Counter()
|
| 17 |
+
for ans in df['answer']:
|
| 18 |
+
tokens = word_tokenize(ans.lower())
|
| 19 |
+
counter.update(tokens)
|
| 20 |
+
vocab = sorted([word for word, freq in counter.items() if freq >= min_freq])
|
| 21 |
+
vocab = [self.pad, self.bos, self.eos, self.unk] + vocab
|
| 22 |
+
word2idx = {word: idx for idx, word in enumerate(vocab)}
|
| 23 |
+
idx2word = {idx: word for word, idx in word2idx.items()}
|
| 24 |
+
self.vocab = vocab
|
| 25 |
+
self.word2idx = word2idx
|
| 26 |
+
self.idx2word = idx2word
|
| 27 |
+
self.vocab_size = len(vocab)
|
| 28 |
+
self.pad_token_id = self.word2idx["<pad>"]
|
| 29 |
+
self.bos_token_id = self.word2idx["<bos>"]
|
| 30 |
+
self.eos_token_id = self.word2idx["<eos>"]
|
| 31 |
+
self.unk_token_id = self.word2idx["<unk>"]
|
| 32 |
+
def encoder(self, text, max_len):
|
| 33 |
+
tokens = word_tokenize(text.lower())
|
| 34 |
+
token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]
|
| 35 |
+
token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
|
| 36 |
+
if len(token_ids) < max_len:
|
| 37 |
+
token_ids += [self.pad_token_id] * (max_len - len(token_ids))
|
| 38 |
+
else:
|
| 39 |
+
token_ids = token_ids[:max_len]
|
| 40 |
+
return token_ids
|
| 41 |
+
def decoder(self, token_ids):
|
| 42 |
+
tokens = []
|
| 43 |
+
for idx in token_ids:
|
| 44 |
+
if idx == self.eos_token_id:
|
| 45 |
+
break
|
| 46 |
+
if idx in (self.pad_token_id, self.bos_token_id):
|
| 47 |
+
continue
|
| 48 |
+
tokens.append(self.idx2word.get(idx, "<unk>"))
|
| 49 |
+
return ' '.join(tokens).strip()
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
CSV_PATH = r"./gen_vqa_v2/metadata.csv"
|
| 52 |
+
answer_max_len = 10
|
| 53 |
+
metadata = pd.read_csv(CSV_PATH)
|
| 54 |
+
vocab = Vocab()
|
| 55 |
+
vocab.build_vocab(metadata, min_freq=5)
|
| 56 |
+
answer_vocab_size = len(vocab.vocab)
|
| 57 |
+
print(f"Answer Vocab Size: {answer_vocab_size}")
|
| 58 |
+
sample_answer = metadata['answer'].values
|
| 59 |
+
text = sample_answer[0]
|
| 60 |
+
print("")
|
| 61 |
+
encoded = vocab.encoder(text, answer_max_len)
|
| 62 |
+
decoded = vocab.decoder(encoded)
|
| 63 |
+
print(f"Sample Answer: {text}")
|
| 64 |
+
print(f"Encoded: {encoded}")
|
| 65 |
+
print(f"Decoded: {decoded}")
|
finetune.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from transformers import GPT2Tokenizer
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
from model import VQAModel
|
| 12 |
+
from train import AugmentedVQADataset, Vocab, save_checkpoint, plot_losses
|
| 13 |
+
def create_optimizer_with_differential_lr(model, clip_lr=5e-7, gpt_lr=5e-7, other_lr=3e-5):
|
| 14 |
+
clip_params, gpt_params, other_params = [], [], []
|
| 15 |
+
for name, param in model.named_parameters():
|
| 16 |
+
if param.requires_grad:
|
| 17 |
+
if 'clip_model' in name:
|
| 18 |
+
clip_params.append(param)
|
| 19 |
+
elif 'gpt2_model' in name:
|
| 20 |
+
gpt_params.append(param)
|
| 21 |
+
else:
|
| 22 |
+
other_params.append(param)
|
| 23 |
+
optimizer = torch.optim.AdamW([
|
| 24 |
+
{'params': clip_params, 'lr': clip_lr},
|
| 25 |
+
{'params': gpt_params, 'lr': gpt_lr},
|
| 26 |
+
{'params': other_params, 'lr': other_lr}
|
| 27 |
+
], weight_decay=1e-4)
|
| 28 |
+
print(f"Optimizer: CLIP params: {len(clip_params)}, GPT-2 params: {len(gpt_params)}, Other params: {len(other_params)}")
|
| 29 |
+
return optimizer
|
| 30 |
+
def train_one_epoch(model, dataloader, optimizer, device, vocab, scaler):
|
| 31 |
+
model.train()
|
| 32 |
+
total_loss = 0.0
|
| 33 |
+
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id, label_smoothing=0.1)
|
| 34 |
+
for batch in tqdm(dataloader):
|
| 35 |
+
optimizer.zero_grad()
|
| 36 |
+
images = batch['image'].to(device)
|
| 37 |
+
questions = {
|
| 38 |
+
'input_ids': batch['question_ids'].to(device),
|
| 39 |
+
'attention_mask': batch['question_mask'].to(device)
|
| 40 |
+
}
|
| 41 |
+
answers = batch['answer_ids'].to(device)
|
| 42 |
+
with torch.amp.autocast(device):
|
| 43 |
+
logits = model(images, questions, answer_input_ids=answers)
|
| 44 |
+
shifted_logits = logits[:, :-1, :].contiguous()
|
| 45 |
+
shifted_answers = answers[:, 1:].contiguous()
|
| 46 |
+
loss = criterion(
|
| 47 |
+
shifted_logits.view(-1, shifted_logits.size(-1)),
|
| 48 |
+
shifted_answers.view(-1)
|
| 49 |
+
)
|
| 50 |
+
if torch.isnan(loss):
|
| 51 |
+
print("NaN loss detected, skipping batch.")
|
| 52 |
+
continue
|
| 53 |
+
scaler.scale(loss).backward()
|
| 54 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 55 |
+
scaler.step(optimizer)
|
| 56 |
+
scaler.update()
|
| 57 |
+
total_loss += loss.item()
|
| 58 |
+
return total_loss / len(dataloader)
|
| 59 |
+
def validate_one_epoch(model, dataloader, device, vocab):
|
| 60 |
+
model.eval()
|
| 61 |
+
total_loss = 0.0
|
| 62 |
+
exact_matches = 0
|
| 63 |
+
total_samples = 0
|
| 64 |
+
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id)
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
for batch in tqdm(dataloader):
|
| 67 |
+
images = batch['image'].to(device)
|
| 68 |
+
questions = {
|
| 69 |
+
'input_ids': batch['question_ids'].to(device),
|
| 70 |
+
'attention_mask': batch['question_mask'].to(device)
|
| 71 |
+
}
|
| 72 |
+
answers = batch['answer_ids'].to(device)
|
| 73 |
+
with torch.amp.autocast("cuda"):
|
| 74 |
+
logits = model(images, questions, answer_input_ids=answers)
|
| 75 |
+
shifted_logits = logits[:, :-1, :].contiguous()
|
| 76 |
+
shifted_answers = answers[:, 1:].contiguous()
|
| 77 |
+
loss = criterion(
|
| 78 |
+
shifted_logits.view(-1, shifted_logits.size(-1)),
|
| 79 |
+
shifted_answers.view(-1)
|
| 80 |
+
)
|
| 81 |
+
total_loss += loss.item()
|
| 82 |
+
generated = model(images, questions)
|
| 83 |
+
for pred, true in zip(generated, answers):
|
| 84 |
+
pred_text = vocab.decoder(pred.cpu().numpy())
|
| 85 |
+
true_text = vocab.decoder(true.cpu().numpy())
|
| 86 |
+
if pred_text.strip() == true_text.strip():
|
| 87 |
+
exact_matches += 1
|
| 88 |
+
total_samples += 1
|
| 89 |
+
avg_loss = total_loss / len(dataloader)
|
| 90 |
+
exact_match_acc = exact_matches / total_samples
|
| 91 |
+
return avg_loss, exact_match_acc
|
| 92 |
+
def filter_spatial_directional_data(df):
|
| 93 |
+
spatial_keywords = [
|
| 94 |
+
'right', 'left', 'above', 'below', 'top', 'bottom',
|
| 95 |
+
'front', 'behind', 'next to', 'beside', 'near',
|
| 96 |
+
'looking', 'facing', 'pointing', 'direction',
|
| 97 |
+
'where is', 'which side', 'what side'
|
| 98 |
+
]
|
| 99 |
+
directional_answers = [
|
| 100 |
+
'up', 'down', 'left', 'right', 'forward', 'backward',
|
| 101 |
+
'north', 'south', 'east', 'west', 'straight', 'sideways'
|
| 102 |
+
]
|
| 103 |
+
spatial_mask = df['question'].str.lower().str.contains('|'.join(spatial_keywords), na=False)
|
| 104 |
+
directional_mask = df['answer'].str.lower().str.contains('|'.join(directional_answers), na=False)
|
| 105 |
+
spatial_df = df[spatial_mask | directional_mask].copy()
|
| 106 |
+
print(f"Found {len(spatial_df)} spatial/directional samples out of {len(df)} total")
|
| 107 |
+
return spatial_df
|
| 108 |
+
def main():
|
| 109 |
+
print("# VQA: Spatial-Enhanced Fine-Tuning")
|
| 110 |
+
torch.manual_seed(42)
|
| 111 |
+
np.random.seed(42)
|
| 112 |
+
random.seed(42)
|
| 113 |
+
if torch.cuda.is_available():
|
| 114 |
+
torch.cuda.manual_seed_all(42)
|
| 115 |
+
DATA_DIR = r"./gen_vqa_v2"
|
| 116 |
+
CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
|
| 117 |
+
PRETRAINED_CHECKPOINT = "./output2/feature_extraction/vqa_checkpoint.pt"
|
| 118 |
+
OUTPUT_DIR = "./output2/spatial_finetuning"
|
| 119 |
+
FINE_TUNED_CHECKPOINT = os.path.join(OUTPUT_DIR, "vqa_checkpoint.pt")
|
| 120 |
+
LOG_CSV = os.path.join(OUTPUT_DIR, "train_log.csv")
|
| 121 |
+
LOSS_GRAPH_PATH = os.path.join(OUTPUT_DIR, "loss_plot.png")
|
| 122 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 123 |
+
batch_size = 64
|
| 124 |
+
num_epochs = 50
|
| 125 |
+
patience = 8
|
| 126 |
+
clip_layers_to_unfreeze = 8
|
| 127 |
+
gpt_layers_to_unfreeze = 8
|
| 128 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 129 |
+
checkpoint = torch.load(PRETRAINED_CHECKPOINT, map_location=device)
|
| 130 |
+
metadata = pd.read_csv(CSV_PATH)
|
| 131 |
+
print(f"\nOriginal dataset size: {len(metadata)}")
|
| 132 |
+
spatial_data = filter_spatial_directional_data(metadata)
|
| 133 |
+
if len(spatial_data) < 1000:
|
| 134 |
+
print(f"\nWARNING: Only {len(spatial_data)} spatial samples found!")
|
| 135 |
+
print("Mixing 70% spatial data with 30% general data for balanced training")
|
| 136 |
+
general_data = metadata[~metadata.index.isin(spatial_data.index)].sample(n=min(len(spatial_data)//2, len(metadata)//3), random_state=42)
|
| 137 |
+
mixed_data = pd.concat([spatial_data, general_data]).sample(frac=1, random_state=42).reset_index(drop=True)
|
| 138 |
+
else:
|
| 139 |
+
print(f"Using {len(spatial_data)} spatial/directional samples")
|
| 140 |
+
mixed_data = spatial_data
|
| 141 |
+
vocab = Vocab()
|
| 142 |
+
vocab.vocab = checkpoint['vocab']
|
| 143 |
+
vocab.vocab_size = len(checkpoint['vocab'])
|
| 144 |
+
vocab.word2idx = checkpoint['word2idx']
|
| 145 |
+
vocab.idx2word = checkpoint['idx2word']
|
| 146 |
+
vocab.pad_token_id = checkpoint['pad_token_id']
|
| 147 |
+
vocab.bos_token_id = checkpoint['bos_token_id']
|
| 148 |
+
vocab.eos_token_id = checkpoint['eos_token_id']
|
| 149 |
+
vocab.unk_token_id = checkpoint['unk_token_id']
|
| 150 |
+
print(f"Answer vocabulary size: {len(vocab.vocab)}")
|
| 151 |
+
model = VQAModel(
|
| 152 |
+
vocab_size=len(checkpoint['vocab']),
|
| 153 |
+
device=device,
|
| 154 |
+
question_max_len=checkpoint.get('question_max_len', 20),
|
| 155 |
+
answer_max_len=checkpoint.get('answer_max_len', 12),
|
| 156 |
+
pad_token_id=checkpoint['pad_token_id'],
|
| 157 |
+
bos_token_id=checkpoint['bos_token_id'],
|
| 158 |
+
eos_token_id=checkpoint['eos_token_id'],
|
| 159 |
+
unk_token_id=checkpoint['unk_token_id'],
|
| 160 |
+
hidden_size=512,
|
| 161 |
+
num_layers=2
|
| 162 |
+
).to(device)
|
| 163 |
+
question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 164 |
+
if question_tokenizer.pad_token is None:
|
| 165 |
+
question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 166 |
+
model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
|
| 167 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 168 |
+
print("Pretrained model loaded successfully!\n")
|
| 169 |
+
print(f"UNFREEZING {clip_layers_to_unfreeze} CLIP LAYERS & {gpt_layers_to_unfreeze} GPT-2 LAYERS FOR SPATIAL UNDERSTANDING")
|
| 170 |
+
model.unfreeze_clip_layers(num_layers=clip_layers_to_unfreeze)
|
| 171 |
+
model.unfreeze_gpt2_layers(num_layers=gpt_layers_to_unfreeze)
|
| 172 |
+
train_df, test_df = train_test_split(mixed_data, test_size=0.2, random_state=42)
|
| 173 |
+
val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
|
| 174 |
+
print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}\n")
|
| 175 |
+
train_dataset = AugmentedVQADataset(train_df, DATA_DIR, question_tokenizer, vocab,
|
| 176 |
+
clip_processor=model.clip_preprocess, augment=True,
|
| 177 |
+
question_max_len=20, answer_max_len=12)
|
| 178 |
+
val_dataset = AugmentedVQADataset(val_df, DATA_DIR, question_tokenizer, vocab,
|
| 179 |
+
clip_processor=model.clip_preprocess, augment=False,
|
| 180 |
+
question_max_len=20, answer_max_len=12)
|
| 181 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
| 182 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
| 183 |
+
optimizer = create_optimizer_with_differential_lr(
|
| 184 |
+
model,
|
| 185 |
+
clip_lr=3e-7,
|
| 186 |
+
gpt_lr=3e-7,
|
| 187 |
+
other_lr=2e-5
|
| 188 |
+
)
|
| 189 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, verbose=True)
|
| 190 |
+
scaler = torch.amp.GradScaler(device)
|
| 191 |
+
print("\nSTARTING SPATIAL-ENHANCED FINE-TUNING")
|
| 192 |
+
best_val_loss = np.inf
|
| 193 |
+
best_exact_match = 0.0
|
| 194 |
+
logs = []
|
| 195 |
+
counter = 0
|
| 196 |
+
for epoch in range(num_epochs):
|
| 197 |
+
print(f"\nSpatial Fine-tuning Epoch {epoch+1}/{num_epochs}")
|
| 198 |
+
train_loss = train_one_epoch(model, train_loader, optimizer, device, vocab, scaler)
|
| 199 |
+
val_loss, val_exact_match = validate_one_epoch(model, val_loader, device, vocab)
|
| 200 |
+
print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Exact Match: {val_exact_match:.4f} | LR: {optimizer.param_groups[0]['lr']}")
|
| 201 |
+
scheduler.step(val_exact_match)
|
| 202 |
+
if val_exact_match > best_exact_match:
|
| 203 |
+
best_exact_match = val_exact_match
|
| 204 |
+
save_checkpoint(model, optimizer, epoch, vocab, FINE_TUNED_CHECKPOINT)
|
| 205 |
+
print("Checkpoint saved!")
|
| 206 |
+
counter = 0
|
| 207 |
+
else:
|
| 208 |
+
counter += 1
|
| 209 |
+
print(f"No improvement for {counter} epochs.")
|
| 210 |
+
if counter >= patience:
|
| 211 |
+
print(f"\nEarly stopping after {patience} epochs without improvement")
|
| 212 |
+
break
|
| 213 |
+
logs.append([epoch + 1, train_loss, val_loss, val_exact_match, optimizer.param_groups[0]['lr']])
|
| 214 |
+
pd.DataFrame(logs, columns=["epoch", "train_loss", "val_loss", "val_exact_match", "lr"]).to_csv(LOG_CSV, index=False)
|
| 215 |
+
plot_losses([x[1] for x in logs], [x[2] for x in logs], save_path=LOSS_GRAPH_PATH)
|
| 216 |
+
print("\nFINE-TUNING COMPLETE")
|
| 217 |
+
print(f"Best exact match: {best_exact_match:.4f}")
|
| 218 |
+
print(f"Model saved to: {FINE_TUNED_CHECKPOINT}")
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
main()
|
finetune2.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from transformers import GPT2Tokenizer
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
from model import VQAModel
|
| 12 |
+
from model_spatial import VQAModelWithSpatialAdapter
|
| 13 |
+
from train import AugmentedVQADataset, Vocab, save_checkpoint, plot_losses
|
| 14 |
+
import math
|
| 15 |
+
def filter_spatial_questions(df):
|
| 16 |
+
"""
|
| 17 |
+
Filter dataset for spatial/directional questions.
|
| 18 |
+
Returns both spatial subset and general subset for mixed training.
|
| 19 |
+
"""
|
| 20 |
+
spatial_keywords = [
|
| 21 |
+
'right', 'left', 'above', 'below', 'top', 'bottom',
|
| 22 |
+
'front', 'behind', 'next to', 'beside', 'near', 'between',
|
| 23 |
+
'in front', 'in back', 'across from', 'opposite',
|
| 24 |
+
'closest', 'farthest', 'nearest', 'furthest',
|
| 25 |
+
'where is', 'which side', 'what side', 'what direction',
|
| 26 |
+
'on the left', 'on the right', 'at the top', 'at the bottom'
|
| 27 |
+
]
|
| 28 |
+
pattern = '|'.join(spatial_keywords)
|
| 29 |
+
spatial_mask = df['question'].str.lower().str.contains(pattern, na=False, regex=True)
|
| 30 |
+
spatial_df = df[spatial_mask].copy()
|
| 31 |
+
general_df = df[~spatial_mask].copy()
|
| 32 |
+
print(f"\n📊 Dataset Filtering Results:")
|
| 33 |
+
print(f" Total samples: {len(df):,}")
|
| 34 |
+
print(f" Spatial samples: {len(spatial_df):,} ({len(spatial_df)/len(df)*100:.1f}%)")
|
| 35 |
+
print(f" General samples: {len(general_df):,} ({len(general_df)/len(df)*100:.1f}%)")
|
| 36 |
+
if len(spatial_df) > 0:
|
| 37 |
+
print(f"\n📝 Sample Spatial Questions:")
|
| 38 |
+
for i, row in spatial_df.sample(min(5, len(spatial_df))).iterrows():
|
| 39 |
+
print(f" Q: {row['question']}")
|
| 40 |
+
print(f" A: {row['answer']}\n")
|
| 41 |
+
return spatial_df, general_df
|
| 42 |
+
def create_mixed_dataset(spatial_df, general_df, spatial_ratio=0.85, min_spatial_samples=1000):
|
| 43 |
+
"""
|
| 44 |
+
Create mixed dataset with specified ratio of spatial to general questions.
|
| 45 |
+
Increased default to 85% spatial for better spatial learning.
|
| 46 |
+
"""
|
| 47 |
+
if len(spatial_df) < min_spatial_samples:
|
| 48 |
+
print(f"\n⚠️ WARNING: Only {len(spatial_df)} spatial samples found!")
|
| 49 |
+
print(f" Recommended minimum: {min_spatial_samples}")
|
| 50 |
+
print(f" Mixing with general data to prevent catastrophic forgetting...")
|
| 51 |
+
num_spatial = len(spatial_df)
|
| 52 |
+
num_general = int(num_spatial * (1 - spatial_ratio) / spatial_ratio)
|
| 53 |
+
num_general = min(num_general, len(general_df))
|
| 54 |
+
else:
|
| 55 |
+
num_spatial = len(spatial_df)
|
| 56 |
+
num_general = int(num_spatial * (1 - spatial_ratio) / spatial_ratio)
|
| 57 |
+
num_general = min(num_general, len(general_df))
|
| 58 |
+
general_sample = general_df.sample(n=num_general, random_state=42)
|
| 59 |
+
mixed_df = pd.concat([spatial_df, general_sample]).sample(frac=1, random_state=42).reset_index(drop=True)
|
| 60 |
+
print(f"\n🔀 Mixed Dataset Created:")
|
| 61 |
+
print(f" Spatial: {num_spatial:,} ({num_spatial/len(mixed_df)*100:.1f}%)")
|
| 62 |
+
print(f" General: {num_general:,} ({num_general/len(mixed_df)*100:.1f}%)")
|
| 63 |
+
print(f" Total: {len(mixed_df):,}")
|
| 64 |
+
return mixed_df
|
| 65 |
+
def unfreeze_clip_layers(model, num_layers=4):
|
| 66 |
+
"""
|
| 67 |
+
Unfreeze last N layers of CLIP for spatial feature learning.
|
| 68 |
+
"""
|
| 69 |
+
total_blocks = len(model.clip_model.visual.transformer.resblocks)
|
| 70 |
+
for i, block in enumerate(model.clip_model.visual.transformer.resblocks):
|
| 71 |
+
if i >= total_blocks - num_layers:
|
| 72 |
+
for p in block.parameters():
|
| 73 |
+
p.requires_grad = True
|
| 74 |
+
if hasattr(model.clip_model.visual, "proj") and model.clip_model.visual.proj is not None:
|
| 75 |
+
if isinstance(model.clip_model.visual.proj, torch.nn.Parameter):
|
| 76 |
+
model.clip_model.visual.proj.requires_grad = True
|
| 77 |
+
else:
|
| 78 |
+
for p in model.clip_model.visual.proj.parameters():
|
| 79 |
+
p.requires_grad = True
|
| 80 |
+
if hasattr(model.clip_model.visual, "ln_post"):
|
| 81 |
+
for p in model.clip_model.visual.ln_post.parameters():
|
| 82 |
+
p.requires_grad = True
|
| 83 |
+
print(f" ✓ Unfroze last {num_layers} CLIP layers")
|
| 84 |
+
def freeze_base_model(model, unfreeze_clip_layers_count=4):
|
| 85 |
+
"""
|
| 86 |
+
Freeze most of the model, unfreeze spatial adapter and last CLIP layers.
|
| 87 |
+
"""
|
| 88 |
+
for param in model.clip_model.parameters():
|
| 89 |
+
param.requires_grad = False
|
| 90 |
+
unfreeze_clip_layers(model, num_layers=unfreeze_clip_layers_count)
|
| 91 |
+
for param in model.gpt2_model.parameters():
|
| 92 |
+
param.requires_grad = False
|
| 93 |
+
for param in model.decoder.parameters():
|
| 94 |
+
param.requires_grad = False
|
| 95 |
+
for param in model.spatial_adapter.parameters():
|
| 96 |
+
param.requires_grad = True
|
| 97 |
+
for param in model.spatial_context_proj.parameters():
|
| 98 |
+
param.requires_grad = True
|
| 99 |
+
for param in model.q_proj.parameters():
|
| 100 |
+
param.requires_grad = True
|
| 101 |
+
for param in model.spatial_fusion.parameters():
|
| 102 |
+
param.requires_grad = True
|
| 103 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 104 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 105 |
+
print(f"\n🔒 Model Freezing Applied:")
|
| 106 |
+
print(f" Total parameters: {total_params:,}")
|
| 107 |
+
print(f" Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)")
|
| 108 |
+
print(f" Frozen parameters: {total_params - trainable_params:,}")
|
| 109 |
+
return model
|
| 110 |
+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr=1e-7):
|
| 111 |
+
"""
|
| 112 |
+
Create learning rate scheduler with warmup and cosine decay.
|
| 113 |
+
"""
|
| 114 |
+
def lr_lambda(current_step):
|
| 115 |
+
if current_step < num_warmup_steps:
|
| 116 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 117 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 118 |
+
return max(min_lr, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| 119 |
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 120 |
+
def create_optimizer_with_differential_lr(model, base_lr=5e-5):
|
| 121 |
+
"""
|
| 122 |
+
Create optimizer with differential learning rates for different components.
|
| 123 |
+
"""
|
| 124 |
+
clip_params = []
|
| 125 |
+
spatial_adapter_params = []
|
| 126 |
+
other_params = []
|
| 127 |
+
for name, param in model.named_parameters():
|
| 128 |
+
if param.requires_grad:
|
| 129 |
+
if 'clip_model' in name:
|
| 130 |
+
clip_params.append(param)
|
| 131 |
+
elif 'spatial_adapter' in name:
|
| 132 |
+
spatial_adapter_params.append(param)
|
| 133 |
+
else:
|
| 134 |
+
other_params.append(param)
|
| 135 |
+
optimizer = torch.optim.AdamW([
|
| 136 |
+
{'params': clip_params, 'lr': base_lr * 0.1},
|
| 137 |
+
{'params': spatial_adapter_params, 'lr': base_lr},
|
| 138 |
+
{'params': other_params, 'lr': base_lr * 0.5}
|
| 139 |
+
], weight_decay=1e-4)
|
| 140 |
+
print(f"\n⚙️ Optimizer Configuration:")
|
| 141 |
+
print(f" CLIP params: {len(clip_params):,} (LR: {base_lr * 0.1:.2e})")
|
| 142 |
+
print(f" Spatial adapter params: {len(spatial_adapter_params):,} (LR: {base_lr:.2e})")
|
| 143 |
+
print(f" Other params: {len(other_params):,} (LR: {base_lr * 0.5:.2e})")
|
| 144 |
+
return optimizer
|
| 145 |
+
def train_one_epoch(model, dataloader, optimizer, device, vocab, scaler):
|
| 146 |
+
"""Training loop for one epoch"""
|
| 147 |
+
model.train()
|
| 148 |
+
total_loss = 0.0
|
| 149 |
+
total_token_acc = 0.0
|
| 150 |
+
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id, label_smoothing=0.1)
|
| 151 |
+
for batch in tqdm(dataloader, desc="Training"):
|
| 152 |
+
optimizer.zero_grad()
|
| 153 |
+
images = batch['image'].to(device)
|
| 154 |
+
questions = {
|
| 155 |
+
'input_ids': batch['question_ids'].to(device),
|
| 156 |
+
'attention_mask': batch['question_mask'].to(device)
|
| 157 |
+
}
|
| 158 |
+
answers = batch['answer_ids'].to(device)
|
| 159 |
+
with torch.amp.autocast(device):
|
| 160 |
+
logits = model(images, questions, answer_input_ids=answers)
|
| 161 |
+
shifted_logits = logits[:, :-1, :].contiguous()
|
| 162 |
+
shifted_answers = answers[:, 1:].contiguous()
|
| 163 |
+
loss = criterion(
|
| 164 |
+
shifted_logits.view(-1, shifted_logits.size(-1)),
|
| 165 |
+
shifted_answers.view(-1)
|
| 166 |
+
)
|
| 167 |
+
predicted_tokens = shifted_logits.argmax(dim=-1)
|
| 168 |
+
correct = (predicted_tokens == shifted_answers).float()
|
| 169 |
+
mask = (shifted_answers != vocab.pad_token_id).float()
|
| 170 |
+
token_acc = (correct * mask).sum() / mask.sum()
|
| 171 |
+
total_token_acc += token_acc.item()
|
| 172 |
+
if torch.isnan(loss):
|
| 173 |
+
print("⚠️ NaN loss detected, skipping batch.")
|
| 174 |
+
continue
|
| 175 |
+
scaler.scale(loss).backward()
|
| 176 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 177 |
+
scaler.step(optimizer)
|
| 178 |
+
scaler.update()
|
| 179 |
+
total_loss += loss.item()
|
| 180 |
+
avg_loss = total_loss / len(dataloader)
|
| 181 |
+
avg_token_acc = total_token_acc / len(dataloader)
|
| 182 |
+
return avg_loss, avg_token_acc
|
| 183 |
+
def validate_one_epoch(model, dataloader, device, vocab):
|
| 184 |
+
"""Validation loop for one epoch"""
|
| 185 |
+
model.eval()
|
| 186 |
+
total_loss = 0.0
|
| 187 |
+
total_token_acc = 0.0
|
| 188 |
+
exact_matches = 0
|
| 189 |
+
total_samples = 0
|
| 190 |
+
criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id)
|
| 191 |
+
with torch.no_grad():
|
| 192 |
+
for batch in tqdm(dataloader, desc="Validation"):
|
| 193 |
+
images = batch['image'].to(device)
|
| 194 |
+
questions = {
|
| 195 |
+
'input_ids': batch['question_ids'].to(device),
|
| 196 |
+
'attention_mask': batch['question_mask'].to(device)
|
| 197 |
+
}
|
| 198 |
+
answers = batch['answer_ids'].to(device)
|
| 199 |
+
with torch.amp.autocast(device):
|
| 200 |
+
logits = model(images, questions, answer_input_ids=answers)
|
| 201 |
+
shifted_logits = logits[:, :-1, :].contiguous()
|
| 202 |
+
shifted_answers = answers[:, 1:].contiguous()
|
| 203 |
+
loss = criterion(
|
| 204 |
+
shifted_logits.view(-1, shifted_logits.size(-1)),
|
| 205 |
+
shifted_answers.view(-1)
|
| 206 |
+
)
|
| 207 |
+
predicted_tokens = shifted_logits.argmax(dim=-1)
|
| 208 |
+
correct = (predicted_tokens == shifted_answers).float()
|
| 209 |
+
mask = (shifted_answers != vocab.pad_token_id).float()
|
| 210 |
+
token_acc = (correct * mask).sum() / mask.sum()
|
| 211 |
+
total_token_acc += token_acc.item()
|
| 212 |
+
total_loss += loss.item()
|
| 213 |
+
generated = model(images, questions)
|
| 214 |
+
for pred, true in zip(generated, answers):
|
| 215 |
+
pred_text = vocab.decoder(pred.cpu().numpy())
|
| 216 |
+
true_text = vocab.decoder(true.cpu().numpy())
|
| 217 |
+
if pred_text.strip() == true_text.strip():
|
| 218 |
+
exact_matches += 1
|
| 219 |
+
total_samples += 1
|
| 220 |
+
avg_loss = total_loss / len(dataloader)
|
| 221 |
+
avg_token_acc = total_token_acc / len(dataloader)
|
| 222 |
+
exact_match_acc = exact_matches / total_samples
|
| 223 |
+
return avg_loss, avg_token_acc, exact_match_acc
|
| 224 |
+
def main():
|
| 225 |
+
print("=" * 80)
|
| 226 |
+
print("🚀 VQA SPATIAL ADAPTER FINE-TUNING V2 (ENHANCED)")
|
| 227 |
+
print("=" * 80)
|
| 228 |
+
torch.manual_seed(42)
|
| 229 |
+
np.random.seed(42)
|
| 230 |
+
random.seed(42)
|
| 231 |
+
if torch.cuda.is_available():
|
| 232 |
+
torch.cuda.manual_seed_all(42)
|
| 233 |
+
DATA_DIR = r"./gen_vqa_v2"
|
| 234 |
+
CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
|
| 235 |
+
PRETRAINED_CHECKPOINT = "./output2/continued_training/vqa_checkpoint.pt"
|
| 236 |
+
OUTPUT_DIR = "./output2/spatial_adapter_v2_2"
|
| 237 |
+
FINE_TUNED_CHECKPOINT = os.path.join(OUTPUT_DIR, "vqa_spatial_checkpoint.pt")
|
| 238 |
+
LOG_CSV = os.path.join(OUTPUT_DIR, "train_log.csv")
|
| 239 |
+
LOSS_GRAPH_PATH = os.path.join(OUTPUT_DIR, "loss_plot.png")
|
| 240 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 241 |
+
batch_size = 64
|
| 242 |
+
base_learning_rate = 5e-5
|
| 243 |
+
num_epochs = 100
|
| 244 |
+
patience = 15
|
| 245 |
+
warmup_epochs = 3
|
| 246 |
+
spatial_ratio = 0.85
|
| 247 |
+
clip_layers_to_unfreeze = 6
|
| 248 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 249 |
+
print(f"\n⚙️ Enhanced Configuration:")
|
| 250 |
+
print(f" Device: {device}")
|
| 251 |
+
print(f" Batch size: {batch_size}")
|
| 252 |
+
print(f" Base learning rate: {base_learning_rate:.2e}")
|
| 253 |
+
print(f" Max epochs: {num_epochs} (increased from 20)")
|
| 254 |
+
print(f" Warmup epochs: {warmup_epochs}")
|
| 255 |
+
print(f" Early stopping patience: {patience}")
|
| 256 |
+
print(f" Spatial ratio: {spatial_ratio:.0%} (increased from 70%)")
|
| 257 |
+
print(f" CLIP layers to unfreeze: {clip_layers_to_unfreeze}")
|
| 258 |
+
print(f"\n📂 Loading dataset from: {CSV_PATH}")
|
| 259 |
+
metadata = pd.read_csv(CSV_PATH)
|
| 260 |
+
spatial_df, general_df = filter_spatial_questions(metadata)
|
| 261 |
+
mixed_data = create_mixed_dataset(spatial_df, general_df, spatial_ratio=spatial_ratio)
|
| 262 |
+
print(f"\n📥 Loading pretrained model from: {PRETRAINED_CHECKPOINT}")
|
| 263 |
+
checkpoint = torch.load(PRETRAINED_CHECKPOINT, map_location=device)
|
| 264 |
+
vocab = Vocab()
|
| 265 |
+
vocab.vocab = checkpoint['vocab']
|
| 266 |
+
vocab.vocab_size = len(checkpoint['vocab'])
|
| 267 |
+
vocab.word2idx = checkpoint['word2idx']
|
| 268 |
+
vocab.idx2word = checkpoint['idx2word']
|
| 269 |
+
vocab.pad_token_id = checkpoint['pad_token_id']
|
| 270 |
+
vocab.bos_token_id = checkpoint['bos_token_id']
|
| 271 |
+
vocab.eos_token_id = checkpoint['eos_token_id']
|
| 272 |
+
vocab.unk_token_id = checkpoint['unk_token_id']
|
| 273 |
+
print(f" Vocabulary size: {len(vocab.vocab):,}")
|
| 274 |
+
question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 275 |
+
if question_tokenizer.pad_token is None:
|
| 276 |
+
question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 277 |
+
base_model = VQAModel(
|
| 278 |
+
vocab_size=len(checkpoint['vocab']),
|
| 279 |
+
device=device,
|
| 280 |
+
question_max_len=checkpoint.get('question_max_len', 20),
|
| 281 |
+
answer_max_len=checkpoint.get('answer_max_len', 12),
|
| 282 |
+
pad_token_id=checkpoint['pad_token_id'],
|
| 283 |
+
bos_token_id=checkpoint['bos_token_id'],
|
| 284 |
+
eos_token_id=checkpoint['eos_token_id'],
|
| 285 |
+
unk_token_id=checkpoint['unk_token_id'],
|
| 286 |
+
hidden_size=512,
|
| 287 |
+
num_layers=2
|
| 288 |
+
).to(device)
|
| 289 |
+
base_model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
|
| 290 |
+
base_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 291 |
+
print(" ✓ Pretrained weights loaded")
|
| 292 |
+
print(f"\n🔧 Creating VQA model with spatial adapter...")
|
| 293 |
+
model = VQAModelWithSpatialAdapter(
|
| 294 |
+
base_model=base_model,
|
| 295 |
+
hidden_size=512,
|
| 296 |
+
num_heads=8,
|
| 297 |
+
dropout=0.3
|
| 298 |
+
).to(device)
|
| 299 |
+
model = freeze_base_model(model, unfreeze_clip_layers_count=clip_layers_to_unfreeze)
|
| 300 |
+
train_df, test_df = train_test_split(mixed_data, test_size=0.2, random_state=42)
|
| 301 |
+
val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
|
| 302 |
+
print(f"\n📊 Data Split:")
|
| 303 |
+
print(f" Train: {len(train_df):,} samples")
|
| 304 |
+
print(f" Validation: {len(val_df):,} samples")
|
| 305 |
+
print(f" Test: {len(test_df):,} samples")
|
| 306 |
+
from torchvision import transforms
|
| 307 |
+
safe_augmentation = transforms.Compose([
|
| 308 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
| 309 |
+
transforms.RandomRotation(5),
|
| 310 |
+
])
|
| 311 |
+
train_dataset = AugmentedVQADataset(
|
| 312 |
+
train_df, DATA_DIR, question_tokenizer, vocab,
|
| 313 |
+
clip_processor=model.clip_preprocess,
|
| 314 |
+
augment=False,
|
| 315 |
+
question_max_len=20,
|
| 316 |
+
answer_max_len=12
|
| 317 |
+
)
|
| 318 |
+
val_dataset = AugmentedVQADataset(
|
| 319 |
+
val_df, DATA_DIR, question_tokenizer, vocab,
|
| 320 |
+
clip_processor=model.clip_preprocess,
|
| 321 |
+
augment=False,
|
| 322 |
+
question_max_len=20,
|
| 323 |
+
answer_max_len=12
|
| 324 |
+
)
|
| 325 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
|
| 326 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
|
| 327 |
+
optimizer = create_optimizer_with_differential_lr(model, base_lr=base_learning_rate)
|
| 328 |
+
num_training_steps = len(train_loader) * num_epochs
|
| 329 |
+
num_warmup_steps = len(train_loader) * warmup_epochs
|
| 330 |
+
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
|
| 331 |
+
print(f"\n📈 Learning Rate Schedule:")
|
| 332 |
+
print(f" Warmup steps: {num_warmup_steps:,} ({warmup_epochs} epochs)")
|
| 333 |
+
print(f" Total steps: {num_training_steps:,}")
|
| 334 |
+
print(f" Schedule: Linear warmup → Cosine decay")
|
| 335 |
+
scaler = torch.amp.GradScaler(device)
|
| 336 |
+
print("\n" + "=" * 80)
|
| 337 |
+
print("🎯 STARTING ENHANCED SPATIAL ADAPTER FINE-TUNING")
|
| 338 |
+
print("=" * 80)
|
| 339 |
+
best_val_exact_match = 0.0
|
| 340 |
+
best_val_loss = np.inf
|
| 341 |
+
counter = 0
|
| 342 |
+
logs = []
|
| 343 |
+
for epoch in range(num_epochs):
|
| 344 |
+
print(f"\n📅 Epoch {epoch+1}/{num_epochs}")
|
| 345 |
+
print("-" * 80)
|
| 346 |
+
train_loss, train_token_acc = train_one_epoch(model, train_loader, optimizer, device, vocab, scaler)
|
| 347 |
+
val_loss, val_token_acc, val_exact_match = validate_one_epoch(model, val_loader, device, vocab)
|
| 348 |
+
current_lr = optimizer.param_groups[1]['lr']
|
| 349 |
+
print(f"\n📈 Metrics:")
|
| 350 |
+
print(f" Train Loss: {train_loss:.4f} | Train Token Acc: {train_token_acc:.4f}")
|
| 351 |
+
print(f" Val Loss: {val_loss:.4f} | Val Token Acc: {val_token_acc:.4f}")
|
| 352 |
+
print(f" Val Exact Match: {val_exact_match:.4f}")
|
| 353 |
+
print(f" Learning Rate: {current_lr:.2e}")
|
| 354 |
+
if val_exact_match > best_val_exact_match:
|
| 355 |
+
best_val_exact_match = val_exact_match
|
| 356 |
+
save_checkpoint(model, optimizer, epoch, vocab, FINE_TUNED_CHECKPOINT)
|
| 357 |
+
print(f" ✅ New best model saved! (Exact Match: {val_exact_match:.4f})")
|
| 358 |
+
counter = 0
|
| 359 |
+
else:
|
| 360 |
+
counter += 1
|
| 361 |
+
print(f" ⏳ No improvement for {counter} epoch(s)")
|
| 362 |
+
if counter >= patience:
|
| 363 |
+
print(f"\n⏹️ Early stopping triggered after {patience} epochs without improvement")
|
| 364 |
+
break
|
| 365 |
+
logs.append([
|
| 366 |
+
epoch + 1,
|
| 367 |
+
train_loss,
|
| 368 |
+
train_token_acc,
|
| 369 |
+
val_loss,
|
| 370 |
+
val_token_acc,
|
| 371 |
+
val_exact_match,
|
| 372 |
+
current_lr
|
| 373 |
+
])
|
| 374 |
+
for _ in range(len(train_loader)):
|
| 375 |
+
scheduler.step()
|
| 376 |
+
log_df = pd.DataFrame(
|
| 377 |
+
logs,
|
| 378 |
+
columns=["epoch", "train_loss", "train_token_acc", "val_loss", "val_token_acc", "val_exact_match", "lr"]
|
| 379 |
+
)
|
| 380 |
+
log_df.to_csv(LOG_CSV, index=False)
|
| 381 |
+
plot_losses([x[1] for x in logs], [x[3] for x in logs], save_path=LOSS_GRAPH_PATH)
|
| 382 |
+
print("\n" + "=" * 80)
|
| 383 |
+
print("✅ ENHANCED FINE-TUNING COMPLETE")
|
| 384 |
+
print("=" * 80)
|
| 385 |
+
print(f"\n📊 Final Results:")
|
| 386 |
+
print(f" Best Exact Match: {best_val_exact_match:.4f}")
|
| 387 |
+
print(f" Total Epochs: {len(logs)}")
|
| 388 |
+
print(f" Improvement from v1: {best_val_exact_match - 0.2037:.4f} ({(best_val_exact_match - 0.2037) / 0.2037 * 100:+.1f}%)")
|
| 389 |
+
print(f"\n💾 Outputs:")
|
| 390 |
+
print(f" Model: {FINE_TUNED_CHECKPOINT}")
|
| 391 |
+
print(f" Logs: {LOG_CSV}")
|
| 392 |
+
print(f" Plot: {LOSS_GRAPH_PATH}")
|
| 393 |
+
print("\n🎉 Ready to test on spatial questions!")
|
| 394 |
+
if __name__ == "__main__":
|
| 395 |
+
main()
|
genvqa-dataset.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import shutil
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from collections import Counter
|
| 8 |
+
IMAGES_DIR = r"../train2014"
|
| 9 |
+
QUESTIONS_PATH = r"../v2_OpenEnded_mscoco_train2014_questions.json"
|
| 10 |
+
ANNOTATIONS_PATH = r"../v2_mscoco_train2014_annotations.json"
|
| 11 |
+
OUTPUT_DIR = "./gen_vqa_v2"
|
| 12 |
+
os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True)
|
| 13 |
+
print("Loading VQA v2 data...")
|
| 14 |
+
with open(QUESTIONS_PATH, "r") as f:
|
| 15 |
+
questions = json.load(f)["questions"]
|
| 16 |
+
with open(ANNOTATIONS_PATH, "r") as f:
|
| 17 |
+
annotations = json.load(f)["annotations"]
|
| 18 |
+
qid_to_ann = {ann["question_id"]: ann for ann in annotations}
|
| 19 |
+
print("Merging questions and answers...")
|
| 20 |
+
merged_data = []
|
| 21 |
+
answer_counter = Counter()
|
| 22 |
+
EXCLUDED_ANSWERS = ['yes', 'no', 'unknown', 'none', 'n/a', 'cant tell', 'not sure']
|
| 23 |
+
AMBIGUOUS_QUESTIONS = ['what is in the image', 'what is this', 'what is that', 'what do you see']
|
| 24 |
+
for q in tqdm(questions, total=len(questions)):
|
| 25 |
+
ann = qid_to_ann.get(q["question_id"])
|
| 26 |
+
if not ann:
|
| 27 |
+
continue
|
| 28 |
+
answers = [a["answer"] for a in ann["answers"] if a["answer"].strip()]
|
| 29 |
+
if not answers:
|
| 30 |
+
continue
|
| 31 |
+
main_answer = max(set(answers), key=answers.count)
|
| 32 |
+
main_answer = main_answer.lower().strip()
|
| 33 |
+
question_text = q["question"].lower().strip()
|
| 34 |
+
if main_answer in EXCLUDED_ANSWERS:
|
| 35 |
+
continue
|
| 36 |
+
if any(ambig in question_text for ambig in AMBIGUOUS_QUESTIONS):
|
| 37 |
+
continue
|
| 38 |
+
if len(main_answer.split()) <= 5 and len(main_answer) <= 30:
|
| 39 |
+
merged_data.append({
|
| 40 |
+
"image_id": q["image_id"],
|
| 41 |
+
"question_id": q["question_id"],
|
| 42 |
+
"question": q["question"],
|
| 43 |
+
"answer": main_answer
|
| 44 |
+
})
|
| 45 |
+
answer_counter[main_answer] += 1
|
| 46 |
+
print(f"Total valid Q-A pairs (after filtering): {len(merged_data)}")
|
| 47 |
+
MIN_ANSWER_FREQ = 20
|
| 48 |
+
frequent_answers = {ans for ans, count in answer_counter.items() if count >= MIN_ANSWER_FREQ}
|
| 49 |
+
filtered_data = [item for item in merged_data if item["answer"] in frequent_answers]
|
| 50 |
+
print(f"After frequency filtering (min_freq={MIN_ANSWER_FREQ}): {len(filtered_data)} pairs")
|
| 51 |
+
MAX_SAMPLES_PER_ANSWER = 600
|
| 52 |
+
answer_samples = {}
|
| 53 |
+
for item in filtered_data:
|
| 54 |
+
ans = item["answer"]
|
| 55 |
+
if ans not in answer_samples:
|
| 56 |
+
answer_samples[ans] = []
|
| 57 |
+
if len(answer_samples[ans]) < MAX_SAMPLES_PER_ANSWER:
|
| 58 |
+
answer_samples[ans].append(item)
|
| 59 |
+
balanced_data = []
|
| 60 |
+
for samples in answer_samples.values():
|
| 61 |
+
balanced_data.extend(samples)
|
| 62 |
+
random.shuffle(balanced_data)
|
| 63 |
+
print(f"After balancing: {len(balanced_data)} pairs with {len(answer_samples)} unique answers")
|
| 64 |
+
print("Copying selected images and saving data...")
|
| 65 |
+
final_data = []
|
| 66 |
+
for item in tqdm(balanced_data):
|
| 67 |
+
img_name = f"COCO_train2014_{item['image_id']:012d}.jpg"
|
| 68 |
+
src_path = os.path.join(IMAGES_DIR, img_name)
|
| 69 |
+
dst_path = os.path.join(OUTPUT_DIR, "images", img_name)
|
| 70 |
+
if os.path.exists(src_path):
|
| 71 |
+
shutil.copy(src_path, dst_path)
|
| 72 |
+
item["image_path"] = f"images/{img_name}"
|
| 73 |
+
final_data.append(item)
|
| 74 |
+
print(f"Final dataset: {len(final_data)} pairs")
|
| 75 |
+
with open(os.path.join(OUTPUT_DIR, "qa_pairs.json"), "w") as f:
|
| 76 |
+
json.dump(final_data, f, indent=2, ensure_ascii=False)
|
| 77 |
+
pd.DataFrame(final_data).to_csv(os.path.join(OUTPUT_DIR, "metadata.csv"), index=False)
|
| 78 |
+
print("Data preparation complete.")
|
groq_service.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Groq LLM Service for VQA Accessibility
|
| 3 |
+
Generates descriptive 2-sentence narrations for blind users
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, Optional
|
| 7 |
+
from groq import Groq
|
| 8 |
+
class GroqDescriptionService:
|
| 9 |
+
"""Service to generate accessible descriptions using Groq LLM"""
|
| 10 |
+
def __init__(self, api_key: Optional[str] = None):
|
| 11 |
+
"""
|
| 12 |
+
Initialize Groq service
|
| 13 |
+
Args:
|
| 14 |
+
api_key: Groq API key (if not provided, reads from GROQ_API_KEY env var)
|
| 15 |
+
"""
|
| 16 |
+
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
| 17 |
+
if not self.api_key:
|
| 18 |
+
raise ValueError(
|
| 19 |
+
"Groq API key not found. Set GROQ_API_KEY environment variable "
|
| 20 |
+
"or pass api_key parameter"
|
| 21 |
+
)
|
| 22 |
+
self.client = Groq(api_key=self.api_key)
|
| 23 |
+
self.model = "llama-3.3-70b-versatile"
|
| 24 |
+
def generate_description(
|
| 25 |
+
self,
|
| 26 |
+
question: str,
|
| 27 |
+
answer: str,
|
| 28 |
+
max_retries: int = 2
|
| 29 |
+
) -> Dict[str, str]:
|
| 30 |
+
"""
|
| 31 |
+
Generate a 2-sentence accessible description for blind users
|
| 32 |
+
Args:
|
| 33 |
+
question: The question asked by the user
|
| 34 |
+
answer: The VQA model's answer
|
| 35 |
+
max_retries: Number of retry attempts on failure
|
| 36 |
+
Returns:
|
| 37 |
+
Dict with 'description' and 'status' keys
|
| 38 |
+
"""
|
| 39 |
+
prompt = f"""You are an accessibility assistant helping blind users understand visual question answering results.
|
| 40 |
+
Question asked: "{question}"
|
| 41 |
+
Answer from VQA model: "{answer}"
|
| 42 |
+
Task: Create a clear, natural 2-sentence description that:
|
| 43 |
+
1. First sentence: Restates the question and provides the answer
|
| 44 |
+
2. Second sentence: Adds helpful context or clarification
|
| 45 |
+
Keep it concise, natural, and easy to understand when spoken aloud.
|
| 46 |
+
Example:
|
| 47 |
+
Question: "What color is the car?"
|
| 48 |
+
Answer: "red"
|
| 49 |
+
Description: "The question asks about the color of the car, and the answer is red. This indicates there is a red-colored vehicle visible in the image."
|
| 50 |
+
Now generate the description:"""
|
| 51 |
+
for attempt in range(max_retries + 1):
|
| 52 |
+
try:
|
| 53 |
+
response = self.client.chat.completions.create(
|
| 54 |
+
model=self.model,
|
| 55 |
+
messages=[
|
| 56 |
+
{
|
| 57 |
+
"role": "system",
|
| 58 |
+
"content": "You are a helpful accessibility assistant. Always respond with exactly 2 clear, natural sentences."
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"role": "user",
|
| 62 |
+
"content": prompt
|
| 63 |
+
}
|
| 64 |
+
],
|
| 65 |
+
temperature=0.7,
|
| 66 |
+
max_tokens=150,
|
| 67 |
+
top_p=0.9
|
| 68 |
+
)
|
| 69 |
+
description = response.choices[0].message.content.strip()
|
| 70 |
+
if description.startswith("Description:"):
|
| 71 |
+
description = description.replace("Description:", "").strip()
|
| 72 |
+
return {
|
| 73 |
+
"description": description,
|
| 74 |
+
"status": "success",
|
| 75 |
+
"model": self.model
|
| 76 |
+
}
|
| 77 |
+
except Exception as e:
|
| 78 |
+
if attempt < max_retries:
|
| 79 |
+
continue
|
| 80 |
+
else:
|
| 81 |
+
fallback = f"The question asks: {question}. The answer is: {answer}."
|
| 82 |
+
return {
|
| 83 |
+
"description": fallback,
|
| 84 |
+
"status": "fallback",
|
| 85 |
+
"error": str(e)
|
| 86 |
+
}
|
| 87 |
+
def generate_batch_descriptions(
|
| 88 |
+
self,
|
| 89 |
+
qa_pairs: list[Dict[str, str]]
|
| 90 |
+
) -> list[Dict[str, str]]:
|
| 91 |
+
"""
|
| 92 |
+
Generate descriptions for multiple Q&A pairs
|
| 93 |
+
Args:
|
| 94 |
+
qa_pairs: List of dicts with 'question' and 'answer' keys
|
| 95 |
+
Returns:
|
| 96 |
+
List of description results
|
| 97 |
+
"""
|
| 98 |
+
results = []
|
| 99 |
+
for pair in qa_pairs:
|
| 100 |
+
result = self.generate_description(
|
| 101 |
+
question=pair.get("question", ""),
|
| 102 |
+
answer=pair.get("answer", "")
|
| 103 |
+
)
|
| 104 |
+
results.append(result)
|
| 105 |
+
return results
|
| 106 |
+
_groq_service_instance = None
|
| 107 |
+
def get_groq_service(api_key: Optional[str] = None) -> GroqDescriptionService:
|
| 108 |
+
"""
|
| 109 |
+
Get or create Groq service singleton
|
| 110 |
+
Args:
|
| 111 |
+
api_key: Optional API key (uses env var if not provided)
|
| 112 |
+
Returns:
|
| 113 |
+
GroqDescriptionService instance
|
| 114 |
+
"""
|
| 115 |
+
global _groq_service_instance
|
| 116 |
+
if _groq_service_instance is None:
|
| 117 |
+
_groq_service_instance = GroqDescriptionService(api_key=api_key)
|
| 118 |
+
return _groq_service_instance
|
knowledge_graph_service.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Knowledge Graph Service for Neuro-Symbolic VQA
|
| 3 |
+
Uses ConceptNet API to provide common-sense reasoning capabilities
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
import re
|
| 8 |
+
from typing import Dict, List, Optional
|
| 9 |
+
from functools import lru_cache
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class KnowledgeGraphService:
|
| 14 |
+
"""
|
| 15 |
+
Lightweight ConceptNet API wrapper for common-sense reasoning.
|
| 16 |
+
Enhances VQA answers with external knowledge about object properties,
|
| 17 |
+
capabilities, uses, and relationships.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
CONCEPTNET_API = "https://api.conceptnet.io"
|
| 21 |
+
|
| 22 |
+
# Common-sense question patterns
|
| 23 |
+
COMMONSENSE_PATTERNS = [
|
| 24 |
+
# Capability questions
|
| 25 |
+
(r'can .* (melt|freeze|fly|swim|float|sink|break|burn|explode)', 'CapableOf'),
|
| 26 |
+
(r'is .* able to', 'CapableOf'),
|
| 27 |
+
(r'does .* (float|sink)', 'CapableOf'),
|
| 28 |
+
|
| 29 |
+
# Property questions
|
| 30 |
+
(r'is .* (edible|poisonous|dangerous|safe|hot|cold|sweet|sour)', 'HasProperty'),
|
| 31 |
+
(r'is this (food|drink|toy|tool|weapon)', 'HasProperty'),
|
| 32 |
+
|
| 33 |
+
# Purpose questions
|
| 34 |
+
(r'what .* (used for|for)', 'UsedFor'),
|
| 35 |
+
(r'why .* (used|made)', 'UsedFor'),
|
| 36 |
+
(r'how .* use', 'UsedFor'),
|
| 37 |
+
|
| 38 |
+
# Composition questions
|
| 39 |
+
(r'what .* made (of|from)', 'MadeOf'),
|
| 40 |
+
(r'what .* (material|ingredient)', 'MadeOf'),
|
| 41 |
+
|
| 42 |
+
# Location questions
|
| 43 |
+
(r'where .* (found|located|kept|stored)', 'AtLocation'),
|
| 44 |
+
(r'where (do|does) .* (live|grow)', 'AtLocation'),
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
def __init__(self, cache_size=100, timeout=5):
|
| 48 |
+
"""
|
| 49 |
+
Initialize Knowledge Graph service.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
cache_size: Number of API responses to cache
|
| 53 |
+
timeout: API request timeout in seconds
|
| 54 |
+
"""
|
| 55 |
+
self.timeout = timeout
|
| 56 |
+
self.cache_size = cache_size
|
| 57 |
+
print("✅ Knowledge Graph service initialized (ConceptNet API)")
|
| 58 |
+
|
| 59 |
+
@lru_cache(maxsize=100)
|
| 60 |
+
def _query_conceptnet(self, concept: str, relation: str, limit: int = 10) -> Optional[Dict]:
|
| 61 |
+
"""
|
| 62 |
+
Query ConceptNet API with caching.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
concept: Concept to query (e.g., "ice_cream")
|
| 66 |
+
relation: Relation type (e.g., "CapableOf", "HasProperty")
|
| 67 |
+
limit: Maximum number of results
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
API response dict or None if failed
|
| 71 |
+
"""
|
| 72 |
+
try:
|
| 73 |
+
# Normalize concept (replace spaces with underscores)
|
| 74 |
+
concept = concept.lower().replace(' ', '_')
|
| 75 |
+
|
| 76 |
+
# Build API URL
|
| 77 |
+
url = f"{self.CONCEPTNET_API}/query"
|
| 78 |
+
params = {
|
| 79 |
+
'start': f'/c/en/{concept}',
|
| 80 |
+
'rel': f'/r/{relation}',
|
| 81 |
+
'limit': limit
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# Make request
|
| 85 |
+
response = requests.get(url, params=params, timeout=self.timeout)
|
| 86 |
+
response.raise_for_status()
|
| 87 |
+
|
| 88 |
+
return response.json()
|
| 89 |
+
|
| 90 |
+
except requests.exceptions.Timeout:
|
| 91 |
+
print(f"⚠️ ConceptNet API timeout for {concept}")
|
| 92 |
+
return None
|
| 93 |
+
except requests.exceptions.RequestException as e:
|
| 94 |
+
print(f"⚠️ ConceptNet API error: {e}")
|
| 95 |
+
return None
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"⚠️ Unexpected error querying ConceptNet: {e}")
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
def get_concept_properties(self, concept: str) -> Dict[str, List[str]]:
|
| 101 |
+
|
| 102 |
+
properties = {
|
| 103 |
+
'CapableOf': [],
|
| 104 |
+
'HasProperty': [],
|
| 105 |
+
'UsedFor': [],
|
| 106 |
+
'MadeOf': [],
|
| 107 |
+
'AtLocation': []
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
# Query each relation type
|
| 111 |
+
for relation in properties.keys():
|
| 112 |
+
data = self._query_conceptnet(concept, relation)
|
| 113 |
+
|
| 114 |
+
if data and 'edges' in data:
|
| 115 |
+
for edge in data['edges']:
|
| 116 |
+
# Extract the end concept
|
| 117 |
+
if 'end' in edge and 'label' in edge['end']:
|
| 118 |
+
end_label = edge['end']['label']
|
| 119 |
+
properties[relation].append(end_label)
|
| 120 |
+
|
| 121 |
+
return properties
|
| 122 |
+
|
| 123 |
+
def is_commonsense_question(self, question: str) -> bool:
|
| 124 |
+
"""
|
| 125 |
+
Detect if a question requires common-sense reasoning.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
question: Question string
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
True if question needs external knowledge
|
| 132 |
+
"""
|
| 133 |
+
q_lower = question.lower()
|
| 134 |
+
|
| 135 |
+
for pattern, _ in self.COMMONSENSE_PATTERNS:
|
| 136 |
+
if re.search(pattern, q_lower):
|
| 137 |
+
return True
|
| 138 |
+
|
| 139 |
+
return False
|
| 140 |
+
|
| 141 |
+
def _detect_question_type(self, question: str) -> Optional[str]:
|
| 142 |
+
"""
|
| 143 |
+
Detect which ConceptNet relation the question is asking about.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
question: Question string
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
Relation type or None
|
| 150 |
+
"""
|
| 151 |
+
q_lower = question.lower()
|
| 152 |
+
|
| 153 |
+
for pattern, relation in self.COMMONSENSE_PATTERNS:
|
| 154 |
+
if re.search(pattern, q_lower):
|
| 155 |
+
return relation
|
| 156 |
+
|
| 157 |
+
return None
|
| 158 |
+
|
| 159 |
+
def answer_commonsense_question(self, object_name: str, question: str) -> Optional[str]:
|
| 160 |
+
"""
|
| 161 |
+
Answer a common-sense question using Knowledge Graph.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
object_name: Object detected by VQA (e.g., "ice cream")
|
| 165 |
+
question: User's question
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Enhanced answer string or None
|
| 169 |
+
"""
|
| 170 |
+
# Detect question type
|
| 171 |
+
relation = self._detect_question_type(question)
|
| 172 |
+
if not relation:
|
| 173 |
+
return None
|
| 174 |
+
|
| 175 |
+
# Query ConceptNet
|
| 176 |
+
data = self._query_conceptnet(object_name, relation, limit=5)
|
| 177 |
+
if not data or 'edges' not in data:
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
# Extract relevant knowledge
|
| 181 |
+
knowledge = []
|
| 182 |
+
for edge in data['edges']:
|
| 183 |
+
if 'end' in edge and 'label' in edge['end']:
|
| 184 |
+
knowledge.append(edge['end']['label'])
|
| 185 |
+
|
| 186 |
+
if not knowledge:
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
# Generate natural language answer based on question type
|
| 190 |
+
return self._synthesize_answer(object_name, question, relation, knowledge)
|
| 191 |
+
|
| 192 |
+
def _synthesize_answer(self, object_name: str, question: str,
|
| 193 |
+
relation: str, knowledge: List[str]) -> str:
|
| 194 |
+
"""
|
| 195 |
+
Synthesize natural language answer from knowledge.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
object_name: Detected object
|
| 199 |
+
question: Original question
|
| 200 |
+
relation: Relation type
|
| 201 |
+
knowledge: List of related concepts from KG
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Natural language answer
|
| 205 |
+
"""
|
| 206 |
+
q_lower = question.lower()
|
| 207 |
+
|
| 208 |
+
# Capability questions (can X do Y?)
|
| 209 |
+
if relation == 'CapableOf':
|
| 210 |
+
# Check if specific capability is mentioned
|
| 211 |
+
for capability in knowledge:
|
| 212 |
+
if capability in q_lower:
|
| 213 |
+
return f"Yes, {object_name} can {capability}."
|
| 214 |
+
|
| 215 |
+
# General capability answer
|
| 216 |
+
if knowledge:
|
| 217 |
+
caps = ', '.join(knowledge[:3])
|
| 218 |
+
return f"{object_name.capitalize()} can {caps}."
|
| 219 |
+
|
| 220 |
+
# Property questions (is X Y?)
|
| 221 |
+
elif relation == 'HasProperty':
|
| 222 |
+
# Check for specific property
|
| 223 |
+
if 'edible' in q_lower:
|
| 224 |
+
if 'edible' in knowledge:
|
| 225 |
+
return f"Yes, {object_name} is edible."
|
| 226 |
+
else:
|
| 227 |
+
return f"No, {object_name} is not edible."
|
| 228 |
+
|
| 229 |
+
if 'dangerous' in q_lower or 'safe' in q_lower:
|
| 230 |
+
if any(prop in knowledge for prop in ['dangerous', 'harmful', 'poisonous']):
|
| 231 |
+
return f"Caution: {object_name} may be dangerous."
|
| 232 |
+
else:
|
| 233 |
+
return f"{object_name.capitalize()} is generally safe."
|
| 234 |
+
|
| 235 |
+
# General properties
|
| 236 |
+
if knowledge:
|
| 237 |
+
props = ', '.join(knowledge[:3])
|
| 238 |
+
return f"{object_name.capitalize()} is {props}."
|
| 239 |
+
|
| 240 |
+
# Purpose questions (what is X used for?)
|
| 241 |
+
elif relation == 'UsedFor':
|
| 242 |
+
if knowledge:
|
| 243 |
+
uses = ', '.join(knowledge[:3])
|
| 244 |
+
return f"{object_name.capitalize()} is used for {uses}."
|
| 245 |
+
|
| 246 |
+
# Composition questions (what is X made of?)
|
| 247 |
+
elif relation == 'MadeOf':
|
| 248 |
+
if knowledge:
|
| 249 |
+
materials = ', '.join(knowledge[:3])
|
| 250 |
+
return f"{object_name.capitalize()} is made of {materials}."
|
| 251 |
+
|
| 252 |
+
# Location questions (where is X found?)
|
| 253 |
+
elif relation == 'AtLocation':
|
| 254 |
+
if knowledge:
|
| 255 |
+
locations = ', '.join(knowledge[:2])
|
| 256 |
+
return f"{object_name.capitalize()} is typically found at {locations}."
|
| 257 |
+
|
| 258 |
+
return None
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# Test function
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
print("=" * 80)
|
| 264 |
+
print("🧪 Testing Knowledge Graph Service")
|
| 265 |
+
print("=" * 80)
|
| 266 |
+
|
| 267 |
+
kg = KnowledgeGraphService()
|
| 268 |
+
|
| 269 |
+
# Test cases
|
| 270 |
+
test_cases = [
|
| 271 |
+
("ice cream", "Can this melt?"),
|
| 272 |
+
("apple", "Is this edible?"),
|
| 273 |
+
("hammer", "What is this used for?"),
|
| 274 |
+
("knife", "Is this dangerous?"),
|
| 275 |
+
("bread", "What is this made of?"),
|
| 276 |
+
]
|
| 277 |
+
|
| 278 |
+
for obj, question in test_cases:
|
| 279 |
+
print(f"\n📝 Object: {obj}")
|
| 280 |
+
print(f"❓ Question: {question}")
|
| 281 |
+
|
| 282 |
+
# Check if common-sense question
|
| 283 |
+
is_cs = kg.is_commonsense_question(question)
|
| 284 |
+
print(f"🔍 Common-sense: {is_cs}")
|
| 285 |
+
|
| 286 |
+
if is_cs:
|
| 287 |
+
# Get answer
|
| 288 |
+
answer = kg.answer_commonsense_question(obj, question)
|
| 289 |
+
print(f"💬 Answer: {answer}")
|
| 290 |
+
|
| 291 |
+
print("-" * 80)
|
llm_reasoning_service.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Reasoning Service for VQA
|
| 3 |
+
Uses Groq LLM for Chain-of-Thought reasoning instead of hardcoded rules
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
from typing import Dict, List, Optional, Any
|
| 7 |
+
from groq import Groq
|
| 8 |
+
import json
|
| 9 |
+
class LLMReasoningService:
|
| 10 |
+
"""
|
| 11 |
+
Service that uses Groq LLM for deductive reasoning from Wikidata facts.
|
| 12 |
+
Replaces hardcoded if/else rules with flexible Chain-of-Thought reasoning.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, api_key: Optional[str] = None, model: str = "llama-3.3-70b-versatile"):
|
| 15 |
+
"""
|
| 16 |
+
Initialize LLM Reasoning service
|
| 17 |
+
Args:
|
| 18 |
+
api_key: Groq API key (if not provided, reads from GROQ_API_KEY env var)
|
| 19 |
+
model: Groq model to use for reasoning
|
| 20 |
+
"""
|
| 21 |
+
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
| 22 |
+
if not self.api_key:
|
| 23 |
+
raise ValueError(
|
| 24 |
+
"Groq API key not found. Set GROQ_API_KEY environment variable "
|
| 25 |
+
"or pass api_key parameter"
|
| 26 |
+
)
|
| 27 |
+
self.client = Groq(api_key=self.api_key)
|
| 28 |
+
self.model = model
|
| 29 |
+
print(f"✅ LLM Reasoning Service initialized (model: {model})")
|
| 30 |
+
def reason_with_facts(
|
| 31 |
+
self,
|
| 32 |
+
object_name: str,
|
| 33 |
+
facts: Dict[str, Any],
|
| 34 |
+
question: str,
|
| 35 |
+
max_retries: int = 2
|
| 36 |
+
) -> Dict[str, Any]:
|
| 37 |
+
"""
|
| 38 |
+
Use LLM to reason about a question using Wikidata facts.
|
| 39 |
+
Args:
|
| 40 |
+
object_name: Name of the detected object (e.g., "candle")
|
| 41 |
+
facts: Dictionary of Wikidata facts about the object
|
| 42 |
+
question: User's question
|
| 43 |
+
max_retries: Number of retry attempts on failure
|
| 44 |
+
Returns:
|
| 45 |
+
Dict with 'answer', 'reasoning_chain', and 'confidence' keys
|
| 46 |
+
Example:
|
| 47 |
+
>>> service.reason_with_facts(
|
| 48 |
+
... object_name="candle",
|
| 49 |
+
... facts={"materials": ["wax"], "categories": ["light source"]},
|
| 50 |
+
... question="Can this melt?"
|
| 51 |
+
... )
|
| 52 |
+
{
|
| 53 |
+
'answer': 'Yes, the candle can melt because it is made of wax...',
|
| 54 |
+
'reasoning_chain': [
|
| 55 |
+
'The object is a candle',
|
| 56 |
+
'It is made of wax',
|
| 57 |
+
'Wax has a low melting point',
|
| 58 |
+
'Therefore, yes, it can melt'
|
| 59 |
+
],
|
| 60 |
+
'confidence': 0.95
|
| 61 |
+
}
|
| 62 |
+
"""
|
| 63 |
+
prompt = self._build_reasoning_prompt(object_name, facts, question)
|
| 64 |
+
for attempt in range(max_retries + 1):
|
| 65 |
+
try:
|
| 66 |
+
response = self.client.chat.completions.create(
|
| 67 |
+
model=self.model,
|
| 68 |
+
messages=[
|
| 69 |
+
{
|
| 70 |
+
"role": "system",
|
| 71 |
+
"content": """You are an expert reasoning assistant for a Visual Question Answering system.
|
| 72 |
+
Your task is to use Chain-of-Thought reasoning to answer questions about objects based on factual knowledge.
|
| 73 |
+
IMPORTANT: Respond in JSON format with this structure:
|
| 74 |
+
{
|
| 75 |
+
"reasoning_chain": ["step 1", "step 2", "step 3"],
|
| 76 |
+
"answer": "final answer in natural language",
|
| 77 |
+
"confidence": 0.0-1.0
|
| 78 |
+
}
|
| 79 |
+
Keep reasoning steps clear and logical. The answer should be conversational and helpful."""
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"role": "user",
|
| 83 |
+
"content": prompt
|
| 84 |
+
}
|
| 85 |
+
],
|
| 86 |
+
temperature=0.3,
|
| 87 |
+
max_tokens=500,
|
| 88 |
+
response_format={"type": "json_object"}
|
| 89 |
+
)
|
| 90 |
+
content = response.choices[0].message.content.strip()
|
| 91 |
+
result = json.loads(content)
|
| 92 |
+
if not all(key in result for key in ['reasoning_chain', 'answer', 'confidence']):
|
| 93 |
+
raise ValueError("Invalid response structure from LLM")
|
| 94 |
+
return {
|
| 95 |
+
'answer': result['answer'],
|
| 96 |
+
'reasoning_chain': result['reasoning_chain'],
|
| 97 |
+
'confidence': float(result['confidence']),
|
| 98 |
+
'status': 'success',
|
| 99 |
+
'model': self.model
|
| 100 |
+
}
|
| 101 |
+
except json.JSONDecodeError as e:
|
| 102 |
+
if attempt < max_retries:
|
| 103 |
+
continue
|
| 104 |
+
else:
|
| 105 |
+
return self._fallback_reasoning(object_name, facts, question)
|
| 106 |
+
except Exception as e:
|
| 107 |
+
if attempt < max_retries:
|
| 108 |
+
continue
|
| 109 |
+
else:
|
| 110 |
+
print(f"⚠️ LLM reasoning failed: {e}")
|
| 111 |
+
return self._fallback_reasoning(object_name, facts, question)
|
| 112 |
+
def _build_reasoning_prompt(
|
| 113 |
+
self,
|
| 114 |
+
object_name: str,
|
| 115 |
+
facts: Dict[str, Any],
|
| 116 |
+
question: str
|
| 117 |
+
) -> str:
|
| 118 |
+
"""
|
| 119 |
+
Build a Chain-of-Thought reasoning prompt.
|
| 120 |
+
Args:
|
| 121 |
+
object_name: Name of the object
|
| 122 |
+
facts: Wikidata facts about the object
|
| 123 |
+
question: User's question
|
| 124 |
+
Returns:
|
| 125 |
+
Formatted prompt string
|
| 126 |
+
"""
|
| 127 |
+
facts_text = self._format_facts(facts)
|
| 128 |
+
prompt = f"""Question: {question}
|
| 129 |
+
Object Detected: {object_name}
|
| 130 |
+
Available Facts from Knowledge Graph:
|
| 131 |
+
{facts_text}
|
| 132 |
+
Task: Use Chain-of-Thought reasoning to answer the question based on the available facts.
|
| 133 |
+
Example of good reasoning:
|
| 134 |
+
Question: "Can this melt?"
|
| 135 |
+
Object: "ice cream"
|
| 136 |
+
Facts: {{
|
| 137 |
+
"categories": ["frozen dessert", "food"],
|
| 138 |
+
"materials": ["milk", "sugar", "cream"]
|
| 139 |
+
}}
|
| 140 |
+
Reasoning:
|
| 141 |
+
{{
|
| 142 |
+
"reasoning_chain": [
|
| 143 |
+
"The object is ice cream, which is a frozen dessert",
|
| 144 |
+
"Ice cream is made of milk, sugar, and cream",
|
| 145 |
+
"These ingredients are frozen to create ice cream",
|
| 146 |
+
"Frozen items melt when exposed to heat",
|
| 147 |
+
"Therefore, yes, ice cream can melt at room temperature"
|
| 148 |
+
],
|
| 149 |
+
"answer": "Yes, ice cream can melt. It's a frozen dessert made from milk, sugar, and cream, which will melt when exposed to temperatures above freezing.",
|
| 150 |
+
"confidence": 0.95
|
| 151 |
+
}}
|
| 152 |
+
Now reason about the actual question above:"""
|
| 153 |
+
return prompt
|
| 154 |
+
def _format_facts(self, facts: Dict[str, Any]) -> str:
|
| 155 |
+
"""Format facts dictionary into readable text."""
|
| 156 |
+
if not facts:
|
| 157 |
+
return "No specific facts available"
|
| 158 |
+
lines = []
|
| 159 |
+
for key, value in facts.items():
|
| 160 |
+
if isinstance(value, list):
|
| 161 |
+
if value:
|
| 162 |
+
lines.append(f" - {key}: {', '.join(str(v) for v in value)}")
|
| 163 |
+
elif value:
|
| 164 |
+
lines.append(f" - {key}: {value}")
|
| 165 |
+
return "\n".join(lines) if lines else "No specific facts available"
|
| 166 |
+
def _fallback_reasoning(
|
| 167 |
+
self,
|
| 168 |
+
object_name: str,
|
| 169 |
+
facts: Dict[str, Any],
|
| 170 |
+
question: str
|
| 171 |
+
) -> Dict[str, Any]:
|
| 172 |
+
"""
|
| 173 |
+
Fallback reasoning when LLM fails.
|
| 174 |
+
Uses simple rule-based approach.
|
| 175 |
+
Args:
|
| 176 |
+
object_name: Name of the object
|
| 177 |
+
facts: Wikidata facts
|
| 178 |
+
question: User's question
|
| 179 |
+
Returns:
|
| 180 |
+
Fallback reasoning result
|
| 181 |
+
"""
|
| 182 |
+
q_lower = question.lower()
|
| 183 |
+
if 'melt' in q_lower:
|
| 184 |
+
materials = facts.get('materials', [])
|
| 185 |
+
if any(m in ['wax', 'ice', 'chocolate', 'butter'] for m in materials):
|
| 186 |
+
return {
|
| 187 |
+
'answer': f"Yes, {object_name} can melt as it contains materials with low melting points.",
|
| 188 |
+
'reasoning_chain': [
|
| 189 |
+
f"The {object_name} contains materials that can melt",
|
| 190 |
+
"These materials have low melting points",
|
| 191 |
+
"Therefore, it can melt when heated"
|
| 192 |
+
],
|
| 193 |
+
'confidence': 0.7,
|
| 194 |
+
'status': 'fallback'
|
| 195 |
+
}
|
| 196 |
+
if 'edible' in q_lower or 'eat' in q_lower:
|
| 197 |
+
categories = facts.get('categories', [])
|
| 198 |
+
if any('food' in str(c).lower() for c in categories):
|
| 199 |
+
return {
|
| 200 |
+
'answer': f"Yes, {object_name} is edible as it is categorized as food.",
|
| 201 |
+
'reasoning_chain': [
|
| 202 |
+
f"The {object_name} is categorized as food",
|
| 203 |
+
"Food items are generally edible",
|
| 204 |
+
"Therefore, it is edible"
|
| 205 |
+
],
|
| 206 |
+
'confidence': 0.8,
|
| 207 |
+
'status': 'fallback'
|
| 208 |
+
}
|
| 209 |
+
return {
|
| 210 |
+
'answer': f"Based on the available information about {object_name}, I cannot provide a definitive answer to this question.",
|
| 211 |
+
'reasoning_chain': [
|
| 212 |
+
f"Analyzing {object_name}",
|
| 213 |
+
"Available facts are limited",
|
| 214 |
+
"Cannot make a confident conclusion"
|
| 215 |
+
],
|
| 216 |
+
'confidence': 0.3,
|
| 217 |
+
'status': 'fallback_generic'
|
| 218 |
+
}
|
| 219 |
+
def batch_reason(
|
| 220 |
+
self,
|
| 221 |
+
reasoning_tasks: List[Dict[str, Any]]
|
| 222 |
+
) -> List[Dict[str, Any]]:
|
| 223 |
+
"""
|
| 224 |
+
Perform reasoning on multiple tasks.
|
| 225 |
+
Args:
|
| 226 |
+
reasoning_tasks: List of dicts with 'object_name', 'facts', 'question' keys
|
| 227 |
+
Returns:
|
| 228 |
+
List of reasoning results
|
| 229 |
+
"""
|
| 230 |
+
results = []
|
| 231 |
+
for task in reasoning_tasks:
|
| 232 |
+
result = self.reason_with_facts(
|
| 233 |
+
object_name=task.get('object_name', ''),
|
| 234 |
+
facts=task.get('facts', {}),
|
| 235 |
+
question=task.get('question', '')
|
| 236 |
+
)
|
| 237 |
+
results.append(result)
|
| 238 |
+
return results
|
| 239 |
+
_llm_reasoning_instance = None
|
| 240 |
+
def get_llm_reasoning_service(api_key: Optional[str] = None) -> LLMReasoningService:
|
| 241 |
+
"""
|
| 242 |
+
Get or create LLM Reasoning service singleton
|
| 243 |
+
Args:
|
| 244 |
+
api_key: Optional API key (uses env var if not provided)
|
| 245 |
+
Returns:
|
| 246 |
+
LLMReasoningService instance
|
| 247 |
+
"""
|
| 248 |
+
global _llm_reasoning_instance
|
| 249 |
+
if _llm_reasoning_instance is None:
|
| 250 |
+
_llm_reasoning_instance = LLMReasoningService(api_key=api_key)
|
| 251 |
+
return _llm_reasoning_instance
|
| 252 |
+
if __name__ == "__main__":
|
| 253 |
+
print("=" * 80)
|
| 254 |
+
print("🧪 Testing LLM Reasoning Service")
|
| 255 |
+
print("=" * 80)
|
| 256 |
+
try:
|
| 257 |
+
service = get_llm_reasoning_service()
|
| 258 |
+
print("\n📝 Test 1: Can a candle melt?")
|
| 259 |
+
result = service.reason_with_facts(
|
| 260 |
+
object_name="candle",
|
| 261 |
+
facts={
|
| 262 |
+
"materials": ["wax", "wick"],
|
| 263 |
+
"categories": ["light source", "household item"],
|
| 264 |
+
"uses": ["provide light", "decoration"]
|
| 265 |
+
},
|
| 266 |
+
question="Can this melt?"
|
| 267 |
+
)
|
| 268 |
+
print(f"Answer: {result['answer']}")
|
| 269 |
+
print(f"Reasoning Chain:")
|
| 270 |
+
for i, step in enumerate(result['reasoning_chain'], 1):
|
| 271 |
+
print(f" {i}. {step}")
|
| 272 |
+
print(f"Confidence: {result['confidence']}")
|
| 273 |
+
print("\n📝 Test 2: Would ice cream survive in the desert?")
|
| 274 |
+
result = service.reason_with_facts(
|
| 275 |
+
object_name="ice cream",
|
| 276 |
+
facts={
|
| 277 |
+
"materials": ["milk", "sugar", "cream"],
|
| 278 |
+
"categories": ["frozen dessert", "food"],
|
| 279 |
+
"properties": ["cold", "frozen"]
|
| 280 |
+
},
|
| 281 |
+
question="Would this survive in the desert?"
|
| 282 |
+
)
|
| 283 |
+
print(f"Answer: {result['answer']}")
|
| 284 |
+
print(f"Reasoning Chain:")
|
| 285 |
+
for i, step in enumerate(result['reasoning_chain'], 1):
|
| 286 |
+
print(f" {i}. {step}")
|
| 287 |
+
print(f"Confidence: {result['confidence']}")
|
| 288 |
+
print("\n" + "=" * 80)
|
| 289 |
+
print("✅ Tests completed!")
|
| 290 |
+
except ValueError as e:
|
| 291 |
+
print(f"\n❌ Error: {e}")
|
| 292 |
+
print("Please set GROQ_API_KEY environment variable")
|
model.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import clip
|
| 4 |
+
from transformers import GPT2Model
|
| 5 |
+
class AttentionDecoder(nn.Module):
|
| 6 |
+
def __init__(self, hidden_size, vocab_size, num_layers=1, dropout=0.3):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.hidden_size = hidden_size
|
| 9 |
+
self.num_layers = num_layers
|
| 10 |
+
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
| 11 |
+
self.attention = nn.Linear(hidden_size * 2, 1)
|
| 12 |
+
self.gru = nn.GRU(
|
| 13 |
+
input_size=hidden_size * 2,
|
| 14 |
+
hidden_size=hidden_size,
|
| 15 |
+
num_layers=num_layers,
|
| 16 |
+
batch_first=True,
|
| 17 |
+
dropout=dropout if num_layers > 1 else 0
|
| 18 |
+
)
|
| 19 |
+
self.ln_gru = nn.LayerNorm(hidden_size)
|
| 20 |
+
self.output = nn.Linear(hidden_size, vocab_size)
|
| 21 |
+
def forward(self, input_ids, context, hidden):
|
| 22 |
+
if input_ids.dim() == 1:
|
| 23 |
+
input_ids = input_ids.unsqueeze(1)
|
| 24 |
+
embeddings = self.embedding(input_ids).float()
|
| 25 |
+
context_expanded = context.unsqueeze(1).expand(-1, embeddings.size(1), -1)
|
| 26 |
+
combined = torch.cat([embeddings, context_expanded], dim=-1)
|
| 27 |
+
attn_weights = torch.softmax(self.attention(combined), dim=1)
|
| 28 |
+
attended_context = (context_expanded * attn_weights).sum(dim=1, keepdim=True)
|
| 29 |
+
gru_input = torch.cat([embeddings, attended_context.expand(-1, embeddings.size(1), -1)], dim=-1)
|
| 30 |
+
gru_output, hidden = self.gru(gru_input, hidden)
|
| 31 |
+
gru_output = self.ln_gru(gru_output)
|
| 32 |
+
return self.output(gru_output), hidden
|
| 33 |
+
class VQAModel(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
vocab_size=3600,
|
| 37 |
+
question_max_len=16,
|
| 38 |
+
answer_max_len=10,
|
| 39 |
+
hidden_size=512,
|
| 40 |
+
num_layers=2,
|
| 41 |
+
dropout=0.3,
|
| 42 |
+
device='cuda',
|
| 43 |
+
pad_token_id=0,
|
| 44 |
+
bos_token_id=1,
|
| 45 |
+
eos_token_id=2,
|
| 46 |
+
unk_token_id=3
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.device = device
|
| 50 |
+
self.question_max_len = question_max_len
|
| 51 |
+
self.answer_max_len = answer_max_len
|
| 52 |
+
self.vocab_size = vocab_size
|
| 53 |
+
self.hidden_size = hidden_size
|
| 54 |
+
self.num_layers = num_layers
|
| 55 |
+
self.fine_tuning_mode = False
|
| 56 |
+
self.pad_token_id = pad_token_id
|
| 57 |
+
self.bos_token_id = bos_token_id
|
| 58 |
+
self.eos_token_id = eos_token_id
|
| 59 |
+
self.unk_token_id = unk_token_id
|
| 60 |
+
self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=device)
|
| 61 |
+
for p in self.clip_model.parameters():
|
| 62 |
+
p.requires_grad = False
|
| 63 |
+
self.gpt2_model = GPT2Model.from_pretrained("distilgpt2")
|
| 64 |
+
self.gpt2_model.to(device)
|
| 65 |
+
for p in self.gpt2_model.parameters():
|
| 66 |
+
p.requires_grad = False
|
| 67 |
+
self.img_proj = nn.Linear(512, hidden_size)
|
| 68 |
+
self.q_proj = nn.Linear(768, hidden_size)
|
| 69 |
+
self.gate_layer = nn.Linear(hidden_size*2, hidden_size)
|
| 70 |
+
self.fusion = nn.Sequential(
|
| 71 |
+
nn.Linear(hidden_size*3, hidden_size),
|
| 72 |
+
nn.ReLU(),
|
| 73 |
+
nn.Dropout(dropout),
|
| 74 |
+
nn.Linear(hidden_size, hidden_size)
|
| 75 |
+
)
|
| 76 |
+
self.decoder = AttentionDecoder(hidden_size, vocab_size, num_layers, dropout)
|
| 77 |
+
def unfreeze_clip_layers(self, num_layers=2):
|
| 78 |
+
self.clip_model.train()
|
| 79 |
+
self.clip_model.visual.float()
|
| 80 |
+
total_blocks = len(self.clip_model.visual.transformer.resblocks)
|
| 81 |
+
for i, block in enumerate(self.clip_model.visual.transformer.resblocks):
|
| 82 |
+
if i >= total_blocks - num_layers:
|
| 83 |
+
for p in block.parameters():
|
| 84 |
+
p.requires_grad = True
|
| 85 |
+
if hasattr(self.clip_model.visual, "proj") and self.clip_model.visual.proj is not None:
|
| 86 |
+
if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
|
| 87 |
+
self.clip_model.visual.proj.requires_grad = True
|
| 88 |
+
else:
|
| 89 |
+
for p in self.clip_model.visual.proj.parameters():
|
| 90 |
+
p.requires_grad = True
|
| 91 |
+
if hasattr(self.clip_model.visual, "ln_post"):
|
| 92 |
+
for p in self.clip_model.visual.ln_post.parameters():
|
| 93 |
+
p.requires_grad = True
|
| 94 |
+
self.fine_tuning_mode = True
|
| 95 |
+
print(f"Unfrozen last {num_layers} CLIP layers")
|
| 96 |
+
def unfreeze_gpt2_layers(self, num_layers=1):
|
| 97 |
+
self.gpt2_model.train()
|
| 98 |
+
total_layers = len(self.gpt2_model.h)
|
| 99 |
+
for i, layer in enumerate(self.gpt2_model.h):
|
| 100 |
+
if i >= total_layers - num_layers:
|
| 101 |
+
for p in layer.parameters():
|
| 102 |
+
p.requires_grad = True
|
| 103 |
+
p.data = p.data.float()
|
| 104 |
+
for p in self.gpt2_model.ln_f.parameters():
|
| 105 |
+
p.requires_grad = True
|
| 106 |
+
p.data = p.data.float()
|
| 107 |
+
self.fine_tuning_mode = True
|
| 108 |
+
print(f"Unfrozen last {num_layers} GPT-2 layers")
|
| 109 |
+
def encode_image(self, images):
|
| 110 |
+
if self.fine_tuning_mode:
|
| 111 |
+
images = images.float()
|
| 112 |
+
img_features = self.clip_model.encode_image(images)
|
| 113 |
+
else:
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
img_features = self.clip_model.encode_image(images)
|
| 116 |
+
img_features = img_features / img_features.norm(dim=-1, keepdim=True)
|
| 117 |
+
return img_features.float()
|
| 118 |
+
def encode_question(self, input_ids, attention_mask):
|
| 119 |
+
if self.fine_tuning_mode:
|
| 120 |
+
outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 121 |
+
else:
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 124 |
+
last_hidden = outputs.last_hidden_state
|
| 125 |
+
mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
|
| 126 |
+
masked = last_hidden * mask
|
| 127 |
+
sum_hidden = masked.sum(dim=1)
|
| 128 |
+
lengths = mask.sum(dim=1).clamp(min=1e-6)
|
| 129 |
+
text_features = sum_hidden / lengths
|
| 130 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 131 |
+
return text_features.float()
|
| 132 |
+
def fuse_features(self, img_features, q_features):
|
| 133 |
+
x = torch.cat([img_features, q_features], dim=-1)
|
| 134 |
+
gate = torch.sigmoid(self.gate_layer(x))
|
| 135 |
+
fused = gate * img_features + (1-gate) * q_features
|
| 136 |
+
fused = self.fusion(torch.cat([fused, x], dim=-1))
|
| 137 |
+
return fused
|
| 138 |
+
def forward(self, images, questions, answer_input_ids=None):
|
| 139 |
+
img_features = self.encode_image(images)
|
| 140 |
+
img_features = self.img_proj(img_features).float()
|
| 141 |
+
q_features = self.encode_question(questions["input_ids"], questions["attention_mask"])
|
| 142 |
+
q_features = self.q_proj(q_features).float()
|
| 143 |
+
batch_size = img_features.size(0)
|
| 144 |
+
context = self.fuse_features(img_features, q_features)
|
| 145 |
+
hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size,
|
| 146 |
+
device=self.device, dtype=torch.float)
|
| 147 |
+
if answer_input_ids is not None:
|
| 148 |
+
logits, _ = self.decoder(answer_input_ids, context, hidden)
|
| 149 |
+
return logits
|
| 150 |
+
else:
|
| 151 |
+
generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id,
|
| 152 |
+
dtype=torch.long, device=self.device)
|
| 153 |
+
generated[:, 0] = self.bos_token_id
|
| 154 |
+
for t in range(1, self.answer_max_len):
|
| 155 |
+
current_input = generated[:, t-1]
|
| 156 |
+
logits, hidden = self.decoder(current_input, context, hidden)
|
| 157 |
+
next_tokens = logits.squeeze(1).argmax(dim=-1)
|
| 158 |
+
generated[:, t] = next_tokens
|
| 159 |
+
if (next_tokens == self.eos_token_id).all():
|
| 160 |
+
break
|
| 161 |
+
return generated
|
| 162 |
+
def generate_with_beam_search(self, images, questions, beam_width=5):
|
| 163 |
+
batch_size = images.size(0)
|
| 164 |
+
all_results = []
|
| 165 |
+
for b in range(batch_size):
|
| 166 |
+
img = images[b:b+1]
|
| 167 |
+
q_ids = questions["input_ids"][b:b+1]
|
| 168 |
+
q_mask = questions["attention_mask"][b:b+1]
|
| 169 |
+
img_features = self.encode_image(img)
|
| 170 |
+
img_features = self.img_proj(img_features).float()
|
| 171 |
+
q_features = self.encode_question(q_ids, q_mask)
|
| 172 |
+
q_features = self.q_proj(q_features).float()
|
| 173 |
+
context = self.fuse_features(img_features, q_features)
|
| 174 |
+
initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size,
|
| 175 |
+
device=self.device, dtype=torch.float)
|
| 176 |
+
beams = [(
|
| 177 |
+
torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device),
|
| 178 |
+
0.0,
|
| 179 |
+
initial_hidden
|
| 180 |
+
)]
|
| 181 |
+
completed_beams = []
|
| 182 |
+
for t in range(1, self.answer_max_len):
|
| 183 |
+
candidates = []
|
| 184 |
+
for seq, score, hidden in beams:
|
| 185 |
+
if seq[0, -1].item() == self.eos_token_id:
|
| 186 |
+
completed_beams.append((seq, score))
|
| 187 |
+
continue
|
| 188 |
+
current_input = seq[:, -1]
|
| 189 |
+
logits, new_hidden = self.decoder(current_input, context, hidden)
|
| 190 |
+
log_probs = torch.log_softmax(logits.squeeze(1), dim=-1)
|
| 191 |
+
top_log_probs, top_indices = torch.topk(log_probs[0], beam_width)
|
| 192 |
+
for i in range(beam_width):
|
| 193 |
+
next_token = top_indices[i].unsqueeze(0).unsqueeze(0)
|
| 194 |
+
new_seq = torch.cat([seq, next_token], dim=1)
|
| 195 |
+
new_score = score + top_log_probs[i].item()
|
| 196 |
+
candidates.append((new_seq, new_score, new_hidden))
|
| 197 |
+
if len(candidates) == 0:
|
| 198 |
+
break
|
| 199 |
+
beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
|
| 200 |
+
all_beams = completed_beams + [(seq, score) for seq, score, _ in beams]
|
| 201 |
+
if len(all_beams) == 0:
|
| 202 |
+
result = torch.full((1, self.answer_max_len), self.pad_token_id,
|
| 203 |
+
dtype=torch.long, device=self.device)
|
| 204 |
+
else:
|
| 205 |
+
best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7))
|
| 206 |
+
result = torch.full((1, self.answer_max_len), self.pad_token_id,
|
| 207 |
+
dtype=torch.long, device=self.device)
|
| 208 |
+
seq_len = min(best_beam[0].size(1), self.answer_max_len)
|
| 209 |
+
result[:, :seq_len] = best_beam[0][:, :seq_len]
|
| 210 |
+
all_results.append(result)
|
| 211 |
+
return torch.cat(all_results, dim=0)
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
device = "cuda"
|
| 214 |
+
model = VQAModel(device=device).to(device)
|
| 215 |
+
model.eval()
|
| 216 |
+
fake_image = torch.randn(1, 3, 224, 224).to(device)
|
| 217 |
+
fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0]]).to(device)
|
| 218 |
+
fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0]]).to(device)
|
| 219 |
+
question_batch = {
|
| 220 |
+
"input_ids": fake_question_ids,
|
| 221 |
+
"attention_mask": fake_question_mask
|
| 222 |
+
}
|
| 223 |
+
output = model(fake_image, question_batch)
|
| 224 |
+
print(output)
|
model_spatial.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import clip
|
| 4 |
+
from transformers import GPT2Model
|
| 5 |
+
import math
|
| 6 |
+
class SpatialAdapter(nn.Module):
|
| 7 |
+
"""
|
| 8 |
+
Spatial Adapter with Multi-Head Cross-Attention for spatial reasoning.
|
| 9 |
+
Processes CLIP patch features (14x14 grid) with question guidance.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, patch_dim=512, question_dim=512, hidden_dim=512, num_heads=8, dropout=0.3):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.hidden_dim = hidden_dim
|
| 14 |
+
self.num_heads = num_heads
|
| 15 |
+
self.head_dim = hidden_dim // num_heads
|
| 16 |
+
assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
|
| 17 |
+
self.register_buffer('pos_encoding_2d', self._create_2d_positional_encoding(14, 14, patch_dim))
|
| 18 |
+
self.patch_proj = nn.Linear(patch_dim, hidden_dim)
|
| 19 |
+
self.question_proj = nn.Linear(question_dim, hidden_dim)
|
| 20 |
+
self.cross_attn_query = nn.Linear(hidden_dim, hidden_dim)
|
| 21 |
+
self.cross_attn_key = nn.Linear(hidden_dim, hidden_dim)
|
| 22 |
+
self.cross_attn_value = nn.Linear(hidden_dim, hidden_dim)
|
| 23 |
+
self.cross_attn_out = nn.Linear(hidden_dim, hidden_dim)
|
| 24 |
+
self.self_attn_query = nn.Linear(hidden_dim, hidden_dim)
|
| 25 |
+
self.self_attn_key = nn.Linear(hidden_dim, hidden_dim)
|
| 26 |
+
self.self_attn_value = nn.Linear(hidden_dim, hidden_dim)
|
| 27 |
+
self.self_attn_out = nn.Linear(hidden_dim, hidden_dim)
|
| 28 |
+
self.ffn = nn.Sequential(
|
| 29 |
+
nn.Linear(hidden_dim, hidden_dim * 4),
|
| 30 |
+
nn.GELU(),
|
| 31 |
+
nn.Dropout(dropout),
|
| 32 |
+
nn.Linear(hidden_dim * 4, hidden_dim),
|
| 33 |
+
nn.Dropout(dropout)
|
| 34 |
+
)
|
| 35 |
+
self.ln1 = nn.LayerNorm(hidden_dim)
|
| 36 |
+
self.ln2 = nn.LayerNorm(hidden_dim)
|
| 37 |
+
self.ln3 = nn.LayerNorm(hidden_dim)
|
| 38 |
+
self.dropout = nn.Dropout(dropout)
|
| 39 |
+
def _create_2d_positional_encoding(self, height, width, dim):
|
| 40 |
+
"""Create 2D positional encoding for spatial grid"""
|
| 41 |
+
pos_h = torch.arange(height).unsqueeze(1).repeat(1, width).flatten()
|
| 42 |
+
pos_w = torch.arange(width).unsqueeze(0).repeat(height, 1).flatten()
|
| 43 |
+
pe = torch.zeros(height * width, dim)
|
| 44 |
+
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
|
| 45 |
+
pe[:, 0:dim//2:2] = torch.sin(pos_h.unsqueeze(1) * div_term[:dim//4])
|
| 46 |
+
pe[:, 1:dim//2:2] = torch.cos(pos_h.unsqueeze(1) * div_term[:dim//4])
|
| 47 |
+
pe[:, dim//2::2] = torch.sin(pos_w.unsqueeze(1) * div_term[:dim//4])
|
| 48 |
+
pe[:, dim//2+1::2] = torch.cos(pos_w.unsqueeze(1) * div_term[:dim//4])
|
| 49 |
+
return pe.unsqueeze(0)
|
| 50 |
+
def _multi_head_attention(self, query, key, value, num_heads):
|
| 51 |
+
"""Generic multi-head attention implementation"""
|
| 52 |
+
batch_size = query.size(0)
|
| 53 |
+
Q = query.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
|
| 54 |
+
K = key.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
|
| 55 |
+
V = value.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
|
| 56 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 57 |
+
attn_weights = torch.softmax(scores, dim=-1)
|
| 58 |
+
attn_weights = self.dropout(attn_weights)
|
| 59 |
+
context = torch.matmul(attn_weights, V)
|
| 60 |
+
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
|
| 61 |
+
return context, attn_weights
|
| 62 |
+
def forward(self, patch_features, question_features):
|
| 63 |
+
"""
|
| 64 |
+
Args:
|
| 65 |
+
patch_features: [batch_size, num_patches, patch_dim] - CLIP patch features
|
| 66 |
+
question_features: [batch_size, question_dim] - Question encoding
|
| 67 |
+
Returns:
|
| 68 |
+
spatial_context: [batch_size, hidden_dim] - Spatially-aware context
|
| 69 |
+
"""
|
| 70 |
+
batch_size, num_patches, _ = patch_features.shape
|
| 71 |
+
patch_features = patch_features + self.pos_encoding_2d[:, :num_patches, :].to(patch_features.device)
|
| 72 |
+
patches = self.patch_proj(patch_features)
|
| 73 |
+
question = self.question_proj(question_features.unsqueeze(1))
|
| 74 |
+
Q_cross = self.cross_attn_query(patches)
|
| 75 |
+
K_cross = self.cross_attn_key(question)
|
| 76 |
+
V_cross = self.cross_attn_value(question)
|
| 77 |
+
cross_context, _ = self._multi_head_attention(Q_cross, K_cross, V_cross, self.num_heads)
|
| 78 |
+
cross_out = self.cross_attn_out(cross_context)
|
| 79 |
+
patches = self.ln1(patches + self.dropout(cross_out))
|
| 80 |
+
Q_self = self.self_attn_query(patches)
|
| 81 |
+
K_self = self.self_attn_key(patches)
|
| 82 |
+
V_self = self.self_attn_value(patches)
|
| 83 |
+
self_context, _ = self._multi_head_attention(Q_self, K_self, V_self, self.num_heads)
|
| 84 |
+
self_out = self.self_attn_out(self_context)
|
| 85 |
+
patches = self.ln2(patches + self.dropout(self_out))
|
| 86 |
+
ffn_out = self.ffn(patches)
|
| 87 |
+
patches = self.ln3(patches + ffn_out)
|
| 88 |
+
attn_scores = torch.matmul(patches, question.transpose(1, 2))
|
| 89 |
+
attn_weights = torch.softmax(attn_scores, dim=1)
|
| 90 |
+
spatial_context = (patches * attn_weights).sum(dim=1)
|
| 91 |
+
return spatial_context
|
| 92 |
+
class VQAModelWithSpatialAdapter(nn.Module):
|
| 93 |
+
"""
|
| 94 |
+
Enhanced VQA Model with Spatial Adapter for spatial reasoning.
|
| 95 |
+
Uses patch-based CLIP features instead of global encoding.
|
| 96 |
+
"""
|
| 97 |
+
def __init__(
|
| 98 |
+
self,
|
| 99 |
+
base_model,
|
| 100 |
+
hidden_size=512,
|
| 101 |
+
num_heads=8,
|
| 102 |
+
dropout=0.3
|
| 103 |
+
):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.device = base_model.device
|
| 106 |
+
self.question_max_len = base_model.question_max_len
|
| 107 |
+
self.answer_max_len = base_model.answer_max_len
|
| 108 |
+
self.vocab_size = base_model.vocab_size
|
| 109 |
+
self.hidden_size = hidden_size
|
| 110 |
+
self.num_layers = base_model.num_layers
|
| 111 |
+
self.fine_tuning_mode = base_model.fine_tuning_mode
|
| 112 |
+
self.pad_token_id = base_model.pad_token_id
|
| 113 |
+
self.bos_token_id = base_model.bos_token_id
|
| 114 |
+
self.eos_token_id = base_model.eos_token_id
|
| 115 |
+
self.unk_token_id = base_model.unk_token_id
|
| 116 |
+
self.clip_model = base_model.clip_model
|
| 117 |
+
self.clip_preprocess = base_model.clip_preprocess
|
| 118 |
+
self.gpt2_model = base_model.gpt2_model
|
| 119 |
+
self.decoder = base_model.decoder
|
| 120 |
+
self.spatial_adapter = SpatialAdapter(
|
| 121 |
+
patch_dim=512,
|
| 122 |
+
question_dim=768,
|
| 123 |
+
hidden_dim=hidden_size,
|
| 124 |
+
num_heads=num_heads,
|
| 125 |
+
dropout=dropout
|
| 126 |
+
)
|
| 127 |
+
self.spatial_context_proj = nn.Linear(hidden_size, hidden_size)
|
| 128 |
+
self.q_proj = nn.Linear(768, hidden_size)
|
| 129 |
+
self.spatial_fusion = nn.Sequential(
|
| 130 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
| 131 |
+
nn.GELU(),
|
| 132 |
+
nn.Dropout(dropout),
|
| 133 |
+
nn.Linear(hidden_size, hidden_size),
|
| 134 |
+
nn.LayerNorm(hidden_size)
|
| 135 |
+
)
|
| 136 |
+
def extract_clip_patch_features(self, images):
|
| 137 |
+
"""
|
| 138 |
+
Extract patch features from CLIP instead of global features.
|
| 139 |
+
Returns: [batch_size, num_patches, patch_dim]
|
| 140 |
+
"""
|
| 141 |
+
clip_dtype = self.clip_model.visual.conv1.weight.dtype
|
| 142 |
+
images = images.to(clip_dtype)
|
| 143 |
+
if self.fine_tuning_mode:
|
| 144 |
+
x = self.clip_model.visual.conv1(images)
|
| 145 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
| 146 |
+
x = x.permute(0, 2, 1)
|
| 147 |
+
class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(
|
| 148 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
|
| 149 |
+
)
|
| 150 |
+
x = torch.cat([class_token, x], dim=1)
|
| 151 |
+
x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
|
| 152 |
+
x = self.clip_model.visual.ln_pre(x)
|
| 153 |
+
x = x.permute(1, 0, 2)
|
| 154 |
+
x = self.clip_model.visual.transformer(x)
|
| 155 |
+
x = x.permute(1, 0, 2)
|
| 156 |
+
patch_features = x[:, 1:, :]
|
| 157 |
+
if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
|
| 158 |
+
if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
|
| 159 |
+
patch_features = patch_features @ self.clip_model.visual.proj
|
| 160 |
+
else:
|
| 161 |
+
patch_features = self.clip_model.visual.proj(patch_features)
|
| 162 |
+
else:
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
x = self.clip_model.visual.conv1(images)
|
| 165 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
| 166 |
+
x = x.permute(0, 2, 1)
|
| 167 |
+
class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(
|
| 168 |
+
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
|
| 169 |
+
)
|
| 170 |
+
x = torch.cat([class_token, x], dim=1)
|
| 171 |
+
x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
|
| 172 |
+
x = self.clip_model.visual.ln_pre(x)
|
| 173 |
+
x = x.permute(1, 0, 2)
|
| 174 |
+
x = self.clip_model.visual.transformer(x)
|
| 175 |
+
x = x.permute(1, 0, 2)
|
| 176 |
+
patch_features = x[:, 1:, :]
|
| 177 |
+
if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
|
| 178 |
+
if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
|
| 179 |
+
patch_features = patch_features @ self.clip_model.visual.proj
|
| 180 |
+
else:
|
| 181 |
+
patch_features = self.clip_model.visual.proj(patch_features)
|
| 182 |
+
return patch_features.float()
|
| 183 |
+
def encode_question(self, input_ids, attention_mask):
|
| 184 |
+
"""Encode question using GPT-2 (same as base model)"""
|
| 185 |
+
if self.fine_tuning_mode:
|
| 186 |
+
outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 187 |
+
else:
|
| 188 |
+
with torch.no_grad():
|
| 189 |
+
outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 190 |
+
last_hidden = outputs.last_hidden_state
|
| 191 |
+
mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
|
| 192 |
+
masked = last_hidden * mask
|
| 193 |
+
sum_hidden = masked.sum(dim=1)
|
| 194 |
+
lengths = mask.sum(dim=1).clamp(min=1e-6)
|
| 195 |
+
text_features = sum_hidden / lengths
|
| 196 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 197 |
+
return text_features.float()
|
| 198 |
+
def forward(self, images, questions, answer_input_ids=None):
|
| 199 |
+
"""
|
| 200 |
+
Forward pass with spatial adapter.
|
| 201 |
+
"""
|
| 202 |
+
patch_features = self.extract_clip_patch_features(images)
|
| 203 |
+
q_features = self.encode_question(questions["input_ids"], questions["attention_mask"])
|
| 204 |
+
spatial_context = self.spatial_adapter(patch_features, q_features)
|
| 205 |
+
spatial_context = self.spatial_context_proj(spatial_context)
|
| 206 |
+
q_projected = self.q_proj(q_features)
|
| 207 |
+
fused = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1))
|
| 208 |
+
batch_size = images.size(0)
|
| 209 |
+
hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size,
|
| 210 |
+
device=self.device, dtype=torch.float)
|
| 211 |
+
if answer_input_ids is not None:
|
| 212 |
+
logits, _ = self.decoder(answer_input_ids, fused, hidden)
|
| 213 |
+
return logits
|
| 214 |
+
else:
|
| 215 |
+
generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id,
|
| 216 |
+
dtype=torch.long, device=self.device)
|
| 217 |
+
generated[:, 0] = self.bos_token_id
|
| 218 |
+
for t in range(1, self.answer_max_len):
|
| 219 |
+
current_input = generated[:, t-1]
|
| 220 |
+
logits, hidden = self.decoder(current_input, fused, hidden)
|
| 221 |
+
next_tokens = logits.squeeze(1).argmax(dim=-1)
|
| 222 |
+
generated[:, t] = next_tokens
|
| 223 |
+
if (next_tokens == self.eos_token_id).all():
|
| 224 |
+
break
|
| 225 |
+
return generated
|
| 226 |
+
def generate_with_beam_search(self, images, questions, beam_width=5):
|
| 227 |
+
"""Beam search generation (same as base model but with spatial features)"""
|
| 228 |
+
batch_size = images.size(0)
|
| 229 |
+
all_results = []
|
| 230 |
+
for b in range(batch_size):
|
| 231 |
+
img = images[b:b+1]
|
| 232 |
+
q_ids = questions["input_ids"][b:b+1]
|
| 233 |
+
q_mask = questions["attention_mask"][b:b+1]
|
| 234 |
+
patch_features = self.extract_clip_patch_features(img)
|
| 235 |
+
q_features = self.encode_question(q_ids, q_mask)
|
| 236 |
+
spatial_context = self.spatial_adapter(patch_features, q_features)
|
| 237 |
+
spatial_context = self.spatial_context_proj(spatial_context)
|
| 238 |
+
q_projected = self.q_proj(q_features)
|
| 239 |
+
context = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1))
|
| 240 |
+
initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size,
|
| 241 |
+
device=self.device, dtype=torch.float)
|
| 242 |
+
beams = [(
|
| 243 |
+
torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device),
|
| 244 |
+
0.0,
|
| 245 |
+
initial_hidden
|
| 246 |
+
)]
|
| 247 |
+
completed_beams = []
|
| 248 |
+
for t in range(1, self.answer_max_len):
|
| 249 |
+
candidates = []
|
| 250 |
+
for seq, score, hidden in beams:
|
| 251 |
+
if seq[0, -1].item() == self.eos_token_id:
|
| 252 |
+
completed_beams.append((seq, score))
|
| 253 |
+
continue
|
| 254 |
+
current_input = seq[:, -1]
|
| 255 |
+
logits, new_hidden = self.decoder(current_input, context, hidden)
|
| 256 |
+
log_probs = torch.log_softmax(logits.squeeze(1), dim=-1)
|
| 257 |
+
top_log_probs, top_indices = torch.topk(log_probs[0], beam_width)
|
| 258 |
+
for i in range(beam_width):
|
| 259 |
+
next_token = top_indices[i].unsqueeze(0).unsqueeze(0)
|
| 260 |
+
new_seq = torch.cat([seq, next_token], dim=1)
|
| 261 |
+
new_score = score + top_log_probs[i].item()
|
| 262 |
+
candidates.append((new_seq, new_score, new_hidden))
|
| 263 |
+
if len(candidates) == 0:
|
| 264 |
+
break
|
| 265 |
+
beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
|
| 266 |
+
all_beams = completed_beams + [(seq, score) for seq, score, _ in beams]
|
| 267 |
+
if len(all_beams) == 0:
|
| 268 |
+
result = torch.full((1, self.answer_max_len), self.pad_token_id,
|
| 269 |
+
dtype=torch.long, device=self.device)
|
| 270 |
+
else:
|
| 271 |
+
best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7))
|
| 272 |
+
result = torch.full((1, self.answer_max_len), self.pad_token_id,
|
| 273 |
+
dtype=torch.long, device=self.device)
|
| 274 |
+
seq_len = min(best_beam[0].size(1), self.answer_max_len)
|
| 275 |
+
result[:, :seq_len] = best_beam[0][:, :seq_len]
|
| 276 |
+
all_results.append(result)
|
| 277 |
+
return torch.cat(all_results, dim=0)
|
| 278 |
+
if __name__ == "__main__":
|
| 279 |
+
print("Testing Spatial Adapter Architecture...")
|
| 280 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 281 |
+
from model import VQAModel
|
| 282 |
+
base_model = VQAModel(device=device).to(device)
|
| 283 |
+
spatial_model = VQAModelWithSpatialAdapter(base_model).to(device)
|
| 284 |
+
spatial_model.eval()
|
| 285 |
+
fake_image = torch.randn(2, 3, 224, 224).to(device)
|
| 286 |
+
fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0], [1, 15, 25, 35, 2, 0, 0]]).to(device)
|
| 287 |
+
fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0, 0]]).to(device)
|
| 288 |
+
question_batch = {
|
| 289 |
+
"input_ids": fake_question_ids,
|
| 290 |
+
"attention_mask": fake_question_mask
|
| 291 |
+
}
|
| 292 |
+
print(f"\nInput shapes:")
|
| 293 |
+
print(f" Images: {fake_image.shape}")
|
| 294 |
+
print(f" Questions: {fake_question_ids.shape}")
|
| 295 |
+
with torch.no_grad():
|
| 296 |
+
patch_features = spatial_model.extract_clip_patch_features(fake_image)
|
| 297 |
+
print(f"\nPatch features shape: {patch_features.shape}")
|
| 298 |
+
print(f" Expected: [2, 196, 512] (batch_size, num_patches, patch_dim)")
|
| 299 |
+
output = spatial_model(fake_image, question_batch)
|
| 300 |
+
print(f"\nGenerated output shape: {output.shape}")
|
| 301 |
+
print(f" Expected: [2, {spatial_model.answer_max_len}]")
|
| 302 |
+
total_params = sum(p.numel() for p in spatial_model.parameters())
|
| 303 |
+
spatial_adapter_params = sum(p.numel() for p in spatial_model.spatial_adapter.parameters())
|
| 304 |
+
trainable_params = sum(p.numel() for p in spatial_model.parameters() if p.requires_grad)
|
| 305 |
+
print(f"\nParameter counts:")
|
| 306 |
+
print(f" Total parameters: {total_params:,}")
|
| 307 |
+
print(f" Spatial adapter parameters: {spatial_adapter_params:,}")
|
| 308 |
+
print(f" Trainable parameters: {trainable_params:,}")
|
| 309 |
+
print("\n✓ Spatial adapter architecture test passed!")
|
models/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
models/model.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import clip
|
| 4 |
+
from transformers import GPT2Model
|
| 5 |
+
class AttentionDecoder(nn.Module):
|
| 6 |
+
def __init__(self, hidden_size, vocab_size, num_layers=1, dropout=0.3):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.hidden_size = hidden_size
|
| 9 |
+
self.num_layers = num_layers
|
| 10 |
+
self.embedding = nn.Embedding(vocab_size, hidden_size)
|
| 11 |
+
self.attention = nn.Linear(hidden_size * 2, 1)
|
| 12 |
+
self.gru = nn.GRU(
|
| 13 |
+
input_size=hidden_size * 2,
|
| 14 |
+
hidden_size=hidden_size,
|
| 15 |
+
num_layers=num_layers,
|
| 16 |
+
batch_first=True,
|
| 17 |
+
dropout=dropout if num_layers > 1 else 0
|
| 18 |
+
)
|
| 19 |
+
self.ln_gru = nn.LayerNorm(hidden_size)
|
| 20 |
+
self.output = nn.Linear(hidden_size, vocab_size)
|
| 21 |
+
def forward(self, input_ids, context, hidden):
|
| 22 |
+
if input_ids.dim() == 1:
|
| 23 |
+
input_ids = input_ids.unsqueeze(1)
|
| 24 |
+
embeddings = self.embedding(input_ids).float()
|
| 25 |
+
context_expanded = context.unsqueeze(1).expand(-1, embeddings.size(1), -1)
|
| 26 |
+
combined = torch.cat([embeddings, context_expanded], dim=-1)
|
| 27 |
+
attn_weights = torch.softmax(self.attention(combined), dim=1)
|
| 28 |
+
attended_context = (context_expanded * attn_weights).sum(dim=1, keepdim=True)
|
| 29 |
+
gru_input = torch.cat([embeddings, attended_context.expand(-1, embeddings.size(1), -1)], dim=-1)
|
| 30 |
+
gru_output, hidden = self.gru(gru_input, hidden)
|
| 31 |
+
gru_output = self.ln_gru(gru_output)
|
| 32 |
+
return self.output(gru_output), hidden
|
| 33 |
+
class VQAModel(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
vocab_size=3600,
|
| 37 |
+
question_max_len=16,
|
| 38 |
+
answer_max_len=10,
|
| 39 |
+
hidden_size=512,
|
| 40 |
+
num_layers=2,
|
| 41 |
+
dropout=0.3,
|
| 42 |
+
device='cuda',
|
| 43 |
+
pad_token_id=0,
|
| 44 |
+
bos_token_id=1,
|
| 45 |
+
eos_token_id=2,
|
| 46 |
+
unk_token_id=3
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.device = device
|
| 50 |
+
self.question_max_len = question_max_len
|
| 51 |
+
self.answer_max_len = answer_max_len
|
| 52 |
+
self.vocab_size = vocab_size
|
| 53 |
+
self.hidden_size = hidden_size
|
| 54 |
+
self.num_layers = num_layers
|
| 55 |
+
self.fine_tuning_mode = False
|
| 56 |
+
self.pad_token_id = pad_token_id
|
| 57 |
+
self.bos_token_id = bos_token_id
|
| 58 |
+
self.eos_token_id = eos_token_id
|
| 59 |
+
self.unk_token_id = unk_token_id
|
| 60 |
+
self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=device)
|
| 61 |
+
for p in self.clip_model.parameters():
|
| 62 |
+
p.requires_grad = False
|
| 63 |
+
self.gpt2_model = GPT2Model.from_pretrained("distilgpt2")
|
| 64 |
+
self.gpt2_model.to(device)
|
| 65 |
+
for p in self.gpt2_model.parameters():
|
| 66 |
+
p.requires_grad = False
|
| 67 |
+
self.img_proj = nn.Linear(512, hidden_size)
|
| 68 |
+
self.q_proj = nn.Linear(768, hidden_size)
|
| 69 |
+
self.gate_layer = nn.Linear(hidden_size*2, hidden_size)
|
| 70 |
+
self.fusion = nn.Sequential(
|
| 71 |
+
nn.Linear(hidden_size*3, hidden_size),
|
| 72 |
+
nn.ReLU(),
|
| 73 |
+
nn.Dropout(dropout),
|
| 74 |
+
nn.Linear(hidden_size, hidden_size)
|
| 75 |
+
)
|
| 76 |
+
self.decoder = AttentionDecoder(hidden_size, vocab_size, num_layers, dropout)
|
| 77 |
+
def unfreeze_clip_layers(self, num_layers=2):
|
| 78 |
+
self.clip_model.train()
|
| 79 |
+
self.clip_model.visual.float()
|
| 80 |
+
total_blocks = len(self.clip_model.visual.transformer.resblocks)
|
| 81 |
+
for i, block in enumerate(self.clip_model.visual.transformer.resblocks):
|
| 82 |
+
if i >= total_blocks - num_layers:
|
| 83 |
+
for p in block.parameters():
|
| 84 |
+
p.requires_grad = True
|
| 85 |
+
if hasattr(self.clip_model.visual, "proj") and self.clip_model.visual.proj is not None:
|
| 86 |
+
if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
|
| 87 |
+
self.clip_model.visual.proj.requires_grad = True
|
| 88 |
+
else:
|
| 89 |
+
for p in self.clip_model.visual.proj.parameters():
|
| 90 |
+
p.requires_grad = True
|
| 91 |
+
if hasattr(self.clip_model.visual, "ln_post"):
|
| 92 |
+
for p in self.clip_model.visual.ln_post.parameters():
|
| 93 |
+
p.requires_grad = True
|
| 94 |
+
self.fine_tuning_mode = True
|
| 95 |
+
print(f"Unfrozen last {num_layers} CLIP layers")
|
| 96 |
+
def unfreeze_gpt2_layers(self, num_layers=1):
|
| 97 |
+
self.gpt2_model.train()
|
| 98 |
+
total_layers = len(self.gpt2_model.h)
|
| 99 |
+
for i, layer in enumerate(self.gpt2_model.h):
|
| 100 |
+
if i >= total_layers - num_layers:
|
| 101 |
+
for p in layer.parameters():
|
| 102 |
+
p.requires_grad = True
|
| 103 |
+
p.data = p.data.float()
|
| 104 |
+
for p in self.gpt2_model.ln_f.parameters():
|
| 105 |
+
p.requires_grad = True
|
| 106 |
+
p.data = p.data.float()
|
| 107 |
+
self.fine_tuning_mode = True
|
| 108 |
+
print(f"Unfrozen last {num_layers} GPT-2 layers")
|
| 109 |
+
def encode_image(self, images):
|
| 110 |
+
if self.fine_tuning_mode:
|
| 111 |
+
images = images.float()
|
| 112 |
+
img_features = self.clip_model.encode_image(images)
|
| 113 |
+
else:
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
img_features = self.clip_model.encode_image(images)
|
| 116 |
+
img_features = img_features / img_features.norm(dim=-1, keepdim=True)
|
| 117 |
+
return img_features.float()
|
| 118 |
+
def encode_question(self, input_ids, attention_mask):
|
| 119 |
+
if self.fine_tuning_mode:
|
| 120 |
+
outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 121 |
+
else:
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 124 |
+
last_hidden = outputs.last_hidden_state
|
| 125 |
+
mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
|
| 126 |
+
masked = last_hidden * mask
|
| 127 |
+
sum_hidden = masked.sum(dim=1)
|
| 128 |
+
lengths = mask.sum(dim=1).clamp(min=1e-6)
|
| 129 |
+
text_features = sum_hidden / lengths
|
| 130 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 131 |
+
return text_features.float()
|
| 132 |
+
def fuse_features(self, img_features, q_features):
|
| 133 |
+
x = torch.cat([img_features, q_features], dim=-1)
|
| 134 |
+
gate = torch.sigmoid(self.gate_layer(x))
|
| 135 |
+
fused = gate * img_features + (1-gate) * q_features
|
| 136 |
+
fused = self.fusion(torch.cat([fused, x], dim=-1))
|
| 137 |
+
return fused
|
| 138 |
+
def forward(self, images, questions, answer_input_ids=None):
|
| 139 |
+
img_features = self.encode_image(images)
|
| 140 |
+
img_features = self.img_proj(img_features).float()
|
| 141 |
+
q_features = self.encode_question(questions["input_ids"], questions["attention_mask"])
|
| 142 |
+
q_features = self.q_proj(q_features).float()
|
| 143 |
+
batch_size = img_features.size(0)
|
| 144 |
+
context = self.fuse_features(img_features, q_features)
|
| 145 |
+
hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size,
|
| 146 |
+
device=self.device, dtype=torch.float)
|
| 147 |
+
if answer_input_ids is not None:
|
| 148 |
+
logits, _ = self.decoder(answer_input_ids, context, hidden)
|
| 149 |
+
return logits
|
| 150 |
+
else:
|
| 151 |
+
generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id,
|
| 152 |
+
dtype=torch.long, device=self.device)
|
| 153 |
+
generated[:, 0] = self.bos_token_id
|
| 154 |
+
for t in range(1, self.answer_max_len):
|
| 155 |
+
current_input = generated[:, t-1]
|
| 156 |
+
logits, hidden = self.decoder(current_input, context, hidden)
|
| 157 |
+
next_tokens = logits.squeeze(1).argmax(dim=-1)
|
| 158 |
+
generated[:, t] = next_tokens
|
| 159 |
+
if (next_tokens == self.eos_token_id).all():
|
| 160 |
+
break
|
| 161 |
+
return generated
|
| 162 |
+
def generate_with_beam_search(self, images, questions, beam_width=5):
|
| 163 |
+
batch_size = images.size(0)
|
| 164 |
+
all_results = []
|
| 165 |
+
for b in range(batch_size):
|
| 166 |
+
img = images[b:b+1]
|
| 167 |
+
q_ids = questions["input_ids"][b:b+1]
|
| 168 |
+
q_mask = questions["attention_mask"][b:b+1]
|
| 169 |
+
img_features = self.encode_image(img)
|
| 170 |
+
img_features = self.img_proj(img_features).float()
|
| 171 |
+
q_features = self.encode_question(q_ids, q_mask)
|
| 172 |
+
q_features = self.q_proj(q_features).float()
|
| 173 |
+
context = self.fuse_features(img_features, q_features)
|
| 174 |
+
initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size,
|
| 175 |
+
device=self.device, dtype=torch.float)
|
| 176 |
+
beams = [(
|
| 177 |
+
torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device),
|
| 178 |
+
0.0,
|
| 179 |
+
initial_hidden
|
| 180 |
+
)]
|
| 181 |
+
completed_beams = []
|
| 182 |
+
for t in range(1, self.answer_max_len):
|
| 183 |
+
candidates = []
|
| 184 |
+
for seq, score, hidden in beams:
|
| 185 |
+
if seq[0, -1].item() == self.eos_token_id:
|
| 186 |
+
completed_beams.append((seq, score))
|
| 187 |
+
continue
|
| 188 |
+
current_input = seq[:, -1]
|
| 189 |
+
logits, new_hidden = self.decoder(current_input, context, hidden)
|
| 190 |
+
log_probs = torch.log_softmax(logits.squeeze(1), dim=-1)
|
| 191 |
+
top_log_probs, top_indices = torch.topk(log_probs[0], beam_width)
|
| 192 |
+
for i in range(beam_width):
|
| 193 |
+
next_token = top_indices[i].unsqueeze(0).unsqueeze(0)
|
| 194 |
+
new_seq = torch.cat([seq, next_token], dim=1)
|
| 195 |
+
new_score = score + top_log_probs[i].item()
|
| 196 |
+
candidates.append((new_seq, new_score, new_hidden))
|
| 197 |
+
if len(candidates) == 0:
|
| 198 |
+
break
|
| 199 |
+
beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
|
| 200 |
+
all_beams = completed_beams + [(seq, score) for seq, score, _ in beams]
|
| 201 |
+
if len(all_beams) == 0:
|
| 202 |
+
result = torch.full((1, self.answer_max_len), self.pad_token_id,
|
| 203 |
+
dtype=torch.long, device=self.device)
|
| 204 |
+
else:
|
| 205 |
+
best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7))
|
| 206 |
+
result = torch.full((1, self.answer_max_len), self.pad_token_id,
|
| 207 |
+
dtype=torch.long, device=self.device)
|
| 208 |
+
seq_len = min(best_beam[0].size(1), self.answer_max_len)
|
| 209 |
+
result[:, :seq_len] = best_beam[0][:, :seq_len]
|
| 210 |
+
all_results.append(result)
|
| 211 |
+
return torch.cat(all_results, dim=0)
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
device = "cuda"
|
| 214 |
+
model = VQAModel(device=device).to(device)
|
| 215 |
+
model.eval()
|
| 216 |
+
fake_image = torch.randn(1, 3, 224, 224).to(device)
|
| 217 |
+
fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0]]).to(device)
|
| 218 |
+
fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0]]).to(device)
|
| 219 |
+
question_batch = {
|
| 220 |
+
"input_ids": fake_question_ids,
|
| 221 |
+
"attention_mask": fake_question_mask
|
| 222 |
+
}
|
| 223 |
+
output = model(fake_image, question_batch)
|
| 224 |
+
print(output)
|
quick_start.bat
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
REM Quick Start Script for VQA Mobile App
|
| 3 |
+
REM This script helps you start the backend and frontend
|
| 4 |
+
|
| 5 |
+
echo ========================================
|
| 6 |
+
echo VQA Mobile App - Quick Start
|
| 7 |
+
echo ========================================
|
| 8 |
+
echo.
|
| 9 |
+
|
| 10 |
+
REM Get current IP address
|
| 11 |
+
echo [1/3] Checking your IP address...
|
| 12 |
+
for /f "tokens=2 delims=:" %%a in ('ipconfig ^| findstr /c:"IPv4"') do (
|
| 13 |
+
set IP=%%a
|
| 14 |
+
set IP=!IP:~1!
|
| 15 |
+
echo Your IP: !IP!
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
echo.
|
| 19 |
+
echo [2/3] Current Configuration:
|
| 20 |
+
echo Backend: http://10.215.4.143:8000
|
| 21 |
+
echo Frontend: ui/src/config/api.js
|
| 22 |
+
echo.
|
| 23 |
+
|
| 24 |
+
echo IMPORTANT: Make sure both laptop and mobile are on the SAME network!
|
| 25 |
+
echo.
|
| 26 |
+
|
| 27 |
+
echo [3/3] Choose an option:
|
| 28 |
+
echo 1. Start Backend (Python)
|
| 29 |
+
echo 2. Start Frontend (Expo)
|
| 30 |
+
echo 3. Start Both (Opens 2 terminals)
|
| 31 |
+
echo 4. Exit
|
| 32 |
+
echo.
|
| 33 |
+
|
| 34 |
+
choice /c 1234 /n /m "Enter your choice (1-4): "
|
| 35 |
+
|
| 36 |
+
if errorlevel 4 goto :end
|
| 37 |
+
if errorlevel 3 goto :both
|
| 38 |
+
if errorlevel 2 goto :frontend
|
| 39 |
+
if errorlevel 1 goto :backend
|
| 40 |
+
|
| 41 |
+
:backend
|
| 42 |
+
echo.
|
| 43 |
+
echo Starting Backend Server...
|
| 44 |
+
echo Make sure you have activated your Python environment!
|
| 45 |
+
echo.
|
| 46 |
+
python backend_api.py
|
| 47 |
+
goto :end
|
| 48 |
+
|
| 49 |
+
:frontend
|
| 50 |
+
echo.
|
| 51 |
+
echo Starting Expo Frontend...
|
| 52 |
+
cd ui
|
| 53 |
+
npx expo start
|
| 54 |
+
goto :end
|
| 55 |
+
|
| 56 |
+
:both
|
| 57 |
+
echo.
|
| 58 |
+
echo Starting both Backend and Frontend...
|
| 59 |
+
echo Opening Backend in new window...
|
| 60 |
+
start cmd /k "python backend_api.py"
|
| 61 |
+
timeout /t 3 /nobreak >nul
|
| 62 |
+
echo Opening Frontend in new window...
|
| 63 |
+
start cmd /k "cd ui && npx expo start"
|
| 64 |
+
echo.
|
| 65 |
+
echo Both servers are starting in separate windows!
|
| 66 |
+
goto :end
|
| 67 |
+
|
| 68 |
+
:end
|
| 69 |
+
echo.
|
| 70 |
+
echo Done!
|
| 71 |
+
pause
|
requirements_api.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.115.6
|
| 2 |
+
uvicorn>=0.34.0
|
| 3 |
+
python-multipart>=0.0.20
|
| 4 |
+
pillow>=11.1.0
|
| 5 |
+
torch>=2.0.0
|
| 6 |
+
torchvision>=0.15.0
|
| 7 |
+
transformers>=4.30.0
|
| 8 |
+
ftfy
|
| 9 |
+
regex
|
| 10 |
+
tqdm
|
| 11 |
+
git+https://github.com/openai/CLIP.git
|
| 12 |
+
groq>=0.4.0
|
| 13 |
+
python-dotenv>=1.0.0
|
| 14 |
+
huggingface-hub
|
scores/feature.txt
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
================================================================================
|
| 2 |
+
EVALUATION RESULTS
|
| 3 |
+
================================================================================
|
| 4 |
+
|
| 5 |
+
📊 Accuracy Metrics:
|
| 6 |
+
Exact Match Accuracy: 50.17% (63805/135256)
|
| 7 |
+
VQA Accuracy: 15.72%
|
| 8 |
+
|
| 9 |
+
📊 ANLS Metrics:
|
| 10 |
+
Average ANLS (τ=0.5): 50.18%
|
| 11 |
+
ANLS Std Dev: 48.96%
|
| 12 |
+
|
| 13 |
+
📊 Additional Statistics:
|
| 14 |
+
Total samples: 135256
|
| 15 |
+
Avg prediction length: 1.13 words
|
| 16 |
+
Avg GT length: 1.10 words
|
| 17 |
+
|
| 18 |
+
================================================================================
|
| 19 |
+
SAMPLE PREDICTIONS
|
| 20 |
+
================================================================================
|
| 21 |
+
|
| 22 |
+
🏆 Best Predictions (Highest ANLS):
|
| 23 |
+
--------------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
Ground Truth: tusks
|
| 26 |
+
Prediction: tusks
|
| 27 |
+
ANLS: 1.0000
|
| 28 |
+
Exact Match: ✓
|
| 29 |
+
|
| 30 |
+
Ground Truth: seagull
|
| 31 |
+
Prediction: seagull
|
| 32 |
+
ANLS: 1.0000
|
| 33 |
+
Exact Match: ✓
|
| 34 |
+
|
| 35 |
+
Ground Truth: bedroom
|
| 36 |
+
Prediction: bedroom
|
| 37 |
+
ANLS: 1.0000
|
| 38 |
+
Exact Match: ✓
|
| 39 |
+
|
| 40 |
+
Ground Truth: cake
|
| 41 |
+
Prediction: cake
|
| 42 |
+
ANLS: 1.0000
|
| 43 |
+
Exact Match: ✓
|
| 44 |
+
|
| 45 |
+
Ground Truth: short
|
| 46 |
+
Prediction: short
|
| 47 |
+
ANLS: 1.0000
|
| 48 |
+
Exact Match: ✓
|
| 49 |
+
|
| 50 |
+
================================================================================
|
| 51 |
+
⚠️ Worst Predictions (Lowest ANLS):
|
| 52 |
+
--------------------------------------------------------------------------------
|
| 53 |
+
|
| 54 |
+
Ground Truth: mirror
|
| 55 |
+
Prediction: car
|
| 56 |
+
ANLS: 0.0000
|
| 57 |
+
Exact Match: ✗
|
| 58 |
+
|
| 59 |
+
Ground Truth: towel
|
| 60 |
+
Prediction: toy
|
| 61 |
+
ANLS: 0.0000
|
| 62 |
+
Exact Match: ✗
|
| 63 |
+
|
| 64 |
+
Ground Truth: book
|
| 65 |
+
Prediction: camera
|
| 66 |
+
ANLS: 0.0000
|
| 67 |
+
Exact Match: ✗
|
| 68 |
+
|
| 69 |
+
Ground Truth: usa
|
| 70 |
+
Prediction: england
|
| 71 |
+
ANLS: 0.0000
|
| 72 |
+
Exact Match: ✗
|
| 73 |
+
|
| 74 |
+
Ground Truth: red and yellow
|
| 75 |
+
Prediction: green
|
| 76 |
+
ANLS: 0.0000
|
| 77 |
+
Exact Match: ✗
|
scores/score.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from transformers import GPT2Tokenizer
|
| 6 |
+
from model import VQAModel
|
| 7 |
+
from model_spatial import VQAModelWithSpatialAdapter
|
| 8 |
+
from train import Vocab
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
import numpy as np
|
| 11 |
+
try:
|
| 12 |
+
from Levenshtein import distance as levenshtein_distance
|
| 13 |
+
except ImportError:
|
| 14 |
+
print("Installing python-Levenshtein...")
|
| 15 |
+
import subprocess
|
| 16 |
+
subprocess.check_call(['pip', 'install', 'python-Levenshtein'])
|
| 17 |
+
from Levenshtein import distance as levenshtein_distance
|
| 18 |
+
MODEL_TYPE = "feature"
|
| 19 |
+
SPATIAL_CHECKPOINT = "./output2/spatial_adapter_v2_2/vqa_spatial_checkpoint.pt"
|
| 20 |
+
FEATURE_CHECKPOINT = "./output2/feature_extraction/vqa_checkpoint.pt"
|
| 21 |
+
CSV_PATH = "./gen_vqa_v2/metadata.csv"
|
| 22 |
+
IMG_DIR = "./gen_vqa_v2"
|
| 23 |
+
MAX_SAMPLES = None
|
| 24 |
+
def load_spatial_model(checkpoint_path, device='cuda'):
|
| 25 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 26 |
+
vocab = Vocab()
|
| 27 |
+
vocab.vocab = checkpoint['vocab']
|
| 28 |
+
vocab.vocab_size = len(checkpoint['vocab'])
|
| 29 |
+
vocab.word2idx = checkpoint['word2idx']
|
| 30 |
+
vocab.idx2word = checkpoint['idx2word']
|
| 31 |
+
vocab.pad_token_id = checkpoint['pad_token_id']
|
| 32 |
+
vocab.bos_token_id = checkpoint['bos_token_id']
|
| 33 |
+
vocab.eos_token_id = checkpoint['eos_token_id']
|
| 34 |
+
vocab.unk_token_id = checkpoint['unk_token_id']
|
| 35 |
+
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 36 |
+
if tokenizer.pad_token is None:
|
| 37 |
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 38 |
+
base_model = VQAModel(
|
| 39 |
+
vocab_size=len(checkpoint['vocab']),
|
| 40 |
+
device=device,
|
| 41 |
+
question_max_len=checkpoint.get('question_max_len', 20),
|
| 42 |
+
answer_max_len=checkpoint.get('answer_max_len', 12),
|
| 43 |
+
pad_token_id=checkpoint['pad_token_id'],
|
| 44 |
+
bos_token_id=checkpoint['bos_token_id'],
|
| 45 |
+
eos_token_id=checkpoint['eos_token_id'],
|
| 46 |
+
unk_token_id=checkpoint['unk_token_id'],
|
| 47 |
+
hidden_size=512,
|
| 48 |
+
num_layers=2
|
| 49 |
+
).to(device)
|
| 50 |
+
base_model.gpt2_model.resize_token_embeddings(len(tokenizer))
|
| 51 |
+
model = VQAModelWithSpatialAdapter(
|
| 52 |
+
base_model=base_model,
|
| 53 |
+
hidden_size=512,
|
| 54 |
+
num_heads=8,
|
| 55 |
+
dropout=0.3
|
| 56 |
+
).to(device)
|
| 57 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 58 |
+
model.eval()
|
| 59 |
+
return model, vocab, tokenizer
|
| 60 |
+
def load_feature_model(checkpoint_path, device='cuda'):
|
| 61 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 62 |
+
vocab = Vocab()
|
| 63 |
+
vocab.vocab = checkpoint['vocab']
|
| 64 |
+
vocab.vocab_size = len(checkpoint['vocab'])
|
| 65 |
+
vocab.word2idx = checkpoint['word2idx']
|
| 66 |
+
vocab.idx2word = checkpoint['idx2word']
|
| 67 |
+
vocab.pad_token_id = checkpoint['pad_token_id']
|
| 68 |
+
vocab.bos_token_id = checkpoint['bos_token_id']
|
| 69 |
+
vocab.eos_token_id = checkpoint['eos_token_id']
|
| 70 |
+
vocab.unk_token_id = checkpoint['unk_token_id']
|
| 71 |
+
model = VQAModel(
|
| 72 |
+
vocab_size=len(checkpoint['vocab']),
|
| 73 |
+
device=device,
|
| 74 |
+
question_max_len=checkpoint.get('question_max_len', 20),
|
| 75 |
+
answer_max_len=checkpoint.get('answer_max_len', 12),
|
| 76 |
+
pad_token_id=checkpoint['pad_token_id'],
|
| 77 |
+
bos_token_id=checkpoint['bos_token_id'],
|
| 78 |
+
eos_token_id=checkpoint['eos_token_id'],
|
| 79 |
+
unk_token_id=checkpoint['unk_token_id'],
|
| 80 |
+
hidden_size=512,
|
| 81 |
+
num_layers=2
|
| 82 |
+
).to(device)
|
| 83 |
+
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
|
| 84 |
+
if tokenizer.pad_token is None:
|
| 85 |
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
| 86 |
+
model.gpt2_model.resize_token_embeddings(len(tokenizer))
|
| 87 |
+
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 88 |
+
model.eval()
|
| 89 |
+
return model, vocab, tokenizer
|
| 90 |
+
def generate_answer(model, vocab, tokenizer, image_path, question, device='cuda'):
|
| 91 |
+
image = Image.open(image_path).convert('RGB')
|
| 92 |
+
image = model.clip_preprocess(image).unsqueeze(0).to(device)
|
| 93 |
+
question_tokens = tokenizer(
|
| 94 |
+
question,
|
| 95 |
+
padding='max_length',
|
| 96 |
+
truncation=True,
|
| 97 |
+
max_length=model.question_max_len,
|
| 98 |
+
return_tensors='pt'
|
| 99 |
+
)
|
| 100 |
+
questions = {
|
| 101 |
+
'input_ids': question_tokens['input_ids'].to(device),
|
| 102 |
+
'attention_mask': question_tokens['attention_mask'].to(device)
|
| 103 |
+
}
|
| 104 |
+
with torch.no_grad():
|
| 105 |
+
if hasattr(model, 'generate_with_beam_search'):
|
| 106 |
+
generated = model.generate_with_beam_search(image, questions, beam_width=5)
|
| 107 |
+
else:
|
| 108 |
+
logits = model(image, questions)
|
| 109 |
+
generated = logits.argmax(dim=-1)
|
| 110 |
+
return vocab.decoder(generated[0].cpu().numpy())
|
| 111 |
+
def exact_match_accuracy(predictions, ground_truths):
|
| 112 |
+
"""
|
| 113 |
+
Calculate exact match accuracy (case-insensitive, stripped).
|
| 114 |
+
Args:
|
| 115 |
+
predictions: List of predicted answers
|
| 116 |
+
ground_truths: List of ground truth answers
|
| 117 |
+
Returns:
|
| 118 |
+
accuracy: Percentage of exact matches
|
| 119 |
+
"""
|
| 120 |
+
matches = sum(1 for pred, gt in zip(predictions, ground_truths)
|
| 121 |
+
if pred.strip().lower() == gt.strip().lower())
|
| 122 |
+
accuracy = (matches / len(predictions)) * 100 if predictions else 0
|
| 123 |
+
return accuracy, matches
|
| 124 |
+
def vqa_accuracy(predictions, ground_truths_list):
|
| 125 |
+
"""
|
| 126 |
+
VQA official metric: min(
|
| 127 |
+
Note: This assumes ground_truths_list is a list of lists,
|
| 128 |
+
where each inner list contains multiple human annotations.
|
| 129 |
+
If you only have one annotation per question, this reduces to exact match.
|
| 130 |
+
Args:
|
| 131 |
+
predictions: List of predicted answers
|
| 132 |
+
ground_truths_list: List of lists of ground truth answers
|
| 133 |
+
Returns:
|
| 134 |
+
vqa_score: VQA accuracy score (0-100)
|
| 135 |
+
"""
|
| 136 |
+
if not isinstance(ground_truths_list[0], list):
|
| 137 |
+
ground_truths_list = [[gt] for gt in ground_truths_list]
|
| 138 |
+
scores = []
|
| 139 |
+
for pred, gt_list in zip(predictions, ground_truths_list):
|
| 140 |
+
pred_clean = pred.strip().lower()
|
| 141 |
+
matches = sum(1 for gt in gt_list if pred_clean == gt.strip().lower())
|
| 142 |
+
score = min(matches / 3.0, 1.0)
|
| 143 |
+
scores.append(score)
|
| 144 |
+
vqa_score = (sum(scores) / len(scores)) * 100 if scores else 0
|
| 145 |
+
return vqa_score
|
| 146 |
+
def calculate_anls(prediction, ground_truth, threshold=0.5):
|
| 147 |
+
"""
|
| 148 |
+
Calculate ANLS (Average Normalized Levenshtein Similarity) for a single pair.
|
| 149 |
+
Args:
|
| 150 |
+
prediction: Predicted answer string
|
| 151 |
+
ground_truth: Ground truth answer string
|
| 152 |
+
threshold: Minimum similarity threshold (default: 0.5)
|
| 153 |
+
Returns:
|
| 154 |
+
anls_score: ANLS score (0-1)
|
| 155 |
+
"""
|
| 156 |
+
pred_clean = prediction.strip().lower()
|
| 157 |
+
gt_clean = ground_truth.strip().lower()
|
| 158 |
+
if len(gt_clean) == 0:
|
| 159 |
+
return 1.0 if len(pred_clean) == 0 else 0.0
|
| 160 |
+
dist = levenshtein_distance(pred_clean, gt_clean)
|
| 161 |
+
max_len = max(len(pred_clean), len(gt_clean))
|
| 162 |
+
if max_len == 0:
|
| 163 |
+
return 1.0
|
| 164 |
+
similarity = 1 - (dist / max_len)
|
| 165 |
+
anls = similarity if similarity >= threshold else 0.0
|
| 166 |
+
return anls
|
| 167 |
+
def average_anls(predictions, ground_truths, threshold=0.5):
|
| 168 |
+
"""
|
| 169 |
+
Calculate average ANLS across all predictions.
|
| 170 |
+
Args:
|
| 171 |
+
predictions: List of predicted answers
|
| 172 |
+
ground_truths: List of ground truth answers
|
| 173 |
+
threshold: Minimum similarity threshold
|
| 174 |
+
Returns:
|
| 175 |
+
avg_anls: Average ANLS score (0-100)
|
| 176 |
+
"""
|
| 177 |
+
anls_scores = []
|
| 178 |
+
for pred, gt in zip(predictions, ground_truths):
|
| 179 |
+
score = calculate_anls(pred, gt, threshold)
|
| 180 |
+
anls_scores.append(score)
|
| 181 |
+
avg_anls = (sum(anls_scores) / len(anls_scores)) * 100 if anls_scores else 0
|
| 182 |
+
return avg_anls, anls_scores
|
| 183 |
+
if __name__ == "__main__":
|
| 184 |
+
print("=" * 80)
|
| 185 |
+
print("VQA EVALUATION: ACCURACY + ANLS")
|
| 186 |
+
print("=" * 80)
|
| 187 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 188 |
+
print(f"\nDevice: {device}")
|
| 189 |
+
print(f"Model: {MODEL_TYPE.upper()}\n")
|
| 190 |
+
if MODEL_TYPE == "spatial":
|
| 191 |
+
model, vocab, tokenizer = load_spatial_model(SPATIAL_CHECKPOINT, device)
|
| 192 |
+
else:
|
| 193 |
+
model, vocab, tokenizer = load_feature_model(FEATURE_CHECKPOINT, device)
|
| 194 |
+
print("✓ Model loaded!\n")
|
| 195 |
+
df = pd.read_csv(CSV_PATH)
|
| 196 |
+
if MAX_SAMPLES:
|
| 197 |
+
df = df.head(MAX_SAMPLES)
|
| 198 |
+
print(f"Evaluating {len(df)} samples\n")
|
| 199 |
+
print("Generating predictions...")
|
| 200 |
+
predictions = []
|
| 201 |
+
ground_truths = []
|
| 202 |
+
for idx, row in tqdm(df.iterrows(), total=len(df)):
|
| 203 |
+
image_path = os.path.join(IMG_DIR, row['image_path'])
|
| 204 |
+
if not os.path.exists(image_path):
|
| 205 |
+
continue
|
| 206 |
+
try:
|
| 207 |
+
prediction = generate_answer(model, vocab, tokenizer,
|
| 208 |
+
image_path, row['question'], device)
|
| 209 |
+
ground_truth = row['answer']
|
| 210 |
+
predictions.append(prediction)
|
| 211 |
+
ground_truths.append(ground_truth)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
continue
|
| 214 |
+
print(f"\n✓ Generated {len(predictions)} predictions\n")
|
| 215 |
+
print("Calculating metrics...\n")
|
| 216 |
+
exact_acc, exact_matches = exact_match_accuracy(predictions, ground_truths)
|
| 217 |
+
vqa_acc = vqa_accuracy(predictions, ground_truths)
|
| 218 |
+
anls_score, anls_scores = average_anls(predictions, ground_truths, threshold=0.5)
|
| 219 |
+
print("=" * 80)
|
| 220 |
+
print("EVALUATION RESULTS")
|
| 221 |
+
print("=" * 80)
|
| 222 |
+
print(f"\n📊 Accuracy Metrics:")
|
| 223 |
+
print(f" Exact Match Accuracy: {exact_acc:.2f}% ({exact_matches}/{len(predictions)})")
|
| 224 |
+
print(f" VQA Accuracy: {vqa_acc:.2f}%")
|
| 225 |
+
print(f"\n📊 ANLS Metrics:")
|
| 226 |
+
print(f" Average ANLS (τ=0.5): {anls_score:.2f}%")
|
| 227 |
+
print(f" ANLS Std Dev: {np.std(anls_scores)*100:.2f}%")
|
| 228 |
+
print(f"\n📊 Additional Statistics:")
|
| 229 |
+
print(f" Total samples: {len(predictions)}")
|
| 230 |
+
print(f" Avg prediction length: {np.mean([len(p.split()) for p in predictions]):.2f} words")
|
| 231 |
+
print(f" Avg GT length: {np.mean([len(gt.split()) for gt in ground_truths]):.2f} words")
|
| 232 |
+
print("\n" + "=" * 80)
|
| 233 |
+
print("SAMPLE PREDICTIONS")
|
| 234 |
+
print("=" * 80)
|
| 235 |
+
sorted_indices = np.argsort(anls_scores)
|
| 236 |
+
print("\n🏆 Best Predictions (Highest ANLS):")
|
| 237 |
+
print("-" * 80)
|
| 238 |
+
for i in sorted_indices[-5:][::-1]:
|
| 239 |
+
print(f"\nGround Truth: {ground_truths[i]}")
|
| 240 |
+
print(f"Prediction: {predictions[i]}")
|
| 241 |
+
print(f"ANLS: {anls_scores[i]:.4f}")
|
| 242 |
+
print(f"Exact Match: {'✓' if predictions[i].strip().lower() == ground_truths[i].strip().lower() else '✗'}")
|
| 243 |
+
print("\n" + "=" * 80)
|
| 244 |
+
print("⚠️ Worst Predictions (Lowest ANLS):")
|
| 245 |
+
print("-" * 80)
|
| 246 |
+
for i in sorted_indices[:5]:
|
| 247 |
+
print(f"\nGround Truth: {ground_truths[i]}")
|
| 248 |
+
print(f"Prediction: {predictions[i]}")
|
| 249 |
+
print(f"ANLS: {anls_scores[i]:.4f}")
|
| 250 |
+
print(f"Exact Match: {'✓' if predictions[i].strip().lower() == ground_truths[i].strip().lower() else '✗'}")
|
| 251 |
+
print("\n" + "=" * 80)
|
| 252 |
+
print("✅ EVALUATION COMPLETE")
|
| 253 |
+
print("=" * 80)
|
| 254 |
+
with open(f"{MODEL_TYPE}.txt", "w", encoding="utf-8") as f:
|
| 255 |
+
f.write("=" * 80 + "\n")
|
| 256 |
+
f.write("EVALUATION RESULTS\n")
|
| 257 |
+
f.write("=" * 80 + "\n")
|
| 258 |
+
f.write("\n📊 Accuracy Metrics:\n")
|
| 259 |
+
f.write(f" Exact Match Accuracy: {exact_acc:.2f}% ({exact_matches}/{len(predictions)})\n")
|
| 260 |
+
f.write(f" VQA Accuracy: {vqa_acc:.2f}%\n")
|
| 261 |
+
f.write("\n📊 ANLS Metrics:\n")
|
| 262 |
+
f.write(f" Average ANLS (τ=0.5): {anls_score:.2f}%\n")
|
| 263 |
+
f.write(f" ANLS Std Dev: {np.std(anls_scores)*100:.2f}%\n")
|
| 264 |
+
f.write("\n📊 Additional Statistics:\n")
|
| 265 |
+
f.write(f" Total samples: {len(predictions)}\n")
|
| 266 |
+
f.write(f" Avg prediction length: {np.mean([len(p.split()) for p in predictions]):.2f} words\n")
|
| 267 |
+
f.write(f" Avg GT length: {np.mean([len(gt.split()) for gt in ground_truths]):.2f} words\n")
|
| 268 |
+
f.write("\n" + "=" * 80 + "\n")
|
| 269 |
+
f.write("SAMPLE PREDICTIONS\n")
|
| 270 |
+
f.write("=" * 80 + "\n")
|
| 271 |
+
sorted_indices = np.argsort(anls_scores)
|
| 272 |
+
f.write("\n🏆 Best Predictions (Highest ANLS):\n")
|
| 273 |
+
f.write("-" * 80 + "\n")
|
| 274 |
+
for i in sorted_indices[-5:][::-1]:
|
| 275 |
+
f.write(f"\nGround Truth: {ground_truths[i]}\n")
|
| 276 |
+
f.write(f"Prediction: {predictions[i]}\n")
|
| 277 |
+
f.write(f"ANLS: {anls_scores[i]:.4f}\n")
|
| 278 |
+
f.write(
|
| 279 |
+
f"Exact Match: {'✓' if predictions[i].strip().lower() == ground_truths[i].strip().lower() else '✗'}\n"
|
| 280 |
+
)
|
| 281 |
+
f.write("\n" + "=" * 80 + "\n")
|
| 282 |
+
f.write("⚠️ Worst Predictions (Lowest ANLS):\n")
|
| 283 |
+
f.write("-" * 80 + "\n")
|
| 284 |
+
for i in sorted_indices[:5]:
|
| 285 |
+
f.write(f"\nGround Truth: {ground_truths[i]}\n")
|
| 286 |
+
f.write(f"Prediction: {predictions[i]}\n")
|
| 287 |
+
f.write(f"ANLS: {anls_scores[i]:.4f}\n")
|
| 288 |
+
f.write(
|
| 289 |
+
f"Exact Match: {'✓' if predictions[i].strip().lower() == ground_truths[i].strip().lower() else '✗'}\n"
|
| 290 |
+
)
|
| 291 |
+
results_df = pd.DataFrame({
|
| 292 |
+
'prediction': predictions,
|
| 293 |
+
'ground_truth': ground_truths,
|
| 294 |
+
'anls_score': anls_scores,
|
| 295 |
+
'exact_match': [pred.strip().lower() == gt.strip().lower()
|
| 296 |
+
for pred, gt in zip(predictions, ground_truths)]
|
| 297 |
+
})
|
| 298 |
+
output_file = f"vqa_evaluation_{MODEL_TYPE}.csv"
|
| 299 |
+
results_df.to_csv(output_file, index=False)
|
| 300 |
+
print(f"\n💾 Results saved to: {output_file}")
|
scores/vqa_evaluation_feature.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|