File size: 6,608 Bytes
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f44b2b9
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
Model loader with singleton pattern for CRISPR BERT model.

Ensures the model is loaded only once and reused across requests.
"""

import os
import logging
from pathlib import Path
from typing import Optional
import threading

import numpy as np
import tensorflow as tf
from huggingface_hub import hf_hub_download

from .custom_layers import get_custom_objects
from .tokenizer import WINDOW_SIZE

logger = logging.getLogger(__name__)

# Singleton state
_model: Optional[tf.keras.Model] = None
_embedding_model: Optional[tf.keras.Model] = None
_model_lock = threading.Lock()

# HuggingFace model repository
HF_MODEL_REPO = os.environ.get("CRISPR_HF_REPO", "genomenet/crispr-bert-model")
HF_MODEL_FILENAME = os.environ.get("CRISPR_HF_FILENAME", "best.h5")

# Local model path (optional override)
DEFAULT_MODEL_PATH = os.environ.get("CRISPR_MODEL_PATH", "")

# Embedding layer name for hidden state extraction
# Note: Fine-tuned model has 22 blocks (0-21), base BERT has 24 (0-23)
EMBEDDING_LAYER = os.environ.get(
    "CRISPR_EMBEDDING_LAYER",
    "layer_transformer_block_21"
)


def setup_gpu():
    """Configure GPU memory growth to avoid OOM errors."""
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for gpu in gpus:
            try:
                tf.config.experimental.set_memory_growth(gpu, True)
            except RuntimeError as e:
                logger.warning(f"GPU memory growth setting failed: {e}")
        logger.info(f"GPUs available: {[g.name for g in gpus]}")
        return True
    else:
        logger.warning("No GPU found. Running on CPU.")
        return False


def load_model(model_path: Optional[str] = None) -> tf.keras.Model:
    """
    Load the CRISPR detection model.

    Downloads from HuggingFace Hub if no local path is provided.

    Args:
        model_path: Path to model file (.h5 or .keras)

    Returns:
        Loaded Keras model
    """
    # Use provided path, environment variable, or download from HF Hub
    if model_path:
        path = Path(model_path)
    elif DEFAULT_MODEL_PATH:
        path = Path(DEFAULT_MODEL_PATH)
    else:
        # Download from HuggingFace Hub
        logger.info(f"Downloading model from HuggingFace: {HF_MODEL_REPO}/{HF_MODEL_FILENAME}")
        path = Path(hf_hub_download(
            repo_id=HF_MODEL_REPO,
            filename=HF_MODEL_FILENAME
        ))
        logger.info(f"Model downloaded to: {path}")

    if not path.exists():
        raise FileNotFoundError(
            f"Model file not found: {path}\n"
            f"Please set CRISPR_MODEL_PATH or ensure HF_MODEL_REPO is accessible."
        )

    logger.info(f"Loading model from: {path}")

    custom_objects = get_custom_objects()
    model = tf.keras.models.load_model(str(path), custom_objects=custom_objects, compile=False)

    logger.info(f"Model loaded. Input shape: {model.input_shape}, Output shape: {model.output_shape}")

    return model


def build_embedding_model(model: tf.keras.Model, layer_name: str = EMBEDDING_LAYER) -> tf.keras.Model:
    """
    Build a sub-model that outputs hidden states from a specific layer.

    Args:
        model: Full CRISPR detection model
        layer_name: Name of the layer to extract embeddings from

    Returns:
        Keras model that outputs embeddings
    """
    try:
        embedding_output = model.get_layer(layer_name).output
    except ValueError:
        # Try to find a suitable layer
        available_layers = [l.name for l in model.layers if "transformer" in l.name.lower()]
        raise ValueError(
            f"Layer '{layer_name}' not found in model. "
            f"Available transformer layers: {available_layers}"
        )

    embedding_model = tf.keras.Model(
        inputs=model.inputs,
        outputs=embedding_output,
        name="embedding_model"
    )

    logger.info(f"Embedding model built. Output shape: {embedding_model.output_shape}")

    return embedding_model


def get_model(model_path: Optional[str] = None) -> tf.keras.Model:
    """
    Get the singleton model instance.

    Thread-safe lazy loading of the model.

    Args:
        model_path: Optional path to model file

    Returns:
        Loaded Keras model
    """
    global _model

    if _model is None:
        with _model_lock:
            if _model is None:
                setup_gpu()
                _model = load_model(model_path)

    return _model


def get_embedding_model(model_path: Optional[str] = None, layer_name: str = EMBEDDING_LAYER) -> tf.keras.Model:
    """
    Get the singleton embedding model instance.

    Args:
        model_path: Optional path to model file
        layer_name: Name of layer to extract embeddings from

    Returns:
        Embedding extraction model
    """
    global _embedding_model

    if _embedding_model is None:
        with _model_lock:
            if _embedding_model is None:
                model = get_model(model_path)
                _embedding_model = build_embedding_model(model, layer_name)

    return _embedding_model


def warmup_model(model: Optional[tf.keras.Model] = None):
    """
    Warm up the model by running a dummy inference.

    This triggers graph compilation and avoids slow first request.

    Args:
        model: Model to warm up (uses singleton if not provided)
    """
    if model is None:
        model = get_model()

    logger.info("Warming up model...")

    # Determine expected input dtype
    expected_dtype = model.inputs[0].dtype
    if expected_dtype.is_floating:
        dtype = np.float32
    elif expected_dtype == tf.int64:
        dtype = np.int64
    else:
        dtype = np.int32

    # Create dummy input
    dummy = np.ones((1, WINDOW_SIZE), dtype=dtype)

    # Run inference
    _ = model(dummy, training=False)

    logger.info("Model warmup complete.")


def get_model_info() -> dict:
    """
    Get information about the loaded model.

    Returns:
        Dictionary with model metadata
    """
    model = get_model()

    return {
        "input_shape": str(model.input_shape),
        "output_shape": str(model.output_shape),
        "input_dtype": str(model.inputs[0].dtype.name),
        "num_parameters": int(model.count_params()),
        "num_layers": len(model.layers),
    }


def is_model_loaded() -> bool:
    """Check if the model has been loaded."""
    return _model is not None


def get_gpu_status() -> dict:
    """Get GPU availability status."""
    gpus = tf.config.list_physical_devices("GPU")
    return {
        "gpu_available": len(gpus) > 0,
        "gpu_count": len(gpus),
        "gpu_names": [g.name for g in gpus],
    }