Tan Zi Xu
gdrive integration
695209d
import json, re, torch, torch.nn as nn, torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
from typing import List, Tuple, Dict
import os, tempfile
from contextlib import nullcontext
from huggingface_hub import hf_hub_download
from pathlib import Path
from storage import get_store_from_env
#GRADCAM IMPORTS
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask
# Track last resolution for Diagnostics
LAST_WEIGHT_SOURCE = None # one of: local | s3 | gdrive | hf_manifest
LAST_WEIGHT_DETAIL = "" # key/repo/filename etc
LAST_WEIGHT_PATH = "" # final local path used
def get_last_weight_resolution():
return LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH
def _cloud_mode_for(model_name: str) -> str:
"""
CLOUD_WEIGHTS_MODE: off | prefer_hf | prefer_cloud | auto (default)
off -> never use cloud (S3/Drive) for weights
prefer_hf -> try HF first, then cloud
prefer_cloud-> try cloud first, then HF
auto -> HF for all models except those listed in CLOUD_WEIGHTS_ALLOW
"""
mode = (os.getenv("CLOUD_WEIGHTS_MODE", "auto") or "auto").lower()
allow = {x.strip() for x in (os.getenv("CLOUD_WEIGHTS_ALLOW","").split(",")) if x.strip()}
if mode == "auto":
return "prefer_cloud" if (model_name in allow) else "prefer_hf"
return mode
def _cloud_probe(model_name: str, ckpt_path: str):
"""
Returns tuple (cloud_exists, cloud_key, store or None, explanation)
Respects AWS_* / GDRIVE_* env to build the store via get_store_from_env().
"""
try:
from storage import get_store_from_env
store = get_store_from_env()
fname = os.path.basename(ckpt_path) or "best.pth"
wk = (os.getenv("WEIGHTS_KEY") or "").strip()
default_key = f"models/{model_name}/weights/{fname}"
# only honor WEIGHTS_KEY if it clearly targets this model
if wk and (f"/{model_name}/" in wk or wk.startswith(f"models/{model_name}/")):
key = wk
else:
key = default_key
exists = store.exists(key)
return exists, key, store, ""
except Exception as e:
return False, "", None, f"cloud probe failed: {e}"
def _cloud_autodiscover(store, model_name: str):
# find ANY .pth under models/<model>/weights/
folder = f"models/{model_name}/weights/"
try:
for b in store.list(prefix=folder, recursive=True):
if b.key.lower().endswith(".pth"):
return True, b.key
except Exception:
pass
return False, ""
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
def _device():
if torch.backends.mps.is_available(): return torch.device("mps")
if torch.cuda.is_available(): return torch.device("cuda")
return torch.device("cpu")
def _log(msg: str):
print(f"[infer] {msg}", flush=True)
# ---- robust checkpoint extraction helpers ----
def _is_tensor_dict(d):
try:
import torch as _t
return isinstance(d, dict) and any(isinstance(v, _t.Tensor) for v in d.values())
except Exception:
return False
def _find_tensor_dict_rec(obj, path="ckpt", depth=4):
"""DFS to locate a dict[str, Tensor] inside nested checkpoints."""
if depth < 0:
return None, None
if _is_tensor_dict(obj):
return obj, path
if isinstance(obj, dict):
# common wrappers
if "state_dict" in obj and _is_tensor_dict(obj["state_dict"]):
return obj["state_dict"], f"{path}['state_dict']"
if "model" in obj:
m = obj["model"]
if _is_tensor_dict(m):
return m, f"{path}['model']"
if hasattr(m, "state_dict") and callable(getattr(m, "state_dict")):
try:
sd = m.state_dict()
if _is_tensor_dict(sd):
return sd, f"{path}['model'].state_dict()"
except Exception:
pass
# generic DFS
for k, v in obj.items():
sd, where = _find_tensor_dict_rec(v, f"{path}['{k}']", depth-1)
if sd is not None:
return sd, where
# last resort: object with state_dict()
if hasattr(obj, "state_dict") and callable(getattr(obj, "state_dict")):
try:
sd = obj.state_dict()
if _is_tensor_dict(sd):
return sd, f"{path}.state_dict()"
except Exception:
pass
return None, None
def _extract_state_dict(ckpt):
sd, where = _find_tensor_dict_rec(ckpt, "ckpt", depth=4)
if sd is None:
raise RuntimeError(
f"Unsupported checkpoint format: could not locate a dict[str,Tensor]. "
f"Top: {list(ckpt.keys()) if isinstance(ckpt, dict) else type(ckpt)}"
)
print(f"[weights] extracted state_dict from {where}", flush=True)
return sd
def _strip_prefix_in_state_dict(sd, prefixes=("module.", "model.")):
new = {}
for k, v in sd.items():
nk = k
for p in prefixes:
if nk.startswith(p):
nk = nk[len(p):]
new[nk] = v
return new
# --- heads & model builders (mirrors train.py) ---
def _get_feature_dim(model, arch: str) -> int:
if arch.startswith("resnet"): return model.fc.in_features
if arch.startswith("convnext"): return model.classifier[2].in_features
if arch.startswith("efficientnet_v2"): return model.classifier[1].in_features
if arch.startswith("mobilenet_v3"): return model.classifier[0].in_features
raise ValueError(f"Unsupported arch: {arch}")
def _attach_head(model, arch: str, head: nn.Module) -> nn.Module:
if arch.startswith("resnet"): model.fc = head
elif arch.startswith("convnext"):
ln = model.classifier[0] # LayerNorm
flat = model.classifier[1] # Flatten(1)
model.classifier = nn.Sequential(ln, flat, head)
return model
elif arch.startswith("efficientnet_v2"): model.classifier = nn.Sequential(nn.Dropout(p=0.0), head)
elif arch.startswith("mobilenet_v3"): model.classifier = head
else: raise ValueError(f"Unsupported arch: {arch}")
return model
def _make_head(in_dim: int, num_classes: int, hidden: List[int], dropout: float, norm: str):
layers=[]; last=in_dim
for h in hidden:
layers.append(nn.Linear(last, h))
if norm=="batchnorm": layers.append(nn.BatchNorm1d(h))
elif norm=="layernorm": layers.append(nn.LayerNorm(h))
layers += [nn.ReLU(inplace=True), nn.Dropout(dropout)]
last=h
layers.append(nn.Linear(last, num_classes))
return nn.Sequential(*layers)
def build_model_from_config(cfg: dict, num_classes: int):
arch = cfg["model"]["arch"]
pretr = False # FORCE: do not download ImageNet weights on Spaces
if arch == "resnet50":
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretr else None)
elif arch == "convnext_tiny":
model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1 if pretr else None)
elif arch == "convnext_small":
model = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1 if pretr else None)
elif arch == "efficientnet_v2_s":
model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1 if pretr else None)
elif arch == "mobilenet_v3_large":
model = models.mobilenet_v3_large(weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2 if pretr else None)
else:
raise ValueError(f"Unknown arch: {arch}")
in_dim = _get_feature_dim(model, arch)
head = _make_head(in_dim, num_classes,
cfg["model"]["head_hidden"],
cfg["model"]["head_dropout"],
cfg["model"]["head_norm"])
return _attach_head(model, arch, head), arch
def _load_class_names(classes_path: str):
with open(classes_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# list of dicts
if isinstance(raw, list) and raw and isinstance(raw[0], dict) and "name" in raw[0]:
id2name = {int(x["id"]): x["name"] for x in raw}
names = [id2name[i] for i in sorted(id2name.keys())]
return names
# dict id->name (keys may be strings)
if isinstance(raw, dict):
try:
id2name = {int(k): v for k, v in raw.items()}
names = [id2name[i] for i in sorted(id2name.keys())]
return names
except Exception:
# dict name->id
items = sorted(raw.items(), key=lambda kv: int(kv[1]))
return [k for k, _ in items]
# list of names
if isinstance(raw, list):
return list(raw)
raise ValueError("Unsupported class_map.json format")
def _resolve_ckpt(ckpt_path: str, model_name: str, models_root: str) -> str:
from pathlib import Path
import tempfile
global LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH
# Availability probes (for Diagnostics tab convenience)
avail_local = os.path.exists(ckpt_path)
cloud_exists, cloud_key, cloud_store, cloud_err = _cloud_probe(model_name, ckpt_path)
# 0) If local exists, we still note it; actual order below depends on mode
mode = _cloud_mode_for(model_name)
def _use_cloud_with(key: str) -> str:
nonlocal cloud_store
tmp = Path(tempfile.gettempdir()) / (os.path.basename(key) or "best.pth")
cloud_store.download_to(key, tmp)
src = "s3" if os.getenv("AWS_S3_BUCKET") else ("gdrive" if os.getenv("GDRIVE_FOLDER_ID") else "cloud")
_log(f"Downloaded weights from {src}: {key} -> {tmp}")
return str(tmp), src
# Order: depends on mode
if mode == "off":
order = ["local", "hf_manifest"]
elif mode == "prefer_cloud":
order = ["local", "cloud", "cloud_auto", "hf_manifest"]
elif mode == "prefer_hf":
order = ["local", "hf_manifest", "cloud", "cloud_auto"]
else: # auto (per-allow-list)
order = ["local", "cloud", "cloud_auto", "hf_manifest"] if _cloud_mode_for(model_name) == "prefer_cloud" \
else ["local", "hf_manifest", "cloud", "cloud_auto"]
# Try sources in decided order
for src in order:
if src == "local" and avail_local:
LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = "local", ckpt_path, ckpt_path
_log(f"Using local weights: {ckpt_path}")
return ckpt_path
if src == "cloud" and cloud_exists and cloud_store:
path, sname = _use_cloud_with(cloud_key)
LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = sname, cloud_key, path
return path
if src == "cloud_auto" and cloud_store:
ok, auto_key = _cloud_autodiscover(cloud_store, model_name)
if ok:
path, sname = _use_cloud_with(auto_key)
LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = sname, auto_key, path
return path
if src == "hf_manifest":
mpath = Path(models_root) / "manifest.json"
if mpath.exists():
try:
with open(mpath, "r", encoding="utf-8") as f:
manifest = json.load(f)
if model_name in manifest:
rec = manifest[model_name]
repo = rec["repo_id"]
fname = rec.get("filename", os.path.basename(ckpt_path) or "best.pth")
_log(f"Downloading from manifest → repo={repo} file={fname}")
path = hf_hub_download(repo_id=repo, filename=fname, token=os.getenv("HF_TOKEN"))
LAST_WEIGHT_SOURCE, LAST_WEIGHT_DETAIL, LAST_WEIGHT_PATH = "hf_manifest", f"{repo}:{fname}", path
_log(f"Downloaded to cache: {path}")
return path
except Exception as e:
_log(f"Manifest read error: {e}")
# If we got here, nothing worked
parts = [f"local={'yes' if avail_local else 'no'}",
f"cloud={'yes' if cloud_exists else 'no'} key={cloud_key or '-'} err={cloud_err or '-'}",
f"hf_manifest={'present' if (Path(models_root)/'manifest.json').exists() else 'absent'}"]
raise FileNotFoundError("Could not resolve weights. Probes: " + ", ".join(parts))
#--- main loading function---
def load_predict_fn(config_path: str, ckpt_path: str, classes_path: str,
model_name: str = None, models_root: str = None):
_log(f"Loading config: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
cfg = json.load(f)
if model_name is None:
from pathlib import Path
model_name = Path(config_path).parent.name
if models_root is None:
from pathlib import Path
models_root = str(Path(config_path).resolve().parents[1] / "models")
_log(f"Loading classes: {classes_path}")
class_names = _load_class_names(classes_path)
_log(f"Classes loaded: {len(class_names)}")
_log("Building model…")
model, arch = build_model_from_config(cfg, len(class_names))
dev = _device()
_log("Resolving checkpoint…")
ckpt_path = _resolve_ckpt(ckpt_path, model_name, models_root)
_log(f"Loading weights from: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location="cpu")
# robust unwrap of trainer-style checkpoints
state_dict = _extract_state_dict(ckpt)
state_dict = _strip_prefix_in_state_dict(state_dict)
try:
model.load_state_dict(state_dict, strict=True)
except RuntimeError as e:
_log("[weights] strict load failed (likely head/classes/arch mismatch).")
_log(str(e))
# Optional debug escape hatch:
if os.getenv("ALLOW_PARTIAL_WEIGHTS", "0") == "1":
_log("[weights] retrying with strict=False due to ALLOW_PARTIAL_WEIGHTS=1")
model.load_state_dict(state_dict, strict=False)
else:
raise
_log("Weights loaded OK")
model.to(dev).eval()
_log(f"Model on device: {dev}")
size = int(cfg["preprocess"].get("image_size", 224))
tfm = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])
num_classes = len(class_names) # <-- add this
@torch.no_grad()
def predict_pils(pils, topk: int = 3):
if not pils:
return []
xb = torch.stack([tfm(im.convert("RGB")) for im in pils], 0).to(dev)
# Use autocast only on CUDA/MPS
amp_ctx = torch.amp.autocast(device_type=dev.type) if dev.type in ("cuda","mps") else nullcontext()
with amp_ctx:
logits = model(xb)
probs = F.softmax(logits, dim=1)
k = min(topk, num_classes) # <-- use the defined count
top_ids = probs.topk(k=k, dim=1).indices.cpu().numpy()
probs_np = probs.cpu().numpy()
out = [
[(class_names[int(j)], float(probs_np[i, int(j)])) for j in ids]
for i, ids in enumerate(top_ids)
]
return out
_log("Predict function ready")
# Return both the predict function and the raw model object
return predict_pils, class_names, model, arch # <-- Added model, arch
# --- GradCAM helper to get last conv layer ---
def _get_last_conv_layer_name(arch: str) -> str:
"""Determine the target layer name for GradCAM based on architecture."""
if arch.startswith("resnet"):
return "layer4"
elif arch.startswith("convnext"):
if arch == "convnext_tiny":
return "features.6"
elif arch == "convnext_small":
return "features.11"
else:
raise ValueError(f"Last conv layer name for ConvNeXt variant '{arch}' needs to be defined.")
elif arch.startswith("efficientnet_v2"):
if arch == "efficientnet_v2_s":
return "features.7"
else:
# Needs specific handling for other variants
raise ValueError(f"Last conv layer name for EfficientNet-V2 variant '{arch}' needs to be defined.")
elif arch.startswith("mobilenet_v3"):
if arch == "mobilenet_v3_large":
return "features.16"
else:
# Needs specific handling for other variants
raise ValueError(f"Last conv layer name for MobileNet-V3 variant '{arch}' needs to be defined.")
else:
raise ValueError(f"Unsupported arch for GradCAM target layer: {arch}")