Spaces:
Paused
Paused
RT-DETR compatibility
Browse files- perception_roi_server.py +18 -7
perception_roi_server.py
CHANGED
|
@@ -31,7 +31,7 @@ import numpy as np
|
|
| 31 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 32 |
from fastapi.middleware.cors import CORSMiddleware
|
| 33 |
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
|
| 34 |
-
from ultralytics import YOLO
|
| 35 |
|
| 36 |
DEFAULT_WEIGHTS = os.environ.get("YOLO_WEIGHTS", "yolov8s.pt")
|
| 37 |
WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", os.path.dirname(__file__))
|
|
@@ -57,7 +57,13 @@ def root():
|
|
| 57 |
return {"status": "ok", "service": "roi-compression"}
|
| 58 |
|
| 59 |
_model_lock = threading.Lock()
|
| 60 |
-
_models: Dict[str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
def _resolve_weights_path(weights: str) -> (str, List[str]):
|
| 63 |
if not weights:
|
|
@@ -93,8 +99,9 @@ def _resolve_weights_path(weights: str) -> (str, List[str]):
|
|
| 93 |
return os.path.abspath(cand), checked
|
| 94 |
return w, checked
|
| 95 |
|
| 96 |
-
def get_model(weights: str) ->
|
| 97 |
key, checked = _resolve_weights_path(weights or DEFAULT_WEIGHTS)
|
|
|
|
| 98 |
if str(key).endswith(".pt") and not os.path.exists(key):
|
| 99 |
search_list = ", ".join(checked) if checked else "(no local paths searched)"
|
| 100 |
raise RuntimeError(
|
|
@@ -102,9 +109,13 @@ def get_model(weights: str) -> YOLO:
|
|
| 102 |
f"Set WEIGHTS_DIR or upload the weights to the app directory."
|
| 103 |
)
|
| 104 |
with _model_lock:
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
def _parse_queries(q: str) -> List[str]:
|
|
@@ -124,7 +135,7 @@ def _keep_det(label: str, queries: List[str]) -> bool:
|
|
| 124 |
|
| 125 |
|
| 126 |
def _yolo_detect_frame(
|
| 127 |
-
model:
|
| 128 |
frame_bgr: np.ndarray,
|
| 129 |
conf: float,
|
| 130 |
queries: List[str],
|
|
|
|
| 31 |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
| 32 |
from fastapi.middleware.cors import CORSMiddleware
|
| 33 |
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
|
| 34 |
+
from ultralytics import YOLO, RTDETR
|
| 35 |
|
| 36 |
DEFAULT_WEIGHTS = os.environ.get("YOLO_WEIGHTS", "yolov8s.pt")
|
| 37 |
WEIGHTS_DIR = os.environ.get("WEIGHTS_DIR", os.path.dirname(__file__))
|
|
|
|
| 57 |
return {"status": "ok", "service": "roi-compression"}
|
| 58 |
|
| 59 |
_model_lock = threading.Lock()
|
| 60 |
+
_models: Dict[str, Any] = {}
|
| 61 |
+
|
| 62 |
+
def _infer_model_type(weights: str) -> str:
|
| 63 |
+
name = os.path.basename(str(weights or "")).lower()
|
| 64 |
+
if name.startswith("rtdetr"):
|
| 65 |
+
return "rtdetr"
|
| 66 |
+
return "yolo"
|
| 67 |
|
| 68 |
def _resolve_weights_path(weights: str) -> (str, List[str]):
|
| 69 |
if not weights:
|
|
|
|
| 99 |
return os.path.abspath(cand), checked
|
| 100 |
return w, checked
|
| 101 |
|
| 102 |
+
def get_model(weights: str) -> Any:
|
| 103 |
key, checked = _resolve_weights_path(weights or DEFAULT_WEIGHTS)
|
| 104 |
+
model_type = _infer_model_type(key)
|
| 105 |
if str(key).endswith(".pt") and not os.path.exists(key):
|
| 106 |
search_list = ", ".join(checked) if checked else "(no local paths searched)"
|
| 107 |
raise RuntimeError(
|
|
|
|
| 109 |
f"Set WEIGHTS_DIR or upload the weights to the app directory."
|
| 110 |
)
|
| 111 |
with _model_lock:
|
| 112 |
+
cache_key = f"{model_type}:{key}"
|
| 113 |
+
if cache_key not in _models:
|
| 114 |
+
if model_type == "rtdetr":
|
| 115 |
+
_models[cache_key] = RTDETR(key)
|
| 116 |
+
else:
|
| 117 |
+
_models[cache_key] = YOLO(key)
|
| 118 |
+
return _models[cache_key]
|
| 119 |
|
| 120 |
|
| 121 |
def _parse_queries(q: str) -> List[str]:
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
def _yolo_detect_frame(
|
| 138 |
+
model: Any,
|
| 139 |
frame_bgr: np.ndarray,
|
| 140 |
conf: float,
|
| 141 |
queries: List[str],
|