File size: 6,570 Bytes
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64f4176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
64f4176
 
 
 
e72f783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# api/startup.py
# All model and index loading happens here β€” once at FastAPI startup
# Everything stays in memory for the entire server lifetime
# Never load models per-request

import os
import json
import time
import torch
import clip

from src.patchcore import patchcore
from src.retriever import retriever
from src.graph import knowledge_graph
from src.depth import depth_estimator
from src.xai import gradcam, shap_explainer
from src.cache import inference_cache
from src.orchestrator import init_orchestrator
from api.logger import init_logger


# Startup timestamp β€” used for uptime calculation in /health
STARTUP_TIME = None
MODEL_VERSION = "v1.0"

def download_artifacts():
    """Download all required artifacts from HF Dataset at startup."""
    from huggingface_hub import hf_hub_download, snapshot_download
    import shutil

    HF_REPO = "CaffeinatedCoding/anomalyos-logs"
    token = os.environ.get("HF_TOKEN")
    
    os.makedirs("data", exist_ok=True)

    files_to_download = [
        ("models/pca_256.pkl",              "data/pca_256.pkl"),
        ("configs/thresholds.json",          "data/thresholds.json"),
        ("graph/knowledge_graph.json",       "data/knowledge_graph.json"),
        ("indexes/index1_category.faiss",    "data/index1_category.faiss"),
        ("indexes/index1_metadata.json",     "data/index1_metadata.json"),
        ("indexes/index2_defect.faiss",      "data/index2_defect.faiss"),
        ("indexes/index2_metadata.json",     "data/index2_metadata.json"),
    ]

    # Index 3 β€” one per category
    categories = [
        'bottle','cable','capsule','carpet','grid','hazelnut',
        'leather','metal_nut','pill','screw','tile','toothbrush',
        'transistor','wood','zipper'
    ]
    for cat in categories:
        files_to_download.append((
            f"indexes/index3_{cat}.faiss",
            f"data/index3_{cat}.faiss"
        ))

    for repo_path, local_path in files_to_download:
        if os.path.exists(local_path):
            print(f"Already exists: {local_path}")
            continue
        try:
            print(f"Downloading {repo_path}...")
            downloaded = hf_hub_download(
                repo_id=HF_REPO,
                filename=repo_path,
                repo_type="dataset",
                token=token,
                local_dir="/tmp/artifacts"
            )
            shutil.copy(downloaded, local_path)
            print(f"  β†’ {local_path}")
        except Exception as e:
            print(f"  WARNING: Could not download {repo_path}: {e}")

def load_all():
    """
    Called once from FastAPI lifespan on startup.
    Order matters β€” patchcore before orchestrator, logger before anything logs.
    """
    global STARTUP_TIME
    STARTUP_TIME = time.time()

    print("=" * 50)
    print("AnomalyOS startup sequence")
    print("=" * 50)

    # Download artifacts first
    download_artifacts()
    

    # ── CPU thread tuning ─────────────────────────────────────
    # HF Spaces CPU Basic = 2 vCPU
    # Limit PyTorch threads to match β€” prevents over-subscription
    torch.set_num_threads(2)
    torch.set_default_dtype(torch.float32)
    print(f"PyTorch threads: {torch.get_num_threads()}")

    # ── Logger ────────────────────────────────────────────────
    hf_token = os.environ.get("HF_TOKEN", "")
    init_logger(hf_token)

    # ── PatchCore extractor ───────────────────────────────────
    patchcore.load()

    # ── FAISS indexes ─────────────────────────────────────────
    # Index 3 is lazy-loaded β€” not loaded here
    retriever.load_indexes()

    # ── Knowledge graph ───────────────────────────────────────
    knowledge_graph.load()

    # ── MiDaS depth estimator ─────────────────────────────────
    try:
        depth_estimator.load()
    except FileNotFoundError as e:
        print(f"WARNING: {e}")
        print("Depth features will return zeros β€” inference continues")

    # ── CLIP model ────────────────────────────────────────────
    # Loaded here, injected into orchestrator
    print("Loading CLIP ViT-B/32...")
    clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu")
    clip_model.eval()
    print("CLIP loaded")

    # ── Thresholds ────────────────────────────────────────────
    thresholds_path = os.path.join(
        os.environ.get("DATA_DIR", "data"), "thresholds.json"
    )
    if os.path.exists(thresholds_path):
        with open(thresholds_path) as f:
            thresholds = json.load(f)
        print(f"Thresholds loaded: {len(thresholds)} categories")
    else:
        thresholds = {}
        print("WARNING: thresholds.json not found β€” using score > 0.5 fallback")

    # ── GradCAM++ ─────────────────────────────────────────────
    try:
        gradcam.load()
    except Exception as e:
        print(f"WARNING: GradCAM++ load failed: {e}")
        print("Forensics mode will run without GradCAM++")

    # ── SHAP background ───────────────────────────────────────
    bg_path = os.path.join(
        os.environ.get("DATA_DIR", "data"), "shap_background.npy"
    )
    shap_explainer.load_background(bg_path)

    # ── Inject into orchestrator ──────────────────────────────
    init_orchestrator(clip_model, clip_preprocess, thresholds)

    elapsed = time.time() - STARTUP_TIME
    print("=" * 50)
    print(f"Startup complete in {elapsed:.1f}s")
    print(f"Model version: {MODEL_VERSION}")
    print("=" * 50)

    return {
        "clip_model": clip_model,
        "clip_preprocess": clip_preprocess,
        "thresholds": thresholds
    }


def get_uptime() -> float:
    if STARTUP_TIME is None:
        return 0.0
    return time.time() - STARTUP_TIME