File size: 5,883 Bytes
3712bfe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from PIL import Image
import json
from pathlib import Path
from tqdm import tqdm
import requests
from io import BytesIO
# Import Oculus
import sys
sys.path.insert(0, str(Path(__file__).parent))
from oculus_unified_model import OculusForConditionalGeneration
class InstructionDataset(Dataset):
"""
Dataset for Visual Instruction Tuning.
Loads from a JSON file with format:
[{'image': 'path/to/img', 'conversations': [{'from': 'human', 'value': '...'}, {'from': 'gpt', 'value': '...'}]}]
"""
def __init__(self, processor, data_dir="data/coco", max_samples=None):
self.processor = processor
self.samples = []
# Load COCO Captions
ann_file = Path(data_dir) / "annotations" / "captions_train2017.json"
if not ann_file.exists():
print(f"⚠️ COCO Captions not found at {ann_file}. Using synthetic fallback.")
# ... (Synthetic fallback code from before could go here, or just empty)
self.samples = [
{"image_path": "data/coco/images/000000071345.jpg", "q": "Describe this.", "a": "A car parked on the street."}
] * 100
else:
print(f"Loading real instruction data from {ann_file}...")
with open(ann_file) as f:
coco = json.load(f)
# Map image_id to filename
img_map = {img['id']: img['file_name'] for img in coco['images']}
# Prompts pool
prompts = [
"Describe this image.",
"What is going on here?",
"Write a caption for this photo.",
"What do you see?",
"Provide a detailed description.",
"Explain the scene."
]
import random
# Create samples
for ann in coco['annotations']:
img_id = ann['image_id']
caption = ann['caption']
filename = img_map.get(img_id)
if filename:
img_path = Path(data_dir) / "images" / filename
# Only add if image exists
if img_path.exists():
self.samples.append({
"image_path": str(img_path),
"question": random.choice(prompts),
"answer": caption
})
if max_samples and len(self.samples) >= max_samples:
break
print(f"✅ Loaded {len(self.samples)} instruction samples from COCO")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
item = self.samples[idx]
# Load image
try:
image = Image.open(item['image_path']).convert('RGB')
except:
image = Image.new('RGB', (224, 224))
question = item['question']
answer = item['answer']
# Format for VQA model
encoding = self.processor(
images=image,
text=question,
padding="max_length",
truncation=True,
max_length=32,
return_tensors="pt"
)
labels = self.processor(text=answer, padding="max_length", truncation=True, max_length=32, return_tensors="pt").input_ids
return {
"pixel_values": encoding.pixel_values.squeeze(0),
"input_ids": encoding.input_ids.squeeze(0),
"attention_mask": encoding.attention_mask.squeeze(0),
"labels": labels.squeeze(0)
}
def train():
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available():
device = "mps"
print(f"Using device: {device}")
# Load Model
model_path = "checkpoints/oculus_detection_v2/final"
print(f"Loading Oculus from {model_path}...")
oculus = OculusForConditionalGeneration.from_pretrained(model_path)
# Check if VQA model is loaded
oculus.load_language_model(device=device)
# We fine-tune the VQA component specifically
vqa_model = oculus.lm_vqa_model
vqa_model.train()
optimizer = AdamW(vqa_model.parameters(), lr=2e-5)
# Dataset - Use 5000 real samples for instruction tuning
dataset = InstructionDataset(oculus.lm_vqa_processor, max_samples=5000)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
print("\n🚀 Starting Instruction Tuning (Reasoning Module)...")
epochs = 4
for epoch in range(epochs):
total_loss = 0
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
for batch in pbar:
batch = {k: v.to(device) for k, v in batch.items()}
# Forward pass
outputs = vqa_model(**batch)
loss = outputs.loss
# Backward
loss.backward()
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
pbar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}")
# Save finetuned weights
output_dir = Path("checkpoints/oculus_instruct_v1")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n💾 Saving tuned VQA model to {output_dir}")
vqa_model.save_pretrained(output_dir / "vqa_model")
oculus.lm_vqa_processor.save_pretrained(output_dir / "vqa_model")
print("✅ Instruction Tuning Complete!")
if __name__ == "__main__":
train()
|