File size: 5,883 Bytes
3712bfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
from PIL import Image
import json
from pathlib import Path
from tqdm import tqdm
import requests
from io import BytesIO

# Import Oculus
import sys
sys.path.insert(0, str(Path(__file__).parent))
from oculus_unified_model import OculusForConditionalGeneration

class InstructionDataset(Dataset):
    """
    Dataset for Visual Instruction Tuning.
    Loads from a JSON file with format:
    [{'image': 'path/to/img', 'conversations': [{'from': 'human', 'value': '...'}, {'from': 'gpt', 'value': '...'}]}]
    """
    def __init__(self, processor, data_dir="data/coco", max_samples=None):
        self.processor = processor
        self.samples = []
        
        # Load COCO Captions
        ann_file = Path(data_dir) / "annotations" / "captions_train2017.json"
        if not ann_file.exists():
            print(f"⚠️ COCO Captions not found at {ann_file}. Using synthetic fallback.")
            # ... (Synthetic fallback code from before could go here, or just empty)
            self.samples = [
                {"image_path": "data/coco/images/000000071345.jpg", "q": "Describe this.", "a": "A car parked on the street."}
            ] * 100
        else:
            print(f"Loading real instruction data from {ann_file}...")
            with open(ann_file) as f:
                coco = json.load(f)
            
            # Map image_id to filename
            img_map = {img['id']: img['file_name'] for img in coco['images']}
            
            # Prompts pool
            prompts = [
                "Describe this image.",
                "What is going on here?",
                "Write a caption for this photo.",
                "What do you see?",
                "Provide a detailed description.",
                "Explain the scene."
            ]
            import random
            
            # Create samples
            for ann in coco['annotations']:
                img_id = ann['image_id']
                caption = ann['caption']
                filename = img_map.get(img_id)
                
                if filename:
                    img_path = Path(data_dir) / "images" / filename
                    # Only add if image exists
                    if img_path.exists():
                        self.samples.append({
                            "image_path": str(img_path),
                            "question": random.choice(prompts),
                            "answer": caption
                        })
                
                if max_samples and len(self.samples) >= max_samples:
                    break
                    
        print(f"✅ Loaded {len(self.samples)} instruction samples from COCO")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        
        # Load image
        try:
            image = Image.open(item['image_path']).convert('RGB')
        except:
            image = Image.new('RGB', (224, 224))

        question = item['question']
        answer = item['answer']
        
        # Format for VQA model
        encoding = self.processor(
            images=image, 
            text=question, 
            padding="max_length", 
            truncation=True, 
            max_length=32,
            return_tensors="pt"
        )
        
        labels = self.processor(text=answer, padding="max_length", truncation=True, max_length=32, return_tensors="pt").input_ids
        
        return {
            "pixel_values": encoding.pixel_values.squeeze(0),
            "input_ids": encoding.input_ids.squeeze(0),
            "attention_mask": encoding.attention_mask.squeeze(0),
            "labels": labels.squeeze(0)
        }

def train():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if torch.backends.mps.is_available():
        device = "mps"
    print(f"Using device: {device}")
    
    # Load Model
    model_path = "checkpoints/oculus_detection_v2/final"
    print(f"Loading Oculus from {model_path}...")
    oculus = OculusForConditionalGeneration.from_pretrained(model_path)
    
    # Check if VQA model is loaded
    oculus.load_language_model(device=device)
    
    # We fine-tune the VQA component specifically
    vqa_model = oculus.lm_vqa_model
    vqa_model.train()
    
    optimizer = AdamW(vqa_model.parameters(), lr=2e-5)
    
    # Dataset - Use 5000 real samples for instruction tuning
    dataset = InstructionDataset(oculus.lm_vqa_processor, max_samples=5000)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    print("\n🚀 Starting Instruction Tuning (Reasoning Module)...")
    epochs = 4
    
    for epoch in range(epochs):
        total_loss = 0
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch in pbar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            outputs = vqa_model(**batch)
            loss = outputs.loss
            
            # Backward
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item())
            
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}")
    
    # Save finetuned weights
    output_dir = Path("checkpoints/oculus_instruct_v1")
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n💾 Saving tuned VQA model to {output_dir}")
    vqa_model.save_pretrained(output_dir / "vqa_model")
    oculus.lm_vqa_processor.save_pretrained(output_dir / "vqa_model")
    
    print("✅ Instruction Tuning Complete!")

if __name__ == "__main__":
    train()