Spaces:
Sleeping
Sleeping
| """ | |
| Core inference functions for CRISPR Array Detection. | |
| Provides: | |
| - predict_sequence(): Per-position CRISPR probability scores | |
| - embed_sequence(): Hidden state embeddings for state-dynamics visualization | |
| """ | |
| import logging | |
| from typing import Optional, Literal | |
| from dataclasses import dataclass | |
| import numpy as np | |
| import tensorflow as tf | |
| from .model_loader import get_model, get_embedding_model | |
| from .tokenizer import encode_sequence, create_windows, WINDOW_SIZE | |
| logger = logging.getLogger(__name__) | |
| class PredictionResult: | |
| """Result of CRISPR prediction.""" | |
| positions: list[int] | |
| probabilities: list[float] | |
| overall_score: float | |
| sequence_length: int | |
| num_windows: int | |
| class EmbeddingResult: | |
| """Result of embedding extraction.""" | |
| embedding: list[float] | |
| method: str | |
| embedding_dim: int | |
| sequence_length: int | |
| class TrajectoryResult: | |
| """Result of trajectory embedding extraction.""" | |
| embeddings: list[list[float]] | |
| positions: list[int] | |
| method: str | |
| embedding_dim: int | |
| sequence_length: int | |
| num_windows: int | |
| def cast_for_model(X: np.ndarray, expected_dtype: tf.dtypes.DType) -> np.ndarray: | |
| """Cast input array to the dtype expected by the model.""" | |
| if expected_dtype.is_floating: | |
| return X.astype(np.float32, copy=False) | |
| if expected_dtype == tf.int64: | |
| return X.astype(np.int64, copy=False) | |
| return X.astype(np.int32, copy=False) | |
| def predict_batch( | |
| model: tf.keras.Model, | |
| windows: np.ndarray, | |
| batch_size: int = 32 | |
| ) -> np.ndarray: | |
| """ | |
| Run prediction on batched windows. | |
| Args: | |
| model: Keras model | |
| windows: Array of shape (N, window_size) | |
| batch_size: Batch size for inference | |
| Returns: | |
| Predictions of shape (N, window_size) with probabilities | |
| """ | |
| if batch_size <= 0: | |
| raise ValueError("batch_size must be a positive integer") | |
| expected_dtype = model.inputs[0].dtype | |
| windows = cast_for_model(windows, expected_dtype) | |
| n_windows = len(windows) | |
| window_size = windows.shape[1] | |
| # Pre-allocate output | |
| predictions = np.empty((n_windows, window_size), dtype=np.float32) | |
| for i in range(0, n_windows, batch_size): | |
| batch = windows[i:i + batch_size] | |
| pred = model(batch, training=False) | |
| # Handle different output shapes | |
| if len(pred.shape) == 3: | |
| if pred.shape[-1] == 1: | |
| pred = tf.squeeze(pred, axis=-1) | |
| else: | |
| # Multi-class output - take class 1 probability (CRISPR positive) | |
| pred = pred[..., 1] if pred.shape[-1] > 1 else pred[..., 0] | |
| pred = tf.clip_by_value(pred, 0.0, 1.0).numpy() | |
| predictions[i:i + len(batch)] = pred | |
| return predictions | |
| def aggregate_predictions( | |
| predictions: np.ndarray, | |
| starts: np.ndarray, | |
| seq_length: int, | |
| window_size: int = WINDOW_SIZE, | |
| aggregation: Literal["mean", "max"] = "mean" | |
| ) -> np.ndarray: | |
| """ | |
| Aggregate overlapping window predictions into per-position scores. | |
| Args: | |
| predictions: Array of shape (N, window_size) with per-window predictions | |
| starts: Start positions of each window | |
| seq_length: Total sequence length | |
| window_size: Size of each window | |
| aggregation: Aggregation method ("mean" or "max") | |
| Returns: | |
| Per-position probability array of shape (seq_length,) | |
| """ | |
| if aggregation not in {"mean", "max"}: | |
| raise ValueError("aggregation must be 'mean' or 'max'") | |
| scores = np.zeros(seq_length, dtype=np.float32) | |
| counts = np.zeros(seq_length, dtype=np.int32) | |
| for i, start in enumerate(starts): | |
| end = min(start + window_size, seq_length) | |
| pred_len = end - start | |
| if aggregation == "max": | |
| scores[start:end] = np.maximum(scores[start:end], predictions[i, :pred_len]) | |
| else: # mean | |
| scores[start:end] += predictions[i, :pred_len] | |
| counts[start:end] += 1 | |
| if aggregation == "mean": | |
| # Avoid division by zero | |
| counts = np.maximum(counts, 1) | |
| scores = scores / counts | |
| return scores | |
| def predict_sequence( | |
| sequence: str, | |
| stride: int = 100, | |
| batch_size: int = 32, | |
| aggregation: Literal["mean", "max"] = "mean", | |
| model: Optional[tf.keras.Model] = None | |
| ) -> PredictionResult: | |
| """ | |
| Predict CRISPR array probability for each position in a sequence. | |
| Args: | |
| sequence: DNA sequence string | |
| stride: Step size between sliding windows (default 100) | |
| batch_size: Batch size for inference (default 32) | |
| aggregation: How to aggregate overlapping predictions ("mean" or "max") | |
| model: Optional model instance (uses singleton if not provided) | |
| Returns: | |
| PredictionResult with per-position probabilities | |
| """ | |
| if aggregation not in {"mean", "max"}: | |
| raise ValueError("aggregation must be 'mean' or 'max'") | |
| if batch_size <= 0: | |
| raise ValueError("batch_size must be a positive integer") | |
| # Tokenize sequence | |
| tokens = encode_sequence(sequence) | |
| seq_length = len(tokens) | |
| # Create sliding windows | |
| windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride) | |
| if model is None: | |
| model = get_model() | |
| logger.info(f"Processing sequence: {seq_length} bp, {len(windows)} windows (stride={stride})") | |
| # Run batched prediction | |
| predictions = predict_batch(model, windows, batch_size=batch_size) | |
| # Aggregate to per-position scores | |
| scores = aggregate_predictions( | |
| predictions, starts, seq_length, | |
| window_size=WINDOW_SIZE, aggregation=aggregation | |
| ) | |
| # Calculate overall score (mean of all positions) | |
| overall_score = float(np.mean(scores)) | |
| return PredictionResult( | |
| positions=list(range(seq_length)), | |
| probabilities=[float(p) for p in scores], | |
| overall_score=overall_score, | |
| sequence_length=seq_length, | |
| num_windows=len(windows) | |
| ) | |
| def embed_batch( | |
| model: tf.keras.Model, | |
| windows: np.ndarray, | |
| batch_size: int = 32 | |
| ) -> np.ndarray: | |
| """ | |
| Extract embeddings from batched windows. | |
| Args: | |
| model: Embedding model | |
| windows: Array of shape (N, window_size) | |
| batch_size: Batch size for inference | |
| Returns: | |
| Embeddings of shape (N, window_size, embed_dim) or (N, embed_dim) | |
| """ | |
| if batch_size <= 0: | |
| raise ValueError("batch_size must be a positive integer") | |
| expected_dtype = model.inputs[0].dtype | |
| windows = cast_for_model(windows, expected_dtype) | |
| n_windows = len(windows) | |
| embeddings = [] | |
| for i in range(0, n_windows, batch_size): | |
| batch = windows[i:i + batch_size] | |
| emb = model(batch, training=False).numpy() | |
| embeddings.append(emb) | |
| return np.concatenate(embeddings, axis=0) | |
| def embed_sequence( | |
| sequence: str, | |
| mode: Literal["mean", "cls", "max", "trajectory"] = "mean", | |
| stride: int = 100, | |
| batch_size: int = 32, | |
| model: Optional[tf.keras.Model] = None | |
| ) -> EmbeddingResult | TrajectoryResult: | |
| """ | |
| Extract hidden state embeddings from a sequence. | |
| Args: | |
| sequence: DNA sequence string | |
| mode: Embedding mode: | |
| - "mean": Mean pool over all positions and windows | |
| - "cls": Use first position embedding (similar to BERT [CLS]) | |
| - "max": Max pool over all positions | |
| - "trajectory": Return per-window embeddings for state-dynamics | |
| stride: Step size between windows (for trajectory mode) | |
| batch_size: Batch size for inference | |
| model: Optional embedding model (uses singleton if not provided) | |
| Returns: | |
| EmbeddingResult (for mean/cls/max) or TrajectoryResult (for trajectory) | |
| """ | |
| if mode not in {"mean", "cls", "max", "trajectory"}: | |
| raise ValueError("mode must be one of: mean, cls, max, trajectory") | |
| if batch_size <= 0: | |
| raise ValueError("batch_size must be a positive integer") | |
| # Tokenize sequence | |
| tokens = encode_sequence(sequence) | |
| seq_length = len(tokens) | |
| # Create windows | |
| windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride) | |
| if model is None: | |
| model = get_embedding_model() | |
| logger.info(f"Extracting embeddings: {seq_length} bp, {len(windows)} windows") | |
| # Get embeddings (shape: N, window_size, embed_dim) | |
| embeddings = embed_batch(model, windows, batch_size=batch_size) | |
| if mode == "trajectory": | |
| # Return per-window mean-pooled embeddings for state-dynamics | |
| # Shape: (N, embed_dim) after mean over positions | |
| window_embeddings = np.mean(embeddings, axis=1) | |
| return TrajectoryResult( | |
| embeddings=[emb.tolist() for emb in window_embeddings], | |
| positions=[int(s + WINDOW_SIZE // 2) for s in starts], # Center position | |
| method="mean_pooled_per_window", | |
| embedding_dim=int(window_embeddings.shape[-1]), | |
| sequence_length=seq_length, | |
| num_windows=len(windows) | |
| ) | |
| # Aggregate embeddings across all windows and positions | |
| if mode == "cls": | |
| # Take first position from first window | |
| final_embedding = embeddings[0, 0, :] | |
| method = "first_position_first_window" | |
| elif mode == "max": | |
| # Max pool across all positions and windows | |
| final_embedding = np.max(embeddings, axis=(0, 1)) | |
| method = "max_pooled_all" | |
| else: # mean | |
| # Mean pool across all positions and windows | |
| final_embedding = np.mean(embeddings, axis=(0, 1)) | |
| method = "mean_pooled_all" | |
| return EmbeddingResult( | |
| embedding=final_embedding.tolist(), | |
| method=method, | |
| embedding_dim=int(final_embedding.shape[0]), | |
| sequence_length=seq_length | |
| ) | |
| def detect_crispr_regions( | |
| sequence: str, | |
| threshold: float = 0.3, | |
| min_length: int = 160, | |
| merge_gap: int = 80, | |
| stride: int = 100, | |
| model: Optional[tf.keras.Model] = None, | |
| prediction_result: Optional[PredictionResult] = None | |
| ) -> list[dict]: | |
| """ | |
| Detect CRISPR array regions in a sequence. | |
| Args: | |
| sequence: DNA sequence string | |
| threshold: Probability threshold for calling CRISPR (default 0.3) | |
| min_length: Minimum region length in bp (default 160) | |
| merge_gap: Maximum gap to merge adjacent regions (default 80) | |
| stride: Sliding window stride (default 100) | |
| model: Optional model instance | |
| Returns: | |
| List of detected regions with coordinates and scores | |
| """ | |
| if not 0.0 <= threshold <= 1.0: | |
| raise ValueError("threshold must be between 0 and 1") | |
| if min_length < 1: | |
| raise ValueError("min_length must be at least 1") | |
| if merge_gap < 0: | |
| raise ValueError("merge_gap must be non-negative") | |
| # Get per-position predictions, or reuse a caller-provided result to avoid | |
| # running the model twice in UI flows that need both scores and regions. | |
| result = prediction_result or predict_sequence(sequence, stride=stride, model=model) | |
| scores = np.array(result.probabilities) | |
| # Threshold to binary mask | |
| mask = scores >= threshold | |
| # Find contiguous regions | |
| regions = [] | |
| in_region = False | |
| start = 0 | |
| for i, is_crispr in enumerate(mask): | |
| if is_crispr and not in_region: | |
| start = i | |
| in_region = True | |
| elif not is_crispr and in_region: | |
| regions.append((start, i)) | |
| in_region = False | |
| if in_region: | |
| regions.append((start, len(mask))) | |
| # Merge nearby regions | |
| if regions: | |
| merged = [list(regions[0])] | |
| for s, e in regions[1:]: | |
| prev_s, prev_e = merged[-1] | |
| if s <= prev_e + merge_gap: | |
| merged[-1][1] = e | |
| else: | |
| merged.append([s, e]) | |
| regions = [(s, e) for s, e in merged] | |
| # Filter by length and compute scores | |
| detected = [] | |
| for i, (start, end) in enumerate(regions): | |
| length = end - start | |
| if length >= min_length: | |
| region_scores = scores[start:end] | |
| detected.append({ | |
| "region_id": i + 1, | |
| "start": int(start + 1), # 1-based | |
| "end": int(end), # 1-based inclusive | |
| "length": length, | |
| "max_score": float(np.max(region_scores)), | |
| "mean_score": float(np.mean(region_scores)), | |
| }) | |
| return detected | |