Spaces:
Sleeping
Sleeping
| """ | |
| Model loader with singleton pattern for CRISPR BERT model. | |
| Ensures the model is loaded only once and reused across requests. | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional | |
| import threading | |
| import numpy as np | |
| import tensorflow as tf | |
| from huggingface_hub import hf_hub_download | |
| from .custom_layers import get_custom_objects | |
| from .tokenizer import WINDOW_SIZE | |
| logger = logging.getLogger(__name__) | |
| # Singleton state | |
| _model: Optional[tf.keras.Model] = None | |
| _embedding_model: Optional[tf.keras.Model] = None | |
| _model_lock = threading.Lock() | |
| # HuggingFace model repository | |
| HF_MODEL_REPO = os.environ.get("CRISPR_HF_REPO", "genomenet/crispr-bert-model") | |
| HF_MODEL_FILENAME = os.environ.get("CRISPR_HF_FILENAME", "best.h5") | |
| # Local model path (optional override) | |
| DEFAULT_MODEL_PATH = os.environ.get("CRISPR_MODEL_PATH", "") | |
| # Embedding layer name for hidden state extraction | |
| # Note: Fine-tuned model has 22 blocks (0-21), base BERT has 24 (0-23) | |
| EMBEDDING_LAYER = os.environ.get( | |
| "CRISPR_EMBEDDING_LAYER", | |
| "layer_transformer_block_21" | |
| ) | |
| def setup_gpu(): | |
| """Configure GPU memory growth to avoid OOM errors.""" | |
| gpus = tf.config.list_physical_devices("GPU") | |
| if gpus: | |
| for gpu in gpus: | |
| try: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| except RuntimeError as e: | |
| logger.warning(f"GPU memory growth setting failed: {e}") | |
| logger.info(f"GPUs available: {[g.name for g in gpus]}") | |
| return True | |
| else: | |
| logger.warning("No GPU found. Running on CPU.") | |
| return False | |
| def load_model(model_path: Optional[str] = None) -> tf.keras.Model: | |
| """ | |
| Load the CRISPR detection model. | |
| Downloads from HuggingFace Hub if no local path is provided. | |
| Args: | |
| model_path: Path to model file (.h5 or .keras) | |
| Returns: | |
| Loaded Keras model | |
| """ | |
| # Use provided path, environment variable, or download from HF Hub | |
| if model_path: | |
| path = Path(model_path) | |
| elif DEFAULT_MODEL_PATH: | |
| path = Path(DEFAULT_MODEL_PATH) | |
| else: | |
| # Download from HuggingFace Hub | |
| logger.info(f"Downloading model from HuggingFace: {HF_MODEL_REPO}/{HF_MODEL_FILENAME}") | |
| path = Path(hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename=HF_MODEL_FILENAME | |
| )) | |
| logger.info(f"Model downloaded to: {path}") | |
| if not path.exists(): | |
| raise FileNotFoundError( | |
| f"Model file not found: {path}\n" | |
| f"Please set CRISPR_MODEL_PATH or ensure HF_MODEL_REPO is accessible." | |
| ) | |
| logger.info(f"Loading model from: {path}") | |
| custom_objects = get_custom_objects() | |
| model = tf.keras.models.load_model(str(path), custom_objects=custom_objects, compile=False) | |
| logger.info(f"Model loaded. Input shape: {model.input_shape}, Output shape: {model.output_shape}") | |
| return model | |
| def build_embedding_model(model: tf.keras.Model, layer_name: str = EMBEDDING_LAYER) -> tf.keras.Model: | |
| """ | |
| Build a sub-model that outputs hidden states from a specific layer. | |
| Args: | |
| model: Full CRISPR detection model | |
| layer_name: Name of the layer to extract embeddings from | |
| Returns: | |
| Keras model that outputs embeddings | |
| """ | |
| try: | |
| embedding_output = model.get_layer(layer_name).output | |
| except ValueError: | |
| # Try to find a suitable layer | |
| available_layers = [l.name for l in model.layers if "transformer" in l.name.lower()] | |
| raise ValueError( | |
| f"Layer '{layer_name}' not found in model. " | |
| f"Available transformer layers: {available_layers}" | |
| ) | |
| embedding_model = tf.keras.Model( | |
| inputs=model.inputs, | |
| outputs=embedding_output, | |
| name="embedding_model" | |
| ) | |
| logger.info(f"Embedding model built. Output shape: {embedding_model.output_shape}") | |
| return embedding_model | |
| def get_model(model_path: Optional[str] = None) -> tf.keras.Model: | |
| """ | |
| Get the singleton model instance. | |
| Thread-safe lazy loading of the model. | |
| Args: | |
| model_path: Optional path to model file | |
| Returns: | |
| Loaded Keras model | |
| """ | |
| global _model | |
| if _model is None: | |
| with _model_lock: | |
| if _model is None: | |
| setup_gpu() | |
| _model = load_model(model_path) | |
| return _model | |
| def get_embedding_model(model_path: Optional[str] = None, layer_name: str = EMBEDDING_LAYER) -> tf.keras.Model: | |
| """ | |
| Get the singleton embedding model instance. | |
| Args: | |
| model_path: Optional path to model file | |
| layer_name: Name of layer to extract embeddings from | |
| Returns: | |
| Embedding extraction model | |
| """ | |
| global _embedding_model | |
| if _embedding_model is None: | |
| with _model_lock: | |
| if _embedding_model is None: | |
| model = get_model(model_path) | |
| _embedding_model = build_embedding_model(model, layer_name) | |
| return _embedding_model | |
| def warmup_model(model: Optional[tf.keras.Model] = None): | |
| """ | |
| Warm up the model by running a dummy inference. | |
| This triggers graph compilation and avoids slow first request. | |
| Args: | |
| model: Model to warm up (uses singleton if not provided) | |
| """ | |
| if model is None: | |
| model = get_model() | |
| logger.info("Warming up model...") | |
| # Determine expected input dtype | |
| expected_dtype = model.inputs[0].dtype | |
| if expected_dtype.is_floating: | |
| dtype = np.float32 | |
| elif expected_dtype == tf.int64: | |
| dtype = np.int64 | |
| else: | |
| dtype = np.int32 | |
| # Create dummy input | |
| dummy = np.ones((1, WINDOW_SIZE), dtype=dtype) | |
| # Run inference | |
| _ = model(dummy, training=False) | |
| logger.info("Model warmup complete.") | |
| def get_model_info() -> dict: | |
| """ | |
| Get information about the loaded model. | |
| Returns: | |
| Dictionary with model metadata | |
| """ | |
| model = get_model() | |
| return { | |
| "input_shape": str(model.input_shape), | |
| "output_shape": str(model.output_shape), | |
| "input_dtype": str(model.inputs[0].dtype.name), | |
| "num_parameters": int(model.count_params()), | |
| "num_layers": len(model.layers), | |
| } | |
| def is_model_loaded() -> bool: | |
| """Check if the model has been loaded.""" | |
| return _model is not None | |
| def get_gpu_status() -> dict: | |
| """Get GPU availability status.""" | |
| gpus = tf.config.list_physical_devices("GPU") | |
| return { | |
| "gpu_available": len(gpus) > 0, | |
| "gpu_count": len(gpus), | |
| "gpu_names": [g.name for g in gpus], | |
| } | |