Zhen Ye Claude Opus 4.6 commited on
Commit
bfc6bae
·
1 Parent(s): 562781d

fix: weight download race condition + rewrite CLAUDE.md

Browse files

Add ensure_weights() classmethod and prefetch_weights() to download
model weights once before parallel multi-GPU init, fixing FileNotFoundError
when 4 GPUs race to download visDrone.pt simultaneously.

Rewrite CLAUDE.md to reflect current architecture: async job pipeline,
multi-GPU inference, GSAM2 segmentation, frontend SPA modules.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

CLAUDE.md CHANGED
@@ -4,221 +4,155 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
4
 
5
  ## Project Overview
6
 
7
- Simple video object detection system with three modes:
8
- - **Object Detection**: Detect custom objects using text queries (fully functional)
9
- - **Segmentation**: Mask overlays using SAM3
10
- - **Drone Detection**: (Coming Soon) Specialized UAV detection
11
 
12
- ## Core Architecture
13
-
14
- ### Simple Detection Flow
15
-
16
- ```
17
- User → demo.html → POST /detect → inference.py → detector → processed video
18
- ```
19
-
20
- 1. User selects mode and uploads video via web interface
21
- 2. Frontend sends video + mode + queries to `/detect` endpoint
22
- 3. Backend runs detection inference with selected model
23
- 4. Returns processed video with bounding boxes
24
-
25
- ### Available Detectors
26
-
27
- The system includes 4 pre-trained object detection models:
28
-
29
- | Detector | Key | Type | Best For |
30
- |----------|-----|------|----------|
31
- | **OWLv2** | `owlv2_base` | Open-vocabulary | Custom text queries (default) |
32
- | **YOLOv8** | `hf_yolov8` | COCO classes | Fast real-time detection |
33
- | **DETR** | `detr_resnet50` | COCO classes | Transformer-based detection |
34
- | **Grounding DINO** | `grounding_dino` | Open-vocabulary | Text-grounded detection |
35
-
36
- All detectors implement the `ObjectDetector` interface in `models/detectors/base.py` with a single `predict()` method.
37
 
38
  ## Development Commands
39
 
40
- ### Setup
41
  ```bash
42
- python -m venv .venv
43
- source .venv/bin/activate # or `.venv/bin/activate` on macOS/Linux
44
  pip install -r requirements.txt
45
- ```
46
 
47
- ### Running the Server
48
- ```bash
49
- # Development
50
  uvicorn app:app --host 0.0.0.0 --port 7860 --reload
51
 
52
- # Production (Docker)
53
- docker build -t object_detectors .
54
- docker run -p 7860:7860 object_detectors
55
- ```
56
 
57
- ### Testing the API
58
- ```bash
59
- # Test object detection
60
- curl -X POST http://localhost:7860/detect \
61
- -F "video=@sample.mp4" \
62
- -F "mode=object_detection" \
63
- -F "queries=person,car,dog" \
64
- -F "detector=owlv2_base" \
65
- --output processed.mp4
66
 
67
- # Test placeholder modes (returns JSON)
68
- curl -X POST http://localhost:7860/detect \
69
  -F "video=@sample.mp4" \
70
- -F "mode=segmentation"
 
 
71
  ```
72
 
73
- ## Key Implementation Details
74
-
75
- ### API Endpoint: `/detect`
76
-
77
- **Parameters:**
78
- - `video` (file): Video file to process
79
- - `mode` (string): Detection mode - `object_detection`, `segmentation`, or `drone_detection`
80
- - `queries` (string): Comma-separated object classes (for object_detection mode)
81
- - `detector` (string): Model key (default: `owlv2_base`)
82
-
83
- **Returns:**
84
- - For `object_detection`: MP4 video with bounding boxes
85
- - For `segmentation`: MP4 video with mask overlays
86
- - For `drone_detection`: JSON with `{"status": "coming_soon", "message": "..."}`
87
 
88
- ### Inference Pipeline
89
 
90
- The `run_inference()` function in `inference.py` follows these steps:
 
 
91
 
92
- 1. **Extract Frames**: Decode video using OpenCV
93
- 2. **Parse Queries**: Split comma-separated text into list (defaults to common objects if empty)
94
- 3. **Select Detector**: Load detector by key (cached via `@lru_cache`)
95
- 4. **Process Frames**: Run detection on each frame
96
- - Call `detector.predict(frame, queries)`
97
- - Draw green bounding boxes on detections
98
- 5. **Write Video**: Encode processed frames back to MP4
99
 
