Tan Zi Xu
detection implementation
0aae217
# core/model_loader.py
import os, json, pathlib, tempfile, torch
from pathlib import Path
from typing import List, Tuple, Optional
from cloud.storage_s3 import S3Store
import streamlit as st
from infer import load_predict_fn
from core.detect_infer import YOLODetector
TMP = tempfile.gettempdir()
torch.hub.set_dir(f"{TMP}/.cache/torch/hub")
def setup_page(name: str, tagline: str):
st.set_page_config(page_title=name, layout="wide")
st.markdown("<style>.block-container{max-width:1800px;padding:1rem}</style>", unsafe_allow_html=True)
st.title(name)
st.caption(tagline)
def list_model_dirs(root: Path) -> List[str]:
items = []
if not root.exists(): return items
for p in sorted(root.iterdir()):
if p.is_dir() and (p/"model_config.json").exists() and (p/"class_map.json").exists():
items.append(p.name)
return items
@st.cache_resource(show_spinner=False)
def _load(cfg, ckpt, clsf, model_name, models_root):
return load_predict_fn(str(cfg), str(ckpt), str(clsf),
model_name=model_name, models_root=str(models_root))
def load_model_cached(model_name: str, models_root: Path):
model_dir = models_root / model_name
cfg = model_dir / "model_config.json"
ckpt = model_dir / "weights" / "best.pth"
clsf = model_dir / "class_map.json"
st.write(f"• Reading model folder: {model_dir}")
predict_fn, CLASS_NAMES, MODEL_OBJ, MODEL_ARCH = _load(cfg, ckpt, clsf, model_name, models_root)
with open(cfg, "r", encoding="utf-8") as f:
model_config = json.load(f)
model_image_size = int(model_config["preprocess"].get("image_size", 224))
return (predict_fn, CLASS_NAMES, MODEL_OBJ, MODEL_ARCH,
model_image_size, model_dir, str(cfg), str(ckpt), str(clsf))
# ========= Detector resolution & cache (append to end of core/model_loader.py) =========
def _read_manifest_dict(models_root: Path) -> dict:
m = models_root / "manifest.json"
if not m.exists():
return {}
try:
return json.loads(m.read_text(encoding="utf-8"))
except Exception:
return {}
def _download_s3_to_tmp(s3_uri: str, tmp_subdir: str = "yolo") -> str:
assert s3_uri.startswith("s3://"), "s3_uri must start with s3://"
bucket_and_key = s3_uri[len("s3://"):]
bucket, _, key = bucket_and_key.partition("/")
store = S3Store(bucket=bucket, prefix="")
data = store.read_bytes(key)
out_dir = Path(tempfile.gettempdir()) / "models" / tmp_subdir
out_dir.mkdir(parents=True, exist_ok=True)
fname = key.split("/")[-1] or "best.pt"
out_path = out_dir / fname
out_path.write_bytes(data)
return str(out_path)
def resolve_detector_weights(models_root: Path,
model_name: Optional[str] = None) -> Tuple[str, str]:
"""
Returns (weights_path, source_str) where source_str in:
{'env','manifest-hf','manifest-s3','manifest-local','local','s3'}
Precedence: DET_WEIGHTS env -> manifest.detectors[model_name](hf->s3->local) -> local guess.
"""
# 1) explicit env override (s3 or local)
env_w = os.getenv("DET_WEIGHTS", "").strip()
if env_w:
if env_w.startswith("s3://"):
return _download_s3_to_tmp(env_w), "s3"
return env_w, "env"
# 2) manifest
if model_name:
man = _read_manifest_dict(models_root)
det = (man.get("detectors") or {}).get(model_name)
if det:
# HF (optional)
hf = det.get("hf")
if isinstance(hf, dict) and hf.get("repo_id") and hf.get("filename"):
# defer HF to your existing HF resolver if you have one; else skip
try:
from huggingface_hub import hf_hub_download
token = os.getenv("HF_TOKEN")
wpath = hf_hub_download(hf["repo_id"], hf["filename"], token=token)
return wpath, "manifest-hf"
except Exception:
pass
# S3
if det.get("s3"):
bkt = os.getenv("AWS_S3_BUCKET", "")
pfx = os.getenv("AWS_S3_PREFIX", "").strip("/")
key = det["s3"].lstrip("/")
s3_uri = f"s3://{bkt}/{('/'+pfx) if pfx else ''}/{key}".replace("//", "/").replace("s3:/", "s3://")
return _download_s3_to_tmp(s3_uri, tmp_subdir=model_name), "manifest-s3"
# Local
if det.get("local"):
return str((models_root / det["local"]).resolve()), "manifest-local"
# 3) fallback local convention
if model_name:
guess = models_root / model_name / "weights" / "best.pt"
if guess.exists():
return str(guess), "local"
raise FileNotFoundError("Detector weights not found. Set DET_WEIGHTS or add to manifest.json.")
@st.cache_resource(show_spinner=False)
def get_detector_cached(models_root: str, model_name: str):
"""
Returns (detector, meta_dict) and caches across reruns.
meta: {'name','path','source','device'}
"""
root = Path(models_root)
wpath, source = resolve_detector_weights(root, model_name)
det = YOLODetector(wpath) # auto-selects cuda/cpu
device = det.device if hasattr(det, "device") else "auto"
meta = {"name": model_name, "path": wpath, "source": source, "device": device}
return det, meta