kobiakor15 commited on
Commit
3712bfe
·
verified ·
1 Parent(s): 989f87b

Upload training/train_instruction_tuning.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/train_instruction_tuning.py +171 -0
training/train_instruction_tuning.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torch.optim import AdamW
6
+ from transformers import get_scheduler
7
+ from PIL import Image
8
+ import json
9
+ from pathlib import Path
10
+ from tqdm import tqdm
11
+ import requests
12
+ from io import BytesIO
13
+
14
+ # Import Oculus
15
+ import sys
16
+ sys.path.insert(0, str(Path(__file__).parent))
17
+ from oculus_unified_model import OculusForConditionalGeneration
18
+
19
+ class InstructionDataset(Dataset):
20
+ """
21
+ Dataset for Visual Instruction Tuning.
22
+ Loads from a JSON file with format:
23
+ [{'image': 'path/to/img', 'conversations': [{'from': 'human', 'value': '...'}, {'from': 'gpt', 'value': '...'}]}]
24
+ """
25
+ def __init__(self, processor, data_dir="data/coco", max_samples=None):
26
+ self.processor = processor
27
+ self.samples = []
28
+
29
+ # Load COCO Captions
30
+ ann_file = Path(data_dir) / "annotations" / "captions_train2017.json"
31
+ if not ann_file.exists():
32
+ print(f"⚠️ COCO Captions not found at {ann_file}. Using synthetic fallback.")
33
+ # ... (Synthetic fallback code from before could go here, or just empty)
34
+ self.samples = [
35
+ {"image_path": "data/coco/images/000000071345.jpg", "q": "Describe this.", "a": "A car parked on the street."}
36
+ ] * 100
37
+ else:
38
+ print(f"Loading real instruction data from {ann_file}...")
39
+ with open(ann_file) as f:
40
+ coco = json.load(f)
41
+
42
+ # Map image_id to filename
43
+ img_map = {img['id']: img['file_name'] for img in coco['images']}
44
+
45
+ # Prompts pool
46
+ prompts = [
47
+ "Describe this image.",
48
+ "What is going on here?",
49
+ "Write a caption for this photo.",
50
+ "What do you see?",
51
+ "Provide a detailed description.",
52
+ "Explain the scene."
53
+ ]
54
+ import random
55
+
56
+ # Create samples
57
+ for ann in coco['annotations']:
58
+ img_id = ann['image_id']
59
+ caption = ann['caption']
60
+ filename = img_map.get(img_id)
61
+
62
+ if filename:
63
+ img_path = Path(data_dir) / "images" / filename
64
+ # Only add if image exists
65
+ if img_path.exists():
66
+ self.samples.append({
67
+ "image_path": str(img_path),
68
+ "question": random.choice(prompts),
69
+ "answer": caption
70
+ })
71
+
72
+ if max_samples and len(self.samples) >= max_samples:
73
+ break
74
+
75
+ print(f"✅ Loaded {len(self.samples)} instruction samples from COCO")
76
+
77
+ def __len__(self):
78
+ return len(self.samples)
79
+
80
+ def __getitem__(self, idx):
81
+ item = self.samples[idx]
82
+
83
+ # Load image
84
+ try:
85
+ image = Image.open(item['image_path']).convert('RGB')
86
+ except:
87
+ image = Image.new('RGB', (224, 224))
88
+
89
+ question = item['question']
90
+ answer = item['answer']
91
+
92
+ # Format for VQA model
93
+ encoding = self.processor(
94
+ images=image,
95
+ text=question,
96
+ padding="max_length",
97
+ truncation=True,
98
+ max_length=32,
99
+ return_tensors="pt"
100
+ )
101
+
102
+ labels = self.processor(text=answer, padding="max_length", truncation=True, max_length=32, return_tensors="pt").input_ids
103
+
104
+ return {
105
+ "pixel_values": encoding.pixel_values.squeeze(0),
106
+ "input_ids": encoding.input_ids.squeeze(0),
107
+ "attention_mask": encoding.attention_mask.squeeze(0),
108
+ "labels": labels.squeeze(0)
109
+ }
110
+
111
+ def train():
112
+ device = "cuda" if torch.cuda.is_available() else "cpu"
113
+ if torch.backends.mps.is_available():
114
+ device = "mps"
115
+ print(f"Using device: {device}")
116
+
117
+ # Load Model
118
+ model_path = "checkpoints/oculus_detection_v2/final"
119
+ print(f"Loading Oculus from {model_path}...")
120
+ oculus = OculusForConditionalGeneration.from_pretrained(model_path)
121
+
122
+ # Check if VQA model is loaded
123
+ oculus.load_language_model(device=device)
124
+
125
+ # We fine-tune the VQA component specifically
126
+ vqa_model = oculus.lm_vqa_model
127
+ vqa_model.train()
128
+
129
+ optimizer = AdamW(vqa_model.parameters(), lr=2e-5)
130
+
131
+ # Dataset - Use 5000 real samples for instruction tuning
132
+ dataset = InstructionDataset(oculus.lm_vqa_processor, max_samples=5000)
133
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
134
+
135
+ print("\n🚀 Starting Instruction Tuning (Reasoning Module)...")
136
+ epochs = 4
137
+
138
+ for epoch in range(epochs):
139
+ total_loss = 0
140
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
141
+
142
+ for batch in pbar:
143
+ batch = {k: v.to(device) for k, v in batch.items()}
144
+
145
+ # Forward pass
146
+ outputs = vqa_model(**batch)
147
+ loss = outputs.loss
148
+
149
+ # Backward
150
+ loss.backward()
151
+ optimizer.step()
152
+ optimizer.zero_grad()
153
+
154
+ total_loss += loss.item()
155
+ pbar.set_postfix(loss=loss.item())
156
+
157
+ avg_loss = total_loss / len(dataloader)
158
+ print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f}")
159
+
160
+ # Save finetuned weights
161
+ output_dir = Path("checkpoints/oculus_instruct_v1")
162
+ output_dir.mkdir(parents=True, exist_ok=True)
163
+
164
+ print(f"\n💾 Saving tuned VQA model to {output_dir}")
165
+ vqa_model.save_pretrained(output_dir / "vqa_model")
166
+ oculus.lm_vqa_processor.save_pretrained(output_dir / "vqa_model")
167
+
168
+ print("✅ Instruction Tuning Complete!")
169
+
170
+ if __name__ == "__main__":
171
+ train()