| import json, re, torch, torch.nn as nn, torch.nn.functional as F |
| from torchvision import models, transforms |
| from PIL import Image |
| from typing import List, Tuple, Dict |
| import os, tempfile |
| from contextlib import nullcontext |
| from huggingface_hub import hf_hub_download |
| from pathlib import Path |
| from storage import get_store_from_env |
|
|
| |
| from torchcam.methods import GradCAM |
| from torchcam.utils import overlay_mask |
|
|
| |
| LAST_WEIGHT_SOURCE = None |
| LAST_WEIGHT_DETAIL = "" |
| LAST_WEIGHT_PATH = "" |
|
|
| def get_last_weight_resolution(): |
| return LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH |
|
|
| def _cloud_mode_for(model_name: str) -> str: |
| """ |
| CLOUD_WEIGHTS_MODE: off | prefer_hf | prefer_cloud | auto (default) |
| off -> never use cloud (S3/Drive) for weights |
| prefer_hf -> try HF first, then cloud |
| prefer_cloud-> try cloud first, then HF |
| auto -> HF for all models except those listed in CLOUD_WEIGHTS_ALLOW |
| """ |
| mode = (os.getenv("CLOUD_WEIGHTS_MODE", "auto") or "auto").lower() |
| allow = {x.strip() for x in (os.getenv("CLOUD_WEIGHTS_ALLOW","").split(",")) if x.strip()} |
| if mode == "auto": |
| return "prefer_cloud" if (model_name in allow) else "prefer_hf" |
| return mode |
|
|
| def _cloud_probe(model_name: str, ckpt_path: str): |
| """ |
| Returns tuple (cloud_exists, cloud_key, store or None, explanation) |
| Respects AWS_* / GDRIVE_* env to build the store via get_store_from_env(). |
| """ |
| try: |
| from storage import get_store_from_env |
| store = get_store_from_env() |
| fname = os.path.basename(ckpt_path) or "best.pth" |
| wk = (os.getenv("WEIGHTS_KEY") or "").strip() |
| default_key = f"models/{model_name}/weights/{fname}" |
| |
| if wk and (f"/{model_name}/" in wk or wk.startswith(f"models/{model_name}/")): |
| key = wk |
| else: |
| key = default_key |
| exists = store.exists(key) |
| return exists, key, store, "" |
| except Exception as e: |
| return False, "", None, f"cloud probe failed: {e}" |
|
|
| def _cloud_autodiscover(store, model_name: str): |
| |
| folder = f"models/{model_name}/weights/" |
| try: |
| for b in store.list(prefix=folder, recursive=True): |
| if b.key.lower().endswith(".pth"): |
| return True, b.key |
| except Exception: |
| pass |
| return False, "" |
|
|
| IMAGENET_MEAN = [0.485, 0.456, 0.406] |
| IMAGENET_STD = [0.229, 0.224, 0.225] |
|
|
| def _device(): |
| if torch.backends.mps.is_available(): return torch.device("mps") |
| if torch.cuda.is_available(): return torch.device("cuda") |
| return torch.device("cpu") |
|
|
| def _log(msg: str): |
| print(f"[infer] {msg}", flush=True) |
|
|
| |
| def _is_tensor_dict(d): |
| try: |
| import torch as _t |
| return isinstance(d, dict) and any(isinstance(v, _t.Tensor) for v in d.values()) |
| except Exception: |
| return False |
|
|
| def _find_tensor_dict_rec(obj, path="ckpt", depth=4): |
| """DFS to locate a dict[str, Tensor] inside nested checkpoints.""" |
| if depth < 0: |
| return None, None |
| if _is_tensor_dict(obj): |
| return obj, path |
| if isinstance(obj, dict): |
| |
| if "state_dict" in obj and _is_tensor_dict(obj["state_dict"]): |
| return obj["state_dict"], f"{path}['state_dict']" |
| if "model" in obj: |
| m = obj["model"] |
| if _is_tensor_dict(m): |
| return m, f"{path}['model']" |
| if hasattr(m, "state_dict") and callable(getattr(m, "state_dict")): |
| try: |
| sd = m.state_dict() |
| if _is_tensor_dict(sd): |
| return sd, f"{path}['model'].state_dict()" |
| except Exception: |
| pass |
| |
| for k, v in obj.items(): |
| sd, where = _find_tensor_dict_rec(v, f"{path}['{k}']", depth-1) |
| if sd is not None: |
| return sd, where |
| |
| if hasattr(obj, "state_dict") and callable(getattr(obj, "state_dict")): |
| try: |
| sd = obj.state_dict() |
| if _is_tensor_dict(sd): |
| return sd, f"{path}.state_dict()" |
| except Exception: |
| pass |
| return None, None |
|
|
| def _extract_state_dict(ckpt): |
| sd, where = _find_tensor_dict_rec(ckpt, "ckpt", depth=4) |
| if sd is None: |
| raise RuntimeError( |
| f"Unsupported checkpoint format: could not locate a dict[str,Tensor]. " |
| f"Top: {list(ckpt.keys()) if isinstance(ckpt, dict) else type(ckpt)}" |
| ) |
| print(f"[weights] extracted state_dict from {where}", flush=True) |
| return sd |
|
|
| def _strip_prefix_in_state_dict(sd, prefixes=("module.", "model.")): |
| new = {} |
| for k, v in sd.items(): |
| nk = k |
| for p in prefixes: |
| if nk.startswith(p): |
| nk = nk[len(p):] |
| new[nk] = v |
| return new |
|
|
| |
| def _get_feature_dim(model, arch: str) -> int: |
| if arch.startswith("resnet"): return model.fc.in_features |
| if arch.startswith("convnext"): return model.classifier[2].in_features |
| if arch.startswith("efficientnet_v2"): return model.classifier[1].in_features |
| if arch.startswith("mobilenet_v3"): return model.classifier[0].in_features |
| raise ValueError(f"Unsupported arch: {arch}") |
|
|
| def _attach_head(model, arch: str, head: nn.Module) -> nn.Module: |
| if arch.startswith("resnet"): model.fc = head |
| elif arch.startswith("convnext"): |
| ln = model.classifier[0] |
| flat = model.classifier[1] |
| model.classifier = nn.Sequential(ln, flat, head) |
| return model |
| elif arch.startswith("efficientnet_v2"): model.classifier = nn.Sequential(nn.Dropout(p=0.0), head) |
| elif arch.startswith("mobilenet_v3"): model.classifier = head |
| else: raise ValueError(f"Unsupported arch: {arch}") |
| return model |
|
|
| def _make_head(in_dim: int, num_classes: int, hidden: List[int], dropout: float, norm: str): |
| layers=[]; last=in_dim |
| for h in hidden: |
| layers.append(nn.Linear(last, h)) |
| if norm=="batchnorm": layers.append(nn.BatchNorm1d(h)) |
| elif norm=="layernorm": layers.append(nn.LayerNorm(h)) |
| layers += [nn.ReLU(inplace=True), nn.Dropout(dropout)] |
| last=h |
| layers.append(nn.Linear(last, num_classes)) |
| return nn.Sequential(*layers) |
|
|
| def build_model_from_config(cfg: dict, num_classes: int): |
| arch = cfg["model"]["arch"] |
| pretr = False |
|
|
| if arch == "resnet50": |
| model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretr else None) |
| elif arch == "convnext_tiny": |
| model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretr else None) |
| elif arch == "convnext_small": |
| model = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1 if pretr else None) |
| elif arch == "efficientnet_v2_s": |
| model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretr else None) |
| elif arch == "mobilenet_v3_large": |
| model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2 if pretr else None) |
| else: |
| raise ValueError(f"Unknown arch: {arch}") |
|
|
| in_dim = _get_feature_dim(model, arch) |
| head = _make_head(in_dim, num_classes, |
| cfg["model"]["head_hidden"], |
| cfg["model"]["head_dropout"], |
| cfg["model"]["head_norm"]) |
| return _attach_head(model, arch, head), arch |
|
|
| def _load_class_names(classes_path: str): |
| with open(classes_path, "r", encoding="utf-8") as f: |
| raw = json.load(f) |
| |
| if isinstance(raw, list) and raw and isinstance(raw[0], dict) and "name" in raw[0]: |
| id2name = {int(x["id"]): x["name"] for x in raw} |
| names = [id2name[i] for i in sorted(id2name.keys())] |
| return names |
| |
| if isinstance(raw, dict): |
| try: |
| id2name = {int(k): v for k, v in raw.items()} |
| names = [id2name[i] for i in sorted(id2name.keys())] |
| return names |
| except Exception: |
| |
| items = sorted(raw.items(), key=lambda kv: int(kv[1])) |
| return [k for k, _ in items] |
| |
| if isinstance(raw, list): |
| return list(raw) |
| raise ValueError("Unsupported class_map.json format") |
|
|
| def _resolve_ckpt(ckpt_path: str, model_name: str, models_root: str) -> str: |
| from pathlib import Path |
| import tempfile |
| global LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH |
|
|
| |
| avail_local = os.path.exists(ckpt_path) |
| cloud_exists, cloud_key, cloud_store, cloud_err = _cloud_probe(model_name, ckpt_path) |
|
|
| |
| mode = _cloud_mode_for(model_name) |
|
|
| def _use_cloud_with(key: str) -> str: |
| nonlocal cloud_store |
| tmp = Path(tempfile.gettempdir()) / (os.path.basename(key) or "best.pth") |
| cloud_store.download_to(key, tmp) |
| src = "s3" if os.getenv("AWS_S3_BUCKET") else ("gdrive" if os.getenv("GDRIVE_FOLDER_ID") else "cloud") |
| _log(f"Downloaded weights from {src}: {key} -> {tmp}") |
| return str(tmp), src |
|
|
| |
| if mode == "off": |
| order = ["local", "hf_manifest"] |
| elif mode == "prefer_cloud": |
| order = ["local", "cloud", "cloud_auto", "hf_manifest"] |
| elif mode == "prefer_hf": |
| order = ["local", "hf_manifest", "cloud", "cloud_auto"] |
| else: |
| order = ["local", "cloud", "cloud_auto", "hf_manifest"] if _cloud_mode_for(model_name) == "prefer_cloud" \ |
| else ["local", "hf_manifest", "cloud", "cloud_auto"] |
|
|
| |
| for src in order: |
| if src == "local" and avail_local: |
| LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = "local", ckpt_path, ckpt_path |
| _log(f"Using local weights: {ckpt_path}") |
| return ckpt_path |
|
|
| if src == "cloud" and cloud_exists and cloud_store: |
| path, sname = _use_cloud_with(cloud_key) |
| LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = sname, cloud_key, path |
| return path |
|
|
| if src == "cloud_auto" and cloud_store: |
| ok, auto_key = _cloud_autodiscover(cloud_store, model_name) |
| if ok: |
| path, sname = _use_cloud_with(auto_key) |
| LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = sname, auto_key, path |
| return path |
|
|
| if src == "hf_manifest": |
| mpath = Path(models_root) / "manifest.json" |
| if mpath.exists(): |
| try: |
| with open(mpath, "r", encoding="utf-8") as f: |
| manifest = json.load(f) |
| if model_name in manifest: |
| rec = manifest[model_name] |
| repo = rec["repo_id"] |
| fname = rec.get("filename", os.path.basename(ckpt_path) or "best.pth") |
| _log(f"Downloading from manifest → repo={repo} file={fname}") |
| path = hf_hub_download(repo_id=repo, filename=fname, token=os.getenv("HF_TOKEN")) |
| LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = "hf_manifest", f"{repo}:{fname}", path |
| _log(f"Downloaded to cache: {path}") |
| return path |
| except Exception as e: |
| _log(f"Manifest read error: {e}") |
|
|
| |
| parts = [f"local={'yes' if avail_local else 'no'}", |
| f"cloud={'yes' if cloud_exists else 'no'} key={cloud_key or '-'} err={cloud_err or '-'}", |
| f"hf_manifest={'present' if (Path(models_root)/'manifest.json').exists() else 'absent'}"] |
| raise FileNotFoundError("Could not resolve weights. Probes: " + ", ".join(parts)) |
|
|
| |
| def load_predict_fn(config_path: str, ckpt_path: str, classes_path: str, |
| model_name: str = None, models_root: str = None): |
| _log(f"Loading config: {config_path}") |
| with open(config_path, "r", encoding="utf-8") as f: |
| cfg = json.load(f) |
|
|
| if model_name is None: |
| from pathlib import Path |
| model_name = Path(config_path).parent.name |
| if models_root is None: |
| from pathlib import Path |
| models_root = str(Path(config_path).resolve().parents[1] / "models") |
|
|
| _log(f"Loading classes: {classes_path}") |
| class_names = _load_class_names(classes_path) |
| _log(f"Classes loaded: {len(class_names)}") |
|
|
| _log("Building model…") |
| model, arch = build_model_from_config(cfg, len(class_names)) |
| dev = _device() |
| _log("Resolving checkpoint…") |
| ckpt_path = _resolve_ckpt(ckpt_path, model_name, models_root) |
| _log(f"Loading weights from: {ckpt_path}") |
| ckpt = torch.load(ckpt_path, map_location="cpu") |
|
|
| |
| state_dict = _extract_state_dict(ckpt) |
| state_dict = _strip_prefix_in_state_dict(state_dict) |
|
|
| try: |
| model.load_state_dict(state_dict, strict=True) |
| except RuntimeError as e: |
| _log("[weights] strict load failed (likely head/classes/arch mismatch).") |
| _log(str(e)) |
| |
| if os.getenv("ALLOW_PARTIAL_WEIGHTS", "0") == "1": |
| _log("[weights] retrying with strict=False due to ALLOW_PARTIAL_WEIGHTS=1") |
| model.load_state_dict(state_dict, strict=False) |
| else: |
| raise |
| _log("Weights loaded OK") |
| model.to(dev).eval() |
| _log(f"Model on device: {dev}") |
|
|
| size = int(cfg["preprocess"].get("image_size", 224)) |
| tfm = transforms.Compose([ |
| transforms.Resize((size, size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) |
| ]) |
| num_classes = len(class_names) |
|
|
| @torch.no_grad() |
| def predict_pils(pils, topk: int = 3): |
| if not pils: |
| return [] |
| xb = torch.stack([tfm(im.convert("RGB")) for im in pils], 0).to(dev) |
| |
| amp_ctx = torch.amp.autocast(device_type=dev.type) if dev.type in ("cuda","mps") else nullcontext() |
| with amp_ctx: |
| logits = model(xb) |
| probs = F.softmax(logits, dim=1) |
| k = min(topk, num_classes) |
| top_ids = probs.topk(k=k, dim=1).indices.cpu().numpy() |
| probs_np = probs.cpu().numpy() |
| out = [ |
| [(class_names[int(j)], float(probs_np[i, int(j)])) for j in ids] |
| for i, ids in enumerate(top_ids) |
| ] |
| return out |
|
|
| _log("Predict function ready") |
| |
| return predict_pils, class_names, model, arch |
|
|
| |
| def _get_last_conv_layer_name(arch: str) -> str: |
| """Determine the target layer name for GradCAM based on architecture.""" |
| if arch.startswith("resnet"): |
| return "layer4" |
| elif arch.startswith("convnext"): |
| if arch == "convnext_tiny": |
| return "features.6" |
| elif arch == "convnext_small": |
| return "features.11" |
| else: |
|
|
| raise ValueError(f"Last conv layer name for ConvNeXt variant '{arch}' needs to be defined.") |
| elif arch.startswith("efficientnet_v2"): |
|
|
| if arch == "efficientnet_v2_s": |
| return "features.7" |
| else: |
| |
| raise ValueError(f"Last conv layer name for EfficientNet-V2 variant '{arch}' needs to be defined.") |
| elif arch.startswith("mobilenet_v3"): |
|
|
| if arch == "mobilenet_v3_large": |
| return "features.16" |
| else: |
| |
| raise ValueError(f"Last conv layer name for MobileNet-V3 variant '{arch}' needs to be defined.") |
| else: |
| raise ValueError(f"Unsupported arch for GradCAM target layer: {arch}") |