Spaces:
No application file
No application file
| # -*- coding: utf-8 -*- | |
| """ | |
| Gene Prediction Model - predictor.py | |
| Boundary-aware deep learning model for gene prediction | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List, Tuple, Dict, Optional | |
| import logging | |
| import re | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # ============================= MODEL COMPONENTS ============================= | |
| class BoundaryAwareGenePredictor(nn.Module): | |
| """Multi-task model predicting genes, starts, and ends separately.""" | |
| def __init__(self, input_dim: int = 14, hidden_dim: int = 256, | |
| num_layers: int = 3, dropout: float = 0.3): | |
| super().__init__() | |
| self.conv_layers = nn.ModuleList([ | |
| nn.Conv1d(input_dim, hidden_dim//4, kernel_size=k, padding=k//2) | |
| for k in [3, 7, 15, 31] | |
| ]) | |
| self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, num_layers, | |
| batch_first=True, bidirectional=True, dropout=dropout) | |
| self.norm = nn.LayerNorm(hidden_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.boundary_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True) | |
| self.gene_classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim//2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim//2, 2) | |
| ) | |
| self.start_classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim//2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim//2, 2) | |
| ) | |
| self.end_classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim//2), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim//2, 2) | |
| ) | |
| def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: | |
| batch_size, seq_len, _ = x.shape | |
| x_conv = x.transpose(1, 2) | |
| conv_features = [F.relu(conv(x_conv)) for conv in self.conv_layers] | |
| features = torch.cat(conv_features, dim=1).transpose(1, 2) | |
| if lengths is not None: | |
| packed = nn.utils.rnn.pack_padded_sequence( | |
| features, lengths.cpu(), batch_first=True, enforce_sorted=False | |
| ) | |
| lstm_out, _ = self.lstm(packed) | |
| lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True) | |
| else: | |
| lstm_out, _ = self.lstm(features) | |
| lstm_out = self.norm(lstm_out) | |
| attended, _ = self.boundary_attention(lstm_out, lstm_out, lstm_out) | |
| attended = self.dropout(attended) | |
| return { | |
| 'gene': self.gene_classifier(attended), | |
| 'start': self.start_classifier(attended), | |
| 'end': self.end_classifier(attended) | |
| } | |
| # ============================= DATA PREPROCESSING ============================= | |
| class DNAProcessor: | |
| """DNA sequence processor with boundary-aware features.""" | |
| def __init__(self): | |
| self.base_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4} | |
| self.idx_to_base = {v: k for k, v in self.base_to_idx.items()} | |
| self.start_codons = {'ATG', 'GTG', 'TTG'} | |
| self.stop_codons = {'TAA', 'TAG', 'TGA'} | |
| def encode_sequence(self, sequence: str) -> torch.Tensor: | |
| sequence = sequence.upper() | |
| encoded = [self.base_to_idx.get(base, self.base_to_idx['N']) for base in sequence] | |
| return torch.tensor(encoded, dtype=torch.long) | |
| def create_enhanced_features(self, sequence: str) -> torch.Tensor: | |
| sequence = sequence.upper() | |
| seq_len = len(sequence) | |
| encoded = self.encode_sequence(sequence) | |
| # One-hot encoding | |
| one_hot = torch.zeros(seq_len, 5) | |
| one_hot.scatter_(1, encoded.unsqueeze(1), 1) | |
| features = [one_hot] | |
| # Start codon indicators (increased weights for GTG and TTG) | |
| start_indicators = torch.zeros(seq_len, 3) | |
| for i in range(seq_len - 2): | |
| codon = sequence[i:i+3] | |
| if codon == 'ATG': | |
| start_indicators[i:i+3, 0] = 1.0 | |
| elif codon == 'GTG': | |
| start_indicators[i:i+3, 1] = 0.9 # Increased from 0.7 | |
| elif codon == 'TTG': | |
| start_indicators[i:i+3, 2] = 0.8 # Increased from 0.5 | |
| features.append(start_indicators) | |
| # Stop codon indicators | |
| stop_indicators = torch.zeros(seq_len, 3) | |
| for i in range(seq_len - 2): | |
| codon = sequence[i:i+3] | |
| if codon == 'TAA': | |
| stop_indicators[i:i+3, 0] = 1.0 | |
| elif codon == 'TAG': | |
| stop_indicators[i:i+3, 1] = 1.0 | |
| elif codon == 'TGA': | |
| stop_indicators[i:i+3, 2] = 1.0 | |
| features.append(stop_indicators) | |
| # GC content | |
| gc_content = torch.zeros(seq_len, 1) | |
| window_size = 50 | |
| for i in range(seq_len): | |
| start = max(0, i - window_size//2) | |
| end = min(seq_len, i + window_size//2) | |
| window = sequence[start:end] | |
| gc_count = window.count('G') + window.count('C') | |
| gc_content[i, 0] = gc_count / len(window) if len(window) > 0 else 0 | |
| features.append(gc_content) | |
| # Position encoding | |
| pos_encoding = torch.zeros(seq_len, 2) | |
| positions = torch.arange(seq_len, dtype=torch.float) | |
| pos_encoding[:, 0] = torch.sin(positions / 10000) | |
| pos_encoding[:, 1] = torch.cos(positions / 10000) | |
| features.append(pos_encoding) | |
| return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14 | |
| # ============================= POST-PROCESSING ============================= | |
| class EnhancedPostProcessor: | |
| """Enhanced post-processor with stricter boundary detection.""" | |
| def __init__(self, min_gene_length: int = 150, max_gene_length: int = 5000): | |
| self.min_gene_length = min_gene_length | |
| self.max_gene_length = max_gene_length | |
| self.start_codons = {'ATG', 'GTG', 'TTG'} | |
| self.stop_codons = {'TAA', 'TAG', 'TGA'} | |
| def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray, | |
| end_probs: np.ndarray, sequence: str = None) -> np.ndarray: | |
| """Process predictions with enhanced boundary detection.""" | |
| # More conservative thresholds | |
| gene_pred = (gene_probs[:, 1] > 0.6).astype(int) | |
| start_pred = (start_probs[:, 1] > 0.4).astype(int) | |
| end_pred = (end_probs[:, 1] > 0.5).astype(int) | |
| if sequence is not None: | |
| processed = self._refine_with_codons_and_boundaries( | |
| gene_pred, start_pred, end_pred, sequence | |
| ) | |
| else: | |
| processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred) | |
| processed = self._apply_constraints(processed, sequence) | |
| return processed | |
| def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray, | |
| start_pred: np.ndarray, end_pred: np.ndarray, | |
| sequence: str) -> np.ndarray: | |
| refined = gene_pred.copy() | |
| sequence = sequence.upper() | |
| start_codon_positions = [] | |
| stop_codon_positions = [] | |
| for i in range(len(sequence) - 2): | |
| codon = sequence[i:i+3] | |
| if codon in self.start_codons: | |
| start_codon_positions.append(i) | |
| if codon in self.stop_codons: | |
| stop_codon_positions.append(i + 3) | |
| changes = np.diff(np.concatenate(([0], gene_pred, [0]))) | |
| gene_starts = np.where(changes == 1)[0] | |
| gene_ends = np.where(changes == -1)[0] | |
| refined = np.zeros_like(gene_pred) | |
| for g_start, g_end in zip(gene_starts, gene_ends): | |
| best_start = g_start | |
| start_window = 100 # Increased from 50 | |
| nearby_starts = [pos for pos in start_codon_positions | |
| if abs(pos - g_start) <= start_window] | |
| if nearby_starts: | |
| start_scores = [] | |
| for pos in nearby_starts: | |
| if pos < len(start_pred): | |
| codon = sequence[pos:pos+3] | |
| codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8) | |
| boundary_score = start_pred[pos] | |
| distance_penalty = abs(pos - g_start) / start_window * 0.2 # Add distance penalty | |
| score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty | |
| start_scores.append((score, pos)) | |
| if start_scores: | |
| best_start = max(start_scores, key=lambda x: x[0])[1] | |
| best_end = g_end | |
| end_window = 100 | |
| nearby_ends = [pos for pos in stop_codon_positions | |
| if g_start < pos <= g_end + end_window] | |
| if nearby_ends: | |
| end_scores = [] | |
| for pos in nearby_ends: | |
| gene_length = pos - best_start | |
| if self.min_gene_length <= gene_length <= self.max_gene_length: | |
| if pos < len(end_pred): | |
| frame_bonus = 0.2 if (pos - best_start) % 3 == 0 else 0 | |
| boundary_score = end_pred[pos] | |
| length_penalty = abs(gene_length - 1000) / 10000 | |
| score = boundary_score + frame_bonus - length_penalty | |
| end_scores.append((score, pos)) | |
| if end_scores: | |
| best_end = max(end_scores, key=lambda x: x[0])[1] | |
| gene_length = best_end - best_start | |
| if (gene_length >= self.min_gene_length and | |
| gene_length <= self.max_gene_length and | |
| best_start < best_end): | |
| refined[best_start:best_end] = 1 | |
| return refined | |
| def _refine_with_boundaries(self, gene_pred: np.ndarray, start_pred: np.ndarray, | |
| end_pred: np.ndarray) -> np.ndarray: | |
| refined = gene_pred.copy() | |
| changes = np.diff(np.concatenate(([0], gene_pred, [0]))) | |
| gene_starts = np.where(changes == 1)[0] | |
| gene_ends = np.where(changes == -1)[0] | |
| for g_start, g_end in zip(gene_starts, gene_ends): | |
| start_window = slice(max(0, g_start-30), min(len(start_pred), g_start+30)) | |
| start_candidates = np.where(start_pred[start_window])[0] | |
| if len(start_candidates) > 0: | |
| relative_positions = start_candidates + max(0, g_start-30) | |
| distances = np.abs(relative_positions - g_start) | |
| best_start_idx = np.argmin(distances) | |
| new_start = relative_positions[best_start_idx] | |
| refined[g_start:new_start] = 0 if new_start > g_start else refined[g_start:new_start] | |
| refined[new_start:g_end] = 1 | |
| g_start = new_start | |
| end_window = slice(max(0, g_end-50), min(len(end_pred), g_end+50)) | |
| end_candidates = np.where(end_pred[end_window])[0] | |
| if len(end_candidates) > 0: | |
| relative_positions = end_candidates + max(0, g_end-50) | |
| valid_ends = [pos for pos in relative_positions | |
| if self.min_gene_length <= pos - g_start <= self.max_gene_length] | |
| if valid_ends: | |
| distances = np.abs(np.array(valid_ends) - g_end) | |
| new_end = valid_ends[np.argmin(distances)] | |
| refined[g_start:new_end] = 1 | |
| refined[new_end:g_end] = 0 if new_end < g_end else refined[new_end:g_end] | |
| return refined | |
| def _apply_constraints(self, predictions: np.ndarray, sequence: str = None) -> np.ndarray: | |
| processed = predictions.copy() | |
| changes = np.diff(np.concatenate(([0], predictions, [0]))) | |
| starts = np.where(changes == 1)[0] | |
| ends = np.where(changes == -1)[0] | |
| for start, end in zip(starts, ends): | |
| gene_length = end - start | |
| if gene_length < self.min_gene_length or gene_length > self.max_gene_length: | |
| processed[start:end] = 0 | |
| continue | |
| if sequence is not None: | |
| if gene_length % 3 != 0: | |
| new_length = (gene_length // 3) * 3 | |
| if new_length >= self.min_gene_length: | |
| new_end = start + new_length | |
| processed[new_end:end] = 0 | |
| else: | |
| processed[start:end] = 0 | |
| return processed | |
| # ============================= PREDICTION ============================= | |
| class GenePredictor: | |
| """Handles gene prediction using the trained boundary-aware model.""" | |
| def __init__(self, model_path: str = 'best_boundary_aware_model.pth', | |
| device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): | |
| self.device = device | |
| self.model = BoundaryAwareGenePredictor(input_dim=14).to(device) | |
| try: | |
| self.model.load_state_dict(torch.load(model_path, map_location=device)) | |
| logging.info(f"Loaded model from {model_path}") | |
| except Exception as e: | |
| logging.error(f"Failed to load model: {e}") | |
| raise | |
| self.model.eval() | |
| self.processor = DNAProcessor() | |
| self.post_processor = EnhancedPostProcessor() | |
| def predict(self, sequence: str) -> Tuple[np.ndarray, Dict[str, np.ndarray], float]: | |
| sequence = sequence.upper() | |
| if not re.match('^[ACTGN]+$', sequence): | |
| logging.warning("Sequence contains invalid characters. Using 'N' for unknowns.") | |
| sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence) | |
| features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(features) | |
| gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0] | |
| start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0] | |
| end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0] | |
| predictions = self.post_processor.process_predictions( | |
| gene_probs, start_probs, end_probs, sequence | |
| ) | |
| confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0 | |
| return predictions, {'gene': gene_probs, 'start': start_probs, 'end': end_probs}, confidence | |
| def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]: | |
| regions = [] | |
| changes = np.diff(np.concatenate(([0], predictions, [0]))) | |
| starts = np.where(changes == 1)[0] | |
| ends = np.where(changes == -1)[0] | |
| for start, end in zip(starts, ends): | |
| gene_seq = sequence[start:end] | |
| actual_start_codon = None | |
| actual_stop_codon = None | |
| if len(gene_seq) >= 3: | |
| start_codon = gene_seq[:3] | |
| if start_codon in ['ATG', 'GTG', 'TTG']: | |
| actual_start_codon = start_codon | |
| if len(gene_seq) >= 6: | |
| for i in range(len(gene_seq) - 2, 2, -3): | |
| codon = gene_seq[i:i+3] | |
| if codon in ['TAA', 'TAG', 'TGA']: | |
| actual_stop_codon = codon | |
| break | |
| regions.append({ | |
| 'start': int(start), # Convert to Python int for JSON serialization | |
| 'end': int(end), | |
| 'sequence': gene_seq, # Return full sequence | |
| 'length': int(end - start), | |
| 'start_codon': actual_start_codon, | |
| 'stop_codon': actual_stop_codon, | |
| 'in_frame': (end - start) % 3 == 0 | |
| }) | |
| return regions | |
| def compute_accuracy(self, predictions: np.ndarray, labels: List[int]) -> Dict: | |
| min_len = min(len(predictions), len(labels)) | |
| predictions = predictions[:min_len] | |
| labels = np.array(labels[:min_len]) | |
| accuracy = np.mean(predictions == labels) | |
| true_pos = np.sum((predictions == 1) & (labels == 1)) | |
| false_neg = np.sum((predictions == 0) & (labels == 1)) | |
| false_pos = np.sum((predictions == 1) & (labels == 0)) | |
| precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) > 0 else 0.0 | |
| recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) > 0 else 0.0 | |
| f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| return { | |
| 'accuracy': accuracy, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'f1': f1, | |
| 'true_positives': int(true_pos), | |
| 'false_positives': int(false_pos), | |
| 'false_negatives': int(false_neg) | |
| } | |
| def labels_from_coordinates(self, seq_len: int, start: int, end: int) -> List[int]: | |
| labels = [0] * seq_len | |
| start = max(0, min(start, seq_len - 1)) | |
| end = max(start, min(end, seq_len)) | |
| for i in range(start, end): | |
| labels[i] = 1 | |
| return labels |