100
- Default queries (if none provided): `["person", "car", "truck", "motorcycle", "bicycle", "bus", "train", "airplane"]`
101
 
102
- ### Detector Loading
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- Detectors are registered in `models/model_loader.py`:
105
 
106
- ```python
107
- _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = {
108
- "owlv2_base": Owlv2Detector,
109
- "hf_yolov8": HuggingFaceYoloV8Detector,
110
- "detr_resnet50": DetrDetector,
111
- "grounding_dino": GroundingDinoDetector,
112
- }
113
- ```
114
 
115
- Loaded via `load_detector(name)` which caches instances for performance.
 
 
 
 
 
 
116
 
117
- ### Detection Result Format
 
 
 
 
 
118
 
119
- All detectors return a `DetectionResult` namedtuple:
120
- ```python
121
- DetectionResult(
122
- boxes: np.ndarray, # Nx4 array [x1, y1, x2, y2]
123
- scores: Sequence[float], # Confidence scores
124
- labels: Sequence[int], # Class indices
125
- label_names: Optional[Sequence[str]] # Human-readable names
126
- )
127
- ```
128
 
129
- ## File Structure
 
 
 
130
 
131
- ```
132
- .
133
- ├── app.py # FastAPI server with /detect endpoint
134
- ├── inference.py # Video processing and detection pipeline
135
- ├── demo.html # Web interface with mode selector
136
- ├── requirements.txt # Python dependencies
137
- ├── models/
138
- │ ├── model_loader.py # Detector registry and loading
139
- │ └── detectors/
140
- │ ├── base.py # ObjectDetector interface
141
- │ ├── owlv2.py # OWLv2 implementation
142
- │ ├── yolov8.py # YOLOv8 implementation
143
- │ ├── detr.py # DETR implementation
144
- │ └── grounding_dino.py # Grounding DINO implementation
145
- ├── utils/
146
- │ └── video.py # Video encoding/decoding utilities
147
- └── coco_classes.py # COCO dataset class definitions
148
- ```
149
 
150
- ## Adding New Detectors
151
 
152
- To add a new detector:
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- 1. **Create detector class** in `models/detectors/`:
155
- ```python
156
- from .base import ObjectDetector, DetectionResult
157
 
158
- class MyDetector(ObjectDetector):
159
- name = "my_detector"
160
 
161
- def predict(self, frame, queries):
162
- # Your detection logic
163
- return DetectionResult(boxes, scores, labels, label_names)
164
- ```
165
 
166
- 2. **Register in model_loader.py**:
167
- ```python
168
- _REGISTRY = {
169
- ...
170
- "my_detector": MyDetector,
171
- }
172
- ```
173
 
174
- 3. **Update frontend** `demo.html` detector dropdown:
175
- ```html
176
- <option value="my_detector">My Detector</option>
177
- ```
178
 
179
- ## Adding New Detection Modes
180
 
181
- To implement additional modes such as drone detection:
182
 
183
- 1. **Create specialized detector** (if needed):
184
- - For segmentation: Extend `SegmentationResult` to include masks
185
- - For drone detection: Create `DroneDetector` with specialized filtering
 
186
 
187
- 2. **Update `/detect` endpoint** in `app.py`:
188
- ```python
189
- if mode == "segmentation":
190
- # Run segmentation inference
191
- # Return video with masks rendered
192
- ```
193
 
194
- 3. **Update frontend** to remove "disabled" class from mode card
195
 
196
- 4. **Update inference.py** if needed to handle new output types
197
 
198
- ## Common Patterns
199
 
200
- ### Query Processing
201
- Queries are parsed from comma-separated strings:
202
- ```python
203
- queries = [q.strip() for q in "person, car, dog".split(",") if q.strip()]
204
- # Result: ["person", "car", "dog"]
205
- ```
206
 
207
- ### Frame Processing Loop
208
- Standard pattern for processing video frames:
209
- ```python
210
- processed_frames = []
211
- for idx, frame in enumerate(frames):
212
- processed_frame, detections = infer_frame(frame, queries, detector_name)
213
- processed_frames.append(processed_frame)
214
- ```
215
 
