kobiakor15 commited on
Commit
3ab6ebf
·
verified ·
1 Parent(s): fbcbc74

Upload training/train_detection.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/train_detection.py +491 -0
training/train_detection.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OCULUS Detection Head Training
4
+
5
+ Trains the detection (box) and point heads on COCO detection data.
6
+ Uses the frozen vision encoders + trained projector, only trains the heads.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import json
12
+ import time
13
+ import random
14
+ from pathlib import Path
15
+ from dataclasses import dataclass
16
+ from typing import List, Dict, Tuple, Optional
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torch.utils.data import Dataset, DataLoader
23
+ from PIL import Image
24
+
25
+ OCULUS_ROOT = Path(__file__).parent
26
+
27
+ # Add to path
28
+ sys.path.insert(0, str(OCULUS_ROOT))
29
+
30
+ from oculus_unified_model import OculusForConditionalGeneration, OculusConfig
31
+
32
+
33
+ @dataclass
34
+ class DetectionTrainingConfig:
35
+ """Training configuration."""
36
+ # Data
37
+ data_dir: str = "data/coco"
38
+ annotations_file: str = "annotations/instances_train2017.json"
39
+ images_subdir: str = "images"
40
+
41
+ # Training
42
+ batch_size: int = 4
43
+ learning_rate: float = 1e-4
44
+ num_epochs: int = 3
45
+ warmup_steps: int = 100
46
+ max_samples: int = 3000 # Limit for faster training
47
+
48
+ # Model
49
+ checkpoint_path: str = "checkpoints/oculus_coco/final"
50
+
51
+ # Checkpointing
52
+ save_every: int = 200
53
+ checkpoint_dir: str = "checkpoints/oculus_detection"
54
+
55
+ # Logging
56
+ log_every: int = 25
57
+
58
+
59
+ class COCODetectionDataset(Dataset):
60
+ """COCO Detection dataset."""
61
+
62
+ # COCO 80 class names
63
+ COCO_CLASSES = [
64
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
65
+ 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench',
66
+ 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
67
+ 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
68
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
69
+ 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
70
+ 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
71
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
72
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
73
+ 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
74
+ 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
75
+ 'toothbrush'
76
+ ]
77
+
78
+ def __init__(self, data_dir: str, annotations_file: str, images_subdir: str, max_samples: int = None):
79
+ self.data_dir = Path(data_dir)
80
+ self.images_dir = self.data_dir / images_subdir
81
+
82
+ # Load annotations
83
+ annotations_path = self.data_dir / annotations_file
84
+ print(f" Loading annotations from {annotations_path}...")
85
+
86
+ with open(annotations_path) as f:
87
+ coco_data = json.load(f)
88
+
89
+ # Build category ID to index mapping
90
+ self.cat_id_to_idx = {}
91
+ for i, cat in enumerate(coco_data['categories']):
92
+ self.cat_id_to_idx[cat['id']] = i
93
+
94
+ # Build image ID to annotations mapping
95
+ img_to_anns = {}
96
+ for ann in coco_data['annotations']:
97
+ img_id = ann['image_id']
98
+ if img_id not in img_to_anns:
99
+ img_to_anns[img_id] = []
100
+ img_to_anns[img_id].append(ann)
101
+
102
+ # Build samples list
103
+ self.samples = []
104
+ for img_info in coco_data['images']:
105
+ img_id = img_info['id']
106
+ if img_id not in img_to_anns:
107
+ continue
108
+
109
+ # Check if image exists
110
+ img_path = self.images_dir / img_info['file_name']
111
+ if not img_path.exists():
112
+ continue
113
+
114
+ anns = img_to_anns[img_id]
115
+
116
+ # Convert annotations to boxes
117
+ boxes = []
118
+ labels = []
119
+ for ann in anns:
120
+ if 'bbox' not in ann or ann.get('iscrowd', 0):
121
+ continue
122
+
123
+ # COCO bbox format: [x, y, width, height]
124
+ x, y, w, h = ann['bbox']
125
+
126
+ # Convert to normalized [x1, y1, x2, y2]
127
+ x1 = x / img_info['width']
128
+ y1 = y / img_info['height']
129
+ x2 = (x + w) / img_info['width']
130
+ y2 = (y + h) / img_info['height']
131
+
132
+ # Clamp to [0, 1]
133
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(1, x2), min(1, y2)
134
+
135
+ boxes.append([x1, y1, x2, y2])
136
+ labels.append(self.cat_id_to_idx[ann['category_id']])
137
+
138
+ if boxes:
139
+ self.samples.append({
140
+ 'image_path': str(img_path),
141
+ 'boxes': boxes,
142
+ 'labels': labels,
143
+ 'width': img_info['width'],
144
+ 'height': img_info['height']
145
+ })
146
+
147
+ if max_samples and len(self.samples) >= max_samples:
148
+ break
149
+
150
+ print(f" Loaded {len(self.samples):,} images with detections")
151
+
152
+ def __len__(self):
153
+ return len(self.samples)
154
+
155
+ def __getitem__(self, idx):
156
+ return self.samples[idx]
157
+
158
+
159
+ class DetectionTrainer:
160
+ """Trainer for detection heads."""
161
+
162
+ def __init__(self, config: DetectionTrainingConfig):
163
+ self.config = config
164
+
165
+ print("\n" + "=" * 60)
166
+ print("🎯 OCULUS DETECTION TRAINER")
167
+ print("=" * 60)
168
+
169
+ self._load_model()
170
+ self._load_dataset()
171
+ self._create_optimizer()
172
+
173
+ self.checkpoint_dir = Path(config.checkpoint_dir)
174
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
175
+
176
+ def _load_model(self):
177
+ """Load model with trained projector."""
178
+ print("\n[Loading Model]")
179
+
180
+ checkpoint_path = OCULUS_ROOT / self.config.checkpoint_path
181
+ self.model = OculusForConditionalGeneration.from_pretrained(checkpoint_path)
182
+
183
+ # Load vision encoders
184
+ self.model.vision_encoder.load_encoders()
185
+
186
+ # Freeze vision encoder and projector
187
+ for param in self.model.vision_encoder.parameters():
188
+ param.requires_grad = False
189
+ for param in self.model.projector.parameters():
190
+ param.requires_grad = False
191
+
192
+ # Make sure detection/point heads are trainable
193
+ for param in self.model.detection_head.parameters():
194
+ param.requires_grad = True
195
+ for param in self.model.point_head.parameters():
196
+ param.requires_grad = True
197
+
198
+ # Count trainable params
199
+ trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
200
+ total = sum(p.numel() for p in self.model.parameters())
201
+ print(f" ✓ Trainable: {trainable:,} / {total:,} parameters")
202
+
203
+ def _load_dataset(self):
204
+ """Load COCO detection dataset."""
205
+ print("\n[Loading Dataset]")
206
+ self.dataset = COCODetectionDataset(
207
+ self.config.data_dir,
208
+ self.config.annotations_file,
209
+ self.config.images_subdir,
210
+ max_samples=self.config.max_samples
211
+ )
212
+
213
+ def _create_optimizer(self):
214
+ """Create optimizer for detection heads only."""
215
+ print("\n[Optimizer]")
216
+
217
+ # Only optimize detection heads
218
+ params = list(self.model.detection_head.parameters()) + \
219
+ list(self.model.point_head.parameters())
220
+
221
+ if self.model.vision_adapter is not None:
222
+ params += list(self.model.vision_adapter.parameters())
223
+
224
+ self.optimizer = torch.optim.AdamW(params, lr=self.config.learning_rate, weight_decay=0.01)
225
+ print(f" ✓ AdamW (lr={self.config.learning_rate})")
226
+
227
+ def encode_image(self, image_path: str) -> torch.Tensor:
228
+ """Encode image to vision tokens."""
229
+ image = Image.open(image_path).convert('RGB')
230
+
231
+ with torch.no_grad():
232
+ vision_tokens = self.model.encode_image(image)
233
+
234
+ return vision_tokens
235
+
236
+ def compute_detection_loss(
237
+ self,
238
+ vision_tokens: torch.Tensor,
239
+ target_boxes: List[List[float]],
240
+ target_labels: List[int]
241
+ ) -> Tuple[torch.Tensor, Dict]:
242
+ """Compute detection loss."""
243
+
244
+ # Get predictions
245
+ cls_logits, box_preds = self.model.detection_head(vision_tokens)
246
+
247
+ batch_size = vision_tokens.shape[0]
248
+ num_tokens = vision_tokens.shape[1]
249
+
250
+ # For each ground truth box, assign it to the nearest predicted "slot"
251
+ total_cls_loss = 0
252
+ total_box_loss = 0
253
+ num_matches = 0
254
+
255
+ target_boxes_t = torch.tensor(target_boxes, dtype=torch.float32)
256
+ target_labels_t = torch.tensor(target_labels, dtype=torch.long)
257
+
258
+ for i in range(batch_size):
259
+ if len(target_boxes) == 0:
260
+ continue
261
+
262
+ # Get predictions for this sample
263
+ pred_boxes = box_preds[i] # [num_tokens, 4]
264
+ pred_cls = cls_logits[i] # [num_tokens, num_classes]
265
+
266
+ # For each GT box, find best matching prediction
267
+ for gt_idx, (gt_box, gt_label) in enumerate(zip(target_boxes, target_labels)):
268
+ gt_box_t = torch.tensor(gt_box, dtype=torch.float32)
269
+
270
+ # Compute IoU with all predictions
271
+ ious = self._compute_iou(pred_boxes, gt_box_t.unsqueeze(0).expand(num_tokens, -1))
272
+
273
+ # Find best match
274
+ best_idx = ious.argmax()
275
+
276
+ # Classification loss for best match
277
+ cls_loss = F.cross_entropy(
278
+ pred_cls[best_idx:best_idx+1],
279
+ torch.tensor([gt_label], dtype=torch.long)
280
+ )
281
+
282
+ # Box regression loss (L1)
283
+ box_loss = F.l1_loss(pred_boxes[best_idx], gt_box_t)
284
+
285
+ total_cls_loss += cls_loss
286
+ total_box_loss += box_loss
287
+ num_matches += 1
288
+
289
+ if num_matches > 0:
290
+ total_cls_loss /= num_matches
291
+ total_box_loss /= num_matches
292
+
293
+ # Combined loss
294
+ total_loss = total_cls_loss + 5.0 * total_box_loss # Weight box loss higher
295
+
296
+ return total_loss, {
297
+ 'cls_loss': float(total_cls_loss) if num_matches > 0 else 0,
298
+ 'box_loss': float(total_box_loss) if num_matches > 0 else 0,
299
+ 'num_matches': num_matches
300
+ }
301
+
302
+ def _compute_iou(self, boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor:
303
+ """Compute IoU between two sets of boxes."""
304
+ # boxes format: [x1, y1, x2, y2]
305
+ x1 = torch.max(boxes1[:, 0], boxes2[:, 0])
306
+ y1 = torch.max(boxes1[:, 1], boxes2[:, 1])
307
+ x2 = torch.min(boxes1[:, 2], boxes2[:, 2])
308
+ y2 = torch.min(boxes1[:, 3], boxes2[:, 3])
309
+
310
+ inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)
311
+
312
+ area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1])
313
+ area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1])
314
+
315
+ union_area = area1 + area2 - inter_area + 1e-8
316
+
317
+ return inter_area / union_area
318
+
319
+ def train_step(self, sample: Dict) -> Tuple[float, Dict]:
320
+ """Single training step."""
321
+
322
+ self.optimizer.zero_grad()
323
+
324
+ try:
325
+ # Encode image (with gradients through adapter if needed)
326
+ image = Image.open(sample['image_path']).convert('RGB')
327
+
328
+ # Get vision features from frozen encoders
329
+ with torch.no_grad():
330
+ vision_features = self.model.vision_encoder(image)
331
+
332
+ # Check for dimension mismatch and create adapter
333
+ actual_dim = vision_features.shape[-1]
334
+ expected_dim = self.model.config.fused_vision_dim
335
+
336
+ if actual_dim != expected_dim:
337
+ if self.model.vision_adapter is None:
338
+ print(f" [Adapter] Creating: {actual_dim} -> {expected_dim}")
339
+ self.model.vision_adapter = nn.Linear(actual_dim, expected_dim)
340
+ nn.init.xavier_uniform_(self.model.vision_adapter.weight)
341
+ nn.init.zeros_(self.model.vision_adapter.bias)
342
+
343
+ # Add adapter params to optimizer
344
+ self.optimizer.add_param_group({
345
+ 'params': self.model.vision_adapter.parameters()
346
+ })
347
+
348
+ vision_features = self.model.vision_adapter(vision_features)
349
+
350
+ # Project to tokens
351
+ vision_tokens = self.model.projector(vision_features)
352
+
353
+ # Compute detection loss
354
+ loss, metrics = self.compute_detection_loss(
355
+ vision_tokens,
356
+ sample['boxes'],
357
+ sample['labels']
358
+ )
359
+
360
+ if loss.requires_grad:
361
+ loss.backward()
362
+ self.optimizer.step()
363
+
364
+ return float(loss), metrics
365
+
366
+ except Exception as e:
367
+ print(f" ⚠️ Error: {e}")
368
+ return 0.0, {}
369
+
370
+ def save_checkpoint(self, step: int, loss: float):
371
+ """Save checkpoint."""
372
+ checkpoint_path = self.checkpoint_dir / f"step_{step:06d}"
373
+ checkpoint_path.mkdir(exist_ok=True)
374
+
375
+ # Save detection heads
376
+ torch.save({
377
+ 'detection': self.model.detection_head.state_dict(),
378
+ 'point': self.model.point_head.state_dict(),
379
+ 'adapter': self.model.vision_adapter.state_dict() if self.model.vision_adapter else None,
380
+ }, checkpoint_path / "heads.pth")
381
+
382
+ # Save state
383
+ state = {'step': step, 'loss': loss}
384
+ with open(checkpoint_path / "state.json", "w") as f:
385
+ json.dump(state, f, indent=2)
386
+
387
+ print(f" 💾 Checkpoint: {checkpoint_path}")
388
+
389
+ def train(self):
390
+ """Main training loop."""
391
+ print("\n" + "=" * 60)
392
+ print("🚀 STARTING DETECTION TRAINING")
393
+ print("=" * 60)
394
+ print(f" Dataset: {len(self.dataset):,} samples")
395
+ print(f" Epochs: {self.config.num_epochs}")
396
+ print(f" Learning rate: {self.config.learning_rate}")
397
+
398
+ global_step = 0
399
+ best_loss = float('inf')
400
+ start_time = time.time()
401
+
402
+ for epoch in range(self.config.num_epochs):
403
+ print(f"\n📚 Epoch {epoch + 1}/{self.config.num_epochs}")
404
+ print("-" * 40)
405
+
406
+ # Shuffle
407
+ indices = list(range(len(self.dataset)))
408
+ random.shuffle(indices)
409
+
410
+ epoch_loss = 0
411
+ epoch_box_loss = 0
412
+ epoch_cls_loss = 0
413
+ num_batches = 0
414
+
415
+ for i, idx in enumerate(indices):
416
+ sample = self.dataset[idx]
417
+
418
+ loss, metrics = self.train_step(sample)
419
+
420
+ if loss == 0:
421
+ continue
422
+
423
+ epoch_loss += loss
424
+ epoch_box_loss += metrics.get('box_loss', 0)
425
+ epoch_cls_loss += metrics.get('cls_loss', 0)
426
+ num_batches += 1
427
+ global_step += 1
428
+
429
+ # Logging
430
+ if global_step % self.config.log_every == 0:
431
+ elapsed = time.time() - start_time
432
+ avg_loss = epoch_loss / num_batches
433
+ print(f" Step {global_step:5d} | Loss: {loss:.4f} | "
434
+ f"Avg: {avg_loss:.4f} | Box: {metrics.get('box_loss', 0):.4f} | "
435
+ f"Cls: {metrics.get('cls_loss', 0):.4f} | {elapsed:.0f}s")
436
+
437
+ # Checkpointing
438
+ if global_step % self.config.save_every == 0:
439
+ self.save_checkpoint(global_step, loss)
440
+ if loss < best_loss:
441
+ best_loss = loss
442
+
443
+ avg_epoch_loss = epoch_loss / max(num_batches, 1)
444
+ print(f"\n ✓ Epoch {epoch + 1} | Avg loss: {avg_epoch_loss:.4f} | "
445
+ f"Box: {epoch_box_loss/max(num_batches,1):.4f} | "
446
+ f"Cls: {epoch_cls_loss/max(num_batches,1):.4f}")
447
+
448
+ # Final save
449
+ print("\n" + "=" * 60)
450
+ print("💾 Saving Final Model")
451
+ print("=" * 60)
452
+
453
+ final_path = self.checkpoint_dir / "final"
454
+ final_path.mkdir(exist_ok=True)
455
+
456
+ # Save heads
457
+ torch.save({
458
+ 'detection': self.model.detection_head.state_dict(),
459
+ 'point': self.model.point_head.state_dict(),
460
+ 'adapter': self.model.vision_adapter.state_dict() if self.model.vision_adapter else None,
461
+ }, final_path / "heads.pth")
462
+
463
+ # Also copy over the projector
464
+ import shutil
465
+ src_projector = OCULUS_ROOT / self.config.checkpoint_path / "projector.npz"
466
+ src_config = OCULUS_ROOT / self.config.checkpoint_path / "config.json"
467
+ if src_projector.exists():
468
+ shutil.copy(src_projector, final_path / "projector.npz")
469
+ if src_config.exists():
470
+ shutil.copy(src_config, final_path / "config.json")
471
+
472
+ print(f"✅ Training complete! Model: {final_path}")
473
+ return final_path
474
+
475
+
476
+ def main():
477
+ config = DetectionTrainingConfig(
478
+ data_dir="data/coco",
479
+ max_samples=2000, # Start smaller for faster iteration
480
+ num_epochs=2,
481
+ learning_rate=5e-4,
482
+ save_every=200,
483
+ log_every=25,
484
+ )
485
+
486
+ trainer = DetectionTrainer(config)
487
+ trainer.train()
488
+
489
+
490
+ if __name__ == "__main__":
491
+ main()