waferguard-ml / app /model_utils.py
oliversinn's picture
wip
32d72db
"""Model loading and prediction utilities."""
import os
import numpy as np
import streamlit as st
import tensorflow as tf
from huggingface_hub import hf_hub_download
from app.config import MODEL_REGISTRY
from app.labels import ID_TO_PATTERN
@st.cache_resource
def load_model(model_key: str) -> tf.keras.Model:
"""Load and cache a Keras model from HF Hub or local fallback.
Tries HF Hub first, falls back to local path if HF unavailable.
Cached across reruns and sessions.
"""
info = MODEL_REGISTRY[model_key]
hf_repo_id = info.get("hf_repo_id")
local_path = info.get("path")
# Try HF Hub first
if hf_repo_id:
try:
model_path = hf_hub_download(
repo_id=hf_repo_id,
filename="model.keras",
cache_dir=".cache",
)
model = tf.keras.models.load_model(str(model_path))
return model
except Exception:
pass
# Fallback to local path
if local_path and os.path.exists(str(local_path)):
try:
model = tf.keras.models.load_model(str(local_path))
return model
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}") from e
raise FileNotFoundError(f"Model not found for key: {model_key}")
def predict_single(model: tf.keras.Model, prepared_input: np.ndarray) -> dict:
"""Run prediction on a single prepared wafer map.
Args:
prepared_input: shape (1, 52, 52, 3) float32
Returns:
dict with class_id, pattern_name, confidence, probabilities
"""
probs = model.predict(prepared_input, verbose=0)[0]
class_id = int(np.argmax(probs))
return {
"class_id": class_id,
"pattern_name": ID_TO_PATTERN[class_id],
"confidence": float(probs[class_id]),
"probabilities": probs,
}
def predict_batch(model: tf.keras.Model, prepared_inputs: np.ndarray) -> list[dict]:
"""Run prediction on a batch of prepared wafer maps.
Args:
prepared_inputs: shape (N, 52, 52, 3) float32
Returns:
list of result dicts, one per wafer
"""
all_probs = model.predict(prepared_inputs, verbose=0)
results = []
for i, probs in enumerate(all_probs):
class_id = int(np.argmax(probs))
results.append({
"index": i,
"class_id": class_id,
"pattern_name": ID_TO_PATTERN[class_id],
"confidence": float(probs[class_id]),
"probabilities": probs,
})
return results