genomenet's picture
Minimalist monochrome redesign with Geist Mono font
3cc5297
"""
DNA Sequence Tokenization for CRISPR BERT model.
Token mapping:
- 0: PAD/OOV (padding or unknown)
- 1: A (Adenine)
- 2: C (Cytosine)
- 3: G (Guanine)
- 4: T (Thymine)
- 5: AMB (Ambiguous - N and other IUPAC codes)
"""
import numpy as np
VOCAB_SIZE = 6
WINDOW_SIZE = 1000
# Lookup table: ASCII -> token ID
# Default to 5 (AMB) for any unknown character
_LUT = np.full(256, 5, dtype=np.uint8)
_LUT[ord("A")] = 1
_LUT[ord("C")] = 2
_LUT[ord("G")] = 3
_LUT[ord("T")] = 4
# Also handle lowercase
_LUT[ord("a")] = 1
_LUT[ord("c")] = 2
_LUT[ord("g")] = 3
_LUT[ord("t")] = 4
def _coerce_positive_int(name: str, value) -> int:
"""Accept int-like values from UI/API inputs and reject unsafe strides."""
if isinstance(value, bool):
raise ValueError(f"{name} must be a positive integer")
if isinstance(value, (int, np.integer)):
parsed = int(value)
elif isinstance(value, float) and value.is_integer():
parsed = int(value)
else:
raise ValueError(f"{name} must be a positive integer")
if parsed <= 0:
raise ValueError(f"{name} must be a positive integer")
return parsed
def encode_sequence(sequence: str) -> np.ndarray:
"""
Convert DNA sequence string to integer token array.
Args:
sequence: DNA sequence string (A, C, G, T, N, etc.)
Returns:
numpy array of uint8 token IDs
"""
# Convert to uppercase for consistency
seq_upper = sequence.upper()
# Convert to bytes and apply lookup
try:
seq_bytes = np.frombuffer(seq_upper.encode("ascii"), dtype=np.uint8)
except UnicodeEncodeError as exc:
raise ValueError("Sequence contains non-ASCII characters") from exc
return _LUT[seq_bytes]
def validate_sequence(sequence: str) -> tuple[bool, str]:
"""
Validate a DNA sequence for API input.
Args:
sequence: Input DNA sequence
Returns:
Tuple of (is_valid, error_message)
"""
if not sequence:
return False, "Sequence is empty"
if len(sequence) < WINDOW_SIZE:
return False, f"Sequence must be at least {WINDOW_SIZE} nucleotides (got {len(sequence)})"
# Check for valid characters (allow standard IUPAC codes)
valid_chars = set("ACGTNacgtnRYSWKMBDHVryswkmbdhv")
seq_chars = set(sequence)
invalid_chars = seq_chars - valid_chars
if invalid_chars:
invalid = ", ".join(repr(c) for c in sorted(invalid_chars))
return False, f"Invalid characters in sequence: {invalid}"
return True, ""
def strip_fasta_header(text: str) -> str:
"""
Remove FASTA header lines from input text.
Args:
text: Input text that may contain FASTA headers
Returns:
Sequence string with headers removed
"""
lines = text.strip().splitlines()
sequence_lines = []
for line in lines:
line = line.strip()
if not line or line.startswith(">"):
continue
sequence_lines.append(line)
return "".join(sequence_lines)
def create_windows(
tokens: np.ndarray,
window_size: int = WINDOW_SIZE,
stride: int = 100
) -> tuple[np.ndarray, np.ndarray]:
"""
Create sliding windows from tokenized sequence.
Args:
tokens: Tokenized sequence array
window_size: Size of each window (default 1000)
stride: Step size between windows (default 100)
Returns:
Tuple of (windows array, start positions array)
"""
window_size = _coerce_positive_int("window_size", window_size)
stride = _coerce_positive_int("stride", stride)
seq_len = len(tokens)
if seq_len < window_size:
# Pad short sequences
padded = np.zeros(window_size, dtype=tokens.dtype)
padded[:seq_len] = tokens
return padded.reshape(1, -1), np.array([0])
# Calculate number of windows
n_windows = (seq_len - window_size) // stride + 1
# Ensure we cover the end of the sequence
starts = np.arange(0, n_windows * stride, stride, dtype=np.int32)
# Add final window if needed
if starts[-1] + window_size < seq_len:
starts = np.append(starts, seq_len - window_size)
# Create windows
windows = np.array([tokens[s:s + window_size] for s in starts])
return windows, starts