File size: 1,948 Bytes
51e944e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from functools import lru_cache
import numpy as np
import tensorflow as tf
from api_backend.configs import logger
import os

# Get the base directory of the current file
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_DIR = os.path.join(BASE_DIR, "models")

# Model Registry
MODEL_REGISTRY = {
    "efficientnet": {
        "path": os.path.join(MODEL_DIR, "efficientnet.keras"),
        "preprocess": tf.keras.applications.efficientnet_v2.preprocess_input,
        "decode": tf.keras.applications.efficientnet_v2.decode_predictions,
        "input_size": (480, 480)
    },
    "resnet": {
        "path": os.path.join(MODEL_DIR, "resnet50_imagenet.keras"),
        "preprocess": tf.keras.applications.resnet50.preprocess_input,
        "decode": tf.keras.applications.resnet50.decode_predictions,
        "input_size": (224, 224)
    }
}

# Exceptions
class ModelNotFoundError(Exception):
    """Exception raised when a requested model is not found."""
    pass

class InvalidImageError(Exception):
    """Exception raised when image processing fails."""
    pass

# Model Loading
@lru_cache(maxsize=None)
def load_model(model_path: str, input_size: tuple) -> tf.keras.Model:
    """Load and warm up a TensorFlow model with caching."""
    try:
        model = tf.keras.models.load_model(model_path)
        # Warm up the model
        dummy_input = np.zeros((1, *input_size, 3))
        _ = model.predict(dummy_input)
        logger.info(f"Successfully loaded model from {model_path}")
        return model
    except Exception as e:
        logger.error(f"Failed to load model from {model_path}: {str(e)}")
        raise RuntimeError(f"Failed to load model from {model_path}: {str(e)}")

# Initialize models
models = {}
for name, config in MODEL_REGISTRY.items():
    try:
        models[name] = load_model(config["path"], config["input_size"])
    except Exception as e:
        logger.error(f"Could not load model {name}: {str(e)}")