AliZak commited on
Commit
8d06707
·
verified ·
1 Parent(s): 81b959f

RT-DETR compatibility

Browse files
Files changed (1) hide show
  1. perception_roi_server.py +18 -7
perception_roi_server.py CHANGED
@@ -31,7 +31,7 @@ import numpy as np
31
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
32
  from fastapi.middleware.cors import CORSMiddleware
33
  from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
34
- from ultralytics import YOLO
35
 
36
  DEFAULT_WEIGHTS = os.environ.get("YOLO_WEIGHTS", "yolov8s.pt")
37
  WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", os.path.dirname(__file__))
@@ -57,7 +57,13 @@ def root():
57
  return {"status": "ok", "service": "roi-compression"}
58
 
59
  _model_lock = threading.Lock()
60
- _models: Dict[str, YOLO] = {}
 
 
 
 
 
 
61
 
62
  def _resolve_weights_path(weights: str) -> (str, List[str]):
63
  if not weights:
@@ -93,8 +99,9 @@ def _resolve_weights_path(weights: str) -> (str, List[str]):
93
  return os.path.abspath(cand), checked
94
  return w, checked
95
 
96
- def get_model(weights: str) -> YOLO:
97
  key, checked = _resolve_weights_path(weights or DEFAULT_WEIGHTS)
 
98
  if str(key).endswith(".pt") and not os.path.exists(key):
99
  search_list = ", ".join(checked) if checked else "(no local paths searched)"
100
  raise RuntimeError(
@@ -102,9 +109,13 @@ def get_model(weights: str) -> YOLO:
102
  f"Set WEIGHTS_DIR or upload the weights to the app directory."
103
  )
104
  with _model_lock:
105
- if key not in _models:
106
- _models[key] = YOLO(key)
107
- return _models[key]
 
 
 
 
108
 
109
 
110
  def _parse_queries(q: str) -> List[str]:
@@ -124,7 +135,7 @@ def _keep_det(label: str, queries: List[str]) -> bool:
124
 
125
 
126
  def _yolo_detect_frame(
127
- model: YOLO,
128
  frame_bgr: np.ndarray,
129
  conf: float,
130
  queries: List[str],
 
31
  from fastapi import FastAPI, UploadFile, File, Form, HTTPException
32
  from fastapi.middleware.cors import CORSMiddleware
33
  from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
34
+ from ultralytics import YOLO, RTDETR
35
 
36
  DEFAULT_WEIGHTS = os.environ.get("YOLO_WEIGHTS", "yolov8s.pt")
37
  WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", os.path.dirname(__file__))
 
57
  return {"status": "ok", "service": "roi-compression"}
58
 
59
  _model_lock = threading.Lock()
60
+ _models: Dict[str, Any] = {}
61
+
62
+ def _infer_model_type(weights: str) -> str:
63
+ name = os.path.basename(str(weights or "")).lower()
64
+ if name.startswith("rtdetr"):
65
+ return "rtdetr"
66
+ return "yolo"
67
 
68
  def _resolve_weights_path(weights: str) -> (str, List[str]):
69
  if not weights:
 
99
  return os.path.abspath(cand), checked
100
  return w, checked
101
 
102
+ def get_model(weights: str) -> Any:
103
  key, checked = _resolve_weights_path(weights or DEFAULT_WEIGHTS)
104
+ model_type = _infer_model_type(key)
105
  if str(key).endswith(".pt") and not os.path.exists(key):
106
  search_list = ", ".join(checked) if checked else "(no local paths searched)"
107
  raise RuntimeError(
 
109
  f"Set WEIGHTS_DIR or upload the weights to the app directory."
110
  )
111
  with _model_lock:
112
+ cache_key = f"{model_type}:{key}"
113
+ if cache_key not in _models:
114
+ if model_type == "rtdetr":
115
+ _models[cache_key] = RTDETR(key)
116
+ else:
117
+ _models[cache_key] = YOLO(key)
118
+ return _models[cache_key]
119
 
120
 
121
  def _parse_queries(q: str) -> List[str]:
 
135
 
136
 
137
  def _yolo_detect_frame(
138
+ model: Any,
139
  frame_bgr: np.ndarray,
140
  conf: float,
141
  queries: List[str],