import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import List, Tuple, Dict, Optional, Union import logging import re import os from pathlib import Path # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # ============================= FILE READERS ============================= class FileReader: """Handles reading DNA sequences from various file formats.""" @staticmethod def read_fasta(file_path: str) -> Dict[str, str]: """ Read FASTA file and return dictionary of sequence_id: sequence """ sequences = {} current_id = None current_seq = [] try: with open(file_path, 'r', encoding='utf-8') as file: for line in file: line = line.strip() if line.startswith('>'): # Save previous sequence if exists if current_id is not None: sequences[current_id] = ''.join(current_seq) # Start new sequence current_id = line[1:] # Remove '>' character current_seq = [] elif line and current_id is not None: # Add sequence line (remove any whitespace) current_seq.append(line.replace(' ', '').replace('\t', '')) # Don't forget the last sequence if current_id is not None: sequences[current_id] = ''.join(current_seq) except Exception as e: logging.error(f"Error reading FASTA file {file_path}: {e}") raise return sequences @staticmethod def read_txt(file_path: str) -> str: """ Read plain text file containing DNA sequence """ try: with open(file_path, 'r', encoding='utf-8') as file: content = file.read().strip() # Remove any whitespace, newlines, and non-DNA characters sequence = ''.join(c.upper() for c in content if c.upper() in 'ACTGN') return sequence except Exception as e: logging.error(f"Error reading TXT file {file_path}: {e}") raise @staticmethod def detect_file_type(file_path: str) -> str: """ Detect file type based on extension and content """ file_path = Path(file_path) extension = file_path.suffix.lower() if extension in ['.fasta', '.fa', '.fas', '.fna']: return 'fasta' elif extension in ['.txt', '.seq']: return 'txt' else: # Try to detect by content try: with open(file_path, 'r', encoding='utf-8') as file: first_line = file.readline().strip() if first_line.startswith('>'): return 'fasta' else: return 'txt' except: logging.warning(f"Could not detect file type for {file_path}, assuming txt") return 'txt' # ============================= ORIGINAL MODEL COMPONENTS ============================= # (Including all the original classes: BoundaryAwareGenePredictor, DNAProcessor, EnhancedPostProcessor) class BoundaryAwareGenePredictor(nn.Module): """Multi-task model predicting genes, starts, and ends separately.""" def __init__(self, input_dim: int = 14, hidden_dim: int = 256, num_layers: int = 3, dropout: float = 0.3): super().__init__() self.conv_layers = nn.ModuleList([ nn.Conv1d(input_dim, hidden_dim//4, kernel_size=k, padding=k//2) for k in [3, 7, 15, 31] ]) self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, num_layers, batch_first=True, bidirectional=True, dropout=dropout) self.norm = nn.LayerNorm(hidden_dim) self.dropout = nn.Dropout(dropout) self.boundary_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True) self.gene_classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim//2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim//2, 2) ) self.start_classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim//2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim//2, 2) ) self.end_classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim//2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim//2, 2) ) def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: batch_size, seq_len, _ = x.shape x_conv = x.transpose(1, 2) conv_features = [F.relu(conv(x_conv)) for conv in self.conv_layers] features = torch.cat(conv_features, dim=1).transpose(1, 2) if lengths is not None: packed = nn.utils.rnn.pack_padded_sequence( features, lengths.cpu(), batch_first=True, enforce_sorted=False ) lstm_out, _ = self.lstm(packed) lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True) else: lstm_out, _ = self.lstm(features) lstm_out = self.norm(lstm_out) attended, _ = self.boundary_attention(lstm_out, lstm_out, lstm_out) attended = self.dropout(attended) return { 'gene': self.gene_classifier(attended), 'start': self.start_classifier(attended), 'end': self.end_classifier(attended) } class DNAProcessor: """DNA sequence processor with boundary-aware features.""" def __init__(self): self.base_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4} self.idx_to_base = {v: k for k, v in self.base_to_idx.items()} self.start_codons = {'ATG', 'GTG', 'TTG'} self.stop_codons = {'TAA', 'TAG', 'TGA'} def encode_sequence(self, sequence: str) -> torch.Tensor: sequence = sequence.upper() encoded = [self.base_to_idx.get(base, self.base_to_idx['N']) for base in sequence] return torch.tensor(encoded, dtype=torch.long) def create_enhanced_features(self, sequence: str) -> torch.Tensor: sequence = sequence.upper() seq_len = len(sequence) encoded = self.encode_sequence(sequence) # One-hot encoding one_hot = torch.zeros(seq_len, 5) one_hot.scatter_(1, encoded.unsqueeze(1), 1) features = [one_hot] # Start codon indicators start_indicators = torch.zeros(seq_len, 3) for i in range(seq_len - 2): codon = sequence[i:i+3] if codon == 'ATG': start_indicators[i:i+3, 0] = 1.0 elif codon == 'GTG': start_indicators[i:i+3, 1] = 0.9 elif codon == 'TTG': start_indicators[i:i+3, 2] = 0.8 features.append(start_indicators) # Stop codon indicators stop_indicators = torch.zeros(seq_len, 3) for i in range(seq_len - 2): codon = sequence[i:i+3] if codon == 'TAA': stop_indicators[i:i+3, 0] = 1.0 elif codon == 'TAG': stop_indicators[i:i+3, 1] = 1.0 elif codon == 'TGA': stop_indicators[i:i+3, 2] = 1.0 features.append(stop_indicators) # GC content gc_content = torch.zeros(seq_len, 1) window_size = 50 for i in range(seq_len): start = max(0, i - window_size//2) end = min(seq_len, i + window_size//2) window = sequence[start:end] gc_count = window.count('G') + window.count('C') gc_content[i, 0] = gc_count / len(window) if len(window) > 0 else 0 features.append(gc_content) # Position encoding pos_encoding = torch.zeros(seq_len, 2) positions = torch.arange(seq_len, dtype=torch.float) pos_encoding[:, 0] = torch.sin(positions / 10000) pos_encoding[:, 1] = torch.cos(positions / 10000) features.append(pos_encoding) return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14 class EnhancedPostProcessor: """Enhanced post-processor with stricter boundary detection.""" def __init__(self, min_gene_length: int = 150, max_gene_length: int = 5000): self.min_gene_length = min_gene_length self.max_gene_length = max_gene_length self.start_codons = {'ATG', 'GTG', 'TTG'} self.stop_codons = {'TAA', 'TAG', 'TGA'} def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray, end_probs: np.ndarray, sequence: str = None) -> np.ndarray: """Process predictions with enhanced boundary detection.""" gene_pred = (gene_probs[:, 1] > 0.6).astype(int) start_pred = (start_probs[:, 1] > 0.4).astype(int) end_pred = (end_probs[:, 1] > 0.5).astype(int) if sequence is not None: processed = self._refine_with_codons_and_boundaries( gene_pred, start_pred, end_pred, sequence ) else: processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred) processed = self._apply_constraints(processed, sequence) return processed def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray, start_pred: np.ndarray, end_pred: np.ndarray, sequence: str) -> np.ndarray: refined = gene_pred.copy() sequence = sequence.upper() start_codon_positions = [] stop_codon_positions = [] for i in range(len(sequence) - 2): codon = sequence[i:i+3] if codon in self.start_codons: start_codon_positions.append(i) if codon in self.stop_codons: stop_codon_positions.append(i + 3) changes = np.diff(np.concatenate(([0], gene_pred, [0]))) gene_starts = np.where(changes == 1)[0] gene_ends = np.where(changes == -1)[0] refined = np.zeros_like(gene_pred) for g_start, g_end in zip(gene_starts, gene_ends): best_start = g_start start_window = 100 nearby_starts = [pos for pos in start_codon_positions if abs(pos - g_start) <= start_window] if nearby_starts: start_scores = [] for pos in nearby_starts: if pos < len(start_pred): codon = sequence[pos:pos+3] codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8) boundary_score = start_pred[pos] distance_penalty = abs(pos - g_start) / start_window * 0.2 score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty start_scores.append((score, pos)) if start_scores: best_start = max(start_scores, key=lambda x: x[0])[1] best_end = g_end end_window = 100 nearby_ends = [pos for pos in stop_codon_positions if g_start < pos <= g_end + end_window] if nearby_ends: end_scores = [] for pos in nearby_ends: gene_length = pos - best_start if self.min_gene_length <= gene_length <= self.max_gene_length: if pos < len(end_pred): frame_bonus = 0.2 if (pos - best_start) % 3 == 0 else 0 boundary_score = end_pred[pos] length_penalty = abs(gene_length - 1000) / 10000 score = boundary_score + frame_bonus - length_penalty end_scores.append((score, pos)) if end_scores: best_end = max(end_scores, key=lambda x: x[0])[1] gene_length = best_end - best_start if (gene_length >= self.min_gene_length and gene_length <= self.max_gene_length and best_start < best_end): refined[best_start:best_end] = 1 return refined def _refine_with_boundaries(self, gene_pred: np.ndarray, start_pred: np.ndarray, end_pred: np.ndarray) -> np.ndarray: refined = gene_pred.copy() changes = np.diff(np.concatenate(([0], gene_pred, [0]))) gene_starts = np.where(changes == 1)[0] gene_ends = np.where(changes == -1)[0] for g_start, g_end in zip(gene_starts, gene_ends): start_window = slice(max(0, g_start-30), min(len(start_pred), g_start+30)) start_candidates = np.where(start_pred[start_window])[0] if len(start_candidates) > 0: relative_positions = start_candidates + max(0, g_start-30) distances = np.abs(relative_positions - g_start) best_start_idx = np.argmin(distances) new_start = relative_positions[best_start_idx] refined[g_start:new_start] = 0 if new_start > g_start else refined[g_start:new_start] refined[new_start:g_end] = 1 g_start = new_start end_window = slice(max(0, g_end-50), min(len(end_pred), g_end+50)) end_candidates = np.where(end_pred[end_window])[0] if len(end_candidates) > 0: relative_positions = end_candidates + max(0, g_end-50) valid_ends = [pos for pos in relative_positions if self.min_gene_length <= pos - g_start <= self.max_gene_length] if valid_ends: distances = np.abs(np.array(valid_ends) - g_end) new_end = valid_ends[np.argmin(distances)] refined[g_start:new_end] = 1 refined[new_end:g_end] = 0 if new_end < g_end else refined[new_end:g_end] return refined def _apply_constraints(self, predictions: np.ndarray, sequence: str = None) -> np.ndarray: processed = predictions.copy() changes = np.diff(np.concatenate(([0], predictions, [0]))) starts = np.where(changes == 1)[0] ends = np.where(changes == -1)[0] for start, end in zip(starts, ends): gene_length = end - start if gene_length < self.min_gene_length or gene_length > self.max_gene_length: processed[start:end] = 0 continue if sequence is not None: if gene_length % 3 != 0: new_length = (gene_length // 3) * 3 if new_length >= self.min_gene_length: new_end = start + new_length processed[new_end:end] = 0 else: processed[start:end] = 0 return processed # ============================= ENHANCED GENE PREDICTOR ============================= class EnhancedGenePredictor: """Enhanced Gene Predictor with file input support.""" def __init__(self, model_path: str = 'model/best_boundary_aware_model.pth', device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): self.device = device self.model = BoundaryAwareGenePredictor(input_dim=14).to(device) try: self.model.load_state_dict(torch.load(model_path, map_location=device)) logging.info(f"Loaded model from {model_path}") except Exception as e: logging.error(f"Failed to load model: {e}") raise self.model.eval() self.processor = DNAProcessor() self.post_processor = EnhancedPostProcessor() self.file_reader = FileReader() def predict_from_file(self, file_path: str) -> Dict[str, Dict]: """ Predict genes from a file (.txt or .fasta) Returns a dictionary with sequence_id as keys and prediction results as values """ if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") file_type = self.file_reader.detect_file_type(file_path) logging.info(f"Detected file type: {file_type}") results = {} if file_type == 'fasta': sequences = self.file_reader.read_fasta(file_path) for seq_id, sequence in sequences.items(): logging.info(f"Processing sequence: {seq_id} (length: {len(sequence)})") result = self.predict_sequence(sequence, seq_id) results[seq_id] = result else: # txt file sequence = self.file_reader.read_txt(file_path) seq_id = Path(file_path).stem # Use filename as sequence ID logging.info(f"Processing sequence from {file_path} (length: {len(sequence)})") result = self.predict_sequence(sequence, seq_id) results[seq_id] = result return results def predict_sequence(self, sequence: str, seq_id: str = "sequence") -> Dict: """ Predict genes from a single DNA sequence string """ sequence = sequence.upper() if not re.match('^[ACTGN]+$', sequence): logging.warning(f"Sequence {seq_id} contains invalid characters. Using 'N' for unknowns.") sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence) # Handle very long sequences by chunking if needed max_chunk_size = 50000 # Adjust based on your GPU memory if len(sequence) > max_chunk_size: return self._predict_long_sequence(sequence, seq_id, max_chunk_size) features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(features) gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0] start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0] end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0] predictions = self.post_processor.process_predictions( gene_probs, start_probs, end_probs, sequence ) confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0 gene_regions = self.extract_gene_regions(predictions, sequence) return { 'sequence_id': seq_id, 'sequence_length': len(sequence), 'predictions': predictions.tolist(), 'probabilities': { 'gene': gene_probs.tolist(), 'start': start_probs.tolist(), 'end': end_probs.tolist() }, 'confidence': float(confidence), 'gene_regions': gene_regions, 'total_genes_found': len(gene_regions) } def _predict_long_sequence(self, sequence: str, seq_id: str, chunk_size: int) -> Dict: """ Handle very long sequences by processing in chunks with overlap """ overlap = 1000 # Overlap between chunks to avoid missing genes at boundaries all_predictions = [] all_gene_probs = [] all_start_probs = [] all_end_probs = [] for i in range(0, len(sequence), chunk_size - overlap): end_pos = min(i + chunk_size, len(sequence)) chunk = sequence[i:end_pos] logging.info(f"Processing chunk {i//chunk_size + 1} of sequence {seq_id}") features = self.processor.create_enhanced_features(chunk).unsqueeze(0).to(self.device) with torch.no_grad(): outputs = self.model(features) gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0] start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0] end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0] chunk_predictions = self.post_processor.process_predictions( gene_probs, start_probs, end_probs, chunk ) # Handle overlaps by taking the first chunk fully and subsequent chunks without overlap if i == 0: all_predictions.extend(chunk_predictions) all_gene_probs.extend(gene_probs) all_start_probs.extend(start_probs) all_end_probs.extend(end_probs) else: # Skip overlap region skip = min(overlap, len(chunk_predictions)) all_predictions.extend(chunk_predictions[skip:]) all_gene_probs.extend(gene_probs[skip:]) all_start_probs.extend(start_probs[skip:]) all_end_probs.extend(end_probs[skip:]) predictions = np.array(all_predictions) gene_probs = np.array(all_gene_probs) start_probs = np.array(all_start_probs) end_probs = np.array(all_end_probs) confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0 gene_regions = self.extract_gene_regions(predictions, sequence) return { 'sequence_id': seq_id, 'sequence_length': len(sequence), 'predictions': predictions.tolist(), 'probabilities': { 'gene': gene_probs.tolist(), 'start': start_probs.tolist(), 'end': end_probs.tolist() }, 'confidence': float(confidence), 'gene_regions': gene_regions, 'total_genes_found': len(gene_regions) } def predict_from_text(self, sequence: str) -> Dict: """ Predict genes from a text string (backward compatibility) """ return self.predict_sequence(sequence) def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]: """Extract gene regions from predictions""" regions = [] changes = np.diff(np.concatenate(([0], predictions, [0]))) starts = np.where(changes == 1)[0] ends = np.where(changes == -1)[0] for start, end in zip(starts, ends): gene_seq = sequence[start:end] actual_start_codon = None actual_stop_codon = None if len(gene_seq) >= 3: start_codon = gene_seq[:3] if start_codon in ['ATG', 'GTG', 'TTG']: actual_start_codon = start_codon if len(gene_seq) >= 6: for i in range(len(gene_seq) - 2, 2, -3): codon = gene_seq[i:i+3] if codon in ['TAA', 'TAG', 'TGA']: actual_stop_codon = codon break regions.append({ 'start': int(start), 'end': int(end), 'sequence': gene_seq, 'length': int(end - start), 'start_codon': actual_start_codon, 'stop_codon': actual_stop_codon, 'in_frame': (end - start) % 3 == 0 }) return regions def save_results(self, results: Dict[str, Dict], output_path: str, format: str = 'json'): """ Save prediction results to file """ import json if format.lower() == 'json': with open(output_path, 'w') as f: json.dump(results, f, indent=2) elif format.lower() == 'csv': import csv with open(output_path, 'w', newline='') as f: writer = csv.writer(f) writer.writerow(['sequence_id', 'gene_start', 'gene_end', 'gene_length', 'start_codon', 'stop_codon', 'in_frame', 'confidence']) for seq_id, result in results.items(): for gene in result['gene_regions']: writer.writerow([ seq_id, gene['start'], gene['end'], gene['length'], gene['start_codon'], gene['stop_codon'], gene['in_frame'], result['confidence'] ]) logging.info(f"Results saved to {output_path}") # ============================= USAGE EXAMPLE ============================= def main(): """Example usage of the enhanced gene predictor""" # Initialize predictor predictor = EnhancedGenePredictor(model_path='model/best_boundary_aware_model.pth') # Example 1: Predict from FASTA file try: fasta_results = predictor.predict_from_file('example.fasta') predictor.save_results(fasta_results, 'fasta_predictions.json') print("FASTA predictions saved to fasta_predictions.json") except FileNotFoundError: print("example.fasta not found, skipping FASTA example") # Example 2: Predict from TXT file try: txt_results = predictor.predict_from_file('example.txt') predictor.save_results(txt_results, 'txt_predictions.csv', format='csv') print("TXT predictions saved to txt_predictions.csv") except FileNotFoundError: print("example.txt not found, skipping TXT example") # Example 3: Predict from text string (original functionality) example_sequence = "ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCACAGGTAACGGTGCGGGCTGA" text_results = predictor.predict_from_text(example_sequence) print(f"Found {text_results['total_genes_found']} genes in example sequence") if __name__ == "__main__": main()