Oculus / training /train_instruction_tuning.py
kobiakor15's picture
Upload training/train_instruction_tuning.py with huggingface_hub
3712bfe verified
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from PIL import Image
import json
from pathlib import Path
from tqdm import tqdm
import requests
from io import BytesIO
# Import Oculus
import sys
sys.path.insert(0, str(Path(__file__).parent))
from oculus_unified_model import OculusForConditionalGeneration
class InstructionDataset(Dataset):
"""
Dataset for Visual Instruction Tuning.
Loads from a JSON file with format:
[{'image': 'path/to/img', 'conversations': [{'from': 'human', 'value': '...'}, {'from': 'gpt', 'value': '...'}]}]
"""
def __init__(self, processor, data_dir="data/coco", max_samples=None):
self.processor = processor
self.samples = []
# Load COCO Captions
ann_file = Path(data_dir) / "annotations" / "captions_train2017.json"
if not ann_file.exists():
print(f"⚠️ COCO Captions not found at {ann_file}. Using synthetic fallback.")
# ... (Synthetic fallback code from before could go here, or just empty)
self.samples = [
{"image_path": "data/coco/images/000000071345.jpg", "q": "Describe this.", "a": "A car parked on the street."}
] * 100
else:
print(f"Loading real instruction data from {ann_file}...")
with open(ann_file) as f:
coco = json.load(f)
# Map image_id to filename
img_map = {img['id']: img['file_name'] for img in coco['images']}
# Prompts pool
prompts = [
"Describe this image.",
"What is going on here?",
"Write a caption for this photo.",
"What do you see?",
"Provide a detailed description.",
"Explain the scene."
]
import random
# Create samples
for ann in coco['annotations']:
img_id = ann['image_id']
caption = ann['caption']
filename = img_map.get(img_id)
if filename:
img_path = Path(data_dir) / "images" / filename
# Only add if image exists
if img_path.exists():
self.samples.append({
"image_path": str(img_path),
"question": random.choice(prompts),
"answer": caption
})
if max_samples and len(self.samples) >= max_samples:
break
print(f"✅ Loaded {len(self.samples)} instruction samples from COCO")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
item = self.samples[idx]
# Load image
try:
image = Image.open(item['image_path']).convert('RGB')
except:
image = Image.new('RGB', (224, 224))
question = item['question']
answer = item['answer']
# Format for VQA model
encoding = self.processor(
images=image,
text=question,
padding="max_length",
truncation=True,
max_length=32,
return_tensors="pt"
)
labels = self.processor(text=answer, padding="max_length", truncation=True, max_length=32, return_tensors="pt").input_ids
return {
"pixel_values": encoding.pixel_values.squeeze(0),
"input_ids": encoding.input_ids.squeeze(0),
"attention_mask": encoding.attention_mask.squeeze(0),
"labels": labels.squeeze(0)
}
def train():
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
device = "mps"
print(f"Using device: {device}")
# Load Model
model_path = "checkpoints/oculus_detection_v2/final"
print(f"Loading Oculus from {model_path}...")
oculus = OculusForConditionalGeneration.from_pretrained(model_path)
# Check if VQA model is loaded
oculus.load_language_model(device=device)
# We fine-tune the VQA component specifically
vqa_model = oculus.lm_vqa_model
vqa_model.train()
optimizer = AdamW(vqa_model.parameters(), lr=2e-5)
# Dataset - Use 5000 real samples for instruction tuning
dataset = InstructionDataset(oculus.lm_vqa_processor, max_samples=5000)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
print("\n🚀 Starting Instruction Tuning (Reasoning Module)...")
epochs = 4
for epoch in range(epochs):
total_loss = 0
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
for batch in pbar:
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
outputs = vqa_model(**batch)
loss = outputs.loss
# Backward
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
pbar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}")
# Save finetuned weights
output_dir = Path("checkpoints/oculus_instruct_v1")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n💾 Saving tuned VQA model to {output_dir}")
vqa_model.save_pretrained(output_dir / "vqa_model")
oculus.lm_vqa_processor.save_pretrained(output_dir / "vqa_model")
print("✅ Instruction Tuning Complete!")
if __name__ == "__main__":
train()