genomenet's picture
Minimalist monochrome redesign with Geist Mono font
3cc5297
"""
Core inference functions for CRISPR Array Detection.
Provides:
- predict_sequence(): Per-position CRISPR probability scores
- embed_sequence(): Hidden state embeddings for state-dynamics visualization
"""
import logging
from typing import Optional, Literal
from dataclasses import dataclass
import numpy as np
import tensorflow as tf
from .model_loader import get_model, get_embedding_model
from .tokenizer import encode_sequence, create_windows, WINDOW_SIZE
logger = logging.getLogger(__name__)
@dataclass
class PredictionResult:
"""Result of CRISPR prediction."""
positions: list[int]
probabilities: list[float]
overall_score: float
sequence_length: int
num_windows: int
@dataclass
class EmbeddingResult:
"""Result of embedding extraction."""
embedding: list[float]
method: str
embedding_dim: int
sequence_length: int
@dataclass
class TrajectoryResult:
"""Result of trajectory embedding extraction."""
embeddings: list[list[float]]
positions: list[int]
method: str
embedding_dim: int
sequence_length: int
num_windows: int
def cast_for_model(X: np.ndarray, expected_dtype: tf.dtypes.DType) -> np.ndarray:
"""Cast input array to the dtype expected by the model."""
if expected_dtype.is_floating:
return X.astype(np.float32, copy=False)
if expected_dtype == tf.int64:
return X.astype(np.int64, copy=False)
return X.astype(np.int32, copy=False)
def predict_batch(
model: tf.keras.Model,
windows: np.ndarray,
batch_size: int = 32
) -> np.ndarray:
"""
Run prediction on batched windows.
Args:
model: Keras model
windows: Array of shape (N, window_size)
batch_size: Batch size for inference
Returns:
Predictions of shape (N, window_size) with probabilities
"""
if batch_size <= 0:
raise ValueError("batch_size must be a positive integer")
expected_dtype = model.inputs[0].dtype
windows = cast_for_model(windows, expected_dtype)
n_windows = len(windows)
window_size = windows.shape[1]
# Pre-allocate output
predictions = np.empty((n_windows, window_size), dtype=np.float32)
for i in range(0, n_windows, batch_size):
batch = windows[i:i + batch_size]
pred = model(batch, training=False)
# Handle different output shapes
if len(pred.shape) == 3:
if pred.shape[-1] == 1:
pred = tf.squeeze(pred, axis=-1)
else:
# Multi-class output - take class 1 probability (CRISPR positive)
pred = pred[..., 1] if pred.shape[-1] > 1 else pred[..., 0]
pred = tf.clip_by_value(pred, 0.0, 1.0).numpy()
predictions[i:i + len(batch)] = pred
return predictions
def aggregate_predictions(
predictions: np.ndarray,
starts: np.ndarray,
seq_length: int,
window_size: int = WINDOW_SIZE,
aggregation: Literal["mean", "max"] = "mean"
) -> np.ndarray:
"""
Aggregate overlapping window predictions into per-position scores.
Args:
predictions: Array of shape (N, window_size) with per-window predictions
starts: Start positions of each window
seq_length: Total sequence length
window_size: Size of each window
aggregation: Aggregation method ("mean" or "max")
Returns:
Per-position probability array of shape (seq_length,)
"""
if aggregation not in {"mean", "max"}:
raise ValueError("aggregation must be 'mean' or 'max'")
scores = np.zeros(seq_length, dtype=np.float32)
counts = np.zeros(seq_length, dtype=np.int32)
for i, start in enumerate(starts):
end = min(start + window_size, seq_length)
pred_len = end - start
if aggregation == "max":
scores[start:end] = np.maximum(scores[start:end], predictions[i, :pred_len])
else: # mean
scores[start:end] += predictions[i, :pred_len]
counts[start:end] += 1
if aggregation == "mean":
# Avoid division by zero
counts = np.maximum(counts, 1)
scores = scores / counts
return scores
def predict_sequence(
sequence: str,
stride: int = 100,
batch_size: int = 32,
aggregation: Literal["mean", "max"] = "mean",
model: Optional[tf.keras.Model] = None
) -> PredictionResult:
"""
Predict CRISPR array probability for each position in a sequence.
Args:
sequence: DNA sequence string
stride: Step size between sliding windows (default 100)
batch_size: Batch size for inference (default 32)
aggregation: How to aggregate overlapping predictions ("mean" or "max")
model: Optional model instance (uses singleton if not provided)
Returns:
PredictionResult with per-position probabilities
"""
if aggregation not in {"mean", "max"}:
raise ValueError("aggregation must be 'mean' or 'max'")
if batch_size <= 0:
raise ValueError("batch_size must be a positive integer")
# Tokenize sequence
tokens = encode_sequence(sequence)
seq_length = len(tokens)
# Create sliding windows
windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride)
if model is None:
model = get_model()
logger.info(f"Processing sequence: {seq_length} bp, {len(windows)} windows (stride={stride})")
# Run batched prediction
predictions = predict_batch(model, windows, batch_size=batch_size)
# Aggregate to per-position scores
scores = aggregate_predictions(
predictions, starts, seq_length,
window_size=WINDOW_SIZE, aggregation=aggregation
)
# Calculate overall score (mean of all positions)
overall_score = float(np.mean(scores))
return PredictionResult(
positions=list(range(seq_length)),
probabilities=[float(p) for p in scores],
overall_score=overall_score,
sequence_length=seq_length,
num_windows=len(windows)
)
def embed_batch(
model: tf.keras.Model,
windows: np.ndarray,
batch_size: int = 32
) -> np.ndarray:
"""
Extract embeddings from batched windows.
Args:
model: Embedding model
windows: Array of shape (N, window_size)
batch_size: Batch size for inference
Returns:
Embeddings of shape (N, window_size, embed_dim) or (N, embed_dim)
"""
if batch_size <= 0:
raise ValueError("batch_size must be a positive integer")
expected_dtype = model.inputs[0].dtype
windows = cast_for_model(windows, expected_dtype)
n_windows = len(windows)
embeddings = []
for i in range(0, n_windows, batch_size):
batch = windows[i:i + batch_size]
emb = model(batch, training=False).numpy()
embeddings.append(emb)
return np.concatenate(embeddings, axis=0)
def embed_sequence(
sequence: str,
mode: Literal["mean", "cls", "max", "trajectory"] = "mean",
stride: int = 100,
batch_size: int = 32,
model: Optional[tf.keras.Model] = None
) -> EmbeddingResult | TrajectoryResult:
"""
Extract hidden state embeddings from a sequence.
Args:
sequence: DNA sequence string
mode: Embedding mode:
- "mean": Mean pool over all positions and windows
- "cls": Use first position embedding (similar to BERT [CLS])
- "max": Max pool over all positions
- "trajectory": Return per-window embeddings for state-dynamics
stride: Step size between windows (for trajectory mode)
batch_size: Batch size for inference
model: Optional embedding model (uses singleton if not provided)
Returns:
EmbeddingResult (for mean/cls/max) or TrajectoryResult (for trajectory)
"""
if mode not in {"mean", "cls", "max", "trajectory"}:
raise ValueError("mode must be one of: mean, cls, max, trajectory")
if batch_size <= 0:
raise ValueError("batch_size must be a positive integer")
# Tokenize sequence
tokens = encode_sequence(sequence)
seq_length = len(tokens)
# Create windows
windows, starts = create_windows(tokens, window_size=WINDOW_SIZE, stride=stride)
if model is None:
model = get_embedding_model()
logger.info(f"Extracting embeddings: {seq_length} bp, {len(windows)} windows")
# Get embeddings (shape: N, window_size, embed_dim)
embeddings = embed_batch(model, windows, batch_size=batch_size)
if mode == "trajectory":
# Return per-window mean-pooled embeddings for state-dynamics
# Shape: (N, embed_dim) after mean over positions
window_embeddings = np.mean(embeddings, axis=1)
return TrajectoryResult(
embeddings=[emb.tolist() for emb in window_embeddings],
positions=[int(s + WINDOW_SIZE // 2) for s in starts], # Center position
method="mean_pooled_per_window",
embedding_dim=int(window_embeddings.shape[-1]),
sequence_length=seq_length,
num_windows=len(windows)
)
# Aggregate embeddings across all windows and positions
if mode == "cls":
# Take first position from first window
final_embedding = embeddings[0, 0, :]
method = "first_position_first_window"
elif mode == "max":
# Max pool across all positions and windows
final_embedding = np.max(embeddings, axis=(0, 1))
method = "max_pooled_all"
else: # mean
# Mean pool across all positions and windows
final_embedding = np.mean(embeddings, axis=(0, 1))
method = "mean_pooled_all"
return EmbeddingResult(
embedding=final_embedding.tolist(),
method=method,
embedding_dim=int(final_embedding.shape[0]),
sequence_length=seq_length
)
def detect_crispr_regions(
sequence: str,
threshold: float = 0.3,
min_length: int = 160,
merge_gap: int = 80,
stride: int = 100,
model: Optional[tf.keras.Model] = None,
prediction_result: Optional[PredictionResult] = None
) -> list[dict]:
"""
Detect CRISPR array regions in a sequence.
Args:
sequence: DNA sequence string
threshold: Probability threshold for calling CRISPR (default 0.3)
min_length: Minimum region length in bp (default 160)
merge_gap: Maximum gap to merge adjacent regions (default 80)
stride: Sliding window stride (default 100)
model: Optional model instance
Returns:
List of detected regions with coordinates and scores
"""
if not 0.0 <= threshold <= 1.0:
raise ValueError("threshold must be between 0 and 1")
if min_length < 1:
raise ValueError("min_length must be at least 1")
if merge_gap < 0:
raise ValueError("merge_gap must be non-negative")
# Get per-position predictions, or reuse a caller-provided result to avoid
# running the model twice in UI flows that need both scores and regions.
result = prediction_result or predict_sequence(sequence, stride=stride, model=model)
scores = np.array(result.probabilities)
# Threshold to binary mask
mask = scores >= threshold
# Find contiguous regions
regions = []
in_region = False
start = 0
for i, is_crispr in enumerate(mask):
if is_crispr and not in_region:
start = i
in_region = True
elif not is_crispr and in_region:
regions.append((start, i))
in_region = False
if in_region:
regions.append((start, len(mask)))
# Merge nearby regions
if regions:
merged = [list(regions[0])]
for s, e in regions[1:]:
prev_s, prev_e = merged[-1]
if s <= prev_e + merge_gap:
merged[-1][1] = e
else:
merged.append([s, e])
regions = [(s, e) for s, e in merged]
# Filter by length and compute scores
detected = []
for i, (start, end) in enumerate(regions):
length = end - start
if length >= min_length:
region_scores = scores[start:end]
detected.append({
"region_id": i + 1,
"start": int(start + 1), # 1-based
"end": int(end), # 1-based inclusive
"length": length,
"max_score": float(np.max(region_scores)),
"mean_score": float(np.mean(region_scores)),
})
return detected