Zhen Ye commited on
Commit
94c85d4
·
1 Parent(s): 469102e

using apple depth pro hf

Browse files
Files changed (4) hide show
  1. CLAUDE.md +254 -0
  2. demo.html +114 -10
  3. models/depth_estimators/depth_pro.py +34 -25
  4. requirements.txt +0 -1
CLAUDE.md ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ Simple video object detection system with three modes:
8
+ - **Object Detection**: Detect custom objects using text queries (fully functional)
9
+ - **Segmentation**: Mask overlays using SAM3
10
+ - **Drone Detection**: (Coming Soon) Specialized UAV detection
11
+
12
+ ## Core Architecture
13
+
14
+ ### Simple Detection Flow
15
+
16
+ ```
17
+ User → demo.html → POST /detect → inference.py → detector → processed video
18
+ ```
19
+
20
+ 1. User selects mode and uploads video via web interface
21
+ 2. Frontend sends video + mode + queries to `/detect` endpoint
22
+ 3. Backend runs detection inference with selected model
23
+ 4. Returns processed video with bounding boxes
24
+
25
+ ### Available Detectors
26
+
27
+ The system includes 4 pre-trained object detection models:
28
+
29
+ | Detector | Key | Type | Best For |
30
+ |----------|-----|------|----------|
31
+ | **OWLv2** | `owlv2_base` | Open-vocabulary | Custom text queries (default) |
32
+ | **YOLOv8** | `hf_yolov8` | COCO classes | Fast real-time detection |
33
+ | **DETR** | `detr_resnet50` | COCO classes | Transformer-based detection |
34
+ | **Grounding DINO** | `grounding_dino` | Open-vocabulary | Text-grounded detection |
35
+
36
+ All detectors implement the `ObjectDetector` interface in `models/detectors/base.py` with a single `predict()` method.
37
+
38
+ ## Development Commands
39
+
40
+ ### Setup
41
+ ```bash
42
+ python -m venv .venv
43
+ source .venv/bin/activate # or `.venv/bin/activate` on macOS/Linux
44
+ pip install -r requirements.txt
45
+ ```
46
+
47
+ ### Running the Server
48
+ ```bash
49
+ # Development
50
+ uvicorn app:app --host 0.0.0.0 --port 7860 --reload
51
+
52
+ # Production (Docker)
53
+ docker build -t object_detectors .
54
+ docker run -p 7860:7860 object_detectors
55
+ ```
56
+
57
+ ### Testing the API
58
+ ```bash
59
+ # Test object detection
60
+ curl -X POST http://localhost:7860/detect \
61
+ -F "video=@sample.mp4" \
62
+ -F "mode=object_detection" \
63
+ -F "queries=person,car,dog" \
64
+ -F "detector=owlv2_base" \
65
+ --output processed.mp4
66
+
67
+ # Test placeholder modes (returns JSON)
68
+ curl -X POST http://localhost:7860/detect \
69
+ -F "video=@sample.mp4" \
70
+ -F "mode=segmentation"
71
+ ```
72
+
73
+ ## Key Implementation Details
74
+
75
+ ### API Endpoint: `/detect`
76
+
77
+ **Parameters:**
78
+ - `video` (file): Video file to process
79
+ - `mode` (string): Detection mode - `object_detection`, `segmentation`, or `drone_detection`
80
+ - `queries` (string): Comma-separated object classes (for object_detection mode)
81
+ - `detector` (string): Model key (default: `owlv2_base`)
82
+
83
+ **Returns:**
84
+ - For `object_detection`: MP4 video with bounding boxes
85
+ - For `segmentation`: MP4 video with mask overlays
86
+ - For `drone_detection`: JSON with `{"status": "coming_soon", "message": "..."}`
87
+
88
+ ### Inference Pipeline
89
+
90
+ The `run_inference()` function in `inference.py` follows these steps:
91
+
92
+ 1. **Extract Frames**: Decode video using OpenCV
93
+ 2. **Parse Queries**: Split comma-separated text into list (defaults to common objects if empty)
94
+ 3. **Select Detector**: Load detector by key (cached via `@lru_cache`)
95
+ 4. **Process Frames**: Run detection on each frame
96
+ - Call `detector.predict(frame, queries)`
97
+ - Draw green bounding boxes on detections
98
+ 5. **Write Video**: Encode processed frames back to MP4
99
+
100
+ Default queries (if none provided): `["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"]`
101
+
102
+ ### Detector Loading
103
+
104
+ Detectors are registered in `models/model_loader.py`:
105
+
106
+ ```python
107
+ _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
108
+ "owlv2_base": Owlv2Detector,
109
+ "hf_yolov8": HuggingFaceYoloV8Detector,
110
+ "detr_resnet50": DetrDetector,
111
+ "grounding_dino": GroundingDinoDetector,
112
+ }
113
+ ```
114
+
115
+ Loaded via `load_detector(name)` which caches instances for performance.
116
+
117
+ ### Detection Result Format
118
+
119
+ All detectors return a `DetectionResult` namedtuple:
120
+ ```python
121
+ DetectionResult(
122
+ boxes: np.ndarray, # Nx4 array [x1, y1, x2, y2]
123
+ scores: Sequence[float], # Confidence scores
124
+ labels: Sequence[int], # Class indices
125
+ label_names: Optional[Sequence[str]] # Human-readable names
126
+ )
127
+ ```
128
+
129
+ ## File Structure
130
+
131
+ ```
132
+ .
133
+ ├── app.py # FastAPI server with /detect endpoint
134
+ ├── inference.py # Video processing and detection pipeline
135
+ ├── demo.html # Web interface with mode selector
136
+ ├── requirements.txt # Python dependencies
137
+ ├── models/
138
+ │ ├── model_loader.py # Detector registry and loading
139
+ │ └── detectors/
140
+ │ ├── base.py # ObjectDetector interface
141
+ │ ├── owlv2.py # OWLv2 implementation
142
+ │ ├── yolov8.py # YOLOv8 implementation
143
+ │ ├── detr.py # DETR implementation
144
+ │ └── grounding_dino.py # Grounding DINO implementation
145
+ ├── utils/
146
+ │ └── video.py # Video encoding/decoding utilities
147
+ └── coco_classes.py # COCO dataset class definitions
148
+ ```
149
+
150
+ ## Adding New Detectors
151
+
152
+ To add a new detector:
153
+
154
+ 1. **Create detector class** in `models/detectors/`:
155
+ ```python
156
+ from .base import ObjectDetector, DetectionResult
157
+
158
+ class MyDetector(ObjectDetector):
159
+ name = "my_detector"
160
+
161
+ def predict(self, frame, queries):
162
+ # Your detection logic
163
+ return DetectionResult(boxes, scores, labels, label_names)
164
+ ```
165
+
166
+ 2. **Register in model_loader.py**:
167
+ ```python
168
+ _REGISTRY = {
169
+ ...
170
+ "my_detector": MyDetector,
171
+ }
172
+ ```
173
+
174
+ 3. **Update frontend** `demo.html` detector dropdown:
175
+ ```html
176
+ <option value="my_detector">My Detector</option>
177
+ ```
178
+
179
+ ## Adding New Detection Modes
180
+
181
+ To implement additional modes such as drone detection:
182
+
183
+ 1. **Create specialized detector** (if needed):
184
+ - For segmentation: Extend `SegmentationResult` to include masks
185
+ - For drone detection: Create `DroneDetector` with specialized filtering
186
+
187
+ 2. **Update `/detect` endpoint** in `app.py`:
188
+ ```python
189
+ if mode == "segmentation":
190
+ # Run segmentation inference
191
+ # Return video with masks rendered
192
+ ```
193
+
194
+ 3. **Update frontend** to remove "disabled" class from mode card
195
+
196
+ 4. **Update inference.py** if needed to handle new output types
197
+
198
+ ## Common Patterns
199
+
200
+ ### Query Processing
201
+ Queries are parsed from comma-separated strings:
202
+ ```python
203
+ queries = [q.strip() for q in "person, car, dog".split(",") if q.strip()]
204
+ # Result: ["person", "car", "dog"]
205
+ ```
206
+
207
+ ### Frame Processing Loop
208
+ Standard pattern for processing video frames:
209
+ ```python
210
+ processed_frames = []
211
+ for idx, frame in enumerate(frames):
212
+ processed_frame, detections = infer_frame(frame, queries, detector_name)
213
+ processed_frames.append(processed_frame)
214
+ ```
215
+
216
+ ### Temporary File Management
217
+ FastAPI's `BackgroundTasks` cleans up temp files after response:
218
+ ```python
219
+ _schedule_cleanup(background_tasks, input_path)
220
+ _schedule_cleanup(background_tasks, output_path)
221
+ ```
222
+
223
+ ## Performance Notes
224
+
225
+ - **Detector Caching**: Models are loaded once and cached via `@lru_cache`
226
+ - **Default Resolution**: Videos processed at original resolution
227
+ - **Frame Limit**: Use `max_frames` parameter in `run_inference()` for testing
228
+ - **Memory Usage**: Entire video is loaded into memory (frames list)
229
+
230
+ ## Troubleshooting
231
+
232
+ ### "No module named 'fastapi'"
233
+ Install dependencies: `pip install -r requirements.txt`
234
+
235
+ ### "Video decoding failed"
236
+ Check video codec compatibility. System expects MP4/H.264.
237
+
238
+ ### "Detector not found"
239
+ Verify detector key exists in `model_loader._REGISTRY`
240
+
241
+ ### Slow processing
242
+ - Try faster detector: YOLOv8 (`hf_yolov8`)
243
+ - Reduce video resolution before uploading
244
+ - Use `max_frames` parameter for testing
245
+
246
+ ## Dependencies
247
+
248
+ Core packages:
249
+ - `fastapi` + `uvicorn`: Web server
250
+ - `torch` + `transformers`: Deep learning models
251
+ - `opencv-python-headless`: Video processing
252
+ - `ultralytics`: YOLOv8 implementation
253
+ - `huggingface-hub`: Model downloading
254
+ - `pillow`, `scipy`, `accelerate`, `timm`: Supporting libraries
demo.html CHANGED
@@ -306,6 +306,31 @@
306
  100% { transform: rotate(360deg); }
