|
|
| import os |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim import AdamW |
| from transformers import get_scheduler |
| from PIL import Image |
| import json |
| from pathlib import Path |
| from tqdm import tqdm |
| import requests |
| from io import BytesIO |
|
|
| |
| import sys |
| sys.path.insert(0, str(Path(__file__).parent)) |
| from oculus_unified_model import OculusForConditionalGeneration |
|
|
| class InstructionDataset(Dataset): |
| """ |
| Dataset for Visual Instruction Tuning. |
| Loads from a JSON file with format: |
| [{'image': 'path/to/img', 'conversations': [{'from': 'human', 'value': '...'}, {'from': 'gpt', 'value': '...'}]}] |
| """ |
| def __init__(self, processor, data_dir="data/coco", max_samples=None): |
| self.processor = processor |
| self.samples = [] |
| |
| |
| ann_file = Path(data_dir) / "annotations" / "captions_train2017.json" |
| if not ann_file.exists(): |
| print(f"⚠️ COCO Captions not found at {ann_file}. Using synthetic fallback.") |
| |
| self.samples = [ |
| {"image_path": "data/coco/images/000000071345.jpg", "q": "Describe this.", "a": "A car parked on the street."} |
| ] * 100 |
| else: |
| print(f"Loading real instruction data from {ann_file}...") |
| with open(ann_file) as f: |
| coco = json.load(f) |
| |
| |
| img_map = {img['id']: img['file_name'] for img in coco['images']} |
| |
| |
| prompts = [ |
| "Describe this image.", |
| "What is going on here?", |
| "Write a caption for this photo.", |
| "What do you see?", |
| "Provide a detailed description.", |
| "Explain the scene." |
| ] |
| import random |
| |
| |
| for ann in coco['annotations']: |
| img_id = ann['image_id'] |
| caption = ann['caption'] |
| filename = img_map.get(img_id) |
| |
| if filename: |
| img_path = Path(data_dir) / "images" / filename |
| |
| if img_path.exists(): |
| self.samples.append({ |
| "image_path": str(img_path), |
| "question": random.choice(prompts), |
| "answer": caption |
| }) |
| |
| if max_samples and len(self.samples) >= max_samples: |
| break |
| |
| print(f"✅ Loaded {len(self.samples)} instruction samples from COCO") |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
| def __getitem__(self, idx): |
| item = self.samples[idx] |
| |
| |
| try: |
| image = Image.open(item['image_path']).convert('RGB') |
| except: |
| image = Image.new('RGB', (224, 224)) |
|
|
| question = item['question'] |
| answer = item['answer'] |
| |
| |
| encoding = self.processor( |
| images=image, |
| text=question, |
| padding="max_length", |
| truncation=True, |
| max_length=32, |
| return_tensors="pt" |
| ) |
| |
| labels = self.processor(text=answer, padding="max_length", truncation=True, max_length=32, return_tensors="pt").input_ids |
| |
| return { |
| "pixel_values": encoding.pixel_values.squeeze(0), |
| "input_ids": encoding.input_ids.squeeze(0), |
| "attention_mask": encoding.attention_mask.squeeze(0), |
| "labels": labels.squeeze(0) |
| } |
|
|
| def train(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if torch.backends.mps.is_available(): |
| device = "mps" |
| print(f"Using device: {device}") |
| |
| |
| model_path = "checkpoints/oculus_detection_v2/final" |
| print(f"Loading Oculus from {model_path}...") |
| oculus = OculusForConditionalGeneration.from_pretrained(model_path) |
| |
| |
| oculus.load_language_model(device=device) |
| |
| |
| vqa_model = oculus.lm_vqa_model |
| vqa_model.train() |
| |
| optimizer = AdamW(vqa_model.parameters(), lr=2e-5) |
| |
| |
| dataset = InstructionDataset(oculus.lm_vqa_processor, max_samples=5000) |
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True) |
| |
| print("\n🚀 Starting Instruction Tuning (Reasoning Module)...") |
| epochs = 4 |
| |
| for epoch in range(epochs): |
| total_loss = 0 |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") |
| |
| for batch in pbar: |
| batch = {k: v.to(device) for k, v in batch.items()} |
| |
| |
| outputs = vqa_model(**batch) |
| loss = outputs.loss |
| |
| |
| loss.backward() |
| optimizer.step() |
| optimizer.zero_grad() |
| |
| total_loss += loss.item() |
| pbar.set_postfix(loss=loss.item()) |
| |
| avg_loss = total_loss / len(dataloader) |
| print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}") |
| |
| |
| output_dir = Path("checkpoints/oculus_instruct_v1") |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| print(f"\n💾 Saving tuned VQA model to {output_dir}") |
| vqa_model.save_pretrained(output_dir / "vqa_model") |
| oculus.lm_vqa_processor.save_pretrained(output_dir / "vqa_model") |
| |
| print("✅ Instruction Tuning Complete!") |
|
|
| if __name__ == "__main__": |
| train() |
|
|