File size: 12,498 Bytes
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5297
 
 
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5297
 
 
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5297
 
 
 
52e5b45
 
 
 
 
 
 
 
3cc5297
 
 
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5297
 
 
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5297
 
 
 
52e5b45
 
 
 
 
 
 
 
3cc5297
 
 
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5297
 
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5297
 
 
 
 
 
 
 
 
 
52e5b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
"""
Core inference functions for CRISPR Array Detection.

Provides:
- predict_sequence(): Per-position CRISPR probability scores
- embed_sequence(): Hidden state embeddings for state-dynamics visualization
"""

import logging
from typing import Optional, Literal
from dataclasses import dataclass

import numpy as np
import tensorflow as tf

from .model_loader import get_model, get_embedding_model
from .tokenizer import encode_sequence, create_windows, WINDOW_SIZE

logger = logging.getLogger(__name__)


@dataclass
class PredictionResult:
    """Result of CRISPR prediction."""
    positions: list[int]
    probabilities: list[float]
    overall_score: float
    sequence_length: int
    num_windows: int


@dataclass
class EmbeddingResult:
    """Result of embedding extraction."""
    embedding: list[float]
    method: str
    embedding_dim: int
    sequence_length: int


@dataclass
class TrajectoryResult:
    """Result of trajectory embedding extraction."""
    embeddings: list[list[float]]
    positions: list[int]
    method: str
    embedding_dim: int
    sequence_length: int
    num_windows: int


def cast_for_model(X: np.ndarray, expected_dtype: tf.dtypes.DType) -> np.ndarray:
    """Cast input array to the dtype expected by the model."""
    if expected_dtype.is_floating:
        return X.astype(np.float32, copy=False)
    if expected_dtype == tf.int64:
        return X.astype(np.int64, copy=False)
    return X.astype(np.int32, copy=False)


def predict_batch(
    model: tf.keras.Model,
    windows: np.ndarray,
    batch_size: int = 32
) -> np.ndarray:
    """
    Run prediction on batched windows.

    Args:
        model: Keras model
        windows: Array of shape (N, window_size)
        batch_size: Batch size for inference

    Returns:
        Predictions of shape (N, window_size) with probabilities
    """
    if batch_size <= 0:
        raise ValueError("batch_size must be a positive integer")

    expected_dtype = model.inputs[0].dtype
    windows = cast_for_model(windows, expected_dtype)

    n_windows = len(windows)
    window_size = windows.shape[1]

    # Pre-allocate output
    predictions = np.empty((n_windows, window_size), dtype=np.float32)

    for i in range(0, n_windows, batch_size):
        batch = windows[i:i + batch_size]
        pred = model(batch, training=False)

        # Handle different output shapes
        if len(pred.shape) == 3:
            if pred.shape[-1] == 1:
                pred = tf.squeeze(pred, axis=-1)
            else:
                # Multi-class output - take class 1 probability (CRISPR positive)
                pred = pred[..., 1] if pred.shape[-1] > 1 else pred[..., 0]

        pred = tf.clip_by_value(pred, 0.0, 1.0).numpy()
        predictions[i:i + len(batch)] = pred

    return predictions


def aggregate_predictions(
    predictions: np.ndarray,
    starts: np.ndarray,
    seq_length: int,
    window_size: int = WINDOW_SIZE,
    aggregation: Literal["mean", "max"] = "mean"
) -> np.ndarray:
    """
    Aggregate overlapping window predictions into per-position scores.

    Args:
        predictions: Array of shape (N, window_size) with per-window predictions
        starts: Start positions of each window
        seq_length: Total sequence length
        window_size: Size of each window
        aggregation: Aggregation method ("mean" or "max")

    Returns:
        Per-position probability array of shape (seq_length,)
    """
    if aggregation not in {"mean", "max"}:
        raise ValueError("aggregation must be 'mean' or 'max'")

    scores = np.zeros(seq_length, dtype=np.float32)
    counts = np.zeros(seq_length, dtype=np.int32)

    for i, start in enumerate(starts):
        end = min(start + window_size, seq_length)
        pred_len = end - start

        if aggregation == "max":
            scores[start:end] = np.maximum(scores[start:end], predictions[i, :pred_len])
        else:  # mean
            scores[start:end] += predictions[i, :pred_len]
            counts[start:end] += 1

    if aggregation == "mean":
        # Avoid division by zero
        counts = np.maximum(counts, 1)
        scores = scores / counts

    return scores


