""" Model loader with singleton pattern for CRISPR BERT model. Ensures the model is loaded only once and reused across requests. """ import os import logging from pathlib import Path from typing import Optional import threading import numpy as np import tensorflow as tf from huggingface_hub import hf_hub_download from .custom_layers import get_custom_objects from .tokenizer import WINDOW_SIZE logger = logging.getLogger(__name__) # Singleton state _model: Optional[tf.keras.Model] = None _embedding_model: Optional[tf.keras.Model] = None _model_lock = threading.Lock() # HuggingFace model repository HF_MODEL_REPO = os.environ.get("CRISPR_HF_REPO", "genomenet/crispr-bert-model") HF_MODEL_FILENAME = os.environ.get("CRISPR_HF_FILENAME", "best.h5") # Local model path (optional override) DEFAULT_MODEL_PATH = os.environ.get("CRISPR_MODEL_PATH", "") # Embedding layer name for hidden state extraction # Note: Fine-tuned model has 22 blocks (0-21), base BERT has 24 (0-23) EMBEDDING_LAYER = os.environ.get( "CRISPR_EMBEDDING_LAYER", "layer_transformer_block_21" ) def setup_gpu(): """Configure GPU memory growth to avoid OOM errors.""" gpus = tf.config.list_physical_devices("GPU") if gpus: for gpu in gpus: try: tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: logger.warning(f"GPU memory growth setting failed: {e}") logger.info(f"GPUs available: {[g.name for g in gpus]}") return True else: logger.warning("No GPU found. Running on CPU.") return False def load_model(model_path: Optional[str] = None) -> tf.keras.Model: """ Load the CRISPR detection model. Downloads from HuggingFace Hub if no local path is provided. Args: model_path: Path to model file (.h5 or .keras) Returns: Loaded Keras model """ # Use provided path, environment variable, or download from HF Hub if model_path: path = Path(model_path) elif DEFAULT_MODEL_PATH: path = Path(DEFAULT_MODEL_PATH) else: # Download from HuggingFace Hub logger.info(f"Downloading model from HuggingFace: {HF_MODEL_REPO}/{HF_MODEL_FILENAME}") path = Path(hf_hub_download( repo_id=HF_MODEL_REPO, filename=HF_MODEL_FILENAME )) logger.info(f"Model downloaded to: {path}") if not path.exists(): raise FileNotFoundError( f"Model file not found: {path}\n" f"Please set CRISPR_MODEL_PATH or ensure HF_MODEL_REPO is accessible." ) logger.info(f"Loading model from: {path}") custom_objects = get_custom_objects() model = tf.keras.models.load_model(str(path), custom_objects=custom_objects, compile=False) logger.info(f"Model loaded. Input shape: {model.input_shape}, Output shape: {model.output_shape}") return model def build_embedding_model(model: tf.keras.Model, layer_name: str = EMBEDDING_LAYER) -> tf.keras.Model: """ Build a sub-model that outputs hidden states from a specific layer. Args: model: Full CRISPR detection model layer_name: Name of the layer to extract embeddings from Returns: Keras model that outputs embeddings """ try: embedding_output = model.get_layer(layer_name).output except ValueError: # Try to find a suitable layer available_layers = [l.name for l in model.layers if "transformer" in l.name.lower()] raise ValueError( f"Layer '{layer_name}' not found in model. " f"Available transformer layers: {available_layers}" ) embedding_model = tf.keras.Model( inputs=model.inputs, outputs=embedding_output, name="embedding_model" ) logger.info(f"Embedding model built. Output shape: {embedding_model.output_shape}") return embedding_model def get_model(model_path: Optional[str] = None) -> tf.keras.Model: """ Get the singleton model instance. Thread-safe lazy loading of the model. Args: model_path: Optional path to model file Returns: Loaded Keras model """ global _model if _model is None: with _model_lock: if _model is None: setup_gpu() _model = load_model(model_path) return _model def get_embedding_model(model_path: Optional[str] = None, layer_name: str = EMBEDDING_LAYER) -> tf.keras.Model: """ Get the singleton embedding model instance. Args: model_path: Optional path to model file layer_name: Name of layer to extract embeddings from Returns: Embedding extraction model """ global _embedding_model if _embedding_model is None: with _model_lock: if _embedding_model is None: model = get_model(model_path) _embedding_model = build_embedding_model(model, layer_name) return _embedding_model def warmup_model(model: Optional[tf.keras.Model] = None): """ Warm up the model by running a dummy inference. This triggers graph compilation and avoids slow first request. Args: model: Model to warm up (uses singleton if not provided) """ if model is None: model = get_model() logger.info("Warming up model...") # Determine expected input dtype expected_dtype = model.inputs[0].dtype if expected_dtype.is_floating: dtype = np.float32 elif expected_dtype == tf.int64: dtype = np.int64 else: dtype = np.int32 # Create dummy input dummy = np.ones((1, WINDOW_SIZE), dtype=dtype) # Run inference _ = model(dummy, training=False) logger.info("Model warmup complete.") def get_model_info() -> dict: """ Get information about the loaded model. Returns: Dictionary with model metadata """ model = get_model() return { "input_shape": str(model.input_shape), "output_shape": str(model.output_shape), "input_dtype": str(model.inputs[0].dtype.name), "num_parameters": int(model.count_params()), "num_layers": len(model.layers), } def is_model_loaded() -> bool: """Check if the model has been loaded.""" return _model is not None def get_gpu_status() -> dict: """Get GPU availability status.""" gpus = tf.config.list_physical_devices("GPU") return { "gpu_available": len(gpus) > 0, "gpu_count": len(gpus), "gpu_names": [g.name for g in gpus], }