kobiakor15 commited on
Commit
989f87b
·
verified ·
1 Parent(s): 3ab6ebf

Upload training/train_detection_extended.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/train_detection_extended.py +485 -0
training/train_detection_extended.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ OCULUS Extended Detection Training
4
+
5
+ Longer training with more data for better detection accuracy.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import json
11
+ import time
12
+ import random
13
+ from pathlib import Path
14
+ from dataclasses import dataclass
15
+ from typing import List, Dict, Tuple, Optional
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.utils.data import Dataset, DataLoader
22
+ from PIL import Image
23
+
24
+ OCULUS_ROOT = Path(__file__).parent
25
+ sys.path.insert(0, str(OCULUS_ROOT))
26
+
27
+ from oculus_unified_model import OculusForConditionalGeneration, OculusConfig
28
+
29
+
30
+ @dataclass
31
+ class ExtendedTrainingConfig:
32
+ """Extended training configuration."""
33
+ # Data
34
+ data_dir: str = "data/coco"
35
+ annotations_file: str = "annotations/instances_train2017.json"
36
+ images_subdir: str = "images"
37
+
38
+ # Training - EXTENDED
39
+ batch_size: int = 1
40
+ learning_rate: float = 3e-4
41
+ num_epochs: int = 5
42
+ warmup_steps: int = 200
43
+ max_samples: int = 8000 # More data
44
+
45
+ # Model
46
+ checkpoint_path: str = "checkpoints/oculus_detection/final"
47
+
48
+ # Checkpointing
49
+ save_every: int = 500
50
+ checkpoint_dir: str = "checkpoints/oculus_detection_v2"
51
+
52
+ # Logging
53
+ log_every: int = 50
54
+
55
+
56
+ class COCODetectionDataset:
57
+ """COCO Detection dataset."""
58
+
59
+ COCO_CLASSES = [
60
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
61
+ 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench',
62
+ 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
63
+ 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
64
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
65
+ 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
66
+ 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
67
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
68
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
69
+ 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
70
+ 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
71
+ 'toothbrush'
72
+ ]
73
+
74
+ def __init__(self, data_dir: str, annotations_file: str, images_subdir: str, max_samples: int = None):
75
+ self.data_dir = Path(data_dir)
76
+ self.images_dir = self.data_dir / images_subdir
77
+
78
+ annotations_path = self.data_dir / annotations_file
79
+ print(f" Loading annotations from {annotations_path}...")
80
+
81
+ with open(annotations_path) as f:
82
+ coco_data = json.load(f)
83
+
84
+ self.cat_id_to_idx = {}
85
+ for i, cat in enumerate(coco_data['categories']):
86
+ self.cat_id_to_idx[cat['id']] = i
87
+
88
+ img_to_anns = {}
89
+ for ann in coco_data['annotations']:
90
+ img_id = ann['image_id']
91
+ if img_id not in img_to_anns:
92
+ img_to_anns[img_id] = []
93
+ img_to_anns[img_id].append(ann)
94
+
95
+ self.samples = []
96
+ for img_info in coco_data['images']:
97
+ img_id = img_info['id']
98
+ if img_id not in img_to_anns:
99
+ continue
100
+
101
+ img_path = self.images_dir / img_info['file_name']
102
+ if not img_path.exists():
103
+ continue
104
+
105
+ anns = img_to_anns[img_id]
106
+ boxes = []
107
+ labels = []
108
+ for ann in anns:
109
+ if 'bbox' not in ann or ann.get('iscrowd', 0):
110
+ continue
111
+
112
+ x, y, w, h = ann['bbox']
113
+ x1 = x / img_info['width']
114
+ y1 = y / img_info['height']
115
+ x2 = (x + w) / img_info['width']
116
+ y2 = (y + h) / img_info['height']
117
+
118
+ x1, y1, x2, y2 = max(0, x1), max(0, y1), min(1, x2), min(1, y2)
119
+
120
+ boxes.append([x1, y1, x2, y2])
121
+ labels.append(self.cat_id_to_idx[ann['category_id']])
122
+
123
+ if boxes:
124
+ self.samples.append({
125
+ 'image_path': str(img_path),
126
+ 'boxes': boxes,
127
+ 'labels': labels,
128
+ 'width': img_info['width'],
129
+ 'height': img_info['height']
130
+ })
131
+
132
+ if max_samples and len(self.samples) >= max_samples:
133
+ break
134
+
135
+ print(f" Loaded {len(self.samples):,} images with detections")
136
+
137
+ def __len__(self):
138
+ return len(self.samples)
139
+
140
+ def __getitem__(self, idx):
141
+ return self.samples[idx]
142
+
143
+
144
+ class ExtendedTrainer:
145
+ """Extended trainer with better loss functions."""
146
+
147
+ def __init__(self, config: ExtendedTrainingConfig):
148
+ self.config = config
149
+
150
+ print("\n" + "=" * 60)
151
+ print("🎯 OCULUS EXTENDED DETECTION TRAINER")
152
+ print("=" * 60)
153
+
154
+ self._load_model()
155
+ self._load_dataset()
156
+ self._create_optimizer()
157
+
158
+ self.checkpoint_dir = Path(config.checkpoint_dir)
159
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
160
+
161
+ def _load_model(self):
162
+ """Load model with trained projector and heads."""
163
+ print("\n[Loading Model]")
164
+
165
+ # Try to resume from V2 checkpoint first
166
+ v2_checkpoint = Path("checkpoints/oculus_detection_v2/final")
167
+ if v2_checkpoint.exists():
168
+ print(f" ✨ Resuming from V2 checkpoint: {v2_checkpoint}")
169
+ checkpoint_path = v2_checkpoint
170
+ else:
171
+ checkpoint_path = OCULUS_ROOT / self.config.checkpoint_path
172
+
173
+ self.model = OculusForConditionalGeneration.from_pretrained(checkpoint_path)
174
+
175
+ # Load existing detection heads
176
+ heads_path = checkpoint_path / "heads.pth"
177
+ if heads_path.exists():
178
+ heads = torch.load(heads_path)
179
+ self.model.detection_head.load_state_dict(heads['detection'])
180
+ self.model.point_head.load_state_dict(heads['point'])
181
+ print(" ✓ Loaded pre-trained detection heads")
182
+
183
+ # Load vision encoders
184
+ self.model.vision_encoder.load_encoders()
185
+
186
+ # Freeze vision encoder and projector
187
+ for param in self.model.vision_encoder.parameters():
188
+ param.requires_grad = False
189
+ for param in self.model.projector.parameters():
190
+ param.requires_grad = False
191
+
192
+ # Detection heads are trainable
193
+ for param in self.model.detection_head.parameters():
194
+ param.requires_grad = True
195
+ for param in self.model.point_head.parameters():
196
+ param.requires_grad = True
197
+
198
+ trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
199
+ total = sum(p.numel() for p in self.model.parameters())
200
+ print(f" ✓ Trainable: {trainable:,} / {total:,} parameters")
201
+
202
+ def _load_dataset(self):
203
+ """Load COCO detection dataset."""
204
+ print("\n[Loading Dataset]")
205
+ self.dataset = COCODetectionDataset(
206
+ self.config.data_dir,
207
+ self.config.annotations_file,
208
+ self.config.images_subdir,
209
+ max_samples=self.config.max_samples
210
+ )
211
+
212
+ def _create_optimizer(self):
213
+ """Create optimizer."""
214
+ print("\n[Optimizer]")
215
+
216
+ params = list(self.model.detection_head.parameters()) + \
217
+ list(self.model.point_head.parameters())
218
+
219
+ self.optimizer = torch.optim.AdamW(params, lr=self.config.learning_rate, weight_decay=0.01)
220
+
221
+ # Learning rate scheduler
222
+ total_steps = self.config.num_epochs * len(self.dataset)
223
+ warmup_steps = self.config.warmup_steps
224
+
225
+ def lr_lambda(step):
226
+ if step < warmup_steps:
227
+ return step / warmup_steps
228
+ return max(0.1, 1.0 - (step - warmup_steps) / (total_steps - warmup_steps))
229
+
230
+ self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
231
+
232
+ print(f" ✓ AdamW (lr={self.config.learning_rate}) + scheduler")
233
+
234
+ def _compute_iou(self, box1: torch.Tensor, box2: torch.Tensor) -> torch.Tensor:
235
+ """Compute IoU between two boxes [x1, y1, x2, y2]."""
236
+ x1 = torch.max(box1[0], box2[0])
237
+ y1 = torch.max(box1[1], box2[1])
238
+ x2 = torch.min(box1[2], box2[2])
239
+ y2 = torch.min(box1[3], box2[3])
240
+
241
+ inter_w = torch.clamp(x2 - x1, min=0)
242
+ inter_h = torch.clamp(y2 - y1, min=0)
243
+ inter_area = inter_w * inter_h
244
+
245
+ area1 = torch.clamp((box1[2] - box1[0]) * (box1[3] - box1[1]), min=1e-8)
246
+ area2 = torch.clamp((box2[2] - box2[0]) * (box2[3] - box2[1]), min=1e-8)
247
+
248
+ union_area = area1 + area2 - inter_area + 1e-8
249
+ iou = inter_area / union_area
250
+
251
+ return torch.clamp(iou, min=0.0, max=1.0)
252
+
253
+ def compute_loss(
254
+ self,
255
+ vision_tokens: torch.Tensor,
256
+ target_boxes: List[List[float]],
257
+ target_labels: List[int]
258
+ ) -> Tuple[torch.Tensor, Dict]:
259
+ """Compute detection loss with IoU and classification."""
260
+
261
+ cls_logits, box_preds = self.model.detection_head(vision_tokens)
262
+
263
+ num_tokens = vision_tokens.shape[1]
264
+
265
+ total_cls_loss = torch.tensor(0.0, requires_grad=True)
266
+ total_box_loss = torch.tensor(0.0, requires_grad=True)
267
+ total_iou_loss = torch.tensor(0.0, requires_grad=True)
268
+ num_matches = 0
269
+
270
+ for gt_idx, (gt_box, gt_label) in enumerate(zip(target_boxes, target_labels)):
271
+ gt_box_t = torch.tensor(gt_box, dtype=torch.float32)
272
+ gt_label_t = torch.tensor([gt_label], dtype=torch.long)
273
+
274
+ pred_boxes = box_preds[0] # [num_tokens, 4]
275
+
276
+ # Find best matching prediction using IoU
277
+ with torch.no_grad():
278
+ ious = []
279
+ for j in range(num_tokens):
280
+ iou = self._compute_iou(pred_boxes[j], gt_box_t)
281
+ ious.append(float(iou.detach()))
282
+ best_idx = int(np.argmax(ious))
283
+
284
+ # Classification loss
285
+ cls_loss = F.cross_entropy(
286
+ cls_logits[0, best_idx:best_idx+1],
287
+ gt_label_t,
288
+ label_smoothing=0.1
289
+ )
290
+
291
+ # Box regression loss (Smooth L1)
292
+ box_loss = F.smooth_l1_loss(pred_boxes[best_idx], gt_box_t)
293
+
294
+ # IoU loss (1 - IoU)
295
+ iou = self._compute_iou(pred_boxes[best_idx], gt_box_t)
296
+ iou_loss = 1.0 - iou
297
+
298
+ total_cls_loss = total_cls_loss + cls_loss
299
+ total_box_loss = total_box_loss + box_loss
300
+ total_iou_loss = total_iou_loss + iou_loss
301
+ num_matches += 1
302
+
303
+ if num_matches > 0:
304
+ total_cls_loss = total_cls_loss / num_matches
305
+ total_box_loss = total_box_loss / num_matches
306
+ total_iou_loss = total_iou_loss / num_matches
307
+
308
+ # Combined loss
309
+ total_loss = total_cls_loss + 5.0 * total_box_loss + 2.0 * total_iou_loss
310
+
311
+ return total_loss, {
312
+ 'cls_loss': float(total_cls_loss.detach()),
313
+ 'box_loss': float(total_box_loss.detach()),
314
+ 'iou_loss': float(total_iou_loss.detach()),
315
+ 'num_matches': num_matches
316
+ }
317
+
318
+ def train_step(self, sample: Dict) -> Tuple[float, Dict]:
319
+ """Single training step."""
320
+
321
+ self.optimizer.zero_grad()
322
+
323
+ try:
324
+ image = Image.open(sample['image_path']).convert('RGB')
325
+
326
+ with torch.no_grad():
327
+ vision_features = self.model.vision_encoder(image)
328
+
329
+ actual_dim = vision_features.shape[-1]
330
+ expected_dim = self.model.config.fused_vision_dim
331
+
332
+ if actual_dim != expected_dim:
333
+ if self.model.vision_adapter is None:
334
+ self.model.vision_adapter = nn.Linear(actual_dim, expected_dim)
335
+ nn.init.xavier_uniform_(self.model.vision_adapter.weight)
336
+ nn.init.zeros_(self.model.vision_adapter.bias)
337
+ self.optimizer.add_param_group({
338
+ 'params': self.model.vision_adapter.parameters()
339
+ })
340
+
341
+ vision_features = self.model.vision_adapter(vision_features)
342
+
343
+ vision_tokens = self.model.projector(vision_features)
344
+
345
+ loss, metrics = self.compute_loss(
346
+ vision_tokens,
347
+ sample['boxes'],
348
+ sample['labels']
349
+ )
350
+
351
+ if loss.requires_grad:
352
+ loss.backward()
353
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
354
+ self.optimizer.step()
355
+ self.scheduler.step()
356
+
357
+ return float(loss.detach()), metrics
358
+
359
+ except Exception as e:
360
+ return 0.0, {}
361
+
362
+ def save_checkpoint(self, step: int, loss: float, is_final: bool = False):
363
+ """Save checkpoint."""
364
+ if is_final:
365
+ checkpoint_path = self.checkpoint_dir / "final"
366
+ else:
367
+ checkpoint_path = self.checkpoint_dir / f"step_{step:06d}"
368
+
369
+ checkpoint_path.mkdir(exist_ok=True)
370
+
371
+ torch.save({
372
+ 'detection': self.model.detection_head.state_dict(),
373
+ 'point': self.model.point_head.state_dict(),
374
+ 'adapter': self.model.vision_adapter.state_dict() if self.model.vision_adapter else None,
375
+ }, checkpoint_path / "heads.pth")
376
+
377
+ # Copy projector config
378
+ import shutil
379
+ src_projector = OCULUS_ROOT / self.config.checkpoint_path / "projector.npz"
380
+ src_config = OCULUS_ROOT / self.config.checkpoint_path / "config.json"
381
+ if src_projector.exists():
382
+ shutil.copy(src_projector, checkpoint_path / "projector.npz")
383
+ if src_config.exists():
384
+ shutil.copy(src_config, checkpoint_path / "config.json")
385
+
386
+ state = {'step': step, 'loss': loss}
387
+ with open(checkpoint_path / "state.json", "w") as f:
388
+ json.dump(state, f, indent=2)
389
+
390
+ print(f" 💾 Checkpoint: {checkpoint_path}")
391
+
392
+ def train(self):
393
+ """Main training loop."""
394
+ print("\n" + "=" * 60)
395
+ print("🚀 STARTING EXTENDED TRAINING")
396
+ print("=" * 60)
397
+ print(f" Dataset: {len(self.dataset):,} samples")
398
+ print(f" Epochs: {self.config.num_epochs}")
399
+ print(f" Learning rate: {self.config.learning_rate}")
400
+
401
+ global_step = 0
402
+ best_loss = float('inf')
403
+ start_time = time.time()
404
+
405
+ for epoch in range(self.config.num_epochs):
406
+ print(f"\n📚 Epoch {epoch + 1}/{self.config.num_epochs}")
407
+ print("-" * 40)
408
+
409
+ indices = list(range(len(self.dataset)))
410
+ random.shuffle(indices)
411
+
412
+ epoch_loss = 0
413
+ epoch_cls = 0
414
+ epoch_box = 0
415
+ epoch_giou = 0
416
+ num_batches = 0
417
+
418
+ for i, idx in enumerate(indices):
419
+ sample = self.dataset[idx]
420
+
421
+ loss, metrics = self.train_step(sample)
422
+
423
+ if loss == 0:
424
+ continue
425
+
426
+ epoch_loss += loss
427
+ epoch_cls += metrics.get('cls_loss', 0)
428
+ epoch_box += metrics.get('box_loss', 0)
429
+ epoch_giou += metrics.get('giou_loss', 0)
430
+ num_batches += 1
431
+ global_step += 1
432
+
433
+ if global_step % self.config.log_every == 0:
434
+ elapsed = time.time() - start_time
435
+ avg_loss = epoch_loss / num_batches
436
+ lr = self.scheduler.get_last_lr()[0]
437
+ print(f" Step {global_step:5d} | Loss: {loss:.4f} | Avg: {avg_loss:.4f} | "
438
+ f"Cls: {metrics.get('cls_loss', 0):.3f} | Box: {metrics.get('box_loss', 0):.3f} | "
439
+ f"IoU: {metrics.get('iou_loss', 0):.3f} | LR: {lr:.6f} | {elapsed:.0f}s")
440
+
441
+ if global_step % self.config.save_every == 0:
442
+ self.save_checkpoint(global_step, loss)
443
+ if loss < best_loss:
444
+ best_loss = loss
445
+
446
+ avg_epoch_loss = epoch_loss / max(num_batches, 1)
447
+ print(f"\n ✓ Epoch {epoch + 1} | Avg: {avg_epoch_loss:.4f} | "
448
+ f"Cls: {epoch_cls/max(num_batches,1):.3f} | "
449
+ f"Box: {epoch_box/max(num_batches,1):.3f} | "
450
+ f"GIoU: {epoch_giou/max(num_batches,1):.3f}")
451
+
452
+ print("\n" + "=" * 60)
453
+ print("💾 Saving Final Model")
454
+ print("=" * 60)
455
+
456
+ self.save_checkpoint(global_step, avg_epoch_loss, is_final=True)
457
+
458
+ print(f"✅ Training complete! Model: {self.checkpoint_dir / 'final'}")
459
+ return self.checkpoint_dir / "final"
460
+
461
+
462
+ def main():
463
+ config = ExtendedTrainingConfig(
464
+ data_dir="data/coco",
465
+ max_samples=5000, # More data
466
+ num_epochs=4, # More epochs
467
+ learning_rate=3e-4,
468
+ save_every=500,
469
+ log_every=50,
470
+ )
471
+
472
+ trainer = ExtendedTrainer(config)
473
+ model_path = trainer.train()
474
+
475
+ # Run benchmarks after training
476
+ print("\n" + "=" * 60)
477
+ print("📊 RUNNING BENCHMARKS")
478
+ print("=" * 60)
479
+
480
+ from eval_benchmarks import run_benchmarks
481
+ run_benchmarks(str(model_path), benchmarks=['coco', 'counting', 'vqa'])
482
+
483
+
484
+ if __name__ == "__main__":
485
+ main()