# core/model_loader.py import os, json, pathlib, tempfile, torch from pathlib import Path from typing import List, Tuple, Optional from cloud.storage_s3 import S3Store import streamlit as st from infer import load_predict_fn from core.detect_infer import YOLODetector TMP = tempfile.gettempdir() torch.hub.set_dir(f"{TMP}/.cache/torch/hub") def setup_page(name: str, tagline: str): st.set_page_config(page_title=name, layout="wide") st.markdown("", unsafe_allow_html=True) st.title(name) st.caption(tagline) def list_model_dirs(root: Path) -> List[str]: items = [] if not root.exists(): return items for p in sorted(root.iterdir()): if p.is_dir() and (p/"model_config.json").exists() and (p/"class_map.json").exists(): items.append(p.name) return items @st.cache_resource(show_spinner=False) def _load(cfg, ckpt, clsf, model_name, models_root): return load_predict_fn(str(cfg), str(ckpt), str(clsf), model_name=model_name, models_root=str(models_root)) def load_model_cached(model_name: str, models_root: Path): model_dir = models_root / model_name cfg = model_dir / "model_config.json" ckpt = model_dir / "weights" / "best.pth" clsf = model_dir / "class_map.json" st.write(f"• Reading model folder: {model_dir}") predict_fn, CLASS_NAMES, MODEL_OBJ, MODEL_ARCH = _load(cfg, ckpt, clsf, model_name, models_root) with open(cfg, "r", encoding="utf-8") as f: model_config = json.load(f) model_image_size = int(model_config["preprocess"].get("image_size", 224)) return (predict_fn, CLASS_NAMES, MODEL_OBJ, MODEL_ARCH, model_image_size, model_dir, str(cfg), str(ckpt), str(clsf)) # ========= Detector resolution & cache (append to end of core/model_loader.py) ========= def _read_manifest_dict(models_root: Path) -> dict: m = models_root / "manifest.json" if not m.exists(): return {} try: return json.loads(m.read_text(encoding="utf-8")) except Exception: return {} def _download_s3_to_tmp(s3_uri: str, tmp_subdir: str = "yolo") -> str: assert s3_uri.startswith("s3://"), "s3_uri must start with s3://" bucket_and_key = s3_uri[len("s3://"):] bucket, _, key = bucket_and_key.partition("/") store = S3Store(bucket=bucket, prefix="") data = store.read_bytes(key) out_dir = Path(tempfile.gettempdir()) / "models" / tmp_subdir out_dir.mkdir(parents=True, exist_ok=True) fname = key.split("/")[-1] or "best.pt" out_path = out_dir / fname out_path.write_bytes(data) return str(out_path) def resolve_detector_weights(models_root: Path, model_name: Optional[str] = None) -> Tuple[str, str]: """ Returns (weights_path, source_str) where source_str in: {'env','manifest-hf','manifest-s3','manifest-local','local','s3'} Precedence: DET_WEIGHTS env -> manifest.detectors[model_name](hf->s3->local) -> local guess. """ # 1) explicit env override (s3 or local) env_w = os.getenv("DET_WEIGHTS", "").strip() if env_w: if env_w.startswith("s3://"): return _download_s3_to_tmp(env_w), "s3" return env_w, "env" # 2) manifest if model_name: man = _read_manifest_dict(models_root) det = (man.get("detectors") or {}).get(model_name) if det: # HF (optional) hf = det.get("hf") if isinstance(hf, dict) and hf.get("repo_id") and hf.get("filename"): # defer HF to your existing HF resolver if you have one; else skip try: from huggingface_hub import hf_hub_download token = os.getenv("HF_TOKEN") wpath = hf_hub_download(hf["repo_id"], hf["filename"], token=token) return wpath, "manifest-hf" except Exception: pass # S3 if det.get("s3"): bkt = os.getenv("AWS_S3_BUCKET", "") pfx = os.getenv("AWS_S3_PREFIX", "").strip("/") key = det["s3"].lstrip("/") s3_uri = f"s3://{bkt}/{('/'+pfx) if pfx else ''}/{key}".replace("//", "/").replace("s3:/", "s3://") return _download_s3_to_tmp(s3_uri, tmp_subdir=model_name), "manifest-s3" # Local if det.get("local"): return str((models_root / det["local"]).resolve()), "manifest-local" # 3) fallback local convention if model_name: guess = models_root / model_name / "weights" / "best.pt" if guess.exists(): return str(guess), "local" raise FileNotFoundError("Detector weights not found. Set DET_WEIGHTS or add to manifest.json.") @st.cache_resource(show_spinner=False) def get_detector_cached(models_root: str, model_name: str): """ Returns (detector, meta_dict) and caches across reruns. meta: {'name','path','source','device'} """ root = Path(models_root) wpath, source = resolve_detector_weights(root, model_name) det = YOLODetector(wpath) # auto-selects cuda/cpu device = det.device if hasattr(det, "device") else "auto" meta = {"name": model_name, "path": wpath, "source": source, "device": device} return det, meta