def predict_sequence(
    sequence: str,
    stride: int = 100,
    batch_size: int = 32,
    aggregation: Literal["mean", "max"] = "mean",
    model: Optional[tf.keras.Model] = None
) -> PredictionResult:
    """
    Predict CRISPR array probability for each position in a sequence.

    Args:
        sequence: DNA sequence string
        stride: Step size between sliding windows (default 100)
        batch_size: Batch size for inference (default 32)
        aggregation: How to aggregate overlapping predictions ("mean" or "max")
        model: Optional model instance (uses singleton if not provided)

    Returns:
        PredictionResult with per-position probabilities
    """
    if aggregation not in {"mean", "max"}:
        raise ValueError("aggregation must be 'mean' or 'max'")
    if batch_size <= 0:
        raise ValueError("batch_size must be a positive integer")

    # Tokenize sequence
    tokens = encode_sequence(sequence)
    seq_length = len(tokens)

    # Create sliding windows
    windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride)

    if model is None:
        model = get_model()

    logger.info(f"Processing sequence: {seq_length} bp, {len(windows)} windows (stride={stride})")

    # Run batched prediction
    predictions = predict_batch(model, windows, batch_size=batch_size)

    # Aggregate to per-position scores
    scores = aggregate_predictions(
        predictions, starts, seq_length,
        window_size=WINDOW_SIZE, aggregation=aggregation
    )

    # Calculate overall score (mean of all positions)
    overall_score = float(np.mean(scores))

    return PredictionResult(
        positions=list(range(seq_length)),
        probabilities=[float(p) for p in scores],
        overall_score=overall_score,
        sequence_length=seq_length,
        num_windows=len(windows)
    )


def embed_batch(
    model: tf.keras.Model,
    windows: np.ndarray,
    batch_size: int = 32
) -> np.ndarray:
    """
    Extract embeddings from batched windows.

    Args:
        model: Embedding model
        windows: Array of shape (N, window_size)
        batch_size: Batch size for inference

    Returns:
        Embeddings of shape (N, window_size, embed_dim) or (N, embed_dim)
    """
    if batch_size <= 0:
        raise ValueError("batch_size must be a positive integer")

    expected_dtype = model.inputs[0].dtype
    windows = cast_for_model(windows, expected_dtype)

    n_windows = len(windows)
    embeddings = []

    for i in range(0, n_windows, batch_size):
        batch = windows[i:i + batch_size]
        emb = model(batch, training=False).numpy()
        embeddings.append(emb)

    return np.concatenate(embeddings, axis=0)


