Spaces:
Sleeping
Sleeping
| from functools import lru_cache | |
| import numpy as np | |
| import tensorflow as tf | |
| from api_backend.configs import logger | |
| import os | |
| # Get the base directory of the current file | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODEL_DIR = os.path.join(BASE_DIR, "models") | |
| # Model Registry | |
| MODEL_REGISTRY = { | |
| "efficientnet": { | |
| "path": os.path.join(MODEL_DIR, "efficientnet.keras"), | |
| "preprocess": tf.keras.applications.efficientnet_v2.preprocess_input, | |
| "decode": tf.keras.applications.efficientnet_v2.decode_predictions, | |
| "input_size": (480, 480) | |
| }, | |
| "resnet": { | |
| "path": os.path.join(MODEL_DIR, "resnet50_imagenet.keras"), | |
| "preprocess": tf.keras.applications.resnet50.preprocess_input, | |
| "decode": tf.keras.applications.resnet50.decode_predictions, | |
| "input_size": (224, 224) | |
| } | |
| } | |
| # Exceptions | |
| class ModelNotFoundError(Exception): | |
| """Exception raised when a requested model is not found.""" | |
| pass | |
| class InvalidImageError(Exception): | |
| """Exception raised when image processing fails.""" | |
| pass | |
| # Model Loading | |
| def load_model(model_path: str, input_size: tuple) -> tf.keras.Model: | |
| """Load and warm up a TensorFlow model with caching.""" | |
| try: | |
| model = tf.keras.models.load_model(model_path) | |
| # Warm up the model | |
| dummy_input = np.zeros((1, *input_size, 3)) | |
| _ = model.predict(dummy_input) | |
| logger.info(f"Successfully loaded model from {model_path}") | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to load model from {model_path}: {str(e)}") | |
| raise RuntimeError(f"Failed to load model from {model_path}: {str(e)}") | |
| # Initialize models | |
| models = {} | |
| for name, config in MODEL_REGISTRY.items(): | |
| try: | |
| models[name] = load_model(config["path"], config["input_size"]) | |
| except Exception as e: | |
| logger.error(f"Could not load model {name}: {str(e)}") |