Spaces:
Sleeping
Sleeping
File size: 1,948 Bytes
51e944e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | from functools import lru_cache
import numpy as np
import tensorflow as tf
from api_backend.configs import logger
import os
# Get the base directory of the current file
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(BASE_DIR, "models")
# Model Registry
MODEL_REGISTRY = {
"efficientnet": {
"path": os.path.join(MODEL_DIR, "efficientnet.keras"),
"preprocess": tf.keras.applications.efficientnet_v2.preprocess_input,
"decode": tf.keras.applications.efficientnet_v2.decode_predictions,
"input_size": (480, 480)
},
"resnet": {
"path": os.path.join(MODEL_DIR, "resnet50_imagenet.keras"),
"preprocess": tf.keras.applications.resnet50.preprocess_input,
"decode": tf.keras.applications.resnet50.decode_predictions,
"input_size": (224, 224)
}
}
# Exceptions
class ModelNotFoundError(Exception):
"""Exception raised when a requested model is not found."""
pass
class InvalidImageError(Exception):
"""Exception raised when image processing fails."""
pass
# Model Loading
@lru_cache(maxsize=None)
def load_model(model_path: str, input_size: tuple) -> tf.keras.Model:
"""Load and warm up a TensorFlow model with caching."""
try:
model = tf.keras.models.load_model(model_path)
# Warm up the model
dummy_input = np.zeros((1, *input_size, 3))
_ = model.predict(dummy_input)
logger.info(f"Successfully loaded model from {model_path}")
return model
except Exception as e:
logger.error(f"Failed to load model from {model_path}: {str(e)}")
raise RuntimeError(f"Failed to load model from {model_path}: {str(e)}")
# Initialize models
models = {}
for name, config in MODEL_REGISTRY.items():
try:
models[name] = load_model(config["path"], config["input_size"])
except Exception as e:
logger.error(f"Could not load model {name}: {str(e)}") |