""" Core inference functions for CRISPR Array Detection. Provides: - predict_sequence(): Per-position CRISPR probability scores - embed_sequence(): Hidden state embeddings for state-dynamics visualization """ import logging from typing import Optional, Literal from dataclasses import dataclass import numpy as np import tensorflow as tf from .model_loader import get_model, get_embedding_model from .tokenizer import encode_sequence, create_windows, WINDOW_SIZE logger = logging.getLogger(__name__) @dataclass class PredictionResult: """Result of CRISPR prediction.""" positions: list[int] probabilities: list[float] overall_score: float sequence_length: int num_windows: int @dataclass class EmbeddingResult: """Result of embedding extraction.""" embedding: list[float] method: str embedding_dim: int sequence_length: int @dataclass class TrajectoryResult: """Result of trajectory embedding extraction.""" embeddings: list[list[float]] positions: list[int] method: str embedding_dim: int sequence_length: int num_windows: int def cast_for_model(X: np.ndarray, expected_dtype: tf.dtypes.DType) -> np.ndarray: """Cast input array to the dtype expected by the model.""" if expected_dtype.is_floating: return X.astype(np.float32, copy=False) if expected_dtype == tf.int64: return X.astype(np.int64, copy=False) return X.astype(np.int32, copy=False) def predict_batch( model: tf.keras.Model, windows: np.ndarray, batch_size: int = 32 ) -> np.ndarray: """ Run prediction on batched windows. Args: model: Keras model windows: Array of shape (N, window_size) batch_size: Batch size for inference Returns: Predictions of shape (N, window_size) with probabilities """ if batch_size <= 0: raise ValueError("batch_size must be a positive integer") expected_dtype = model.inputs[0].dtype windows = cast_for_model(windows, expected_dtype) n_windows = len(windows) window_size = windows.shape[1] # Pre-allocate output predictions = np.empty((n_windows, window_size), dtype=np.float32) for i in range(0, n_windows, batch_size): batch = windows[i:i + batch_size] pred = model(batch, training=False) # Handle different output shapes if len(pred.shape) == 3: if pred.shape[-1] == 1: pred = tf.squeeze(pred, axis=-1) else: # Multi-class output - take class 1 probability (CRISPR positive) pred = pred[..., 1] if pred.shape[-1] > 1 else pred[..., 0] pred = tf.clip_by_value(pred, 0.0, 1.0).numpy() predictions[i:i + len(batch)] = pred return predictions def aggregate_predictions( predictions: np.ndarray, starts: np.ndarray, seq_length: int, window_size: int = WINDOW_SIZE, aggregation: Literal["mean", "max"] = "mean" ) -> np.ndarray: """ Aggregate overlapping window predictions into per-position scores. Args: predictions: Array of shape (N, window_size) with per-window predictions starts: Start positions of each window seq_length: Total sequence length window_size: Size of each window aggregation: Aggregation method ("mean" or "max") Returns: Per-position probability array of shape (seq_length,) """ if aggregation not in {"mean", "max"}: raise ValueError("aggregation must be 'mean' or 'max'") scores = np.zeros(seq_length, dtype=np.float32) counts = np.zeros(seq_length, dtype=np.int32) for i, start in enumerate(starts): end = min(start + window_size, seq_length) pred_len = end - start if aggregation == "max": scores[start:end] = np.maximum(scores[start:end], predictions[i, :pred_len]) else: # mean scores[start:end] += predictions[i, :pred_len] counts[start:end] += 1 if aggregation == "mean": # Avoid division by zero counts = np.maximum(counts, 1) scores = scores / counts return scores def predict_sequence( sequence: str, stride: int = 100, batch_size: int = 32, aggregation: Literal["mean", "max"] = "mean", model: Optional[tf.keras.Model] = None ) -> PredictionResult: """ Predict CRISPR array probability for each position in a sequence. Args: sequence: DNA sequence string stride: Step size between sliding windows (default 100) batch_size: Batch size for inference (default 32) aggregation: How to aggregate overlapping predictions ("mean" or "max") model: Optional model instance (uses singleton if not provided) Returns: PredictionResult with per-position probabilities """ if aggregation not in {"mean", "max"}: raise ValueError("aggregation must be 'mean' or 'max'") if batch_size <= 0: raise ValueError("batch_size must be a positive integer") # Tokenize sequence tokens = encode_sequence(sequence) seq_length = len(tokens) # Create sliding windows windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride) if model is None: model = get_model() logger.info(f"Processing sequence: {seq_length} bp, {len(windows)} windows (stride={stride})") # Run batched prediction predictions = predict_batch(model, windows, batch_size=batch_size) # Aggregate to per-position scores scores = aggregate_predictions( predictions, starts, seq_length, window_size=WINDOW_SIZE, aggregation=aggregation ) # Calculate overall score (mean of all positions) overall_score = float(np.mean(scores)) return PredictionResult( positions=list(range(seq_length)), probabilities=[float(p) for p in scores], overall_score=overall_score, sequence_length=seq_length, num_windows=len(windows) ) def embed_batch( model: tf.keras.Model, windows: np.ndarray, batch_size: int = 32 ) -> np.ndarray: """ Extract embeddings from batched windows. Args: model: Embedding model windows: Array of shape (N, window_size) batch_size: Batch size for inference Returns: Embeddings of shape (N, window_size, embed_dim) or (N, embed_dim) """ if batch_size <= 0: raise ValueError("batch_size must be a positive integer") expected_dtype = model.inputs[0].dtype windows = cast_for_model(windows, expected_dtype) n_windows = len(windows) embeddings = [] for i in range(0, n_windows, batch_size): batch = windows[i:i + batch_size] emb = model(batch, training=False).numpy() embeddings.append(emb) return np.concatenate(embeddings, axis=0) def embed_sequence( sequence: str, mode: Literal["mean", "cls", "max", "trajectory"] = "mean", stride: int = 100, batch_size: int = 32, model: Optional[tf.keras.Model] = None ) -> EmbeddingResult | TrajectoryResult: """ Extract hidden state embeddings from a sequence. Args: sequence: DNA sequence string mode: Embedding mode: - "mean": Mean pool over all positions and windows - "cls": Use first position embedding (similar to BERT [CLS]) - "max": Max pool over all positions - "trajectory": Return per-window embeddings for state-dynamics stride: Step size between windows (for trajectory mode) batch_size: Batch size for inference model: Optional embedding model (uses singleton if not provided) Returns: EmbeddingResult (for mean/cls/max) or TrajectoryResult (for trajectory) """ if mode not in {"mean", "cls", "max", "trajectory"}: raise ValueError("mode must be one of: mean, cls, max, trajectory") if batch_size <= 0: raise ValueError("batch_size must be a positive integer") # Tokenize sequence tokens = encode_sequence(sequence) seq_length = len(tokens) # Create windows windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride) if model is None: model = get_embedding_model() logger.info(f"Extracting embeddings: {seq_length} bp, {len(windows)} windows") # Get embeddings (shape: N, window_size, embed_dim) embeddings = embed_batch(model, windows, batch_size=batch_size) if mode == "trajectory": # Return per-window mean-pooled embeddings for state-dynamics # Shape: (N, embed_dim) after mean over positions window_embeddings = np.mean(embeddings, axis=1) return TrajectoryResult( embeddings=[emb.tolist() for emb in window_embeddings], positions=[int(s + WINDOW_SIZE // 2) for s in starts], # Center position method="mean_pooled_per_window", embedding_dim=int(window_embeddings.shape[-1]), sequence_length=seq_length, num_windows=len(windows) ) # Aggregate embeddings across all windows and positions if mode == "cls": # Take first position from first window final_embedding = embeddings[0, 0, :] method = "first_position_first_window" elif mode == "max": # Max pool across all positions and windows final_embedding = np.max(embeddings, axis=(0, 1)) method = "max_pooled_all" else: # mean # Mean pool across all positions and windows final_embedding = np.mean(embeddings, axis=(0, 1)) method = "mean_pooled_all" return EmbeddingResult( embedding=final_embedding.tolist(), method=method, embedding_dim=int(final_embedding.shape[0]), sequence_length=seq_length ) def detect_crispr_regions( sequence: str, threshold: float = 0.3, min_length: int = 160, merge_gap: int = 80, stride: int = 100, model: Optional[tf.keras.Model] = None, prediction_result: Optional[PredictionResult] = None ) -> list[dict]: """ Detect CRISPR array regions in a sequence. Args: sequence: DNA sequence string threshold: Probability threshold for calling CRISPR (default 0.3) min_length: Minimum region length in bp (default 160) merge_gap: Maximum gap to merge adjacent regions (default 80) stride: Sliding window stride (default 100) model: Optional model instance Returns: List of detected regions with coordinates and scores """ if not 0.0 <= threshold <= 1.0: raise ValueError("threshold must be between 0 and 1") if min_length < 1: raise ValueError("min_length must be at least 1") if merge_gap < 0: raise ValueError("merge_gap must be non-negative") # Get per-position predictions, or reuse a caller-provided result to avoid # running the model twice in UI flows that need both scores and regions. result = prediction_result or predict_sequence(sequence, stride=stride, model=model) scores = np.array(result.probabilities) # Threshold to binary mask mask = scores >= threshold # Find contiguous regions regions = [] in_region = False start = 0 for i, is_crispr in enumerate(mask): if is_crispr and not in_region: start = i in_region = True elif not is_crispr and in_region: regions.append((start, i)) in_region = False if in_region: regions.append((start, len(mask))) # Merge nearby regions if regions: merged = [list(regions[0])] for s, e in regions[1:]: prev_s, prev_e = merged[-1] if s <= prev_e + merge_gap: merged[-1][1] = e else: merged.append([s, e]) regions = [(s, e) for s, e in merged] # Filter by length and compute scores detected = [] for i, (start, end) in enumerate(regions): length = end - start if length >= min_length: region_scores = scores[start:end] detected.append({ "region_id": i + 1, "start": int(start + 1), # 1-based "end": int(end), # 1-based inclusive "length": length, "max_score": float(np.max(region_scores)), "mean_score": float(np.mean(region_scores)), }) return detected