Spaces:
Sleeping
Sleeping
| """Model loading and prediction utilities.""" | |
| import os | |
| import numpy as np | |
| import streamlit as st | |
| import tensorflow as tf | |
| from huggingface_hub import hf_hub_download | |
| from app.config import MODEL_REGISTRY | |
| from app.labels import ID_TO_PATTERN | |
| def load_model(model_key: str) -> tf.keras.Model: | |
| """Load and cache a Keras model from HF Hub or local fallback. | |
| Tries HF Hub first, falls back to local path if HF unavailable. | |
| Cached across reruns and sessions. | |
| """ | |
| info = MODEL_REGISTRY[model_key] | |
| hf_repo_id = info.get("hf_repo_id") | |
| local_path = info.get("path") | |
| # Try HF Hub first | |
| if hf_repo_id: | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=hf_repo_id, | |
| filename="model.keras", | |
| cache_dir=".cache", | |
| ) | |
| model = tf.keras.models.load_model(str(model_path)) | |
| return model | |
| except Exception: | |
| pass | |
| # Fallback to local path | |
| if local_path and os.path.exists(str(local_path)): | |
| try: | |
| model = tf.keras.models.load_model(str(local_path)) | |
| return model | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load model: {e}") from e | |
| raise FileNotFoundError(f"Model not found for key: {model_key}") | |
| def predict_single(model: tf.keras.Model, prepared_input: np.ndarray) -> dict: | |
| """Run prediction on a single prepared wafer map. | |
| Args: | |
| prepared_input: shape (1, 52, 52, 3) float32 | |
| Returns: | |
| dict with class_id, pattern_name, confidence, probabilities | |
| """ | |
| probs = model.predict(prepared_input, verbose=0)[0] | |
| class_id = int(np.argmax(probs)) | |
| return { | |
| "class_id": class_id, | |
| "pattern_name": ID_TO_PATTERN[class_id], | |
| "confidence": float(probs[class_id]), | |
| "probabilities": probs, | |
| } | |
| def predict_batch(model: tf.keras.Model, prepared_inputs: np.ndarray) -> list[dict]: | |
| """Run prediction on a batch of prepared wafer maps. | |
| Args: | |
| prepared_inputs: shape (N, 52, 52, 3) float32 | |
| Returns: | |
| list of result dicts, one per wafer | |
| """ | |
| all_probs = model.predict(prepared_inputs, verbose=0) | |
| results = [] | |
| for i, probs in enumerate(all_probs): | |
| class_id = int(np.argmax(probs)) | |
| results.append({ | |
| "index": i, | |
| "class_id": class_id, | |
| "pattern_name": ID_TO_PATTERN[class_id], | |
| "confidence": float(probs[class_id]), | |
| "probabilities": probs, | |
| }) | |
| return results | |