| |
| import os, json, pathlib, tempfile, torch |
| from pathlib import Path |
| from typing import List, Tuple, Optional |
| from cloud.storage_s3 import S3Store |
| import streamlit as st |
| from infer import load_predict_fn |
| from core.detect_infer import YOLODetector |
|
|
| TMP = tempfile.gettempdir() |
| torch.hub.set_dir(f"{TMP}/.cache/torch/hub") |
|
|
| def setup_page(name: str, tagline: str): |
| st.set_page_config(page_title=name, layout="wide") |
| st.markdown("<style>.block-container{max-width:1800px;padding:1rem}</style>", unsafe_allow_html=True) |
| st.title(name) |
| st.caption(tagline) |
|
|
| def list_model_dirs(root: Path) -> List[str]: |
| items = [] |
| if not root.exists(): return items |
| for p in sorted(root.iterdir()): |
| if p.is_dir() and (p/"model_config.json").exists() and (p/"class_map.json").exists(): |
| items.append(p.name) |
| return items |
|
|
| @st.cache_resource(show_spinner=False) |
| def _load(cfg, ckpt, clsf, model_name, models_root): |
| return load_predict_fn(str(cfg), str(ckpt), str(clsf), |
| model_name=model_name, models_root=str(models_root)) |
|
|
| def load_model_cached(model_name: str, models_root: Path): |
| model_dir = models_root / model_name |
| cfg = model_dir / "model_config.json" |
| ckpt = model_dir / "weights" / "best.pth" |
| clsf = model_dir / "class_map.json" |
|
|
| st.write(f"• Reading model folder: {model_dir}") |
| predict_fn, CLASS_NAMES, MODEL_OBJ, MODEL_ARCH = _load(cfg, ckpt, clsf, model_name, models_root) |
|
|
| with open(cfg, "r", encoding="utf-8") as f: |
| model_config = json.load(f) |
| model_image_size = int(model_config["preprocess"].get("image_size", 224)) |
|
|
| return (predict_fn, CLASS_NAMES, MODEL_OBJ, MODEL_ARCH, |
| model_image_size, model_dir, str(cfg), str(ckpt), str(clsf)) |
|
|
| |
| def _read_manifest_dict(models_root: Path) -> dict: |
| m = models_root / "manifest.json" |
| if not m.exists(): |
| return {} |
| try: |
| return json.loads(m.read_text(encoding="utf-8")) |
| except Exception: |
| return {} |
|
|
| def _download_s3_to_tmp(s3_uri: str, tmp_subdir: str = "yolo") -> str: |
| assert s3_uri.startswith("s3://"), "s3_uri must start with s3://" |
| bucket_and_key = s3_uri[len("s3://"):] |
| bucket, _, key = bucket_and_key.partition("/") |
| store = S3Store(bucket=bucket, prefix="") |
| data = store.read_bytes(key) |
| out_dir = Path(tempfile.gettempdir()) / "models" / tmp_subdir |
| out_dir.mkdir(parents=True, exist_ok=True) |
| fname = key.split("/")[-1] or "best.pt" |
| out_path = out_dir / fname |
| out_path.write_bytes(data) |
| return str(out_path) |
|
|
| def resolve_detector_weights(models_root: Path, |
| model_name: Optional[str] = None) -> Tuple[str, str]: |
| """ |
| Returns (weights_path, source_str) where source_str in: |
| {'env','manifest-hf','manifest-s3','manifest-local','local','s3'} |
| Precedence: DET_WEIGHTS env -> manifest.detectors[model_name](hf->s3->local) -> local guess. |
| """ |
| |
| env_w = os.getenv("DET_WEIGHTS", "").strip() |
| if env_w: |
| if env_w.startswith("s3://"): |
| return _download_s3_to_tmp(env_w), "s3" |
| return env_w, "env" |
|
|
| |
| if model_name: |
| man = _read_manifest_dict(models_root) |
| det = (man.get("detectors") or {}).get(model_name) |
| if det: |
| |
| hf = det.get("hf") |
| if isinstance(hf, dict) and hf.get("repo_id") and hf.get("filename"): |
| |
| try: |
| from huggingface_hub import hf_hub_download |
| token = os.getenv("HF_TOKEN") |
| wpath = hf_hub_download(hf["repo_id"], hf["filename"], token=token) |
| return wpath, "manifest-hf" |
| except Exception: |
| pass |
| |
| if det.get("s3"): |
| bkt = os.getenv("AWS_S3_BUCKET", "") |
| pfx = os.getenv("AWS_S3_PREFIX", "").strip("/") |
| key = det["s3"].lstrip("/") |
| s3_uri = f"s3://{bkt}/{('/'+pfx) if pfx else ''}/{key}".replace("//", "/").replace("s3:/", "s3://") |
| return _download_s3_to_tmp(s3_uri, tmp_subdir=model_name), "manifest-s3" |
| |
| if det.get("local"): |
| return str((models_root / det["local"]).resolve()), "manifest-local" |
|
|
| |
| if model_name: |
| guess = models_root / model_name / "weights" / "best.pt" |
| if guess.exists(): |
| return str(guess), "local" |
|
|
| raise FileNotFoundError("Detector weights not found. Set DET_WEIGHTS or add to manifest.json.") |
|
|
| @st.cache_resource(show_spinner=False) |
| def get_detector_cached(models_root: str, model_name: str): |
| """ |
| Returns (detector, meta_dict) and caches across reruns. |
| meta: {'name','path','source','device'} |
| """ |
| root = Path(models_root) |
| wpath, source = resolve_detector_weights(root, model_name) |
| det = YOLODetector(wpath) |
| device = det.device if hasattr(det, "device") else "auto" |
| meta = {"name": model_name, "path": wpath, "source": source, "device": device} |
| return det, meta |
|
|
|
|