216
- ### Temporary File Management
217
- FastAPI's `BackgroundTasks` cleans up temp files after response:
218
- ```python
219
- _schedule_cleanup(background_tasks, input_path)
220
- _schedule_cleanup(background_tasks, output_path)
221
- ```
222
 
223
  ## Parallel Execution with Team Mode
224
 
@@ -227,7 +161,6 @@ When implementing features that touch independent subsystems, **use team mode (p
227
  ### When to Parallelize
228
  - Backend (Python) + Frontend (JS) changes — always parallelizable
229
  - Independent API endpoints or UI components
230
- - Test writing + implementation when in different files
231
  - Any 2+ tasks that don't modify the same files
232
 
233
  ### How to Parallelize
@@ -235,46 +168,5 @@ When implementing features that touch independent subsystems, **use team mode (p
235
  2. Dispatch one agent per domain using `isolation: "worktree"`
236
  3. Each agent works in its own git worktree — no conflicts
237
  4. Merge results back: `git checkout <worktree-branch> -- <files>`
238
- 5. Clean up worktrees after merge
239
-
240
- ### Example
241
- ```
242
- Agent 1 (worktree): Backend — app.py, jobs/storage.py
243
- Agent 2 (worktree): Frontend — timeline.js, client.js
244
- → Both run simultaneously, merge when done
245
- ```
246
 
247
  **Default to parallel** when tasks are independent. Sequential only when one task's output is the other's input.
248
-
249
- ## Performance Notes
250
-
251
- - **Detector Caching**: Models are loaded once and cached via `@lru_cache`
252
- - **Default Resolution**: Videos processed at original resolution
253
- - **Frame Limit**: Use `max_frames` parameter in `run_inference()` for testing
254
- - **Memory Usage**: Entire video is loaded into memory (frames list)
255
-
256
- ## Troubleshooting
257
-
258
- ### "No module named 'fastapi'"
259
- Install dependencies: `pip install -r requirements.txt`
260
-
261
- ### "Video decoding failed"
262
- Check video codec compatibility. System expects MP4/H.264.
263
-
264
- ### "Detector not found"
265
- Verify detector key exists in `model_loader._REGISTRY`
266
-
267
- ### Slow processing
268
- - Try faster detector: YOLOv8 (`hf_yolov8`)
269
- - Reduce video resolution before uploading
270
- - Use `max_frames` parameter for testing
271
-
272
- ## Dependencies
273
-
274
- Core packages:
275
- - `fastapi` + `uvicorn`: Web server
276
- - `torch` + `transformers`: Deep learning models
277
- - `opencv-python-headless`: Video processing
278
- - `ultralytics`: YOLOv8 implementation
279
- - `huggingface-hub`: Model downloading
280
- - `pillow`, `scipy`, `accelerate`, `timm`: Supporting libraries
 
4
 
5
  ## Project Overview
6
 
7
+ Multi-GPU video analysis platform with three fully functional modes:
8
+ - **Object Detection**: Bounding boxes via YOLO11, DETR, or Grounding DINO
9
+ - **Segmentation**: Mask overlays via Grounded SAM2 (GSAM2) or YOLO+SAM2 (YSAM2)
10
+ - **Drone Detection**: Aerial object detection via YOLOv8 fine-tuned on VisDrone
11
 
12
+ Deployed as a HuggingFace Space (Docker SDK) at `https://biaslab2025-isr.hf.space`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  ## Development Commands
15
 
 
16
  ```bash
17
+ # Setup
18
+ python -m venv .venv && source .venv/bin/activate
19
  pip install -r requirements.txt
 
20
 
21
+ # Run dev server
 
 
22
  uvicorn app:app --host 0.0.0.0 --port 7860 --reload
23
 
24
+ # Verify imports (quick smoke test — no tests exist yet)
25
+ python -c "from app import app"
 
 
26
 
27
+ # Docker
28
+ docker build -t isr . && docker run -p 7860:7860 isr
 
 
 
 
 
 
 
29
 
30
+ # Test async detection
31
+ curl -X POST http://localhost:7860/detect/async \
32
  -F "video=@sample.mp4" \
33
+ -F "mode=object_detection" \
34
+ -F "queries=person,car" \
35
+ -F "detector=yolo11"
36
  ```
37
 
38
+ ## Core Architecture
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ ### Async Detection Flow (primary path)
41
 
42
+ ```
43
+ Frontend (index.html) → POST /detect/async → background task → MJPEG stream + polling
44
+ ```
45
 
46
+ 1. Frontend uploads video + mode + queries to `/detect/async`
47
+ 2. Backend creates a `JobInfo`, spawns `process_video_async()` as an `asyncio.Task`
48
+ 3. `inference.py` runs multi-GPU parallel inference, publishing frames to an MJPEG stream
49
+ 4. Frontend consumes `/detect/stream/{job_id}` for live video, polls `/detect/status/{job_id}`
50
+ 5. On completion, frontend fetches final video from `/detect/video/{job_id}`
 
 
51
 
52
+ ### API Endpoints (app.py)
53
 
54
+ | Method | Path | Purpose |
55
+ |--------|------|---------|
56
+ | POST | `/detect/async` | Start async job (returns `job_id` + stream/status URLs) |
57
+ | GET | `/detect/status/{job_id}` | Poll job status |
58
+ | GET | `/detect/stream/{job_id}` | MJPEG live stream (event-driven, 640px wide) |
59
+ | GET | `/detect/video/{job_id}` | Download processed MP4 |
60
+ | GET | `/detect/depth-video/{job_id}` | Download depth video |
61
+ | GET | `/detect/tracks/{job_id}/summary` | Per-frame detection counts (timeline heatmap) |
62
+ | GET | `/detect/tracks/{job_id}/{frame_idx}` | Per-frame track data |
63
+ | DELETE | `/detect/job/{job_id}` | Cancel running job |
64
+ | POST | `/detect` | Synchronous detection (returns MP4 directly) |
65
+ | POST | `/benchmark` | GSAM2 latency breakdown |
66
+ | POST | `/benchmark/profile` | Per-frame timing breakdown |
67
+ | POST | `/benchmark/analysis` | Full roofline analysis |
68
 
69
+ **`/detect/async` params:** `video`, `mode` (object_detection/segmentation/drone_detection), `queries`, `detector` (default: yolo11), `segmenter` (default: GSAM2-L), `enable_depth` (default: false), `step` (default: 7, segmentation keyframe interval).
70
 
71
+ ### Multi-GPU Inference Pipeline (inference.py)
 
 
 
 
 
 
 
72
 
73
+ **`run_inference()`** Detection and drone modes:
74
+ - `AsyncVideoReader` prefetches frames into a queue (up to 32 frames)
75
+ - Models loaded in parallel via `ThreadPoolExecutor` (one detector per GPU)
76
+ - Queue-based producer/consumer: main thread feeds `queue_in`, N GPU workers drain it
77
+ - Workers batch frames (up to `max_batch_size=32` for YOLO) under per-model `RLock`
78
+ - Writer thread reorders frames, runs `ByteTracker` + `SpeedEstimator`, writes via `StreamingVideoWriter`, publishes to MJPEG stream
79
+ - Cancellation: workers poll `_check_cancellation(job_id)` each cycle
80
 
81
+ **`run_grounded_sam2_tracking()`** Segmentation mode:
82
+ - Extracts all frames to JPEG files on disk
83
+ - Runs detection on keyframes (every `step` frames) to seed SAM2
84
+ - SAM2 video predictor propagates masks between keyframes
85
+ - ID reconciliation via IoU matching in `MaskDictionary`
86
+ - Renders colored semi-transparent mask overlays with contours
87
 
88
+ ### Jobs System (jobs/)
 
 
 
 
 
 
 
 
89
 
90
+ - **`models.py`** — `JobInfo` dataclass + `JobStatus` enum (PROCESSING/COMPLETED/FAILED/CANCELLED)
91
+ - **`storage.py`** — In-memory `JobStorage` (singleton, `RLock`-protected) + disk at `/tmp/detection_jobs/{job_id}/`. Per-frame track data stored here. Auto-cleanup every 10 min (1hr expiry).
92
+ - **`background.py`** — `process_video_async()` coroutine dispatches to the right inference function
93
+ - **`streaming.py`** — MJPEG frame queue + `asyncio.Event` publisher; `publish_frame()` resizes to 640px
94
 
95
+ ### Frontend (frontend/)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ Single-page app served at `/app`. No build step. Uses `window.APP` global namespace.
98
 
99
+ **Script modules (load order matters):**
100
+ - `init.js` → bootstraps `window.APP` namespace
101
+ - `core/config.js` → backend URL, tracking constants
102
+ - `core/state.js` → all client state (video, job, tracks, UI)
103
+ - `core/video.js` → video load/unload, blob lifecycle, depth toggle
104
+ - `core/tracker.js` → client-side IoU + velocity tracker
105
+ - `core/timeline.js` → canvas heatmap timeline bar
106
+ - `api/client.js` → `hfDetectAsync()`, `pollAsyncJob()`, `cancelBackendJob()`
107
+ - `ui/overlays.js` → canvas bounding box rendering
108
+ - `ui/cards.js` → live track card panel
109
+ - `ui/logging.js` → system log, status indicators
110
+ - `main.js` → event wiring, app entry point
111
 
112
+ The frontend infers `mode` from `data-kind` attribute on the `<select id="detectorSelect">` options.
 
 
113
 
114
+ ## Models
 
115
 
116
+ ### Detectors (models/detectors/)
 
 
 
117
 
118
+ | Key | Class | Type | Batch | Notes |
119
+ |-----|-------|------|-------|-------|
120
+ | `yolo11` | `Yolo11Detector` | COCO closed-set | Yes (32) | Default. Tiling for large frames. |
121
+ | `detr_resnet50` | `DetrDetector` | COCO closed-set | No | HF transformers pipeline |
122
+ | `grounding_dino` | `GroundingDinoDetector` | Open-vocabulary | No | Text-query grounded detection |
123
+ | `yolov8_visdrone` | `YoloV8VisDroneDetector` | VisDrone aerial | Yes (32) | `ensure_weights()` for safe parallel init |
 
124
 
125
+ All implement `ObjectDetector.predict(frame, queries) → DetectionResult(boxes, scores, labels, label_names)`.
 
 
 
126
 
127
+ Registered in `models/model_loader.py`. Cached via `@lru_cache` for single-GPU; `load_detector_on_device(name, device)` for multi-GPU (uncached). Call `prefetch_weights(name)` before parallel GPU init to avoid download race conditions.
128
 
129
+ ### Segmenters (models/segmenters/)
130
 
131
+ | Key | Detector | SAM2 Size |
132
+ |-----|----------|-----------|
133
+ | `GSAM2-S/B/L` | Grounding DINO | small/base/large |
134
+ | `YSAM2-S/B/L` | YOLO11 | small/base/large |
135
 
136
+ Default: `GSAM2-L`. Registered in `models/segmenters/model_loader.py`.
 
 
 
 
 
137
 
138
+ ### Depth Estimators (models/depth_estimators/)
139
 
140
+ Single entry: key `depth` `DepthAnythingV2Estimator`. Optional, enabled via `enable_depth=True`.
141
 
142
+ ## Adding New Detectors
143
 
144
+ 1. Create class in `models/detectors/` implementing `ObjectDetector.predict()` → `DetectionResult`
145
+ 2. If weights need downloading, add `ensure_weights()` classmethod for thread-safe prefetch
146
+ 3. Register in `models/model_loader.py` `_REGISTRY`
147
+ 4. Add `<option>` to `frontend/index.html` `#detectorSelect` with appropriate `data-kind`
 
 
148
 
149
+ ## Key Patterns
 
 
 
 
 
 
 
150
 
151
+ - **Weight downloads**: Use `ensure_weights()` classmethod + `prefetch_weights()` in inference.py before `ThreadPoolExecutor` to avoid race conditions (see `yolov8_visdrone.py`)
152
+ - **Per-model locking**: Each detector/depth instance gets a `threading.RLock` for thread-safe `predict()` calls in multi-GPU workers
153
+ - **Frame reordering**: Writer thread uses a reorder buffer (128 frames) since GPU workers finish out-of-order
154
+ - **MJPEG streaming**: `publish_frame()` drops frames if queue full (backpressure), consumer is event-driven at ~30fps
155
+ - **Job file layout**: `/tmp/detection_jobs/{job_id}/` → `input.mp4`, `output.mp4`, `depth.mp4`
 
156
 
157
  ## Parallel Execution with Team Mode
158
 
 
161
  ### When to Parallelize
162
  - Backend (Python) + Frontend (JS) changes — always parallelizable
163
  - Independent API endpoints or UI components
 
164
  - Any 2+ tasks that don't modify the same files
165
 
166
  ### How to Parallelize
 
168
  2. Dispatch one agent per domain using `isolation: "worktree"`
169
  3. Each agent works in its own git worktree — no conflicts
170
  4. Merge results back: `git checkout <worktree-branch> -- <files>`
 
 
 
 
 
 
 
 
171
 
172
  **Default to parallel** when tasks are independent. Sequential only when one task's output is the other's input.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py CHANGED
@@ -646,7 +646,10 @@ def run_inference(
646
 
647
  if num_gpus > 0:
648
  logging.info("Detected %d GPUs. Loading models in parallel...", num_gpus)
649
-
 
 
 
650
  def load_models_on_gpu(gpu_id: int):
651
  device_str = f"cuda:{gpu_id}"
652
  try:
 
646
 
647
  if num_gpus > 0:
648
  logging.info("Detected %d GPUs. Loading models in parallel...", num_gpus)
649
+ # Pre-download weights before parallel GPU init to avoid race conditions
650
+ from models.model_loader import prefetch_weights
651
+ prefetch_weights(active_detector)
652
+
653
  def load_models_on_gpu(gpu_id: int):
654
  device_str = f"cuda:{gpu_id}"
655
  try:
models/detectors/yolov8_visdrone.py CHANGED
@@ -23,6 +23,17 @@ class YoloV8VisDroneDetector(ObjectDetector):
23
  supports_batch = True
24
  max_batch_size = 32
25
 
 
 
 
 
 
 
 
 
 
 
 
26
  def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
27
  self.name = "yolov8_visdrone"
28
  self.score_threshold = score_threshold
@@ -31,17 +42,9 @@ class YoloV8VisDroneDetector(ObjectDetector):
31
  else:
32
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
  logging.info(
34
- "Loading YOLOv8-VisDrone from HuggingFace Hub: %s onto %s",
35
- self.REPO_ID,
36
- self.device,
37
  )
38
- if not _VISDRONE_PATH.exists():
39
- logging.info("Downloading visDrone.pt to %s ...", _VISDRONE_PATH)
40
- hf_hub_download(
41
- repo_id=self.REPO_ID,
42
- filename="visDrone.pt",
43
- local_dir=str(_WEIGHTS_CACHE),
44
- )
45
  self.model = YOLO(str(_VISDRONE_PATH))
46
  self.model.to(self.device)
47
  self.class_names = self.model.names
 
23
  supports_batch = True
24
  max_batch_size = 32
25
 
26
+ @classmethod
27
+ def ensure_weights(cls):
28
+ """Download weights once (call before parallel GPU init)."""
29
+ if not _VISDRONE_PATH.exists():
30
+ logging.info("Downloading visDrone.pt to %s ...", _VISDRONE_PATH)
31
+ hf_hub_download(
32
+ repo_id=cls.REPO_ID,
33
+ filename="visDrone.pt",
34
+ local_dir=str(_WEIGHTS_CACHE),
35
+ )
36
+
37
  def __init__(self, score_threshold: float = 0.3, device: str = None) -> None:
38
  self.name = "yolov8_visdrone"
39
  self.score_threshold = score_threshold
 
42
  else:
43
  self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
  logging.info(
45
+ "Loading YOLOv8-VisDrone onto %s", self.device,
 
 
46
  )
47
+ self.ensure_weights()
 
 
 
 
 
 
48
  self.model = YOLO(str(_VISDRONE_PATH))
49
  self.model.to(self.device)
50
  self.class_names = self.model.names
models/model_loader.py CHANGED
@@ -39,6 +39,13 @@ def load_detector(name: Optional[str] = None) -> ObjectDetector:
39
  return _get_cached_detector(detector_name)
40
 
41
 
 
 
 
 
 
 
 
42
  def load_detector_on_device(name: str, device: str) -> ObjectDetector:
43
  """Create a new detector instance on the specified device (no caching)."""
44
  return _create_detector(name, device=device)
 
39
  return _get_cached_detector(detector_name)
40
 
41
 
42
+ def prefetch_weights(name: str) -> None:
43
+ """Pre-download model weights (call before parallel GPU init)."""
44
+ factory = _REGISTRY.get(name)
45
+ if factory and hasattr(factory, "ensure_weights"):
46
+ factory.ensure_weights()
47
+
48
+
49
  def load_detector_on_device(name: str, device: str) -> ObjectDetector:
50
  """Create a new detector instance on the specified device (no caching)."""
51
  return _create_detector(name, device=device)