Deva8 commited on
Commit
bb8f662
·
1 Parent(s): 016e102

Deploy VQA Space with model downloader

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.example +6 -0
  2. .gitattributes +12 -34
  3. DATASET_CARD.md +250 -0
  4. Dockerfile +23 -0
  5. HOW_TO_RUN.md +255 -0
  6. PATTERN_MATCHING_FIX.md +86 -0
  7. QUICK_START.md +196 -0
  8. README.md +203 -7
  9. README_COMPLETE.md +530 -0
  10. SETUP_GUIDE.md +118 -0
  11. VQA_ENHANCEMENTS.md +298 -0
  12. __pycache__/backend_api.cpython-312.pyc +0 -0
  13. __pycache__/conversation_manager.cpython-312.pyc +0 -0
  14. __pycache__/ensemble_vqa_app.cpython-312.pyc +0 -0
  15. __pycache__/groq_service.cpython-312.pyc +0 -0
  16. __pycache__/knowledge_graph_service.cpython-312.pyc +0 -0
  17. __pycache__/llm_reasoning_service.cpython-312.pyc +0 -0
  18. __pycache__/model_spatial.cpython-312.pyc +0 -0
  19. __pycache__/semantic_neurosymbolic_vqa.cpython-312.pyc +0 -0
  20. architecture_draft.html +89 -0
  21. architecture_draft.mmd +69 -0
  22. backend_api.py +341 -0
  23. continue.py +344 -0
  24. continued_training_metric.csv +21 -0
  25. conversation_manager.py +312 -0
  26. download_models.py +27 -0
  27. draft_generator.py +112 -0
  28. ensemble_vqa_app.py +458 -0
  29. enterprise_architecture.drawio +341 -0
  30. exp_results/feature_extraction_metric.csv +31 -0
  31. experiments/__pycache__/train.cpython-312.pyc +0 -0
  32. experiments/test.py +73 -0
  33. experiments/train.py +349 -0
  34. experiments/utils/preprocess.py +164 -0
  35. experiments/utils/vocab.py +65 -0
  36. finetune.py +220 -0
  37. finetune2.py +395 -0
  38. genvqa-dataset.py +78 -0
  39. groq_service.py +118 -0
  40. knowledge_graph_service.py +291 -0
  41. llm_reasoning_service.py +292 -0
  42. model.py +224 -0
  43. model_spatial.py +309 -0
  44. models/__pycache__/model.cpython-312.pyc +0 -0
  45. models/model.py +224 -0
  46. quick_start.bat +71 -0
  47. requirements_api.txt +14 -0
  48. scores/feature.txt +77 -0
  49. scores/score.py +300 -0
  50. scores/vqa_evaluation_feature.csv +0 -0
.env.example ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Groq API Configuration
2
+ # Get your API key from: https://console.groq.com/keys
3
+ GROQ_API_KEY=your_groq_api_key_here
4
+
5
+ # Optional: Model selection (default: llama-3.3-70b-versatile)
6
+ # GROQ_MODEL=llama-3.3-70b-versatile
.gitattributes CHANGED
@@ -1,35 +1,13 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  *.pt filter=lfs diff=lfs merge=lfs -text
2
+ *.bin filter=lfs diff=lhs merge=lfs -text
3
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
4
+ *.json
5
+ filter=lfs
6
+ diff=lfs
7
+ merge=lfs
8
+ -text
9
+ *.csv
10
+ filter=lfs
11
+ diff=lfs
12
+ merge=lfs
13
+ -text
 
DATASET_CARD.md ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VQA v2 Curated Dataset for Spatial Reasoning
2
+
3
+ ## Dataset Description
4
+
5
+ This is a **curated and balanced subset** of the VQA v2 (Visual Question Answering v2.0) dataset, specifically preprocessed for training visual question answering models with enhanced spatial reasoning capabilities.
6
+
7
+ ### Dataset Summary
8
+
9
+ - **Source**: VQA v2 (MSCOCO train2014 split)
10
+ - **Task**: Visual Question Answering
11
+ - **Language**: English
12
+ - **License**: CC BY 4.0 (inherited from VQA v2)
13
+
14
+ ### Key Features
15
+
16
+ ✨ **Quality-Focused Curation**:
17
+ - Filtered out ambiguous yes/no questions
18
+ - Removed vague questions ("what is in the image", etc.)
19
+ - Answer length limited to 5 words / 30 characters
20
+ - Minimum answer frequency threshold (20 occurrences)
21
+
22
+ 🎯 **Balanced Distribution**:
23
+ - Maximum 600 samples per answer class
24
+ - Prevents model bias toward common answers
25
+ - Ensures diverse question-answer coverage
26
+
27
+ 📊 **Dataset Statistics**:
28
+ - **Total Q-A pairs**: ~[Your final count from running the script]
29
+ - **Unique answers**: ~[Number of unique answer classes]
30
+ - **Images**: MSCOCO train2014 subset
31
+ - **Format**: JSON + CSV metadata
32
+
33
+ ---
34
+
35
+ ## Dataset Structure
36
+
37
+ ### Data Fields
38
+
39
+ Each sample contains:
40
+
41
+ ```json
42
+ {
43
+ "image_id": 123456, // MSCOCO image ID
44
+ "question_id": 789012, // VQA v2 question ID
45
+ "question": "What color is the car?",
46
+ "answer": "red", // Most frequent answer from annotators
47
+ "image_path": "images/COCO_train2014_000000123456.jpg"
48
+ }
49
+ ```
50
+
51
+ ### Data Splits
52
+
53
+ - **Training**: Main dataset (recommend 80-90% for training)
54
+ - **Validation**: User-defined split (recommend 10-20% for validation)
55
+
56
+ ### File Structure
57
+
58
+ ```
59
+ gen_vqa_v2/
60
+ ├── images/ # MSCOCO train2014 images
61
+ │ └── COCO_train2014_*.jpg
62
+ ├── qa_pairs.json # Question-answer pairs (JSON)
63
+ └── metadata.csv # Same data in CSV format
64
+ ```
65
+
66
+ ---
67
+
68
+ ## Data Preprocessing
69
+
70
+ ### Filtering Criteria
71
+
72
+ **Excluded Answers**:
73
+ - Generic responses: `yes`, `no`, `unknown`, `none`, `n/a`, `cant tell`, `not sure`
74
+
75
+ **Excluded Questions**:
76
+ - Ambiguous queries: "what is in the image", "what is this", "what is that", "what do you see"
77
+
78
+ **Answer Constraints**:
79
+ - Maximum 5 words per answer
80
+ - Maximum 30 characters per answer
81
+ - Minimum frequency: 20 occurrences across dataset
82
+
83
+ **Balancing Strategy**:
84
+ - Maximum 600 samples per answer class
85
+ - Prevents over-representation of common answers (e.g., "white", "2")
86
+
87
+ ### Preprocessing Script
88
+
89
+ The dataset was generated using `genvqa-dataset.py`:
90
+
91
+ ```python
92
+ # Key parameters
93
+ MIN_ANSWER_FREQ = 20 # Minimum answer occurrences
94
+ MAX_SAMPLES_PER_ANSWER = 600 # Class balancing limit
95
+ ```
96
+
97
+ ---
98
+
99
+ ## Intended Use
100
+
101
+ ### Primary Use Cases
102
+
103
+ ✅ **Training VQA Models**:
104
+ - Visual question answering systems
105
+ - Multimodal vision-language models
106
+ - Spatial reasoning research
107
+
108
+ ✅ **Research Applications**:
109
+ - Evaluating spatial understanding in VQA
110
+ - Studying answer distribution bias
111
+ - Benchmarking ensemble architectures
112
+
113
+ ### Out-of-Scope Use
114
+
115
+ ❌ Medical diagnosis or safety-critical applications
116
+ ❌ Surveillance or privacy-invasive systems
117
+ ❌ Generating misleading or harmful content
118
+
119
+ ---
120
+
121
+ ## Dataset Creation
122
+
123
+ ### Source Data
124
+
125
+ **VQA v2 Dataset**:
126
+ - **Paper**: [Making the V in VQA Matter](https://arxiv.org/abs/1612.00837)
127
+ - **Authors**: Goyal et al. (2017)
128
+ - **Images**: MSCOCO train2014
129
+ - **Original Size**: 443,757 question-answer pairs (train split)
130
+
131
+ ### Curation Rationale
132
+
133
+ This curated subset addresses common VQA training challenges:
134
+
135
+ 1. **Bias Reduction**: Limits over-represented answers
136
+ 2. **Quality Control**: Removes ambiguous/uninformative samples
137
+ 3. **Spatial Focus**: Retains questions requiring spatial reasoning
138
+ 4. **Practical Constraints**: Focuses on concise, specific answers
139
+
140
+ ### Annotations
141
+
142
+ Annotations are inherited from VQA v2:
143
+ - 10 answers per question from human annotators
144
+ - **Answer selection**: Most frequent answer among annotators
145
+ - **Consensus**: Majority voting for ground truth
146
+
147
+ ---
148
+
149
+ ## Considerations for Using the Data
150
+
151
+ ### Social Impact
152
+
153
+ This dataset inherits biases from MSCOCO and VQA v2:
154
+ - **Geographic bias**: Primarily Western/North American scenes
155
+ - **Cultural bias**: Limited representation of global diversity
156
+ - **Object bias**: Common objects over-represented
157
+
158
+ ### Limitations
159
+
160
+ ⚠️ **Known Issues**:
161
+ - Answer distribution still skewed toward common objects (e.g., "white", "2", "yes")
162
+ - Spatial reasoning questions may be underrepresented
163
+ - Some questions may have multiple valid answers
164
+
165
+ ⚠️ **Not Suitable For**:
166
+ - Fine-grained visual reasoning (e.g., "How many stripes on the 3rd zebra?")
167
+ - Rare object recognition
168
+ - Non-English languages
169
+
170
+ ---
171
+
172
+ ## Citation
173
+
174
+ ### BibTeX
175
+
176
+ ```bibtex
177
+ @inproceedings{goyal2017making,
178
+ title={Making the V in VQA Matter: Elevating the Role of Image Understanding in Visual Question Answering},
179
+ author={Goyal, Yash and Khot, Tejas and Summers-Stay, Douglas and Batra, Dhruv and Parikh, Devi},
180
+ booktitle={CVPR},
181
+ year={2017}
182
+ }
183
+ ```
184
+
185
+ ### Original VQA v2 Dataset
186
+
187
+ - **Homepage**: https://visualqa.org/
188
+ - **Paper**: https://arxiv.org/abs/1612.00837
189
+ - **License**: CC BY 4.0
190
+
191
+ ---
192
+
193
+ ## Additional Information
194
+
195
+ ### Dataset Curators
196
+
197
+ Curated from VQA v2 by [Your Name/Organization]
198
+
199
+ ### Licensing
200
+
201
+ This dataset is released under **CC BY 4.0**, consistent with the original VQA v2 license.
202
+
203
+ ### Contact
204
+
205
+ For questions or issues, please contact [your email/GitHub].
206
+
207
+ ---
208
+
209
+ ## Usage Example
210
+
211
+ ### Loading the Dataset
212
+
213
+ ```python
214
+ import json
215
+ import pandas as pd
216
+ from PIL import Image
217
+
218
+ # Load metadata
219
+ with open("gen_vqa_v2/qa_pairs.json", "r") as f:
220
+ data = json.load(f)
221
+
222
+ # Or use CSV
223
+ df = pd.read_csv("gen_vqa_v2/metadata.csv")
224
+
225
+ # Access a sample
226
+ sample = data[0]
227
+ image = Image.open(f"gen_vqa_v2/{sample['image_path']}")
228
+ question = sample['question']
229
+ answer = sample['answer']
230
+
231
+ print(f"Q: {question}")
232
+ print(f"A: {answer}")
233
+ ```
234
+
235
+ ### Training Split
236
+
237
+ ```python
238
+ from sklearn.model_selection import train_test_split
239
+
240
+ # 80-20 train-val split
241
+ train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
242
+ ```
243
+
244
+ ---
245
+
246
+ ## Acknowledgments
247
+
248
+ - **VQA v2 Team**: Goyal et al. for the original dataset
249
+ - **MSCOCO Team**: Lin et al. for the image dataset
250
+ - **Community**: Open-source VQA research community
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
2
+
3
+ WORKDIR /app
4
+
5
+ # System deps
6
+ RUN apt-get update && apt-get install -y \
7
+ git \
8
+ libglib2.0-0 \
9
+ libsm6 \
10
+ libxrender1 \
11
+ libxext6 \
12
+ libgl1-mesa-glx \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Install Python deps
16
+ COPY requirements_api.txt .
17
+ RUN pip install --no-cache-dir -r requirements_api.txt
18
+
19
+ # Copy all project files
20
+ COPY . .
21
+
22
+ # Download models before starting server
23
+ CMD python download_models.py && uvicorn backend_api:app --host 0.0.0.0 --port 7860
HOW_TO_RUN.md ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 How to Run the VQA Mobile App
2
+
3
+ ## Quick Overview
4
+
5
+ You now have a complete React Native mobile app for Visual Question Answering! Here's what was created:
6
+
7
+ ### ✅ What's Built
8
+
9
+ 1. **Backend API** (`backend_api.py`)
10
+ - FastAPI server wrapping your ensemble VQA models
11
+ - Automatic routing between base and spatial models
12
+ - Image upload and question answering endpoints
13
+
14
+ 2. **Mobile App** (`ui/` folder)
15
+ - Beautiful React Native app with Expo
16
+ - Google OAuth authentication
17
+ - Camera and gallery image picker
18
+ - Question input and answer display
19
+ - Model routing visualization
20
+
21
+ ## 🎯 Running the App (3 Steps)
22
+
23
+ ### Step 1: Start the Backend Server
24
+
25
+ ```bash
26
+ # Open PowerShell/Terminal
27
+ cd c:\Users\rdeva\Downloads\vqa_coes
28
+
29
+ # Install API dependencies (FIRST TIME ONLY)
30
+ # If you get import errors, run this:
31
+ pip install fastapi uvicorn python-multipart
32
+
33
+ # Start the server
34
+ python start_backend.py
35
+ # Or: python backend_api.py
36
+ ```
37
+
38
+ > **Note**: If you get "ModuleNotFoundError", see [IMPORT_ERRORS_FIX.md](file:///c:/Users/rdeva/Downloads/vqa_coes/IMPORT_ERRORS_FIX.md) for solutions.
39
+
40
+ ✅ **Keep this window open!** The server must stay running.
41
+
42
+ You should see:
43
+ ```
44
+ 🚀 INITIALIZING ENSEMBLE VQA SYSTEM
45
+ ✅ Ensemble ready!
46
+ ```
47
+
48
+ ### Step 2: Configure the Mobile App
49
+
50
+ 1. **Find your local IP address:**
51
+ ```bash
52
+ ipconfig
53
+ ```
54
+ Look for "IPv4 Address" (e.g., `192.168.1.100`)
55
+
56
+ 2. **Update the API URL:**
57
+ - Open: `ui\src\config\api.js`
58
+ - Change line 8:
59
+ ```javascript
60
+ export const API_BASE_URL = 'http://YOUR_IP_HERE:8000';
61
+ ```
62
+ - Example:
63
+ ```javascript
64
+ export const API_BASE_URL = 'http://192.168.1.100:8000';
65
+ ```
66
+
67
+ ### Step 3: Start the Mobile App
68
+
69
+ ```bash
70
+ # Open a NEW PowerShell/Terminal window
71
+ cd c:\Users\rdeva\Downloads\vqa_coes\ui
72
+
73
+ # Start Expo
74
+ npm start
75
+ ```
76
+
77
+ You'll see a QR code in the terminal.
78
+
79
+ ### Step 4: Run on Your Phone
80
+
81
+ 1. **Install Expo Go** on your smartphone:
82
+ - [Android - Play Store](https://play.google.com/store/apps/details?id=host.exp.exponent)
83
+ - [iOS - App Store](https://apps.apple.com/app/expo-go/id982107779)
84
+
85
+ 2. **Scan the QR code:**
86
+ - Android: Open Expo Go → Scan QR
87
+ - iOS: Open Camera → Scan QR → Tap notification
88
+
89
+ 3. **Wait for the app to load** (first time takes ~1-2 minutes)
90
+
91
+ ## 📱 Using the App
92
+
93
+ ### Option A: Test Without Google Login
94
+
95
+ For quick testing, you can bypass Google authentication:
96
+
97
+ 1. Open `ui\App.js`
98
+ 2. Find line 23-27 and replace with:
99
+ ```javascript
100
+ <Stack.Screen name="Home" component={HomeScreen} />
101
+ ```
102
+ 3. Save and reload the app (shake phone → Reload)
103
+
104
+ ### Option B: Set Up Google Login
105
+
106
+ 1. Go to [Google Cloud Console](https://console.cloud.google.com/)
107
+ 2. Create a new project
108
+ 3. Enable Google+ API
109
+ 4. Create OAuth 2.0 credentials
110
+ 5. Update `ui\src\config\google.js` with your client IDs
111
+
112
+ ### Testing VQA Functionality
113
+
114
+ 1. **Select an image:**
115
+ - Tap "Camera" to take a photo
116
+ - Tap "Gallery" to choose existing image
117
+
118
+ 2. **Ask a question:**
119
+ - Type your question (e.g., "What color is the car?")
120
+ - Tap "Ask Question"
121
+
122
+ 3. **View the answer:**
123
+ - See the AI-generated answer
124
+ - Check which model was used:
125
+ - 🔍 **Base Model** - General questions
126
+ - 📍 **Spatial Model** - Spatial questions (left, right, above, etc.)
127
+
128
+ ## 🧪 Example Questions to Try
129
+
130
+ ### General Questions (Base Model 🔍)
131
+ - "What color is the car?"
132
+ - "How many people are in the image?"
133
+ - "What room is this?"
134
+ - "Is there a dog?"
135
+
136
+ ### Spatial Questions (Spatial Model 📍)
137
+ - "What is to the right of the table?"
138
+ - "What is above the chair?"
139
+ - "What is next to the door?"
140
+ - "What is on the left side?"
141
+
142
+ ## 🔧 Troubleshooting
143
+
144
+ ### "Cannot connect to server"
145
+ - ✅ Check backend is running (`python backend_api.py`)
146
+ - ✅ Verify IP address in `api.js` matches your computer's IP
147
+ - ✅ Ensure phone and computer are on the **same WiFi network**
148
+ - ✅ Check Windows Firewall isn't blocking port 8000
149
+
150
+ ### "Model not loaded"
151
+ - ✅ Ensure these files exist in `c:\Users\rdeva\Downloads\vqa_coes\`:
152
+ - `vqa_checkpoint.pt`
153
+ - `vqa_spatial_checkpoint.pt`
154
+ - ✅ Check backend terminal for error messages
155
+
156
+ ### App won't load on phone
157
+ - ✅ Verify Expo Go is installed
158
+ - ✅ Both devices on same WiFi
159
+ - ✅ Try restarting Expo: Press `Ctrl+C`, then `npm start`
160
+ - ✅ Clear cache: `npm start -- --clear`
161
+
162
+ ### Camera/Gallery not working
163
+ - ✅ Grant permissions when prompted
164
+ - ✅ Check phone Settings → App Permissions
165
+
166
+ ## 📁 Project Structure
167
+
168
+ ```
169
+ vqa_coes/
170
+ ├── backend_api.py # FastAPI backend server
171
+ ├── ensemble_vqa_app.py # Your existing ensemble system
172
+ ├── model_spatial.py # Spatial model
173
+ ├── models/model.py # Base model
174
+ ├── vqa_checkpoint.pt # Base model weights
175
+ ├── vqa_spatial_checkpoint.pt # Spatial model weights
176
+ ├── requirements_api.txt # Backend dependencies
177
+ ��── QUICK_START.md # This guide
178
+ └── ui/ # Mobile app
179
+ ├── App.js # Main app component
180
+ ├── app.json # Expo configuration
181
+ ├── package.json # Dependencies
182
+ └── src/
183
+ ├── config/
184
+ │ ├── api.js # ⚠️ UPDATE YOUR IP HERE
185
+ │ └── google.js # Google OAuth config
186
+ ├── contexts/
187
+ │ └── AuthContext.js # Authentication
188
+ ├── screens/
189
+ │ ├── LoginScreen.js # Login UI
190
+ │ └── HomeScreen.js # Main VQA UI
191
+ ├── services/
192
+ │ └── api.js # API client
193
+ └── styles/
194
+ ├── theme.js # Design system
195
+ └── globalStyles.js
196
+ ```
197
+
198
+ ## 📚 Documentation
199
+
200
+ - **Quick Start**: `QUICK_START.md` (this file)
201
+ - **Full README**: `ui/README.md`
202
+ - **Implementation Details**: See walkthrough artifact
203
+
204
+ ## 🎨 Customization
205
+
206
+ ### Change Colors
207
+ Edit `ui/src/styles/theme.js`:
208
+ ```javascript
209
+ colors: {
210
+ primary: '#6366F1', // Change to your color
211
+ secondary: '#EC4899', // Change to your color
212
+ // ...
213
+ }
214
+ ```
215
+
216
+ ### Change App Name
217
+ Edit `ui/app.json`:
218
+ ```json
219
+ {
220
+ "expo": {
221
+ "name": "Your App Name",
222
+ "slug": "your-app-slug"
223
+ }
224
+ }
225
+ ```
226
+
227
+ ## 🚢 Next Steps
228
+
229
+ Once everything works:
230
+
231
+ 1. **Add Google OAuth** for production
232
+ 2. **Create custom icons** (see `ui/assets/ICONS_README.md`)
233
+ 3. **Build standalone app**:
234
+ ```bash
235
+ npx eas-cli build --platform android
236
+ ```
237
+
238
+ ## 💡 Tips
239
+
240
+ - **Backend must run first** before starting the mobile app
241
+ - **Same WiFi network** is required for phone and computer
242
+ - **First load is slow** - subsequent loads are faster
243
+ - **Shake phone** to access Expo developer menu
244
+ - **Check logs** in both terminals for debugging
245
+
246
+ ## 🆘 Need Help?
247
+
248
+ 1. Check the troubleshooting section above
249
+ 2. Review backend terminal for errors
250
+ 3. Check Expo console in terminal
251
+ 4. Verify all configuration steps
252
+
253
+ ---
254
+
255
+ **Ready to test?** Follow the 4 steps above and start asking questions about images! 🎉
PATTERN_MATCHING_FIX.md ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fix: Removed Hardcoded Patterns from Neuro-Symbolic VQA
2
+
3
+ ## Problem Identified
4
+ The `_detect_objects_with_clip()` method in `semantic_neurosymbolic_vqa.py` contained a **predefined list of object categories**, which is essentially pattern matching and defeats the purpose of a truly neuro-symbolic approach.
5
+
6
+ ```python
7
+ # ❌ OLD CODE - Hardcoded categories (pattern matching!)
8
+ object_categories = [
9
+ "food", "soup", "noodles", "rice", "meat", "vegetable", "fruit",
10
+ "bowl", "plate", "cup", "glass", "spoon", "fork", "knife", ...
11
+ ]
12
+ ```
13
+
14
+ This is **not acceptable** because:
15
+ - It limits detection to predefined categories only
16
+ - It's essentially pattern matching, not true neural understanding
17
+ - It violates the neuro-symbolic principle of learning from data
18
+
19
+ ## Solution Applied
20
+
21
+ ### 1. Deprecated `_detect_objects_with_clip()`
22
+ The method now returns an empty list and warns that it's deprecated:
23
+
24
+ ```python
25
+ # ✅ NEW CODE - No predefined lists!
26
+ def _detect_objects_with_clip(self, image_features, image_path=None):
27
+ """
28
+ NOTE: This method is deprecated in favor of using the VQA model
29
+ directly from ensemble_vqa_app.py.
30
+ """
31
+ print("⚠️ _detect_objects_with_clip is deprecated")
32
+ print("→ Use VQA model's _detect_multiple_objects() instead")
33
+ return []
34
+ ```
35
+
36
+ ### 2. Updated `answer_with_clip_features()`
37
+ Now **requires** objects to be provided by the VQA model:
38
+
39
+ ```python
40
+ # ✅ Objects must come from VQA model, not predefined lists
41
+ def answer_with_clip_features(
42
+ self,
43
+ image_features,
44
+ question,
45
+ image_path=None,
46
+ detected_objects: List[str] = None # REQUIRED!
47
+ ):
48
+ if not detected_objects:
49
+ print("⚠️ No objects provided - neuro-symbolic reasoning requires VQA-detected objects")
50
+ return None
51
+ ```
52
+
53
+ ### 3. Ensemble VQA Uses True VQA Detection
54
+ The `ensemble_vqa_app.py` already uses `_detect_multiple_objects()` which:
55
+ - Asks the VQA model **open-ended questions** like "What is this?"
56
+ - Uses the model's learned knowledge, not predefined categories
57
+ - Generates objects dynamically based on visual understanding
58
+
59
+ ```python
60
+ # ✅ TRUE NEURO-SYMBOLIC APPROACH
61
+ detected_objects = self._detect_multiple_objects(image, model, top_k=5)
62
+ # This asks VQA model: "What is this?", "What food is this?", etc.
63
+ # NO predefined categories!
64
+ ```
65
+
66
+ ## Result
67
+
68
+ ✅ **Pure Neuro-Symbolic Pipeline**:
69
+ 1. **VQA Model** detects objects using learned visual understanding (no predefined lists)
70
+ 2. **Wikidata** provides factual knowledge about detected objects
71
+ 3. **LLM** performs Chain-of-Thought reasoning on the facts
72
+ 4. **No pattern matching** anywhere in the pipeline
73
+
74
+ ## Files Modified
75
+ - `semantic_neurosymbolic_vqa.py`:
76
+ - Deprecated `_detect_objects_with_clip()`
77
+ - Updated `answer_with_clip_features()` to require VQA-detected objects
78
+ - Changed knowledge source from "CLIP + Wikidata" to "VQA + Wikidata"
79
+
80
+ ## Verification
81
+ The system now uses a **truly neuro-symbolic approach**:
82
+ - ✅ No hardcoded object categories
83
+ - ✅ No predefined patterns
84
+ - ✅ Pure learned visual understanding from VQA model
85
+ - ✅ Symbolic reasoning from Wikidata + LLM
86
+ - ✅ Chain-of-Thought transparency
QUICK_START.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quick Start Guide - VQA Mobile App
2
+
3
+ This guide will help you get the VQA mobile app running quickly.
4
+
5
+ ## Prerequisites Checklist
6
+
7
+ - [ ] Python 3.8+ installed
8
+ - [ ] Node.js 16+ installed
9
+ - [ ] VQA model checkpoints available
10
+ - [ ] Smartphone with Expo Go app installed
11
+ - [ ] Computer and phone on same WiFi network
12
+
13
+ ## Step-by-Step Setup
14
+
15
+ ### Step 1: Start the Backend Server
16
+
17
+ ```bash
18
+ # Open terminal/PowerShell
19
+ cd c:\Users\rdeva\Downloads\vqa_coes
20
+
21
+ # Install backend dependencies (first time only)
22
+ pip install -r requirements_api.txt
23
+
24
+ # Start the server
25
+ python backend_api.py
26
+ ```
27
+
28
+ **Expected output:**
29
+ ```
30
+ 🚀 INITIALIZING ENSEMBLE VQA SYSTEM
31
+ ⚙️ Device: cuda
32
+ 📥 Loading models...
33
+ ✅ Ensemble ready!
34
+ ```
35
+
36
+ **Important:** Keep this terminal window open! The server must keep running.
37
+
38
+ ### Step 2: Find Your Local IP Address
39
+
40
+ **Windows:**
41
+ ```bash
42
+ ipconfig
43
+ ```
44
+ Look for "IPv4 Address" under your WiFi adapter (e.g., `192.168.1.100`)
45
+
46
+ **Mac/Linux:**
47
+ ```bash
48
+ ifconfig
49
+ # or
50
+ ip addr
51
+ ```
52
+
53
+ ### Step 3: Configure the Mobile App
54
+
55
+ 1. Open `ui/src/config/api.js`
56
+ 2. Replace the IP address:
57
+ ```javascript
58
+ export const API_BASE_URL = 'http://YOUR_IP_HERE:8000';
59
+ // Example: export const API_BASE_URL = 'http://192.168.1.100:8000';
60
+ ```
61
+
62
+ ### Step 4: Configure Google OAuth (Optional for Testing)
63
+
64
+ **For testing without Google login**, you can skip this and modify the app to bypass authentication.
65
+
66
+ **For full Google login:**
67
+
68
+ 1. Go to [Google Cloud Console](https://console.cloud.google.com/)
69
+ 2. Create a project
70
+ 3. Enable Google+ API
71
+ 4. Create OAuth 2.0 credentials
72
+ 5. Update `ui/src/config/google.js` with your client IDs
73
+
74
+ ### Step 5: Start the Mobile App
75
+
76
+ ```bash
77
+ # Open a NEW terminal/PowerShell
78
+ cd c:\Users\rdeva\Downloads\vqa_coes\ui
79
+
80
+ # Start Expo
81
+ npm start
82
+ ```
83
+
84
+ **Expected output:**
85
+ ```
86
+ Metro waiting on exp://192.168.1.100:8081
87
+ › Scan the QR code above with Expo Go (Android) or the Camera app (iOS)
88
+ ```
89
+
90
+ ### Step 6: Run on Your Phone
91
+
92
+ 1. **Install Expo Go** on your phone:
93
+ - [Android - Play Store](https://play.google.com/store/apps/details?id=host.exp.exponent)
94
+ - [iOS - App Store](https://apps.apple.com/app/expo-go/id982107779)
95
+
96
+ 2. **Scan the QR code**:
97
+ - Android: Open Expo Go app → Scan QR code
98
+ - iOS: Open Camera app → Scan QR code → Tap notification
99
+
100
+ 3. **Wait for app to load** (first time may take 1-2 minutes)
101
+
102
+ ## Testing Without Google Login
103
+
104
+ If you want to test the VQA functionality without setting up Google OAuth:
105
+
106
+ 1. Open `ui/App.js`
107
+ 2. Temporarily modify the navigation to always show HomeScreen:
108
+
109
+ ```javascript
110
+ // Replace this:
111
+ {user ? (
112
+ <Stack.Screen name="Home" component={HomeScreen} />
113
+ ) : (
114
+ <Stack.Screen name="Login" component={LoginScreen} />
115
+ )}
116
+
117
+ // With this:
118
+ <Stack.Screen name="Home" component={HomeScreen} />
119
+ ```
120
+
121
+ 3. Restart the Expo server
122
+
123
+ ## Testing the App
124
+
125
+ ### Test 1: General Question (Base Model)
126
+ 1. Tap "Gallery" and select an image
127
+ 2. Enter question: "What color is the car?"
128
+ 3. Tap "Ask Question"
129
+ 4. Should show: 🔍 Base Model
130
+
131
+ ### Test 2: Spatial Question (Spatial Model)
132
+ 1. Select an image with multiple objects
133
+ 2. Enter question: "What is to the right of the table?"
134
+ 3. Tap "Ask Question"
135
+ 4. Should show: 📍 Spatial Model
136
+
137
+ ## Troubleshooting
138
+
139
+ ### "Cannot connect to server"
140
+ - ✅ Check backend is running
141
+ - ✅ Verify IP address in `api.js` is correct
142
+ - ✅ Ensure phone and computer on same WiFi
143
+ - ✅ Check firewall isn't blocking port 8000
144
+
145
+ ### "Model not loaded"
146
+ - ✅ Check checkpoint files are in project root
147
+ - ✅ Verify file names: `vqa_checkpoint.pt` and `vqa_spatial_checkpoint.pt`
148
+ - ✅ Check backend terminal for error messages
149
+
150
+ ### App won't load on phone
151
+ - ✅ Ensure Expo Go is installed
152
+ - ✅ Check both devices on same network
153
+ - ✅ Try restarting Expo server (Ctrl+C, then `npm start`)
154
+ - ✅ Clear Expo cache: `npm start -- --clear`
155
+
156
+ ### "Permission denied" for camera/gallery
157
+ - ✅ Grant permissions when prompted
158
+ - ✅ Check phone settings → App permissions
159
+
160
+ ## Next Steps
161
+
162
+ Once everything works:
163
+
164
+ 1. **Set up Google OAuth** for production use
165
+ 2. **Customize the UI** in `src/styles/theme.js`
166
+ 3. **Add custom icons** in `assets/` folder
167
+ 4. **Build standalone app** with `eas build`
168
+
169
+ ## Quick Commands Reference
170
+
171
+ ```bash
172
+ # Start backend
173
+ cd c:\Users\rdeva\Downloads\vqa_coes
174
+ python backend_api.py
175
+
176
+ # Start mobile app
177
+ cd c:\Users\rdeva\Downloads\vqa_coes\ui
178
+ npm start
179
+
180
+ # Clear Expo cache
181
+ npm start -- --clear
182
+
183
+ # Install new package
184
+ npm install package-name
185
+
186
+ # Check backend health
187
+ curl http://localhost:8000/health
188
+ ```
189
+
190
+ ## Support
191
+
192
+ If you encounter issues:
193
+ 1. Check the main README.md
194
+ 2. Review backend terminal logs
195
+ 3. Check Expo console for errors
196
+ 4. Verify all prerequisites are met
README.md CHANGED
@@ -1,10 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Vqa Backend
3
- emoji: 🐢
4
- colorFrom: indigo
5
- colorTo: purple
6
- sdk: docker
7
- pinned: false
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # GenVQA — Generative Visual Question Answering
4
+
5
+ **A neuro-symbolic VQA system that detects objects with a neural model, retrieves structured facts from Wikidata, and generates grounded answers with Groq.**
6
+
7
+ [![Backend CI](https://github.com/DevaRajan8/Generative-vqa/actions/workflows/backend-ci.yml/badge.svg)](https://github.com/DevaRajan8/Generative-vqa/actions/workflows/backend-ci.yml)
8
+ [![UI CI](https://github.com/DevaRajan8/Generative-vqa/actions/workflows/ui-ci.yml/badge.svg)](https://github.com/DevaRajan8/Generative-vqa/actions/workflows/ui-ci.yml)
9
+ ![Python](https://img.shields.io/badge/Python-3.10%2B-blue?logo=python)
10
+ ![License](https://img.shields.io/badge/License-MIT-green)
11
+
12
+ </div>
13
+
14
+ ---
15
+
16
+ ## Architecture
17
+
18
+ ```
19
+ ┌─────────────────────────────────────────────────────────────┐
20
+ │ CLIENT LAYER │
21
+ │ 📱 Expo Mobile App (React Native) │
22
+ │ • Image upload + question input │
23
+ │ • Displays answer + accessibility description │
24
+ └────────────────────────┬────────────────────────────────────┘
25
+ │ HTTP POST /api/answer
26
+
27
+ ┌─────────────────────────────────────────────────────────────┐
28
+ │ BACKEND LAYER (FastAPI) │
29
+ │ backend_api.py │
30
+ │ • Request handling, session management │
31
+ │ • Conversation Manager → multi-turn context tracking │
32
+ └────────────────────────┬────────────────────────────────────┘
33
+
34
+
35
+ ┌─────────────────────────────────────────────────────────────┐
36
+ │ ROUTING LAYER (ensemble_vqa_app.py) │
37
+ │ │
38
+ │ CLIP encodes question → compares against: │
39
+ │ "reasoning question" vs "visual/perceptual question" │
40
+ │ │
41
+ │ Reasoning? Visual? │
42
+ │ │ │ │
43
+ │ ▼ ▼ │
44
+ │ ┌─────────────────┐ ┌─────────────────────┐ │
45
+ │ │ NEURO-SYMBOLIC │ │ NEURAL VQA PATH │ │
46
+ │ │ │ │ │ │
47
+ │ │ 1. VQA model │ │ VQA model (GRU + │ │
48
+ │ │ detects obj │ │ Attention) predicts │ │
49
+ │ │ │ │ answer directly │ │
50
+ │ │ 2. Wikidata API │ └──────────┬──────────┘ │
51
+ │ │ fetches facts│ │ │
52
+ │ │ (P31, P2101, │ │ │
53
+ │ │ P2054, P186,│ │ │
54
+ │ │ P366 ...) │ │ │
55
+ │ │ │ │ │
56
+ │ │ 3. Groq LLM │ │ │
57
+ │ │ verbalizes │ │ │
58
+ │ │ from facts │ │ │
59
+ │ └─────────┬───────┘ │ │
60
+ │ └──────────────┬──────────┘ │
61
+ └────────────────────────── │ ─────────────────────────────┘
62
+
63
+
64
+ ┌─────────────────┐
65
+ │ GROQ SERVICE │
66
+ │ Accessibility │
67
+ │ description │
68
+ │ (2 sentences, │
69
+ │ screen-reader │
70
+ │ friendly) │
71
+ └────��───┬────────┘
72
+
73
+
74
+ JSON response
75
+ { answer, model_used,
76
+ kg_enhancement,
77
+ wikidata_entity,
78
+ description }
79
+ ```
80
+
81
+ | Layer | Component | Role |
82
+ |---|---|---|
83
+ | **Client** | Expo React Native | Image upload, question input, answer display |
84
+ | **API** | FastAPI (`backend_api.py`) | Routing, sessions, conversation state |
85
+ | **Conversation** | `conversation_manager.py` | Multi-turn context, history tracking |
86
+ | **Router** | CLIP (in `ensemble_vqa_app.py`) | Classifies question as reasoning vs visual |
87
+ | **Neural VQA** | GRU + Attention (`model.py`) | Answers visual questions directly from image |
88
+ | **Neuro-Symbolic** | `semantic_neurosymbolic_vqa.py` | VQA detects objects → Wikidata fetches facts → Groq verbalizes |
89
+ | **Accessibility** | `groq_service.py` | Generates spoken-friendly 2-sentence description for every answer |
90
+
91
+ ---
92
+
93
+ ## Features
94
+
95
+ - 🔍 **Visual Question Answering** — trained on VQAv2, fine-tuned on custom data
96
+ - 🧠 **Neuro-Symbolic Routing** — CLIP semantically classifies questions as _reasoning_ vs _visual_, routes accordingly
97
+ - 🌐 **Live Wikidata Facts** — queries physical properties, categories, materials, uses in real time
98
+ - 🤖 **Groq Verbalization** — Llama 3.3 70B answers from structured facts, not hallucination
99
+ - 💬 **Conversational Support** — multi-turn conversation manager with context tracking
100
+ - 📱 **Expo Mobile UI** — React Native app for iOS/Android/Web
101
+ - ♿ **Accessibility** — Groq generates spoken-friendly descriptions for every answer
102
+
103
+ ---
104
+
105
+ ## Quick Start
106
+
107
+ ### 1 — Backend
108
+
109
+ ```bash
110
+ # Clone and install
111
+ git clone https://github.com/DevaRajan8/Generative-vqa.git
112
+ cd Generative-vqa
113
+ pip install -r requirements_api.txt
114
+
115
+ # Set your Groq API key
116
+ cp .env.example .env
117
+ # Edit .env → GROQ_API_KEY=your_key_here
118
+
119
+ # Start API
120
+ python backend_api.py
121
+ # → http://localhost:8000
122
+ ```
123
+
124
+ ### 2 — Mobile UI
125
+
126
+ ```bash
127
+ cd ui
128
+ npm install
129
+ npx expo start --clear
130
+ ```
131
+
132
+ > Scan the QR code with Expo Go, or press `w` for browser.
133
+
134
+ ---
135
+
136
+ ## API
137
+
138
+ | Endpoint | Method | Description |
139
+ |---|---|---|
140
+ | `/api/answer` | POST | Answer a question about an uploaded image |
141
+ | `/api/health` | GET | Health check |
142
+ | `/api/conversation/new` | POST | Start a new conversation session |
143
+
144
+ **Example:**
145
+
146
+ ```bash
147
+ curl -X POST http://localhost:8000/api/answer \
148
+ -F "image=@photo.jpg" \
149
+ -F "question=Can this melt?"
150
+ ```
151
+
152
+ **Response:**
153
+
154
+ ```json
155
+ {
156
+ "answer": "ice",
157
+ "model_used": "neuro-symbolic",
158
+ "kg_enhancement": "Yes — ice can melt. [Wikidata P2101: melting point = 0.0 °C]",
159
+ "knowledge_source": "VQA (neural) + Wikidata (symbolic) + Groq (verbalize)",
160
+ "wikidata_entity": "Q86"
161
+ }
162
+ ```
163
+
164
+ ---
165
+
166
+ ## Project Structure
167
+
168
+ ```
169
+ ├── backend_api.py # FastAPI server
170
+ ├── ensemble_vqa_app.py # VQA orchestrator (routing + inference)
171
+ ├── semantic_neurosymbolic_vqa.py # Wikidata KB + Groq verbalizer
172
+ ├── groq_service.py # Groq accessibility descriptions
173
+ ├── conversation_manager.py # Multi-turn conversation tracking
174
+ ├── model.py # VQA model definition
175
+ ├── train.py # Training pipeline
176
+ ├── ui/ # Expo React Native app
177
+ │ └── src/screens/HomeScreen.js
178
+ └── .github/
179
+ ├── workflows/ # CI — backend lint + UI build
180
+ └── ISSUE_TEMPLATE/
181
+ ```
182
+
183
  ---
184
+
185
+ ## Environment Variables
186
+
187
+ | Variable | Required | Description |
188
+ |---|---|---|
189
+ | `GROQ_API_KEY` | ✅ | Groq API key — [get one free](https://console.groq.com) |
190
+ | `MODEL_PATH` | optional | Path to VQA checkpoint (default: `vqa_checkpoint.pt`) |
191
+ | `PORT` | optional | API server port (default: `8000`) |
192
+
193
  ---
194
 
195
+ ## Requirements
196
+
197
+ - Python 3.10+
198
+ - CUDA GPU recommended (CPU works but is slow)
199
+ - Node.js 20+ (for UI)
200
+ - Groq API key (free tier available)
201
+
202
+ ---
203
+
204
+ ## License
205
+
206
+ MIT © [DevaRajan8](https://github.com/DevaRajan8)
README_COMPLETE.md ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # 🧠 GenVQA — Generative Visual Question Answering
4
+
5
+ **A hybrid neuro-symbolic VQA system that intelligently routes between pure neural networks and knowledge-grounded reasoning**
6
+
7
+ </div>
8
+
9
+ ---
10
+
11
+ ## Overview
12
+
13
+ GenVQA is an advanced Visual Question Answering system that combines the best of both worlds:
14
+
15
+ - **Neural networks** for perception-based visual questions
16
+ - **Symbolic reasoning** for knowledge-intensive reasoning questions
17
+
18
+ The system automatically classifies incoming questions and routes them to the optimal processing pipeline, ensuring accurate and grounded answers.
19
+
20
+ ---
21
+
22
+ ## System Architecture
23
+
24
+ ```
25
+ ┌──────────────────────────────────────────────────────────────────┐
26
+ │ CLIENT │
27
+ │ Expo React Native App (iOS/Android/Web) │
28
+ │ • Image upload via camera/gallery │
29
+ │ • Question input with suggested prompts │
30
+ │ • Multi-turn conversational interface │
31
+ │ • Google OAuth authentication │
32
+ └───────────────────────────┬──────────────────────────────────────┘
33
+ │ HTTP POST /api/answer
34
+
35
+ ┌──────────────────────────────────────────────────────────────────┐
36
+ │ BACKEND API LAYER │
37
+ │ FastAPI (backend_api.py) │
38
+ │ • Request handling & validation │
39
+ │ • Session management & authentication │
40
+ │ • Multi-turn conversation tracking │
41
+ └───────────────────────────┬──────────────────────────────────────┘
42
+
43
+
44
+ ┌──────────────────────────────────────────────────────────────────┐
45
+ │ INTELLIGENT ROUTING LAYER │
46
+ │ (ensemble_vqa_app.py) │
47
+ │ │
48
+ │ CLIP Semantic Classifier: │
49
+ │ Encodes question → Compares similarity: │
50
+ │ "This is a reasoning question about facts" │
51
+ │ vs │
52
+ │ "This is a visual perception question" │
53
+ │ │
54
+ │ Similarity > threshold?
55
+
56
+ │ ├─────────┬────────┐ │
57
+ │ │ │ │ │
58
+ │ REASONING VISUAL SPATIAL │
59
+ │ │ │ │ │
60
+ └─────────────────────┼─────────┼────────┼─────────────────────────┘
61
+ │ │ │
62
+ ┌─────────────┘ │ └─────────────┐
63
+ ▼ ▼ ▼
64
+ ┌──────────────────┐ ┌───────────────────┐ ┌─────────────────┐
65
+ │ NEURO-SYMBOLIC │ │ NEURAL VQA PATH │ │ SPATIAL ADAPTER │
66
+ │ PIPELINE │ │ │ │ PATH │
67
+ │ │ │ CLIP + GRU + │ │ │
68
+ │ ① VQA Model │ │ Attention │ │ Enhanced with │
69
+ │ Detects │ │ │ │ spatial │
70
+ │ Objects │ │ Direct answer │ │ self-attention │
71
+ │ (e.g. "soup") │ │ prediction from │ │ for left/right │
72
+ │ │ │ image features │ │ above/below │
73
+ │ ② Wikidata API │ │ │ │ questions │
74
+ │ Fetches Facts │ │ Outputs: │ │ │
75
+ │ P31: category │ │ "red" │ │ Outputs: │
76
+ │ P186: material│ └───────┬───────────┘ │ "on the left" │
77
+ │ P2101: melting│ │ └────────┬────────┘
78
+ │ P366: use │ │ │
79
+ │ P2054: density│ │ │
80
+ │ │ │ │
81
+ │ ③ Groq LLM │ │ │
82
+ │ Verbalizes │ │ │
83
+ │ from facts │ │ │
84
+ │ (instead
85
+ of free │ │ │
86
+ │ reasoning) │ │ │
87
+ │ │ │ │
88
+ │ Outputs: │ │ │
89
+ │ "Soup is made of │ │ │
90
+ │ water and │ │ │
91
+ │ vegetables, │ │ │
92
+ │ used for eating"│ │ │
93
+ └────────┬─────────┘ │ │
94
+ │ │ │
95
+ └──────────┬──────────┴────────────────────────┘
96
+
97
+ ┌──────────────────────┐
98
+ │ GROQ ACCESSIBILITY │
99
+ │ SERVICE │
100
+ │ │
101
+ │ Generates 2-sentence│
102
+ │ screen-reader │
103
+ │ friendly description│
104
+ │ for every answer │
105
+ └──────────┬───────────┘
106
+
107
+
108
+ JSON Response
109
+ {
110
+ "answer": "...",
111
+ "model_used": "neuro_symbolic|base|spatial",
112
+ "confidence": 0.85,
113
+ "kg_enhancement": true/false,
114
+ "wikidata_entity": "Q123456",
115
+ "description": "...",
116
+ "session_id": "..."
117
+ }
118
+ ```
119
+
120
+ ---
121
+
122
+ ## Neural vs Neuro-Symbolic: Deep Dive
123
+
124
+ ### Neural Pathway
125
+
126
+ **When Used**: Perceptual questions about what's directly visible
127
+
128
+ - _"What color is the car?"_
129
+ - _"How many people are in the image?"_
130
+ - _"Is the dog sitting or standing?"_
131
+
132
+ **Architecture**:
133
+
134
+ ```
135
+ Image Input
136
+
137
+
138
+ ┌─────────────────────────────┐
139
+ │ CLIP Vision Encoder │
140
+ │ (ViT-B/16) │
141
+ │ • Pre-trained on 400M │
142
+ │ image-text pairs │
143
+ │ • 512-dim embeddings │
144
+ └──────────┬──────────────────┘
145
+
146
+
147
+ [512-dim vector] ────────────┐
148
+
149
+ Question Input │
150
+ │ │
151
+ ▼ │
152
+ ┌─────────────────────────────┐ │
153
+ │ GPT-2 Text Encoder │ │
154
+ │ (distilgpt2) │ │
155
+ │ • Contextual embeddings │ │
156
+ │ • 768-dim output │ │
157
+ └──────────┬──────────────────┘ │
158
+ │ │
159
+ ▼ │
160
+ [768-dim vector] │
161
+ │ │
162
+ ▼ │
163
+ ┌──────────────┐ │
164
+ │ Linear Proj │ │
165
+ │ 768 → 512 │ │
166
+ └──────┬───────┘ │
167
+ │ │
168
+ └───────────┬───────────┘
169
+
170
+
171
+ ┌──────────────────────┐
172
+ │ Multimodal Fusion │
173
+ │ • Gated combination │
174
+ │ • 3-layer MLP │
175
+ │ • ReLU + Dropout │
176
+ └──────────┬───────────┘
177
+
178
+
179
+ ┌──────────────────────┐
180
+ │ GRU Decoder with │
181
+ │ Attention Mechanism │
182
+ │ │
183
+ │ • Hidden: 512-dim │
184
+ │ • 2 layers │
185
+ │ • Seq2seq decoding │
186
+ │ • Attention over │
187
+ │ fused features │
188
+ └──────────┬───────────┘
189
+
190
+
191
+ Answer Tokens
192
+ "red car"
193
+ ```
194
+
195
+ **Key Components**:
196
+
197
+ - **CLIP**: Zero-shot image understanding, robust to domain shift
198
+ - **GPT-2**: Contextual question encoding
199
+ - **Attention**: Decoder focuses on relevant image regions per word
200
+ - **GRU**: Sequential answer generation with memory
201
+
202
+ **Training**:
203
+
204
+ - Dataset: VQA v2 (curated, balanced subset)
205
+ - Loss: Cross-entropy over answer vocabulary
206
+ - Fine-tuning: Last 2 CLIP layers + full decoder
207
+ - Accuracy: ~39% on general VQA, ~28% on spatial questions
208
+
209
+ ---
210
+
211
+ ### Neuro-Symbolic Pathway (Knowledge-Grounded Reasoning)
212
+
213
+ **When Used**: Questions requiring external knowledge or reasoning
214
+
215
+ - _"Can soup melt?"_
216
+ - _"What is ice cream made of?"_
217
+ - _"Does this float in water?"_
218
+
219
+ **Architecture**:
220
+
221
+ ```
222
+ Step 1: NEURAL DETECTION
223
+ ─────────────────────────
224
+ Image + Question
225
+
226
+
227
+ ┌──────────────────────┐
228
+ │ VQA Model │
229
+ │ (same as above) │
230
+ │ │
231
+ │ Predicts: "soup" │
232
+ └──────────┬───────────┘
233
+
234
+
235
+ Detected Object
236
+ "soup"
237
+
238
+ Step 2: SYMBOLIC FACT RETRIEVAL
239
+ ────────────────────────────────
240
+ "soup"
241
+
242
+
243
+ ┌──────────────────────────────────┐
244
+ │ Wikidata SPARQL Queries │
245
+ │ │
246
+ │ ① Entity Resolution: │
247
+ │ "soup" → Q41415 (Wikidata ID) │
248
+ │ │
249
+ │ ② Fetch ALL Relevant Properties: │
250
+ │ │
251
+ │ P31 (instance of): │
252
+ │ → "food" │
253
+ │ → "liquid food" │
254
+ │ → "dish" │
255
+ │ │
256
+ │ P186 (made of): │
257
+ │ → "water" │
258
+ │ → "vegetables" │
259
+ │ → "broth" │
260
+ │ │
261
+ │ P366 (used for): │
262
+ │ → "consumption" │
263
+ │ → "nutrition" │
264
+ │ │
265
+ │ P2101 (melting point): │
266
+ │ → (not found) │
267
+ │ │
268
+ │ P2054 (density): │
269
+ │ → ~1000 kg/m³ │
270
+ │ → (floats/sinks calc) │
271
+ │ │
272
+ │ P2777 (flash point): │
273
+ │ → (not found) │
274
+ └──────────────┬───────────────────┘
275
+
276
+
277
+ Structured Knowledge Graph
278
+ {
279
+ "entity": "soup (Q41415)",
280
+ "categories": ["food", "liquid"],
281
+ "materials": ["water", "vegetables"],
282
+ "uses": ["consumption"],
283
+ "density": 1000,
284
+ "melting_point": null
285
+ }
286
+
287
+ Step 3: LLM VERBALIZATION (NOT REASONING!)
288
+ ───────────────────────────────────────────
289
+ Knowledge Graph
290
+
291
+
292
+ ┌────────────────────────────────────┐
293
+ │ Groq API │
294
+ │ (Llama 3.3 70B) │
295
+ │ │
296
+ │ System Prompt: │
297
+ │ "You are a fact verbalizer. │
298
+ │ Answer ONLY from provided │
299
+ │ Wikidata facts. Do NOT use │
300
+ │ your training knowledge. │
301
+ │ If facts don't contain the │
302
+ │ answer, say 'unknown from │
303
+ │ available data'." │
304
+ │ │
305
+ │ User Input: │
306
+ │ Question: "Can soup melt?" │
307
+ │ Facts: {structured data above} │
308
+ └────────────┬───────────────────────┘
309
+
310
+
311
+ Natural Language Answer
312
+ "According to Wikidata, soup is
313
+ a liquid food made of water and
314
+ vegetables. Since it's already
315
+ liquid, it doesn't have a melting
316
+ point like solids do. It can
317
+ freeze, but not melt."
318
+ ```
319
+
320
+ **Critical Design Principle**:
321
+
322
+ > Groq is a **verbalizer**, NOT a reasoner. All reasoning happens in the symbolic layer (Wikidata facts). Groq only translates structured facts into natural language.
323
+
324
+ **Why This Matters**:
325
+
326
+ - **Without facts**: Groq hallucinates from training data
327
+ - **With facts**: Groq grounds answers in real-time data
328
+ - **Result**: Factual accuracy, no made-up information
329
+
330
+ **Knowledge Base Properties Fetched**:
331
+ | Property | Wikidata Code | Example Value |
332
+ |----------|---------------|---------------|
333
+ | Category | P31 | "food", "tool", "animal" |
334
+ | Material | P186 | "metal", "wood", "plastic" |
335
+ | Melting Point | P2101 | 273.15 K (0°C) |
336
+ | Density | P2054 | 917 kg/m³ (floats/sinks) |
337
+ | Use | P366 | "eating", "transportation" |
338
+ | Flash Point | P2777 | 310 K (flammable) |
339
+ | Location | P276 | "ocean", "forest" |
340
+
341
+ ---
342
+
343
+ ### Spatial Reasoning Pathway
344
+
345
+ **When Used**: Questions about relative positions
346
+
347
+ - _"What is to the left of the car?"_
348
+ - _"Is the cat above or below the table?"_
349
+
350
+ **Architecture Enhancement**:
351
+
352
+ ```
353
+ Base VQA Model
354
+
355
+
356
+ ┌──────────────────────────────┐
357
+ │ Spatial Self-Attention │
358
+ │ • Multi-head attention (8) │
359
+ │ • Learns spatial relations │
360
+ │ • Position-aware weighting │
361
+ └──────────┬───────────────────┘
362
+
363
+
364
+ Spatial-aware answer
365
+ "on the left side"
366
+ ```
367
+
368
+ **Keyword Triggering**:
369
+
370
+ - Detects: `left`, `right`, `above`, `below`, `top`, `bottom`, `next to`, `behind`, `between`, etc.
371
+ - Routes to spatial adapter model
372
+ - Enhanced accuracy on positional questions
373
+
374
+ ---
375
+
376
+ ## Intelligent Routing System
377
+
378
+ **CLIP-Based Semantic Routing**:
379
+
380
+ ```python
381
+ # Encode question with CLIP
382
+ question_embedding = clip.encode_text(question)
383
+
384
+ # Compare against two templates
385
+ reasoning_prompt = "This is a reasoning question about facts and knowledge"
386
+ visual_prompt = "This is a visual perception question about what you see"
387
+
388
+ reasoning_similarity = cosine_similarity(question_embedding,
389
+ clip.encode_text(reasoning_prompt))
390
+ visual_similarity = cosine_similarity(question_embedding,
391
+ clip.encode_text(visual_prompt))
392
+
393
+ # Route decision
394
+ if reasoning_similarity > visual_similarity + THRESHOLD:
395
+ route_to_neuro_symbolic()
396
+ elif contains_spatial_keywords(question):
397
+ route_to_spatial_adapter()
398
+ else:
399
+ route_to_base_neural()
400
+ ```
401
+
402
+ **Routing Logic**:
403
+
404
+ 1. **Neuro-Symbolic** if CLIP classifies as reasoning (>0.6 similarity)
405
+ 2. **Spatial** if contains spatial keywords (`left`, `right`, `above`, etc.)
406
+ 3. **Base Neural** for all other visual perception questions
407
+
408
+ ---
409
+
410
+ ## Multi-Turn Conversation Support
411
+
412
+ **Conversation Manager Features**:
413
+
414
+ - Session tracking with UUID
415
+ - Context retention across turns
416
+ - Pronoun resolution (`it`, `this`, `that` → previous object)
417
+ - Automatic session expiry (30 min timeout)
418
+
419
+ **Example Conversation**:
420
+
421
+ ```
422
+ Turn 1:
423
+ User: "What is this?"
424
+ VQA: "A red car"
425
+ Objects: ["car"]
426
+
427
+ Turn 2:
428
+ User: "Can it float?" # "it" = "car"
429
+ System: Resolves "it" → "car"
430
+ VQA: [Neuro-Symbolic] "According to Wikidata, cars are made
431
+ of metal and plastic with density around 800-1000 kg/m³,
432
+ which is close to water. Most cars would sink."
433
+
434
+ Turn 3:
435
+ User: "What color is it again?" # Still referring to car
436
+ VQA: [Neural] "red" # From Turn 1 context
437
+ ```
438
+
439
+ ---
440
+
441
+ ## Quick Start
442
+
443
+ ### Prerequisites
444
+
445
+ - Python 3.10+
446
+ - CUDA GPU (recommended, 4GB+ VRAM)
447
+ - Node.js 16+ (for mobile UI)
448
+ - Groq API key ([get one free](https://console.groq.com))
449
+
450
+ ### Backend Setup
451
+
452
+ ```bash
453
+ # 1. Clone repository
454
+ git clone https://github.com/YourUsername/vqa_coes.git
455
+ cd vqa_coes
456
+
457
+ # 2. Install dependencies
458
+ pip install -r requirements_api.txt
459
+
460
+ # 3. Set environment variables
461
+ echo "GROQ_API_KEY=your_groq_api_key_here" > .env
462
+
463
+ # 4. Download model checkpoints (if not included)
464
+ # Ensure these files exist in project root:
465
+ # - vqa_checkpoint.pt (base model)
466
+ # - vqa_spatial_checkpoint.pt (spatial model)
467
+
468
+ # 5. Start API server
469
+ python backend_api.py
470
+
471
+ # Server will start at http://localhost:8000
472
+ ```
473
+
474
+ ### Mobile UI Setup
475
+
476
+ ```bash
477
+ # 1. Navigate to UI folder
478
+ cd ui
479
+
480
+ # 2. Install dependencies
481
+ npm install
482
+
483
+ # 3. Configure API endpoint
484
+ # Edit ui/src/config/api.js
485
+ # Change: export const API_BASE_URL = 'http://YOUR_LOCAL_IP:8000';
486
+
487
+ # 4. Start Expo
488
+ npx expo start --clear
489
+
490
+ # Scan QR code with Expo Go app, or press 'w' for web
491
+ ```
492
+
493
+ ---
494
+
495
+ ## 🔧 API Reference
496
+
497
+ ### POST `/api/answer`
498
+
499
+ Answer a visual question with optional conversation context.
500
+
501
+ **Request**:
502
+
503
+ ```bash
504
+ curl -X POST http://localhost:8000/api/answer \
505
+ -F "image=@photo.jpg" \
506
+ -F "question=Can this float in water?" \
507
+ -F "session_id=optional-uuid-here"
508
+ ```
509
+
510
+ **Response**:
511
+
512
+ ```json
513
+ {
514
+ "answer": "According to Wikidata, this object has a density of 917 kg/m³, which is less than water (1000 kg/m³), so it would float.",
515
+ "model_used": "neuro_symbolic",
516
+ "confidence": 0.87,
517
+ "kg_enhancement": true,
518
+ "wikidata_entity": "Q41576",
519
+ "description": "The object appears to be made of ice. Based on its physical properties from scientific data, it would float on water due to lower density.",
520
+ "session_id": "550e8400-e29b-41d4-a716-446655440000",
521
+ "conversation_turn": 2
522
+ }
523
+
524
+
525
+ ## 📄 License
526
+
527
+ MIT License - see LICENSE file for details
528
+
529
+ ---
530
+ ```
SETUP_GUIDE.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VQA Accessibility Enhancement - Setup Guide
2
+
3
+ ## Backend Setup
4
+
5
+ ### 1. Install Python Dependencies
6
+ ```bash
7
+ cd c:\Users\rdeva\Downloads\vqa_coes
8
+ pip install -r requirements_api.txt
9
+ ```
10
+
11
+ ### 2. Configure Groq API Key
12
+
13
+ 1. Get your Groq API key from: https://console.groq.com/keys
14
+ 2. Create a `.env` file in the project root:
15
+ ```bash
16
+ copy .env.example .env
17
+ ```
18
+ 3. Edit `.env` and add your API key:
19
+ ```
20
+ GROQ_API_KEY=your_actual_groq_api_key_here
21
+ ```
22
+
23
+ ### 3. Start Backend Server
24
+ ```bash
25
+ python backend_api.py
26
+ ```
27
+
28
+ The server will start on `http://localhost:8000`
29
+
30
+ ---
31
+
32
+ ## Frontend Setup
33
+
34
+ ### 1. Install Node Dependencies
35
+ ```bash
36
+ cd ui
37
+ npm install
38
+ ```
39
+
40
+ This will install the new `expo-speech` package for text-to-speech functionality.
41
+
42
+ ### 2. Start Expo App
43
+ ```bash
44
+ npm start
45
+ ```
46
+
47
+ Then:
48
+ - Press `a` for Android emulator
49
+ - Press `i` for iOS simulator
50
+ - Scan QR code with Expo Go app for physical device
51
+
52
+ ---
53
+
54
+ ## Testing the Features
55
+
56
+ ### Image Display Fix
57
+ 1. Open the app
58
+ 2. Tap "Camera" or "Gallery" to select an image
59
+ 3. **Expected**: Image should display correctly (no blank screen)
60
+
61
+ ### LLM Description Feature
62
+ 1. Upload an image
63
+ 2. Enter a question (e.g., "What color is the car?")
64
+ 3. Tap "Ask Question"
65
+ 4. **Expected**:
66
+ - Original answer appears in the "Answer" card
67
+ - "Accessible Description" card appears below with 2-sentence description
68
+ - Speaker icon button is visible
69
+
70
+ ### Text-to-Speech
71
+ 1. After getting an answer with description
72
+ 2. Tap the speaker icon (🔊) in the "Accessible Description" card
73
+ 3. **Expected**: The description is read aloud
74
+ 4. Tap the stop icon (⏹️) to stop playback
75
+
76
+ ---
77
+
78
+ ## Troubleshooting
79
+
80
+ ### Backend Issues
81
+
82
+ **Groq API Key Error**
83
+ ```
84
+ ValueError: Groq API key not found
85
+ ```
86
+ **Solution**: Make sure `.env` file exists with `GROQ_API_KEY=your_key`
87
+
88
+ **Models Not Loading**
89
+ ```
90
+ ❌ Base checkpoint not found
91
+ ```
92
+ **Solution**: Ensure `vqa_checkpoint.pt` and `vqa_spatial_checkpoint.pt` are in the project root
93
+
94
+ ### Frontend Issues
95
+
96
+ **Image Not Displaying**
97
+ - Make sure you've run `npm install` to get the latest `expo-image` package
98
+ - Check console logs for image URI format issues
99
+
100
+ **Text-to-Speech Not Working**
101
+ - Ensure device volume is turned up
102
+ - Check that `expo-speech` package is installed
103
+ - On iOS simulator, speech may not work (test on physical device)
104
+
105
+ **Cannot Connect to Backend**
106
+ - Verify backend is running on port 8000
107
+ - Update `ui/src/config/api.js` with correct backend URL
108
+ - For physical devices, use ngrok or your computer's local IP
109
+
110
+ ---
111
+
112
+ ## Features Summary
113
+
114
+ ✅ **Fixed**: Image display issue (using expo-image instead of react-native Image)
115
+ ✅ **Added**: Groq LLM integration for 2-sentence descriptions
116
+ ✅ **Added**: Text-to-speech accessibility feature
117
+ ✅ **Added**: Visual distinction between raw answer and description
118
+ ✅ **Added**: Fallback mode when Groq API is unavailable
VQA_ENHANCEMENTS.md ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VQA Enhancements: LLM Reasoning & Conversational VQA
2
+
3
+ This document describes the two major enhancements added to the VQA system.
4
+
5
+ ## 🧠 Feature 1: LLM-Driven Reasoning Engine
6
+
7
+ ### Overview
8
+ Replaced hardcoded if/else rules with **Groq LLM Chain-of-Thought reasoning** for intelligent deductive reasoning from Wikidata facts.
9
+
10
+ ### What Changed
11
+ **Before**: Hardcoded rules in `semantic_neurosymbolic_vqa.py`
12
+ ```python
13
+ if 'melt' in question:
14
+ check material properties...
15
+ ```
16
+
17
+ **After**: LLM-driven reasoning
18
+ ```python
19
+ reasoning_result = llm_service.reason_with_facts(
20
+ object_name="candle",
21
+ facts={"materials": ["wax"], "categories": ["light source"]},
22
+ question="Can this melt?"
23
+ )
24
+ # Returns: Chain-of-Thought reasoning + answer
25
+ ```
26
+
27
+ ### Benefits
28
+ - ✅ Handles complex questions like "Would this survive a fire?"
29
+ - ✅ Provides transparent reasoning chains
30
+ - ✅ More flexible and generalizable
31
+ - ✅ Automatic fallback to rule-based reasoning if LLM fails
32
+
33
+ ### Example
34
+ **Question**: "Can this melt?"
35
+ **Object**: Candle
36
+ **Facts**: Material: wax, Category: light source
37
+
38
+ **LLM Reasoning Chain**:
39
+ 1. The object is a candle
40
+ 2. It is made of wax
41
+ 3. Wax has a low melting point (~60°C)
42
+ 4. Therefore, yes, it can melt at moderate temperatures
43
+
44
+ **Answer**: "Yes, the candle can melt because it's made of wax, which has a low melting point."
45
+
46
+ ### Files Added/Modified
47
+ - **NEW**: `llm_reasoning_service.py` - LLM reasoning with Chain-of-Thought
48
+ - **MODIFIED**: `semantic_neurosymbolic_vqa.py` - Integrated LLM reasoning
49
+ - **MODIFIED**: `backend_api.py` - Added reasoning_chain to API responses
50
+
51
+ ---
52
+
53
+ ## 💬 Feature 2: Conversational VQA
54
+
55
+ ### Overview
56
+ Added **multi-turn conversation support** with context management and pronoun resolution.
57
+
58
+ ### What Changed
59
+ **Before**: Single-shot Q&A with no context
60
+ ```
61
+ User: "What is this?" → System: "A red apple."
62
+ User: "Is it healthy?" → System: "What is 'it'?" ❌
63
+ ```
64
+
65
+ **After**: Multi-turn conversations
66
+ ```
67
+ User: "What is this?" → System: "A red apple."
68
+ User: "Is it healthy?" → System: "Yes, apples are rich in fiber..." ✅
69
+ (System knows "it" = apple)
70
+ ```
71
+
72
+ ### Benefits
73
+ - ✅ Natural follow-up questions
74
+ - ✅ Context-aware pronoun resolution
75
+ - ✅ Session management with auto-expiration
76
+ - ✅ Conversation history tracking
77
+
78
+ ### Example Conversation
79
+ ```
80
+ Turn 1:
81
+ Q: "What is this?"
82
+ A: "A red apple"
83
+ Objects: ["apple"]
84
+
85
+ Turn 2:
86
+ Q: "Is it healthy?"
87
+ Resolved: "Is apple healthy?"
88
+ A: "Yes, apples are rich in fiber and vitamins"
89
+
90
+ Turn 3:
91
+ Q: "What color is it?"
92
+ Resolved: "What color is apple?"
93
+ A: "Red"
94
+ ```
95
+
96
+ ### Files Added/Modified
97
+ - **NEW**: `conversation_manager.py` - Multi-turn conversation management
98
+ - **MODIFIED**: `ensemble_vqa_app.py` - Added `answer_conversational()` method
99
+ - **MODIFIED**: `backend_api.py` - Added conversation endpoints
100
+
101
+ ---
102
+
103
+ ## 🚀 API Endpoints
104
+
105
+ ### Existing Endpoint (Enhanced)
106
+ **POST** `/api/answer`
107
+ - Now includes `reasoning_chain` in response
108
+ - Backward compatible
109
+
110
+ ### New Conversation Endpoints
111
+
112
+ **POST** `/api/conversation/answer`
113
+ - Multi-turn conversation support
114
+ - Request: `image`, `question`, `session_id` (optional)
115
+ - Response includes:
116
+ - `session_id` - For continuing conversation
117
+ - `resolved_question` - Question with pronouns resolved
118
+ - `conversation_context` - Previous turns, objects, etc.
119
+ - `reasoning_chain` - LLM reasoning steps (if applicable)
120
+
121
+ **GET** `/api/conversation/{session_id}/history`
122
+ - Get full conversation history
123
+ - Returns all turns with timestamps
124
+
125
+ **DELETE** `/api/conversation/{session_id}`
126
+ - Clear conversation session
127
+ - Useful for starting fresh
128
+
129
+ ---
130
+
131
+ ## 📋 Usage Examples
132
+
133
+ ### Example 1: LLM Reasoning (Python)
134
+ ```python
135
+ from llm_reasoning_service import get_llm_reasoning_service
136
+
137
+ service = get_llm_reasoning_service()
138
+
139
+ result = service.reason_with_facts(
140
+ object_name="ice cream",
141
+ facts={
142
+ "materials": ["milk", "sugar", "cream"],
143
+ "categories": ["frozen dessert"]
144
+ },
145
+ question="Would this survive in the desert?"
146
+ )
147
+
148
+ print(result['answer'])
149
+ # "No, ice cream would not survive in the desert because..."
150
+
151
+ print(result['reasoning_chain'])
152
+ # ["Ice cream is a frozen dessert", "Deserts are hot...", ...]
153
+ ```
154
+
155
+ ### Example 2: Conversational VQA (API)
156
+ ```bash
157
+ # Turn 1: Ask what it is
158
+ curl -X POST http://localhost:8000/api/conversation/answer \
159
+ -F "image=@apple.jpg" \
160
+ -F "question=What is this?"
161
+
162
+ # Response: {"session_id": "abc123", "answer": "apple", ...}
163
+
164
+ # Turn 2: Follow-up question with pronoun
165
+ curl -X POST http://localhost:8000/api/conversation/answer \
166
+ -F "image=@apple.jpg" \
167
+ -F "question=Is it healthy?" \
168
+ -F "session_id=abc123"
169
+
170
+ # Response: {
171
+ # "resolved_question": "Is apple healthy?",
172
+ # "answer": "Yes, apples are healthy",
173
+ # "conversation_context": {"turn_number": 2, ...}
174
+ # }
175
+ ```
176
+
177
+ ### Example 3: Conversational VQA (Python)
178
+ ```python
179
+ from ensemble_vqa_app import ProductionEnsembleVQA
180
+
181
+ ensemble = ProductionEnsembleVQA(
182
+ base_checkpoint="vqa_checkpoint.pt",
183
+ spatial_checkpoint="vqa_spatial_checkpoint.pt"
184
+ )
185
+
186
+ # Turn 1
187
+ result1 = ensemble.answer_conversational(
188
+ image_path="apple.jpg",
189
+ question="What is this?",
190
+ verbose=True
191
+ )
192
+ session_id = result1['session_id']
193
+ print(f"Answer: {result1['answer']}") # "apple"
194
+
195
+ # Turn 2 - pronoun resolution
196
+ result2 = ensemble.answer_conversational(
197
+ image_path="apple.jpg",
198
+ question="Is it healthy?",
199
+ session_id=session_id,
200
+ verbose=True
201
+ )
202
+ print(f"Resolved: {result2['resolved_question']}") # "Is apple healthy?"
203
+ print(f"Answer: {result2['answer']}") # "Yes, apples are healthy"
204
+ ```
205
+
206
+ ---
207
+
208
+ ## ⚙️ Configuration
209
+
210
+ ### Environment Variables
211
+ ```bash
212
+ # Required for LLM reasoning
213
+ GROQ_API_KEY=your_groq_api_key_here
214
+ ```
215
+
216
+ ### Session Timeout
217
+ Conversations expire after **30 minutes** of inactivity (configurable in `ConversationManager`).
218
+
219
+ ---
220
+
221
+ ## 🧪 Testing
222
+
223
+ Run the test suite:
224
+ ```bash
225
+ python test_vqa_enhancements.py
226
+ ```
227
+
228
+ Tests include:
229
+ - ✅ LLM reasoning with various question types
230
+ - ✅ Conversation manager pronoun resolution
231
+ - ✅ Session management and expiration
232
+ - ✅ Integration with existing VQA system
233
+
234
+ ---
235
+
236
+ ## 🔄 Backward Compatibility
237
+
238
+ **All existing functionality remains intact:**
239
+ - ✅ Original `/api/answer` endpoint works unchanged
240
+ - ✅ Single-shot Q&A still supported
241
+ - ✅ Spatial routing unchanged
242
+ - ✅ Neuro-symbolic fallback preserved
243
+
244
+ **New features are opt-in:**
245
+ - Use `/api/conversation/answer` for multi-turn
246
+ - LLM reasoning activates automatically for reasoning questions
247
+ - Fallback to rule-based if LLM unavailable
248
+
249
+ ---
250
+
251
+ ## 📊 Architecture
252
+
253
+ ```
254
+ User Question
255
+
256
+ Ensemble VQA
257
+
258
+ ┌─────────────────────────────────┐
259
+ │ Conversation Manager │
260
+ │ - Resolve pronouns │
261
+ │ - Track context │
262
+ └─────────────────────────────────┘
263
+
264
+ ┌─────────────────────────────────┐
265
+ │ Semantic Neuro-Symbolic VQA │
266
+ │ - Detect objects (VQA) │
267
+ │ - Query Wikidata │
268
+ └─────────────────────────────────┘
269
+
270
+ ┌─────────────────────────────────┐
271
+ │ LLM Reasoning Service │
272
+ │ - Chain-of-Thought reasoning │
273
+ │ - Fallback to rules │
274
+ └─────────────────────────────────┘
275
+
276
+ Answer + Reasoning Chain
277
+ ```
278
+
279
+ ---
280
+
281
+ ## 🎯 Key Improvements
282
+
283
+ | Feature | Before | After |
284
+ |---------|--------|-------|
285
+ | **Reasoning** | Hardcoded if/else rules | LLM Chain-of-Thought |
286
+ | **Conversations** | Single-shot only | Multi-turn with context |
287
+ | **Pronouns** | Not handled | Automatic resolution |
288
+ | **Transparency** | Black box | Reasoning chains visible |
289
+ | **Flexibility** | Rigid rules | Adaptive LLM reasoning |
290
+
291
+ ---
292
+
293
+ ## 📝 Notes
294
+
295
+ - LLM reasoning requires `GROQ_API_KEY` environment variable
296
+ - Conversation sessions auto-expire after 30 minutes
297
+ - All features have fallback mechanisms for robustness
298
+ - Zero breaking changes to existing code
__pycache__/backend_api.cpython-312.pyc ADDED
Binary file (14.8 kB). View file
 
__pycache__/conversation_manager.cpython-312.pyc ADDED
Binary file (15.5 kB). View file
 
__pycache__/ensemble_vqa_app.cpython-312.pyc ADDED
Binary file (22.3 kB). View file
 
__pycache__/groq_service.cpython-312.pyc ADDED
Binary file (5.32 kB). View file
 
__pycache__/knowledge_graph_service.cpython-312.pyc ADDED
Binary file (10 kB). View file
 
__pycache__/llm_reasoning_service.cpython-312.pyc ADDED
Binary file (13.4 kB). View file
 
__pycache__/model_spatial.cpython-312.pyc ADDED
Binary file (25.5 kB). View file
 
__pycache__/semantic_neurosymbolic_vqa.cpython-312.pyc ADDED
Binary file (32 kB). View file
 
architecture_draft.html ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <!DOCTYPE html>
3
+ <html>
4
+ <head>
5
+ <title>VQA Architecture Draft</title>
6
+ <script type="module">
7
+ import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
8
+ mermaid.initialize({ startOnLoad: true, theme: 'dark', flowchart: { curve: 'basis' } });
9
+ </script>
10
+ <style>
11
+ body { background-color: #0D1117; color: white; font-family: sans-serif; display: flex; justify-content: center; padding: 20px; }
12
+ .mermaid { background-color: #161B22; padding: 20px; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.5); }
13
+ </style>
14
+ </head>
15
+ <body>
16
+ <div class="mermaid">
17
+
18
+ graph TD
19
+ %% Styling
20
+ classDef default fill:#1A1A1A,stroke:#444,stroke-width:2px,color:#FFF,rx:8px,ry:8px,font-family:arial;
21
+ classDef mobile fill:#003366,stroke:#0055AA,stroke-width:2px,color:#FFF;
22
+ classDef preproc fill:#333333,stroke:#555,stroke-width:2px,color:#FFF;
23
+ classDef model fill:#4B0082,stroke:#8A2BE2,stroke-width:2px,color:#FFF;
24
+ classDef condition fill:#2B2B2B,stroke:#F4A460,stroke-width:2px,color:#FFF,shape:rhombus;
25
+ classDef external fill:#004d00,stroke:#009900,stroke-width:2px,color:#FFF;
26
+ classDef final fill:#660000,stroke:#CC0000,stroke-width:2px,color:#FFF;
27
+
28
+ %% Nodes
29
+ UserApp[📱 Mobile App]:::mobile
30
+
31
+ ImgUpload[🖼️ Image]:::preproc
32
+ Question[⌨️ Question Text]:::preproc
33
+
34
+ PIL[🐍 PIL Preprocessing<br/>RGB conversion]:::preproc
35
+
36
+ CLIP[👁️ OpenAI CLIP ViT-B/32<br/>Image Features 512-dim]:::model
37
+ GPT2[🤗 DistilGPT-2<br/>Tokenized Question]:::model
38
+
39
+ Route1{Question<br/>spatial?}:::condition
40
+
41
+ Spatial[📐 Spatial VQA Model<br/>8-head attention]:::model
42
+ Base[🧠 Base VQA Model<br/>General VQA]:::model
43
+
44
+ Decoder[🤗 GPT-2 Decoder<br/>vocab decode]:::model
45
+ NeuralAns[💬 Neural Answer]:::final
46
+
47
+ Route2{Knowledge<br/>question?}:::condition
48
+
49
+ ObjDet[👁️ CLIP Object Detector<br/>Top-3 objects]:::model
50
+ Wikidata[🌍 Wikidata SPARQL<br/>P31, P186, P366]:::external
51
+ GroqV[⚡ Groq Llama-3.3<br/>Verbalizer]:::external
52
+ KGAns[🧩 KG Enhancement]:::final
53
+
54
+ FastAPI[🚀 FastAPI]:::preproc
55
+ GroqA[⚡ Groq Llama-3.3<br/>Accessibility]:::external
56
+ Audio[🔊 2-sentence description]:::final
57
+
58
+ %% Edges
59
+ UserApp -- "Image uploaded" --> ImgUpload
60
+ UserApp -- "Question typed" --> Question
61
+
62
+ ImgUpload --> PIL
63
+ PIL --> CLIP
64
+ Question --> GPT2
65
+
66
+ CLIP & GPT2 --> Route1
67
+
68
+ Route1 -- "YES" --> Spatial
69
+ Route1 -- "NO" --> Base
70
+
71
+ Spatial & Base -- "Beam search (width=5)" --> Decoder
72
+ Decoder --> NeuralAns
73
+
74
+ CLIP -- "Anchor similarity" --> Route2
75
+
76
+ Route2 -- "YES" --> ObjDet
77
+ ObjDet -- "Detected objects" --> Wikidata
78
+ Wikidata -- "Structured facts" --> GroqV
79
+ GroqV --> KGAns
80
+
81
+ FastAPI -- "Narration request" --> GroqA
82
+ GroqA --> Audio
83
+
84
+ NeuralAns & KGAns & Audio -- "JSON output" --> FastAPI
85
+ FastAPI --> UserApp
86
+
87
+ </div>
88
+ </body>
89
+ </html>
architecture_draft.mmd ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ graph TD
3
+ %% Styling
4
+ classDef default fill:#1A1A1A,stroke:#444,stroke-width:2px,color:#FFF,rx:8px,ry:8px,font-family:arial;
5
+ classDef mobile fill:#003366,stroke:#0055AA,stroke-width:2px,color:#FFF;
6
+ classDef preproc fill:#333333,stroke:#555,stroke-width:2px,color:#FFF;
7
+ classDef model fill:#4B0082,stroke:#8A2BE2,stroke-width:2px,color:#FFF;
8
+ classDef condition fill:#2B2B2B,stroke:#F4A460,stroke-width:2px,color:#FFF,shape:rhombus;
9
+ classDef external fill:#004d00,stroke:#009900,stroke-width:2px,color:#FFF;
10
+ classDef final fill:#660000,stroke:#CC0000,stroke-width:2px,color:#FFF;
11
+
12
+ %% Nodes
13
+ UserApp[📱 Mobile App]:::mobile
14
+
15
+ ImgUpload[🖼️ Image]:::preproc
16
+ Question[⌨️ Question Text]:::preproc
17
+
18
+ PIL[🐍 PIL Preprocessing<br/>RGB conversion]:::preproc
19
+
20
+ CLIP[👁️ OpenAI CLIP ViT-B/32<br/>Image Features 512-dim]:::model
21
+ GPT2[🤗 DistilGPT-2<br/>Tokenized Question]:::model
22
+
23
+ Route1{Question<br/>spatial?}:::condition
24
+
25
+ Spatial[📐 Spatial VQA Model<br/>8-head attention]:::model
26
+ Base[🧠 Base VQA Model<br/>General VQA]:::model
27
+
28
+ Decoder[🤗 GPT-2 Decoder<br/>vocab decode]:::model
29
+ NeuralAns[💬 Neural Answer]:::final
30
+
31
+ Route2{Knowledge<br/>question?}:::condition
32
+
33
+ ObjDet[👁️ CLIP Object Detector<br/>Top-3 objects]:::model
34
+ Wikidata[🌍 Wikidata SPARQL<br/>P31, P186, P366]:::external
35
+ GroqV[⚡ Groq Llama-3.3<br/>Verbalizer]:::external
36
+ KGAns[🧩 KG Enhancement]:::final
37
+
38
+ FastAPI[🚀 FastAPI]:::preproc
39
+ GroqA[⚡ Groq Llama-3.3<br/>Accessibility]:::external
40
+ Audio[🔊 2-sentence description]:::final
41
+
42
+ %% Edges
43
+ UserApp -- "Image uploaded" --> ImgUpload
44
+ UserApp -- "Question typed" --> Question
45
+
46
+ ImgUpload --> PIL
47
+ PIL --> CLIP
48
+ Question --> GPT2
49
+
50
+ CLIP & GPT2 --> Route1
51
+
52
+ Route1 -- "YES" --> Spatial
53
+ Route1 -- "NO" --> Base
54
+
55
+ Spatial & Base -- "Beam search (width=5)" --> Decoder
56
+ Decoder --> NeuralAns
57
+
58
+ CLIP -- "Anchor similarity" --> Route2
59
+
60
+ Route2 -- "YES" --> ObjDet
61
+ ObjDet -- "Detected objects" --> Wikidata
62
+ Wikidata -- "Structured facts" --> GroqV
63
+ GroqV --> KGAns
64
+
65
+ FastAPI -- "Narration request" --> GroqA
66
+ GroqA --> Audio
67
+
68
+ NeuralAns & KGAns & Audio -- "JSON output" --> FastAPI
69
+ FastAPI --> UserApp
backend_api.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI Backend for Ensemble VQA Mobile App
3
+ Provides REST API endpoints for the React Native mobile application
4
+ """
5
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from fastapi.responses import JSONResponse
8
+ import uvicorn
9
+ from PIL import Image
10
+ import io
11
+ import os
12
+ import sys
13
+ from pathlib import Path
14
+ from dotenv import load_dotenv
15
+ load_dotenv()
16
+ from ensemble_vqa_app import ProductionEnsembleVQA
17
+ from groq_service import get_groq_service
18
+ app = FastAPI(
19
+ title="Ensemble VQA API",
20
+ description="Visual Question Answering API with ensemble model routing",
21
+ version="1.0.0"
22
+ )
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+ ensemble_model = None
31
+ groq_service = None
32
+ @app.on_event("startup")
33
+ async def startup_event():
34
+ """Initialize the ensemble VQA model on server startup"""
35
+ global ensemble_model, groq_service
36
+ print("=" * 80)
37
+ print("🚀 STARTING VQA API SERVER")
38
+ print("=" * 80)
39
+ BASE_CHECKPOINT = "./vqa_checkpoint.pt"
40
+ SPATIAL_CHECKPOINT = "./vqa_spatial_checkpoint.pt"
41
+ if not os.path.exists(BASE_CHECKPOINT):
42
+ print(f"❌ Base checkpoint not found: {BASE_CHECKPOINT}")
43
+ print("Please ensure vqa_checkpoint.pt is in the project root")
44
+ sys.exit(1)
45
+ if not os.path.exists(SPATIAL_CHECKPOINT):
46
+ print(f"❌ Spatial checkpoint not found: {SPATIAL_CHECKPOINT}")
47
+ print("Please ensure vqa_spatial_checkpoint.pt is in the project root")
48
+ sys.exit(1)
49
+ try:
50
+ ensemble_model = ProductionEnsembleVQA(
51
+ base_checkpoint=BASE_CHECKPOINT,
52
+ spatial_checkpoint=SPATIAL_CHECKPOINT,
53
+ device='cuda'
54
+ )
55
+ print("\n✅ VQA models loaded successfully!")
56
+ try:
57
+ groq_service = get_groq_service()
58
+ print("✅ Groq LLM service initialized for accessibility features")
59
+ except ValueError as e:
60
+ print(f"⚠️ Groq service not available: {e}")
61
+ print(" Accessibility descriptions will use fallback mode")
62
+ groq_service = None
63
+ print("📱 Mobile app can now connect")
64
+ print("=" * 80)
65
+ except Exception as e:
66
+ print(f"\n❌ Failed to load models: {e}")
67
+ sys.exit(1)
68
+ @app.get("/")
69
+ async def root():
70
+ """Root endpoint"""
71
+ return {
72
+ "message": "Ensemble VQA API",
73
+ "version": "1.0.0",
74
+ "status": "running",
75
+ "endpoints": {
76
+ "health": "/health",
77
+ "answer": "/api/answer (POST)"
78
+ }
79
+ }
80
+ @app.get("/health")
81
+ async def health_check():
82
+ """Health check endpoint"""
83
+ return {
84
+ "status": "healthy",
85
+ "model_loaded": ensemble_model is not None,
86
+ "models": {
87
+ "base": "loaded" if ensemble_model else "not loaded",
88
+ "spatial": "loaded" if ensemble_model else "not loaded"
89
+ }
90
+ }
91
+ @app.post("/api/answer")
92
+ async def answer_question(
93
+ image: UploadFile = File(...),
94
+ question: str = Form(...)
95
+ ):
96
+ """
97
+ Answer a visual question using the ensemble VQA system
98
+ Args:
99
+ image: Image file (JPEG, PNG)
100
+ question: Question text
101
+ Returns:
102
+ JSON response with answer, model used, accessibility description, and metadata
103
+ """
104
+ if ensemble_model is None:
105
+ raise HTTPException(status_code=503, detail="Model not loaded")
106
+ if not question or question.strip() == "":
107
+ raise HTTPException(status_code=400, detail="Question cannot be empty")
108
+ try:
109
+ image_bytes = await image.read()
110
+ try:
111
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
112
+ except Exception as e:
113
+ raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}")
114
+ temp_image_path = "temp_upload.jpg"
115
+ pil_image.save(temp_image_path)
116
+ result = ensemble_model.answer(
117
+ image_path=temp_image_path,
118
+ question=question,
119
+ use_beam_search=True,
120
+ beam_width=5,
121
+ verbose=True
122
+ )
123
+ if os.path.exists(temp_image_path):
124
+ os.remove(temp_image_path)
125
+ is_spatial = ensemble_model.is_spatial_question(question)
126
+ description = None
127
+ description_status = "not_available"
128
+ if groq_service is not None:
129
+ try:
130
+ desc_result = groq_service.generate_description(
131
+ question=question,
132
+ answer=result['answer']
133
+ )
134
+ description = desc_result.get('description')
135
+ description_status = desc_result.get('status', 'success')
136
+ except Exception as e:
137
+ print(f"⚠️ Groq description generation failed: {e}")
138
+ description = f"Question: {question}. Answer: {result['answer']}."
139
+ description_status = "fallback"
140
+ else:
141
+ description = f"Question: {question}. Answer: {result['answer']}."
142
+ description_status = "fallback"
143
+ reasoning_chain = None
144
+ if result.get('kg_enhancement'):
145
+ reasoning_chain = result.get('reasoning_chain', [])
146
+ return JSONResponse(content={
147
+ "success": True,
148
+ "answer": result['answer'],
149
+ "description": description,
150
+ "description_status": description_status,
151
+ "model_used": result['model_used'],
152
+ "confidence": result['confidence'],
153
+ "question_type": "spatial" if is_spatial else "general",
154
+ "question": question,
155
+ "kg_enhancement": result.get('kg_enhancement'),
156
+ "reasoning_type": result.get('reasoning_type', 'neural'),
157
+ "reasoning_chain": reasoning_chain,
158
+ "metadata": {
159
+ "beam_search": True,
160
+ "beam_width": 5
161
+ }
162
+ })
163
+ except HTTPException:
164
+ raise
165
+ except Exception as e:
166
+ print(f"❌ Error processing request: {e}")
167
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
168
+ @app.get("/api/models/info")
169
+ async def models_info():
170
+ """Get information about loaded models"""
171
+ if ensemble_model is None:
172
+ raise HTTPException(status_code=503, detail="Models not loaded")
173
+ return {
174
+ "base_model": {
175
+ "name": "Base VQA Model",
176
+ "description": "General visual question answering",
177
+ "accuracy": "50%",
178
+ "use_case": "General questions about objects, colors, counts, etc."
179
+ },
180
+ "spatial_model": {
181
+ "name": "Spatial Adapter Model",
182
+ "description": "Spatial reasoning and positional questions",
183
+ "accuracy": "40%",
184
+ "use_case": "Spatial questions (left, right, above, below, etc.)"
185
+ },
186
+ "routing": {
187
+ "method": "Keyword-based classification",
188
+ "spatial_keywords": ensemble_model.SPATIAL_KEYWORDS
189
+ },
190
+ "conversation": {
191
+ "enabled": ensemble_model.conversation_enabled if ensemble_model else False,
192
+ "timeout_minutes": 30
193
+ }
194
+ }
195
+ @app.post("/api/conversation/answer")
196
+ async def answer_conversational(
197
+ image: UploadFile = File(...),
198
+ question: str = Form(...),
199
+ session_id: str = Form(None)
200
+ ):
201
+ """
202
+ Answer a visual question with multi-turn conversation support.
203
+ Handles pronoun resolution and maintains conversation context.
204
+ Args:
205
+ image: Image file (JPEG, PNG)
206
+ question: Question text (may contain pronouns like "it", "this")
207
+ session_id: Optional session ID to continue conversation
208
+ Returns:
209
+ JSON response with answer, session_id, resolved question, and context
210
+ """
211
+ if ensemble_model is None:
212
+ raise HTTPException(status_code=503, detail="Model not loaded")
213
+ if not ensemble_model.conversation_enabled:
214
+ raise HTTPException(
215
+ status_code=501,
216
+ detail="Conversational VQA not available. Use /api/answer instead."
217
+ )
218
+ if not question or question.strip() == "":
219
+ raise HTTPException(status_code=400, detail="Question cannot be empty")
220
+ try:
221
+ image_bytes = await image.read()
222
+ try:
223
+ pil_image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
224
+ except Exception as e:
225
+ raise HTTPException(status_code=400, detail=f"Invalid image format: {str(e)}")
226
+ temp_image_path = "temp_upload.jpg"
227
+ pil_image.save(temp_image_path)
228
+ result = ensemble_model.answer_conversational(
229
+ image_path=temp_image_path,
230
+ question=question,
231
+ session_id=session_id,
232
+ use_beam_search=True,
233
+ beam_width=5,
234
+ verbose=True
235
+ )
236
+ if os.path.exists(temp_image_path):
237
+ os.remove(temp_image_path)
238
+ description = None
239
+ if groq_service is not None:
240
+ try:
241
+ desc_result = groq_service.generate_description(
242
+ question=result['resolved_question'],
243
+ answer=result['answer']
244
+ )
245
+ description = desc_result.get('description')
246
+ except:
247
+ description = f"Question: {question}. Answer: {result['answer']}."
248
+ else:
249
+ description = f"Question: {question}. Answer: {result['answer']}."
250
+ return JSONResponse(content={
251
+ "success": True,
252
+ "answer": result['answer'],
253
+ "description": description,
254
+ "session_id": result['session_id'],
255
+ "resolved_question": result['resolved_question'],
256
+ "original_question": question,
257
+ "conversation_context": result['conversation_context'],
258
+ "model_used": result['model_used'],
259
+ "confidence": result['confidence'],
260
+ "kg_enhancement": result.get('kg_enhancement'),
261
+ "reasoning_type": result.get('reasoning_type', 'neural'),
262
+ "reasoning_chain": result.get('reasoning_chain'),
263
+ "metadata": {
264
+ "beam_search": True,
265
+ "beam_width": 5,
266
+ "conversation_enabled": True
267
+ }
268
+ })
269
+ except HTTPException:
270
+ raise
271
+ except Exception as e:
272
+ print(f"❌ Error processing conversational request: {e}")
273
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
274
+ @app.get("/api/conversation/{session_id}/history")
275
+ async def get_conversation_history(session_id: str):
276
+ """
277
+ Get conversation history for a session.
278
+ Args:
279
+ session_id: Session ID
280
+ Returns:
281
+ JSON with conversation history
282
+ """
283
+ if ensemble_model is None or not ensemble_model.conversation_enabled:
284
+ raise HTTPException(status_code=503, detail="Conversation service not available")
285
+ history = ensemble_model.conversation_manager.get_history(session_id)
286
+ if history is None:
287
+ raise HTTPException(
288
+ status_code=404,
289
+ detail=f"Session {session_id} not found or expired"
290
+ )
291
+ return JSONResponse(content={
292
+ "success": True,
293
+ "session_id": session_id,
294
+ "history": history,
295
+ "turn_count": len(history)
296
+ })
297
+ @app.delete("/api/conversation/{session_id}")
298
+ async def delete_conversation(session_id: str):
299
+ """
300
+ Delete a conversation session.
301
+ Args:
302
+ session_id: Session ID to delete
303
+ Returns:
304
+ JSON with success status
305
+ """
306
+ if ensemble_model is None or not ensemble_model.conversation_enabled:
307
+ raise HTTPException(status_code=503, detail="Conversation service not available")
308
+ deleted = ensemble_model.conversation_manager.delete_session(session_id)
309
+ if not deleted:
310
+ raise HTTPException(
311
+ status_code=404,
312
+ detail=f"Session {session_id} not found"
313
+ )
314
+ return JSONResponse(content={
315
+ "success": True,
316
+ "message": f"Session {session_id} deleted"
317
+ })
318
+ if __name__ == "__main__":
319
+ print("\n" + "=" * 80)
320
+ print("🚀 ENSEMBLE VQA API SERVER")
321
+ print("=" * 80)
322
+ print("\n📋 Configuration:")
323
+ print(" - Host: 0.0.0.0 (accessible from network)")
324
+ print(" - Port: 8000")
325
+ print(" - Reload: Enabled (development mode)")
326
+ print("\n🔗 Access URLs:")
327
+ print(" - Local: http://localhost:8000")
328
+ print(" - Network: http://<your-ip>:8000")
329
+ print(" - Docs: http://localhost:8000/docs")
330
+ print("\n💡 For mobile testing:")
331
+ print(" 1. Find your local IP: ipconfig (Windows) or ifconfig (Mac/Linux)")
332
+ print(" 2. Update API_URL in mobile app to http://<your-ip>:8000")
333
+ print(" 3. Ensure phone and computer are on same network")
334
+ print("=" * 80 + "\n")
335
+ uvicorn.run(
336
+ "backend_api:app",
337
+ host="0.0.0.0",
338
+ port=7860, # HuggingFace Spaces requires port 7860
339
+ reload=True,
340
+ log_level="info"
341
+ )
continue.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from PIL import Image
7
+ from transformers import GPT2Tokenizer
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ from collections import Counter
12
+ from nltk.tokenize import word_tokenize
13
+ from sklearn.model_selection import train_test_split
14
+ from torchvision import transforms
15
+ from model import VQAModel
16
+ device = 'cuda'
17
+ class Vocab:
18
+ def __init__(self):
19
+ self.vocab = None
20
+ self.vocab_size = None
21
+ self.word2idx = None
22
+ self.idx2word = None
23
+ self.pad = '<pad>'
24
+ self.bos = '<bos>'
25
+ self.eos = '<eos>'
26
+ self.unk = '<unk>'
27
+ def build_vocab(self, df, min_freq=1):
28
+ counter = Counter()
29
+ for ans in df['answer']:
30
+ tokens = word_tokenize(ans.lower())
31
+ counter.update(tokens)
32
+ vocab = sorted([word for word, freq in counter.items() if freq >= min_freq])
33
+ vocab = [self.pad, self.bos, self.eos, self.unk] + vocab
34
+ word2idx = {word: idx for idx, word in enumerate(vocab)}
35
+ idx2word = {idx: word for word, idx in word2idx.items()}
36
+ self.vocab = vocab
37
+ self.word2idx = word2idx
38
+ self.idx2word = idx2word
39
+ self.vocab_size = len(vocab)
40
+ self.pad_token_id = self.word2idx["<pad>"]
41
+ self.bos_token_id = self.word2idx["<bos>"]
42
+ self.eos_token_id = self.word2idx["<eos>"]
43
+ self.unk_token_id = self.word2idx["<unk>"]
44
+ def encoder(self, text, max_len):
45
+ tokens = word_tokenize(text.lower())
46
+ token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]
47
+ token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
48
+ if len(token_ids) < max_len:
49
+ token_ids += [self.pad_token_id] * (max_len - len(token_ids))
50
+ else:
51
+ token_ids = token_ids[:max_len]
52
+ return token_ids
53
+ def decoder(self, token_ids):
54
+ tokens = []
55
+ for idx in token_ids:
56
+ if idx == self.eos_token_id:
57
+ break
58
+ if idx in (self.pad_token_id, self.bos_token_id):
59
+ continue
60
+ tokens.append(self.idx2word.get(idx, "<unk>"))
61
+ return ' '.join(tokens).strip()
62
+ class AugmentedVQADataset(Dataset):
63
+ def __init__(self, df, img_dir, question_tokenizer, text_processor, clip_processor,
64
+ question_max_len=32, answer_max_len=16, augment=True):
65
+ self.df = df
66
+ self.img_dir = img_dir
67
+ self.question_tokenizer = question_tokenizer
68
+ self.text_processor = text_processor
69
+ self.clip_processor = clip_processor
70
+ self.question_max_len = question_max_len
71
+ self.answer_max_len = answer_max_len
72
+ self.augment = augment
73
+ if augment:
74
+ self.transform = transforms.Compose([
75
+ transforms.RandomHorizontalFlip(p=0.5),
76
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
77
+ transforms.RandomRotation(10),
78
+ ])
79
+ else:
80
+ self.transform = None
81
+ def __len__(self):
82
+ return len(self.df)
83
+ def __getitem__(self, idx):
84
+ row = self.df.iloc[idx]
85
+ img_path = os.path.join(self.img_dir, row['image_path'])
86
+ image = Image.open(img_path).convert('RGB')
87
+ question = row['question']
88
+ answer = row['answer']
89
+ if self.augment and self.transform:
90
+ image = self.transform(image)
91
+ question_tokenized = self.question_tokenizer(
92
+ question,
93
+ padding='max_length',
94
+ truncation=True,
95
+ max_length=self.question_max_len,
96
+ return_tensors='pt'
97
+ )
98
+ answer_ids = self.text_processor.encoder(answer, max_len=self.answer_max_len)
99
+ image = self.clip_processor(image)
100
+ return {
101
+ 'image_path': img_path,
102
+ 'image': image,
103
+ 'question_ids': question_tokenized['input_ids'].squeeze(0),
104
+ 'question_mask': question_tokenized['attention_mask'].squeeze(0),
105
+ 'answer_ids': torch.tensor(answer_ids, dtype=torch.long)
106
+ }
107
+ def save_checkpoint(model, optimizer, epoch, vocab, path):
108
+ torch.save({
109
+ 'epoch': epoch,
110
+ 'model_state_dict': model.state_dict(),
111
+ 'optimizer_state_dict': optimizer.state_dict(),
112
+ 'vocab': vocab.vocab,
113
+ 'word2idx': vocab.word2idx,
114
+ 'idx2word': vocab.idx2word,
115
+ 'pad_token_id': vocab.pad_token_id,
116
+ 'bos_token_id': vocab.bos_token_id,
117
+ 'eos_token_id': vocab.eos_token_id,
118
+ 'unk_token_id': vocab.unk_token_id,
119
+ 'question_max_len': model.question_max_len,
120
+ 'answer_max_len': model.answer_max_len
121
+ }, path)
122
+ def plot_losses(train_losses, val_losses, save_path="loss_plot.png"):
123
+ plt.figure(figsize=(8,6))
124
+ plt.plot(train_losses, label="Train Loss")
125
+ plt.plot(val_losses, label="Validation Loss")
126
+ plt.xlabel("Epoch")
127
+ plt.ylabel("Loss")
128
+ plt.title("Train vs Validation Loss")
129
+ plt.legend()
130
+ plt.savefig(save_path)
131
+ plt.close()
132
+ def train_one_epoch(model, dataloader, optimizer, device, scaler, vocab):
133
+ model.train()
134
+ total_loss = 0
135
+ total_token_acc = 0
136
+ criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id, label_smoothing=0.1)
137
+ for batch in tqdm(dataloader):
138
+ optimizer.zero_grad()
139
+ images = batch['image'].to(device)
140
+ questions = {
141
+ 'input_ids': batch['question_ids'].to(device),
142
+ 'attention_mask': batch['question_mask'].to(device)
143
+ }
144
+ answers = batch['answer_ids'].to(device)
145
+ with torch.amp.autocast(device):
146
+ logits = model(images, questions, answer_input_ids=answers)
147
+ shifted_logits = logits[:, :-1, :]
148
+ shifted_answers = answers[:, 1:]
149
+ loss = criterion(
150
+ shifted_logits.reshape(-1, shifted_logits.size(-1)),
151
+ shifted_answers.reshape(-1)
152
+ )
153
+ predicted_tokens = shifted_logits.argmax(dim=-1)
154
+ correct = (predicted_tokens == shifted_answers).float()
155
+ mask = (shifted_answers != vocab.pad_token_id).float()
156
+ token_acc = (correct * mask).sum() / mask.sum()
157
+ total_token_acc += token_acc.item()
158
+ scaler.scale(loss).backward()
159
+ scaler.unscale_(optimizer)
160
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
161
+ scaler.step(optimizer)
162
+ scaler.update()
163
+ total_loss += loss.item()
164
+ avg_loss = total_loss / len(dataloader)
165
+ avg_token_acc = total_token_acc / len(dataloader)
166
+ return avg_loss, avg_token_acc
167
+ def validate_one_epoch(model, dataloader, device, vocab):
168
+ model.eval()
169
+ total_loss = 0
170
+ total_token_acc = 0
171
+ exact_matches = 0
172
+ total_samples = 0
173
+ criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id)
174
+ with torch.no_grad():
175
+ for batch in tqdm(dataloader):
176
+ images = batch['image'].to(device)
177
+ questions = {
178
+ 'input_ids': batch['question_ids'].to(device),
179
+ 'attention_mask': batch['question_mask'].to(device)
180
+ }
181
+ answers = batch['answer_ids'].to(device)
182
+ logits = model(images, questions, answer_input_ids=answers)
183
+ shifted_logits = logits[:, :-1, :]
184
+ shifted_answers = answers[:, 1:]
185
+ loss = criterion(
186
+ shifted_logits.reshape(-1, shifted_logits.size(-1)),
187
+ shifted_answers.reshape(-1)
188
+ )
189
+ total_loss += loss.item()
190
+ predicted_tokens = shifted_logits.argmax(dim=-1)
191
+ correct = (predicted_tokens == shifted_answers).float()
192
+ mask = (shifted_answers != vocab.pad_token_id).float()
193
+ token_acc = (correct * mask).sum() / mask.sum()
194
+ total_token_acc += token_acc.item()
195
+ generated = model(images, questions)
196
+ for pred, true in zip(generated, answers):
197
+ pred_text = vocab.decoder(pred.cpu().numpy())
198
+ true_text = vocab.decoder(true.cpu().numpy())
199
+ if pred_text.strip() == true_text.strip():
200
+ exact_matches += 1
201
+ total_samples += 1
202
+ avg_loss = total_loss / len(dataloader)
203
+ avg_token_acc = total_token_acc / len(dataloader)
204
+ exact_match_acc = exact_matches / total_samples
205
+ return avg_loss, avg_token_acc, exact_match_acc
206
+ def main():
207
+ print()
208
+ print("# VQA: Continue Training (Same Settings)")
209
+ print()
210
+ import random
211
+ import numpy as np
212
+ torch.manual_seed(42)
213
+ random.seed(42)
214
+ np.random.seed(42)
215
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(42)
216
+ DATA_DIR = r"./gen_vqa_v2"
217
+ CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
218
+ RESUME_CHECKPOINT = r"./output2/continued_training/vqa_checkpoint.pt"
219
+ OUTPUT_DIR = r"./output2/continued_training_2"
220
+ CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "vqa_checkpoint.pt")
221
+ LOG_CSV = os.path.join(OUTPUT_DIR, "train_log.csv")
222
+ LOSS_GRAPH_PATH = os.path.join(OUTPUT_DIR, "loss_plot.png")
223
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
224
+ batch_size = 64
225
+ additional_epochs = 50
226
+ patience = 8
227
+ question_max_len = 20
228
+ answer_max_len = 12
229
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
230
+ print(device)
231
+ print(f"Loading checkpoint from: {RESUME_CHECKPOINT}")
232
+ checkpoint = torch.load(RESUME_CHECKPOINT, map_location=device)
233
+ start_epoch = checkpoint['epoch'] + 1
234
+ metadata = pd.read_csv(CSV_PATH)
235
+ vocab = Vocab()
236
+ vocab.vocab = checkpoint['vocab']
237
+ vocab.vocab_size = len(checkpoint['vocab'])
238
+ vocab.word2idx = checkpoint['word2idx']
239
+ vocab.idx2word = checkpoint['idx2word']
240
+ vocab.pad_token_id = checkpoint['pad_token_id']
241
+ vocab.bos_token_id = checkpoint['bos_token_id']
242
+ vocab.eos_token_id = checkpoint['eos_token_id']
243
+ vocab.unk_token_id = checkpoint['unk_token_id']
244
+ print(f"Answer Vocab Size: {len(vocab.vocab)}")
245
+ print(f"Resuming from epoch: {start_epoch}")
246
+ train_df, test_df = train_test_split(metadata, test_size=0.2, random_state=42)
247
+ val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
248
+ print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")
249
+ print()
250
+ model = VQAModel(
251
+ vocab_size=len(vocab.vocab),
252
+ device=device,
253
+ question_max_len=question_max_len,
254
+ answer_max_len=answer_max_len,
255
+ pad_token_id=vocab.pad_token_id,
256
+ bos_token_id=vocab.bos_token_id,
257
+ eos_token_id=vocab.eos_token_id,
258
+ unk_token_id=vocab.unk_token_id,
259
+ hidden_size=512,
260
+ num_layers=2
261
+ ).to(device)
262
+ clip_processor = model.clip_preprocess
263
+ question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
264
+ if question_tokenizer.pad_token is None:
265
+ question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
266
+ model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
267
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
268
+ print("Model loaded from checkpoint!")
269
+ if model.fine_tuning_mode:
270
+ print("Model already in fine-tuning mode (encoders unfrozen)")
271
+ else:
272
+ print("Continuing with same training configuration")
273
+ print()
274
+ train_dataset = AugmentedVQADataset(
275
+ train_df, DATA_DIR, question_tokenizer, vocab,
276
+ clip_processor=clip_processor,
277
+ question_max_len=question_max_len,
278
+ answer_max_len=answer_max_len,
279
+ augment=True
280
+ )
281
+ val_dataset = AugmentedVQADataset(
282
+ val_df, DATA_DIR, question_tokenizer, vocab,
283
+ clip_processor=clip_processor,
284
+ question_max_len=question_max_len,
285
+ answer_max_len=answer_max_len,
286
+ augment=False
287
+ )
288
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
289
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
290
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
291
+ optimizer = torch.optim.AdamW(trainable_params, lr=1e-6, weight_decay=1e-4)
292
+ print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
293
+ if 'optimizer_state_dict' in checkpoint:
294
+ try:
295
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
296
+ print("Optimizer state loaded from checkpoint!")
297
+ for param_group in optimizer.param_groups:
298
+ print(f" Loaded LR: {param_group['lr']}")
299
+ except Exception as e:
300
+ print(f"Could not load optimizer state: {e}")
301
+ print("Using fresh optimizer")
302
+ else:
303
+ print("No optimizer state in checkpoint, using fresh optimizer")
304
+ print()
305
+ scaler = torch.amp.GradScaler(device)
306
+ best_val_exact_match = 0.0
307
+ counter = 0
308
+ logs = []
309
+ if os.path.exists(LOG_CSV):
310
+ old_logs = pd.read_csv(LOG_CSV)
311
+ logs = old_logs.values.tolist()
312
+ best_val_exact_match = old_logs['val_exact_match'].max()
313
+ print(f"Previous best exact match: {best_val_exact_match:.4f}")
314
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
315
+ optimizer, mode='max', factor=0.5, patience=4, verbose=True
316
+ )
317
+ total_epochs = start_epoch + additional_epochs
318
+ for epoch in range(start_epoch, total_epochs):
319
+ print(f"\nEpoch {epoch+1}/{total_epochs}")
320
+ train_loss, train_token_acc = train_one_epoch(model, train_loader, optimizer, device, scaler, vocab)
321
+ val_loss, val_token_acc, val_exact_match = validate_one_epoch(model, val_loader, device, vocab)
322
+ print(f"Train Loss: {train_loss:.4f} | Train Token Acc: {train_token_acc:.4f}")
323
+ print(f"Val Loss: {val_loss:.4f} | Val Token Acc: {val_token_acc:.4f} | Val Exact Match: {val_exact_match:.4f}")
324
+ print(f"LR: {optimizer.param_groups[0]['lr']}")
325
+ scheduler.step(val_exact_match)
326
+ if val_exact_match > best_val_exact_match:
327
+ best_val_exact_match = val_exact_match
328
+ save_checkpoint(model, optimizer, epoch, vocab, CHECKPOINT_PATH)
329
+ print("Checkpoint saved!")
330
+ counter = 0
331
+ else:
332
+ counter += 1
333
+ print(f"No improvement in exact match for {counter} epochs.")
334
+ if counter >= patience:
335
+ print(f"\nEarly stopping after {patience} epochs without improvement")
336
+ break
337
+ logs.append([epoch+1, train_loss, train_token_acc, val_loss, val_token_acc, val_exact_match, optimizer.param_groups[0]['lr']])
338
+ log_df = pd.DataFrame(logs, columns=["epoch","train_loss","train_token_acc","val_loss","val_token_acc","val_exact_match","lr"])
339
+ log_df.to_csv(LOG_CSV, index=False)
340
+ plot_losses([x[1] for x in logs], [x[3] for x in logs], save_path=LOSS_GRAPH_PATH)
341
+ print("\nContinued training complete!")
342
+ print(f"Best exact match accuracy: {best_val_exact_match:.4f}")
343
+ if __name__ == "__main__":
344
+ main()
continued_training_metric.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,train_loss,train_token_acc,val_loss,val_token_acc,val_exact_match,lr
2
+ 30,1.9502590653601657,0.7322589969014642,1.3020859152640936,0.6990427433882119,0.38998964956380305,1e-06
3
+ 31,1.9464403521302605,0.7328945476229131,1.3008682691263702,0.7001620300535886,0.3919858051160727,1e-06
4
+ 32,1.9446046694293662,0.733795435205851,1.2995548267971795,0.7003354483617926,0.39220760017743606,1e-06
5
+ 33,1.9418390615673053,0.7339540544097625,1.2990998206835873,0.7004338480391592,0.3923554635516783,1e-06
6
+ 34,1.9405346881137806,0.733893274767451,1.299637350552487,0.7005681339299904,0.39257725861304155,1e-06
7
+ 35,1.9380957318931413,0.7351758044757201,1.2987835997680448,0.7006050677232023,0.39265119030016266,1e-06
8
+ 36,1.9369506880350187,0.7359647384978554,1.2979233675407913,0.7013796053405078,0.39405589235546357,1e-06
9
+ 37,1.9360789391220428,0.7364758676075498,1.2977605515493538,0.7014409610123005,0.39398196066834246,1e-06
10
+ 38,1.9357275886693557,0.7362391176412685,1.297402927054549,0.7011285817848062,0.39353837054561586,1e-06
11
+ 39,1.932767997896227,0.736806065813456,1.2974532218474262,0.7004903276573937,0.39220760017743606,1e-06
12
+ 40,1.9330583925010325,0.7374090065552876,1.2972412691363748,0.7010474972567469,0.39316871211001037,1e-06
13
+ 41,1.9306796564990991,0.7378562083616544,1.2969766115804888,0.7015751037957534,0.39427768741682684,1e-06
14
+ 42,1.9282727334571266,0.7377051650808099,1.2973702516195909,0.7011518692070583,0.39331657548425253,1e-06
15
+ 43,1.9271106582502864,0.7386680361718415,1.2968672679842643,0.7010392338599799,0.39316871211001037,1e-06
16
+ 44,1.9269962475047457,0.7397106953586509,1.296902930680311,0.7012923545432541,0.3936862339198581,1e-06
17
+ 45,1.9244166012701376,0.7400805678048972,1.2962118839880206,0.7011814176473977,0.39353837054561586,1e-06
18
+ 46,1.9289601324296857,0.7377478108470924,1.296351783118158,0.7014656890675707,0.3941298240425847,5e-07
19
+ 47,1.9269490434459369,0.7386778470752227,1.2962728831565604,0.7015336796922503,0.39420375572970573,5e-07
20
+ 48,1.9252020313075702,0.7394137923214155,1.2964817043745294,0.7014302642277952,0.39420375572970573,5e-07
21
+ 49,1.9241666916486853,0.7392096879001484,1.296351099070513,0.7016751350096937,0.39449948247819017,5e-07
conversation_manager.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation Manager for Multi-turn VQA
3
+ Manages conversation state, context, and pronoun resolution
4
+ """
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional, Any
7
+ from datetime import datetime, timedelta
8
+ import uuid
9
+ import re
10
+ @dataclass
11
+ class ConversationTurn:
12
+ """Represents a single turn in a conversation"""
13
+ question: str
14
+ answer: str
15
+ objects_detected: List[str]
16
+ timestamp: datetime
17
+ reasoning_chain: Optional[List[str]] = None
18
+ model_used: Optional[str] = None
19
+ @dataclass
20
+ class ConversationSession:
21
+ """Represents a complete conversation session"""
22
+ session_id: str
23
+ image_path: str
24
+ history: List[ConversationTurn] = field(default_factory=list)
25
+ current_objects: List[str] = field(default_factory=list)
26
+ created_at: datetime = field(default_factory=datetime.now)
27
+ last_activity: datetime = field(default_factory=datetime.now)
28
+ def add_turn(
29
+ self,
30
+ question: str,
31
+ answer: str,
32
+ objects_detected: List[str],
33
+ reasoning_chain: Optional[List[str]] = None,
34
+ model_used: Optional[str] = None
35
+ ):
36
+ """Add a new turn to the conversation"""
37
+ turn = ConversationTurn(
38
+ question=question,
39
+ answer=answer,
40
+ objects_detected=objects_detected,
41
+ timestamp=datetime.now(),
42
+ reasoning_chain=reasoning_chain,
43
+ model_used=model_used
44
+ )
45
+ self.history.append(turn)
46
+ if objects_detected:
47
+ self.current_objects = objects_detected
48
+ self.last_activity = datetime.now()
49
+ def get_context_summary(self) -> str:
50
+ """Get a summary of the conversation context"""
51
+ if not self.history:
52
+ return "No previous conversation"
53
+ summary_parts = []
54
+ for i, turn in enumerate(self.history[-3:], 1):
55
+ summary_parts.append(f"Turn {i}: Q: {turn.question} A: {turn.answer}")
56
+ return " | ".join(summary_parts)
57
+ def is_expired(self, timeout_minutes: int = 30) -> bool:
58
+ """Check if session has expired"""
59
+ expiry_time = self.last_activity + timedelta(minutes=timeout_minutes)
60
+ return datetime.now() > expiry_time
61
+ class ConversationManager:
62
+ """
63
+ Manages multi-turn conversation sessions for VQA.
64
+ Handles context retention, pronoun resolution, and session lifecycle.
65
+ """
66
+ PRONOUNS = ['it', 'this', 'that', 'these', 'those', 'they', 'them']
67
+ def __init__(self, session_timeout_minutes: int = 30):
68
+ """
69
+ Initialize conversation manager
70
+ Args:
71
+ session_timeout_minutes: Minutes before a session expires
72
+ """
73
+ self.sessions: Dict[str, ConversationSession] = {}
74
+ self.session_timeout = session_timeout_minutes
75
+ print(f"✅ Conversation Manager initialized (timeout: {session_timeout_minutes}min)")
76
+ def create_session(self, image_path: str, session_id: Optional[str] = None) -> str:
77
+ """
78
+ Create a new conversation session
79
+ Args:
80
+ image_path: Path to the image for this conversation
81
+ session_id: Optional custom session ID (generates UUID if not provided)
82
+ Returns:
83
+ Session ID
84
+ """
85
+ if session_id is None:
86
+ session_id = str(uuid.uuid4())
87
+ session = ConversationSession(
88
+ session_id=session_id,
89
+ image_path=image_path
90
+ )
91
+ self.sessions[session_id] = session
92
+ return session_id
93
+ def get_session(self, session_id: str) -> Optional[ConversationSession]:
94
+ """
95
+ Get an existing session
96
+ Args:
97
+ session_id: Session ID to retrieve
98
+ Returns:
99
+ ConversationSession or None if not found/expired
100
+ """
101
+ session = self.sessions.get(session_id)
102
+ if session is None:
103
+ return None
104
+ if session.is_expired(self.session_timeout):
105
+ self.delete_session(session_id)
106
+ return None
107
+ return session
108
+ def get_or_create_session(
109
+ self,
110
+ session_id: Optional[str],
111
+ image_path: str
112
+ ) -> ConversationSession:
113
+ """
114
+ Get existing session or create new one
115
+ Args:
116
+ session_id: Optional session ID
117
+ image_path: Image path for new session
118
+ Returns:
119
+ ConversationSession
120
+ """
121
+ if session_id:
122
+ session = self.get_session(session_id)
123
+ if session:
124
+ return session
125
+ new_id = self.create_session(image_path, session_id)
126
+ return self.sessions[new_id]
127
+ def add_turn(
128
+ self,
129
+ session_id: str,
130
+ question: str,
131
+ answer: str,
132
+ objects_detected: List[str],
133
+ reasoning_chain: Optional[List[str]] = None,
134
+ model_used: Optional[str] = None
135
+ ) -> bool:
136
+ """
137
+ Add a turn to a conversation session
138
+ Args:
139
+ session_id: Session ID
140
+ question: User's question
141
+ answer: VQA answer
142
+ objects_detected: List of detected objects
143
+ reasoning_chain: Optional reasoning steps
144
+ model_used: Optional model identifier
145
+ Returns:
146
+ True if successful, False if session not found
147
+ """
148
+ session = self.get_session(session_id)
149
+ if session is None:
150
+ return False
151
+ session.add_turn(
152
+ question=question,
153
+ answer=answer,
154
+ objects_detected=objects_detected,
155
+ reasoning_chain=reasoning_chain,
156
+ model_used=model_used
157
+ )
158
+ return True
159
+ def resolve_references(
160
+ self,
161
+ question: str,
162
+ session: ConversationSession
163
+ ) -> str:
164
+ """
165
+ Resolve pronouns and references in a question using conversation context.
166
+ Args:
167
+ question: User's question (may contain pronouns)
168
+ session: Conversation session with context
169
+ Returns:
170
+ Question with pronouns resolved
171
+ Example:
172
+ Input: "Is it healthy?"
173
+ Context: Previous object was "apple"
174
+ Output: "Is apple healthy?"
175
+ """
176
+ if not session.history:
177
+ return question
178
+ q_lower = question.lower()
179
+ has_pronoun = any(pronoun in q_lower.split() for pronoun in self.PRONOUNS)
180
+ if not has_pronoun:
181
+ return question
182
+ recent_objects = session.current_objects
183
+ if not recent_objects:
184
+ return question
185
+ resolved = question
186
+ if any(pronoun in q_lower.split() for pronoun in ['it', 'this', 'that']):
187
+ primary_object = recent_objects[0]
188
+ resolved = re.sub(r'\bit\b', primary_object, resolved, flags=re.IGNORECASE)
189
+ resolved = re.sub(r'\bthis\b', primary_object, resolved, flags=re.IGNORECASE)
190
+ resolved = re.sub(r'\bthat\b', primary_object, resolved, flags=re.IGNORECASE)
191
+ if any(pronoun in q_lower.split() for pronoun in ['these', 'those', 'they', 'them']):
192
+ objects_phrase = ', '.join(recent_objects)
193
+ resolved = re.sub(r'\bthese\b', objects_phrase, resolved, flags=re.IGNORECASE)
194
+ resolved = re.sub(r'\bthose\b', objects_phrase, resolved, flags=re.IGNORECASE)
195
+ resolved = re.sub(r'\bthey\b', objects_phrase, resolved, flags=re.IGNORECASE)
196
+ resolved = re.sub(r'\bthem\b', objects_phrase, resolved, flags=re.IGNORECASE)
197
+ return resolved
198
+ def get_context_for_question(
199
+ self,
200
+ session_id: str,
201
+ question: str
202
+ ) -> Dict[str, Any]:
203
+ """
204
+ Get relevant context for answering a question
205
+ Args:
206
+ session_id: Session ID
207
+ question: Current question
208
+ Returns:
209
+ Dict with context information
210
+ """
211
+ session = self.get_session(session_id)
212
+ if session is None:
213
+ return {
214
+ 'has_context': False,
215
+ 'turn_number': 0,
216
+ 'previous_objects': [],
217
+ 'previous_questions': []
218
+ }
219
+ return {
220
+ 'has_context': len(session.history) > 0,
221
+ 'turn_number': len(session.history) + 1,
222
+ 'previous_objects': session.current_objects,
223
+ 'previous_questions': [turn.question for turn in session.history[-3:]],
224
+ 'previous_answers': [turn.answer for turn in session.history[-3:]],
225
+ 'context_summary': session.get_context_summary()
226
+ }
227
+ def get_history(self, session_id: str) -> Optional[List[Dict[str, Any]]]:
228
+ """
229
+ Get conversation history for a session
230
+ Args:
231
+ session_id: Session ID
232
+ Returns:
233
+ List of turn dictionaries or None if session not found
234
+ """
235
+ session = self.get_session(session_id)
236
+ if session is None:
237
+ return None
238
+ history = []
239
+ for turn in session.history:
240
+ history.append({
241
+ 'question': turn.question,
242
+ 'answer': turn.answer,
243
+ 'objects_detected': turn.objects_detected,
244
+ 'timestamp': turn.timestamp.isoformat(),
245
+ 'reasoning_chain': turn.reasoning_chain,
246
+ 'model_used': turn.model_used
247
+ })
248
+ return history
249
+ def delete_session(self, session_id: str) -> bool:
250
+ """
251
+ Delete a conversation session
252
+ Args:
253
+ session_id: Session ID to delete
254
+ Returns:
255
+ True if deleted, False if not found
256
+ """
257
+ if session_id in self.sessions:
258
+ del self.sessions[session_id]
259
+ return True
260
+ return False
261
+ def cleanup_expired_sessions(self):
262
+ """Remove all expired sessions"""
263
+ expired_ids = [
264
+ sid for sid, session in self.sessions.items()
265
+ if session.is_expired(self.session_timeout)
266
+ ]
267
+ for sid in expired_ids:
268
+ self.delete_session(sid)
269
+ return len(expired_ids)
270
+ def get_active_sessions_count(self) -> int:
271
+ """Get count of active (non-expired) sessions"""
272
+ self.cleanup_expired_sessions()
273
+ return len(self.sessions)
274
+ if __name__ == "__main__":
275
+ print("=" * 80)
276
+ print("🧪 Testing Conversation Manager")
277
+ print("=" * 80)
278
+ manager = ConversationManager(session_timeout_minutes=30)
279
+ print("\n📝 Test 1: Multi-turn conversation")
280
+ session_id = manager.create_session("test_image.jpg")
281
+ print(f"Created session: {session_id}")
282
+ manager.add_turn(
283
+ session_id=session_id,
284
+ question="What is this?",
285
+ answer="apple",
286
+ objects_detected=["apple"]
287
+ )
288
+ print("Turn 1: 'What is this?' → 'apple'")
289
+ session = manager.get_session(session_id)
290
+ question_2 = "Is it healthy?"
291
+ resolved_2 = manager.resolve_references(question_2, session)
292
+ print(f"Turn 2: '{question_2}' → Resolved: '{resolved_2}'")
293
+ manager.add_turn(
294
+ session_id=session_id,
295
+ question=question_2,
296
+ answer="Yes, apples are healthy",
297
+ objects_detected=["apple"]
298
+ )
299
+ question_3 = "What color is it?"
300
+ resolved_3 = manager.resolve_references(question_3, session)
301
+ print(f"Turn 3: '{question_3}' → Resolved: '{resolved_3}'")
302
+ print("\n📝 Test 2: Context retrieval")
303
+ context = manager.get_context_for_question(session_id, "Another question")
304
+ print(f"Turn number: {context['turn_number']}")
305
+ print(f"Previous objects: {context['previous_objects']}")
306
+ print(f"Context summary: {context['context_summary']}")
307
+ print("\n📝 Test 3: Conversation history")
308
+ history = manager.get_history(session_id)
309
+ for i, turn in enumerate(history, 1):
310
+ print(f" Turn {i}: Q: {turn['question']} | A: {turn['answer']}")
311
+ print("\n" + "=" * 80)
312
+ print("✅ Tests completed!")
download_models.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import hf_hub_download
3
+
4
+ REPO_ID = "Deva8/GENvqa-model"
5
+
6
+ # We use the token from the environment variable (which the user must set in Settings -> Secrets)
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
+
9
+ print("Downloading models from HuggingFace Hub...")
10
+
11
+ # Download base checkpoint
12
+ hf_hub_download(
13
+ repo_id=REPO_ID,
14
+ filename="vqa_checkpoint.pt",
15
+ local_dir=".",
16
+ token=HF_TOKEN
17
+ )
18
+ print("Base checkpoint downloaded successfully.")
19
+
20
+ # Download spatial checkpoint
21
+ hf_hub_download(
22
+ repo_id=REPO_ID,
23
+ filename="vqa_spatial_checkpoint.pt",
24
+ local_dir=".",
25
+ token=HF_TOKEN
26
+ )
27
+ print("Spatial checkpoint downloaded successfully.")
draft_generator.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import os
3
+
4
+ mermaid_code = """
5
+ graph TD
6
+ %% Styling
7
+ classDef default fill:#1A1A1A,stroke:#444,stroke-width:2px,color:#FFF,rx:8px,ry:8px,font-family:arial;
8
+ classDef mobile fill:#003366,stroke:#0055AA,stroke-width:2px,color:#FFF;
9
+ classDef preproc fill:#333333,stroke:#555,stroke-width:2px,color:#FFF;
10
+ classDef model fill:#4B0082,stroke:#8A2BE2,stroke-width:2px,color:#FFF;
11
+ classDef condition fill:#2B2B2B,stroke:#F4A460,stroke-width:2px,color:#FFF,shape:rhombus;
12
+ classDef external fill:#004d00,stroke:#009900,stroke-width:2px,color:#FFF;
13
+ classDef final fill:#660000,stroke:#CC0000,stroke-width:2px,color:#FFF;
14
+
15
+ %% Nodes
16
+ UserApp[📱 Mobile App]:::mobile
17
+
18
+ ImgUpload[🖼️ Image]:::preproc
19
+ Question[⌨️ Question Text]:::preproc
20
+
21
+ PIL[🐍 PIL Preprocessing<br/>RGB conversion]:::preproc
22
+
23
+ CLIP[👁️ OpenAI CLIP ViT-B/32<br/>Image Features 512-dim]:::model
24
+ GPT2[🤗 DistilGPT-2<br/>Tokenized Question]:::model
25
+
26
+ Route1{Question<br/>spatial?}:::condition
27
+
28
+ Spatial[📐 Spatial VQA Model<br/>8-head attention]:::model
29
+ Base[🧠 Base VQA Model<br/>General VQA]:::model
30
+
31
+ Decoder[🤗 GPT-2 Decoder<br/>vocab decode]:::model
32
+ NeuralAns[💬 Neural Answer]:::final
33
+
34
+ Route2{Knowledge<br/>question?}:::condition
35
+
36
+ ObjDet[👁️ CLIP Object Detector<br/>Top-3 objects]:::model
37
+ Wikidata[🌍 Wikidata SPARQL<br/>P31, P186, P366]:::external
38
+ GroqV[⚡ Groq Llama-3.3<br/>Verbalizer]:::external
39
+ KGAns[🧩 KG Enhancement]:::final
40
+
41
+ FastAPI[🚀 FastAPI]:::preproc
42
+ GroqA[⚡ Groq Llama-3.3<br/>Accessibility]:::external
43
+ Audio[🔊 2-sentence description]:::final
44
+
45
+ %% Edges
46
+ UserApp -- "Image uploaded" --> ImgUpload
47
+ UserApp -- "Question typed" --> Question
48
+
49
+ ImgUpload --> PIL
50
+ PIL --> CLIP
51
+ Question --> GPT2
52
+
53
+ CLIP & GPT2 --> Route1
54
+
55
+ Route1 -- "YES" --> Spatial
56
+ Route1 -- "NO" --> Base
57
+
58
+ Spatial & Base -- "Beam search (width=5)" --> Decoder
59
+ Decoder --> NeuralAns
60
+
61
+ CLIP -- "Anchor similarity" --> Route2
62
+
63
+ Route2 -- "YES" --> ObjDet
64
+ ObjDet -- "Detected objects" --> Wikidata
65
+ Wikidata -- "Structured facts" --> GroqV
66
+ GroqV --> KGAns
67
+
68
+ FastAPI -- "Narration request" --> GroqA
69
+ GroqA --> Audio
70
+
71
+ NeuralAns & KGAns & Audio -- "JSON output" --> FastAPI
72
+ FastAPI --> UserApp
73
+ """
74
+
75
+ file_path = r"C:\Users\rdeva\Downloads\vqa_coes\architecture_draft.mmd"
76
+
77
+ with open(file_path, "w", encoding="utf-8") as f:
78
+ f.write(mermaid_code)
79
+
80
+ print(f"Mermaid file saved to {file_path}")
81
+
82
+ # Note: In a real environment, we would use mermaid-cli (mmdc) to convert this to SVG/PNG.
83
+ # Since it might not be installed globally, we will just provide the mermaid file and
84
+ # instructions, or generate an HTML wrapper that renders it in browser.
85
+
86
+ html_path = r"C:\Users\rdeva\Downloads\vqa_coes\architecture_draft.html"
87
+ html_content = f"""
88
+ <!DOCTYPE html>
89
+ <html>
90
+ <head>
91
+ <title>VQA Architecture Draft</title>
92
+ <script type="module">
93
+ import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
94
+ mermaid.initialize({{ startOnLoad: true, theme: 'dark', flowchart: {{ curve: 'basis' }} }});
95
+ </script>
96
+ <style>
97
+ body {{ background-color: #0D1117; color: white; font-family: sans-serif; display: flex; justify-content: center; padding: 20px; }}
98
+ .mermaid {{ background-color: #161B22; padding: 20px; border-radius: 10px; box-shadow: 0 4px 8px rgba(0,0,0,0.5); }}
99
+ </style>
100
+ </head>
101
+ <body>
102
+ <div class="mermaid">
103
+ {mermaid_code}
104
+ </div>
105
+ </body>
106
+ </html>
107
+ """
108
+
109
+ with open(html_path, "w", encoding="utf-8") as f:
110
+ f.write(html_content)
111
+
112
+ print(f"HTML viewer saved to {html_path}")
ensemble_vqa_app.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Production Ensemble VQA Application
3
+ Combines base model (general VQA) and spatial adapter (spatial reasoning)
4
+ for optimal performance on all question types.
5
+ NEW: Neuro-Symbolic VQA with Knowledge Graph integration
6
+ NEW: Multi-turn Conversational VQA with context management
7
+ """
8
+ import os
9
+ import torch
10
+ from PIL import Image
11
+ from transformers import GPT2Tokenizer
12
+ from models.model import VQAModel
13
+ from model_spatial import VQAModelWithSpatialAdapter
14
+ from experiments.train import Vocab
15
+ from knowledge_graph_service import KnowledgeGraphService
16
+ from typing import Optional
17
+ import time
18
+ class ProductionEnsembleVQA:
19
+
20
+ SPATIAL_KEYWORDS = [
21
+ 'right', 'left', 'above', 'below', 'top', 'bottom',
22
+ 'up', 'down', 'upward', 'downward',
23
+ 'front', 'behind', 'back', 'next to', 'beside', 'near', 'between',
24
+ 'in front', 'in back', 'across from', 'opposite', 'adjacent',
25
+ 'closest', 'farthest', 'nearest', 'furthest', 'closer', 'farther',
26
+ 'where is', 'where are', 'which side', 'what side', 'what direction',
27
+ 'on the left', 'on the right', 'at the top', 'at the bottom',
28
+ 'to the left', 'to the right', 'in the middle', 'in the center',
29
+ 'under', 'over', 'underneath', 'on top of', 'inside', 'outside'
30
+ ]
31
+ def __init__(self, base_checkpoint, spatial_checkpoint, device='cuda'):
32
+
33
+ self.device = device if torch.cuda.is_available() else 'cpu'
34
+ print("="*80)
35
+ print("🚀 INITIALIZING ENSEMBLE VQA SYSTEM")
36
+ print("="*80)
37
+ print(f"\n⚙️ Device: {self.device}")
38
+ print("\n📥 Loading models...")
39
+ start_time = time.time()
40
+ print(" [1/2] Loading base model (general VQA)...")
41
+ self.base_model, self.vocab, self.tokenizer = self._load_base_model(base_checkpoint)
42
+ print(" ✓ Base model loaded")
43
+ print(" [2/2] Loading spatial model (spatial reasoning)...")
44
+ self.spatial_model, _, _ = self._load_spatial_model(spatial_checkpoint)
45
+ print(" ✓ Spatial model loaded")
46
+ load_time = time.time() - start_time
47
+ print(" [3/3] Initializing Semantic Neuro-Symbolic VQA...")
48
+ try:
49
+ from semantic_neurosymbolic_vqa import SemanticNeurosymbolicVQA
50
+ self.kg_service = SemanticNeurosymbolicVQA(device=self.device)
51
+ print(" ✓ Semantic Neuro-Symbolic VQA ready (CLIP + Wikidata, no pattern matching)")
52
+ self.kg_enabled = True
53
+ except Exception as e:
54
+ print(f" ⚠️ Semantic Neuro-Symbolic VQA unavailable: {e}")
55
+ print(" → Falling back to neural-only mode")
56
+ self.kg_service = None
57
+ self.kg_enabled = False
58
+ print(f"\n✅ Ensemble ready! (loaded in {load_time:.1f}s)")
59
+ print(f"📊 Memory: ~2x single model (~4GB GPU)")
60
+ print(f"🎯 Routing: Automatic based on question type")
61
+ print(f"🧠 Neuro-Symbolic: {'Enabled' if self.kg_enabled else 'Disabled (neural-only)'}")
62
+ print(f"💬 Conversation: Initializing multi-turn support...")
63
+ try:
64
+ from conversation_manager import ConversationManager
65
+ self.conversation_manager = ConversationManager(session_timeout_minutes=30)
66
+ self.conversation_enabled = True
67
+ print(f" ✓ Conversational VQA ready (multi-turn with context)")
68
+ except Exception as e:
69
+ print(f" ⚠️ Conversation manager unavailable: {e}")
70
+ print(f" → Single-shot Q&A only")
71
+ self.conversation_manager = None
72
+ self.conversation_enabled = False
73
+ print("="*80)
74
+ def _load_base_model(self, checkpoint_path):
75
+ """Load base VQA model."""
76
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
77
+ vocab = Vocab()
78
+ vocab.vocab = checkpoint['vocab']
79
+ vocab.vocab_size = len(checkpoint['vocab'])
80
+ vocab.word2idx = checkpoint['word2idx']
81
+ vocab.idx2word = checkpoint['idx2word']
82
+ vocab.pad_token_id = checkpoint['pad_token_id']
83
+ vocab.bos_token_id = checkpoint['bos_token_id']
84
+ vocab.eos_token_id = checkpoint['eos_token_id']
85
+ vocab.unk_token_id = checkpoint['unk_token_id']
86
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
87
+ if tokenizer.pad_token is None:
88
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
89
+ model = VQAModel(
90
+ vocab_size=len(checkpoint['vocab']),
91
+ device=self.device,
92
+ question_max_len=checkpoint.get('question_max_len', 20),
93
+ answer_max_len=checkpoint.get('answer_max_len', 12),
94
+ pad_token_id=checkpoint['pad_token_id'],
95
+ bos_token_id=checkpoint['bos_token_id'],
96
+ eos_token_id=checkpoint['eos_token_id'],
97
+ unk_token_id=checkpoint['unk_token_id'],
98
+ hidden_size=512,
99
+ num_layers=2
100
+ ).to(self.device)
101
+ model.gpt2_model.resize_token_embeddings(len(tokenizer))
102
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
103
+ model.eval()
104
+ return model, vocab, tokenizer
105
+ def _load_spatial_model(self, checkpoint_path):
106
+ """Load spatial adapter model."""
107
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
108
+ vocab = Vocab()
109
+ vocab.vocab = checkpoint['vocab']
110
+ vocab.vocab_size = len(checkpoint['vocab'])
111
+ vocab.word2idx = checkpoint['word2idx']
112
+ vocab.idx2word = checkpoint['idx2word']
113
+ vocab.pad_token_id = checkpoint['pad_token_id']
114
+ vocab.bos_token_id = checkpoint['bos_token_id']
115
+ vocab.eos_token_id = checkpoint['eos_token_id']
116
+ vocab.unk_token_id = checkpoint['unk_token_id']
117
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
118
+ if tokenizer.pad_token is None:
119
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
120
+ base_model = VQAModel(
121
+ vocab_size=len(checkpoint['vocab']),
122
+ device=self.device,
123
+ question_max_len=checkpoint.get('question_max_len', 20),
124
+ answer_max_len=checkpoint.get('answer_max_len', 12),
125
+ pad_token_id=checkpoint['pad_token_id'],
126
+ bos_token_id=checkpoint['bos_token_id'],
127
+ eos_token_id=checkpoint['eos_token_id'],
128
+ unk_token_id=checkpoint['unk_token_id'],
129
+ hidden_size=512,
130
+ num_layers=2
131
+ ).to(self.device)
132
+ base_model.gpt2_model.resize_token_embeddings(len(tokenizer))
133
+ model = VQAModelWithSpatialAdapter(
134
+ base_model=base_model,
135
+ hidden_size=512,
136
+ num_heads=8,
137
+ dropout=0.3
138
+ ).to(self.device)
139
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
140
+ model.eval()
141
+ return model, vocab, tokenizer
142
+ def is_spatial_question(self, question):
143
+ """
144
+ Classify if a question is spatial using keyword matching.
145
+ Args:
146
+ question: Question string
147
+ Returns:
148
+ bool: True if spatial, False otherwise
149
+ """
150
+ q_lower = question.lower()
151
+ return any(keyword in q_lower for keyword in self.SPATIAL_KEYWORDS)
152
+ def answer(self, image_path, question, use_beam_search=True, beam_width=5, verbose=False):
153
+ """
154
+ Answer a question by routing to appropriate model.
155
+ Now with Neuro-Symbolic reasoning for common-sense questions!
156
+ Args:
157
+ image_path: Path to image file
158
+ question: Question string
159
+ use_beam_search: Whether to use beam search (better quality)
160
+ beam_width: Beam width for beam search
161
+ verbose: Print routing information
162
+ Returns:
163
+ dict: {
164
+ 'answer': str,
165
+ 'model_used': 'spatial' or 'base',
166
+ 'confidence': float,
167
+ 'kg_enhancement': str (optional),
168
+ 'reasoning_type': 'neural' or 'neuro-symbolic'
169
+ }
170
+ """
171
+ is_spatial = self.is_spatial_question(question)
172
+ model_used = 'spatial' if is_spatial else 'base'
173
+ if verbose:
174
+ print(f"🔍 Question type: {'Spatial' if is_spatial else 'General'}")
175
+ print(f"🤖 Using: {model_used} model")
176
+ model = self.spatial_model if is_spatial else self.base_model
177
+ image = Image.open(image_path).convert('RGB')
178
+ image = model.clip_preprocess(image).unsqueeze(0).to(self.device)
179
+ question_tokens = self.tokenizer(
180
+ question,
181
+ padding='max_length',
182
+ truncation=True,
183
+ max_length=model.question_max_len,
184
+ return_tensors='pt'
185
+ )
186
+ questions = {
187
+ 'input_ids': question_tokens['input_ids'].to(self.device),
188
+ 'attention_mask': question_tokens['attention_mask'].to(self.device)
189
+ }
190
+ with torch.no_grad():
191
+ if use_beam_search and hasattr(model, 'generate_with_beam_search'):
192
+ generated = model.generate_with_beam_search(
193
+ image, questions, beam_width=beam_width
194
+ )
195
+ else:
196
+ generated = model(image, questions)
197
+ # Always get the neural answer first — it is ALWAYS the primary answer
198
+ if verbose:
199
+ print("📝 Using neural VQA...")
200
+ neural_answer = self.vocab.decoder(generated[0].cpu().numpy())
201
+
202
+ # Neuro-symbolic is a *supplement* only — its result goes into kg_enhancement,
203
+ # never replacing the neural answer.
204
+ kg_enhancement = None
205
+ reasoning_type = 'neural'
206
+ objects_detected = []
207
+ question_intent = None
208
+ wikidata_entity = None
209
+ knowledge_source = None
210
+
211
+ if self.kg_enabled and self.kg_service:
212
+ if verbose:
213
+ print("🔍 Analyzing question semantics...")
214
+ should_use_ns = self.kg_service.should_use_neurosymbolic(
215
+ image_features=None,
216
+ question=question,
217
+ vqa_confidence=0.0,
218
+ image_path=image_path
219
+ )
220
+ if should_use_ns:
221
+ if verbose:
222
+ print("🧠 Neuro-Symbolic supplement: detecting subject via CLIP...")
223
+
224
+ # CLIP zero-shot: compare image against 80+ concrete noun labels
225
+ # This is much more accurate than asking the VQA model
226
+ detected_objects = self.kg_service.detect_objects_with_clip(
227
+ image_path=image_path, top_k=3
228
+ )
229
+
230
+ if verbose:
231
+ print(f" → CLIP detected: {detected_objects}")
232
+ print(" → Fetching Wikidata facts + Groq verbalization...")
233
+
234
+ if detected_objects:
235
+ ns_result = self.kg_service.answer_with_clip_features(
236
+ image_features=None,
237
+ question=question,
238
+ image_path=image_path,
239
+ detected_objects=tuple(detected_objects)
240
+ )
241
+
242
+ if ns_result:
243
+ kg_enhancement = ns_result['kg_enhancement']
244
+ reasoning_type = 'neuro-symbolic'
245
+ objects_detected = detected_objects # expose to return dict
246
+ question_intent = ns_result.get('question_intent')
247
+ wikidata_entity = ns_result.get('wikidata_entity')
248
+ knowledge_source = ns_result.get('knowledge_source')
249
+ if verbose:
250
+ print(f"✨ Neuro-Symbolic supplement: {kg_enhancement}")
251
+ print(f" → Wikidata entity: {wikidata_entity}")
252
+ else:
253
+ if verbose:
254
+ print(" → CLIP could not identify subject, skipping Wikidata lookup")
255
+
256
+ return {
257
+ 'answer': neural_answer,
258
+ 'model_used': model_used,
259
+ 'confidence': 1.0,
260
+ 'kg_enhancement': kg_enhancement,
261
+ 'reasoning_type': reasoning_type,
262
+ 'objects_detected': objects_detected,
263
+ 'question_intent': question_intent,
264
+ 'wikidata_entity': wikidata_entity,
265
+ 'knowledge_source': knowledge_source,
266
+ }
267
+ def answer_conversational(
268
+ self,
269
+ image_path: str,
270
+ question: str,
271
+ session_id: Optional[str] = None,
272
+ use_beam_search: bool = True,
273
+ beam_width: int = 5,
274
+ verbose: bool = False
275
+ ) -> dict:
276
+ """
277
+ Answer a question with multi-turn conversation support.
278
+ Handles pronoun resolution and context management.
279
+ Args:
280
+ image_path: Path to image file
281
+ question: Question string (may contain pronouns like "it", "this")
282
+ session_id: Optional session ID for continuing conversation
283
+ use_beam_search: Whether to use beam search
284
+ beam_width: Beam width for beam search
285
+ verbose: Print routing information
286
+ Returns:
287
+ dict: {
288
+ 'answer': str,
289
+ 'session_id': str,
290
+ 'resolved_question': str,
291
+ 'conversation_context': dict,
292
+ ... (other fields from answer())
293
+ }
294
+ """
295
+ if not self.conversation_enabled or not self.conversation_manager:
296
+ result = self.answer(image_path, question, use_beam_search, beam_width, verbose)
297
+ result['session_id'] = None
298
+ result['resolved_question'] = question
299
+ result['conversation_context'] = {'has_context': False}
300
+ return result
301
+ session = self.conversation_manager.get_or_create_session(session_id, image_path)
302
+ actual_session_id = session.session_id
303
+ if verbose:
304
+ print(f"💬 Session: {actual_session_id}")
305
+ print(f" Turn number: {len(session.history) + 1}")
306
+ resolved_question = self.conversation_manager.resolve_references(question, session)
307
+ if verbose and resolved_question != question:
308
+ print(f"🔄 Pronoun resolution:")
309
+ print(f" Original: {question}")
310
+ print(f" Resolved: {resolved_question}")
311
+ result = self.answer(
312
+ image_path=image_path,
313
+ question=resolved_question,
314
+ use_beam_search=use_beam_search,
315
+ beam_width=beam_width,
316
+ verbose=verbose
317
+ )
318
+ self.conversation_manager.add_turn(
319
+ session_id=actual_session_id,
320
+ question=question,
321
+ answer=result['answer'],
322
+ objects_detected=result.get('objects_detected', []),
323
+ reasoning_chain=result.get('reasoning_chain'),
324
+ model_used=result.get('model_used')
325
+ )
326
+ context = self.conversation_manager.get_context_for_question(
327
+ actual_session_id,
328
+ question
329
+ )
330
+ result['session_id'] = actual_session_id
331
+ result['resolved_question'] = resolved_question
332
+ result['conversation_context'] = context
333
+ return result
334
+ def _detect_multiple_objects(self, image, vqa_model, top_k=3):
335
+ """
336
+ Detect the primary subject of the image using neutral, unbiased questions.
337
+ We ask the same question several ways so the VQA model has the best chance
338
+ of identifying the actual subject — never biasing toward food or objects.
339
+ Returns at most top_k unique answers.
340
+ """
341
+ # Neutral questions — no food bias, no category bias
342
+ detection_questions = [
343
+ "What is the main subject of this image?",
344
+ "What is in this image?",
345
+ "What is shown in this picture?",
346
+ ]
347
+ # Tokens we treat as non-answers
348
+ stop_words = {'a', 'an', 'the', 'this', 'that', 'it', 'yes', 'no',
349
+ 'some', 'there', 'here', 'image', 'picture', 'photo'}
350
+ detected = []
351
+ for question in detection_questions:
352
+ try:
353
+ question_tokens = self.tokenizer(
354
+ question,
355
+ padding='max_length',
356
+ truncation=True,
357
+ max_length=vqa_model.question_max_len,
358
+ return_tensors='pt'
359
+ )
360
+ questions = {
361
+ 'input_ids': question_tokens['input_ids'].to(self.device),
362
+ 'attention_mask': question_tokens['attention_mask'].to(self.device)
363
+ }
364
+ with torch.no_grad():
365
+ generated = vqa_model(image, questions)
366
+ answer = self.vocab.decoder(generated[0].cpu().numpy()).strip()
367
+ if (answer
368
+ and answer.lower() not in stop_words
369
+ and answer not in detected):
370
+ detected.append(answer)
371
+ if len(detected) >= top_k:
372
+ break
373
+ except Exception as e:
374
+ print(f" ⚠️ Error detecting objects: {e}")
375
+ continue
376
+ return detected if detected else []
377
+ def batch_answer(self, image_question_pairs, use_beam_search=True, verbose=False):
378
+ """
379
+ Answer multiple questions efficiently.
380
+ Args:
381
+ image_question_pairs: List of (image_path, question) tuples
382
+ use_beam_search: Whether to use beam search
383
+ verbose: Print progress
384
+ Returns:
385
+ List of result dicts
386
+ """
387
+ results = []
388
+ total = len(image_question_pairs)
389
+ for i, (image_path, question) in enumerate(image_question_pairs):
390
+ if verbose:
391
+ print(f"\n[{i+1}/{total}] Processing...")
392
+ result = self.answer(image_path, question, use_beam_search, verbose=verbose)
393
+ results.append(result)
394
+ return results
395
+ def demo():
396
+ """Demo usage of production ensemble VQA."""
397
+ BASE_CHECKPOINT = "./output2/continued_training/vqa_checkpoint.pt"
398
+ SPATIAL_CHECKPOINT = "./output2/spatial_adapter_v2_2/vqa_spatial_checkpoint.pt"
399
+ IMAGE = "./im2.jpg"
400
+ ensemble = ProductionEnsembleVQA(BASE_CHECKPOINT, SPATIAL_CHECKPOINT)
401
+ test_cases = [
402
+ ("what is to the right of the soup?", True),
403
+ ("what is on the left side?", True),
404
+ ("what is above the table?", True),
405
+ ("what is next to the bowl?", True),
406
+ ("what color is the bowl?", False),
407
+ ("how many items are there?", False),
408
+ ("what room is this?", False),
409
+ ("is there a spoon?", False),
410
+ ]
411
+ print("\n" + "="*80)
412
+ print("🧪 TESTING ENSEMBLE VQA SYSTEM")
413
+ print("="*80)
414
+ print(f"\n📷 Image: {IMAGE}\n")
415
+ for question, expected_spatial in test_cases:
416
+ result = ensemble.answer(IMAGE, question, verbose=False)
417
+ is_spatial = result['model_used'] == 'spatial'
418
+ routing_correct = "✓" if is_spatial == expected_spatial else "✗"
419
+ print(f"Q: {question}")
420
+ print(f"A: {result['answer']}")
421
+ print(f"Model: {result['model_used']} {routing_correct}")
422
+ print()
423
+ print("="*80)
424
+ print("✅ Demo complete!")
425
+ def interactive_mode():
426
+ """Interactive mode for testing."""
427
+ BASE_CHECKPOINT = "./output2/continued_training/vqa_checkpoint.pt"
428
+ SPATIAL_CHECKPOINT = "./output2/spatial_adapter_v2_2/vqa_spatial_checkpoint.pt"
429
+ ensemble = ProductionEnsembleVQA(BASE_CHECKPOINT, SPATIAL_CHECKPOINT)
430
+ print("\n" + "="*80)
431
+ print("🎮 INTERACTIVE MODE")
432
+ print("="*80)
433
+ print("\nCommands:")
434
+ print(" - Enter image path and question")
435
+ print(" - Type 'quit' to exit")
436
+ print("="*80 + "\n")
437
+ while True:
438
+ try:
439
+ image_path = input("📷 Image path: ").strip()
440
+ if image_path.lower() == 'quit':
441
+ break
442
+ question = input("❓ Question: ").strip()
443
+ if question.lower() == 'quit':
444
+ break
445
+ result = ensemble.answer(image_path, question, verbose=True)
446
+ print(f"\n💬 Answer: {result['answer']}\n")
447
+ print("-"*80 + "\n")
448
+ except KeyboardInterrupt:
449
+ print("\n\n👋 Goodbye!")
450
+ break
451
+ except Exception as e:
452
+ print(f"\n❌ Error: {e}\n")
453
+ if __name__ == "__main__":
454
+ import sys
455
+ if len(sys.argv) > 1 and sys.argv[1] == "interactive":
456
+ interactive_mode()
457
+ else:
458
+ demo()
enterprise_architecture.drawio ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <mxGraphModel dx="1800" dy="1100" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="1920" pageHeight="1080" math="0" shadow="1">
3
+ <root>
4
+ <mxCell id="0" />
5
+ <mxCell id="1" parent="0" />
6
+
7
+ <mxCell id="bg" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=none;" vertex="1" parent="1">
8
+ <mxGeometry x="-20" y="-20" width="1960" height="1120" as="geometry" />
9
+ </mxCell>
10
+
11
+ <mxCell id="title_bg" value="" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#161B22;strokeColor=#30363D;" vertex="1" parent="1">
12
+ <mxGeometry x="20" y="20" width="1880" height="70" as="geometry" />
13
+ </mxCell>
14
+
15
+ <mxCell id="title" value="&lt;font style=&quot;font-size:24px;font-weight:bold;&quot; color=&quot;#58A6FF&quot;&gt;Semantic Neuro-Symbolic VQA -- Enterprise Architecture&lt;/font&gt;&lt;br&gt;&lt;font style=&quot;font-size:11px;&quot; color=&quot;#8B949E&quot;&gt;React Native Mobile UI | FastAPI (Uvicorn) | PyTorch | OpenAI CLIP | Wikidata SPARQL | Groq LLM (Llama-3.3-70B-Versatile)&lt;/font&gt;" style="text;html=1;strokeColor=none;fillColor=none;align=center;verticalAlign=middle;whiteSpace=wrap;" vertex="1" parent="1">
16
+ <mxGeometry x="20" y="20" width="1880" height="70" as="geometry" />
17
+ </mxCell>
18
+
19
+ <!-- ===================== CLIENT LAYER ===================== -->
20
+ <mxCell id="client_layer" value="&lt;font style=&quot;font-size:14px;font-weight:bold;&quot; color=&quot;#79C0FF&quot;&gt;[1] CLIENT LAYER&lt;/font&gt;" style="swimlane;startSize=30;fillColor=#161B22;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontStyle=1;fontSize=13;rounded=10;" vertex="1" parent="1">
21
+ <mxGeometry x="20" y="110" width="350" height="870" as="geometry" />
22
+ </mxCell>
23
+
24
+ <mxCell id="mobile_label" value="[React Native / Expo]" style="text;html=1;fontSize=20;align=center;fillColor=none;strokeColor=none;fontColor=#58A6FF;" vertex="1" parent="client_layer">
25
+ <mxGeometry x="80" y="38" width="190" height="35" as="geometry" />
26
+ </mxCell>
27
+
28
+ <mxCell id="mobile_app" value="&lt;b&gt;React Native Mobile App&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Expo Framework | iOS and Android&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=12;" vertex="1" parent="client_layer">
29
+ <mxGeometry x="30" y="85" width="290" height="60" as="geometry" />
30
+ </mxCell>
31
+
32
+ <mxCell id="screen_login" value="&lt;b&gt;LoginScreen.js&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Auth | Session Management&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
33
+ <mxGeometry x="30" y="165" width="290" height="50" as="geometry" />
34
+ </mxCell>
35
+
36
+ <mxCell id="screen_camera" value="&lt;b&gt;CameraScreen.js&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Image Capture | Upload&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
37
+ <mxGeometry x="30" y="225" width="290" height="50" as="geometry" />
38
+ </mxCell>
39
+
40
+ <mxCell id="screen_home" value="&lt;b&gt;HomeScreen.js&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Main Dashboard | History&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
41
+ <mxGeometry x="30" y="285" width="290" height="50" as="geometry" />
42
+ </mxCell>
43
+
44
+ <mxCell id="screen_qa" value="&lt;b&gt;QuestionScreen.js&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Q and A Interface | Conversation&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
45
+ <mxGeometry x="30" y="345" width="290" height="50" as="geometry" />
46
+ </mxCell>
47
+
48
+ <mxCell id="screen_result" value="&lt;b&gt;ResultScreen.js&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Answer Display | KG Enhancement&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
49
+ <mxGeometry x="30" y="405" width="290" height="50" as="geometry" />
50
+ </mxCell>
51
+
52
+ <mxCell id="api_js" value="&lt;b&gt;api.js (API Service)&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Axios | FormData | Session Tokens&lt;br&gt;REST calls to FastAPI backend&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A2820;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="client_layer">
53
+ <mxGeometry x="30" y="478" width="290" height="70" as="geometry" />
54
+ </mxCell>
55
+
56
+ <mxCell id="ep1" value="POST /api/answer" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#3FB950;fontColor=#3FB950;fontSize=10;" vertex="1" parent="client_layer">
57
+ <mxGeometry x="30" y="565" width="135" height="30" as="geometry" />
58
+ </mxCell>
59
+ <mxCell id="ep2" value="POST /api/conversation/answer" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#3FB950;fontColor=#3FB950;fontSize=10;" vertex="1" parent="client_layer">
60
+ <mxGeometry x="177" y="565" width="143" height="30" as="geometry" />
61
+ </mxCell>
62
+ <mxCell id="ep3" value="GET /api/models/info" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#3FB950;fontColor=#3FB950;fontSize=10;" vertex="1" parent="client_layer">
63
+ <mxGeometry x="30" y="605" width="135" height="30" as="geometry" />
64
+ </mxCell>
65
+ <mxCell id="ep4" value="GET/DELETE /api/conversation/{id}" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#3FB950;fontColor=#3FB950;fontSize=10;" vertex="1" parent="client_layer">
66
+ <mxGeometry x="177" y="605" width="143" height="30" as="geometry" />
67
+ </mxCell>
68
+
69
+ <mxCell id="client_tech" value="&lt;b&gt;Tech:&lt;/b&gt; Expo | React Navigation | Axios | FormData&lt;br&gt;&lt;b&gt;Auth:&lt;/b&gt; Session tokens | Context API" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#161B22;strokeColor=#21262D;fontColor=#8B949E;fontSize=10;" vertex="1" parent="client_layer">
70
+ <mxGeometry x="30" y="660" width="290" height="55" as="geometry" />
71
+ </mxCell>
72
+
73
+ <!-- ===================== API GATEWAY LAYER ===================== -->
74
+ <mxCell id="api_layer" value="&lt;font style=&quot;font-size:14px;font-weight:bold;&quot; color=&quot;#56D364&quot;&gt;[2] API GATEWAY LAYER&lt;/font&gt;" style="swimlane;startSize=30;fillColor=#161B22;strokeColor=#3FB950;fontColor=#FFFFFF;fontStyle=1;fontSize=13;rounded=10;" vertex="1" parent="1">
75
+ <mxGeometry x="400" y="110" width="360" height="870" as="geometry" />
76
+ </mxCell>
77
+
78
+ <mxCell id="apigw_label" value="[FastAPI + Uvicorn]" style="text;html=1;fontSize=20;align=center;fillColor=none;strokeColor=none;fontColor=#3FB950;" vertex="1" parent="api_layer">
79
+ <mxGeometry x="85" y="38" width="190" height="35" as="geometry" />
80
+ </mxCell>
81
+
82
+ <mxCell id="fastapi_main" value="&lt;b&gt;FastAPI Backend (Uvicorn)&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;backend_api.py&lt;br&gt;Host: 0.0.0.0 | Port: 8000&lt;br&gt;CORS enabled | Auto-reload dev mode&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#162415;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
83
+ <mxGeometry x="20" y="88" width="320" height="80" as="geometry" />
84
+ </mxCell>
85
+
86
+ <mxCell id="startup" value="&lt;b&gt;Startup Event&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Load checkpoints | Init models&lt;br&gt;Init Groq service | Health check&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
87
+ <mxGeometry x="20" y="188" width="320" height="60" as="geometry" />
88
+ </mxCell>
89
+
90
+ <mxCell id="ep_health" value="GET /health&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Model status check&lt;/font&gt;" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
91
+ <mxGeometry x="20" y="268" width="145" height="50" as="geometry" />
92
+ </mxCell>
93
+ <mxCell id="ep_root" value="GET /&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;API info and docs&lt;/font&gt;" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
94
+ <mxGeometry x="175" y="268" width="145" height="50" as="geometry" />
95
+ </mxCell>
96
+ <mxCell id="ep_answer" value="POST /api/answer&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;image + question -&gt; JSON answer&lt;/font&gt;" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#132D0E;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
97
+ <mxGeometry x="20" y="328" width="300" height="50" as="geometry" />
98
+ </mxCell>
99
+ <mxCell id="ep_conv" value="POST /api/conversation/answer&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Multi-turn | session_id | pronouns&lt;/font&gt;" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#132D0E;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
100
+ <mxGeometry x="20" y="388" width="300" height="50" as="geometry" />
101
+ </mxCell>
102
+ <mxCell id="ep_hist" value="GET /api/conversation/{id}/history" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
103
+ <mxGeometry x="20" y="448" width="300" height="38" as="geometry" />
104
+ </mxCell>
105
+ <mxCell id="ep_del" value="DELETE /api/conversation/{id}" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
106
+ <mxGeometry x="20" y="496" width="300" height="38" as="geometry" />
107
+ </mxCell>
108
+ <mxCell id="ep_models" value="GET /api/models/info" style="rounded=6;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
109
+ <mxGeometry x="20" y="544" width="300" height="38" as="geometry" />
110
+ </mxCell>
111
+
112
+ <mxCell id="middleware" value="&lt;b&gt;Middleware&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;CORS | Error handling | HTTP 400/503/500&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
113
+ <mxGeometry x="20" y="600" width="320" height="50" as="geometry" />
114
+ </mxCell>
115
+
116
+ <mxCell id="conv_manager" value="&lt;b&gt;ConversationManager&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;conversation_manager.py&lt;br&gt;Session 30min timeout | Pronoun resolution&lt;br&gt;History storage | Context retrieval&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A1A2E;strokeColor=#7B2FBE;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="api_layer">
117
+ <mxGeometry x="20" y="670" width="320" height="80" as="geometry" />
118
+ </mxCell>
119
+
120
+ <!-- ===================== ML INFERENCE ENGINE ===================== -->
121
+ <mxCell id="ml_layer" value="&lt;font style=&quot;font-size:14px;font-weight:bold;&quot; color=&quot;#FFA657&quot;&gt;[3] ML INFERENCE ENGINE&lt;/font&gt;" style="swimlane;startSize=30;fillColor=#161B22;strokeColor=#D29922;fontColor=#FFFFFF;fontStyle=1;fontSize=13;rounded=10;" vertex="1" parent="1">
122
+ <mxGeometry x="800" y="110" width="380" height="870" as="geometry" />
123
+ </mxCell>
124
+
125
+ <mxCell id="ml_label" value="[PyTorch + CLIP + DistilGPT-2]" style="text;html=1;fontSize=16;align=center;fillColor=none;strokeColor=none;fontColor=#D29922;" vertex="1" parent="ml_layer">
126
+ <mxGeometry x="40" y="38" width="300" height="35" as="geometry" />
127
+ </mxCell>
128
+
129
+ <mxCell id="ensemble_vqa" value="&lt;b&gt;ProductionEnsembleVQA&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;ensemble_vqa_app.py&lt;br&gt;Device: CUDA / CPU auto-detect&lt;br&gt;Beam Search width=5 | Top-K Decoding&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#2D2000;strokeColor=#D29922;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
130
+ <mxGeometry x="20" y="88" width="340" height="80" as="geometry" />
131
+ </mxCell>
132
+
133
+ <mxCell id="router" value="&lt;b&gt;Question Router (Keyword Classifier)&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;is_spatial_question()&lt;br&gt;Spatial keywords: left, right, above, below, next to...&lt;br&gt;Routes to Base or Spatial model&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1E1E00;strokeColor=#D29922;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
134
+ <mxGeometry x="20" y="188" width="340" height="75" as="geometry" />
135
+ </mxCell>
136
+
137
+ <mxCell id="base_model_box" value="&lt;b&gt;Base VQA Model&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;model.py | VQAModel&lt;br&gt;CLIP ViT-B/32 + GPT-2&lt;br&gt;vqa_checkpoint.pt (731 MB)&lt;br&gt;hidden=512 | layers=2 | acc~50%&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#162415;strokeColor=#3FB950;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
138
+ <mxGeometry x="20" y="285" width="158" height="120" as="geometry" />
139
+ </mxCell>
140
+
141
+ <mxCell id="spatial_model_box" value="&lt;b&gt;Spatial VQA Model&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;model_spatial.py&lt;br&gt;SpatialAdapter + 8-head attn&lt;br&gt;vqa_spatial_checkpoint.pt (739 MB)&lt;br&gt;dropout=0.3 | acc~40%&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#0D2137;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
142
+ <mxGeometry x="192" y="285" width="168" height="120" as="geometry" />
143
+ </mxCell>
144
+
145
+ <mxCell id="gpt2" value="&lt;b&gt;DistilGPT-2 Tokenizer&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Text tokenization | Vocab&lt;br&gt;BOS / EOS / PAD tokens | Beam search decoding&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
146
+ <mxGeometry x="20" y="425" width="340" height="65" as="geometry" />
147
+ </mxCell>
148
+
149
+ <mxCell id="clip_box" value="&lt;b&gt;OpenAI CLIP (ViT-B/32)&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;Image encoder + Text encoder&lt;br&gt;Zero-shot object detection (80+ nouns)&lt;br&gt;Question routing: visual vs knowledge&lt;br&gt;Anchor similarity | Softmax x10&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A1A0D;strokeColor=#E3B341;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
150
+ <mxGeometry x="20" y="508" width="340" height="90" as="geometry" />
151
+ </mxCell>
152
+
153
+ <mxCell id="img_proc" value="&lt;b&gt;Image Preprocessor (PIL)&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;JPEG/PNG -&gt; RGB | CLIP preprocess | Tensor&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#1C2128;strokeColor=#30363D;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
154
+ <mxGeometry x="20" y="615" width="340" height="55" as="geometry" />
155
+ </mxCell>
156
+
157
+ <mxCell id="pt_files" value="&lt;b&gt;PyTorch Checkpoints (Local Disk)&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;vqa_checkpoint.pt (731 MB)&lt;br&gt;vqa_spatial_checkpoint.pt (739 MB)&lt;br&gt;state_dict | vocab | tokenizer config&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#251A00;strokeColor=#D29922;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ml_layer">
158
+ <mxGeometry x="20" y="688" width="340" height="80" as="geometry" />
159
+ </mxCell>
160
+
161
+ <mxCell id="gpu_badge" value="GPU: CUDA | ~4 GB VRAM | 2x Model Parallel loading" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#D29922;fontColor=#E3B341;fontSize=10;" vertex="1" parent="ml_layer">
162
+ <mxGeometry x="20" y="785" width="340" height="28" as="geometry" />
163
+ </mxCell>
164
+
165
+ <!-- ===================== NEURO-SYMBOLIC PIPELINE ===================== -->
166
+ <mxCell id="ns_layer" value="&lt;font style=&quot;font-size:14px;font-weight:bold;&quot; color=&quot;#BC8CFF&quot;&gt;[4] NEURO-SYMBOLIC PIPELINE&lt;/font&gt;" style="swimlane;startSize=30;fillColor=#161B22;strokeColor=#8957E5;fontColor=#FFFFFF;fontStyle=1;fontSize=13;rounded=10;" vertex="1" parent="1">
167
+ <mxGeometry x="1220" y="110" width="370" height="870" as="geometry" />
168
+ </mxCell>
169
+
170
+ <mxCell id="ns_label" value="[CLIP + Wikidata SPARQL + Groq LLM]" style="text;html=1;fontSize=14;align=center;fillColor=none;strokeColor=none;fontColor=#8957E5;" vertex="1" parent="ns_layer">
171
+ <mxGeometry x="15" y="38" width="340" height="35" as="geometry" />
172
+ </mxCell>
173
+
174
+ <mxCell id="ns_main" value="&lt;b&gt;SemanticNeurosymbolicVQA&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;semantic_neurosymbolic_vqa.py&lt;br&gt;Neural -&gt; Symbolic -&gt; Verbalize pipeline&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A0D2E;strokeColor=#8957E5;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
175
+ <mxGeometry x="20" y="88" width="330" height="65" as="geometry" />
176
+ </mxCell>
177
+
178
+ <mxCell id="ns_step1" value="&lt;b&gt;Step 1: CLIP Routing&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;should_use_neurosymbolic()&lt;br&gt;VISUAL anchor vs KNOWLEDGE anchor&lt;br&gt;Temperature softmax x10&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D1A30;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
179
+ <mxGeometry x="20" y="173" width="330" height="78" as="geometry" />
180
+ </mxCell>
181
+
182
+ <mxCell id="route_decision" value="VISUAL question?&lt;br&gt;-&gt; Neural VQA only&lt;br&gt;KNOWLEDGE question?&lt;br&gt;-&gt; Neuro-Symbolic" style="rhombus;whiteSpace=wrap;html=1;fillColor=#21262D;strokeColor=#8957E5;fontColor=#FFFFFF;fontSize=10;" vertex="1" parent="ns_layer">
183
+ <mxGeometry x="75" y="268" width="220" height="88" as="geometry" />
184
+ </mxCell>
185
+
186
+ <mxCell id="ns_step2" value="&lt;b&gt;Step 2: CLIP Object Detection&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;detect_objects_with_clip()&lt;br&gt;80+ noun vocabulary | Top-3 objects&lt;br&gt;Cosine similarity | prompt: &apos;a photo of a {label}&apos;&lt;/font&gt;" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#0D1A30;strokeColor=#1F6FEB;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
187
+ <mxGeometry x="20" y="375" width="330" height="80" as="geometry" />
188
+ </mxCell>
189
+
190
+ <mxCell id="wikidata_box" value="&lt;b&gt;Step 3: WikidataKnowledgeBase&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;SPARQL: query.wikidata.org&lt;br&gt;P31 (category) | P186 (material) | P366 (uses)&lt;br&gt;P2101 (melting pt) | P2054 (density)&lt;br&gt;lru_cache(500) | timeout=10s&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#0D2E2E;strokeColor=#2EA8A8;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
191
+ <mxGeometry x="20" y="473" width="330" height="100" as="geometry" />
192
+ </mxCell>
193
+
194
+ <mxCell id="groq_box" value="&lt;b&gt;Step 4: Groq LLM Verbalizer&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;WikidataGroqAnswerer&lt;br&gt;Model: llama-3.3-70b-versatile&lt;br&gt;Temp=0.1 | max_tokens=180 | top_p=0.9&lt;br&gt;Answers ONLY from Wikidata facts&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A2B1A;strokeColor=#F85149;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
195
+ <mxGeometry x="20" y="592" width="330" height="95" as="geometry" />
196
+ </mxCell>
197
+
198
+ <mxCell id="groq_access" value="&lt;b&gt;Groq Accessibility Service&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;groq_service.py | GroqDescriptionService&lt;br&gt;2-sentence narrations for blind users&lt;br&gt;Temp=0.7 | max_tokens=150&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A2B1A;strokeColor=#F85149;fontColor=#FFFFFF;fontSize=11;" vertex="1" parent="ns_layer">
199
+ <mxGeometry x="20" y="706" width="330" height="85" as="geometry" />
200
+ </mxCell>
201
+
202
+ <mxCell id="groq_badge" value="Groq API | Llama-3.3-70B-Versatile | GROQ_API_KEY env var" style="rounded=5;whiteSpace=wrap;html=1;fillColor=#0D1117;strokeColor=#F85149;fontColor=#F85149;fontSize=10;" vertex="1" parent="ns_layer">
203
+ <mxGeometry x="20" y="808" width="330" height="28" as="geometry" />
204
+ </mxCell>
205
+
206
+ <!-- ===================== EXTERNAL SERVICES ===================== -->
207
+ <mxCell id="wikidata_ext" value="&lt;b&gt;Wikidata SPARQL API&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;query.wikidata.org/sparql&lt;br&gt;wikidata.org/w/api.php&lt;br&gt;Entity lookup | Property values&lt;br&gt;Free and Open Knowledge Base&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#0A2525;strokeColor=#2EA8A8;fontColor=#FFFFFF;fontSize=12;" vertex="1" parent="1">
208
+ <mxGeometry x="1640" y="200" width="250" height="130" as="geometry" />
209
+ </mxCell>
210
+
211
+ <mxCell id="groq_cloud" value="&lt;b&gt;Groq Cloud API&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;api.groq.com&lt;br&gt;Llama-3.3-70B-Versatile&lt;br&gt;Ultra-low latency inference&lt;br&gt;chat.completions endpoint&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A0A0A;strokeColor=#F85149;fontColor=#FFFFFF;fontSize=12;" vertex="1" parent="1">
212
+ <mxGeometry x="1640" y="385" width="250" height="130" as="geometry" />
213
+ </mxCell>
214
+
215
+ <mxCell id="hf_clip" value="&lt;b&gt;OpenAI / HuggingFace Hub&lt;/b&gt;&lt;br&gt;&lt;font color=&quot;#8B949E&quot;&gt;CLIP ViT-B/32 weights&lt;br&gt;GPT-2 / DistilGPT-2 tokenizer&lt;br&gt;Cached locally after first download&lt;/font&gt;" style="rounded=10;whiteSpace=wrap;html=1;fillColor=#1A1000;strokeColor=#E3B341;fontColor=#FFFFFF;fontSize=12;" vertex="1" parent="1">
216
+ <mxGeometry x="1640" y="565" width="250" height="105" as="geometry" />
217
+ </mxCell>
218
+
219
+ <!-- ===================== LEGEND ===================== -->
220
+ <mxCell id="legend" value="&lt;b&gt;LEGEND&lt;/b&gt;&lt;br&gt;[1] Blue = Client Layer (React Native)&lt;br&gt;[2] Green = API Gateway (FastAPI)&lt;br&gt;[3] Orange = ML Inference (PyTorch)&lt;br&gt;[4] Purple = Neuro-Symbolic Pipeline&lt;br&gt;Solid arrow = Primary data flow&lt;br&gt;Dashed arrow = Conditional / supplement&lt;br&gt;Animated = Live request flow" style="rounded=8;whiteSpace=wrap;html=1;fillColor=#161B22;strokeColor=#30363D;fontColor=#8B949E;fontSize=11;align=left;" vertex="1" parent="1">
221
+ <mxGeometry x="1640" y="710" width="250" height="155" as="geometry" />
222
+ </mxCell>
223
+
224
+ <!-- ===================== EDGES / ANIMATED FLOWS ===================== -->
225
+
226
+ <!-- 1. api.js -> FastAPI (HTTP REST) -->
227
+ <mxCell id="flow_1" value="&lt;font color=&quot;#3FB950&quot;&gt;HTTP REST (JSON/FormData)&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;orthogonalLoop=1;jettySize=auto;strokeColor=#3FB950;strokeWidth=3;fontSize=10;fontColor=#3FB950;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="api_js" target="fastapi_main">
228
+ <mxGeometry relative="1" as="geometry" />
229
+ </mxCell>
230
+
231
+ <!-- 2. FastAPI -> Ensemble VQA -->
232
+ <mxCell id="flow_2" value="&lt;font color=&quot;#FFA657&quot;&gt;answer()&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;orthogonalLoop=1;jettySize=auto;strokeColor=#D29922;strokeWidth=3;fontSize=10;fontColor=#FFA657;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="fastapi_main" target="ensemble_vqa">
233
+ <mxGeometry relative="1" as="geometry" />
234
+ </mxCell>
235
+
236
+ <!-- 3. Ensemble -> Router -->
237
+ <mxCell id="flow_3" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#D29922;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ensemble_vqa" target="router">
238
+ <mxGeometry relative="1" as="geometry" />
239
+ </mxCell>
240
+
241
+ <!-- 4a. Router -> Base Model -->
242
+ <mxCell id="flow_4a" value="&lt;font color=&quot;#3FB950&quot;&gt;General Q&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#3FB950;strokeWidth=2;animation=1;endArrow=block;endFill=1;fontSize=10;fontColor=#3FB950;" edge="1" parent="1" source="router" target="base_model_box">
243
+ <mxGeometry relative="1" as="geometry" />
244
+ </mxCell>
245
+
246
+ <!-- 4b. Router -> Spatial Model -->
247
+ <mxCell id="flow_4b" value="&lt;font color=&quot;#58A6FF&quot;&gt;Spatial Q&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#1F6FEB;strokeWidth=2;animation=1;endArrow=block;endFill=1;fontSize=10;fontColor=#58A6FF;" edge="1" parent="1" source="router" target="spatial_model_box">
248
+ <mxGeometry relative="1" as="geometry" />
249
+ </mxCell>
250
+
251
+ <!-- 5. Ensemble -> NS Pipeline (supplement) -->
252
+ <mxCell id="flow_5" value="&lt;font color=&quot;#BC8CFF&quot;&gt;NS supplement&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;orthogonalLoop=1;jettySize=auto;strokeColor=#8957E5;strokeWidth=3;fontSize=10;fontColor=#BC8CFF;animation=1;dashed=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ensemble_vqa" target="ns_main">
253
+ <mxGeometry relative="1" as="geometry" />
254
+ </mxCell>
255
+
256
+ <!-- 6. NS main -> CLIP Routing -->
257
+ <mxCell id="flow_6" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#8957E5;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ns_main" target="ns_step1">
258
+ <mxGeometry relative="1" as="geometry" />
259
+ </mxCell>
260
+
261
+ <!-- 7. CLIP Routing -> Decision diamond -->
262
+ <mxCell id="flow_7" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#8957E5;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ns_step1" target="route_decision">
263
+ <mxGeometry relative="1" as="geometry" />
264
+ </mxCell>
265
+
266
+ <!-- 8. Decision -> Object Detection -->
267
+ <mxCell id="flow_8" value="&lt;font color=&quot;#BC8CFF&quot;&gt;Knowledge Q&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#8957E5;strokeWidth=2;animation=1;dashed=1;endArrow=block;endFill=1;fontSize=10;fontColor=#BC8CFF;" edge="1" parent="1" source="route_decision" target="ns_step2">
268
+ <mxGeometry relative="1" as="geometry" />
269
+ </mxCell>
270
+
271
+ <!-- 9. Object Detection -> Wikidata box -->
272
+ <mxCell id="flow_9" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#2EA8A8;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="ns_step2" target="wikidata_box">
273
+ <mxGeometry relative="1" as="geometry" />
274
+ </mxCell>
275
+
276
+ <!-- 10. Wikidata box -> Wikidata external API -->
277
+ <mxCell id="flow_10" value="&lt;font color=&quot;#2EA8A8&quot;&gt;SPARQL queries&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#2EA8A8;strokeWidth=3;fontSize=10;fontColor=#2EA8A8;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="wikidata_box" target="wikidata_ext">
278
+ <mxGeometry relative="1" as="geometry" />
279
+ </mxCell>
280
+
281
+ <!-- 11. Wikidata facts -> Groq verbalizer -->
282
+ <mxCell id="flow_11" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#F85149;strokeWidth=2;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="wikidata_box" target="groq_box">
283
+ <mxGeometry relative="1" as="geometry" />
284
+ </mxCell>
285
+
286
+ <!-- 12. Groq box -> Groq Cloud -->
287
+ <mxCell id="flow_12" value="&lt;font color=&quot;#F85149&quot;&gt;API call | Llama-3.3-70B&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#F85149;strokeWidth=3;fontSize=10;fontColor=#F85149;animation=1;endArrow=block;endFill=1;" edge="1" parent="1" source="groq_box" target="groq_cloud">
288
+ <mxGeometry relative="1" as="geometry" />
289
+ </mxCell>
290
+
291
+ <!-- 13. Groq accessibility -> Groq Cloud -->
292
+ <mxCell id="flow_13" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#F85149;strokeWidth=2;animation=1;dashed=1;endArrow=block;endFill=1;" edge="1" parent="1" source="groq_access" target="groq_cloud">
293
+ <mxGeometry relative="1" as="geometry" />
294
+ </mxCell>
295
+
296
+ <!-- 14. FastAPI -> Groq Accessibility (top arc) -->
297
+ <mxCell id="flow_14" value="&lt;font color=&quot;#F85149&quot;&gt;accessibility narration&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#F85149;strokeWidth=2;fontSize=10;fontColor=#F85149;animation=1;dashed=1;endArrow=block;endFill=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;" edge="1" parent="1" source="fastapi_main" target="groq_access">
298
+ <mxGeometry relative="1" as="geometry">
299
+ <Array as="points">
300
+ <mxPoint x="580" y="140" />
301
+ <mxPoint x="1385" y="140" />
302
+ </Array>
303
+ </mxGeometry>
304
+ </mxCell>
305
+
306
+ <!-- 15. CLIP box -> HuggingFace (model weights) -->
307
+ <mxCell id="flow_15" value="&lt;font color=&quot;#E3B341&quot;&gt;model weights (cached)&lt;/font&gt;" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#E3B341;strokeWidth=2;fontSize=10;fontColor=#E3B341;dashed=1;endArrow=block;endFill=1;" edge="1" parent="1" source="clip_box" target="hf_clip">
308
+ <mxGeometry relative="1" as="geometry" />
309
+ </mxCell>
310
+
311
+ <!-- 16a. Base model -> GPT2 Tokenizer -->
312
+ <mxCell id="flow_16a" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#30363D;strokeWidth=1;endArrow=block;endFill=1;" edge="1" parent="1" source="base_model_box" target="gpt2">
313
+ <mxGeometry relative="1" as="geometry" />
314
+ </mxCell>
315
+
316
+ <!-- 16b. Spatial model -> GPT2 Tokenizer -->
317
+ <mxCell id="flow_16b" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#30363D;strokeWidth=1;endArrow=block;endFill=1;" edge="1" parent="1" source="spatial_model_box" target="gpt2">
318
+ <mxGeometry relative="1" as="geometry" />
319
+ </mxCell>
320
+
321
+ <!-- 17. Conv Manager <-> Ensemble VQA -->
322
+ <mxCell id="flow_17" value="" style="edgeStyle=orthogonalEdgeStyle;rounded=1;strokeColor=#7B2FBE;strokeWidth=2;animation=1;dashed=1;endArrow=block;endFill=1;startArrow=block;startFill=1;" edge="1" parent="1" source="conv_manager" target="ensemble_vqa">
323
+ <mxGeometry relative="1" as="geometry" />
324
+ </mxCell>
325
+
326
+ <!-- ===================== PHASE ANNOTATIONS ===================== -->
327
+ <mxCell id="ann1" value="(1) User uploads image + question" style="text;html=1;strokeColor=none;fillColor=#0D1117;fontColor=#58A6FF;fontSize=11;fontStyle=1;align=center;" vertex="1" parent="1">
328
+ <mxGeometry x="100" y="988" width="250" height="28" as="geometry" />
329
+ </mxCell>
330
+ <mxCell id="ann2" value="(2) REST API routes to ensemble" style="text;html=1;strokeColor=none;fillColor=#0D1117;fontColor=#3FB950;fontSize=11;fontStyle=1;align=center;" vertex="1" parent="1">
331
+ <mxGeometry x="460" y="988" width="240" height="28" as="geometry" />
332
+ </mxCell>
333
+ <mxCell id="ann3" value="(3) Neural model answers question" style="text;html=1;strokeColor=none;fillColor=#0D1117;fontColor=#FFA657;fontSize=11;fontStyle=1;align=center;" vertex="1" parent="1">
334
+ <mxGeometry x="860" y="988" width="250" height="28" as="geometry" />
335
+ </mxCell>
336
+ <mxCell id="ann4" value="(4) Symbolic + Groq enriches answer" style="text;html=1;strokeColor=none;fillColor=#0D1117;fontColor=#BC8CFF;fontSize=11;fontStyle=1;align=center;" vertex="1" parent="1">
337
+ <mxGeometry x="1270" y="988" width="260" height="28" as="geometry" />
338
+ </mxCell>
339
+
340
+ </root>
341
+ </mxGraphModel>
exp_results/feature_extraction_metric.csv ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch,train_loss,train_token_acc,val_loss,val_token_acc,val_exact_match,lr
2
+ 1,3.687392619148223,0.5010925703618669,2.6377785576964325,0.531001718679689,0.0625462073044507,0.0001
3
+ 2,3.0861334211370917,0.5492582896593264,2.1873294205035805,0.5735707690693298,0.1437971314505397,0.0001
4
+ 3,2.8613873058208554,0.5773015727241105,2.0188139058508963,0.5919563278274717,0.18172408694366404,0.0001
5
+ 4,2.737266832805117,0.5940482014385925,1.8989913845961948,0.6057449292461827,0.2079698358716546,0.0001
6
+ 5,2.64607786060719,0.6068536389304081,1.8126546847370435,0.6131761748835726,0.22467839716102322,0.0001
7
+ 6,2.5737654996439945,0.6159161500927967,1.745610311908542,0.6227055006432083,0.23806003252994234,0.0001
8
+ 7,2.514629547727101,0.6238974923921153,1.6846065549355633,0.6310539678582605,0.25521218394203754,0.0001
9
+ 8,2.467853448716654,0.630066124487741,1.6530387682734795,0.6351331795723933,0.2616442407215733,0.0001
10
+ 9,2.430272235876001,0.6363310434633568,1.6044414886888467,0.6438829395568596,0.2796096406920006,0.0001
11
+ 10,2.3940254725485,0.6410929495099732,1.5768477393771119,0.6476609546620891,0.2876681945882005,0.0001
12
+ 11,2.3626844231579023,0.6466396824626934,1.553934060740021,0.6507747072093891,0.2935087978707674,0.0001
13
+ 12,2.3347287295768417,0.6508579807194079,1.5344560882955227,0.6529503009229336,0.29957119621469763,0.0001
14
+ 13,2.309176077580466,0.6551987208042674,1.5069528773145855,0.6592958943461472,0.3086647937305929,0.0001
15
+ 14,2.2852324938224235,0.6583507632729854,1.4877223473674845,0.6627878375210852,0.31820198136921485,0.0001
16
+ 15,2.265477722738707,0.6621552250710977,1.4731922914397042,0.6635274037999926,0.3206417270442111,0.0001
17
+ 16,2.245406344189297,0.6660276569959188,1.454425812892194,0.6657813076140746,0.3254472867070827,1e-06
18
+ 17,2.2047869251156476,0.6741207528932076,1.4267255866302635,0.6736559963451242,0.3408990093153926,1e-06
19
+ 18,2.173899897451869,0.6801777819710184,1.4036545191171035,0.6780021879470574,0.34703533934644387,1e-06
20
+ 19,2.15051551812644,0.6852958937991237,1.3850691127327253,0.6806749330376679,0.3535413278131007,1e-06
21
+ 20,2.130151925532512,0.6903713528113137,1.3759601954019294,0.682907020145992,0.3590862043471832,1e-06
22
+ 21,2.111327923803482,0.6937075932303665,1.3607378039719924,0.6867363317957464,0.3650746710039923,1e-06
23
+ 22,2.092705831874552,0.6989087903379759,1.3529389587775715,0.6871686296642951,0.3676622800532308,1e-06
24
+ 23,2.0762000757163266,0.7018636832358497,1.3471845992893543,0.6889090611124938,0.3711370693479225,1e-06
25
+ 24,2.0588077032516723,0.7061800249295429,1.3332587570514318,0.6925943864966339,0.37853023806003255,1e-06
26
+ 25,2.043530640342685,0.7086816234112068,1.323614944545728,0.6927403596774587,0.3790477598698802,1e-06
27
+ 26,2.028976038177644,0.7119645012827895,1.321273627989697,0.6960837739818501,0.38511015821381045,1e-06
28
+ 27,2.0125017191516372,0.7166598519934908,1.3151825143481202,0.6966083350608934,0.38651486026911136,1e-06
29
+ 28,1.998029633995205,0.7198163744333156,1.3046240308937036,0.6980289071798325,0.38836315244713887,1e-06
30
+ 29,1.9832194559959038,0.7228894007410402,1.3061683574375116,0.6981341627971182,0.3905811030607719,1e-06
31
+ 30,1.96923152904127,0.7272438684699805,1.3041821732273642,0.6986926667532831,0.3902114446251663,1e-06
experiments/__pycache__/train.cpython-312.pyc ADDED
Binary file (21.6 kB). View file
 
experiments/test.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import GPT2Tokenizer
5
+ from model import VQAModel
6
+ from train import Vocab
7
+ def load_model(checkpoint_path, device='cuda'):
8
+ checkpoint = torch.load(checkpoint_path, map_location=device)
9
+ vocab = Vocab()
10
+ vocab.vocab = checkpoint['vocab']
11
+ vocab.vocab_size = len(checkpoint['vocab'])
12
+ vocab.word2idx = checkpoint['word2idx']
13
+ vocab.idx2word = checkpoint['idx2word']
14
+ vocab.pad_token_id = checkpoint['pad_token_id']
15
+ vocab.bos_token_id = checkpoint['bos_token_id']
16
+ vocab.eos_token_id = checkpoint['eos_token_id']
17
+ vocab.unk_token_id = checkpoint['unk_token_id']
18
+ model = VQAModel(
19
+ vocab_size=len(checkpoint['vocab']),
20
+ device=device,
21
+ question_max_len=checkpoint.get('question_max_len', 20),
22
+ answer_max_len=checkpoint.get('answer_max_len', 12),
23
+ pad_token_id=checkpoint['pad_token_id'],
24
+ bos_token_id=checkpoint['bos_token_id'],
25
+ eos_token_id=checkpoint['eos_token_id'],
26
+ unk_token_id=checkpoint['unk_token_id'],
27
+ hidden_size=512,
28
+ num_layers=2
29
+ ).to(device)
30
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
31
+ if tokenizer.pad_token is None:
32
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
33
+ model.gpt2_model.resize_token_embeddings(len(tokenizer))
34
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
35
+ model.eval()
36
+ return model, vocab, tokenizer
37
+ def answer_question(model, vocab, tokenizer, image_path, question, device='cuda', use_beam_search=True, beam_width=5, temperature=0.8):
38
+ image = Image.open(image_path).convert('RGB')
39
+ image = model.clip_preprocess(image).unsqueeze(0).to(device)
40
+ question_tokens = tokenizer(
41
+ question,
42
+ padding='max_length',
43
+ truncation=True,
44
+ max_length=model.question_max_len,
45
+ return_tensors='pt'
46
+ )
47
+ questions = {
48
+ 'input_ids': question_tokens['input_ids'].to(device),
49
+ 'attention_mask': question_tokens['attention_mask'].to(device)
50
+ }
51
+ with torch.no_grad():
52
+ if use_beam_search and hasattr(model, 'generate_with_beam_search'):
53
+ generated = model.generate_with_beam_search(image, questions, beam_width=beam_width)
54
+ else:
55
+ generated = model(image, questions)
56
+ answer = vocab.decoder(generated[0].cpu().numpy())
57
+ return answer
58
+ CHECKPOINT = "./output2/spatial_adapter_v2_2/vqa_spatial_checkpoint.pt"
59
+ IMAGE_PATH = r"./im2.jpg"
60
+ QUESTION = ""
61
+ if __name__ == "__main__":
62
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
63
+ print("Loading model...")
64
+ model, vocab, tokenizer = load_model(CHECKPOINT, device)
65
+ print("Model loaded!\n")
66
+ test_questions = [
67
+ "What is to the right of the soup?"
68
+ ]
69
+ print(f"Image: {IMAGE_PATH}\n")
70
+ for question in test_questions:
71
+ print(f"Question: {question}")
72
+ answer = answer_question(model, vocab, tokenizer, IMAGE_PATH, question, device, use_beam_search=True, beam_width=5)
73
+ print(f"Answer: {answer}\n")
experiments/train.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from PIL import Image
7
+ from transformers import GPT2Tokenizer
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ from collections import Counter
12
+ from nltk.tokenize import word_tokenize
13
+ from sklearn.model_selection import train_test_split
14
+ from torchvision import transforms
15
+ from models.model import VQAModel
16
+ device = 'cuda'
17
+ class Vocab:
18
+ def __init__(self):
19
+ self.vocab = None
20
+ self.vocab_size = None
21
+ self.word2idx = None
22
+ self.idx2word = None
23
+ self.pad = '<pad>'
24
+ self.bos = '<bos>'
25
+ self.eos = '<eos>'
26
+ self.unk = '<unk>'
27
+ def build_vocab(self, df, min_freq=1):
28
+ counter = Counter()
29
+ for ans in df['answer']:
30
+ tokens = word_tokenize(ans.lower())
31
+ counter.update(tokens)
32
+ vocab = sorted([word for word, freq in counter.items() if freq >= min_freq])
33
+ vocab = [self.pad, self.bos, self.eos, self.unk] + vocab
34
+ word2idx = {word: idx for idx, word in enumerate(vocab)}
35
+ idx2word = {idx: word for word, idx in word2idx.items()}
36
+ self.vocab = vocab
37
+ self.word2idx = word2idx
38
+ self.idx2word = idx2word
39
+ self.vocab_size = len(vocab)
40
+ self.pad_token_id = self.word2idx["<pad>"]
41
+ self.bos_token_id = self.word2idx["<bos>"]
42
+ self.eos_token_id = self.word2idx["<eos>"]
43
+ self.unk_token_id = self.word2idx["<unk>"]
44
+ def encoder(self, text, max_len):
45
+ tokens = word_tokenize(text.lower())
46
+ token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]
47
+ token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
48
+ if len(token_ids) < max_len:
49
+ token_ids += [self.pad_token_id] * (max_len - len(token_ids))
50
+ else:
51
+ token_ids = token_ids[:max_len]
52
+ return token_ids
53
+ def decoder(self, token_ids):
54
+ tokens = []
55
+ for idx in token_ids:
56
+ if idx == self.eos_token_id:
57
+ break
58
+ if idx in (self.pad_token_id, self.bos_token_id):
59
+ continue
60
+ tokens.append(self.idx2word.get(idx, "<unk>"))
61
+ return ' '.join(tokens).strip()
62
+ class AugmentedVQADataset(Dataset):
63
+ def __init__(self, df, img_dir, question_tokenizer, text_processor, clip_processor,
64
+ question_max_len=32, answer_max_len=16, augment=True):
65
+ self.df = df
66
+ self.img_dir = img_dir
67
+ self.question_tokenizer = question_tokenizer
68
+ self.text_processor = text_processor
69
+ self.clip_processor = clip_processor
70
+ self.question_max_len = question_max_len
71
+ self.answer_max_len = answer_max_len
72
+ self.augment = augment
73
+ if augment:
74
+ self.transform = transforms.Compose([
75
+ transforms.RandomHorizontalFlip(p=0.5),
76
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
77
+ transforms.RandomRotation(10),
78
+ ])
79
+ else:
80
+ self.transform = None
81
+ def __len__(self):
82
+ return len(self.df)
83
+ def __getitem__(self, idx):
84
+ row = self.df.iloc[idx]
85
+ img_path = os.path.join(self.img_dir, row['image_path'])
86
+ image = Image.open(img_path).convert('RGB')
87
+ question = row['question']
88
+ answer = row['answer']
89
+ if self.augment and self.transform:
90
+ image = self.transform(image)
91
+ question_tokenized = self.question_tokenizer(
92
+ question,
93
+ padding='max_length',
94
+ truncation=True,
95
+ max_length=self.question_max_len,
96
+ return_tensors='pt'
97
+ )
98
+ answer_ids = self.text_processor.encoder(answer, max_len=self.answer_max_len)
99
+ image = self.clip_processor(image)
100
+ return {
101
+ 'image_path': img_path,
102
+ 'image': image,
103
+ 'question_ids': question_tokenized['input_ids'].squeeze(0),
104
+ 'question_mask': question_tokenized['attention_mask'].squeeze(0),
105
+ 'answer_ids': torch.tensor(answer_ids, dtype=torch.long)
106
+ }
107
+ def save_checkpoint(model, optimizer, epoch, vocab, path):
108
+ torch.save({
109
+ 'epoch': epoch,
110
+ 'model_state_dict': model.state_dict(),
111
+ 'optimizer_state_dict': optimizer.state_dict(),
112
+ 'vocab': vocab.vocab,
113
+ 'word2idx': vocab.word2idx,
114
+ 'idx2word': vocab.idx2word,
115
+ 'pad_token_id': vocab.pad_token_id,
116
+ 'bos_token_id': vocab.bos_token_id,
117
+ 'eos_token_id': vocab.eos_token_id,
118
+ 'unk_token_id': vocab.unk_token_id,
119
+ 'question_max_len': model.question_max_len,
120
+ 'answer_max_len': model.answer_max_len
121
+ }, path)
122
+ def plot_losses(train_losses, val_losses, save_path="loss_plot.png"):
123
+ plt.figure(figsize=(8,6))
124
+ plt.plot(train_losses, label="Train Loss")
125
+ plt.plot(val_losses, label="Validation Loss")
126
+ plt.xlabel("Epoch")
127
+ plt.ylabel("Loss")
128
+ plt.title("Train vs Validation Loss")
129
+ plt.legend()
130
+ plt.savefig(save_path)
131
+ plt.close()
132
+ def train_one_epoch(model, dataloader, optimizer, device, scaler, vocab):
133
+ model.train()
134
+ total_loss = 0
135
+ total_token_acc = 0
136
+ criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id, label_smoothing=0.1)
137
+ for batch in tqdm(dataloader):
138
+ optimizer.zero_grad()
139
+ images = batch['image'].to(device)
140
+ questions = {
141
+ 'input_ids': batch['question_ids'].to(device),
142
+ 'attention_mask': batch['question_mask'].to(device)
143
+ }
144
+ answers = batch['answer_ids'].to(device)
145
+ with torch.amp.autocast(device):
146
+ logits = model(images, questions, answer_input_ids=answers)
147
+ shifted_logits = logits[:, :-1, :]
148
+ shifted_answers = answers[:, 1:]
149
+ loss = criterion(
150
+ shifted_logits.reshape(-1, shifted_logits.size(-1)),
151
+ shifted_answers.reshape(-1)
152
+ )
153
+ predicted_tokens = shifted_logits.argmax(dim=-1)
154
+ correct = (predicted_tokens == shifted_answers).float()
155
+ mask = (shifted_answers != vocab.pad_token_id).float()
156
+ token_acc = (correct * mask).sum() / mask.sum()
157
+ total_token_acc += token_acc.item()
158
+ scaler.scale(loss).backward()
159
+ scaler.unscale_(optimizer)
160
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
161
+ scaler.step(optimizer)
162
+ scaler.update()
163
+ total_loss += loss.item()
164
+ avg_loss = total_loss / len(dataloader)
165
+ avg_token_acc = total_token_acc / len(dataloader)
166
+ return avg_loss, avg_token_acc
167
+ def validate_one_epoch(model, dataloader, device, vocab):
168
+ model.eval()
169
+ total_loss = 0
170
+ total_token_acc = 0
171
+ exact_matches = 0
172
+ total_samples = 0
173
+ criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id)
174
+ with torch.no_grad():
175
+ for batch in tqdm(dataloader):
176
+ images = batch['image'].to(device)
177
+ questions = {
178
+ 'input_ids': batch['question_ids'].to(device),
179
+ 'attention_mask': batch['question_mask'].to(device)
180
+ }
181
+ answers = batch['answer_ids'].to(device)
182
+ logits = model(images, questions, answer_input_ids=answers)
183
+ shifted_logits = logits[:, :-1, :]
184
+ shifted_answers = answers[:, 1:]
185
+ loss = criterion(
186
+ shifted_logits.reshape(-1, shifted_logits.size(-1)),
187
+ shifted_answers.reshape(-1)
188
+ )
189
+ total_loss += loss.item()
190
+ predicted_tokens = shifted_logits.argmax(dim=-1)
191
+ correct = (predicted_tokens == shifted_answers).float()
192
+ mask = (shifted_answers != vocab.pad_token_id).float()
193
+ token_acc = (correct * mask).sum() / mask.sum()
194
+ total_token_acc += token_acc.item()
195
+ if hasattr(model, 'generate_with_beam_search'):
196
+ generated = model.generate_with_beam_search(images, questions, beam_width=3)
197
+ else:
198
+ generated = model(images, questions)
199
+ for pred, true in zip(generated, answers):
200
+ pred_text = vocab.decoder(pred.cpu().numpy())
201
+ true_text = vocab.decoder(true.cpu().numpy())
202
+ if pred_text.strip() == true_text.strip():
203
+ exact_matches += 1
204
+ total_samples += 1
205
+ avg_loss = total_loss / len(dataloader)
206
+ avg_token_acc = total_token_acc / len(dataloader)
207
+ exact_match_acc = exact_matches / total_samples
208
+ return avg_loss, avg_token_acc, exact_match_acc
209
+ def main():
210
+ print()
211
+ print("# VQA: Training with Staged Unfreezing")
212
+ print()
213
+ import random
214
+ import numpy as np
215
+ torch.manual_seed(42)
216
+ random.seed(42)
217
+ np.random.seed(42)
218
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(42)
219
+ DATA_DIR = r"./gen_vqa_v2"
220
+ CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
221
+ OUTPUT_DIR = r"./output2/feature_extraction"
222
+ CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, "vqa_checkpoint.pt")
223
+ LOG_CSV = os.path.join(OUTPUT_DIR, "train_log.csv")
224
+ LOSS_GRAPH_PATH = os.path.join(OUTPUT_DIR, "loss_plot.png")
225
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
226
+ batch_size = 64
227
+ learning_rate = 1e-4
228
+ num_epochs = 30
229
+ patience = 8
230
+ question_max_len = 20
231
+ answer_max_len = 12
232
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
233
+ print(device)
234
+ metadata = pd.read_csv(CSV_PATH)
235
+ print(f"Using: question_max_len={question_max_len}, answer_max_len={answer_max_len}")
236
+ vocab = Vocab()
237
+ vocab.build_vocab(metadata, min_freq=3)
238
+ answer_vocab_size = len(vocab.vocab)
239
+ print(f"Answer Vocab Size: {answer_vocab_size}")
240
+ word_freq = Counter()
241
+ for ans in metadata['answer']:
242
+ tokens = word_tokenize(ans.lower())
243
+ word_freq.update(tokens)
244
+ print("\nTop 20 most common answer words:")
245
+ for word, freq in word_freq.most_common(20):
246
+ print(f" {word}: {freq}")
247
+ train_df, test_df = train_test_split(metadata, test_size=0.2, random_state=42)
248
+ val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
249
+ print(f"\nTrain size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")
250
+ print()
251
+ model = VQAModel(
252
+ vocab_size=answer_vocab_size,
253
+ device=device,
254
+ question_max_len=question_max_len,
255
+ answer_max_len=answer_max_len,
256
+ pad_token_id=vocab.pad_token_id,
257
+ bos_token_id=vocab.bos_token_id,
258
+ eos_token_id=vocab.eos_token_id,
259
+ unk_token_id=vocab.unk_token_id,
260
+ hidden_size=512,
261
+ num_layers=2
262
+ ).to(device)
263
+ print("STAGE 1: Training decoder with frozen encoders")
264
+ print()
265
+ clip_processor = model.clip_preprocess
266
+ question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
267
+ if question_tokenizer.pad_token is None:
268
+ question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
269
+ model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
270
+ train_dataset = AugmentedVQADataset(
271
+ train_df, DATA_DIR, question_tokenizer, vocab,
272
+ clip_processor=clip_processor,
273
+ question_max_len=question_max_len,
274
+ answer_max_len=answer_max_len,
275
+ augment=True
276
+ )
277
+ val_dataset = AugmentedVQADataset(
278
+ val_df, DATA_DIR, question_tokenizer, vocab,
279
+ clip_processor=clip_processor,
280
+ question_max_len=question_max_len,
281
+ answer_max_len=answer_max_len,
282
+ augment=False
283
+ )
284
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
285
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
286
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
287
+ optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=1e-4)
288
+ print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
289
+ print()
290
+ scaler = torch.amp.GradScaler(device)
291
+ best_val_loss = np.inf
292
+ best_val_exact_match = 0.0
293
+ counter = 0
294
+ logs = []
295
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
296
+ optimizer, mode='max', factor=0.5, patience=4, verbose=True
297
+ )
298
+ for epoch in range(num_epochs):
299
+ print(f"\nEpoch {epoch+1}/{num_epochs}")
300
+ train_loss, train_token_acc = train_one_epoch(model, train_loader, optimizer, device, scaler, vocab)
301
+ val_loss, val_token_acc, val_exact_match = validate_one_epoch(model, val_loader, device, vocab)
302
+ print(f"Train Loss: {train_loss:.4f} | Train Token Acc: {train_token_acc:.4f}")
303
+ print(f"Val Loss: {val_loss:.4f} | Val Token Acc: {val_token_acc:.4f} | Val Exact Match: {val_exact_match:.4f}")
304
+ print(f"LR: {optimizer.param_groups[0]['lr']}")
305
+ scheduler.step(val_exact_match)
306
+ if val_exact_match > best_val_exact_match:
307
+ best_val_exact_match = val_exact_match
308
+ save_checkpoint(model, optimizer, epoch, vocab, CHECKPOINT_PATH)
309
+ print("Checkpoint saved!")
310
+ counter = 0
311
+ else:
312
+ counter += 1
313
+ print(f"No improvement in exact match for {counter} epochs.")
314
+ if epoch == 15 and not model.fine_tuning_mode:
315
+ print("\n" + "="*50)
316
+ print("STAGE 2: Unfreezing encoders for fine-tuning")
317
+ print("="*50)
318
+ model.unfreeze_clip_layers(num_layers=3)
319
+ model.unfreeze_gpt2_layers(num_layers=3)
320
+ clip_params = []
321
+ gpt2_params = []
322
+ other_params = []
323
+ for name, param in model.named_parameters():
324
+ if param.requires_grad:
325
+ if 'clip_model' in name:
326
+ clip_params.append(param)
327
+ elif 'gpt2_model' in name:
328
+ gpt2_params.append(param)
329
+ else:
330
+ other_params.append(param)
331
+ optimizer = torch.optim.AdamW([
332
+ {'params': clip_params, 'lr': 1e-6},
333
+ {'params': gpt2_params, 'lr': 1e-6},
334
+ {'params': other_params, 'lr': 5e-5}
335
+ ], weight_decay=1e-4)
336
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
337
+ optimizer, mode='max', factor=0.5, patience=4, verbose=True
338
+ )
339
+ print()
340
+ if counter >= patience:
341
+ print(f"\nEarly stopping after {patience} epochs without improvement")
342
+ logs.append([epoch+1, train_loss, train_token_acc, val_loss, val_token_acc, val_exact_match, optimizer.param_groups[0]['lr']])
343
+ log_df = pd.DataFrame(logs, columns=["epoch","train_loss","train_token_acc","val_loss","val_token_acc","val_exact_match","lr"])
344
+ log_df.to_csv(LOG_CSV, index=False)
345
+ plot_losses([x[1] for x in logs], [x[3] for x in logs], save_path=LOSS_GRAPH_PATH)
346
+ print("Training complete!")
347
+ print(f"Best exact match accuracy: {best_val_exact_match:.4f}")
348
+ if __name__ == "__main__":
349
+ main()
experiments/utils/preprocess.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from PIL import Image
7
+ from transformers import GPT2Tokenizer
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from collections import Counter
11
+ from nltk.tokenize import word_tokenize
12
+ from sklearn.model_selection import train_test_split
13
+ from torchvision import transforms
14
+ from model import VQAModel
15
+ class Vocab:
16
+ def __init__(self):
17
+ self.vocab = None
18
+ self.vocab_size = None
19
+ self.word2idx = None
20
+ self.idx2word = None
21
+ self.pad = '<pad>'
22
+ self.bos = '<bos>'
23
+ self.eos = '<eos>'
24
+ self.unk = '<unk>'
25
+ def build_vocab(self, df, min_freq=1):
26
+ counter = Counter()
27
+ for ans in df['answer']:
28
+ tokens = word_tokenize(ans.lower())
29
+ counter.update(tokens)
30
+ vocab = sorted([word for word, freq in counter.items() if freq >= min_freq])
31
+ vocab = [self.pad, self.bos, self.eos, self.unk] + vocab
32
+ word2idx = {word: idx for idx, word in enumerate(vocab)}
33
+ idx2word = {idx: word for word, idx in word2idx.items()}
34
+ self.vocab = vocab
35
+ self.word2idx = word2idx
36
+ self.idx2word = idx2word
37
+ self.vocab_size = len(vocab)
38
+ self.pad_token_id = self.word2idx["<pad>"]
39
+ self.bos_token_id = self.word2idx["<bos>"]
40
+ self.eos_token_id = self.word2idx["<eos>"]
41
+ self.unk_token_id = self.word2idx["<unk>"]
42
+ def encoder(self, text, max_len):
43
+ tokens = word_tokenize(text.lower())
44
+ token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]
45
+ token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
46
+ if len(token_ids) < max_len:
47
+ token_ids += [self.pad_token_id] * (max_len - len(token_ids))
48
+ else:
49
+ token_ids = token_ids[:max_len]
50
+ return token_ids
51
+ def decoder(self, token_ids):
52
+ tokens = []
53
+ for idx in token_ids:
54
+ if idx == self.eos_token_id:
55
+ break
56
+ if idx in (self.pad_token_id, self.bos_token_id):
57
+ continue
58
+ tokens.append(self.idx2word.get(idx, "<unk>"))
59
+ return ' '.join(tokens).strip()
60
+ class AugmentedVQADataset(Dataset):
61
+ def __init__(self, df, img_dir, question_tokenizer, text_processor, clip_processor,
62
+ question_max_len=32, answer_max_len=16, augment=True):
63
+ self.df = df
64
+ self.img_dir = img_dir
65
+ self.question_tokenizer = question_tokenizer
66
+ self.text_processor = text_processor
67
+ self.clip_processor = clip_processor
68
+ self.question_max_len = question_max_len
69
+ self.answer_max_len = answer_max_len
70
+ self.augment = augment
71
+ if augment:
72
+ self.transform = transforms.Compose([
73
+ transforms.RandomHorizontalFlip(p=0.5),
74
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
75
+ transforms.RandomRotation(10),
76
+ ])
77
+ else:
78
+ self.transform = None
79
+ def __len__(self):
80
+ return len(self.df)
81
+ def __getitem__(self, idx):
82
+ row = self.df.iloc[idx]
83
+ img_path = os.path.join(self.img_dir, row['image_path'])
84
+ image = Image.open(img_path).convert('RGB')
85
+ question = row['question']
86
+ answer = row['answer']
87
+ if self.augment and self.transform:
88
+ image = self.transform(image)
89
+ question_tokenized = self.question_tokenizer(
90
+ question,
91
+ padding='max_length',
92
+ truncation=True,
93
+ max_length=self.question_max_len,
94
+ return_tensors='pt'
95
+ )
96
+ answer_ids = self.text_processor.encoder(answer, max_len=self.answer_max_len)
97
+ image = self.clip_processor(image)
98
+ return {
99
+ 'image_path': img_path,
100
+ 'image': image,
101
+ 'question_ids': question_tokenized['input_ids'].squeeze(0),
102
+ 'question_mask': question_tokenized['attention_mask'].squeeze(0),
103
+ 'answer_ids': torch.tensor(answer_ids, dtype=torch.long)
104
+ }
105
+ if __name__ == "__main__":
106
+ DATA_DIR = r"/home/devarajan8/Documents/vqa/gen_vqa_v2"
107
+ CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
108
+ batch_size = 16
109
+ question_max_len = 16
110
+ answer_max_len = 10
111
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
112
+ metadata = pd.read_csv(CSV_PATH)
113
+ vocab = Vocab()
114
+ vocab.build_vocab(metadata, min_freq=5)
115
+ answer_vocab_size = len(vocab.vocab)
116
+ print(f"Answer Vocab Size: {answer_vocab_size}")
117
+ train_df, test_df = train_test_split(metadata, test_size=0.2, random_state=42)
118
+ val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
119
+ print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}")
120
+ print()
121
+ model = VQAModel(
122
+ vocab_size=answer_vocab_size,
123
+ device=device,
124
+ question_max_len=question_max_len,
125
+ answer_max_len=answer_max_len,
126
+ pad_token_id=vocab.pad_token_id,
127
+ bos_token_id=vocab.bos_token_id,
128
+ eos_token_id=vocab.eos_token_id,
129
+ unk_token_id=vocab.unk_token_id,
130
+ hidden_size=512,
131
+ num_layers=2
132
+ ).to(device)
133
+ clip_processor = model.clip_preprocess
134
+ question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
135
+ if question_tokenizer.pad_token is None:
136
+ question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
137
+ model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
138
+ train_dataset = AugmentedVQADataset(
139
+ train_df, DATA_DIR, question_tokenizer, vocab,
140
+ clip_processor=clip_processor,
141
+ question_max_len=question_max_len,
142
+ answer_max_len=answer_max_len,
143
+ augment=True
144
+ )
145
+ val_dataset = AugmentedVQADataset(
146
+ val_df, DATA_DIR, question_tokenizer, vocab,
147
+ clip_processor=clip_processor,
148
+ question_max_len=question_max_len,
149
+ answer_max_len=answer_max_len,
150
+ augment=False
151
+ )
152
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
153
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
154
+ for batch in train_loader:
155
+ images = batch['image']
156
+ ques_ids = batch['question_ids']
157
+ attn_mask = batch['question_mask']
158
+ answers = batch['answer_ids']
159
+ print(f"Image: {images.shape}")
160
+ print(f"Question Ids: {ques_ids.shape}")
161
+ print(f"Attention Mask: {attn_mask.shape}")
162
+ print(f"Answer Ids: {answers.shape}")
163
+ print(answers[0])
164
+ break
experiments/utils/vocab.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from collections import Counter
4
+ from nltk.tokenize import word_tokenize
5
+ class Vocab:
6
+ def __init__(self):
7
+ self.vocab = None
8
+ self.vocab_size = None
9
+ self.word2idx = None
10
+ self.idx2word = None
11
+ self.pad = '<pad>'
12
+ self.bos = '<bos>'
13
+ self.eos = '<eos>'
14
+ self.unk = '<unk>'
15
+ def build_vocab(self, df, min_freq=1):
16
+ counter = Counter()
17
+ for ans in df['answer']:
18
+ tokens = word_tokenize(ans.lower())
19
+ counter.update(tokens)
20
+ vocab = sorted([word for word, freq in counter.items() if freq >= min_freq])
21
+ vocab = [self.pad, self.bos, self.eos, self.unk] + vocab
22
+ word2idx = {word: idx for idx, word in enumerate(vocab)}
23
+ idx2word = {idx: word for word, idx in word2idx.items()}
24
+ self.vocab = vocab
25
+ self.word2idx = word2idx
26
+ self.idx2word = idx2word
27
+ self.vocab_size = len(vocab)
28
+ self.pad_token_id = self.word2idx["<pad>"]
29
+ self.bos_token_id = self.word2idx["<bos>"]
30
+ self.eos_token_id = self.word2idx["<eos>"]
31
+ self.unk_token_id = self.word2idx["<unk>"]
32
+ def encoder(self, text, max_len):
33
+ tokens = word_tokenize(text.lower())
34
+ token_ids = [self.word2idx.get(token, self.unk_token_id) for token in tokens]
35
+ token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
36
+ if len(token_ids) < max_len:
37
+ token_ids += [self.pad_token_id] * (max_len - len(token_ids))
38
+ else:
39
+ token_ids = token_ids[:max_len]
40
+ return token_ids
41
+ def decoder(self, token_ids):
42
+ tokens = []
43
+ for idx in token_ids:
44
+ if idx == self.eos_token_id:
45
+ break
46
+ if idx in (self.pad_token_id, self.bos_token_id):
47
+ continue
48
+ tokens.append(self.idx2word.get(idx, "<unk>"))
49
+ return ' '.join(tokens).strip()
50
+ if __name__ == "__main__":
51
+ CSV_PATH = r"./gen_vqa_v2/metadata.csv"
52
+ answer_max_len = 10
53
+ metadata = pd.read_csv(CSV_PATH)
54
+ vocab = Vocab()
55
+ vocab.build_vocab(metadata, min_freq=5)
56
+ answer_vocab_size = len(vocab.vocab)
57
+ print(f"Answer Vocab Size: {answer_vocab_size}")
58
+ sample_answer = metadata['answer'].values
59
+ text = sample_answer[0]
60
+ print("")
61
+ encoded = vocab.encoder(text, answer_max_len)
62
+ decoded = vocab.decoder(encoded)
63
+ print(f"Sample Answer: {text}")
64
+ print(f"Encoded: {encoded}")
65
+ print(f"Decoded: {decoded}")
finetune.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+ from transformers import GPT2Tokenizer
9
+ from tqdm import tqdm
10
+ from sklearn.model_selection import train_test_split
11
+ from model import VQAModel
12
+ from train import AugmentedVQADataset, Vocab, save_checkpoint, plot_losses
13
+ def create_optimizer_with_differential_lr(model, clip_lr=5e-7, gpt_lr=5e-7, other_lr=3e-5):
14
+ clip_params, gpt_params, other_params = [], [], []
15
+ for name, param in model.named_parameters():
16
+ if param.requires_grad:
17
+ if 'clip_model' in name:
18
+ clip_params.append(param)
19
+ elif 'gpt2_model' in name:
20
+ gpt_params.append(param)
21
+ else:
22
+ other_params.append(param)
23
+ optimizer = torch.optim.AdamW([
24
+ {'params': clip_params, 'lr': clip_lr},
25
+ {'params': gpt_params, 'lr': gpt_lr},
26
+ {'params': other_params, 'lr': other_lr}
27
+ ], weight_decay=1e-4)
28
+ print(f"Optimizer: CLIP params: {len(clip_params)}, GPT-2 params: {len(gpt_params)}, Other params: {len(other_params)}")
29
+ return optimizer
30
+ def train_one_epoch(model, dataloader, optimizer, device, vocab, scaler):
31
+ model.train()
32
+ total_loss = 0.0
33
+ criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id, label_smoothing=0.1)
34
+ for batch in tqdm(dataloader):
35
+ optimizer.zero_grad()
36
+ images = batch['image'].to(device)
37
+ questions = {
38
+ 'input_ids': batch['question_ids'].to(device),
39
+ 'attention_mask': batch['question_mask'].to(device)
40
+ }
41
+ answers = batch['answer_ids'].to(device)
42
+ with torch.amp.autocast(device):
43
+ logits = model(images, questions, answer_input_ids=answers)
44
+ shifted_logits = logits[:, :-1, :].contiguous()
45
+ shifted_answers = answers[:, 1:].contiguous()
46
+ loss = criterion(
47
+ shifted_logits.view(-1, shifted_logits.size(-1)),
48
+ shifted_answers.view(-1)
49
+ )
50
+ if torch.isnan(loss):
51
+ print("NaN loss detected, skipping batch.")
52
+ continue
53
+ scaler.scale(loss).backward()
54
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
55
+ scaler.step(optimizer)
56
+ scaler.update()
57
+ total_loss += loss.item()
58
+ return total_loss / len(dataloader)
59
+ def validate_one_epoch(model, dataloader, device, vocab):
60
+ model.eval()
61
+ total_loss = 0.0
62
+ exact_matches = 0
63
+ total_samples = 0
64
+ criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id)
65
+ with torch.no_grad():
66
+ for batch in tqdm(dataloader):
67
+ images = batch['image'].to(device)
68
+ questions = {
69
+ 'input_ids': batch['question_ids'].to(device),
70
+ 'attention_mask': batch['question_mask'].to(device)
71
+ }
72
+ answers = batch['answer_ids'].to(device)
73
+ with torch.amp.autocast("cuda"):
74
+ logits = model(images, questions, answer_input_ids=answers)
75
+ shifted_logits = logits[:, :-1, :].contiguous()
76
+ shifted_answers = answers[:, 1:].contiguous()
77
+ loss = criterion(
78
+ shifted_logits.view(-1, shifted_logits.size(-1)),
79
+ shifted_answers.view(-1)
80
+ )
81
+ total_loss += loss.item()
82
+ generated = model(images, questions)
83
+ for pred, true in zip(generated, answers):
84
+ pred_text = vocab.decoder(pred.cpu().numpy())
85
+ true_text = vocab.decoder(true.cpu().numpy())
86
+ if pred_text.strip() == true_text.strip():
87
+ exact_matches += 1
88
+ total_samples += 1
89
+ avg_loss = total_loss / len(dataloader)
90
+ exact_match_acc = exact_matches / total_samples
91
+ return avg_loss, exact_match_acc
92
+ def filter_spatial_directional_data(df):
93
+ spatial_keywords = [
94
+ 'right', 'left', 'above', 'below', 'top', 'bottom',
95
+ 'front', 'behind', 'next to', 'beside', 'near',
96
+ 'looking', 'facing', 'pointing', 'direction',
97
+ 'where is', 'which side', 'what side'
98
+ ]
99
+ directional_answers = [
100
+ 'up', 'down', 'left', 'right', 'forward', 'backward',
101
+ 'north', 'south', 'east', 'west', 'straight', 'sideways'
102
+ ]
103
+ spatial_mask = df['question'].str.lower().str.contains('|'.join(spatial_keywords), na=False)
104
+ directional_mask = df['answer'].str.lower().str.contains('|'.join(directional_answers), na=False)
105
+ spatial_df = df[spatial_mask | directional_mask].copy()
106
+ print(f"Found {len(spatial_df)} spatial/directional samples out of {len(df)} total")
107
+ return spatial_df
108
+ def main():
109
+ print("# VQA: Spatial-Enhanced Fine-Tuning")
110
+ torch.manual_seed(42)
111
+ np.random.seed(42)
112
+ random.seed(42)
113
+ if torch.cuda.is_available():
114
+ torch.cuda.manual_seed_all(42)
115
+ DATA_DIR = r"./gen_vqa_v2"
116
+ CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
117
+ PRETRAINED_CHECKPOINT = "./output2/feature_extraction/vqa_checkpoint.pt"
118
+ OUTPUT_DIR = "./output2/spatial_finetuning"
119
+ FINE_TUNED_CHECKPOINT = os.path.join(OUTPUT_DIR, "vqa_checkpoint.pt")
120
+ LOG_CSV = os.path.join(OUTPUT_DIR, "train_log.csv")
121
+ LOSS_GRAPH_PATH = os.path.join(OUTPUT_DIR, "loss_plot.png")
122
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
123
+ batch_size = 64
124
+ num_epochs = 50
125
+ patience = 8
126
+ clip_layers_to_unfreeze = 8
127
+ gpt_layers_to_unfreeze = 8
128
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
129
+ checkpoint = torch.load(PRETRAINED_CHECKPOINT, map_location=device)
130
+ metadata = pd.read_csv(CSV_PATH)
131
+ print(f"\nOriginal dataset size: {len(metadata)}")
132
+ spatial_data = filter_spatial_directional_data(metadata)
133
+ if len(spatial_data) < 1000:
134
+ print(f"\nWARNING: Only {len(spatial_data)} spatial samples found!")
135
+ print("Mixing 70% spatial data with 30% general data for balanced training")
136
+ general_data = metadata[~metadata.index.isin(spatial_data.index)].sample(n=min(len(spatial_data)//2, len(metadata)//3), random_state=42)
137
+ mixed_data = pd.concat([spatial_data, general_data]).sample(frac=1, random_state=42).reset_index(drop=True)
138
+ else:
139
+ print(f"Using {len(spatial_data)} spatial/directional samples")
140
+ mixed_data = spatial_data
141
+ vocab = Vocab()
142
+ vocab.vocab = checkpoint['vocab']
143
+ vocab.vocab_size = len(checkpoint['vocab'])
144
+ vocab.word2idx = checkpoint['word2idx']
145
+ vocab.idx2word = checkpoint['idx2word']
146
+ vocab.pad_token_id = checkpoint['pad_token_id']
147
+ vocab.bos_token_id = checkpoint['bos_token_id']
148
+ vocab.eos_token_id = checkpoint['eos_token_id']
149
+ vocab.unk_token_id = checkpoint['unk_token_id']
150
+ print(f"Answer vocabulary size: {len(vocab.vocab)}")
151
+ model = VQAModel(
152
+ vocab_size=len(checkpoint['vocab']),
153
+ device=device,
154
+ question_max_len=checkpoint.get('question_max_len', 20),
155
+ answer_max_len=checkpoint.get('answer_max_len', 12),
156
+ pad_token_id=checkpoint['pad_token_id'],
157
+ bos_token_id=checkpoint['bos_token_id'],
158
+ eos_token_id=checkpoint['eos_token_id'],
159
+ unk_token_id=checkpoint['unk_token_id'],
160
+ hidden_size=512,
161
+ num_layers=2
162
+ ).to(device)
163
+ question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
164
+ if question_tokenizer.pad_token is None:
165
+ question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
166
+ model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
167
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
168
+ print("Pretrained model loaded successfully!\n")
169
+ print(f"UNFREEZING {clip_layers_to_unfreeze} CLIP LAYERS & {gpt_layers_to_unfreeze} GPT-2 LAYERS FOR SPATIAL UNDERSTANDING")
170
+ model.unfreeze_clip_layers(num_layers=clip_layers_to_unfreeze)
171
+ model.unfreeze_gpt2_layers(num_layers=gpt_layers_to_unfreeze)
172
+ train_df, test_df = train_test_split(mixed_data, test_size=0.2, random_state=42)
173
+ val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
174
+ print(f"Train size: {len(train_df)}, Val size: {len(val_df)}, Test size: {len(test_df)}\n")
175
+ train_dataset = AugmentedVQADataset(train_df, DATA_DIR, question_tokenizer, vocab,
176
+ clip_processor=model.clip_preprocess, augment=True,
177
+ question_max_len=20, answer_max_len=12)
178
+ val_dataset = AugmentedVQADataset(val_df, DATA_DIR, question_tokenizer, vocab,
179
+ clip_processor=model.clip_preprocess, augment=False,
180
+ question_max_len=20, answer_max_len=12)
181
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
182
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
183
+ optimizer = create_optimizer_with_differential_lr(
184
+ model,
185
+ clip_lr=3e-7,
186
+ gpt_lr=3e-7,
187
+ other_lr=2e-5
188
+ )
189
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, verbose=True)
190
+ scaler = torch.amp.GradScaler(device)
191
+ print("\nSTARTING SPATIAL-ENHANCED FINE-TUNING")
192
+ best_val_loss = np.inf
193
+ best_exact_match = 0.0
194
+ logs = []
195
+ counter = 0
196
+ for epoch in range(num_epochs):
197
+ print(f"\nSpatial Fine-tuning Epoch {epoch+1}/{num_epochs}")
198
+ train_loss = train_one_epoch(model, train_loader, optimizer, device, vocab, scaler)
199
+ val_loss, val_exact_match = validate_one_epoch(model, val_loader, device, vocab)
200
+ print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Exact Match: {val_exact_match:.4f} | LR: {optimizer.param_groups[0]['lr']}")
201
+ scheduler.step(val_exact_match)
202
+ if val_exact_match > best_exact_match:
203
+ best_exact_match = val_exact_match
204
+ save_checkpoint(model, optimizer, epoch, vocab, FINE_TUNED_CHECKPOINT)
205
+ print("Checkpoint saved!")
206
+ counter = 0
207
+ else:
208
+ counter += 1
209
+ print(f"No improvement for {counter} epochs.")
210
+ if counter >= patience:
211
+ print(f"\nEarly stopping after {patience} epochs without improvement")
212
+ break
213
+ logs.append([epoch + 1, train_loss, val_loss, val_exact_match, optimizer.param_groups[0]['lr']])
214
+ pd.DataFrame(logs, columns=["epoch", "train_loss", "val_loss", "val_exact_match", "lr"]).to_csv(LOG_CSV, index=False)
215
+ plot_losses([x[1] for x in logs], [x[2] for x in logs], save_path=LOSS_GRAPH_PATH)
216
+ print("\nFINE-TUNING COMPLETE")
217
+ print(f"Best exact match: {best_exact_match:.4f}")
218
+ print(f"Model saved to: {FINE_TUNED_CHECKPOINT}")
219
+ if __name__ == "__main__":
220
+ main()
finetune2.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import DataLoader
8
+ from transformers import GPT2Tokenizer
9
+ from tqdm import tqdm
10
+ from sklearn.model_selection import train_test_split
11
+ from model import VQAModel
12
+ from model_spatial import VQAModelWithSpatialAdapter
13
+ from train import AugmentedVQADataset, Vocab, save_checkpoint, plot_losses
14
+ import math
15
+ def filter_spatial_questions(df):
16
+ """
17
+ Filter dataset for spatial/directional questions.
18
+ Returns both spatial subset and general subset for mixed training.
19
+ """
20
+ spatial_keywords = [
21
+ 'right', 'left', 'above', 'below', 'top', 'bottom',
22
+ 'front', 'behind', 'next to', 'beside', 'near', 'between',
23
+ 'in front', 'in back', 'across from', 'opposite',
24
+ 'closest', 'farthest', 'nearest', 'furthest',
25
+ 'where is', 'which side', 'what side', 'what direction',
26
+ 'on the left', 'on the right', 'at the top', 'at the bottom'
27
+ ]
28
+ pattern = '|'.join(spatial_keywords)
29
+ spatial_mask = df['question'].str.lower().str.contains(pattern, na=False, regex=True)
30
+ spatial_df = df[spatial_mask].copy()
31
+ general_df = df[~spatial_mask].copy()
32
+ print(f"\n📊 Dataset Filtering Results:")
33
+ print(f" Total samples: {len(df):,}")
34
+ print(f" Spatial samples: {len(spatial_df):,} ({len(spatial_df)/len(df)*100:.1f}%)")
35
+ print(f" General samples: {len(general_df):,} ({len(general_df)/len(df)*100:.1f}%)")
36
+ if len(spatial_df) > 0:
37
+ print(f"\n📝 Sample Spatial Questions:")
38
+ for i, row in spatial_df.sample(min(5, len(spatial_df))).iterrows():
39
+ print(f" Q: {row['question']}")
40
+ print(f" A: {row['answer']}\n")
41
+ return spatial_df, general_df
42
+ def create_mixed_dataset(spatial_df, general_df, spatial_ratio=0.85, min_spatial_samples=1000):
43
+ """
44
+ Create mixed dataset with specified ratio of spatial to general questions.
45
+ Increased default to 85% spatial for better spatial learning.
46
+ """
47
+ if len(spatial_df) < min_spatial_samples:
48
+ print(f"\n⚠️ WARNING: Only {len(spatial_df)} spatial samples found!")
49
+ print(f" Recommended minimum: {min_spatial_samples}")
50
+ print(f" Mixing with general data to prevent catastrophic forgetting...")
51
+ num_spatial = len(spatial_df)
52
+ num_general = int(num_spatial * (1 - spatial_ratio) / spatial_ratio)
53
+ num_general = min(num_general, len(general_df))
54
+ else:
55
+ num_spatial = len(spatial_df)
56
+ num_general = int(num_spatial * (1 - spatial_ratio) / spatial_ratio)
57
+ num_general = min(num_general, len(general_df))
58
+ general_sample = general_df.sample(n=num_general, random_state=42)
59
+ mixed_df = pd.concat([spatial_df, general_sample]).sample(frac=1, random_state=42).reset_index(drop=True)
60
+ print(f"\n🔀 Mixed Dataset Created:")
61
+ print(f" Spatial: {num_spatial:,} ({num_spatial/len(mixed_df)*100:.1f}%)")
62
+ print(f" General: {num_general:,} ({num_general/len(mixed_df)*100:.1f}%)")
63
+ print(f" Total: {len(mixed_df):,}")
64
+ return mixed_df
65
+ def unfreeze_clip_layers(model, num_layers=4):
66
+ """
67
+ Unfreeze last N layers of CLIP for spatial feature learning.
68
+ """
69
+ total_blocks = len(model.clip_model.visual.transformer.resblocks)
70
+ for i, block in enumerate(model.clip_model.visual.transformer.resblocks):
71
+ if i >= total_blocks - num_layers:
72
+ for p in block.parameters():
73
+ p.requires_grad = True
74
+ if hasattr(model.clip_model.visual, "proj") and model.clip_model.visual.proj is not None:
75
+ if isinstance(model.clip_model.visual.proj, torch.nn.Parameter):
76
+ model.clip_model.visual.proj.requires_grad = True
77
+ else:
78
+ for p in model.clip_model.visual.proj.parameters():
79
+ p.requires_grad = True
80
+ if hasattr(model.clip_model.visual, "ln_post"):
81
+ for p in model.clip_model.visual.ln_post.parameters():
82
+ p.requires_grad = True
83
+ print(f" ✓ Unfroze last {num_layers} CLIP layers")
84
+ def freeze_base_model(model, unfreeze_clip_layers_count=4):
85
+ """
86
+ Freeze most of the model, unfreeze spatial adapter and last CLIP layers.
87
+ """
88
+ for param in model.clip_model.parameters():
89
+ param.requires_grad = False
90
+ unfreeze_clip_layers(model, num_layers=unfreeze_clip_layers_count)
91
+ for param in model.gpt2_model.parameters():
92
+ param.requires_grad = False
93
+ for param in model.decoder.parameters():
94
+ param.requires_grad = False
95
+ for param in model.spatial_adapter.parameters():
96
+ param.requires_grad = True
97
+ for param in model.spatial_context_proj.parameters():
98
+ param.requires_grad = True
99
+ for param in model.q_proj.parameters():
100
+ param.requires_grad = True
101
+ for param in model.spatial_fusion.parameters():
102
+ param.requires_grad = True
103
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
104
+ total_params = sum(p.numel() for p in model.parameters())
105
+ print(f"\n🔒 Model Freezing Applied:")
106
+ print(f" Total parameters: {total_params:,}")
107
+ print(f" Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)")
108
+ print(f" Frozen parameters: {total_params - trainable_params:,}")
109
+ return model
110
+ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr=1e-7):
111
+ """
112
+ Create learning rate scheduler with warmup and cosine decay.
113
+ """
114
+ def lr_lambda(current_step):
115
+ if current_step < num_warmup_steps:
116
+ return float(current_step) / float(max(1, num_warmup_steps))
117
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
118
+ return max(min_lr, 0.5 * (1.0 + math.cos(math.pi * progress)))
119
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
120
+ def create_optimizer_with_differential_lr(model, base_lr=5e-5):
121
+ """
122
+ Create optimizer with differential learning rates for different components.
123
+ """
124
+ clip_params = []
125
+ spatial_adapter_params = []
126
+ other_params = []
127
+ for name, param in model.named_parameters():
128
+ if param.requires_grad:
129
+ if 'clip_model' in name:
130
+ clip_params.append(param)
131
+ elif 'spatial_adapter' in name:
132
+ spatial_adapter_params.append(param)
133
+ else:
134
+ other_params.append(param)
135
+ optimizer = torch.optim.AdamW([
136
+ {'params': clip_params, 'lr': base_lr * 0.1},
137
+ {'params': spatial_adapter_params, 'lr': base_lr},
138
+ {'params': other_params, 'lr': base_lr * 0.5}
139
+ ], weight_decay=1e-4)
140
+ print(f"\n⚙️ Optimizer Configuration:")
141
+ print(f" CLIP params: {len(clip_params):,} (LR: {base_lr * 0.1:.2e})")
142
+ print(f" Spatial adapter params: {len(spatial_adapter_params):,} (LR: {base_lr:.2e})")
143
+ print(f" Other params: {len(other_params):,} (LR: {base_lr * 0.5:.2e})")
144
+ return optimizer
145
+ def train_one_epoch(model, dataloader, optimizer, device, vocab, scaler):
146
+ """Training loop for one epoch"""
147
+ model.train()
148
+ total_loss = 0.0
149
+ total_token_acc = 0.0
150
+ criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id, label_smoothing=0.1)
151
+ for batch in tqdm(dataloader, desc="Training"):
152
+ optimizer.zero_grad()
153
+ images = batch['image'].to(device)
154
+ questions = {
155
+ 'input_ids': batch['question_ids'].to(device),
156
+ 'attention_mask': batch['question_mask'].to(device)
157
+ }
158
+ answers = batch['answer_ids'].to(device)
159
+ with torch.amp.autocast(device):
160
+ logits = model(images, questions, answer_input_ids=answers)
161
+ shifted_logits = logits[:, :-1, :].contiguous()
162
+ shifted_answers = answers[:, 1:].contiguous()
163
+ loss = criterion(
164
+ shifted_logits.view(-1, shifted_logits.size(-1)),
165
+ shifted_answers.view(-1)
166
+ )
167
+ predicted_tokens = shifted_logits.argmax(dim=-1)
168
+ correct = (predicted_tokens == shifted_answers).float()
169
+ mask = (shifted_answers != vocab.pad_token_id).float()
170
+ token_acc = (correct * mask).sum() / mask.sum()
171
+ total_token_acc += token_acc.item()
172
+ if torch.isnan(loss):
173
+ print("⚠️ NaN loss detected, skipping batch.")
174
+ continue
175
+ scaler.scale(loss).backward()
176
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
177
+ scaler.step(optimizer)
178
+ scaler.update()
179
+ total_loss += loss.item()
180
+ avg_loss = total_loss / len(dataloader)
181
+ avg_token_acc = total_token_acc / len(dataloader)
182
+ return avg_loss, avg_token_acc
183
+ def validate_one_epoch(model, dataloader, device, vocab):
184
+ """Validation loop for one epoch"""
185
+ model.eval()
186
+ total_loss = 0.0
187
+ total_token_acc = 0.0
188
+ exact_matches = 0
189
+ total_samples = 0
190
+ criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_token_id)
191
+ with torch.no_grad():
192
+ for batch in tqdm(dataloader, desc="Validation"):
193
+ images = batch['image'].to(device)
194
+ questions = {
195
+ 'input_ids': batch['question_ids'].to(device),
196
+ 'attention_mask': batch['question_mask'].to(device)
197
+ }
198
+ answers = batch['answer_ids'].to(device)
199
+ with torch.amp.autocast(device):
200
+ logits = model(images, questions, answer_input_ids=answers)
201
+ shifted_logits = logits[:, :-1, :].contiguous()
202
+ shifted_answers = answers[:, 1:].contiguous()
203
+ loss = criterion(
204
+ shifted_logits.view(-1, shifted_logits.size(-1)),
205
+ shifted_answers.view(-1)
206
+ )
207
+ predicted_tokens = shifted_logits.argmax(dim=-1)
208
+ correct = (predicted_tokens == shifted_answers).float()
209
+ mask = (shifted_answers != vocab.pad_token_id).float()
210
+ token_acc = (correct * mask).sum() / mask.sum()
211
+ total_token_acc += token_acc.item()
212
+ total_loss += loss.item()
213
+ generated = model(images, questions)
214
+ for pred, true in zip(generated, answers):
215
+ pred_text = vocab.decoder(pred.cpu().numpy())
216
+ true_text = vocab.decoder(true.cpu().numpy())
217
+ if pred_text.strip() == true_text.strip():
218
+ exact_matches += 1
219
+ total_samples += 1
220
+ avg_loss = total_loss / len(dataloader)
221
+ avg_token_acc = total_token_acc / len(dataloader)
222
+ exact_match_acc = exact_matches / total_samples
223
+ return avg_loss, avg_token_acc, exact_match_acc
224
+ def main():
225
+ print("=" * 80)
226
+ print("🚀 VQA SPATIAL ADAPTER FINE-TUNING V2 (ENHANCED)")
227
+ print("=" * 80)
228
+ torch.manual_seed(42)
229
+ np.random.seed(42)
230
+ random.seed(42)
231
+ if torch.cuda.is_available():
232
+ torch.cuda.manual_seed_all(42)
233
+ DATA_DIR = r"./gen_vqa_v2"
234
+ CSV_PATH = os.path.join(DATA_DIR, "metadata.csv")
235
+ PRETRAINED_CHECKPOINT = "./output2/continued_training/vqa_checkpoint.pt"
236
+ OUTPUT_DIR = "./output2/spatial_adapter_v2_2"
237
+ FINE_TUNED_CHECKPOINT = os.path.join(OUTPUT_DIR, "vqa_spatial_checkpoint.pt")
238
+ LOG_CSV = os.path.join(OUTPUT_DIR, "train_log.csv")
239
+ LOSS_GRAPH_PATH = os.path.join(OUTPUT_DIR, "loss_plot.png")
240
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
241
+ batch_size = 64
242
+ base_learning_rate = 5e-5
243
+ num_epochs = 100
244
+ patience = 15
245
+ warmup_epochs = 3
246
+ spatial_ratio = 0.85
247
+ clip_layers_to_unfreeze = 6
248
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
249
+ print(f"\n⚙️ Enhanced Configuration:")
250
+ print(f" Device: {device}")
251
+ print(f" Batch size: {batch_size}")
252
+ print(f" Base learning rate: {base_learning_rate:.2e}")
253
+ print(f" Max epochs: {num_epochs} (increased from 20)")
254
+ print(f" Warmup epochs: {warmup_epochs}")
255
+ print(f" Early stopping patience: {patience}")
256
+ print(f" Spatial ratio: {spatial_ratio:.0%} (increased from 70%)")
257
+ print(f" CLIP layers to unfreeze: {clip_layers_to_unfreeze}")
258
+ print(f"\n📂 Loading dataset from: {CSV_PATH}")
259
+ metadata = pd.read_csv(CSV_PATH)
260
+ spatial_df, general_df = filter_spatial_questions(metadata)
261
+ mixed_data = create_mixed_dataset(spatial_df, general_df, spatial_ratio=spatial_ratio)
262
+ print(f"\n📥 Loading pretrained model from: {PRETRAINED_CHECKPOINT}")
263
+ checkpoint = torch.load(PRETRAINED_CHECKPOINT, map_location=device)
264
+ vocab = Vocab()
265
+ vocab.vocab = checkpoint['vocab']
266
+ vocab.vocab_size = len(checkpoint['vocab'])
267
+ vocab.word2idx = checkpoint['word2idx']
268
+ vocab.idx2word = checkpoint['idx2word']
269
+ vocab.pad_token_id = checkpoint['pad_token_id']
270
+ vocab.bos_token_id = checkpoint['bos_token_id']
271
+ vocab.eos_token_id = checkpoint['eos_token_id']
272
+ vocab.unk_token_id = checkpoint['unk_token_id']
273
+ print(f" Vocabulary size: {len(vocab.vocab):,}")
274
+ question_tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
275
+ if question_tokenizer.pad_token is None:
276
+ question_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
277
+ base_model = VQAModel(
278
+ vocab_size=len(checkpoint['vocab']),
279
+ device=device,
280
+ question_max_len=checkpoint.get('question_max_len', 20),
281
+ answer_max_len=checkpoint.get('answer_max_len', 12),
282
+ pad_token_id=checkpoint['pad_token_id'],
283
+ bos_token_id=checkpoint['bos_token_id'],
284
+ eos_token_id=checkpoint['eos_token_id'],
285
+ unk_token_id=checkpoint['unk_token_id'],
286
+ hidden_size=512,
287
+ num_layers=2
288
+ ).to(device)
289
+ base_model.gpt2_model.resize_token_embeddings(len(question_tokenizer))
290
+ base_model.load_state_dict(checkpoint['model_state_dict'], strict=False)
291
+ print(" ✓ Pretrained weights loaded")
292
+ print(f"\n🔧 Creating VQA model with spatial adapter...")
293
+ model = VQAModelWithSpatialAdapter(
294
+ base_model=base_model,
295
+ hidden_size=512,
296
+ num_heads=8,
297
+ dropout=0.3
298
+ ).to(device)
299
+ model = freeze_base_model(model, unfreeze_clip_layers_count=clip_layers_to_unfreeze)
300
+ train_df, test_df = train_test_split(mixed_data, test_size=0.2, random_state=42)
301
+ val_df, test_df = train_test_split(test_df, test_size=0.5, random_state=42)
302
+ print(f"\n📊 Data Split:")
303
+ print(f" Train: {len(train_df):,} samples")
304
+ print(f" Validation: {len(val_df):,} samples")
305
+ print(f" Test: {len(test_df):,} samples")
306
+ from torchvision import transforms
307
+ safe_augmentation = transforms.Compose([
308
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
309
+ transforms.RandomRotation(5),
310
+ ])
311
+ train_dataset = AugmentedVQADataset(
312
+ train_df, DATA_DIR, question_tokenizer, vocab,
313
+ clip_processor=model.clip_preprocess,
314
+ augment=False,
315
+ question_max_len=20,
316
+ answer_max_len=12
317
+ )
318
+ val_dataset = AugmentedVQADataset(
319
+ val_df, DATA_DIR, question_tokenizer, vocab,
320
+ clip_processor=model.clip_preprocess,
321
+ augment=False,
322
+ question_max_len=20,
323
+ answer_max_len=12
324
+ )
325
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
326
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
327
+ optimizer = create_optimizer_with_differential_lr(model, base_lr=base_learning_rate)
328
+ num_training_steps = len(train_loader) * num_epochs
329
+ num_warmup_steps = len(train_loader) * warmup_epochs
330
+ scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
331
+ print(f"\n📈 Learning Rate Schedule:")
332
+ print(f" Warmup steps: {num_warmup_steps:,} ({warmup_epochs} epochs)")
333
+ print(f" Total steps: {num_training_steps:,}")
334
+ print(f" Schedule: Linear warmup → Cosine decay")
335
+ scaler = torch.amp.GradScaler(device)
336
+ print("\n" + "=" * 80)
337
+ print("🎯 STARTING ENHANCED SPATIAL ADAPTER FINE-TUNING")
338
+ print("=" * 80)
339
+ best_val_exact_match = 0.0
340
+ best_val_loss = np.inf
341
+ counter = 0
342
+ logs = []
343
+ for epoch in range(num_epochs):
344
+ print(f"\n📅 Epoch {epoch+1}/{num_epochs}")
345
+ print("-" * 80)
346
+ train_loss, train_token_acc = train_one_epoch(model, train_loader, optimizer, device, vocab, scaler)
347
+ val_loss, val_token_acc, val_exact_match = validate_one_epoch(model, val_loader, device, vocab)
348
+ current_lr = optimizer.param_groups[1]['lr']
349
+ print(f"\n📈 Metrics:")
350
+ print(f" Train Loss: {train_loss:.4f} | Train Token Acc: {train_token_acc:.4f}")
351
+ print(f" Val Loss: {val_loss:.4f} | Val Token Acc: {val_token_acc:.4f}")
352
+ print(f" Val Exact Match: {val_exact_match:.4f}")
353
+ print(f" Learning Rate: {current_lr:.2e}")
354
+ if val_exact_match > best_val_exact_match:
355
+ best_val_exact_match = val_exact_match
356
+ save_checkpoint(model, optimizer, epoch, vocab, FINE_TUNED_CHECKPOINT)
357
+ print(f" ✅ New best model saved! (Exact Match: {val_exact_match:.4f})")
358
+ counter = 0
359
+ else:
360
+ counter += 1
361
+ print(f" ⏳ No improvement for {counter} epoch(s)")
362
+ if counter >= patience:
363
+ print(f"\n⏹️ Early stopping triggered after {patience} epochs without improvement")
364
+ break
365
+ logs.append([
366
+ epoch + 1,
367
+ train_loss,
368
+ train_token_acc,
369
+ val_loss,
370
+ val_token_acc,
371
+ val_exact_match,
372
+ current_lr
373
+ ])
374
+ for _ in range(len(train_loader)):
375
+ scheduler.step()
376
+ log_df = pd.DataFrame(
377
+ logs,
378
+ columns=["epoch", "train_loss", "train_token_acc", "val_loss", "val_token_acc", "val_exact_match", "lr"]
379
+ )
380
+ log_df.to_csv(LOG_CSV, index=False)
381
+ plot_losses([x[1] for x in logs], [x[3] for x in logs], save_path=LOSS_GRAPH_PATH)
382
+ print("\n" + "=" * 80)
383
+ print("✅ ENHANCED FINE-TUNING COMPLETE")
384
+ print("=" * 80)
385
+ print(f"\n📊 Final Results:")
386
+ print(f" Best Exact Match: {best_val_exact_match:.4f}")
387
+ print(f" Total Epochs: {len(logs)}")
388
+ print(f" Improvement from v1: {best_val_exact_match - 0.2037:.4f} ({(best_val_exact_match - 0.2037) / 0.2037 * 100:+.1f}%)")
389
+ print(f"\n💾 Outputs:")
390
+ print(f" Model: {FINE_TUNED_CHECKPOINT}")
391
+ print(f" Logs: {LOG_CSV}")
392
+ print(f" Plot: {LOSS_GRAPH_PATH}")
393
+ print("\n🎉 Ready to test on spatial questions!")
394
+ if __name__ == "__main__":
395
+ main()
genvqa-dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ import shutil
5
+ import pandas as pd
6
+ from tqdm import tqdm
7
+ from collections import Counter
8
+ IMAGES_DIR = r"../train2014"
9
+ QUESTIONS_PATH = r"../v2_OpenEnded_mscoco_train2014_questions.json"
10
+ ANNOTATIONS_PATH = r"../v2_mscoco_train2014_annotations.json"
11
+ OUTPUT_DIR = "./gen_vqa_v2"
12
+ os.makedirs(os.path.join(OUTPUT_DIR, "images"), exist_ok=True)
13
+ print("Loading VQA v2 data...")
14
+ with open(QUESTIONS_PATH, "r") as f:
15
+ questions = json.load(f)["questions"]
16
+ with open(ANNOTATIONS_PATH, "r") as f:
17
+ annotations = json.load(f)["annotations"]
18
+ qid_to_ann = {ann["question_id"]: ann for ann in annotations}
19
+ print("Merging questions and answers...")
20
+ merged_data = []
21
+ answer_counter = Counter()
22
+ EXCLUDED_ANSWERS = ['yes', 'no', 'unknown', 'none', 'n/a', 'cant tell', 'not sure']
23
+ AMBIGUOUS_QUESTIONS = ['what is in the image', 'what is this', 'what is that', 'what do you see']
24
+ for q in tqdm(questions, total=len(questions)):
25
+ ann = qid_to_ann.get(q["question_id"])
26
+ if not ann:
27
+ continue
28
+ answers = [a["answer"] for a in ann["answers"] if a["answer"].strip()]
29
+ if not answers:
30
+ continue
31
+ main_answer = max(set(answers), key=answers.count)
32
+ main_answer = main_answer.lower().strip()
33
+ question_text = q["question"].lower().strip()
34
+ if main_answer in EXCLUDED_ANSWERS:
35
+ continue
36
+ if any(ambig in question_text for ambig in AMBIGUOUS_QUESTIONS):
37
+ continue
38
+ if len(main_answer.split()) <= 5 and len(main_answer) <= 30:
39
+ merged_data.append({
40
+ "image_id": q["image_id"],
41
+ "question_id": q["question_id"],
42
+ "question": q["question"],
43
+ "answer": main_answer
44
+ })
45
+ answer_counter[main_answer] += 1
46
+ print(f"Total valid Q-A pairs (after filtering): {len(merged_data)}")
47
+ MIN_ANSWER_FREQ = 20
48
+ frequent_answers = {ans for ans, count in answer_counter.items() if count >= MIN_ANSWER_FREQ}
49
+ filtered_data = [item for item in merged_data if item["answer"] in frequent_answers]
50
+ print(f"After frequency filtering (min_freq={MIN_ANSWER_FREQ}): {len(filtered_data)} pairs")
51
+ MAX_SAMPLES_PER_ANSWER = 600
52
+ answer_samples = {}
53
+ for item in filtered_data:
54
+ ans = item["answer"]
55
+ if ans not in answer_samples:
56
+ answer_samples[ans] = []
57
+ if len(answer_samples[ans]) < MAX_SAMPLES_PER_ANSWER:
58
+ answer_samples[ans].append(item)
59
+ balanced_data = []
60
+ for samples in answer_samples.values():
61
+ balanced_data.extend(samples)
62
+ random.shuffle(balanced_data)
63
+ print(f"After balancing: {len(balanced_data)} pairs with {len(answer_samples)} unique answers")
64
+ print("Copying selected images and saving data...")
65
+ final_data = []
66
+ for item in tqdm(balanced_data):
67
+ img_name = f"COCO_train2014_{item['image_id']:012d}.jpg"
68
+ src_path = os.path.join(IMAGES_DIR, img_name)
69
+ dst_path = os.path.join(OUTPUT_DIR, "images", img_name)
70
+ if os.path.exists(src_path):
71
+ shutil.copy(src_path, dst_path)
72
+ item["image_path"] = f"images/{img_name}"
73
+ final_data.append(item)
74
+ print(f"Final dataset: {len(final_data)} pairs")
75
+ with open(os.path.join(OUTPUT_DIR, "qa_pairs.json"), "w") as f:
76
+ json.dump(final_data, f, indent=2, ensure_ascii=False)
77
+ pd.DataFrame(final_data).to_csv(os.path.join(OUTPUT_DIR, "metadata.csv"), index=False)
78
+ print("Data preparation complete.")
groq_service.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Groq LLM Service for VQA Accessibility
3
+ Generates descriptive 2-sentence narrations for blind users
4
+ """
5
+ import os
6
+ from typing import Dict, Optional
7
+ from groq import Groq
8
+ class GroqDescriptionService:
9
+ """Service to generate accessible descriptions using Groq LLM"""
10
+ def __init__(self, api_key: Optional[str] = None):
11
+ """
12
+ Initialize Groq service
13
+ Args:
14
+ api_key: Groq API key (if not provided, reads from GROQ_API_KEY env var)
15
+ """
16
+ self.api_key = api_key or os.getenv("GROQ_API_KEY")
17
+ if not self.api_key:
18
+ raise ValueError(
19
+ "Groq API key not found. Set GROQ_API_KEY environment variable "
20
+ "or pass api_key parameter"
21
+ )
22
+ self.client = Groq(api_key=self.api_key)
23
+ self.model = "llama-3.3-70b-versatile"
24
+ def generate_description(
25
+ self,
26
+ question: str,
27
+ answer: str,
28
+ max_retries: int = 2
29
+ ) -> Dict[str, str]:
30
+ """
31
+ Generate a 2-sentence accessible description for blind users
32
+ Args:
33
+ question: The question asked by the user
34
+ answer: The VQA model's answer
35
+ max_retries: Number of retry attempts on failure
36
+ Returns:
37
+ Dict with 'description' and 'status' keys
38
+ """
39
+ prompt = f"""You are an accessibility assistant helping blind users understand visual question answering results.
40
+ Question asked: "{question}"
41
+ Answer from VQA model: "{answer}"
42
+ Task: Create a clear, natural 2-sentence description that:
43
+ 1. First sentence: Restates the question and provides the answer
44
+ 2. Second sentence: Adds helpful context or clarification
45
+ Keep it concise, natural, and easy to understand when spoken aloud.
46
+ Example:
47
+ Question: "What color is the car?"
48
+ Answer: "red"
49
+ Description: "The question asks about the color of the car, and the answer is red. This indicates there is a red-colored vehicle visible in the image."
50
+ Now generate the description:"""
51
+ for attempt in range(max_retries + 1):
52
+ try:
53
+ response = self.client.chat.completions.create(
54
+ model=self.model,
55
+ messages=[
56
+ {
57
+ "role": "system",
58
+ "content": "You are a helpful accessibility assistant. Always respond with exactly 2 clear, natural sentences."
59
+ },
60
+ {
61
+ "role": "user",
62
+ "content": prompt
63
+ }
64
+ ],
65
+ temperature=0.7,
66
+ max_tokens=150,
67
+ top_p=0.9
68
+ )
69
+ description = response.choices[0].message.content.strip()
70
+ if description.startswith("Description:"):
71
+ description = description.replace("Description:", "").strip()
72
+ return {
73
+ "description": description,
74
+ "status": "success",
75
+ "model": self.model
76
+ }
77
+ except Exception as e:
78
+ if attempt < max_retries:
79
+ continue
80
+ else:
81
+ fallback = f"The question asks: {question}. The answer is: {answer}."
82
+ return {
83
+ "description": fallback,
84
+ "status": "fallback",
85
+ "error": str(e)
86
+ }
87
+ def generate_batch_descriptions(
88
+ self,
89
+ qa_pairs: list[Dict[str, str]]
90
+ ) -> list[Dict[str, str]]:
91
+ """
92
+ Generate descriptions for multiple Q&A pairs
93
+ Args:
94
+ qa_pairs: List of dicts with 'question' and 'answer' keys
95
+ Returns:
96
+ List of description results
97
+ """
98
+ results = []
99
+ for pair in qa_pairs:
100
+ result = self.generate_description(
101
+ question=pair.get("question", ""),
102
+ answer=pair.get("answer", "")
103
+ )
104
+ results.append(result)
105
+ return results
106
+ _groq_service_instance = None
107
+ def get_groq_service(api_key: Optional[str] = None) -> GroqDescriptionService:
108
+ """
109
+ Get or create Groq service singleton
110
+ Args:
111
+ api_key: Optional API key (uses env var if not provided)
112
+ Returns:
113
+ GroqDescriptionService instance
114
+ """
115
+ global _groq_service_instance
116
+ if _groq_service_instance is None:
117
+ _groq_service_instance = GroqDescriptionService(api_key=api_key)
118
+ return _groq_service_instance
knowledge_graph_service.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Knowledge Graph Service for Neuro-Symbolic VQA
3
+ Uses ConceptNet API to provide common-sense reasoning capabilities
4
+ """
5
+
6
+ import requests
7
+ import re
8
+ from typing import Dict, List, Optional
9
+ from functools import lru_cache
10
+ import time
11
+
12
+
13
+ class KnowledgeGraphService:
14
+ """
15
+ Lightweight ConceptNet API wrapper for common-sense reasoning.
16
+ Enhances VQA answers with external knowledge about object properties,
17
+ capabilities, uses, and relationships.
18
+ """
19
+
20
+ CONCEPTNET_API = "https://api.conceptnet.io"
21
+
22
+ # Common-sense question patterns
23
+ COMMONSENSE_PATTERNS = [
24
+ # Capability questions
25
+ (r'can .* (melt|freeze|fly|swim|float|sink|break|burn|explode)', 'CapableOf'),
26
+ (r'is .* able to', 'CapableOf'),
27
+ (r'does .* (float|sink)', 'CapableOf'),
28
+
29
+ # Property questions
30
+ (r'is .* (edible|poisonous|dangerous|safe|hot|cold|sweet|sour)', 'HasProperty'),
31
+ (r'is this (food|drink|toy|tool|weapon)', 'HasProperty'),
32
+
33
+ # Purpose questions
34
+ (r'what .* (used for|for)', 'UsedFor'),
35
+ (r'why .* (used|made)', 'UsedFor'),
36
+ (r'how .* use', 'UsedFor'),
37
+
38
+ # Composition questions
39
+ (r'what .* made (of|from)', 'MadeOf'),
40
+ (r'what .* (material|ingredient)', 'MadeOf'),
41
+
42
+ # Location questions
43
+ (r'where .* (found|located|kept|stored)', 'AtLocation'),
44
+ (r'where (do|does) .* (live|grow)', 'AtLocation'),
45
+ ]
46
+
47
+ def __init__(self, cache_size=100, timeout=5):
48
+ """
49
+ Initialize Knowledge Graph service.
50
+
51
+ Args:
52
+ cache_size: Number of API responses to cache
53
+ timeout: API request timeout in seconds
54
+ """
55
+ self.timeout = timeout
56
+ self.cache_size = cache_size
57
+ print("✅ Knowledge Graph service initialized (ConceptNet API)")
58
+
59
+ @lru_cache(maxsize=100)
60
+ def _query_conceptnet(self, concept: str, relation: str, limit: int = 10) -> Optional[Dict]:
61
+ """
62
+ Query ConceptNet API with caching.
63
+
64
+ Args:
65
+ concept: Concept to query (e.g., "ice_cream")
66
+ relation: Relation type (e.g., "CapableOf", "HasProperty")
67
+ limit: Maximum number of results
68
+
69
+ Returns:
70
+ API response dict or None if failed
71
+ """
72
+ try:
73
+ # Normalize concept (replace spaces with underscores)
74
+ concept = concept.lower().replace(' ', '_')
75
+
76
+ # Build API URL
77
+ url = f"{self.CONCEPTNET_API}/query"
78
+ params = {
79
+ 'start': f'/c/en/{concept}',
80
+ 'rel': f'/r/{relation}',
81
+ 'limit': limit
82
+ }
83
+
84
+ # Make request
85
+ response = requests.get(url, params=params, timeout=self.timeout)
86
+ response.raise_for_status()
87
+
88
+ return response.json()
89
+
90
+ except requests.exceptions.Timeout:
91
+ print(f"⚠️ ConceptNet API timeout for {concept}")
92
+ return None
93
+ except requests.exceptions.RequestException as e:
94
+ print(f"⚠️ ConceptNet API error: {e}")
95
+ return None
96
+ except Exception as e:
97
+ print(f"⚠️ Unexpected error querying ConceptNet: {e}")
98
+ return None
99
+
100
+ def get_concept_properties(self, concept: str) -> Dict[str, List[str]]:
101
+
102
+ properties = {
103
+ 'CapableOf': [],
104
+ 'HasProperty': [],
105
+ 'UsedFor': [],
106
+ 'MadeOf': [],
107
+ 'AtLocation': []
108
+ }
109
+
110
+ # Query each relation type
111
+ for relation in properties.keys():
112
+ data = self._query_conceptnet(concept, relation)
113
+
114
+ if data and 'edges' in data:
115
+ for edge in data['edges']:
116
+ # Extract the end concept
117
+ if 'end' in edge and 'label' in edge['end']:
118
+ end_label = edge['end']['label']
119
+ properties[relation].append(end_label)
120
+
121
+ return properties
122
+
123
+ def is_commonsense_question(self, question: str) -> bool:
124
+ """
125
+ Detect if a question requires common-sense reasoning.
126
+
127
+ Args:
128
+ question: Question string
129
+
130
+ Returns:
131
+ True if question needs external knowledge
132
+ """
133
+ q_lower = question.lower()
134
+
135
+ for pattern, _ in self.COMMONSENSE_PATTERNS:
136
+ if re.search(pattern, q_lower):
137
+ return True
138
+
139
+ return False
140
+
141
+ def _detect_question_type(self, question: str) -> Optional[str]:
142
+ """
143
+ Detect which ConceptNet relation the question is asking about.
144
+
145
+ Args:
146
+ question: Question string
147
+
148
+ Returns:
149
+ Relation type or None
150
+ """
151
+ q_lower = question.lower()
152
+
153
+ for pattern, relation in self.COMMONSENSE_PATTERNS:
154
+ if re.search(pattern, q_lower):
155
+ return relation
156
+
157
+ return None
158
+
159
+ def answer_commonsense_question(self, object_name: str, question: str) -> Optional[str]:
160
+ """
161
+ Answer a common-sense question using Knowledge Graph.
162
+
163
+ Args:
164
+ object_name: Object detected by VQA (e.g., "ice cream")
165
+ question: User's question
166
+
167
+ Returns:
168
+ Enhanced answer string or None
169
+ """
170
+ # Detect question type
171
+ relation = self._detect_question_type(question)
172
+ if not relation:
173
+ return None
174
+
175
+ # Query ConceptNet
176
+ data = self._query_conceptnet(object_name, relation, limit=5)
177
+ if not data or 'edges' not in data:
178
+ return None
179
+
180
+ # Extract relevant knowledge
181
+ knowledge = []
182
+ for edge in data['edges']:
183
+ if 'end' in edge and 'label' in edge['end']:
184
+ knowledge.append(edge['end']['label'])
185
+
186
+ if not knowledge:
187
+ return None
188
+
189
+ # Generate natural language answer based on question type
190
+ return self._synthesize_answer(object_name, question, relation, knowledge)
191
+
192
+ def _synthesize_answer(self, object_name: str, question: str,
193
+ relation: str, knowledge: List[str]) -> str:
194
+ """
195
+ Synthesize natural language answer from knowledge.
196
+
197
+ Args:
198
+ object_name: Detected object
199
+ question: Original question
200
+ relation: Relation type
201
+ knowledge: List of related concepts from KG
202
+
203
+ Returns:
204
+ Natural language answer
205
+ """
206
+ q_lower = question.lower()
207
+
208
+ # Capability questions (can X do Y?)
209
+ if relation == 'CapableOf':
210
+ # Check if specific capability is mentioned
211
+ for capability in knowledge:
212
+ if capability in q_lower:
213
+ return f"Yes, {object_name} can {capability}."
214
+
215
+ # General capability answer
216
+ if knowledge:
217
+ caps = ', '.join(knowledge[:3])
218
+ return f"{object_name.capitalize()} can {caps}."
219
+
220
+ # Property questions (is X Y?)
221
+ elif relation == 'HasProperty':
222
+ # Check for specific property
223
+ if 'edible' in q_lower:
224
+ if 'edible' in knowledge:
225
+ return f"Yes, {object_name} is edible."
226
+ else:
227
+ return f"No, {object_name} is not edible."
228
+
229
+ if 'dangerous' in q_lower or 'safe' in q_lower:
230
+ if any(prop in knowledge for prop in ['dangerous', 'harmful', 'poisonous']):
231
+ return f"Caution: {object_name} may be dangerous."
232
+ else:
233
+ return f"{object_name.capitalize()} is generally safe."
234
+
235
+ # General properties
236
+ if knowledge:
237
+ props = ', '.join(knowledge[:3])
238
+ return f"{object_name.capitalize()} is {props}."
239
+
240
+ # Purpose questions (what is X used for?)
241
+ elif relation == 'UsedFor':
242
+ if knowledge:
243
+ uses = ', '.join(knowledge[:3])
244
+ return f"{object_name.capitalize()} is used for {uses}."
245
+
246
+ # Composition questions (what is X made of?)
247
+ elif relation == 'MadeOf':
248
+ if knowledge:
249
+ materials = ', '.join(knowledge[:3])
250
+ return f"{object_name.capitalize()} is made of {materials}."
251
+
252
+ # Location questions (where is X found?)
253
+ elif relation == 'AtLocation':
254
+ if knowledge:
255
+ locations = ', '.join(knowledge[:2])
256
+ return f"{object_name.capitalize()} is typically found at {locations}."
257
+
258
+ return None
259
+
260
+
261
+ # Test function
262
+ if __name__ == "__main__":
263
+ print("=" * 80)
264
+ print("🧪 Testing Knowledge Graph Service")
265
+ print("=" * 80)
266
+
267
+ kg = KnowledgeGraphService()
268
+
269
+ # Test cases
270
+ test_cases = [
271
+ ("ice cream", "Can this melt?"),
272
+ ("apple", "Is this edible?"),
273
+ ("hammer", "What is this used for?"),
274
+ ("knife", "Is this dangerous?"),
275
+ ("bread", "What is this made of?"),
276
+ ]
277
+
278
+ for obj, question in test_cases:
279
+ print(f"\n📝 Object: {obj}")
280
+ print(f"❓ Question: {question}")
281
+
282
+ # Check if common-sense question
283
+ is_cs = kg.is_commonsense_question(question)
284
+ print(f"🔍 Common-sense: {is_cs}")
285
+
286
+ if is_cs:
287
+ # Get answer
288
+ answer = kg.answer_commonsense_question(obj, question)
289
+ print(f"💬 Answer: {answer}")
290
+
291
+ print("-" * 80)
llm_reasoning_service.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Reasoning Service for VQA
3
+ Uses Groq LLM for Chain-of-Thought reasoning instead of hardcoded rules
4
+ """
5
+ import os
6
+ from typing import Dict, List, Optional, Any
7
+ from groq import Groq
8
+ import json
9
+ class LLMReasoningService:
10
+ """
11
+ Service that uses Groq LLM for deductive reasoning from Wikidata facts.
12
+ Replaces hardcoded if/else rules with flexible Chain-of-Thought reasoning.
13
+ """
14
+ def __init__(self, api_key: Optional[str] = None, model: str = "llama-3.3-70b-versatile"):
15
+ """
16
+ Initialize LLM Reasoning service
17
+ Args:
18
+ api_key: Groq API key (if not provided, reads from GROQ_API_KEY env var)
19
+ model: Groq model to use for reasoning
20
+ """
21
+ self.api_key = api_key or os.getenv("GROQ_API_KEY")
22
+ if not self.api_key:
23
+ raise ValueError(
24
+ "Groq API key not found. Set GROQ_API_KEY environment variable "
25
+ "or pass api_key parameter"
26
+ )
27
+ self.client = Groq(api_key=self.api_key)
28
+ self.model = model
29
+ print(f"✅ LLM Reasoning Service initialized (model: {model})")
30
+ def reason_with_facts(
31
+ self,
32
+ object_name: str,
33
+ facts: Dict[str, Any],
34
+ question: str,
35
+ max_retries: int = 2
36
+ ) -> Dict[str, Any]:
37
+ """
38
+ Use LLM to reason about a question using Wikidata facts.
39
+ Args:
40
+ object_name: Name of the detected object (e.g., "candle")
41
+ facts: Dictionary of Wikidata facts about the object
42
+ question: User's question
43
+ max_retries: Number of retry attempts on failure
44
+ Returns:
45
+ Dict with 'answer', 'reasoning_chain', and 'confidence' keys
46
+ Example:
47
+ >>> service.reason_with_facts(
48
+ ... object_name="candle",
49
+ ... facts={"materials": ["wax"], "categories": ["light source"]},
50
+ ... question="Can this melt?"
51
+ ... )
52
+ {
53
+ 'answer': 'Yes, the candle can melt because it is made of wax...',
54
+ 'reasoning_chain': [
55
+ 'The object is a candle',
56
+ 'It is made of wax',
57
+ 'Wax has a low melting point',
58
+ 'Therefore, yes, it can melt'
59
+ ],
60
+ 'confidence': 0.95
61
+ }
62
+ """
63
+ prompt = self._build_reasoning_prompt(object_name, facts, question)
64
+ for attempt in range(max_retries + 1):
65
+ try:
66
+ response = self.client.chat.completions.create(
67
+ model=self.model,
68
+ messages=[
69
+ {
70
+ "role": "system",
71
+ "content": """You are an expert reasoning assistant for a Visual Question Answering system.
72
+ Your task is to use Chain-of-Thought reasoning to answer questions about objects based on factual knowledge.
73
+ IMPORTANT: Respond in JSON format with this structure:
74
+ {
75
+ "reasoning_chain": ["step 1", "step 2", "step 3"],
76
+ "answer": "final answer in natural language",
77
+ "confidence": 0.0-1.0
78
+ }
79
+ Keep reasoning steps clear and logical. The answer should be conversational and helpful."""
80
+ },
81
+ {
82
+ "role": "user",
83
+ "content": prompt
84
+ }
85
+ ],
86
+ temperature=0.3,
87
+ max_tokens=500,
88
+ response_format={"type": "json_object"}
89
+ )
90
+ content = response.choices[0].message.content.strip()
91
+ result = json.loads(content)
92
+ if not all(key in result for key in ['reasoning_chain', 'answer', 'confidence']):
93
+ raise ValueError("Invalid response structure from LLM")
94
+ return {
95
+ 'answer': result['answer'],
96
+ 'reasoning_chain': result['reasoning_chain'],
97
+ 'confidence': float(result['confidence']),
98
+ 'status': 'success',
99
+ 'model': self.model
100
+ }
101
+ except json.JSONDecodeError as e:
102
+ if attempt < max_retries:
103
+ continue
104
+ else:
105
+ return self._fallback_reasoning(object_name, facts, question)
106
+ except Exception as e:
107
+ if attempt < max_retries:
108
+ continue
109
+ else:
110
+ print(f"⚠️ LLM reasoning failed: {e}")
111
+ return self._fallback_reasoning(object_name, facts, question)
112
+ def _build_reasoning_prompt(
113
+ self,
114
+ object_name: str,
115
+ facts: Dict[str, Any],
116
+ question: str
117
+ ) -> str:
118
+ """
119
+ Build a Chain-of-Thought reasoning prompt.
120
+ Args:
121
+ object_name: Name of the object
122
+ facts: Wikidata facts about the object
123
+ question: User's question
124
+ Returns:
125
+ Formatted prompt string
126
+ """
127
+ facts_text = self._format_facts(facts)
128
+ prompt = f"""Question: {question}
129
+ Object Detected: {object_name}
130
+ Available Facts from Knowledge Graph:
131
+ {facts_text}
132
+ Task: Use Chain-of-Thought reasoning to answer the question based on the available facts.
133
+ Example of good reasoning:
134
+ Question: "Can this melt?"
135
+ Object: "ice cream"
136
+ Facts: {{
137
+ "categories": ["frozen dessert", "food"],
138
+ "materials": ["milk", "sugar", "cream"]
139
+ }}
140
+ Reasoning:
141
+ {{
142
+ "reasoning_chain": [
143
+ "The object is ice cream, which is a frozen dessert",
144
+ "Ice cream is made of milk, sugar, and cream",
145
+ "These ingredients are frozen to create ice cream",
146
+ "Frozen items melt when exposed to heat",
147
+ "Therefore, yes, ice cream can melt at room temperature"
148
+ ],
149
+ "answer": "Yes, ice cream can melt. It's a frozen dessert made from milk, sugar, and cream, which will melt when exposed to temperatures above freezing.",
150
+ "confidence": 0.95
151
+ }}
152
+ Now reason about the actual question above:"""
153
+ return prompt
154
+ def _format_facts(self, facts: Dict[str, Any]) -> str:
155
+ """Format facts dictionary into readable text."""
156
+ if not facts:
157
+ return "No specific facts available"
158
+ lines = []
159
+ for key, value in facts.items():
160
+ if isinstance(value, list):
161
+ if value:
162
+ lines.append(f" - {key}: {', '.join(str(v) for v in value)}")
163
+ elif value:
164
+ lines.append(f" - {key}: {value}")
165
+ return "\n".join(lines) if lines else "No specific facts available"
166
+ def _fallback_reasoning(
167
+ self,
168
+ object_name: str,
169
+ facts: Dict[str, Any],
170
+ question: str
171
+ ) -> Dict[str, Any]:
172
+ """
173
+ Fallback reasoning when LLM fails.
174
+ Uses simple rule-based approach.
175
+ Args:
176
+ object_name: Name of the object
177
+ facts: Wikidata facts
178
+ question: User's question
179
+ Returns:
180
+ Fallback reasoning result
181
+ """
182
+ q_lower = question.lower()
183
+ if 'melt' in q_lower:
184
+ materials = facts.get('materials', [])
185
+ if any(m in ['wax', 'ice', 'chocolate', 'butter'] for m in materials):
186
+ return {
187
+ 'answer': f"Yes, {object_name} can melt as it contains materials with low melting points.",
188
+ 'reasoning_chain': [
189
+ f"The {object_name} contains materials that can melt",
190
+ "These materials have low melting points",
191
+ "Therefore, it can melt when heated"
192
+ ],
193
+ 'confidence': 0.7,
194
+ 'status': 'fallback'
195
+ }
196
+ if 'edible' in q_lower or 'eat' in q_lower:
197
+ categories = facts.get('categories', [])
198
+ if any('food' in str(c).lower() for c in categories):
199
+ return {
200
+ 'answer': f"Yes, {object_name} is edible as it is categorized as food.",
201
+ 'reasoning_chain': [
202
+ f"The {object_name} is categorized as food",
203
+ "Food items are generally edible",
204
+ "Therefore, it is edible"
205
+ ],
206
+ 'confidence': 0.8,
207
+ 'status': 'fallback'
208
+ }
209
+ return {
210
+ 'answer': f"Based on the available information about {object_name}, I cannot provide a definitive answer to this question.",
211
+ 'reasoning_chain': [
212
+ f"Analyzing {object_name}",
213
+ "Available facts are limited",
214
+ "Cannot make a confident conclusion"
215
+ ],
216
+ 'confidence': 0.3,
217
+ 'status': 'fallback_generic'
218
+ }
219
+ def batch_reason(
220
+ self,
221
+ reasoning_tasks: List[Dict[str, Any]]
222
+ ) -> List[Dict[str, Any]]:
223
+ """
224
+ Perform reasoning on multiple tasks.
225
+ Args:
226
+ reasoning_tasks: List of dicts with 'object_name', 'facts', 'question' keys
227
+ Returns:
228
+ List of reasoning results
229
+ """
230
+ results = []
231
+ for task in reasoning_tasks:
232
+ result = self.reason_with_facts(
233
+ object_name=task.get('object_name', ''),
234
+ facts=task.get('facts', {}),
235
+ question=task.get('question', '')
236
+ )
237
+ results.append(result)
238
+ return results
239
+ _llm_reasoning_instance = None
240
+ def get_llm_reasoning_service(api_key: Optional[str] = None) -> LLMReasoningService:
241
+ """
242
+ Get or create LLM Reasoning service singleton
243
+ Args:
244
+ api_key: Optional API key (uses env var if not provided)
245
+ Returns:
246
+ LLMReasoningService instance
247
+ """
248
+ global _llm_reasoning_instance
249
+ if _llm_reasoning_instance is None:
250
+ _llm_reasoning_instance = LLMReasoningService(api_key=api_key)
251
+ return _llm_reasoning_instance
252
+ if __name__ == "__main__":
253
+ print("=" * 80)
254
+ print("🧪 Testing LLM Reasoning Service")
255
+ print("=" * 80)
256
+ try:
257
+ service = get_llm_reasoning_service()
258
+ print("\n📝 Test 1: Can a candle melt?")
259
+ result = service.reason_with_facts(
260
+ object_name="candle",
261
+ facts={
262
+ "materials": ["wax", "wick"],
263
+ "categories": ["light source", "household item"],
264
+ "uses": ["provide light", "decoration"]
265
+ },
266
+ question="Can this melt?"
267
+ )
268
+ print(f"Answer: {result['answer']}")
269
+ print(f"Reasoning Chain:")
270
+ for i, step in enumerate(result['reasoning_chain'], 1):
271
+ print(f" {i}. {step}")
272
+ print(f"Confidence: {result['confidence']}")
273
+ print("\n📝 Test 2: Would ice cream survive in the desert?")
274
+ result = service.reason_with_facts(
275
+ object_name="ice cream",
276
+ facts={
277
+ "materials": ["milk", "sugar", "cream"],
278
+ "categories": ["frozen dessert", "food"],
279
+ "properties": ["cold", "frozen"]
280
+ },
281
+ question="Would this survive in the desert?"
282
+ )
283
+ print(f"Answer: {result['answer']}")
284
+ print(f"Reasoning Chain:")
285
+ for i, step in enumerate(result['reasoning_chain'], 1):
286
+ print(f" {i}. {step}")
287
+ print(f"Confidence: {result['confidence']}")
288
+ print("\n" + "=" * 80)
289
+ print("✅ Tests completed!")
290
+ except ValueError as e:
291
+ print(f"\n❌ Error: {e}")
292
+ print("Please set GROQ_API_KEY environment variable")
model.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import clip
4
+ from transformers import GPT2Model
5
+ class AttentionDecoder(nn.Module):
6
+ def __init__(self, hidden_size, vocab_size, num_layers=1, dropout=0.3):
7
+ super().__init__()
8
+ self.hidden_size = hidden_size
9
+ self.num_layers = num_layers
10
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
11
+ self.attention = nn.Linear(hidden_size * 2, 1)
12
+ self.gru = nn.GRU(
13
+ input_size=hidden_size * 2,
14
+ hidden_size=hidden_size,
15
+ num_layers=num_layers,
16
+ batch_first=True,
17
+ dropout=dropout if num_layers > 1 else 0
18
+ )
19
+ self.ln_gru = nn.LayerNorm(hidden_size)
20
+ self.output = nn.Linear(hidden_size, vocab_size)
21
+ def forward(self, input_ids, context, hidden):
22
+ if input_ids.dim() == 1:
23
+ input_ids = input_ids.unsqueeze(1)
24
+ embeddings = self.embedding(input_ids).float()
25
+ context_expanded = context.unsqueeze(1).expand(-1, embeddings.size(1), -1)
26
+ combined = torch.cat([embeddings, context_expanded], dim=-1)
27
+ attn_weights = torch.softmax(self.attention(combined), dim=1)
28
+ attended_context = (context_expanded * attn_weights).sum(dim=1, keepdim=True)
29
+ gru_input = torch.cat([embeddings, attended_context.expand(-1, embeddings.size(1), -1)], dim=-1)
30
+ gru_output, hidden = self.gru(gru_input, hidden)
31
+ gru_output = self.ln_gru(gru_output)
32
+ return self.output(gru_output), hidden
33
+ class VQAModel(nn.Module):
34
+ def __init__(
35
+ self,
36
+ vocab_size=3600,
37
+ question_max_len=16,
38
+ answer_max_len=10,
39
+ hidden_size=512,
40
+ num_layers=2,
41
+ dropout=0.3,
42
+ device='cuda',
43
+ pad_token_id=0,
44
+ bos_token_id=1,
45
+ eos_token_id=2,
46
+ unk_token_id=3
47
+ ):
48
+ super().__init__()
49
+ self.device = device
50
+ self.question_max_len = question_max_len
51
+ self.answer_max_len = answer_max_len
52
+ self.vocab_size = vocab_size
53
+ self.hidden_size = hidden_size
54
+ self.num_layers = num_layers
55
+ self.fine_tuning_mode = False
56
+ self.pad_token_id = pad_token_id
57
+ self.bos_token_id = bos_token_id
58
+ self.eos_token_id = eos_token_id
59
+ self.unk_token_id = unk_token_id
60
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=device)
61
+ for p in self.clip_model.parameters():
62
+ p.requires_grad = False
63
+ self.gpt2_model = GPT2Model.from_pretrained("distilgpt2")
64
+ self.gpt2_model.to(device)
65
+ for p in self.gpt2_model.parameters():
66
+ p.requires_grad = False
67
+ self.img_proj = nn.Linear(512, hidden_size)
68
+ self.q_proj = nn.Linear(768, hidden_size)
69
+ self.gate_layer = nn.Linear(hidden_size*2, hidden_size)
70
+ self.fusion = nn.Sequential(
71
+ nn.Linear(hidden_size*3, hidden_size),
72
+ nn.ReLU(),
73
+ nn.Dropout(dropout),
74
+ nn.Linear(hidden_size, hidden_size)
75
+ )
76
+ self.decoder = AttentionDecoder(hidden_size, vocab_size, num_layers, dropout)
77
+ def unfreeze_clip_layers(self, num_layers=2):
78
+ self.clip_model.train()
79
+ self.clip_model.visual.float()
80
+ total_blocks = len(self.clip_model.visual.transformer.resblocks)
81
+ for i, block in enumerate(self.clip_model.visual.transformer.resblocks):
82
+ if i >= total_blocks - num_layers:
83
+ for p in block.parameters():
84
+ p.requires_grad = True
85
+ if hasattr(self.clip_model.visual, "proj") and self.clip_model.visual.proj is not None:
86
+ if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
87
+ self.clip_model.visual.proj.requires_grad = True
88
+ else:
89
+ for p in self.clip_model.visual.proj.parameters():
90
+ p.requires_grad = True
91
+ if hasattr(self.clip_model.visual, "ln_post"):
92
+ for p in self.clip_model.visual.ln_post.parameters():
93
+ p.requires_grad = True
94
+ self.fine_tuning_mode = True
95
+ print(f"Unfrozen last {num_layers} CLIP layers")
96
+ def unfreeze_gpt2_layers(self, num_layers=1):
97
+ self.gpt2_model.train()
98
+ total_layers = len(self.gpt2_model.h)
99
+ for i, layer in enumerate(self.gpt2_model.h):
100
+ if i >= total_layers - num_layers:
101
+ for p in layer.parameters():
102
+ p.requires_grad = True
103
+ p.data = p.data.float()
104
+ for p in self.gpt2_model.ln_f.parameters():
105
+ p.requires_grad = True
106
+ p.data = p.data.float()
107
+ self.fine_tuning_mode = True
108
+ print(f"Unfrozen last {num_layers} GPT-2 layers")
109
+ def encode_image(self, images):
110
+ if self.fine_tuning_mode:
111
+ images = images.float()
112
+ img_features = self.clip_model.encode_image(images)
113
+ else:
114
+ with torch.no_grad():
115
+ img_features = self.clip_model.encode_image(images)
116
+ img_features = img_features / img_features.norm(dim=-1, keepdim=True)
117
+ return img_features.float()
118
+ def encode_question(self, input_ids, attention_mask):
119
+ if self.fine_tuning_mode:
120
+ outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
121
+ else:
122
+ with torch.no_grad():
123
+ outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
124
+ last_hidden = outputs.last_hidden_state
125
+ mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
126
+ masked = last_hidden * mask
127
+ sum_hidden = masked.sum(dim=1)
128
+ lengths = mask.sum(dim=1).clamp(min=1e-6)
129
+ text_features = sum_hidden / lengths
130
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
131
+ return text_features.float()
132
+ def fuse_features(self, img_features, q_features):
133
+ x = torch.cat([img_features, q_features], dim=-1)
134
+ gate = torch.sigmoid(self.gate_layer(x))
135
+ fused = gate * img_features + (1-gate) * q_features
136
+ fused = self.fusion(torch.cat([fused, x], dim=-1))
137
+ return fused
138
+ def forward(self, images, questions, answer_input_ids=None):
139
+ img_features = self.encode_image(images)
140
+ img_features = self.img_proj(img_features).float()
141
+ q_features = self.encode_question(questions["input_ids"], questions["attention_mask"])
142
+ q_features = self.q_proj(q_features).float()
143
+ batch_size = img_features.size(0)
144
+ context = self.fuse_features(img_features, q_features)
145
+ hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size,
146
+ device=self.device, dtype=torch.float)
147
+ if answer_input_ids is not None:
148
+ logits, _ = self.decoder(answer_input_ids, context, hidden)
149
+ return logits
150
+ else:
151
+ generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id,
152
+ dtype=torch.long, device=self.device)
153
+ generated[:, 0] = self.bos_token_id
154
+ for t in range(1, self.answer_max_len):
155
+ current_input = generated[:, t-1]
156
+ logits, hidden = self.decoder(current_input, context, hidden)
157
+ next_tokens = logits.squeeze(1).argmax(dim=-1)
158
+ generated[:, t] = next_tokens
159
+ if (next_tokens == self.eos_token_id).all():
160
+ break
161
+ return generated
162
+ def generate_with_beam_search(self, images, questions, beam_width=5):
163
+ batch_size = images.size(0)
164
+ all_results = []
165
+ for b in range(batch_size):
166
+ img = images[b:b+1]
167
+ q_ids = questions["input_ids"][b:b+1]
168
+ q_mask = questions["attention_mask"][b:b+1]
169
+ img_features = self.encode_image(img)
170
+ img_features = self.img_proj(img_features).float()
171
+ q_features = self.encode_question(q_ids, q_mask)
172
+ q_features = self.q_proj(q_features).float()
173
+ context = self.fuse_features(img_features, q_features)
174
+ initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size,
175
+ device=self.device, dtype=torch.float)
176
+ beams = [(
177
+ torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device),
178
+ 0.0,
179
+ initial_hidden
180
+ )]
181
+ completed_beams = []
182
+ for t in range(1, self.answer_max_len):
183
+ candidates = []
184
+ for seq, score, hidden in beams:
185
+ if seq[0, -1].item() == self.eos_token_id:
186
+ completed_beams.append((seq, score))
187
+ continue
188
+ current_input = seq[:, -1]
189
+ logits, new_hidden = self.decoder(current_input, context, hidden)
190
+ log_probs = torch.log_softmax(logits.squeeze(1), dim=-1)
191
+ top_log_probs, top_indices = torch.topk(log_probs[0], beam_width)
192
+ for i in range(beam_width):
193
+ next_token = top_indices[i].unsqueeze(0).unsqueeze(0)
194
+ new_seq = torch.cat([seq, next_token], dim=1)
195
+ new_score = score + top_log_probs[i].item()
196
+ candidates.append((new_seq, new_score, new_hidden))
197
+ if len(candidates) == 0:
198
+ break
199
+ beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
200
+ all_beams = completed_beams + [(seq, score) for seq, score, _ in beams]
201
+ if len(all_beams) == 0:
202
+ result = torch.full((1, self.answer_max_len), self.pad_token_id,
203
+ dtype=torch.long, device=self.device)
204
+ else:
205
+ best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7))
206
+ result = torch.full((1, self.answer_max_len), self.pad_token_id,
207
+ dtype=torch.long, device=self.device)
208
+ seq_len = min(best_beam[0].size(1), self.answer_max_len)
209
+ result[:, :seq_len] = best_beam[0][:, :seq_len]
210
+ all_results.append(result)
211
+ return torch.cat(all_results, dim=0)
212
+ if __name__ == "__main__":
213
+ device = "cuda"
214
+ model = VQAModel(device=device).to(device)
215
+ model.eval()
216
+ fake_image = torch.randn(1, 3, 224, 224).to(device)
217
+ fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0]]).to(device)
218
+ fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0]]).to(device)
219
+ question_batch = {
220
+ "input_ids": fake_question_ids,
221
+ "attention_mask": fake_question_mask
222
+ }
223
+ output = model(fake_image, question_batch)
224
+ print(output)
model_spatial.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import clip
4
+ from transformers import GPT2Model
5
+ import math
6
+ class SpatialAdapter(nn.Module):
7
+ """
8
+ Spatial Adapter with Multi-Head Cross-Attention for spatial reasoning.
9
+ Processes CLIP patch features (14x14 grid) with question guidance.
10
+ """
11
+ def __init__(self, patch_dim=512, question_dim=512, hidden_dim=512, num_heads=8, dropout=0.3):
12
+ super().__init__()
13
+ self.hidden_dim = hidden_dim
14
+ self.num_heads = num_heads
15
+ self.head_dim = hidden_dim // num_heads
16
+ assert hidden_dim % num_heads == 0, "hidden_dim must be divisible by num_heads"
17
+ self.register_buffer('pos_encoding_2d', self._create_2d_positional_encoding(14, 14, patch_dim))
18
+ self.patch_proj = nn.Linear(patch_dim, hidden_dim)
19
+ self.question_proj = nn.Linear(question_dim, hidden_dim)
20
+ self.cross_attn_query = nn.Linear(hidden_dim, hidden_dim)
21
+ self.cross_attn_key = nn.Linear(hidden_dim, hidden_dim)
22
+ self.cross_attn_value = nn.Linear(hidden_dim, hidden_dim)
23
+ self.cross_attn_out = nn.Linear(hidden_dim, hidden_dim)
24
+ self.self_attn_query = nn.Linear(hidden_dim, hidden_dim)
25
+ self.self_attn_key = nn.Linear(hidden_dim, hidden_dim)
26
+ self.self_attn_value = nn.Linear(hidden_dim, hidden_dim)
27
+ self.self_attn_out = nn.Linear(hidden_dim, hidden_dim)
28
+ self.ffn = nn.Sequential(
29
+ nn.Linear(hidden_dim, hidden_dim * 4),
30
+ nn.GELU(),
31
+ nn.Dropout(dropout),
32
+ nn.Linear(hidden_dim * 4, hidden_dim),
33
+ nn.Dropout(dropout)
34
+ )
35
+ self.ln1 = nn.LayerNorm(hidden_dim)
36
+ self.ln2 = nn.LayerNorm(hidden_dim)
37
+ self.ln3 = nn.LayerNorm(hidden_dim)
38
+ self.dropout = nn.Dropout(dropout)
39
+ def _create_2d_positional_encoding(self, height, width, dim):
40
+ """Create 2D positional encoding for spatial grid"""
41
+ pos_h = torch.arange(height).unsqueeze(1).repeat(1, width).flatten()
42
+ pos_w = torch.arange(width).unsqueeze(0).repeat(height, 1).flatten()
43
+ pe = torch.zeros(height * width, dim)
44
+ div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
45
+ pe[:, 0:dim//2:2] = torch.sin(pos_h.unsqueeze(1) * div_term[:dim//4])
46
+ pe[:, 1:dim//2:2] = torch.cos(pos_h.unsqueeze(1) * div_term[:dim//4])
47
+ pe[:, dim//2::2] = torch.sin(pos_w.unsqueeze(1) * div_term[:dim//4])
48
+ pe[:, dim//2+1::2] = torch.cos(pos_w.unsqueeze(1) * div_term[:dim//4])
49
+ return pe.unsqueeze(0)
50
+ def _multi_head_attention(self, query, key, value, num_heads):
51
+ """Generic multi-head attention implementation"""
52
+ batch_size = query.size(0)
53
+ Q = query.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
54
+ K = key.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
55
+ V = value.view(batch_size, -1, num_heads, self.head_dim).transpose(1, 2)
56
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
57
+ attn_weights = torch.softmax(scores, dim=-1)
58
+ attn_weights = self.dropout(attn_weights)
59
+ context = torch.matmul(attn_weights, V)
60
+ context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
61
+ return context, attn_weights
62
+ def forward(self, patch_features, question_features):
63
+ """
64
+ Args:
65
+ patch_features: [batch_size, num_patches, patch_dim] - CLIP patch features
66
+ question_features: [batch_size, question_dim] - Question encoding
67
+ Returns:
68
+ spatial_context: [batch_size, hidden_dim] - Spatially-aware context
69
+ """
70
+ batch_size, num_patches, _ = patch_features.shape
71
+ patch_features = patch_features + self.pos_encoding_2d[:, :num_patches, :].to(patch_features.device)
72
+ patches = self.patch_proj(patch_features)
73
+ question = self.question_proj(question_features.unsqueeze(1))
74
+ Q_cross = self.cross_attn_query(patches)
75
+ K_cross = self.cross_attn_key(question)
76
+ V_cross = self.cross_attn_value(question)
77
+ cross_context, _ = self._multi_head_attention(Q_cross, K_cross, V_cross, self.num_heads)
78
+ cross_out = self.cross_attn_out(cross_context)
79
+ patches = self.ln1(patches + self.dropout(cross_out))
80
+ Q_self = self.self_attn_query(patches)
81
+ K_self = self.self_attn_key(patches)
82
+ V_self = self.self_attn_value(patches)
83
+ self_context, _ = self._multi_head_attention(Q_self, K_self, V_self, self.num_heads)
84
+ self_out = self.self_attn_out(self_context)
85
+ patches = self.ln2(patches + self.dropout(self_out))
86
+ ffn_out = self.ffn(patches)
87
+ patches = self.ln3(patches + ffn_out)
88
+ attn_scores = torch.matmul(patches, question.transpose(1, 2))
89
+ attn_weights = torch.softmax(attn_scores, dim=1)
90
+ spatial_context = (patches * attn_weights).sum(dim=1)
91
+ return spatial_context
92
+ class VQAModelWithSpatialAdapter(nn.Module):
93
+ """
94
+ Enhanced VQA Model with Spatial Adapter for spatial reasoning.
95
+ Uses patch-based CLIP features instead of global encoding.
96
+ """
97
+ def __init__(
98
+ self,
99
+ base_model,
100
+ hidden_size=512,
101
+ num_heads=8,
102
+ dropout=0.3
103
+ ):
104
+ super().__init__()
105
+ self.device = base_model.device
106
+ self.question_max_len = base_model.question_max_len
107
+ self.answer_max_len = base_model.answer_max_len
108
+ self.vocab_size = base_model.vocab_size
109
+ self.hidden_size = hidden_size
110
+ self.num_layers = base_model.num_layers
111
+ self.fine_tuning_mode = base_model.fine_tuning_mode
112
+ self.pad_token_id = base_model.pad_token_id
113
+ self.bos_token_id = base_model.bos_token_id
114
+ self.eos_token_id = base_model.eos_token_id
115
+ self.unk_token_id = base_model.unk_token_id
116
+ self.clip_model = base_model.clip_model
117
+ self.clip_preprocess = base_model.clip_preprocess
118
+ self.gpt2_model = base_model.gpt2_model
119
+ self.decoder = base_model.decoder
120
+ self.spatial_adapter = SpatialAdapter(
121
+ patch_dim=512,
122
+ question_dim=768,
123
+ hidden_dim=hidden_size,
124
+ num_heads=num_heads,
125
+ dropout=dropout
126
+ )
127
+ self.spatial_context_proj = nn.Linear(hidden_size, hidden_size)
128
+ self.q_proj = nn.Linear(768, hidden_size)
129
+ self.spatial_fusion = nn.Sequential(
130
+ nn.Linear(hidden_size * 2, hidden_size),
131
+ nn.GELU(),
132
+ nn.Dropout(dropout),
133
+ nn.Linear(hidden_size, hidden_size),
134
+ nn.LayerNorm(hidden_size)
135
+ )
136
+ def extract_clip_patch_features(self, images):
137
+ """
138
+ Extract patch features from CLIP instead of global features.
139
+ Returns: [batch_size, num_patches, patch_dim]
140
+ """
141
+ clip_dtype = self.clip_model.visual.conv1.weight.dtype
142
+ images = images.to(clip_dtype)
143
+ if self.fine_tuning_mode:
144
+ x = self.clip_model.visual.conv1(images)
145
+ x = x.reshape(x.shape[0], x.shape[1], -1)
146
+ x = x.permute(0, 2, 1)
147
+ class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(
148
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
149
+ )
150
+ x = torch.cat([class_token, x], dim=1)
151
+ x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
152
+ x = self.clip_model.visual.ln_pre(x)
153
+ x = x.permute(1, 0, 2)
154
+ x = self.clip_model.visual.transformer(x)
155
+ x = x.permute(1, 0, 2)
156
+ patch_features = x[:, 1:, :]
157
+ if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
158
+ if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
159
+ patch_features = patch_features @ self.clip_model.visual.proj
160
+ else:
161
+ patch_features = self.clip_model.visual.proj(patch_features)
162
+ else:
163
+ with torch.no_grad():
164
+ x = self.clip_model.visual.conv1(images)
165
+ x = x.reshape(x.shape[0], x.shape[1], -1)
166
+ x = x.permute(0, 2, 1)
167
+ class_token = self.clip_model.visual.class_embedding.to(x.dtype) + torch.zeros(
168
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
169
+ )
170
+ x = torch.cat([class_token, x], dim=1)
171
+ x = x + self.clip_model.visual.positional_embedding.to(x.dtype)
172
+ x = self.clip_model.visual.ln_pre(x)
173
+ x = x.permute(1, 0, 2)
174
+ x = self.clip_model.visual.transformer(x)
175
+ x = x.permute(1, 0, 2)
176
+ patch_features = x[:, 1:, :]
177
+ if hasattr(self.clip_model.visual, 'proj') and self.clip_model.visual.proj is not None:
178
+ if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
179
+ patch_features = patch_features @ self.clip_model.visual.proj
180
+ else:
181
+ patch_features = self.clip_model.visual.proj(patch_features)
182
+ return patch_features.float()
183
+ def encode_question(self, input_ids, attention_mask):
184
+ """Encode question using GPT-2 (same as base model)"""
185
+ if self.fine_tuning_mode:
186
+ outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
187
+ else:
188
+ with torch.no_grad():
189
+ outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
190
+ last_hidden = outputs.last_hidden_state
191
+ mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
192
+ masked = last_hidden * mask
193
+ sum_hidden = masked.sum(dim=1)
194
+ lengths = mask.sum(dim=1).clamp(min=1e-6)
195
+ text_features = sum_hidden / lengths
196
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
197
+ return text_features.float()
198
+ def forward(self, images, questions, answer_input_ids=None):
199
+ """
200
+ Forward pass with spatial adapter.
201
+ """
202
+ patch_features = self.extract_clip_patch_features(images)
203
+ q_features = self.encode_question(questions["input_ids"], questions["attention_mask"])
204
+ spatial_context = self.spatial_adapter(patch_features, q_features)
205
+ spatial_context = self.spatial_context_proj(spatial_context)
206
+ q_projected = self.q_proj(q_features)
207
+ fused = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1))
208
+ batch_size = images.size(0)
209
+ hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size,
210
+ device=self.device, dtype=torch.float)
211
+ if answer_input_ids is not None:
212
+ logits, _ = self.decoder(answer_input_ids, fused, hidden)
213
+ return logits
214
+ else:
215
+ generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id,
216
+ dtype=torch.long, device=self.device)
217
+ generated[:, 0] = self.bos_token_id
218
+ for t in range(1, self.answer_max_len):
219
+ current_input = generated[:, t-1]
220
+ logits, hidden = self.decoder(current_input, fused, hidden)
221
+ next_tokens = logits.squeeze(1).argmax(dim=-1)
222
+ generated[:, t] = next_tokens
223
+ if (next_tokens == self.eos_token_id).all():
224
+ break
225
+ return generated
226
+ def generate_with_beam_search(self, images, questions, beam_width=5):
227
+ """Beam search generation (same as base model but with spatial features)"""
228
+ batch_size = images.size(0)
229
+ all_results = []
230
+ for b in range(batch_size):
231
+ img = images[b:b+1]
232
+ q_ids = questions["input_ids"][b:b+1]
233
+ q_mask = questions["attention_mask"][b:b+1]
234
+ patch_features = self.extract_clip_patch_features(img)
235
+ q_features = self.encode_question(q_ids, q_mask)
236
+ spatial_context = self.spatial_adapter(patch_features, q_features)
237
+ spatial_context = self.spatial_context_proj(spatial_context)
238
+ q_projected = self.q_proj(q_features)
239
+ context = self.spatial_fusion(torch.cat([spatial_context, q_projected], dim=-1))
240
+ initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size,
241
+ device=self.device, dtype=torch.float)
242
+ beams = [(
243
+ torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device),
244
+ 0.0,
245
+ initial_hidden
246
+ )]
247
+ completed_beams = []
248
+ for t in range(1, self.answer_max_len):
249
+ candidates = []
250
+ for seq, score, hidden in beams:
251
+ if seq[0, -1].item() == self.eos_token_id:
252
+ completed_beams.append((seq, score))
253
+ continue
254
+ current_input = seq[:, -1]
255
+ logits, new_hidden = self.decoder(current_input, context, hidden)
256
+ log_probs = torch.log_softmax(logits.squeeze(1), dim=-1)
257
+ top_log_probs, top_indices = torch.topk(log_probs[0], beam_width)
258
+ for i in range(beam_width):
259
+ next_token = top_indices[i].unsqueeze(0).unsqueeze(0)
260
+ new_seq = torch.cat([seq, next_token], dim=1)
261
+ new_score = score + top_log_probs[i].item()
262
+ candidates.append((new_seq, new_score, new_hidden))
263
+ if len(candidates) == 0:
264
+ break
265
+ beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
266
+ all_beams = completed_beams + [(seq, score) for seq, score, _ in beams]
267
+ if len(all_beams) == 0:
268
+ result = torch.full((1, self.answer_max_len), self.pad_token_id,
269
+ dtype=torch.long, device=self.device)
270
+ else:
271
+ best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7))
272
+ result = torch.full((1, self.answer_max_len), self.pad_token_id,
273
+ dtype=torch.long, device=self.device)
274
+ seq_len = min(best_beam[0].size(1), self.answer_max_len)
275
+ result[:, :seq_len] = best_beam[0][:, :seq_len]
276
+ all_results.append(result)
277
+ return torch.cat(all_results, dim=0)
278
+ if __name__ == "__main__":
279
+ print("Testing Spatial Adapter Architecture...")
280
+ device = "cuda" if torch.cuda.is_available() else "cpu"
281
+ from model import VQAModel
282
+ base_model = VQAModel(device=device).to(device)
283
+ spatial_model = VQAModelWithSpatialAdapter(base_model).to(device)
284
+ spatial_model.eval()
285
+ fake_image = torch.randn(2, 3, 224, 224).to(device)
286
+ fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0], [1, 15, 25, 35, 2, 0, 0]]).to(device)
287
+ fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 0, 0]]).to(device)
288
+ question_batch = {
289
+ "input_ids": fake_question_ids,
290
+ "attention_mask": fake_question_mask
291
+ }
292
+ print(f"\nInput shapes:")
293
+ print(f" Images: {fake_image.shape}")
294
+ print(f" Questions: {fake_question_ids.shape}")
295
+ with torch.no_grad():
296
+ patch_features = spatial_model.extract_clip_patch_features(fake_image)
297
+ print(f"\nPatch features shape: {patch_features.shape}")
298
+ print(f" Expected: [2, 196, 512] (batch_size, num_patches, patch_dim)")
299
+ output = spatial_model(fake_image, question_batch)
300
+ print(f"\nGenerated output shape: {output.shape}")
301
+ print(f" Expected: [2, {spatial_model.answer_max_len}]")
302
+ total_params = sum(p.numel() for p in spatial_model.parameters())
303
+ spatial_adapter_params = sum(p.numel() for p in spatial_model.spatial_adapter.parameters())
304
+ trainable_params = sum(p.numel() for p in spatial_model.parameters() if p.requires_grad)
305
+ print(f"\nParameter counts:")
306
+ print(f" Total parameters: {total_params:,}")
307
+ print(f" Spatial adapter parameters: {spatial_adapter_params:,}")
308
+ print(f" Trainable parameters: {trainable_params:,}")
309
+ print("\n✓ Spatial adapter architecture test passed!")
models/__pycache__/model.cpython-312.pyc ADDED
Binary file (16.5 kB). View file
 
models/model.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import clip
4
+ from transformers import GPT2Model
5
+ class AttentionDecoder(nn.Module):
6
+ def __init__(self, hidden_size, vocab_size, num_layers=1, dropout=0.3):
7
+ super().__init__()
8
+ self.hidden_size = hidden_size
9
+ self.num_layers = num_layers
10
+ self.embedding = nn.Embedding(vocab_size, hidden_size)
11
+ self.attention = nn.Linear(hidden_size * 2, 1)
12
+ self.gru = nn.GRU(
13
+ input_size=hidden_size * 2,
14
+ hidden_size=hidden_size,
15
+ num_layers=num_layers,
16
+ batch_first=True,
17
+ dropout=dropout if num_layers > 1 else 0
18
+ )
19
+ self.ln_gru = nn.LayerNorm(hidden_size)
20
+ self.output = nn.Linear(hidden_size, vocab_size)
21
+ def forward(self, input_ids, context, hidden):
22
+ if input_ids.dim() == 1:
23
+ input_ids = input_ids.unsqueeze(1)
24
+ embeddings = self.embedding(input_ids).float()
25
+ context_expanded = context.unsqueeze(1).expand(-1, embeddings.size(1), -1)
26
+ combined = torch.cat([embeddings, context_expanded], dim=-1)
27
+ attn_weights = torch.softmax(self.attention(combined), dim=1)
28
+ attended_context = (context_expanded * attn_weights).sum(dim=1, keepdim=True)
29
+ gru_input = torch.cat([embeddings, attended_context.expand(-1, embeddings.size(1), -1)], dim=-1)
30
+ gru_output, hidden = self.gru(gru_input, hidden)
31
+ gru_output = self.ln_gru(gru_output)
32
+ return self.output(gru_output), hidden
33
+ class VQAModel(nn.Module):
34
+ def __init__(
35
+ self,
36
+ vocab_size=3600,
37
+ question_max_len=16,
38
+ answer_max_len=10,
39
+ hidden_size=512,
40
+ num_layers=2,
41
+ dropout=0.3,
42
+ device='cuda',
43
+ pad_token_id=0,
44
+ bos_token_id=1,
45
+ eos_token_id=2,
46
+ unk_token_id=3
47
+ ):
48
+ super().__init__()
49
+ self.device = device
50
+ self.question_max_len = question_max_len
51
+ self.answer_max_len = answer_max_len
52
+ self.vocab_size = vocab_size
53
+ self.hidden_size = hidden_size
54
+ self.num_layers = num_layers
55
+ self.fine_tuning_mode = False
56
+ self.pad_token_id = pad_token_id
57
+ self.bos_token_id = bos_token_id
58
+ self.eos_token_id = eos_token_id
59
+ self.unk_token_id = unk_token_id
60
+ self.clip_model, self.clip_preprocess = clip.load("ViT-B/16", device=device)
61
+ for p in self.clip_model.parameters():
62
+ p.requires_grad = False
63
+ self.gpt2_model = GPT2Model.from_pretrained("distilgpt2")
64
+ self.gpt2_model.to(device)
65
+ for p in self.gpt2_model.parameters():
66
+ p.requires_grad = False
67
+ self.img_proj = nn.Linear(512, hidden_size)
68
+ self.q_proj = nn.Linear(768, hidden_size)
69
+ self.gate_layer = nn.Linear(hidden_size*2, hidden_size)
70
+ self.fusion = nn.Sequential(
71
+ nn.Linear(hidden_size*3, hidden_size),
72
+ nn.ReLU(),
73
+ nn.Dropout(dropout),
74
+ nn.Linear(hidden_size, hidden_size)
75
+ )
76
+ self.decoder = AttentionDecoder(hidden_size, vocab_size, num_layers, dropout)
77
+ def unfreeze_clip_layers(self, num_layers=2):
78
+ self.clip_model.train()
79
+ self.clip_model.visual.float()
80
+ total_blocks = len(self.clip_model.visual.transformer.resblocks)
81
+ for i, block in enumerate(self.clip_model.visual.transformer.resblocks):
82
+ if i >= total_blocks - num_layers:
83
+ for p in block.parameters():
84
+ p.requires_grad = True
85
+ if hasattr(self.clip_model.visual, "proj") and self.clip_model.visual.proj is not None:
86
+ if isinstance(self.clip_model.visual.proj, torch.nn.Parameter):
87
+ self.clip_model.visual.proj.requires_grad = True
88
+ else:
89
+ for p in self.clip_model.visual.proj.parameters():
90
+ p.requires_grad = True
91
+ if hasattr(self.clip_model.visual, "ln_post"):
92
+ for p in self.clip_model.visual.ln_post.parameters():
93
+ p.requires_grad = True
94
+ self.fine_tuning_mode = True
95
+ print(f"Unfrozen last {num_layers} CLIP layers")
96
+ def unfreeze_gpt2_layers(self, num_layers=1):
97
+ self.gpt2_model.train()
98
+ total_layers = len(self.gpt2_model.h)
99
+ for i, layer in enumerate(self.gpt2_model.h):
100
+ if i >= total_layers - num_layers:
101
+ for p in layer.parameters():
102
+ p.requires_grad = True
103
+ p.data = p.data.float()
104
+ for p in self.gpt2_model.ln_f.parameters():
105
+ p.requires_grad = True
106
+ p.data = p.data.float()
107
+ self.fine_tuning_mode = True
108
+ print(f"Unfrozen last {num_layers} GPT-2 layers")
109
+ def encode_image(self, images):
110
+ if self.fine_tuning_mode:
111
+ images = images.float()
112
+ img_features = self.clip_model.encode_image(images)
113
+ else:
114
+ with torch.no_grad():
115
+ img_features = self.clip_model.encode_image(images)
116
+ img_features = img_features / img_features.norm(dim=-1, keepdim=True)
117
+ return img_features.float()
118
+ def encode_question(self, input_ids, attention_mask):
119
+ if self.fine_tuning_mode:
120
+ outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
121
+ else:
122
+ with torch.no_grad():
123
+ outputs = self.gpt2_model(input_ids=input_ids, attention_mask=attention_mask)
124
+ last_hidden = outputs.last_hidden_state
125
+ mask = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
126
+ masked = last_hidden * mask
127
+ sum_hidden = masked.sum(dim=1)
128
+ lengths = mask.sum(dim=1).clamp(min=1e-6)
129
+ text_features = sum_hidden / lengths
130
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
131
+ return text_features.float()
132
+ def fuse_features(self, img_features, q_features):
133
+ x = torch.cat([img_features, q_features], dim=-1)
134
+ gate = torch.sigmoid(self.gate_layer(x))
135
+ fused = gate * img_features + (1-gate) * q_features
136
+ fused = self.fusion(torch.cat([fused, x], dim=-1))
137
+ return fused
138
+ def forward(self, images, questions, answer_input_ids=None):
139
+ img_features = self.encode_image(images)
140
+ img_features = self.img_proj(img_features).float()
141
+ q_features = self.encode_question(questions["input_ids"], questions["attention_mask"])
142
+ q_features = self.q_proj(q_features).float()
143
+ batch_size = img_features.size(0)
144
+ context = self.fuse_features(img_features, q_features)
145
+ hidden = torch.zeros(self.num_layers, batch_size, self.hidden_size,
146
+ device=self.device, dtype=torch.float)
147
+ if answer_input_ids is not None:
148
+ logits, _ = self.decoder(answer_input_ids, context, hidden)
149
+ return logits
150
+ else:
151
+ generated = torch.full((batch_size, self.answer_max_len), self.pad_token_id,
152
+ dtype=torch.long, device=self.device)
153
+ generated[:, 0] = self.bos_token_id
154
+ for t in range(1, self.answer_max_len):
155
+ current_input = generated[:, t-1]
156
+ logits, hidden = self.decoder(current_input, context, hidden)
157
+ next_tokens = logits.squeeze(1).argmax(dim=-1)
158
+ generated[:, t] = next_tokens
159
+ if (next_tokens == self.eos_token_id).all():
160
+ break
161
+ return generated
162
+ def generate_with_beam_search(self, images, questions, beam_width=5):
163
+ batch_size = images.size(0)
164
+ all_results = []
165
+ for b in range(batch_size):
166
+ img = images[b:b+1]
167
+ q_ids = questions["input_ids"][b:b+1]
168
+ q_mask = questions["attention_mask"][b:b+1]
169
+ img_features = self.encode_image(img)
170
+ img_features = self.img_proj(img_features).float()
171
+ q_features = self.encode_question(q_ids, q_mask)
172
+ q_features = self.q_proj(q_features).float()
173
+ context = self.fuse_features(img_features, q_features)
174
+ initial_hidden = torch.zeros(self.num_layers, 1, self.hidden_size,
175
+ device=self.device, dtype=torch.float)
176
+ beams = [(
177
+ torch.full((1, 1), self.bos_token_id, dtype=torch.long, device=self.device),
178
+ 0.0,
179
+ initial_hidden
180
+ )]
181
+ completed_beams = []
182
+ for t in range(1, self.answer_max_len):
183
+ candidates = []
184
+ for seq, score, hidden in beams:
185
+ if seq[0, -1].item() == self.eos_token_id:
186
+ completed_beams.append((seq, score))
187
+ continue
188
+ current_input = seq[:, -1]
189
+ logits, new_hidden = self.decoder(current_input, context, hidden)
190
+ log_probs = torch.log_softmax(logits.squeeze(1), dim=-1)
191
+ top_log_probs, top_indices = torch.topk(log_probs[0], beam_width)
192
+ for i in range(beam_width):
193
+ next_token = top_indices[i].unsqueeze(0).unsqueeze(0)
194
+ new_seq = torch.cat([seq, next_token], dim=1)
195
+ new_score = score + top_log_probs[i].item()
196
+ candidates.append((new_seq, new_score, new_hidden))
197
+ if len(candidates) == 0:
198
+ break
199
+ beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
200
+ all_beams = completed_beams + [(seq, score) for seq, score, _ in beams]
201
+ if len(all_beams) == 0:
202
+ result = torch.full((1, self.answer_max_len), self.pad_token_id,
203
+ dtype=torch.long, device=self.device)
204
+ else:
205
+ best_beam = max(all_beams, key=lambda x: x[1] / (x[0].size(1) ** 0.7))
206
+ result = torch.full((1, self.answer_max_len), self.pad_token_id,
207
+ dtype=torch.long, device=self.device)
208
+ seq_len = min(best_beam[0].size(1), self.answer_max_len)
209
+ result[:, :seq_len] = best_beam[0][:, :seq_len]
210
+ all_results.append(result)
211
+ return torch.cat(all_results, dim=0)
212
+ if __name__ == "__main__":
213
+ device = "cuda"
214
+ model = VQAModel(device=device).to(device)
215
+ model.eval()
216
+ fake_image = torch.randn(1, 3, 224, 224).to(device)
217
+ fake_question_ids = torch.tensor([[1, 10, 20, 30, 2, 0, 0]]).to(device)
218
+ fake_question_mask = torch.tensor([[1, 1, 1, 1, 1, 0, 0]]).to(device)
219
+ question_batch = {
220
+ "input_ids": fake_question_ids,
221
+ "attention_mask": fake_question_mask
222
+ }
223
+ output = model(fake_image, question_batch)
224
+ print(output)
quick_start.bat ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ REM Quick Start Script for VQA Mobile App
3
+ REM This script helps you start the backend and frontend
4
+
5
+ echo ========================================
6
+ echo VQA Mobile App - Quick Start
7
+ echo ========================================
8
+ echo.
9
+
10
+ REM Get current IP address
11
+ echo [1/3] Checking your IP address...
12
+ for /f "tokens=2 delims=:" %%a in ('ipconfig ^| findstr /c:"IPv4"') do (
13
+ set IP=%%a
14
+ set IP=!IP:~1!
15
+ echo Your IP: !IP!
16
+ )
17
+
18
+ echo.
19
+ echo [2/3] Current Configuration:
20
+ echo Backend: http://10.215.4.143:8000
21
+ echo Frontend: ui/src/config/api.js
22
+ echo.
23
+
24
+ echo IMPORTANT: Make sure both laptop and mobile are on the SAME network!
25
+ echo.
26
+
27
+ echo [3/3] Choose an option:
28
+ echo 1. Start Backend (Python)
29
+ echo 2. Start Frontend (Expo)
30
+ echo 3. Start Both (Opens 2 terminals)
31
+ echo 4. Exit
32
+ echo.
33
+
34
+ choice /c 1234 /n /m "Enter your choice (1-4): "
35
+
36
+ if errorlevel 4 goto :end
37
+ if errorlevel 3 goto :both
38
+ if errorlevel 2 goto :frontend
39
+ if errorlevel 1 goto :backend
40
+
41
+ :backend
42
+ echo.
43
+ echo Starting Backend Server...
44
+ echo Make sure you have activated your Python environment!
45
+ echo.
46
+ python backend_api.py
47
+ goto :end
48
+
49
+ :frontend
50
+ echo.
51
+ echo Starting Expo Frontend...
52
+ cd ui
53
+ npx expo start
54
+ goto :end
55
+
56
+ :both
57
+ echo.
58
+ echo Starting both Backend and Frontend...
59
+ echo Opening Backend in new window...
60
+ start cmd /k "python backend_api.py"
61
+ timeout /t 3 /nobreak >nul
62
+ echo Opening Frontend in new window...
63
+ start cmd /k "cd ui && npx expo start"
64
+ echo.
65
+ echo Both servers are starting in separate windows!
66
+ goto :end
67
+
68
+ :end
69
+ echo.
70
+ echo Done!
71
+ pause
requirements_api.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.115.6
2
+ uvicorn>=0.34.0
3
+ python-multipart>=0.0.20
4
+ pillow>=11.1.0
5
+ torch>=2.0.0
6
+ torchvision>=0.15.0
7
+ transformers>=4.30.0
8
+ ftfy
9
+ regex
10
+ tqdm
11
+ git+https://github.com/openai/CLIP.git
12
+ groq>=0.4.0
13
+ python-dotenv>=1.0.0
14
+ huggingface-hub
scores/feature.txt ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ================================================================================
2
+ EVALUATION RESULTS
3
+ ================================================================================
4
+
5
+ 📊 Accuracy Metrics:
6
+ Exact Match Accuracy: 50.17% (63805/135256)
7
+ VQA Accuracy: 15.72%
8
+
9
+ 📊 ANLS Metrics:
10
+ Average ANLS (τ=0.5): 50.18%
11
+ ANLS Std Dev: 48.96%
12
+
13
+ 📊 Additional Statistics:
14
+ Total samples: 135256
15
+ Avg prediction length: 1.13 words
16
+ Avg GT length: 1.10 words
17
+
18
+ ================================================================================
19
+ SAMPLE PREDICTIONS
20
+ ================================================================================
21
+
22
+ 🏆 Best Predictions (Highest ANLS):
23
+ --------------------------------------------------------------------------------
24
+
25
+ Ground Truth: tusks
26
+ Prediction: tusks
27
+ ANLS: 1.0000
28
+ Exact Match: ✓
29
+
30
+ Ground Truth: seagull
31
+ Prediction: seagull
32
+ ANLS: 1.0000
33
+ Exact Match: ✓
34
+
35
+ Ground Truth: bedroom
36
+ Prediction: bedroom
37
+ ANLS: 1.0000
38
+ Exact Match: ✓
39
+
40
+ Ground Truth: cake
41
+ Prediction: cake
42
+ ANLS: 1.0000
43
+ Exact Match: ✓
44
+
45
+ Ground Truth: short
46
+ Prediction: short
47
+ ANLS: 1.0000
48
+ Exact Match: ✓
49
+
50
+ ================================================================================
51
+ ⚠️ Worst Predictions (Lowest ANLS):
52
+ --------------------------------------------------------------------------------
53
+
54
+ Ground Truth: mirror
55
+ Prediction: car
56
+ ANLS: 0.0000
57
+ Exact Match: ✗
58
+
59
+ Ground Truth: towel
60
+ Prediction: toy
61
+ ANLS: 0.0000
62
+ Exact Match: ✗
63
+
64
+ Ground Truth: book
65
+ Prediction: camera
66
+ ANLS: 0.0000
67
+ Exact Match: ✗
68
+
69
+ Ground Truth: usa
70
+ Prediction: england
71
+ ANLS: 0.0000
72
+ Exact Match: ✗
73
+
74
+ Ground Truth: red and yellow
75
+ Prediction: green
76
+ ANLS: 0.0000
77
+ Exact Match: ✗
scores/score.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pandas as pd
4
+ from PIL import Image
5
+ from transformers import GPT2Tokenizer
6
+ from model import VQAModel
7
+ from model_spatial import VQAModelWithSpatialAdapter
8
+ from train import Vocab
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+ try:
12
+ from Levenshtein import distance as levenshtein_distance
13
+ except ImportError:
14
+ print("Installing python-Levenshtein...")
15
+ import subprocess
16
+ subprocess.check_call(['pip', 'install', 'python-Levenshtein'])
17
+ from Levenshtein import distance as levenshtein_distance
18
+ MODEL_TYPE = "feature"
19
+ SPATIAL_CHECKPOINT = "./output2/spatial_adapter_v2_2/vqa_spatial_checkpoint.pt"
20
+ FEATURE_CHECKPOINT = "./output2/feature_extraction/vqa_checkpoint.pt"
21
+ CSV_PATH = "./gen_vqa_v2/metadata.csv"
22
+ IMG_DIR = "./gen_vqa_v2"
23
+ MAX_SAMPLES = None
24
+ def load_spatial_model(checkpoint_path, device='cuda'):
25
+ checkpoint = torch.load(checkpoint_path, map_location=device)
26
+ vocab = Vocab()
27
+ vocab.vocab = checkpoint['vocab']
28
+ vocab.vocab_size = len(checkpoint['vocab'])
29
+ vocab.word2idx = checkpoint['word2idx']
30
+ vocab.idx2word = checkpoint['idx2word']
31
+ vocab.pad_token_id = checkpoint['pad_token_id']
32
+ vocab.bos_token_id = checkpoint['bos_token_id']
33
+ vocab.eos_token_id = checkpoint['eos_token_id']
34
+ vocab.unk_token_id = checkpoint['unk_token_id']
35
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
36
+ if tokenizer.pad_token is None:
37
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
38
+ base_model = VQAModel(
39
+ vocab_size=len(checkpoint['vocab']),
40
+ device=device,
41
+ question_max_len=checkpoint.get('question_max_len', 20),
42
+ answer_max_len=checkpoint.get('answer_max_len', 12),
43
+ pad_token_id=checkpoint['pad_token_id'],
44
+ bos_token_id=checkpoint['bos_token_id'],
45
+ eos_token_id=checkpoint['eos_token_id'],
46
+ unk_token_id=checkpoint['unk_token_id'],
47
+ hidden_size=512,
48
+ num_layers=2
49
+ ).to(device)
50
+ base_model.gpt2_model.resize_token_embeddings(len(tokenizer))
51
+ model = VQAModelWithSpatialAdapter(
52
+ base_model=base_model,
53
+ hidden_size=512,
54
+ num_heads=8,
55
+ dropout=0.3
56
+ ).to(device)
57
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
58
+ model.eval()
59
+ return model, vocab, tokenizer
60
+ def load_feature_model(checkpoint_path, device='cuda'):
61
+ checkpoint = torch.load(checkpoint_path, map_location=device)
62
+ vocab = Vocab()
63
+ vocab.vocab = checkpoint['vocab']
64
+ vocab.vocab_size = len(checkpoint['vocab'])
65
+ vocab.word2idx = checkpoint['word2idx']
66
+ vocab.idx2word = checkpoint['idx2word']
67
+ vocab.pad_token_id = checkpoint['pad_token_id']
68
+ vocab.bos_token_id = checkpoint['bos_token_id']
69
+ vocab.eos_token_id = checkpoint['eos_token_id']
70
+ vocab.unk_token_id = checkpoint['unk_token_id']
71
+ model = VQAModel(
72
+ vocab_size=len(checkpoint['vocab']),
73
+ device=device,
74
+ question_max_len=checkpoint.get('question_max_len', 20),
75
+ answer_max_len=checkpoint.get('answer_max_len', 12),
76
+ pad_token_id=checkpoint['pad_token_id'],
77
+ bos_token_id=checkpoint['bos_token_id'],
78
+ eos_token_id=checkpoint['eos_token_id'],
79
+ unk_token_id=checkpoint['unk_token_id'],
80
+ hidden_size=512,
81
+ num_layers=2
82
+ ).to(device)
83
+ tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
84
+ if tokenizer.pad_token is None:
85
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
86
+ model.gpt2_model.resize_token_embeddings(len(tokenizer))
87
+ model.load_state_dict(checkpoint['model_state_dict'], strict=False)
88
+ model.eval()
89
+ return model, vocab, tokenizer
90
+ def generate_answer(model, vocab, tokenizer, image_path, question, device='cuda'):
91
+ image = Image.open(image_path).convert('RGB')
92
+ image = model.clip_preprocess(image).unsqueeze(0).to(device)
93
+ question_tokens = tokenizer(
94
+ question,
95
+ padding='max_length',
96
+ truncation=True,
97
+ max_length=model.question_max_len,
98
+ return_tensors='pt'
99
+ )
100
+ questions = {
101
+ 'input_ids': question_tokens['input_ids'].to(device),
102
+ 'attention_mask': question_tokens['attention_mask'].to(device)
103
+ }
104
+ with torch.no_grad():
105
+ if hasattr(model, 'generate_with_beam_search'):
106
+ generated = model.generate_with_beam_search(image, questions, beam_width=5)
107
+ else:
108
+ logits = model(image, questions)
109
+ generated = logits.argmax(dim=-1)
110
+ return vocab.decoder(generated[0].cpu().numpy())
111
+ def exact_match_accuracy(predictions, ground_truths):
112
+ """
113
+ Calculate exact match accuracy (case-insensitive, stripped).
114
+ Args:
115
+ predictions: List of predicted answers
116
+ ground_truths: List of ground truth answers
117
+ Returns:
118
+ accuracy: Percentage of exact matches
119
+ """
120
+ matches = sum(1 for pred, gt in zip(predictions, ground_truths)
121
+ if pred.strip().lower() == gt.strip().lower())
122
+ accuracy = (matches / len(predictions)) * 100 if predictions else 0
123
+ return accuracy, matches
124
+ def vqa_accuracy(predictions, ground_truths_list):
125
+ """
126
+ VQA official metric: min(
127
+ Note: This assumes ground_truths_list is a list of lists,
128
+ where each inner list contains multiple human annotations.
129
+ If you only have one annotation per question, this reduces to exact match.
130
+ Args:
131
+ predictions: List of predicted answers
132
+ ground_truths_list: List of lists of ground truth answers
133
+ Returns:
134
+ vqa_score: VQA accuracy score (0-100)
135
+ """
136
+ if not isinstance(ground_truths_list[0], list):
137
+ ground_truths_list = [[gt] for gt in ground_truths_list]
138
+ scores = []
139
+ for pred, gt_list in zip(predictions, ground_truths_list):
140
+ pred_clean = pred.strip().lower()
141
+ matches = sum(1 for gt in gt_list if pred_clean == gt.strip().lower())
142
+ score = min(matches / 3.0, 1.0)
143
+ scores.append(score)
144
+ vqa_score = (sum(scores) / len(scores)) * 100 if scores else 0
145
+ return vqa_score
146
+ def calculate_anls(prediction, ground_truth, threshold=0.5):
147
+ """
148
+ Calculate ANLS (Average Normalized Levenshtein Similarity) for a single pair.
149
+ Args:
150
+ prediction: Predicted answer string
151
+ ground_truth: Ground truth answer string
152
+ threshold: Minimum similarity threshold (default: 0.5)
153
+ Returns:
154
+ anls_score: ANLS score (0-1)
155
+ """
156
+ pred_clean = prediction.strip().lower()
157
+ gt_clean = ground_truth.strip().lower()
158
+ if len(gt_clean) == 0:
159
+ return 1.0 if len(pred_clean) == 0 else 0.0
160
+ dist = levenshtein_distance(pred_clean, gt_clean)
161
+ max_len = max(len(pred_clean), len(gt_clean))
162
+ if max_len == 0:
163
+ return 1.0
164
+ similarity = 1 - (dist / max_len)
165
+ anls = similarity if similarity >= threshold else 0.0
166
+ return anls
167
+ def average_anls(predictions, ground_truths, threshold=0.5):
168
+ """
169
+ Calculate average ANLS across all predictions.
170
+ Args:
171
+ predictions: List of predicted answers
172
+ ground_truths: List of ground truth answers
173
+ threshold: Minimum similarity threshold
174
+ Returns:
175
+ avg_anls: Average ANLS score (0-100)
176
+ """
177
+ anls_scores = []
178
+ for pred, gt in zip(predictions, ground_truths):
179
+ score = calculate_anls(pred, gt, threshold)
180
+ anls_scores.append(score)
181
+ avg_anls = (sum(anls_scores) / len(anls_scores)) * 100 if anls_scores else 0
182
+ return avg_anls, anls_scores
183
+ if __name__ == "__main__":
184
+ print("=" * 80)
185
+ print("VQA EVALUATION: ACCURACY + ANLS")
186
+ print("=" * 80)
187
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
188
+ print(f"\nDevice: {device}")
189
+ print(f"Model: {MODEL_TYPE.upper()}\n")
190
+ if MODEL_TYPE == "spatial":
191
+ model, vocab, tokenizer = load_spatial_model(SPATIAL_CHECKPOINT, device)
192
+ else:
193
+ model, vocab, tokenizer = load_feature_model(FEATURE_CHECKPOINT, device)
194
+ print("✓ Model loaded!\n")
195
+ df = pd.read_csv(CSV_PATH)
196
+ if MAX_SAMPLES:
197
+ df = df.head(MAX_SAMPLES)
198
+ print(f"Evaluating {len(df)} samples\n")
199
+ print("Generating predictions...")
200
+ predictions = []
201
+ ground_truths = []
202
+ for idx, row in tqdm(df.iterrows(), total=len(df)):
203
+ image_path = os.path.join(IMG_DIR, row['image_path'])
204
+ if not os.path.exists(image_path):
205
+ continue
206
+ try:
207
+ prediction = generate_answer(model, vocab, tokenizer,
208
+ image_path, row['question'], device)
209
+ ground_truth = row['answer']
210
+ predictions.append(prediction)
211
+ ground_truths.append(ground_truth)
212
+ except Exception as e:
213
+ continue
214
+ print(f"\n✓ Generated {len(predictions)} predictions\n")
215
+ print("Calculating metrics...\n")
216
+ exact_acc, exact_matches = exact_match_accuracy(predictions, ground_truths)
217
+ vqa_acc = vqa_accuracy(predictions, ground_truths)
218
+ anls_score, anls_scores = average_anls(predictions, ground_truths, threshold=0.5)
219
+ print("=" * 80)
220
+ print("EVALUATION RESULTS")
221
+ print("=" * 80)
222
+ print(f"\n📊 Accuracy Metrics:")
223
+ print(f" Exact Match Accuracy: {exact_acc:.2f}% ({exact_matches}/{len(predictions)})")
224
+ print(f" VQA Accuracy: {vqa_acc:.2f}%")
225
+ print(f"\n📊 ANLS Metrics:")
226
+ print(f" Average ANLS (τ=0.5): {anls_score:.2f}%")
227
+ print(f" ANLS Std Dev: {np.std(anls_scores)*100:.2f}%")
228
+ print(f"\n📊 Additional Statistics:")
229
+ print(f" Total samples: {len(predictions)}")
230
+ print(f" Avg prediction length: {np.mean([len(p.split()) for p in predictions]):.2f} words")
231
+ print(f" Avg GT length: {np.mean([len(gt.split()) for gt in ground_truths]):.2f} words")
232
+ print("\n" + "=" * 80)
233
+ print("SAMPLE PREDICTIONS")
234
+ print("=" * 80)
235
+ sorted_indices = np.argsort(anls_scores)
236
+ print("\n🏆 Best Predictions (Highest ANLS):")
237
+ print("-" * 80)
238
+ for i in sorted_indices[-5:][::-1]:
239
+ print(f"\nGround Truth: {ground_truths[i]}")
240
+ print(f"Prediction: {predictions[i]}")
241
+ print(f"ANLS: {anls_scores[i]:.4f}")
242
+ print(f"Exact Match: {'✓' if predictions[i].strip().lower() == ground_truths[i].strip().lower() else '✗'}")
243
+ print("\n" + "=" * 80)
244
+ print("⚠️ Worst Predictions (Lowest ANLS):")
245
+ print("-" * 80)
246
+ for i in sorted_indices[:5]:
247
+ print(f"\nGround Truth: {ground_truths[i]}")
248
+ print(f"Prediction: {predictions[i]}")
249
+ print(f"ANLS: {anls_scores[i]:.4f}")
250
+ print(f"Exact Match: {'✓' if predictions[i].strip().lower() == ground_truths[i].strip().lower() else '✗'}")
251
+ print("\n" + "=" * 80)
252
+ print("✅ EVALUATION COMPLETE")
253
+ print("=" * 80)
254
+ with open(f"{MODEL_TYPE}.txt", "w", encoding="utf-8") as f:
255
+ f.write("=" * 80 + "\n")
256
+ f.write("EVALUATION RESULTS\n")
257
+ f.write("=" * 80 + "\n")
258
+ f.write("\n📊 Accuracy Metrics:\n")
259
+ f.write(f" Exact Match Accuracy: {exact_acc:.2f}% ({exact_matches}/{len(predictions)})\n")
260
+ f.write(f" VQA Accuracy: {vqa_acc:.2f}%\n")
261
+ f.write("\n📊 ANLS Metrics:\n")
262
+ f.write(f" Average ANLS (τ=0.5): {anls_score:.2f}%\n")
263
+ f.write(f" ANLS Std Dev: {np.std(anls_scores)*100:.2f}%\n")
264
+ f.write("\n📊 Additional Statistics:\n")
265
+ f.write(f" Total samples: {len(predictions)}\n")
266
+ f.write(f" Avg prediction length: {np.mean([len(p.split()) for p in predictions]):.2f} words\n")
267
+ f.write(f" Avg GT length: {np.mean([len(gt.split()) for gt in ground_truths]):.2f} words\n")
268
+ f.write("\n" + "=" * 80 + "\n")
269
+ f.write("SAMPLE PREDICTIONS\n")
270
+ f.write("=" * 80 + "\n")
271
+ sorted_indices = np.argsort(anls_scores)
272
+ f.write("\n🏆 Best Predictions (Highest ANLS):\n")
273
+ f.write("-" * 80 + "\n")
274
+ for i in sorted_indices[-5:][::-1]:
275
+ f.write(f"\nGround Truth: {ground_truths[i]}\n")
276
+ f.write(f"Prediction: {predictions[i]}\n")
277
+ f.write(f"ANLS: {anls_scores[i]:.4f}\n")
278
+ f.write(
279
+ f"Exact Match: {'✓' if predictions[i].strip().lower() == ground_truths[i].strip().lower() else '✗'}\n"
280
+ )
281
+ f.write("\n" + "=" * 80 + "\n")
282
+ f.write("⚠️ Worst Predictions (Lowest ANLS):\n")
283
+ f.write("-" * 80 + "\n")
284
+ for i in sorted_indices[:5]:
285
+ f.write(f"\nGround Truth: {ground_truths[i]}\n")
286
+ f.write(f"Prediction: {predictions[i]}\n")
287
+ f.write(f"ANLS: {anls_scores[i]:.4f}\n")
288
+ f.write(
289
+ f"Exact Match: {'✓' if predictions[i].strip().lower() == ground_truths[i].strip().lower() else '✗'}\n"
290
+ )
291
+ results_df = pd.DataFrame({
292
+ 'prediction': predictions,
293
+ 'ground_truth': ground_truths,
294
+ 'anls_score': anls_scores,
295
+ 'exact_match': [pred.strip().lower() == gt.strip().lower()
296
+ for pred, gt in zip(predictions, ground_truths)]
297
+ })
298
+ output_file = f"vqa_evaluation_{MODEL_TYPE}.csv"
299
+ results_df.to_csv(output_file, index=False)
300
+ print(f"\n💾 Results saved to: {output_file}")
scores/vqa_evaluation_feature.csv ADDED
The diff for this file is too large to render. See raw diff