def embed_sequence(
    sequence: str,
    mode: Literal["mean", "cls", "max", "trajectory"] = "mean",
    stride: int = 100,
    batch_size: int = 32,
    model: Optional[tf.keras.Model] = None
) -> EmbeddingResult | TrajectoryResult:
    """
    Extract hidden state embeddings from a sequence.

    Args:
        sequence: DNA sequence string
        mode: Embedding mode:
            - "mean": Mean pool over all positions and windows
            - "cls": Use first position embedding (similar to BERT [CLS])
            - "max": Max pool over all positions
            - "trajectory": Return per-window embeddings for state-dynamics
        stride: Step size between windows (for trajectory mode)
        batch_size: Batch size for inference
        model: Optional embedding model (uses singleton if not provided)

    Returns:
        EmbeddingResult (for mean/cls/max) or TrajectoryResult (for trajectory)
    """
    if mode not in {"mean", "cls", "max", "trajectory"}:
        raise ValueError("mode must be one of: mean, cls, max, trajectory")
    if batch_size <= 0:
        raise ValueError("batch_size must be a positive integer")

    # Tokenize sequence
    tokens = encode_sequence(sequence)
    seq_length = len(tokens)

    # Create windows
    windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride)

    if model is None:
        model = get_embedding_model()

    logger.info(f"Extracting embeddings: {seq_length} bp, {len(windows)} windows")

    # Get embeddings (shape: N, window_size, embed_dim)
    embeddings = embed_batch(model, windows, batch_size=batch_size)

    if mode == "trajectory":
        # Return per-window mean-pooled embeddings for state-dynamics
        # Shape: (N, embed_dim) after mean over positions
        window_embeddings = np.mean(embeddings, axis=1)

        return TrajectoryResult(
            embeddings=[emb.tolist() for emb in window_embeddings],
            positions=[int(s + WINDOW_SIZE // 2) for s in starts],  # Center position
            method="mean_pooled_per_window",
            embedding_dim=int(window_embeddings.shape[-1]),
            sequence_length=seq_length,
            num_windows=len(windows)
        )

    # Aggregate embeddings across all windows and positions
    if mode == "cls":
        # Take first position from first window
        final_embedding = embeddings[0, 0, :]
        method = "first_position_first_window"
    elif mode == "max":
        # Max pool across all positions and windows
        final_embedding = np.max(embeddings, axis=(0, 1))
        method = "max_pooled_all"
    else:  # mean
        # Mean pool across all positions and windows
        final_embedding = np.mean(embeddings, axis=(0, 1))
        method = "mean_pooled_all"

    return EmbeddingResult(
        embedding=final_embedding.tolist(),
        method=method,
        embedding_dim=int(final_embedding.shape[0]),
        sequence_length=seq_length
    )


def detect_crispr_regions(
    sequence: str,
    threshold: float = 0.3,
    min_length: int = 160,
    merge_gap: int = 80,
    stride: int = 100,
    model: Optional[tf.keras.Model] = None,
    prediction_result: Optional[PredictionResult] = None
) -> list[dict]:
    """
    Detect CRISPR array regions in a sequence.

    Args:
        sequence: DNA sequence string
        threshold: Probability threshold for calling CRISPR (default 0.3)
        min_length: Minimum region length in bp (default 160)
        merge_gap: Maximum gap to merge adjacent regions (default 80)
        stride: Sliding window stride (default 100)
        model: Optional model instance

    Returns:
        List of detected regions with coordinates and scores
    """
    if not 0.0 <= threshold <= 1.0:
        raise ValueError("threshold must be between 0 and 1")
    if min_length < 1:
        raise ValueError("min_length must be at least 1")
    if merge_gap < 0:
        raise ValueError("merge_gap must be non-negative")

    # Get per-position predictions, or reuse a caller-provided result to avoid
    # running the model twice in UI flows that need both scores and regions.
    result = prediction_result or predict_sequence(sequence, stride=stride, model=model)
    scores = np.array(result.probabilities)

    # Threshold to binary mask
    mask = scores >= threshold

    # Find contiguous regions
    regions = []
    in_region = False
    start = 0

    for i, is_crispr in enumerate(mask):
        if is_crispr and not in_region:
            start = i
            in_region = True
        elif not is_crispr and in_region:
            regions.append((start, i))
            in_region = False

    if in_region:
        regions.append((start, len(mask)))

    # Merge nearby regions
    if regions:
        merged = [list(regions[0])]
        for s, e in regions[1:]:
            prev_s, prev_e = merged[-1]
            if s <= prev_e + merge_gap:
                merged[-1][1] = e
            else:
                merged.append([s, e])
        regions = [(s, e) for s, e in merged]

    # Filter by length and compute scores
    detected = []
    for i, (start, end) in enumerate(regions):
        length = end - start
        if length >= min_length:
            region_scores = scores[start:end]
            detected.append({
                "region_id": i + 1,
                "start": int(start + 1),  # 1-based
                "end": int(end),  # 1-based inclusive
                "length": length,
                "max_score": float(np.max(region_scores)),
                "mean_score": float(np.mean(region_scores)),
            })

    return detected