kobiakor15 commited on
Commit
d0ded2a
·
verified ·
1 Parent(s): 3712bfe

Upload training/train_reasoning_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/train_reasoning_v2.py +331 -0
training/train_reasoning_v2.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Oculus Reasoning Training V2 - BEAST MODE
4
+
5
+ Goal: Beat Isaac 0.2-2B on VQA benchmarks
6
+ Strategy:
7
+ 1. Use ALL available COCO data
8
+ 2. Diverse question templates
9
+ 3. Chain-of-thought style training
10
+ 4. Longer training (8 epochs)
11
+ 5. Learning rate warmup + decay
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ import json
17
+ import random
18
+ import math
19
+ from pathlib import Path
20
+ from dataclasses import dataclass
21
+ from typing import List, Dict, Optional
22
+
23
+ import torch
24
+ from torch.utils.data import Dataset, DataLoader
25
+ from torch.optim import AdamW
26
+ from torch.optim.lr_scheduler import CosineAnnealingLR
27
+ from PIL import Image
28
+ from tqdm import tqdm
29
+
30
+ OCULUS_ROOT = Path(__file__).parent
31
+ sys.path.insert(0, str(OCULUS_ROOT))
32
+
33
+ from oculus_unified_model import OculusForConditionalGeneration
34
+
35
+
36
+ # ============================================================================
37
+ # Advanced Dataset with Diverse Prompts
38
+ # ============================================================================
39
+
40
+ class ReasoningDataset(Dataset):
41
+ """
42
+ Advanced dataset for reasoning training.
43
+ Uses diverse question templates and chain-of-thought style answers.
44
+ """
45
+
46
+ # Diverse question templates for VQA-style training
47
+ CAPTION_PROMPTS = [
48
+ "Describe this image in detail.",
49
+ "What is happening in this image?",
50
+ "Explain what you see.",
51
+ "Provide a detailed description of the scene.",
52
+ "What can you observe in this picture?",
53
+ "Describe the contents of this image.",
54
+ "What is shown here?",
55
+ "Give a comprehensive description.",
56
+ ]
57
+
58
+ COUNTING_PROMPTS = [
59
+ "How many {obj}s are in this image?",
60
+ "Count the number of {obj}s visible.",
61
+ "What is the count of {obj}s?",
62
+ "How many {obj}s can you see?",
63
+ ]
64
+
65
+ EXISTENCE_PROMPTS = [
66
+ "Is there a {obj} in this image?",
67
+ "Can you see a {obj}?",
68
+ "Does this image contain a {obj}?",
69
+ "Is a {obj} visible in this picture?",
70
+ ]
71
+
72
+ ATTRIBUTE_PROMPTS = [
73
+ "What objects are visible in this image?",
74
+ "What type of scene is this?",
75
+ "Describe the main subject of this image.",
76
+ "What is the setting of this image?",
77
+ ]
78
+
79
+ def __init__(self, processor, data_dir="data/coco", max_samples=None):
80
+ self.processor = processor
81
+ self.samples = []
82
+
83
+ # Load COCO data
84
+ cap_file = Path(data_dir) / "annotations" / "captions_train2017.json"
85
+ inst_file = Path(data_dir) / "annotations" / "instances_train2017.json"
86
+
87
+ if not cap_file.exists():
88
+ print("⚠️ COCO data not found!")
89
+ return
90
+
91
+ print("📚 Loading COCO data for reasoning training...")
92
+
93
+ # Load captions
94
+ with open(cap_file) as f:
95
+ captions_data = json.load(f)
96
+
97
+ # Load instances for counting/existence
98
+ with open(inst_file) as f:
99
+ instances_data = json.load(f)
100
+
101
+ # Build indexes
102
+ img_map = {img['id']: img for img in captions_data['images']}
103
+ cat_map = {c['id']: c['name'] for c in instances_data['categories']}
104
+
105
+ # Image to captions
106
+ img_captions = {}
107
+ for ann in captions_data['annotations']:
108
+ img_id = ann['image_id']
109
+ if img_id not in img_captions:
110
+ img_captions[img_id] = []
111
+ img_captions[img_id].append(ann['caption'])
112
+
113
+ # Image to object counts
114
+ img_objects = {}
115
+ for ann in instances_data['annotations']:
116
+ if ann.get('iscrowd', 0):
117
+ continue
118
+ img_id = ann['image_id']
119
+ cat = cat_map.get(ann['category_id'], 'object')
120
+ if img_id not in img_objects:
121
+ img_objects[img_id] = {}
122
+ img_objects[img_id][cat] = img_objects[img_id].get(cat, 0) + 1
123
+
124
+ # Create training samples
125
+ count = 0
126
+ for img_id, captions in img_captions.items():
127
+ img = img_map.get(img_id)
128
+ if not img:
129
+ continue
130
+
131
+ img_path = Path(data_dir) / "images" / img['file_name']
132
+ if not img_path.exists():
133
+ continue
134
+
135
+ # 1. Caption-based QA (main training signal)
136
+ for caption in captions[:2]: # Use up to 2 captions per image
137
+ prompt = random.choice(self.CAPTION_PROMPTS)
138
+ self.samples.append({
139
+ 'path': str(img_path),
140
+ 'question': prompt,
141
+ 'answer': caption,
142
+ 'type': 'caption'
143
+ })
144
+
145
+ # 2. Existence questions
146
+ objects = img_objects.get(img_id, {})
147
+ if objects:
148
+ obj = random.choice(list(objects.keys()))
149
+ prompt = random.choice(self.EXISTENCE_PROMPTS).format(obj=obj)
150
+ self.samples.append({
151
+ 'path': str(img_path),
152
+ 'question': prompt,
153
+ 'answer': "Yes",
154
+ 'type': 'existence'
155
+ })
156
+
157
+ # Also add negative examples
158
+ all_cats = list(cat_map.values())
159
+ missing = [c for c in all_cats if c not in objects]
160
+ if missing:
161
+ neg_obj = random.choice(missing[:10])
162
+ prompt = random.choice(self.EXISTENCE_PROMPTS).format(obj=neg_obj)
163
+ self.samples.append({
164
+ 'path': str(img_path),
165
+ 'question': prompt,
166
+ 'answer': "No",
167
+ 'type': 'existence_neg'
168
+ })
169
+
170
+ # 3. Counting questions (for objects with 2-10 instances)
171
+ for obj, count_val in objects.items():
172
+ if 2 <= count_val <= 10:
173
+ prompt = random.choice(self.COUNTING_PROMPTS).format(obj=obj)
174
+ # Chain-of-thought style answer
175
+ answer = f"There are {count_val} {obj}s in this image."
176
+ self.samples.append({
177
+ 'path': str(img_path),
178
+ 'question': prompt,
179
+ 'answer': answer,
180
+ 'type': 'counting'
181
+ })
182
+ break # One counting Q per image
183
+
184
+ count += 1
185
+ if max_samples and count >= max_samples:
186
+ break
187
+
188
+ # Shuffle samples
189
+ random.shuffle(self.samples)
190
+
191
+ print(f"✅ Loaded {len(self.samples)} reasoning samples")
192
+ print(f" - Captions: {sum(1 for s in self.samples if s['type'] == 'caption')}")
193
+ print(f" - Existence: {sum(1 for s in self.samples if 'existence' in s['type'])}")
194
+ print(f" - Counting: {sum(1 for s in self.samples if s['type'] == 'counting')}")
195
+
196
+ def __len__(self):
197
+ return len(self.samples)
198
+
199
+ def __getitem__(self, idx):
200
+ item = self.samples[idx]
201
+
202
+ try:
203
+ image = Image.open(item['path']).convert('RGB')
204
+ except:
205
+ image = Image.new('RGB', (224, 224))
206
+
207
+ # Encode
208
+ encoding = self.processor(
209
+ images=image,
210
+ text=item['question'],
211
+ padding="max_length",
212
+ truncation=True,
213
+ max_length=32,
214
+ return_tensors="pt"
215
+ )
216
+
217
+ # Labels (answer)
218
+ labels = self.processor(
219
+ text=item['answer'],
220
+ padding="max_length",
221
+ truncation=True,
222
+ max_length=64, # Longer for chain-of-thought
223
+ return_tensors="pt"
224
+ ).input_ids
225
+
226
+ return {
227
+ "pixel_values": encoding.pixel_values.squeeze(0),
228
+ "input_ids": encoding.input_ids.squeeze(0),
229
+ "attention_mask": encoding.attention_mask.squeeze(0),
230
+ "labels": labels.squeeze(0)
231
+ }
232
+
233
+
234
+ # ============================================================================
235
+ # Training Loop with Advanced Features
236
+ # ============================================================================
237
+
238
+ def train():
239
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
240
+ print(f"🚀 BEAST MODE TRAINING")
241
+ print(f"Device: {device}")
242
+
243
+ # Load model
244
+ model_path = "checkpoints/oculus_detection_v2/final"
245
+ print(f"\nLoading Oculus from {model_path}...")
246
+ oculus = OculusForConditionalGeneration.from_pretrained(model_path)
247
+ oculus.load_language_model(device=device)
248
+
249
+ # Get VQA model
250
+ vqa_model = oculus.lm_vqa_model
251
+ vqa_model.train()
252
+ vqa_model.to(device)
253
+
254
+ # Dataset - USE ALL DATA (no max_samples limit, or set high)
255
+ dataset = ReasoningDataset(oculus.lm_vqa_processor, max_samples=50000) # 50K samples!
256
+ dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
257
+
258
+ # Optimizer with weight decay
259
+ optimizer = AdamW(vqa_model.parameters(), lr=3e-5, weight_decay=0.01)
260
+
261
+ # Cosine LR scheduler
262
+ epochs = 8
263
+ total_steps = len(dataloader) * epochs
264
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6)
265
+
266
+ print(f"\n📊 Training Config:")
267
+ print(f" Samples: {len(dataset)}")
268
+ print(f" Batch size: 8")
269
+ print(f" Epochs: {epochs}")
270
+ print(f" Total steps: {total_steps}")
271
+ print(f" LR: 3e-5 -> 1e-6 (cosine)")
272
+
273
+ print("\n🔥 Starting training...")
274
+
275
+ best_loss = float('inf')
276
+ global_step = 0
277
+
278
+ for epoch in range(epochs):
279
+ total_loss = 0
280
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
281
+
282
+ for batch in pbar:
283
+ batch = {k: v.to(device) for k, v in batch.items()}
284
+
285
+ # Forward
286
+ outputs = vqa_model(**batch)
287
+ loss = outputs.loss
288
+
289
+ # Backward
290
+ loss.backward()
291
+
292
+ # Gradient clipping
293
+ torch.nn.utils.clip_grad_norm_(vqa_model.parameters(), 1.0)
294
+
295
+ optimizer.step()
296
+ scheduler.step()
297
+ optimizer.zero_grad()
298
+
299
+ total_loss += loss.item()
300
+ global_step += 1
301
+
302
+ # Progress
303
+ lr = scheduler.get_last_lr()[0]
304
+ pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}")
305
+
306
+ avg_loss = total_loss / len(dataloader)
307
+ print(f"\n✓ Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}")
308
+
309
+ # Save checkpoint if best
310
+ if avg_loss < best_loss:
311
+ best_loss = avg_loss
312
+ checkpoint_dir = Path("checkpoints/oculus_reasoning_v2")
313
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
314
+
315
+ print(f" 💾 New best! Saving to {checkpoint_dir}")
316
+ vqa_model.save_pretrained(checkpoint_dir / "vqa_model")
317
+ oculus.lm_vqa_processor.save_pretrained(checkpoint_dir / "vqa_model")
318
+
319
+ # Final save
320
+ final_dir = Path("checkpoints/oculus_reasoning_v2/final")
321
+ final_dir.mkdir(parents=True, exist_ok=True)
322
+ vqa_model.save_pretrained(final_dir)
323
+ oculus.lm_vqa_processor.save_pretrained(final_dir)
324
+
325
+ print(f"\n✅ BEAST MODE TRAINING COMPLETE!")
326
+ print(f" Best Loss: {best_loss:.4f}")
327
+ print(f" Model saved to: checkpoints/oculus_reasoning_v2/final")
328
+
329
+
330
+ if __name__ == "__main__":
331
+ train()