307
  }
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  .hidden {
310
  display: none;
311
  }
@@ -415,6 +440,13 @@
415
  <!-- Results -->
416
  <div class="section hidden" id="resultsSection">
417
  <div class="section-title">Results</div>
 
 
 
 
 
 
 
418
  <div class="results-grid">
419
  <div class="video-card">
420
  <div class="video-card-header">First Frame</div>
@@ -466,6 +498,11 @@
466
  // State
467
  let selectedMode = 'object_detection';
468
  let videoFile = null;
 
 
 
 
 
469
 
470
  // Elements
471
  const modeCards = document.querySelectorAll('.mode-card');
@@ -490,8 +527,56 @@
490
  const depthVideo = document.getElementById('depthVideo');
491
  const depthDownloadBtn = document.getElementById('depthDownloadBtn');
492
  const depthVideoStatus = document.getElementById('depthVideoStatus');
 
 
 
493
  let statusPoller = null;
494
  const statusLine = document.getElementById('statusLine');
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  // Mode selection handler
496
  modeCards.forEach(card => {
497
  card.addEventListener('click', (e) => {
@@ -571,6 +656,12 @@
571
  depthDownloadBtn.removeAttribute('href');
572
  depthDownloadBtn.classList.add('hidden');
573
  depthVideoStatus.textContent = '';
 
 
 
 
 
 
574
  statusLine.classList.add('hidden');
575
  statusLine.textContent = '';
576
 
@@ -615,16 +706,22 @@
615
  clearInterval(statusPoller);
616
  statusPoller = null;
617
  statusLine.textContent = 'Status: completed';
 
 
618
  const videoResponse = await fetch(data.video_url);
619
  if (!videoResponse.ok) {
620
  alert('Failed to fetch processed video.');
621
  return;
622
  }
623
  const blob = await videoResponse.blob();
624
- const videoUrl = URL.createObjectURL(blob);
625
- processedVideo.src = videoUrl;
626
- downloadBtn.href = videoUrl;
627
 
 
 
 
 
 
628
  await loadDepthAssets(data);
629
  } else if (statusData.status === 'failed') {
630
  clearInterval(statusPoller);
@@ -662,8 +759,8 @@
662
  const frameResponse = await fetch(jobData.first_frame_depth_url);
663
  if (frameResponse.ok) {
664
  const frameBlob = await frameResponse.blob();
665
- const frameUrl = URL.createObjectURL(frameBlob);
666
- depthFrameImage.src = frameUrl;
667
  depthFrameImage.classList.remove('hidden');
668
  depthFramePlaceholder.classList.add('hidden');
669
  } else {
@@ -678,11 +775,18 @@
678
  const depthResponse = await fetch(jobData.depth_video_url);
679
  if (depthResponse.ok) {
680
  const depthBlob = await depthResponse.blob();
681
- const depthUrl = URL.createObjectURL(depthBlob);
682
- depthVideo.src = depthUrl;
683
- depthVideo.classList.remove('hidden');
684
- depthDownloadBtn.href = depthUrl;
685
- depthDownloadBtn.classList.remove('hidden');
 
 
 
 
 
 
 
686
  } else {
687
  const error = await depthResponse.json();
688
  depthVideoStatus.textContent = error.detail || 'Depth video unavailable.';
 
306
  100% { transform: rotate(360deg); }
307
  }
308
 
309
+ /* View toggle buttons */
310
+ .view-toggle-btn {
311
+ padding: 12px 28px;
312
+ margin: 0 10px;
313
+ background: #e5e7eb;
314
+ color: #374151;
315
+ border: 2px solid #d1d5db;
316
+ border-radius: 8px;
317
+ cursor: pointer;
318
+ font-weight: 600;
319
+ font-size: 14px;
320
+ transition: all 0.3s;
321
+ }
322
+
323
+ .view-toggle-btn.active {
324
+ background: #1f2933;
325
+ color: #f9fafb;
326
+ border-color: #1f2933;
327
+ }
328
+
329
+ .view-toggle-btn:hover:not(.active) {
330
+ background: #d1d5db;
331
+ transform: translateY(-1px);
332
+ }
333
+
334
  .hidden {
335
  display: none;
336
  }
 
440
  <!-- Results -->
441
  <div class="section hidden" id="resultsSection">
442
  <div class="section-title">Results</div>
443
+
444
+ <!-- View Toggle Buttons -->
445
+ <div id="viewToggleContainer" class="hidden" style="text-align: center; margin-bottom: 20px;">
446
+ <button class="view-toggle-btn active" id="detectionViewBtn">Detection View</button>
447
+ <button class="view-toggle-btn" id="depthViewBtn">Depth View</button>
448
+ </div>
449
+
450
  <div class="results-grid">
451
  <div class="video-card">
452
  <div class="video-card-header">First Frame</div>
 
498
  // State
499
  let selectedMode = 'object_detection';
500
  let videoFile = null;
501
+ let currentView = 'detection'; // 'detection' or 'depth'
502
+ let detectionVideoUrl = null;
503
+ let depthVideoUrl = null;
504
+ let detectionFirstFrameUrl = null;
505
+ let depthFirstFrameUrl = null;
506
 
507
  // Elements
508
  const modeCards = document.querySelectorAll('.mode-card');
 
527
  const depthVideo = document.getElementById('depthVideo');
528
  const depthDownloadBtn = document.getElementById('depthDownloadBtn');
529
  const depthVideoStatus = document.getElementById('depthVideoStatus');
530
+ const viewToggleContainer = document.getElementById('viewToggleContainer');
531
+ const detectionViewBtn = document.getElementById('detectionViewBtn');
532
+ const depthViewBtn = document.getElementById('depthViewBtn');
533
  let statusPoller = null;
534
  const statusLine = document.getElementById('statusLine');
535
+
536
+ // View switching function
537
+ function switchToView(view) {
538
+ currentView = view;
539
+
540
+ if (view === 'detection') {
541
+ detectionViewBtn.classList.add('active');
542
+ depthViewBtn.classList.remove('active');
543
+
544
+ if (detectionFirstFrameUrl) {
545
+ firstFrameImage.src = detectionFirstFrameUrl;
546
+ depthFrameImage.classList.add('hidden');
547
+ depthFramePlaceholder.classList.remove('hidden');
548
+ }
549
+ if (detectionVideoUrl) {
550
+ processedVideo.src = detectionVideoUrl;
551
+ downloadBtn.href = detectionVideoUrl;
552
+ downloadBtn.download = 'processed_detection.mp4';
553
+ processedVideo.load();
554
+ }
555
+ } else {
556
+ depthViewBtn.classList.add('active');
557
+ detectionViewBtn.classList.remove('active');
558
+
559
+ if (depthFirstFrameUrl) {
560
+ firstFrameImage.src = depthFirstFrameUrl;
561
+ depthFrameImage.classList.add('hidden');
562
+ depthFramePlaceholder.classList.add('hidden');
563
+ }
564
+ if (depthVideoUrl) {
565
+ processedVideo.src = depthVideoUrl;
566
+ downloadBtn.href = depthVideoUrl;
567
+ downloadBtn.download = 'depth_map.mp4';
568
+ processedVideo.load();
569
+ }
570
+ }
571
+ }
572
+
573
+ // Toggle button event listeners
574
+ if (detectionViewBtn) {
575
+ detectionViewBtn.addEventListener('click', () => switchToView('detection'));
576
+ }
577
+ if (depthViewBtn) {
578
+ depthViewBtn.addEventListener('click', () => switchToView('depth'));
579
+ }
580
  // Mode selection handler
581
  modeCards.forEach(card => {
582
  card.addEventListener('click', (e) => {
 
656
  depthDownloadBtn.removeAttribute('href');
657
  depthDownloadBtn.classList.add('hidden');
658
  depthVideoStatus.textContent = '';
659
+ viewToggleContainer.classList.add('hidden');
660
+ currentView = 'detection';
661
+ detectionVideoUrl = null;
662
+ depthVideoUrl = null;
663
+ detectionFirstFrameUrl = null;
664
+ depthFirstFrameUrl = null;
665
  statusLine.classList.add('hidden');
666
  statusLine.textContent = '';
667
 
 
706
  clearInterval(statusPoller);
707
  statusPoller = null;
708
  statusLine.textContent = 'Status: completed';
709
+
710
+ // Fetch detection video
711
  const videoResponse = await fetch(data.video_url);
712
  if (!videoResponse.ok) {
713
  alert('Failed to fetch processed video.');
714
  return;
715
  }
716
  const blob = await videoResponse.blob();
717
+ detectionVideoUrl = URL.createObjectURL(blob);
718
+ detectionFirstFrameUrl = `${data.first_frame_url}?t=${Date.now()}`;
 
719
 
720
+ // Set initial detection view
721
+ processedVideo.src = detectionVideoUrl;
722
+ downloadBtn.href = detectionVideoUrl;
723
+
724
+ // Load depth assets
725
  await loadDepthAssets(data);
726
  } else if (statusData.status === 'failed') {
727
  clearInterval(statusPoller);
 
759
  const frameResponse = await fetch(jobData.first_frame_depth_url);
760
  if (frameResponse.ok) {
761
  const frameBlob = await frameResponse.blob();
762
+ depthFirstFrameUrl = URL.createObjectURL(frameBlob);
763
+ depthFrameImage.src = depthFirstFrameUrl;
764
  depthFrameImage.classList.remove('hidden');
765
  depthFramePlaceholder.classList.add('hidden');
766
  } else {
 
775
  const depthResponse = await fetch(jobData.depth_video_url);
776
  if (depthResponse.ok) {
777
  const depthBlob = await depthResponse.blob();
778
+ depthVideoUrl = URL.createObjectURL(depthBlob);
779
+
780
+ // Keep depth video card hidden - using toggle instead
781
+ depthVideo.src = depthVideoUrl;
782
+ depthVideo.classList.add('hidden');
783
+ depthDownloadBtn.classList.add('hidden');
784
+
785
+ // Show toggle buttons now that we have both videos
786
+ viewToggleContainer.classList.remove('hidden');
787
+
788
+ // Start with detection view
789
+ switchToView('detection');
790
  } else {
791
  const error = await depthResponse.json();
792
  depthVideoStatus.textContent = error.detail || 'Depth video unavailable.';
models/depth_estimators/depth_pro.py CHANGED
@@ -8,28 +8,32 @@ from .base import DepthEstimator, DepthResult
8
 
9
 
10
  class DepthProEstimator(DepthEstimator):
11
- """Apple Depth Pro depth estimator."""
12
 
13
  name = "depth_pro"
14
 
15
  def __init__(self):
16
- """Initialize Depth Pro model."""
17
  try:
18
- import depth_pro
19
  except ImportError as exc:
20
  raise ImportError(
21
- "depth_pro package not installed. "
22
- "Install with: pip install git+https://github.com/apple/ml-depth-pro.git"
23
  ) from exc
24
 
25
- logging.info("Loading Depth Pro model...")
26
- self.model, self.transform = depth_pro.create_model_and_transforms()
27
- self.model.eval()
28
 
29
- # Move model to GPU if available
30
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
31
  if torch.cuda.is_available():
32
- self.model = self.model.cuda()
33
  logging.info("Depth Pro model loaded on GPU")
34
  else:
35
  logging.warning("Depth Pro model loaded on CPU (no CUDA available)")
@@ -47,29 +51,34 @@ class DepthProEstimator(DepthEstimator):
47
  # Convert BGR to RGB
48
  rgb_frame = frame[:, :, ::-1] # BGR → RGB
49
 
50
- # Convert to PIL Image for transform
51
  pil_image = Image.fromarray(rgb_frame)
 
52
 
53
- # Apply transform and move to device
54
- image_tensor = self.transform(pil_image)
55
- image_tensor = image_tensor.to(self.device)
56
 
57
  # Run inference (no gradient needed)
58
  with torch.no_grad():
59
- prediction = self.model.infer(image_tensor, f_px=None)
 
 
 
 
 
 
60
 
61
- # Extract depth map and move to CPU/numpy
62
- # prediction is a dict: {"depth": tensor, "focallength_px": tensor}
63
- depth_tensor = prediction["depth"]
64
- focal_length_tensor = prediction.get("focallength_px")
65
 
66
- # Convert to numpy, remove batch dimension if present
67
- depth_map = depth_tensor.cpu().numpy().squeeze()
68
 
69
- # Extract focal length
70
- if focal_length_tensor is not None:
71
- focal_length = float(focal_length_tensor.cpu().item())
72
  else:
73
- focal_length = 1.0
74
 
75
  return DepthResult(depth_map=depth_map, focal_length=focal_length)
 
8
 
9
 
10
  class DepthProEstimator(DepthEstimator):
11
+ """Apple Depth Pro depth estimator using Hugging Face transformers."""
12
 
13
  name = "depth_pro"
14
 
15
  def __init__(self):
16
+ """Initialize Depth Pro model from Hugging Face."""
17
  try:
18
+ from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation
19
  except ImportError as exc:
20
  raise ImportError(
21
+ "transformers package not installed or doesn't include DepthPro. "
22
+ "Update with: pip install transformers --upgrade"
23
  ) from exc
24
 
25
+ logging.info("Loading Depth Pro model from Hugging Face...")
 
 
26
 
27
+ # Set device
28
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ # Load model and processor
31
+ model_id = "apple/DepthPro-hf"
32
+ self.image_processor = DepthProImageProcessorFast.from_pretrained(model_id)
33
+ self.model = DepthProForDepthEstimation.from_pretrained(model_id).to(self.device)
34
+ self.model.eval()
35
+
36
  if torch.cuda.is_available():
 
37
  logging.info("Depth Pro model loaded on GPU")
38
  else:
39
  logging.warning("Depth Pro model loaded on CPU (no CUDA available)")
 
51
  # Convert BGR to RGB
52
  rgb_frame = frame[:, :, ::-1] # BGR → RGB
53
 
54
+ # Convert to PIL Image
55
  pil_image = Image.fromarray(rgb_frame)
56
+ height, width = pil_image.height, pil_image.width
57
 
58
+ # Preprocess image
59
+ inputs = self.image_processor(images=pil_image, return_tensors="pt").to(self.device)
 
60
 
61
  # Run inference (no gradient needed)
62
  with torch.no_grad():
63
+ outputs = self.model(**inputs)
64
+
65
+ # Post-process to get depth and focal length
66
+ post_processed = self.image_processor.post_process_depth_estimation(
67
+ outputs,
68
+ target_sizes=[(height, width)],
69
+ )
70
 
71
+ # Extract depth map and focal length
72
+ depth_tensor = post_processed[0]["predicted_depth"] # Already at target size
73
+ focal_length_value = post_processed[0].get("focal_length", 1.0)
 
74
 
75
+ # Convert to numpy
76
+ depth_map = depth_tensor.cpu().numpy()
77
 
78
+ # focal_length might be a tensor, convert to float
79
+ if isinstance(focal_length_value, torch.Tensor):
80
+ focal_length = float(focal_length_value.item())
81
  else:
82
+ focal_length = float(focal_length_value)
83
 
84
  return DepthResult(depth_map=depth_map, focal_length=focal_length)
requirements.txt CHANGED
@@ -11,4 +11,3 @@ huggingface-hub
11
  ultralytics
12
  timm
13
  ffmpeg-python
14
- depth-pro @ git+https://github.com/apple/ml-depth-pro.git
 
11
  ultralytics
12
  timm
13
  ffmpeg-python