"""Model loading and prediction utilities.""" import os import numpy as np import streamlit as st import tensorflow as tf from huggingface_hub import hf_hub_download from app.config import MODEL_REGISTRY from app.labels import ID_TO_PATTERN @st.cache_resource def load_model(model_key: str) -> tf.keras.Model: """Load and cache a Keras model from HF Hub or local fallback. Tries HF Hub first, falls back to local path if HF unavailable. Cached across reruns and sessions. """ info = MODEL_REGISTRY[model_key] hf_repo_id = info.get("hf_repo_id") local_path = info.get("path") # Try HF Hub first if hf_repo_id: try: model_path = hf_hub_download( repo_id=hf_repo_id, filename="model.keras", cache_dir=".cache", ) model = tf.keras.models.load_model(str(model_path)) return model except Exception: pass # Fallback to local path if local_path and os.path.exists(str(local_path)): try: model = tf.keras.models.load_model(str(local_path)) return model except Exception as e: raise RuntimeError(f"Failed to load model: {e}") from e raise FileNotFoundError(f"Model not found for key: {model_key}") def predict_single(model: tf.keras.Model, prepared_input: np.ndarray) -> dict: """Run prediction on a single prepared wafer map. Args: prepared_input: shape (1, 52, 52, 3) float32 Returns: dict with class_id, pattern_name, confidence, probabilities """ probs = model.predict(prepared_input, verbose=0)[0] class_id = int(np.argmax(probs)) return { "class_id": class_id, "pattern_name": ID_TO_PATTERN[class_id], "confidence": float(probs[class_id]), "probabilities": probs, } def predict_batch(model: tf.keras.Model, prepared_inputs: np.ndarray) -> list[dict]: """Run prediction on a batch of prepared wafer maps. Args: prepared_inputs: shape (N, 52, 52, 3) float32 Returns: list of result dicts, one per wafer """ all_probs = model.predict(prepared_inputs, verbose=0) results = [] for i, probs in enumerate(all_probs): class_id = int(np.argmax(probs)) results.append({ "index": i, "class_id": class_id, "pattern_name": ID_TO_PATTERN[class_id], "confidence": float(probs[class_id]), "probabilities": probs, }) return results