| |
| """ |
| Oculus Reasoning Training V2 - BEAST MODE |
| |
| Goal: Beat Isaac 0.2-2B on VQA benchmarks |
| Strategy: |
| 1. Use ALL available COCO data |
| 2. Diverse question templates |
| 3. Chain-of-thought style training |
| 4. Longer training (8 epochs) |
| 5. Learning rate warmup + decay |
| """ |
|
|
| import os |
| import sys |
| import json |
| import random |
| import math |
| from pathlib import Path |
| from dataclasses import dataclass |
| from typing import List, Dict, Optional |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| OCULUS_ROOT = Path(__file__).parent |
| sys.path.insert(0, str(OCULUS_ROOT)) |
|
|
| from oculus_unified_model import OculusForConditionalGeneration |
|
|
|
|
| |
| |
| |
|
|
| class ReasoningDataset(Dataset): |
| """ |
| Advanced dataset for reasoning training. |
| Uses diverse question templates and chain-of-thought style answers. |
| """ |
| |
| |
| CAPTION_PROMPTS = [ |
| "Describe this image in detail.", |
| "What is happening in this image?", |
| "Explain what you see.", |
| "Provide a detailed description of the scene.", |
| "What can you observe in this picture?", |
| "Describe the contents of this image.", |
| "What is shown here?", |
| "Give a comprehensive description.", |
| ] |
| |
| COUNTING_PROMPTS = [ |
| "How many {obj}s are in this image?", |
| "Count the number of {obj}s visible.", |
| "What is the count of {obj}s?", |
| "How many {obj}s can you see?", |
| ] |
| |
| EXISTENCE_PROMPTS = [ |
| "Is there a {obj} in this image?", |
| "Can you see a {obj}?", |
| "Does this image contain a {obj}?", |
| "Is a {obj} visible in this picture?", |
| ] |
| |
| ATTRIBUTE_PROMPTS = [ |
| "What objects are visible in this image?", |
| "What type of scene is this?", |
| "Describe the main subject of this image.", |
| "What is the setting of this image?", |
| ] |
| |
| def __init__(self, processor, data_dir="data/coco", max_samples=None): |
| self.processor = processor |
| self.samples = [] |
| |
| |
| cap_file = Path(data_dir) / "annotations" / "captions_train2017.json" |
| inst_file = Path(data_dir) / "annotations" / "instances_train2017.json" |
| |
| if not cap_file.exists(): |
| print("⚠️ COCO data not found!") |
| return |
| |
| print("📚 Loading COCO data for reasoning training...") |
| |
| |
| with open(cap_file) as f: |
| captions_data = json.load(f) |
| |
| |
| with open(inst_file) as f: |
| instances_data = json.load(f) |
| |
| |
| img_map = {img['id']: img for img in captions_data['images']} |
| cat_map = {c['id']: c['name'] for c in instances_data['categories']} |
| |
| |
| img_captions = {} |
| for ann in captions_data['annotations']: |
| img_id = ann['image_id'] |
| if img_id not in img_captions: |
| img_captions[img_id] = [] |
| img_captions[img_id].append(ann['caption']) |
| |
| |
| img_objects = {} |
| for ann in instances_data['annotations']: |
| if ann.get('iscrowd', 0): |
| continue |
| img_id = ann['image_id'] |
| cat = cat_map.get(ann['category_id'], 'object') |
| if img_id not in img_objects: |
| img_objects[img_id] = {} |
| img_objects[img_id][cat] = img_objects[img_id].get(cat, 0) + 1 |
| |
| |
| count = 0 |
| for img_id, captions in img_captions.items(): |
| img = img_map.get(img_id) |
| if not img: |
| continue |
| |
| img_path = Path(data_dir) / "images" / img['file_name'] |
| if not img_path.exists(): |
| continue |
| |
| |
| for caption in captions[:2]: |
| prompt = random.choice(self.CAPTION_PROMPTS) |
| self.samples.append({ |
| 'path': str(img_path), |
| 'question': prompt, |
| 'answer': caption, |
| 'type': 'caption' |
| }) |
| |
| |
| objects = img_objects.get(img_id, {}) |
| if objects: |
| obj = random.choice(list(objects.keys())) |
| prompt = random.choice(self.EXISTENCE_PROMPTS).format(obj=obj) |
| self.samples.append({ |
| 'path': str(img_path), |
| 'question': prompt, |
| 'answer': "Yes", |
| 'type': 'existence' |
| }) |
| |
| |
| all_cats = list(cat_map.values()) |
| missing = [c for c in all_cats if c not in objects] |
| if missing: |
| neg_obj = random.choice(missing[:10]) |
| prompt = random.choice(self.EXISTENCE_PROMPTS).format(obj=neg_obj) |
| self.samples.append({ |
| 'path': str(img_path), |
| 'question': prompt, |
| 'answer': "No", |
| 'type': 'existence_neg' |
| }) |
| |
| |
| for obj, count_val in objects.items(): |
| if 2 <= count_val <= 10: |
| prompt = random.choice(self.COUNTING_PROMPTS).format(obj=obj) |
| |
| answer = f"There are {count_val} {obj}s in this image." |
| self.samples.append({ |
| 'path': str(img_path), |
| 'question': prompt, |
| 'answer': answer, |
| 'type': 'counting' |
| }) |
| break |
| |
| count += 1 |
| if max_samples and count >= max_samples: |
| break |
| |
| |
| random.shuffle(self.samples) |
| |
| print(f"✅ Loaded {len(self.samples)} reasoning samples") |
| print(f" - Captions: {sum(1 for s in self.samples if s['type'] == 'caption')}") |
| print(f" - Existence: {sum(1 for s in self.samples if 'existence' in s['type'])}") |
| print(f" - Counting: {sum(1 for s in self.samples if s['type'] == 'counting')}") |
| |
| def __len__(self): |
| return len(self.samples) |
| |
| def __getitem__(self, idx): |
| item = self.samples[idx] |
| |
| try: |
| image = Image.open(item['path']).convert('RGB') |
| except: |
| image = Image.new('RGB', (224, 224)) |
| |
| |
| encoding = self.processor( |
| images=image, |
| text=item['question'], |
| padding="max_length", |
| truncation=True, |
| max_length=32, |
| return_tensors="pt" |
| ) |
| |
| |
| labels = self.processor( |
| text=item['answer'], |
| padding="max_length", |
| truncation=True, |
| max_length=64, |
| return_tensors="pt" |
| ).input_ids |
| |
| return { |
| "pixel_values": encoding.pixel_values.squeeze(0), |
| "input_ids": encoding.input_ids.squeeze(0), |
| "attention_mask": encoding.attention_mask.squeeze(0), |
| "labels": labels.squeeze(0) |
| } |
|
|
|
|
| |
| |
| |
|
|
| def train(): |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| print(f"🚀 BEAST MODE TRAINING") |
| print(f"Device: {device}") |
| |
| |
| model_path = "checkpoints/oculus_detection_v2/final" |
| print(f"\nLoading Oculus from {model_path}...") |
| oculus = OculusForConditionalGeneration.from_pretrained(model_path) |
| oculus.load_language_model(device=device) |
| |
| |
| vqa_model = oculus.lm_vqa_model |
| vqa_model.train() |
| vqa_model.to(device) |
| |
| |
| dataset = ReasoningDataset(oculus.lm_vqa_processor, max_samples=50000) |
| dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0) |
| |
| |
| optimizer = AdamW(vqa_model.parameters(), lr=3e-5, weight_decay=0.01) |
| |
| |
| epochs = 8 |
| total_steps = len(dataloader) * epochs |
| scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6) |
| |
| print(f"\n📊 Training Config:") |
| print(f" Samples: {len(dataset)}") |
| print(f" Batch size: 8") |
| print(f" Epochs: {epochs}") |
| print(f" Total steps: {total_steps}") |
| print(f" LR: 3e-5 -> 1e-6 (cosine)") |
| |
| print("\n🔥 Starting training...") |
| |
| best_loss = float('inf') |
| global_step = 0 |
| |
| for epoch in range(epochs): |
| total_loss = 0 |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") |
| |
| for batch in pbar: |
| batch = {k: v.to(device) for k, v in batch.items()} |
| |
| |
| outputs = vqa_model(**batch) |
| loss = outputs.loss |
| |
| |
| loss.backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_(vqa_model.parameters(), 1.0) |
| |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| |
| total_loss += loss.item() |
| global_step += 1 |
| |
| |
| lr = scheduler.get_last_lr()[0] |
| pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}") |
| |
| avg_loss = total_loss / len(dataloader) |
| print(f"\n✓ Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}") |
| |
| |
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| checkpoint_dir = Path("checkpoints/oculus_reasoning_v2") |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| |
| print(f" 💾 New best! Saving to {checkpoint_dir}") |
| vqa_model.save_pretrained(checkpoint_dir / "vqa_model") |
| oculus.lm_vqa_processor.save_pretrained(checkpoint_dir / "vqa_model") |
| |
| |
| final_dir = Path("checkpoints/oculus_reasoning_v2/final") |
| final_dir.mkdir(parents=True, exist_ok=True) |
| vqa_model.save_pretrained(final_dir) |
| oculus.lm_vqa_processor.save_pretrained(final_dir) |
| |
| print(f"\n✅ BEAST MODE TRAINING COMPLETE!") |
| print(f" Best Loss: {best_loss:.4f}") |
| print(f" Model saved to: checkpoints/oculus_reasoning_v2/final") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|