Spaces:
Running
Running
File size: 6,570 Bytes
e72f783 64f4176 e72f783 64f4176 e72f783 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | # api/startup.py
# All model and index loading happens here β once at FastAPI startup
# Everything stays in memory for the entire server lifetime
# Never load models per-request
import os
import json
import time
import torch
import clip
from src.patchcore import patchcore
from src.retriever import retriever
from src.graph import knowledge_graph
from src.depth import depth_estimator
from src.xai import gradcam, shap_explainer
from src.cache import inference_cache
from src.orchestrator import init_orchestrator
from api.logger import init_logger
# Startup timestamp β used for uptime calculation in /health
STARTUP_TIME = None
MODEL_VERSION = "v1.0"
def download_artifacts():
"""Download all required artifacts from HF Dataset at startup."""
from huggingface_hub import hf_hub_download, snapshot_download
import shutil
HF_REPO = "CaffeinatedCoding/anomalyos-logs"
token = os.environ.get("HF_TOKEN")
os.makedirs("data", exist_ok=True)
files_to_download = [
("models/pca_256.pkl", "data/pca_256.pkl"),
("configs/thresholds.json", "data/thresholds.json"),
("graph/knowledge_graph.json", "data/knowledge_graph.json"),
("indexes/index1_category.faiss", "data/index1_category.faiss"),
("indexes/index1_metadata.json", "data/index1_metadata.json"),
("indexes/index2_defect.faiss", "data/index2_defect.faiss"),
("indexes/index2_metadata.json", "data/index2_metadata.json"),
]
# Index 3 β one per category
categories = [
'bottle','cable','capsule','carpet','grid','hazelnut',
'leather','metal_nut','pill','screw','tile','toothbrush',
'transistor','wood','zipper'
]
for cat in categories:
files_to_download.append((
f"indexes/index3_{cat}.faiss",
f"data/index3_{cat}.faiss"
))
for repo_path, local_path in files_to_download:
if os.path.exists(local_path):
print(f"Already exists: {local_path}")
continue
try:
print(f"Downloading {repo_path}...")
downloaded = hf_hub_download(
repo_id=HF_REPO,
filename=repo_path,
repo_type="dataset",
token=token,
local_dir="/tmp/artifacts"
)
shutil.copy(downloaded, local_path)
print(f" β {local_path}")
except Exception as e:
print(f" WARNING: Could not download {repo_path}: {e}")
def load_all():
"""
Called once from FastAPI lifespan on startup.
Order matters β patchcore before orchestrator, logger before anything logs.
"""
global STARTUP_TIME
STARTUP_TIME = time.time()
print("=" * 50)
print("AnomalyOS startup sequence")
print("=" * 50)
# Download artifacts first
download_artifacts()
# ββ CPU thread tuning βββββββββββββββββββββββββββββββββββββ
# HF Spaces CPU Basic = 2 vCPU
# Limit PyTorch threads to match β prevents over-subscription
torch.set_num_threads(2)
torch.set_default_dtype(torch.float32)
print(f"PyTorch threads: {torch.get_num_threads()}")
# ββ Logger ββββββββββββββββββββββββββββββββββββββββββββββββ
hf_token = os.environ.get("HF_TOKEN", "")
init_logger(hf_token)
# ββ PatchCore extractor βββββββββββββββββββββββββββββββββββ
patchcore.load()
# ββ FAISS indexes βββββββββββββββββββββββββββββββββββββββββ
# Index 3 is lazy-loaded β not loaded here
retriever.load_indexes()
# ββ Knowledge graph βββββββββββββββββββββββββββββββββββββββ
knowledge_graph.load()
# ββ MiDaS depth estimator βββββββββββββββββββββββββββββββββ
try:
depth_estimator.load()
except FileNotFoundError as e:
print(f"WARNING: {e}")
print("Depth features will return zeros β inference continues")
# ββ CLIP model ββββββββββββββββββββββββββββββββββββββββββββ
# Loaded here, injected into orchestrator
print("Loading CLIP ViT-B/32...")
clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu")
clip_model.eval()
print("CLIP loaded")
# ββ Thresholds ββββββββββββββββββββββββββββββββββββββββββββ
thresholds_path = os.path.join(
os.environ.get("DATA_DIR", "data"), "thresholds.json"
)
if os.path.exists(thresholds_path):
with open(thresholds_path) as f:
thresholds = json.load(f)
print(f"Thresholds loaded: {len(thresholds)} categories")
else:
thresholds = {}
print("WARNING: thresholds.json not found β using score > 0.5 fallback")
# ββ GradCAM++ βββββββββββββββββββββββββββββββββββββββββββββ
try:
gradcam.load()
except Exception as e:
print(f"WARNING: GradCAM++ load failed: {e}")
print("Forensics mode will run without GradCAM++")
# ββ SHAP background βββββββββββββββββββββββββββββββββββββββ
bg_path = os.path.join(
os.environ.get("DATA_DIR", "data"), "shap_background.npy"
)
shap_explainer.load_background(bg_path)
# ββ Inject into orchestrator ββββββββββββββββββββββββββββββ
init_orchestrator(clip_model, clip_preprocess, thresholds)
elapsed = time.time() - STARTUP_TIME
print("=" * 50)
print(f"Startup complete in {elapsed:.1f}s")
print(f"Model version: {MODEL_VERSION}")
print("=" * 50)
return {
"clip_model": clip_model,
"clip_preprocess": clip_preprocess,
"thresholds": thresholds
}
def get_uptime() -> float:
if STARTUP_TIME is None:
return 0.0
return time.time() - STARTUP_TIME |