import os import torch from torch.utils.data import Dataset, DataLoader from torch.optim import AdamW from transformers import get_scheduler from PIL import Image import json from pathlib import Path from tqdm import tqdm import requests from io import BytesIO # Import Oculus import sys sys.path.insert(0, str(Path(__file__).parent)) from oculus_unified_model import OculusForConditionalGeneration class InstructionDataset(Dataset): """ Dataset for Visual Instruction Tuning. Loads from a JSON file with format: [{'image': 'path/to/img', 'conversations': [{'from': 'human', 'value': '...'}, {'from': 'gpt', 'value': '...'}]}] """ def __init__(self, processor, data_dir="data/coco", max_samples=None): self.processor = processor self.samples = [] # Load COCO Captions ann_file = Path(data_dir) / "annotations" / "captions_train2017.json" if not ann_file.exists(): print(f"āš ļø COCO Captions not found at {ann_file}. Using synthetic fallback.") # ... (Synthetic fallback code from before could go here, or just empty) self.samples = [ {"image_path": "data/coco/images/000000071345.jpg", "q": "Describe this.", "a": "A car parked on the street."} ] * 100 else: print(f"Loading real instruction data from {ann_file}...") with open(ann_file) as f: coco = json.load(f) # Map image_id to filename img_map = {img['id']: img['file_name'] for img in coco['images']} # Prompts pool prompts = [ "Describe this image.", "What is going on here?", "Write a caption for this photo.", "What do you see?", "Provide a detailed description.", "Explain the scene." ] import random # Create samples for ann in coco['annotations']: img_id = ann['image_id'] caption = ann['caption'] filename = img_map.get(img_id) if filename: img_path = Path(data_dir) / "images" / filename # Only add if image exists if img_path.exists(): self.samples.append({ "image_path": str(img_path), "question": random.choice(prompts), "answer": caption }) if max_samples and len(self.samples) >= max_samples: break print(f"āœ… Loaded {len(self.samples)} instruction samples from COCO") def __len__(self): return len(self.samples) def __getitem__(self, idx): item = self.samples[idx] # Load image try: image = Image.open(item['image_path']).convert('RGB') except: image = Image.new('RGB', (224, 224)) question = item['question'] answer = item['answer'] # Format for VQA model encoding = self.processor( images=image, text=question, padding="max_length", truncation=True, max_length=32, return_tensors="pt" ) labels = self.processor(text=answer, padding="max_length", truncation=True, max_length=32, return_tensors="pt").input_ids return { "pixel_values": encoding.pixel_values.squeeze(0), "input_ids": encoding.input_ids.squeeze(0), "attention_mask": encoding.attention_mask.squeeze(0), "labels": labels.squeeze(0) } def train(): device = "cuda" if torch.cuda.is_available() else "cpu" if torch.backends.mps.is_available(): device = "mps" print(f"Using device: {device}") # Load Model model_path = "checkpoints/oculus_detection_v2/final" print(f"Loading Oculus from {model_path}...") oculus = OculusForConditionalGeneration.from_pretrained(model_path) # Check if VQA model is loaded oculus.load_language_model(device=device) # We fine-tune the VQA component specifically vqa_model = oculus.lm_vqa_model vqa_model.train() optimizer = AdamW(vqa_model.parameters(), lr=2e-5) # Dataset - Use 5000 real samples for instruction tuning dataset = InstructionDataset(oculus.lm_vqa_processor, max_samples=5000) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) print("\nšŸš€ Starting Instruction Tuning (Reasoning Module)...") epochs = 4 for epoch in range(epochs): total_loss = 0 pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") for batch in pbar: batch = {k: v.to(device) for k, v in batch.items()} # Forward pass outputs = vqa_model(**batch) loss = outputs.loss # Backward loss.backward() optimizer.step() optimizer.zero_grad() total_loss += loss.item() pbar.set_postfix(loss=loss.item()) avg_loss = total_loss / len(dataloader) print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}") # Save finetuned weights output_dir = Path("checkpoints/oculus_instruct_v1") output_dir.mkdir(parents=True, exist_ok=True) print(f"\nšŸ’¾ Saving tuned VQA model to {output_dir}") vqa_model.save_pretrained(output_dir / "vqa_model") oculus.lm_vqa_processor.save_pretrained(output_dir / "vqa_model") print("āœ… Instruction Tuning Complete!") if __name__ == "__main__": train()