divinixx commited on
Commit
139033f
·
1 Parent(s): fab6ee2

Skip MiDaS depth model on CPU for faster performance

Browse files
Files changed (2) hide show
  1. app/models/loader.py +22 -5
  2. app/services/pipeline.py +26 -12
app/models/loader.py CHANGED
@@ -3,11 +3,15 @@ Model loader — loads BLIP-2, YOLOv8, and MiDaS at startup.
3
  """
4
 
5
  import logging
 
6
  import torch
7
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
8
  from ultralytics import YOLO
9
  from pathlib import Path
10
 
 
 
 
11
  log = logging.getLogger(__name__)
12
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
13
 
@@ -43,11 +47,24 @@ def load_models() -> dict:
43
  yolo_model.to(device)
44
  log.info("YOLOv8 loaded.")
45
 
46
- log.info("Loading MiDaS (small)...")
47
- midas_model = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True)
48
- midas_model.to(device)
49
- midas_model.eval()
50
- log.info("MiDaS loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  log.info("All models ready.")
53
  return {
 
3
  """
4
 
5
  import logging
6
+ import os
7
  import torch
8
  from transformers import Blip2Processor, Blip2ForConditionalGeneration
9
  from ultralytics import YOLO
10
  from pathlib import Path
11
 
12
+ # Trust all torch hub repos (needed for MiDaS dependencies)
13
+ os.environ['TORCH_HOME'] = '/tmp/torch_cache'
14
+
15
  log = logging.getLogger(__name__)
16
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
17
 
 
47
  yolo_model.to(device)
48
  log.info("YOLOv8 loaded.")
49
 
50
+ # Skip depth estimation on CPU for faster performance
51
+ # Depth is used for prioritization but not critical for navigation
52
+ if device.type == "cpu":
53
+ log.info("Skipping MiDaS on CPU for better performance...")
54
+ midas_model = None
55
+ else:
56
+ log.info("Loading MiDaS (small)...")
57
+ torch.hub.set_dir('/tmp/torch_cache')
58
+ midas_model = torch.hub.load(
59
+ "intel-isl/MiDaS",
60
+ "MiDaS_small",
61
+ trust_repo=True,
62
+ force_reload=False,
63
+ skip_validation=True
64
+ )
65
+ midas_model.to(device)
66
+ midas_model.eval()
67
+ log.info("MiDaS loaded.")
68
 
69
  log.info("All models ready.")
70
  return {
app/services/pipeline.py CHANGED
@@ -108,22 +108,30 @@ def run_pipeline_frame_data(frame: np.ndarray, models: Dict) -> Dict:
108
  blip_model = models.get("blip2_model")
109
  blip_processor = models.get("blip2_processor")
110
  yolo_model = models.get("yolo_model")
111
- midas_model = models.get("midas_model")
112
  device = models.get("device")
113
 
114
- if any(v is None for v in [blip_model, blip_processor, yolo_model, midas_model, device]):
115
- missing = [
116
- k for k, v in [
117
- ("blip2_model", blip_model), ("blip2_processor", blip_processor),
118
- ("yolo_model", yolo_model), ("midas_model", midas_model), ("device", device),
119
- ] if v is None
120
- ]
121
- raise RuntimeError(f"Models not fully loaded. Missing: {missing}")
 
 
122
 
123
  # --- Run each model stage ---
124
  desc, caption_latency = description(frame, blip_model, blip_processor, device)
125
  dets, detection_latency = detect_objects(frame, yolo_model)
126
- depth_map, depth_latency = estimate_depth(frame, midas_model, device)
 
 
 
 
 
 
127
 
128
  main_feature = extract_main_feature(dets)
129
  unique_objects = _unique_ordered(dets["class_names"])
@@ -133,8 +141,14 @@ def run_pipeline_frame_data(frame: np.ndarray, models: Dict) -> Dict:
133
  detection_details = []
134
  for box, label, conf in zip(dets["boxes"], dets["class_names"], dets["confidences"]):
135
  x1, y1, x2, y2 = box
136
- region = depth_map[max(0, y1):min(h, y2), max(0, x1):min(w, x2)]
137
- median_depth = float(np.median(region)) if region.size > 0 else None
 
 
 
 
 
 
138
  detection_details.append({
139
  "label": label,
140
  "confidence": round(float(conf), 3),
 
108
  blip_model = models.get("blip2_model")
109
  blip_processor = models.get("blip2_processor")
110
  yolo_model = models.get("yolo_model")
111
+ midas_model = models.get("midas_model") # Can be None on CPU
112
  device = models.get("device")
113
 
114
+ # Check required models (MiDaS is optional)
115
+ required_models = [
116
+ ("blip2_model", blip_model),
117
+ ("blip2_processor", blip_processor),
118
+ ("yolo_model", yolo_model),
119
+ ("device", device),
120
+ ]
121
+ missing = [k for k, v in required_models if v is None]
122
+ if missing:
123
+ raise RuntimeError(f"Required models not loaded. Missing: {missing}")
124
 
125
  # --- Run each model stage ---
126
  desc, caption_latency = description(frame, blip_model, blip_processor, device)
127
  dets, detection_latency = detect_objects(frame, yolo_model)
128
+
129
+ # Skip depth estimation if MiDaS not loaded (CPU optimization)
130
+ if midas_model is not None:
131
+ depth_map, depth_latency = estimate_depth(frame, midas_model, device)
132
+ else:
133
+ depth_map = None
134
+ depth_latency = 0.0
135
 
136
  main_feature = extract_main_feature(dets)
137
  unique_objects = _unique_ordered(dets["class_names"])
 
141
  detection_details = []
142
  for box, label, conf in zip(dets["boxes"], dets["class_names"], dets["confidences"]):
143
  x1, y1, x2, y2 = box
144
+
145
+ # Calculate depth only if depth_map available
146
+ if depth_map is not None:
147
+ region = depth_map[max(0, y1):min(h, y2), max(0, x1):min(w, x2)]
148
+ median_depth = float(np.median(region)) if region.size > 0 else None
149
+ else:
150
+ median_depth = None
151
+
152
  detection_details.append({
153
  "label": label,
154
  "confidence": round(float(conf), 3),