|
|
|
|
|
import os |
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from torch.optim import AdamW |
|
|
from transformers import get_scheduler |
|
|
from PIL import Image |
|
|
import json |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
import requests |
|
|
from io import BytesIO |
|
|
|
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
from oculus_unified_model import OculusForConditionalGeneration |
|
|
|
|
|
class InstructionDataset(Dataset): |
|
|
""" |
|
|
Dataset for Visual Instruction Tuning. |
|
|
Loads from a JSON file with format: |
|
|
[{'image': 'path/to/img', 'conversations': [{'from': 'human', 'value': '...'}, {'from': 'gpt', 'value': '...'}]}] |
|
|
""" |
|
|
def __init__(self, processor, data_dir="data/coco", max_samples=None): |
|
|
self.processor = processor |
|
|
self.samples = [] |
|
|
|
|
|
|
|
|
ann_file = Path(data_dir) / "annotations" / "captions_train2017.json" |
|
|
if not ann_file.exists(): |
|
|
print(f"⚠️ COCO Captions not found at {ann_file}. Using synthetic fallback.") |
|
|
|
|
|
self.samples = [ |
|
|
{"image_path": "data/coco/images/000000071345.jpg", "q": "Describe this.", "a": "A car parked on the street."} |
|
|
] * 100 |
|
|
else: |
|
|
print(f"Loading real instruction data from {ann_file}...") |
|
|
with open(ann_file) as f: |
|
|
coco = json.load(f) |
|
|
|
|
|
|
|
|
img_map = {img['id']: img['file_name'] for img in coco['images']} |
|
|
|
|
|
|
|
|
prompts = [ |
|
|
"Describe this image.", |
|
|
"What is going on here?", |
|
|
"Write a caption for this photo.", |
|
|
"What do you see?", |
|
|
"Provide a detailed description.", |
|
|
"Explain the scene." |
|
|
] |
|
|
import random |
|
|
|
|
|
|
|
|
for ann in coco['annotations']: |
|
|
img_id = ann['image_id'] |
|
|
caption = ann['caption'] |
|
|
filename = img_map.get(img_id) |
|
|
|
|
|
if filename: |
|
|
img_path = Path(data_dir) / "images" / filename |
|
|
|
|
|
if img_path.exists(): |
|
|
self.samples.append({ |
|
|
"image_path": str(img_path), |
|
|
"question": random.choice(prompts), |
|
|
"answer": caption |
|
|
}) |
|
|
|
|
|
if max_samples and len(self.samples) >= max_samples: |
|
|
break |
|
|
|
|
|
print(f"✅ Loaded {len(self.samples)} instruction samples from COCO") |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.samples[idx] |
|
|
|
|
|
|
|
|
try: |
|
|
image = Image.open(item['image_path']).convert('RGB') |
|
|
except: |
|
|
image = Image.new('RGB', (224, 224)) |
|
|
|
|
|
question = item['question'] |
|
|
answer = item['answer'] |
|
|
|
|
|
|
|
|
encoding = self.processor( |
|
|
images=image, |
|
|
text=question, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=32, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
labels = self.processor(text=answer, padding="max_length", truncation=True, max_length=32, return_tensors="pt").input_ids |
|
|
|
|
|
return { |
|
|
"pixel_values": encoding.pixel_values.squeeze(0), |
|
|
"input_ids": encoding.input_ids.squeeze(0), |
|
|
"attention_mask": encoding.attention_mask.squeeze(0), |
|
|
"labels": labels.squeeze(0) |
|
|
} |
|
|
|
|
|
def train(): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if torch.backends.mps.is_available(): |
|
|
device = "mps" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
model_path = "checkpoints/oculus_detection_v2/final" |
|
|
print(f"Loading Oculus from {model_path}...") |
|
|
oculus = OculusForConditionalGeneration.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
oculus.load_language_model(device=device) |
|
|
|
|
|
|
|
|
vqa_model = oculus.lm_vqa_model |
|
|
vqa_model.train() |
|
|
|
|
|
optimizer = AdamW(vqa_model.parameters(), lr=2e-5) |
|
|
|
|
|
|
|
|
dataset = InstructionDataset(oculus.lm_vqa_processor, max_samples=5000) |
|
|
dataloader = DataLoader(dataset, batch_size=4, shuffle=True) |
|
|
|
|
|
print("\n🚀 Starting Instruction Tuning (Reasoning Module)...") |
|
|
epochs = 4 |
|
|
|
|
|
for epoch in range(epochs): |
|
|
total_loss = 0 |
|
|
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") |
|
|
|
|
|
for batch in pbar: |
|
|
batch = {k: v.to(device) for k, v in batch.items()} |
|
|
|
|
|
|
|
|
outputs = vqa_model(**batch) |
|
|
loss = outputs.loss |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
total_loss += loss.item() |
|
|
pbar.set_postfix(loss=loss.item()) |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}") |
|
|
|
|
|
|
|
|
output_dir = Path("checkpoints/oculus_instruct_v1") |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"\n💾 Saving tuned VQA model to {output_dir}") |
|
|
vqa_model.save_pretrained(output_dir / "vqa_model") |
|
|
oculus.lm_vqa_processor.save_pretrained(output_dir / "vqa_model") |
|
|
|
|
|
print("✅ Instruction Tuning Complete!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|