Spaces:
Paused
Paused
Zhen Ye
commited on
Commit
·
94c85d4
1
Parent(s):
469102e
using apple depth pro hf
Browse files- CLAUDE.md +254 -0
- demo.html +114 -10
- models/depth_estimators/depth_pro.py +34 -25
- requirements.txt +0 -1
CLAUDE.md
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
Simple video object detection system with three modes:
|
| 8 |
+
- **Object Detection**: Detect custom objects using text queries (fully functional)
|
| 9 |
+
- **Segmentation**: Mask overlays using SAM3
|
| 10 |
+
- **Drone Detection**: (Coming Soon) Specialized UAV detection
|
| 11 |
+
|
| 12 |
+
## Core Architecture
|
| 13 |
+
|
| 14 |
+
### Simple Detection Flow
|
| 15 |
+
|
| 16 |
+
```
|
| 17 |
+
User → demo.html → POST /detect → inference.py → detector → processed video
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
1. User selects mode and uploads video via web interface
|
| 21 |
+
2. Frontend sends video + mode + queries to `/detect` endpoint
|
| 22 |
+
3. Backend runs detection inference with selected model
|
| 23 |
+
4. Returns processed video with bounding boxes
|
| 24 |
+
|
| 25 |
+
### Available Detectors
|
| 26 |
+
|
| 27 |
+
The system includes 4 pre-trained object detection models:
|
| 28 |
+
|
| 29 |
+
| Detector | Key | Type | Best For |
|
| 30 |
+
|----------|-----|------|----------|
|
| 31 |
+
| **OWLv2** | `owlv2_base` | Open-vocabulary | Custom text queries (default) |
|
| 32 |
+
| **YOLOv8** | `hf_yolov8` | COCO classes | Fast real-time detection |
|
| 33 |
+
| **DETR** | `detr_resnet50` | COCO classes | Transformer-based detection |
|
| 34 |
+
| **Grounding DINO** | `grounding_dino` | Open-vocabulary | Text-grounded detection |
|
| 35 |
+
|
| 36 |
+
All detectors implement the `ObjectDetector` interface in `models/detectors/base.py` with a single `predict()` method.
|
| 37 |
+
|
| 38 |
+
## Development Commands
|
| 39 |
+
|
| 40 |
+
### Setup
|
| 41 |
+
```bash
|
| 42 |
+
python -m venv .venv
|
| 43 |
+
source .venv/bin/activate # or `.venv/bin/activate` on macOS/Linux
|
| 44 |
+
pip install -r requirements.txt
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
### Running the Server
|
| 48 |
+
```bash
|
| 49 |
+
# Development
|
| 50 |
+
uvicorn app:app --host 0.0.0.0 --port 7860 --reload
|
| 51 |
+
|
| 52 |
+
# Production (Docker)
|
| 53 |
+
docker build -t object_detectors .
|
| 54 |
+
docker run -p 7860:7860 object_detectors
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### Testing the API
|
| 58 |
+
```bash
|
| 59 |
+
# Test object detection
|
| 60 |
+
curl -X POST http://localhost:7860/detect \
|
| 61 |
+
-F "video=@sample.mp4" \
|
| 62 |
+
-F "mode=object_detection" \
|
| 63 |
+
-F "queries=person,car,dog" \
|
| 64 |
+
-F "detector=owlv2_base" \
|
| 65 |
+
--output processed.mp4
|
| 66 |
+
|
| 67 |
+
# Test placeholder modes (returns JSON)
|
| 68 |
+
curl -X POST http://localhost:7860/detect \
|
| 69 |
+
-F "video=@sample.mp4" \
|
| 70 |
+
-F "mode=segmentation"
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
## Key Implementation Details
|
| 74 |
+
|
| 75 |
+
### API Endpoint: `/detect`
|
| 76 |
+
|
| 77 |
+
**Parameters:**
|
| 78 |
+
- `video` (file): Video file to process
|
| 79 |
+
- `mode` (string): Detection mode - `object_detection`, `segmentation`, or `drone_detection`
|
| 80 |
+
- `queries` (string): Comma-separated object classes (for object_detection mode)
|
| 81 |
+
- `detector` (string): Model key (default: `owlv2_base`)
|
| 82 |
+
|
| 83 |
+
**Returns:**
|
| 84 |
+
- For `object_detection`: MP4 video with bounding boxes
|
| 85 |
+
- For `segmentation`: MP4 video with mask overlays
|
| 86 |
+
- For `drone_detection`: JSON with `{"status": "coming_soon", "message": "..."}`
|
| 87 |
+
|
| 88 |
+
### Inference Pipeline
|
| 89 |
+
|
| 90 |
+
The `run_inference()` function in `inference.py` follows these steps:
|
| 91 |
+
|
| 92 |
+
1. **Extract Frames**: Decode video using OpenCV
|
| 93 |
+
2. **Parse Queries**: Split comma-separated text into list (defaults to common objects if empty)
|
| 94 |
+
3. **Select Detector**: Load detector by key (cached via `@lru_cache`)
|
| 95 |
+
4. **Process Frames**: Run detection on each frame
|
| 96 |
+
- Call `detector.predict(frame, queries)`
|
| 97 |
+
- Draw green bounding boxes on detections
|
| 98 |
+
5. **Write Video**: Encode processed frames back to MP4
|
| 99 |
+
|
| 100 |
+
Default queries (if none provided): `["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"]`
|
| 101 |
+
|
| 102 |
+
### Detector Loading
|
| 103 |
+
|
| 104 |
+
Detectors are registered in `models/model_loader.py`:
|
| 105 |
+
|
| 106 |
+
```python
|
| 107 |
+
_REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
|
| 108 |
+
"owlv2_base": Owlv2Detector,
|
| 109 |
+
"hf_yolov8": HuggingFaceYoloV8Detector,
|
| 110 |
+
"detr_resnet50": DetrDetector,
|
| 111 |
+
"grounding_dino": GroundingDinoDetector,
|
| 112 |
+
}
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
Loaded via `load_detector(name)` which caches instances for performance.
|
| 116 |
+
|
| 117 |
+
### Detection Result Format
|
| 118 |
+
|
| 119 |
+
All detectors return a `DetectionResult` namedtuple:
|
| 120 |
+
```python
|
| 121 |
+
DetectionResult(
|
| 122 |
+
boxes: np.ndarray, # Nx4 array [x1, y1, x2, y2]
|
| 123 |
+
scores: Sequence[float], # Confidence scores
|
| 124 |
+
labels: Sequence[int], # Class indices
|
| 125 |
+
label_names: Optional[Sequence[str]] # Human-readable names
|
| 126 |
+
)
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
## File Structure
|
| 130 |
+
|
| 131 |
+
```
|
| 132 |
+
.
|
| 133 |
+
├── app.py # FastAPI server with /detect endpoint
|
| 134 |
+
├── inference.py # Video processing and detection pipeline
|
| 135 |
+
├── demo.html # Web interface with mode selector
|
| 136 |
+
├── requirements.txt # Python dependencies
|
| 137 |
+
├── models/
|
| 138 |
+
│ ├── model_loader.py # Detector registry and loading
|
| 139 |
+
│ └── detectors/
|
| 140 |
+
│ ├── base.py # ObjectDetector interface
|
| 141 |
+
│ ├── owlv2.py # OWLv2 implementation
|
| 142 |
+
│ ├── yolov8.py # YOLOv8 implementation
|
| 143 |
+
│ ├── detr.py # DETR implementation
|
| 144 |
+
│ └── grounding_dino.py # Grounding DINO implementation
|
| 145 |
+
├── utils/
|
| 146 |
+
│ └── video.py # Video encoding/decoding utilities
|
| 147 |
+
└── coco_classes.py # COCO dataset class definitions
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
## Adding New Detectors
|
| 151 |
+
|
| 152 |
+
To add a new detector:
|
| 153 |
+
|
| 154 |
+
1. **Create detector class** in `models/detectors/`:
|
| 155 |
+
```python
|
| 156 |
+
from .base import ObjectDetector, DetectionResult
|
| 157 |
+
|
| 158 |
+
class MyDetector(ObjectDetector):
|
| 159 |
+
name = "my_detector"
|
| 160 |
+
|
| 161 |
+
def predict(self, frame, queries):
|
| 162 |
+
# Your detection logic
|
| 163 |
+
return DetectionResult(boxes, scores, labels, label_names)
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
2. **Register in model_loader.py**:
|
| 167 |
+
```python
|
| 168 |
+
_REGISTRY = {
|
| 169 |
+
...
|
| 170 |
+
"my_detector": MyDetector,
|
| 171 |
+
}
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
3. **Update frontend** `demo.html` detector dropdown:
|
| 175 |
+
```html
|
| 176 |
+
<option value="my_detector">My Detector</option>
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
## Adding New Detection Modes
|
| 180 |
+
|
| 181 |
+
To implement additional modes such as drone detection:
|
| 182 |
+
|
| 183 |
+
1. **Create specialized detector** (if needed):
|
| 184 |
+
- For segmentation: Extend `SegmentationResult` to include masks
|
| 185 |
+
- For drone detection: Create `DroneDetector` with specialized filtering
|
| 186 |
+
|
| 187 |
+
2. **Update `/detect` endpoint** in `app.py`:
|
| 188 |
+
```python
|
| 189 |
+
if mode == "segmentation":
|
| 190 |
+
# Run segmentation inference
|
| 191 |
+
# Return video with masks rendered
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
3. **Update frontend** to remove "disabled" class from mode card
|
| 195 |
+
|
| 196 |
+
4. **Update inference.py** if needed to handle new output types
|
| 197 |
+
|
| 198 |
+
## Common Patterns
|
| 199 |
+
|
| 200 |
+
### Query Processing
|
| 201 |
+
Queries are parsed from comma-separated strings:
|
| 202 |
+
```python
|
| 203 |
+
queries = [q.strip() for q in "person, car, dog".split(",") if q.strip()]
|
| 204 |
+
# Result: ["person", "car", "dog"]
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
### Frame Processing Loop
|
| 208 |
+
Standard pattern for processing video frames:
|
| 209 |
+
```python
|
| 210 |
+
processed_frames = []
|
| 211 |
+
for idx, frame in enumerate(frames):
|
| 212 |
+
processed_frame, detections = infer_frame(frame, queries, detector_name)
|
| 213 |
+
processed_frames.append(processed_frame)
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
### Temporary File Management
|
| 217 |
+
FastAPI's `BackgroundTasks` cleans up temp files after response:
|
| 218 |
+
```python
|
| 219 |
+
_schedule_cleanup(background_tasks, input_path)
|
| 220 |
+
_schedule_cleanup(background_tasks, output_path)
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
## Performance Notes
|
| 224 |
+
|
| 225 |
+
- **Detector Caching**: Models are loaded once and cached via `@lru_cache`
|
| 226 |
+
- **Default Resolution**: Videos processed at original resolution
|
| 227 |
+
- **Frame Limit**: Use `max_frames` parameter in `run_inference()` for testing
|
| 228 |
+
- **Memory Usage**: Entire video is loaded into memory (frames list)
|
| 229 |
+
|
| 230 |
+
## Troubleshooting
|
| 231 |
+
|
| 232 |
+
### "No module named 'fastapi'"
|
| 233 |
+
Install dependencies: `pip install -r requirements.txt`
|
| 234 |
+
|
| 235 |
+
### "Video decoding failed"
|
| 236 |
+
Check video codec compatibility. System expects MP4/H.264.
|
| 237 |
+
|
| 238 |
+
### "Detector not found"
|
| 239 |
+
Verify detector key exists in `model_loader._REGISTRY`
|
| 240 |
+
|
| 241 |
+
### Slow processing
|
| 242 |
+
- Try faster detector: YOLOv8 (`hf_yolov8`)
|
| 243 |
+
- Reduce video resolution before uploading
|
| 244 |
+
- Use `max_frames` parameter for testing
|
| 245 |
+
|
| 246 |
+
## Dependencies
|
| 247 |
+
|
| 248 |
+
Core packages:
|
| 249 |
+
- `fastapi` + `uvicorn`: Web server
|
| 250 |
+
- `torch` + `transformers`: Deep learning models
|
| 251 |
+
- `opencv-python-headless`: Video processing
|
| 252 |
+
- `ultralytics`: YOLOv8 implementation
|
| 253 |
+
- `huggingface-hub`: Model downloading
|
| 254 |
+
- `pillow`, `scipy`, `accelerate`, `timm`: Supporting libraries
|
demo.html
CHANGED
|
@@ -306,6 +306,31 @@
|
|
| 306 |
100% { transform: rotate(360deg); }
|
| 307 |
}
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
.hidden {
|
| 310 |
display: none;
|
| 311 |
}
|
|
@@ -415,6 +440,13 @@
|
|
| 415 |
<!-- Results -->
|
| 416 |
<div class="section hidden" id="resultsSection">
|
| 417 |
<div class="section-title">Results</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
<div class="results-grid">
|
| 419 |
<div class="video-card">
|
| 420 |
<div class="video-card-header">First Frame</div>
|
|
@@ -466,6 +498,11 @@
|
|
| 466 |
// State
|
| 467 |
let selectedMode = 'object_detection';
|
| 468 |
let videoFile = null;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
|
| 470 |
// Elements
|
| 471 |
const modeCards = document.querySelectorAll('.mode-card');
|
|
@@ -490,8 +527,56 @@
|
|
| 490 |
const depthVideo = document.getElementById('depthVideo');
|
| 491 |
const depthDownloadBtn = document.getElementById('depthDownloadBtn');
|
| 492 |
const depthVideoStatus = document.getElementById('depthVideoStatus');
|
|
|
|
|
|
|
|
|
|
| 493 |
let statusPoller = null;
|
| 494 |
const statusLine = document.getElementById('statusLine');
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
// Mode selection handler
|
| 496 |
modeCards.forEach(card => {
|
| 497 |
card.addEventListener('click', (e) => {
|
|
@@ -571,6 +656,12 @@
|
|
| 571 |
depthDownloadBtn.removeAttribute('href');
|
| 572 |
depthDownloadBtn.classList.add('hidden');
|
| 573 |
depthVideoStatus.textContent = '';
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
statusLine.classList.add('hidden');
|
| 575 |
statusLine.textContent = '';
|
| 576 |
|
|
@@ -615,16 +706,22 @@
|
|
| 615 |
clearInterval(statusPoller);
|
| 616 |
statusPoller = null;
|
| 617 |
statusLine.textContent = 'Status: completed';
|
|
|
|
|
|
|
| 618 |
const videoResponse = await fetch(data.video_url);
|
| 619 |
if (!videoResponse.ok) {
|
| 620 |
alert('Failed to fetch processed video.');
|
| 621 |
return;
|
| 622 |
}
|
| 623 |
const blob = await videoResponse.blob();
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
downloadBtn.href = videoUrl;
|
| 627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
await loadDepthAssets(data);
|
| 629 |
} else if (statusData.status === 'failed') {
|
| 630 |
clearInterval(statusPoller);
|
|
@@ -662,8 +759,8 @@
|
|
| 662 |
const frameResponse = await fetch(jobData.first_frame_depth_url);
|
| 663 |
if (frameResponse.ok) {
|
| 664 |
const frameBlob = await frameResponse.blob();
|
| 665 |
-
|
| 666 |
-
depthFrameImage.src =
|
| 667 |
depthFrameImage.classList.remove('hidden');
|
| 668 |
depthFramePlaceholder.classList.add('hidden');
|
| 669 |
} else {
|
|
@@ -678,11 +775,18 @@
|
|
| 678 |
const depthResponse = await fetch(jobData.depth_video_url);
|
| 679 |
if (depthResponse.ok) {
|
| 680 |
const depthBlob = await depthResponse.blob();
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
} else {
|
| 687 |
const error = await depthResponse.json();
|
| 688 |
depthVideoStatus.textContent = error.detail || 'Depth video unavailable.';
|
|
|
|
| 306 |
100% { transform: rotate(360deg); }
|
| 307 |
}
|
| 308 |
|
| 309 |
+
/* View toggle buttons */
|
| 310 |
+
.view-toggle-btn {
|
| 311 |
+
padding: 12px 28px;
|
| 312 |
+
margin: 0 10px;
|
| 313 |
+
background: #e5e7eb;
|
| 314 |
+
color: #374151;
|
| 315 |
+
border: 2px solid #d1d5db;
|
| 316 |
+
border-radius: 8px;
|
| 317 |
+
cursor: pointer;
|
| 318 |
+
font-weight: 600;
|
| 319 |
+
font-size: 14px;
|
| 320 |
+
transition: all 0.3s;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
.view-toggle-btn.active {
|
| 324 |
+
background: #1f2933;
|
| 325 |
+
color: #f9fafb;
|
| 326 |
+
border-color: #1f2933;
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
.view-toggle-btn:hover:not(.active) {
|
| 330 |
+
background: #d1d5db;
|
| 331 |
+
transform: translateY(-1px);
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
.hidden {
|
| 335 |
display: none;
|
| 336 |
}
|
|
|
|
| 440 |
<!-- Results -->
|
| 441 |
<div class="section hidden" id="resultsSection">
|
| 442 |
<div class="section-title">Results</div>
|
| 443 |
+
|
| 444 |
+
<!-- View Toggle Buttons -->
|
| 445 |
+
<div id="viewToggleContainer" class="hidden" style="text-align: center; margin-bottom: 20px;">
|
| 446 |
+
<button class="view-toggle-btn active" id="detectionViewBtn">Detection View</button>
|
| 447 |
+
<button class="view-toggle-btn" id="depthViewBtn">Depth View</button>
|
| 448 |
+
</div>
|
| 449 |
+
|
| 450 |
<div class="results-grid">
|
| 451 |
<div class="video-card">
|
| 452 |
<div class="video-card-header">First Frame</div>
|
|
|
|
| 498 |
// State
|
| 499 |
let selectedMode = 'object_detection';
|
| 500 |
let videoFile = null;
|
| 501 |
+
let currentView = 'detection'; // 'detection' or 'depth'
|
| 502 |
+
let detectionVideoUrl = null;
|
| 503 |
+
let depthVideoUrl = null;
|
| 504 |
+
let detectionFirstFrameUrl = null;
|
| 505 |
+
let depthFirstFrameUrl = null;
|
| 506 |
|
| 507 |
// Elements
|
| 508 |
const modeCards = document.querySelectorAll('.mode-card');
|
|
|
|
| 527 |
const depthVideo = document.getElementById('depthVideo');
|
| 528 |
const depthDownloadBtn = document.getElementById('depthDownloadBtn');
|
| 529 |
const depthVideoStatus = document.getElementById('depthVideoStatus');
|
| 530 |
+
const viewToggleContainer = document.getElementById('viewToggleContainer');
|
| 531 |
+
const detectionViewBtn = document.getElementById('detectionViewBtn');
|
| 532 |
+
const depthViewBtn = document.getElementById('depthViewBtn');
|
| 533 |
let statusPoller = null;
|
| 534 |
const statusLine = document.getElementById('statusLine');
|
| 535 |
+
|
| 536 |
+
// View switching function
|
| 537 |
+
function switchToView(view) {
|
| 538 |
+
currentView = view;
|
| 539 |
+
|
| 540 |
+
if (view === 'detection') {
|
| 541 |
+
detectionViewBtn.classList.add('active');
|
| 542 |
+
depthViewBtn.classList.remove('active');
|
| 543 |
+
|
| 544 |
+
if (detectionFirstFrameUrl) {
|
| 545 |
+
firstFrameImage.src = detectionFirstFrameUrl;
|
| 546 |
+
depthFrameImage.classList.add('hidden');
|
| 547 |
+
depthFramePlaceholder.classList.remove('hidden');
|
| 548 |
+
}
|
| 549 |
+
if (detectionVideoUrl) {
|
| 550 |
+
processedVideo.src = detectionVideoUrl;
|
| 551 |
+
downloadBtn.href = detectionVideoUrl;
|
| 552 |
+
downloadBtn.download = 'processed_detection.mp4';
|
| 553 |
+
processedVideo.load();
|
| 554 |
+
}
|
| 555 |
+
} else {
|
| 556 |
+
depthViewBtn.classList.add('active');
|
| 557 |
+
detectionViewBtn.classList.remove('active');
|
| 558 |
+
|
| 559 |
+
if (depthFirstFrameUrl) {
|
| 560 |
+
firstFrameImage.src = depthFirstFrameUrl;
|
| 561 |
+
depthFrameImage.classList.add('hidden');
|
| 562 |
+
depthFramePlaceholder.classList.add('hidden');
|
| 563 |
+
}
|
| 564 |
+
if (depthVideoUrl) {
|
| 565 |
+
processedVideo.src = depthVideoUrl;
|
| 566 |
+
downloadBtn.href = depthVideoUrl;
|
| 567 |
+
downloadBtn.download = 'depth_map.mp4';
|
| 568 |
+
processedVideo.load();
|
| 569 |
+
}
|
| 570 |
+
}
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
// Toggle button event listeners
|
| 574 |
+
if (detectionViewBtn) {
|
| 575 |
+
detectionViewBtn.addEventListener('click', () => switchToView('detection'));
|
| 576 |
+
}
|
| 577 |
+
if (depthViewBtn) {
|
| 578 |
+
depthViewBtn.addEventListener('click', () => switchToView('depth'));
|
| 579 |
+
}
|
| 580 |
// Mode selection handler
|
| 581 |
modeCards.forEach(card => {
|
| 582 |
card.addEventListener('click', (e) => {
|
|
|
|
| 656 |
depthDownloadBtn.removeAttribute('href');
|
| 657 |
depthDownloadBtn.classList.add('hidden');
|
| 658 |
depthVideoStatus.textContent = '';
|
| 659 |
+
viewToggleContainer.classList.add('hidden');
|
| 660 |
+
currentView = 'detection';
|
| 661 |
+
detectionVideoUrl = null;
|
| 662 |
+
depthVideoUrl = null;
|
| 663 |
+
detectionFirstFrameUrl = null;
|
| 664 |
+
depthFirstFrameUrl = null;
|
| 665 |
statusLine.classList.add('hidden');
|
| 666 |
statusLine.textContent = '';
|
| 667 |
|
|
|
|
| 706 |
clearInterval(statusPoller);
|
| 707 |
statusPoller = null;
|
| 708 |
statusLine.textContent = 'Status: completed';
|
| 709 |
+
|
| 710 |
+
// Fetch detection video
|
| 711 |
const videoResponse = await fetch(data.video_url);
|
| 712 |
if (!videoResponse.ok) {
|
| 713 |
alert('Failed to fetch processed video.');
|
| 714 |
return;
|
| 715 |
}
|
| 716 |
const blob = await videoResponse.blob();
|
| 717 |
+
detectionVideoUrl = URL.createObjectURL(blob);
|
| 718 |
+
detectionFirstFrameUrl = `${data.first_frame_url}?t=${Date.now()}`;
|
|
|
|
| 719 |
|
| 720 |
+
// Set initial detection view
|
| 721 |
+
processedVideo.src = detectionVideoUrl;
|
| 722 |
+
downloadBtn.href = detectionVideoUrl;
|
| 723 |
+
|
| 724 |
+
// Load depth assets
|
| 725 |
await loadDepthAssets(data);
|
| 726 |
} else if (statusData.status === 'failed') {
|
| 727 |
clearInterval(statusPoller);
|
|
|
|
| 759 |
const frameResponse = await fetch(jobData.first_frame_depth_url);
|
| 760 |
if (frameResponse.ok) {
|
| 761 |
const frameBlob = await frameResponse.blob();
|
| 762 |
+
depthFirstFrameUrl = URL.createObjectURL(frameBlob);
|
| 763 |
+
depthFrameImage.src = depthFirstFrameUrl;
|
| 764 |
depthFrameImage.classList.remove('hidden');
|
| 765 |
depthFramePlaceholder.classList.add('hidden');
|
| 766 |
} else {
|
|
|
|
| 775 |
const depthResponse = await fetch(jobData.depth_video_url);
|
| 776 |
if (depthResponse.ok) {
|
| 777 |
const depthBlob = await depthResponse.blob();
|
| 778 |
+
depthVideoUrl = URL.createObjectURL(depthBlob);
|
| 779 |
+
|
| 780 |
+
// Keep depth video card hidden - using toggle instead
|
| 781 |
+
depthVideo.src = depthVideoUrl;
|
| 782 |
+
depthVideo.classList.add('hidden');
|
| 783 |
+
depthDownloadBtn.classList.add('hidden');
|
| 784 |
+
|
| 785 |
+
// Show toggle buttons now that we have both videos
|
| 786 |
+
viewToggleContainer.classList.remove('hidden');
|
| 787 |
+
|
| 788 |
+
// Start with detection view
|
| 789 |
+
switchToView('detection');
|
| 790 |
} else {
|
| 791 |
const error = await depthResponse.json();
|
| 792 |
depthVideoStatus.textContent = error.detail || 'Depth video unavailable.';
|
models/depth_estimators/depth_pro.py
CHANGED
|
@@ -8,28 +8,32 @@ from .base import DepthEstimator, DepthResult
|
|
| 8 |
|
| 9 |
|
| 10 |
class DepthProEstimator(DepthEstimator):
|
| 11 |
-
"""Apple Depth Pro depth estimator."""
|
| 12 |
|
| 13 |
name = "depth_pro"
|
| 14 |
|
| 15 |
def __init__(self):
|
| 16 |
-
"""Initialize Depth Pro model."""
|
| 17 |
try:
|
| 18 |
-
import
|
| 19 |
except ImportError as exc:
|
| 20 |
raise ImportError(
|
| 21 |
-
"
|
| 22 |
-
"
|
| 23 |
) from exc
|
| 24 |
|
| 25 |
-
logging.info("Loading Depth Pro model...")
|
| 26 |
-
self.model, self.transform = depth_pro.create_model_and_transforms()
|
| 27 |
-
self.model.eval()
|
| 28 |
|
| 29 |
-
#
|
| 30 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
-
self.model = self.model.cuda()
|
| 33 |
logging.info("Depth Pro model loaded on GPU")
|
| 34 |
else:
|
| 35 |
logging.warning("Depth Pro model loaded on CPU (no CUDA available)")
|
|
@@ -47,29 +51,34 @@ class DepthProEstimator(DepthEstimator):
|
|
| 47 |
# Convert BGR to RGB
|
| 48 |
rgb_frame = frame[:, :, ::-1] # BGR → RGB
|
| 49 |
|
| 50 |
-
# Convert to PIL Image
|
| 51 |
pil_image = Image.fromarray(rgb_frame)
|
|
|
|
| 52 |
|
| 53 |
-
#
|
| 54 |
-
|
| 55 |
-
image_tensor = image_tensor.to(self.device)
|
| 56 |
|
| 57 |
# Run inference (no gradient needed)
|
| 58 |
with torch.no_grad():
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
# Extract depth map and
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
focal_length_tensor = prediction.get("focallength_px")
|
| 65 |
|
| 66 |
-
# Convert to numpy
|
| 67 |
-
depth_map = depth_tensor.cpu().numpy()
|
| 68 |
|
| 69 |
-
#
|
| 70 |
-
if
|
| 71 |
-
focal_length = float(
|
| 72 |
else:
|
| 73 |
-
focal_length =
|
| 74 |
|
| 75 |
return DepthResult(depth_map=depth_map, focal_length=focal_length)
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class DepthProEstimator(DepthEstimator):
|
| 11 |
+
"""Apple Depth Pro depth estimator using Hugging Face transformers."""
|
| 12 |
|
| 13 |
name = "depth_pro"
|
| 14 |
|
| 15 |
def __init__(self):
|
| 16 |
+
"""Initialize Depth Pro model from Hugging Face."""
|
| 17 |
try:
|
| 18 |
+
from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation
|
| 19 |
except ImportError as exc:
|
| 20 |
raise ImportError(
|
| 21 |
+
"transformers package not installed or doesn't include DepthPro. "
|
| 22 |
+
"Update with: pip install transformers --upgrade"
|
| 23 |
) from exc
|
| 24 |
|
| 25 |
+
logging.info("Loading Depth Pro model from Hugging Face...")
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
# Set device
|
| 28 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
# Load model and processor
|
| 31 |
+
model_id = "apple/DepthPro-hf"
|
| 32 |
+
self.image_processor = DepthProImageProcessorFast.from_pretrained(model_id)
|
| 33 |
+
self.model = DepthProForDepthEstimation.from_pretrained(model_id).to(self.device)
|
| 34 |
+
self.model.eval()
|
| 35 |
+
|
| 36 |
if torch.cuda.is_available():
|
|
|
|
| 37 |
logging.info("Depth Pro model loaded on GPU")
|
| 38 |
else:
|
| 39 |
logging.warning("Depth Pro model loaded on CPU (no CUDA available)")
|
|
|
|
| 51 |
# Convert BGR to RGB
|
| 52 |
rgb_frame = frame[:, :, ::-1] # BGR → RGB
|
| 53 |
|
| 54 |
+
# Convert to PIL Image
|
| 55 |
pil_image = Image.fromarray(rgb_frame)
|
| 56 |
+
height, width = pil_image.height, pil_image.width
|
| 57 |
|
| 58 |
+
# Preprocess image
|
| 59 |
+
inputs = self.image_processor(images=pil_image, return_tensors="pt").to(self.device)
|
|
|
|
| 60 |
|
| 61 |
# Run inference (no gradient needed)
|
| 62 |
with torch.no_grad():
|
| 63 |
+
outputs = self.model(**inputs)
|
| 64 |
+
|
| 65 |
+
# Post-process to get depth and focal length
|
| 66 |
+
post_processed = self.image_processor.post_process_depth_estimation(
|
| 67 |
+
outputs,
|
| 68 |
+
target_sizes=[(height, width)],
|
| 69 |
+
)
|
| 70 |
|
| 71 |
+
# Extract depth map and focal length
|
| 72 |
+
depth_tensor = post_processed[0]["predicted_depth"] # Already at target size
|
| 73 |
+
focal_length_value = post_processed[0].get("focal_length", 1.0)
|
|
|
|
| 74 |
|
| 75 |
+
# Convert to numpy
|
| 76 |
+
depth_map = depth_tensor.cpu().numpy()
|
| 77 |
|
| 78 |
+
# focal_length might be a tensor, convert to float
|
| 79 |
+
if isinstance(focal_length_value, torch.Tensor):
|
| 80 |
+
focal_length = float(focal_length_value.item())
|
| 81 |
else:
|
| 82 |
+
focal_length = float(focal_length_value)
|
| 83 |
|
| 84 |
return DepthResult(depth_map=depth_map, focal_length=focal_length)
|
requirements.txt
CHANGED
|
@@ -11,4 +11,3 @@ huggingface-hub
|
|
| 11 |
ultralytics
|
| 12 |
timm
|
| 13 |
ffmpeg-python
|
| 14 |
-
depth-pro @ git+https://github.com/apple/ml-depth-pro.git
|
|
|
|
| 11 |
ultralytics
|
| 12 |
timm
|
| 13 |
ffmpeg-python
|
|
|