|
|
|
|
|
""" |
|
|
Oculus Reasoning Training V2 - BEAST MODE |
|
|
|
|
|
Goal: Beat Isaac 0.2-2B on VQA benchmarks |
|
|
Strategy: |
|
|
1. Use ALL available COCO data |
|
|
2. Diverse question templates |
|
|
3. Chain-of-thought style training |
|
|
4. Longer training (8 epochs) |
|
|
5. Learning rate warmup + decay |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import random |
|
|
import math |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Dict, Optional |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
|
|
|
OCULUS_ROOT = Path(__file__).parent |
|
|
sys.path.insert(0, str(OCULUS_ROOT)) |
|
|
|
|
|
from oculus_unified_model import OculusForConditionalGeneration |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReasoningDataset(Dataset): |
|
|
""" |
|
|
Advanced dataset for reasoning training. |
|
|
Uses diverse question templates and chain-of-thought style answers. |
|
|
""" |
|
|
|
|
|
|
|
|
CAPTION_PROMPTS = [ |
|
|
"Describe this image in detail.", |
|
|
"What is happening in this image?", |
|
|
"Explain what you see.", |
|
|
"Provide a detailed description of the scene.", |
|
|
"What can you observe in this picture?", |
|
|
"Describe the contents of this image.", |
|
|
"What is shown here?", |
|
|
"Give a comprehensive description.", |
|
|
] |
|
|
|
|
|
COUNTING_PROMPTS = [ |
|
|
"How many {obj}s are in this image?", |
|
|
"Count the number of {obj}s visible.", |
|
|
"What is the count of {obj}s?", |
|
|
"How many {obj}s can you see?", |
|
|
] |
|
|
|
|
|
EXISTENCE_PROMPTS = [ |
|
|
"Is there a {obj} in this image?", |
|
|
"Can you see a {obj}?", |
|
|
"Does this image contain a {obj}?", |
|
|
"Is a {obj} visible in this picture?", |
|
|
] |
|
|
|
|
|
ATTRIBUTE_PROMPTS = [ |
|
|
"What objects are visible in this image?", |
|
|
"What type of scene is this?", |
|
|
"Describe the main subject of this image.", |
|
|
"What is the setting of this image?", |
|
|
] |
|
|
|
|
|
def __init__(self, processor, data_dir="data/coco", max_samples=None): |
|
|
self.processor = processor |
|
|
self.samples = [] |
|
|
|
|
|
|
|
|
cap_file = Path(data_dir) / "annotations" / "captions_train2017.json" |
|
|
inst_file = Path(data_dir) / "annotations" / "instances_train2017.json" |
|
|
|
|
|
if not cap_file.exists(): |
|
|
print("⚠️ COCO data not found!") |
|
|
return |
|
|
|
|
|
print("📚 Loading COCO data for reasoning training...") |
|
|
|
|
|
|
|
|
with open(cap_file) as f: |
|
|
captions_data = json.load(f) |
|
|
|
|
|
|
|
|
with open(inst_file) as f: |
|
|
instances_data = json.load(f) |
|
|
|
|
|
|
|
|
img_map = {img['id']: img for img in captions_data['images']} |
|
|
cat_map = {c['id']: c['name'] for c in instances_data['categories']} |
|
|
|
|
|
|
|
|
img_captions = {} |
|
|
for ann in captions_data['annotations']: |
|
|
img_id = ann['image_id'] |
|
|
if img_id not in img_captions: |
|
|
img_captions[img_id] = [] |
|
|
img_captions[img_id].append(ann['caption']) |
|
|
|
|
|
|
|
|
img_objects = {} |
|
|
for ann in instances_data['annotations']: |
|
|
if ann.get('iscrowd', 0): |
|
|
continue |
|
|
img_id = ann['image_id'] |
|
|
cat = cat_map.get(ann['category_id'], 'object') |
|
|
if img_id not in img_objects: |
|
|
img_objects[img_id] = {} |
|
|
img_objects[img_id][cat] = img_objects[img_id].get(cat, 0) + 1 |
|
|
|
|
|
|
|
|
count = 0 |
|
|
for img_id, captions in img_captions.items(): |
|
|
img = img_map.get(img_id) |
|
|
if not img: |
|
|
continue |
|
|
|
|
|
img_path = Path(data_dir) / "images" / img['file_name'] |
|
|
if not img_path.exists(): |
|
|
continue |
|
|
|
|
|
|
|
|
for caption in captions[:2]: |
|
|
prompt = random.choice(self.CAPTION_PROMPTS) |
|
|
self.samples.append({ |
|
|
'path': str(img_path), |
|
|
'question': prompt, |
|
|
'answer': caption, |
|
|
'type': 'caption' |
|
|
}) |
|
|
|
|
|
|
|
|
objects = img_objects.get(img_id, {}) |
|
|
if objects: |
|
|
obj = random.choice(list(objects.keys())) |
|
|
prompt = random.choice(self.EXISTENCE_PROMPTS).format(obj=obj) |
|
|
self.samples.append({ |
|
|
'path': str(img_path), |
|
|
'question': prompt, |
|
|
'answer': "Yes", |
|
|
'type': 'existence' |
|
|
}) |
|
|
|
|
|
|
|
|
all_cats = list(cat_map.values()) |
|
|
missing = [c for c in all_cats if c not in objects] |
|
|
if missing: |
|
|
neg_obj = random.choice(missing[:10]) |
|
|
prompt = random.choice(self.EXISTENCE_PROMPTS).format(obj=neg_obj) |
|
|
self.samples.append({ |
|
|
'path': str(img_path), |
|
|
'question': prompt, |
|
|
'answer': "No", |
|
|
'type': 'existence_neg' |
|
|
}) |
|
|
|
|
|
|
|
|
for obj, count_val in objects.items(): |
|
|
if 2 <= count_val <= 10: |
|
|
prompt = random.choice(self.COUNTING_PROMPTS).format(obj=obj) |
|
|
|
|
|
answer = f"There are {count_val} {obj}s in this image." |
|
|
self.samples.append({ |
|
|
'path': str(img_path), |
|
|
'question': prompt, |
|
|
'answer': answer, |
|
|
'type': 'counting' |
|
|
}) |
|
|
break |
|
|
|
|
|
count += 1 |
|
|
if max_samples and count >= max_samples: |
|
|
break |
|
|
|
|
|
|
|
|
random.shuffle(self.samples) |
|
|
|
|
|
print(f"✅ Loaded {len(self.samples)} reasoning samples") |
|
|
print(f" - Captions: {sum(1 for s in self.samples if s['type'] == 'caption')}") |
|
|
print(f" - Existence: {sum(1 for s in self.samples if 'existence' in s['type'])}") |
|
|
print(f" - Counting: {sum(1 for s in self.samples if s['type'] == 'counting')}") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.samples[idx] |
|
|
|
|
|
try: |
|
|
image = Image.open(item['path']).convert('RGB') |
|
|
except: |
|
|
image = Image.new('RGB', (224, 224)) |
|
|
|
|
|
|
|
|
encoding = self.processor( |
|
|
images=image, |
|
|
text=item['question'], |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=32, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
|
|
|
labels = self.processor( |
|
|
text=item['answer'], |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=64, |
|
|
return_tensors="pt" |
|
|
).input_ids |
|
|
|
|
|
return { |
|
|
"pixel_values": encoding.pixel_values.squeeze(0), |
|
|
"input_ids": encoding.input_ids.squeeze(0), |
|
|
"attention_mask": encoding.attention_mask.squeeze(0), |
|
|
"labels": labels.squeeze(0) |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(): |
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
print(f"🚀 BEAST MODE TRAINING") |
|
|
print(f"Device: {device}") |
|
|
|
|
|
|
|
|
model_path = "checkpoints/oculus_detection_v2/final" |
|
|
print(f"\nLoading Oculus from {model_path}...") |
|
|
oculus = OculusForConditionalGeneration.from_pretrained(model_path) |
|
|
oculus.load_language_model(device=device) |
|
|
|
|
|
|
|
|
vqa_model = oculus.lm_vqa_model |
|
|
vqa_model.train() |
|
|
vqa_model.to(device) |
|
|
|
|
|
|
|
|
dataset = ReasoningDataset(oculus.lm_vqa_processor, max_samples=50000) |
|
|
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0) |
|
|
|
|
|
|
|
|
optimizer = AdamW(vqa_model.parameters(), lr=3e-5, weight_decay=0.01) |
|
|
|
|
|
|
|
|
epochs = 8 |
|
|
total_steps = len(dataloader) * epochs |
|
|
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-6) |
|
|
|
|
|
print(f"\n📊 Training Config:") |
|
|
print(f" Samples: {len(dataset)}") |
|
|
print(f" Batch size: 8") |
|
|
print(f" Epochs: {epochs}") |
|
|
print(f" Total steps: {total_steps}") |
|
|
print(f" LR: 3e-5 -> 1e-6 (cosine)") |
|
|
|
|
|
print("\n🔥 Starting training...") |
|
|
|
|
|
best_loss = float('inf') |
|
|
global_step = 0 |
|
|
|
|
|
for epoch in range(epochs): |
|
|
total_loss = 0 |
|
|
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") |
|
|
|
|
|
for batch in pbar: |
|
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
|
|
|
|
|
|
outputs = vqa_model(**batch) |
|
|
loss = outputs.loss |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(vqa_model.parameters(), 1.0) |
|
|
|
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
total_loss += loss.item() |
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
lr = scheduler.get_last_lr()[0] |
|
|
pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}") |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
print(f"\n✓ Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}") |
|
|
|
|
|
|
|
|
if avg_loss < best_loss: |
|
|
best_loss = avg_loss |
|
|
checkpoint_dir = Path("checkpoints/oculus_reasoning_v2") |
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f" 💾 New best! Saving to {checkpoint_dir}") |
|
|
vqa_model.save_pretrained(checkpoint_dir / "vqa_model") |
|
|
oculus.lm_vqa_processor.save_pretrained(checkpoint_dir / "vqa_model") |
|
|
|
|
|
|
|
|
final_dir = Path("checkpoints/oculus_reasoning_v2/final") |
|
|
final_dir.mkdir(parents=True, exist_ok=True) |
|
|
vqa_model.save_pretrained(final_dir) |
|
|
oculus.lm_vqa_processor.save_pretrained(final_dir) |
|
|
|
|
|
print(f"\n✅ BEAST MODE TRAINING COMPLETE!") |
|
|
print(f" Best Loss: {best_loss:.4f}") |
|
|
print(f" Model saved to: checkpoints/oculus_reasoning_v2/final") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|