Spaces:
Paused
Paused
maybe finally fix for weights
Browse files- perception_roi_server.py +36 -14
perception_roi_server.py
CHANGED
|
@@ -59,26 +59,48 @@ def root():
|
|
| 59 |
_model_lock = threading.Lock()
|
| 60 |
_models: Dict[str, YOLO] = {}
|
| 61 |
|
| 62 |
-
def _resolve_weights_path(weights: str) -> str:
|
| 63 |
if not weights:
|
| 64 |
-
return DEFAULT_WEIGHTS
|
| 65 |
-
w = str(weights)
|
| 66 |
-
if
|
| 67 |
-
return
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
if WEIGHTS_DIR:
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
cand = os.path.join(base, w)
|
|
|
|
| 76 |
if os.path.exists(cand):
|
| 77 |
-
return cand
|
| 78 |
-
return w
|
| 79 |
|
| 80 |
def get_model(weights: str) -> YOLO:
|
| 81 |
-
key = _resolve_weights_path(weights or DEFAULT_WEIGHTS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
with _model_lock:
|
| 83 |
if key not in _models:
|
| 84 |
_models[key] = YOLO(key)
|
|
|
|
| 59 |
_model_lock = threading.Lock()
|
| 60 |
_models: Dict[str, YOLO] = {}
|
| 61 |
|
| 62 |
+
def _resolve_weights_path(weights: str) -> (str, List[str]):
|
| 63 |
if not weights:
|
| 64 |
+
return DEFAULT_WEIGHTS, []
|
| 65 |
+
w = str(weights).strip()
|
| 66 |
+
if not w:
|
| 67 |
+
return DEFAULT_WEIGHTS, []
|
| 68 |
+
if os.path.isabs(w) and os.path.exists(w):
|
| 69 |
+
return os.path.abspath(w), [os.path.abspath(w)]
|
| 70 |
+
if os.path.exists(w):
|
| 71 |
+
return os.path.abspath(w), [os.path.abspath(w)]
|
| 72 |
+
search_dirs: List[str] = []
|
| 73 |
if WEIGHTS_DIR:
|
| 74 |
+
search_dirs.append(WEIGHTS_DIR)
|
| 75 |
+
search_dirs.extend([
|
| 76 |
+
os.getcwd(),
|
| 77 |
+
os.path.dirname(__file__),
|
| 78 |
+
os.path.abspath(os.path.dirname(__file__)),
|
| 79 |
+
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
|
| 80 |
+
DATA_DIR,
|
| 81 |
+
"/home/user/app",
|
| 82 |
+
"/app",
|
| 83 |
+
"/workspace",
|
| 84 |
+
"/data",
|
| 85 |
+
])
|
| 86 |
+
checked: List[str] = []
|
| 87 |
+
for base in search_dirs:
|
| 88 |
+
if not base:
|
| 89 |
+
continue
|
| 90 |
cand = os.path.join(base, w)
|
| 91 |
+
checked.append(cand)
|
| 92 |
if os.path.exists(cand):
|
| 93 |
+
return os.path.abspath(cand), checked
|
| 94 |
+
return w, checked
|
| 95 |
|
| 96 |
def get_model(weights: str) -> YOLO:
|
| 97 |
+
key, checked = _resolve_weights_path(weights or DEFAULT_WEIGHTS)
|
| 98 |
+
if str(key).endswith(".pt") and not os.path.exists(key):
|
| 99 |
+
search_list = ", ".join(checked) if checked else "(no local paths searched)"
|
| 100 |
+
raise RuntimeError(
|
| 101 |
+
f"Weights not found locally: {weights}. Searched: {search_list}. "
|
| 102 |
+
f"Set WEIGHTS_DIR or upload the weights to the app directory."
|
| 103 |
+
)
|
| 104 |
with _model_lock:
|
| 105 |
if key not in _models:
|
| 106 |
_models[key] = YOLO(key)
|