Spaces:
No application file
No application file
Update predictor.py
Browse files- predictor.py +277 -63
predictor.py
CHANGED
|
@@ -1,24 +1,97 @@
|
|
| 1 |
-
# -*- coding: utf-8 -*-
|
| 2 |
-
"""predictor.ipynb
|
| 3 |
-
|
| 4 |
-
Automatically generated by Colab.
|
| 5 |
-
|
| 6 |
-
Original file is located at
|
| 7 |
-
https://colab.research.google.com/drive/1JURb-0j-R4LWK3oxeGrNxpJm3V6nnX02
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
import torch
|
| 11 |
import torch.nn as nn
|
| 12 |
import torch.nn.functional as F
|
| 13 |
import numpy as np
|
| 14 |
-
from typing import List, Tuple, Dict, Optional
|
| 15 |
import logging
|
| 16 |
import re
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Configure logging
|
| 19 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 20 |
|
| 21 |
-
# =============================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
class BoundaryAwareGenePredictor(nn.Module):
|
| 24 |
"""Multi-task model predicting genes, starts, and ends separately."""
|
|
@@ -80,8 +153,6 @@ class BoundaryAwareGenePredictor(nn.Module):
|
|
| 80 |
'end': self.end_classifier(attended)
|
| 81 |
}
|
| 82 |
|
| 83 |
-
# ============================= DATA PREPROCESSING =============================
|
| 84 |
-
|
| 85 |
class DNAProcessor:
|
| 86 |
"""DNA sequence processor with boundary-aware features."""
|
| 87 |
|
|
@@ -106,16 +177,16 @@ class DNAProcessor:
|
|
| 106 |
one_hot.scatter_(1, encoded.unsqueeze(1), 1)
|
| 107 |
features = [one_hot]
|
| 108 |
|
| 109 |
-
# Start codon indicators
|
| 110 |
start_indicators = torch.zeros(seq_len, 3)
|
| 111 |
for i in range(seq_len - 2):
|
| 112 |
codon = sequence[i:i+3]
|
| 113 |
if codon == 'ATG':
|
| 114 |
start_indicators[i:i+3, 0] = 1.0
|
| 115 |
elif codon == 'GTG':
|
| 116 |
-
start_indicators[i:i+3, 1] = 0.9
|
| 117 |
elif codon == 'TTG':
|
| 118 |
-
start_indicators[i:i+3, 2] = 0.8
|
| 119 |
features.append(start_indicators)
|
| 120 |
|
| 121 |
# Stop codon indicators
|
|
@@ -150,8 +221,6 @@ class DNAProcessor:
|
|
| 150 |
|
| 151 |
return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14
|
| 152 |
|
| 153 |
-
# ============================= POST-PROCESSING =============================
|
| 154 |
-
|
| 155 |
class EnhancedPostProcessor:
|
| 156 |
"""Enhanced post-processor with stricter boundary detection."""
|
| 157 |
|
|
@@ -164,8 +233,6 @@ class EnhancedPostProcessor:
|
|
| 164 |
def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray,
|
| 165 |
end_probs: np.ndarray, sequence: str = None) -> np.ndarray:
|
| 166 |
"""Process predictions with enhanced boundary detection."""
|
| 167 |
-
|
| 168 |
-
# More conservative thresholds
|
| 169 |
gene_pred = (gene_probs[:, 1] > 0.6).astype(int)
|
| 170 |
start_pred = (start_probs[:, 1] > 0.4).astype(int)
|
| 171 |
end_pred = (end_probs[:, 1] > 0.5).astype(int)
|
|
@@ -178,7 +245,6 @@ class EnhancedPostProcessor:
|
|
| 178 |
processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred)
|
| 179 |
|
| 180 |
processed = self._apply_constraints(processed, sequence)
|
| 181 |
-
|
| 182 |
return processed
|
| 183 |
|
| 184 |
def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray,
|
|
@@ -205,7 +271,7 @@ class EnhancedPostProcessor:
|
|
| 205 |
|
| 206 |
for g_start, g_end in zip(gene_starts, gene_ends):
|
| 207 |
best_start = g_start
|
| 208 |
-
start_window = 100
|
| 209 |
nearby_starts = [pos for pos in start_codon_positions
|
| 210 |
if abs(pos - g_start) <= start_window]
|
| 211 |
|
|
@@ -216,7 +282,7 @@ class EnhancedPostProcessor:
|
|
| 216 |
codon = sequence[pos:pos+3]
|
| 217 |
codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8)
|
| 218 |
boundary_score = start_pred[pos]
|
| 219 |
-
distance_penalty = abs(pos - g_start) / start_window * 0.2
|
| 220 |
score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty
|
| 221 |
start_scores.append((score, pos))
|
| 222 |
|
|
@@ -306,10 +372,10 @@ class EnhancedPostProcessor:
|
|
| 306 |
|
| 307 |
return processed
|
| 308 |
|
| 309 |
-
# =============================
|
| 310 |
|
| 311 |
-
class
|
| 312 |
-
"""
|
| 313 |
|
| 314 |
def __init__(self, model_path: str = 'model/best_boundary_aware_model.pth',
|
| 315 |
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
|
@@ -324,13 +390,50 @@ class GenePredictor:
|
|
| 324 |
self.model.eval()
|
| 325 |
self.processor = DNAProcessor()
|
| 326 |
self.post_processor = EnhancedPostProcessor()
|
| 327 |
-
|
| 328 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
sequence = sequence.upper()
|
| 330 |
if not re.match('^[ACTGN]+$', sequence):
|
| 331 |
-
logging.warning("Sequence contains invalid characters. Using 'N' for unknowns.")
|
| 332 |
sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence)
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device)
|
| 335 |
|
| 336 |
with torch.no_grad():
|
|
@@ -343,10 +446,95 @@ class GenePredictor:
|
|
| 343 |
gene_probs, start_probs, end_probs, sequence
|
| 344 |
)
|
| 345 |
confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]:
|
|
|
|
| 350 |
regions = []
|
| 351 |
changes = np.diff(np.concatenate(([0], predictions, [0])))
|
| 352 |
starts = np.where(changes == 1)[0]
|
|
@@ -370,9 +558,9 @@ class GenePredictor:
|
|
| 370 |
break
|
| 371 |
|
| 372 |
regions.append({
|
| 373 |
-
'start': int(start),
|
| 374 |
'end': int(end),
|
| 375 |
-
'sequence': gene_seq,
|
| 376 |
'length': int(end - start),
|
| 377 |
'start_codon': actual_start_codon,
|
| 378 |
'stop_codon': actual_stop_codon,
|
|
@@ -381,34 +569,60 @@ class GenePredictor:
|
|
| 381 |
|
| 382 |
return regions
|
| 383 |
|
| 384 |
-
def
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import numpy as np
|
| 5 |
+
from typing import List, Tuple, Dict, Optional, Union
|
| 6 |
import logging
|
| 7 |
import re
|
| 8 |
+
import os
|
| 9 |
+
from pathlib import Path
|
| 10 |
|
| 11 |
# Configure logging
|
| 12 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 13 |
|
| 14 |
+
# ============================= FILE READERS =============================
|
| 15 |
+
|
| 16 |
+
class FileReader:
|
| 17 |
+
"""Handles reading DNA sequences from various file formats."""
|
| 18 |
+
|
| 19 |
+
@staticmethod
|
| 20 |
+
def read_fasta(file_path: str) -> Dict[str, str]:
|
| 21 |
+
"""
|
| 22 |
+
Read FASTA file and return dictionary of sequence_id: sequence
|
| 23 |
+
"""
|
| 24 |
+
sequences = {}
|
| 25 |
+
current_id = None
|
| 26 |
+
current_seq = []
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 30 |
+
for line in file:
|
| 31 |
+
line = line.strip()
|
| 32 |
+
if line.startswith('>'):
|
| 33 |
+
# Save previous sequence if exists
|
| 34 |
+
if current_id is not None:
|
| 35 |
+
sequences[current_id] = ''.join(current_seq)
|
| 36 |
+
# Start new sequence
|
| 37 |
+
current_id = line[1:] # Remove '>' character
|
| 38 |
+
current_seq = []
|
| 39 |
+
elif line and current_id is not None:
|
| 40 |
+
# Add sequence line (remove any whitespace)
|
| 41 |
+
current_seq.append(line.replace(' ', '').replace('\t', ''))
|
| 42 |
+
|
| 43 |
+
# Don't forget the last sequence
|
| 44 |
+
if current_id is not None:
|
| 45 |
+
sequences[current_id] = ''.join(current_seq)
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logging.error(f"Error reading FASTA file {file_path}: {e}")
|
| 49 |
+
raise
|
| 50 |
+
|
| 51 |
+
return sequences
|
| 52 |
+
|
| 53 |
+
@staticmethod
|
| 54 |
+
def read_txt(file_path: str) -> str:
|
| 55 |
+
"""
|
| 56 |
+
Read plain text file containing DNA sequence
|
| 57 |
+
"""
|
| 58 |
+
try:
|
| 59 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 60 |
+
content = file.read().strip()
|
| 61 |
+
# Remove any whitespace, newlines, and non-DNA characters
|
| 62 |
+
sequence = ''.join(c.upper() for c in content if c.upper() in 'ACTGN')
|
| 63 |
+
return sequence
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logging.error(f"Error reading TXT file {file_path}: {e}")
|
| 66 |
+
raise
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def detect_file_type(file_path: str) -> str:
|
| 70 |
+
"""
|
| 71 |
+
Detect file type based on extension and content
|
| 72 |
+
"""
|
| 73 |
+
file_path = Path(file_path)
|
| 74 |
+
extension = file_path.suffix.lower()
|
| 75 |
+
|
| 76 |
+
if extension in ['.fasta', '.fa', '.fas', '.fna']:
|
| 77 |
+
return 'fasta'
|
| 78 |
+
elif extension in ['.txt', '.seq']:
|
| 79 |
+
return 'txt'
|
| 80 |
+
else:
|
| 81 |
+
# Try to detect by content
|
| 82 |
+
try:
|
| 83 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 84 |
+
first_line = file.readline().strip()
|
| 85 |
+
if first_line.startswith('>'):
|
| 86 |
+
return 'fasta'
|
| 87 |
+
else:
|
| 88 |
+
return 'txt'
|
| 89 |
+
except:
|
| 90 |
+
logging.warning(f"Could not detect file type for {file_path}, assuming txt")
|
| 91 |
+
return 'txt'
|
| 92 |
+
|
| 93 |
+
# ============================= ORIGINAL MODEL COMPONENTS =============================
|
| 94 |
+
# (Including all the original classes: BoundaryAwareGenePredictor, DNAProcessor, EnhancedPostProcessor)
|
| 95 |
|
| 96 |
class BoundaryAwareGenePredictor(nn.Module):
|
| 97 |
"""Multi-task model predicting genes, starts, and ends separately."""
|
|
|
|
| 153 |
'end': self.end_classifier(attended)
|
| 154 |
}
|
| 155 |
|
|
|
|
|
|
|
| 156 |
class DNAProcessor:
|
| 157 |
"""DNA sequence processor with boundary-aware features."""
|
| 158 |
|
|
|
|
| 177 |
one_hot.scatter_(1, encoded.unsqueeze(1), 1)
|
| 178 |
features = [one_hot]
|
| 179 |
|
| 180 |
+
# Start codon indicators
|
| 181 |
start_indicators = torch.zeros(seq_len, 3)
|
| 182 |
for i in range(seq_len - 2):
|
| 183 |
codon = sequence[i:i+3]
|
| 184 |
if codon == 'ATG':
|
| 185 |
start_indicators[i:i+3, 0] = 1.0
|
| 186 |
elif codon == 'GTG':
|
| 187 |
+
start_indicators[i:i+3, 1] = 0.9
|
| 188 |
elif codon == 'TTG':
|
| 189 |
+
start_indicators[i:i+3, 2] = 0.8
|
| 190 |
features.append(start_indicators)
|
| 191 |
|
| 192 |
# Stop codon indicators
|
|
|
|
| 221 |
|
| 222 |
return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14
|
| 223 |
|
|
|
|
|
|
|
| 224 |
class EnhancedPostProcessor:
|
| 225 |
"""Enhanced post-processor with stricter boundary detection."""
|
| 226 |
|
|
|
|
| 233 |
def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray,
|
| 234 |
end_probs: np.ndarray, sequence: str = None) -> np.ndarray:
|
| 235 |
"""Process predictions with enhanced boundary detection."""
|
|
|
|
|
|
|
| 236 |
gene_pred = (gene_probs[:, 1] > 0.6).astype(int)
|
| 237 |
start_pred = (start_probs[:, 1] > 0.4).astype(int)
|
| 238 |
end_pred = (end_probs[:, 1] > 0.5).astype(int)
|
|
|
|
| 245 |
processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred)
|
| 246 |
|
| 247 |
processed = self._apply_constraints(processed, sequence)
|
|
|
|
| 248 |
return processed
|
| 249 |
|
| 250 |
def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray,
|
|
|
|
| 271 |
|
| 272 |
for g_start, g_end in zip(gene_starts, gene_ends):
|
| 273 |
best_start = g_start
|
| 274 |
+
start_window = 100
|
| 275 |
nearby_starts = [pos for pos in start_codon_positions
|
| 276 |
if abs(pos - g_start) <= start_window]
|
| 277 |
|
|
|
|
| 282 |
codon = sequence[pos:pos+3]
|
| 283 |
codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8)
|
| 284 |
boundary_score = start_pred[pos]
|
| 285 |
+
distance_penalty = abs(pos - g_start) / start_window * 0.2
|
| 286 |
score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty
|
| 287 |
start_scores.append((score, pos))
|
| 288 |
|
|
|
|
| 372 |
|
| 373 |
return processed
|
| 374 |
|
| 375 |
+
# ============================= ENHANCED GENE PREDICTOR =============================
|
| 376 |
|
| 377 |
+
class EnhancedGenePredictor:
|
| 378 |
+
"""Enhanced Gene Predictor with file input support."""
|
| 379 |
|
| 380 |
def __init__(self, model_path: str = 'model/best_boundary_aware_model.pth',
|
| 381 |
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
|
|
|
| 390 |
self.model.eval()
|
| 391 |
self.processor = DNAProcessor()
|
| 392 |
self.post_processor = EnhancedPostProcessor()
|
| 393 |
+
self.file_reader = FileReader()
|
| 394 |
+
|
| 395 |
+
def predict_from_file(self, file_path: str) -> Dict[str, Dict]:
|
| 396 |
+
"""
|
| 397 |
+
Predict genes from a file (.txt or .fasta)
|
| 398 |
+
Returns a dictionary with sequence_id as keys and prediction results as values
|
| 399 |
+
"""
|
| 400 |
+
if not os.path.exists(file_path):
|
| 401 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
| 402 |
+
|
| 403 |
+
file_type = self.file_reader.detect_file_type(file_path)
|
| 404 |
+
logging.info(f"Detected file type: {file_type}")
|
| 405 |
+
|
| 406 |
+
results = {}
|
| 407 |
+
|
| 408 |
+
if file_type == 'fasta':
|
| 409 |
+
sequences = self.file_reader.read_fasta(file_path)
|
| 410 |
+
for seq_id, sequence in sequences.items():
|
| 411 |
+
logging.info(f"Processing sequence: {seq_id} (length: {len(sequence)})")
|
| 412 |
+
result = self.predict_sequence(sequence, seq_id)
|
| 413 |
+
results[seq_id] = result
|
| 414 |
+
else: # txt file
|
| 415 |
+
sequence = self.file_reader.read_txt(file_path)
|
| 416 |
+
seq_id = Path(file_path).stem # Use filename as sequence ID
|
| 417 |
+
logging.info(f"Processing sequence from {file_path} (length: {len(sequence)})")
|
| 418 |
+
result = self.predict_sequence(sequence, seq_id)
|
| 419 |
+
results[seq_id] = result
|
| 420 |
+
|
| 421 |
+
return results
|
| 422 |
+
|
| 423 |
+
def predict_sequence(self, sequence: str, seq_id: str = "sequence") -> Dict:
|
| 424 |
+
"""
|
| 425 |
+
Predict genes from a single DNA sequence string
|
| 426 |
+
"""
|
| 427 |
sequence = sequence.upper()
|
| 428 |
if not re.match('^[ACTGN]+$', sequence):
|
| 429 |
+
logging.warning(f"Sequence {seq_id} contains invalid characters. Using 'N' for unknowns.")
|
| 430 |
sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence)
|
| 431 |
|
| 432 |
+
# Handle very long sequences by chunking if needed
|
| 433 |
+
max_chunk_size = 50000 # Adjust based on your GPU memory
|
| 434 |
+
if len(sequence) > max_chunk_size:
|
| 435 |
+
return self._predict_long_sequence(sequence, seq_id, max_chunk_size)
|
| 436 |
+
|
| 437 |
features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device)
|
| 438 |
|
| 439 |
with torch.no_grad():
|
|
|
|
| 446 |
gene_probs, start_probs, end_probs, sequence
|
| 447 |
)
|
| 448 |
confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
|
| 449 |
+
|
| 450 |
+
gene_regions = self.extract_gene_regions(predictions, sequence)
|
| 451 |
+
|
| 452 |
+
return {
|
| 453 |
+
'sequence_id': seq_id,
|
| 454 |
+
'sequence_length': len(sequence),
|
| 455 |
+
'predictions': predictions.tolist(),
|
| 456 |
+
'probabilities': {
|
| 457 |
+
'gene': gene_probs.tolist(),
|
| 458 |
+
'start': start_probs.tolist(),
|
| 459 |
+
'end': end_probs.tolist()
|
| 460 |
+
},
|
| 461 |
+
'confidence': float(confidence),
|
| 462 |
+
'gene_regions': gene_regions,
|
| 463 |
+
'total_genes_found': len(gene_regions)
|
| 464 |
+
}
|
| 465 |
|
| 466 |
+
def _predict_long_sequence(self, sequence: str, seq_id: str, chunk_size: int) -> Dict:
|
| 467 |
+
"""
|
| 468 |
+
Handle very long sequences by processing in chunks with overlap
|
| 469 |
+
"""
|
| 470 |
+
overlap = 1000 # Overlap between chunks to avoid missing genes at boundaries
|
| 471 |
+
all_predictions = []
|
| 472 |
+
all_gene_probs = []
|
| 473 |
+
all_start_probs = []
|
| 474 |
+
all_end_probs = []
|
| 475 |
+
|
| 476 |
+
for i in range(0, len(sequence), chunk_size - overlap):
|
| 477 |
+
end_pos = min(i + chunk_size, len(sequence))
|
| 478 |
+
chunk = sequence[i:end_pos]
|
| 479 |
+
|
| 480 |
+
logging.info(f"Processing chunk {i//chunk_size + 1} of sequence {seq_id}")
|
| 481 |
+
|
| 482 |
+
features = self.processor.create_enhanced_features(chunk).unsqueeze(0).to(self.device)
|
| 483 |
+
|
| 484 |
+
with torch.no_grad():
|
| 485 |
+
outputs = self.model(features)
|
| 486 |
+
gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
|
| 487 |
+
start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
|
| 488 |
+
end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
|
| 489 |
+
|
| 490 |
+
chunk_predictions = self.post_processor.process_predictions(
|
| 491 |
+
gene_probs, start_probs, end_probs, chunk
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# Handle overlaps by taking the first chunk fully and subsequent chunks without overlap
|
| 495 |
+
if i == 0:
|
| 496 |
+
all_predictions.extend(chunk_predictions)
|
| 497 |
+
all_gene_probs.extend(gene_probs)
|
| 498 |
+
all_start_probs.extend(start_probs)
|
| 499 |
+
all_end_probs.extend(end_probs)
|
| 500 |
+
else:
|
| 501 |
+
# Skip overlap region
|
| 502 |
+
skip = min(overlap, len(chunk_predictions))
|
| 503 |
+
all_predictions.extend(chunk_predictions[skip:])
|
| 504 |
+
all_gene_probs.extend(gene_probs[skip:])
|
| 505 |
+
all_start_probs.extend(start_probs[skip:])
|
| 506 |
+
all_end_probs.extend(end_probs[skip:])
|
| 507 |
+
|
| 508 |
+
predictions = np.array(all_predictions)
|
| 509 |
+
gene_probs = np.array(all_gene_probs)
|
| 510 |
+
start_probs = np.array(all_start_probs)
|
| 511 |
+
end_probs = np.array(all_end_probs)
|
| 512 |
+
|
| 513 |
+
confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
|
| 514 |
+
gene_regions = self.extract_gene_regions(predictions, sequence)
|
| 515 |
+
|
| 516 |
+
return {
|
| 517 |
+
'sequence_id': seq_id,
|
| 518 |
+
'sequence_length': len(sequence),
|
| 519 |
+
'predictions': predictions.tolist(),
|
| 520 |
+
'probabilities': {
|
| 521 |
+
'gene': gene_probs.tolist(),
|
| 522 |
+
'start': start_probs.tolist(),
|
| 523 |
+
'end': end_probs.tolist()
|
| 524 |
+
},
|
| 525 |
+
'confidence': float(confidence),
|
| 526 |
+
'gene_regions': gene_regions,
|
| 527 |
+
'total_genes_found': len(gene_regions)
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
def predict_from_text(self, sequence: str) -> Dict:
|
| 531 |
+
"""
|
| 532 |
+
Predict genes from a text string (backward compatibility)
|
| 533 |
+
"""
|
| 534 |
+
return self.predict_sequence(sequence)
|
| 535 |
|
| 536 |
def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]:
|
| 537 |
+
"""Extract gene regions from predictions"""
|
| 538 |
regions = []
|
| 539 |
changes = np.diff(np.concatenate(([0], predictions, [0])))
|
| 540 |
starts = np.where(changes == 1)[0]
|
|
|
|
| 558 |
break
|
| 559 |
|
| 560 |
regions.append({
|
| 561 |
+
'start': int(start),
|
| 562 |
'end': int(end),
|
| 563 |
+
'sequence': gene_seq,
|
| 564 |
'length': int(end - start),
|
| 565 |
'start_codon': actual_start_codon,
|
| 566 |
'stop_codon': actual_stop_codon,
|
|
|
|
| 569 |
|
| 570 |
return regions
|
| 571 |
|
| 572 |
+
def save_results(self, results: Dict[str, Dict], output_path: str, format: str = 'json'):
|
| 573 |
+
"""
|
| 574 |
+
Save prediction results to file
|
| 575 |
+
"""
|
| 576 |
+
import json
|
| 577 |
+
|
| 578 |
+
if format.lower() == 'json':
|
| 579 |
+
with open(output_path, 'w') as f:
|
| 580 |
+
json.dump(results, f, indent=2)
|
| 581 |
+
elif format.lower() == 'csv':
|
| 582 |
+
import csv
|
| 583 |
+
with open(output_path, 'w', newline='') as f:
|
| 584 |
+
writer = csv.writer(f)
|
| 585 |
+
writer.writerow(['sequence_id', 'gene_start', 'gene_end', 'gene_length',
|
| 586 |
+
'start_codon', 'stop_codon', 'in_frame', 'confidence'])
|
| 587 |
+
|
| 588 |
+
for seq_id, result in results.items():
|
| 589 |
+
for gene in result['gene_regions']:
|
| 590 |
+
writer.writerow([
|
| 591 |
+
seq_id, gene['start'], gene['end'], gene['length'],
|
| 592 |
+
gene['start_codon'], gene['stop_codon'], gene['in_frame'],
|
| 593 |
+
result['confidence']
|
| 594 |
+
])
|
| 595 |
+
|
| 596 |
+
logging.info(f"Results saved to {output_path}")
|
| 597 |
+
|
| 598 |
+
# ============================= USAGE EXAMPLE =============================
|
| 599 |
+
|
| 600 |
+
def main():
|
| 601 |
+
"""Example usage of the enhanced gene predictor"""
|
| 602 |
+
|
| 603 |
+
# Initialize predictor
|
| 604 |
+
predictor = EnhancedGenePredictor(model_path='model/best_boundary_aware_model.pth')
|
| 605 |
+
|
| 606 |
+
# Example 1: Predict from FASTA file
|
| 607 |
+
try:
|
| 608 |
+
fasta_results = predictor.predict_from_file('example.fasta')
|
| 609 |
+
predictor.save_results(fasta_results, 'fasta_predictions.json')
|
| 610 |
+
print("FASTA predictions saved to fasta_predictions.json")
|
| 611 |
+
except FileNotFoundError:
|
| 612 |
+
print("example.fasta not found, skipping FASTA example")
|
| 613 |
+
|
| 614 |
+
# Example 2: Predict from TXT file
|
| 615 |
+
try:
|
| 616 |
+
txt_results = predictor.predict_from_file('example.txt')
|
| 617 |
+
predictor.save_results(txt_results, 'txt_predictions.csv', format='csv')
|
| 618 |
+
print("TXT predictions saved to txt_predictions.csv")
|
| 619 |
+
except FileNotFoundError:
|
| 620 |
+
print("example.txt not found, skipping TXT example")
|
| 621 |
+
|
| 622 |
+
# Example 3: Predict from text string (original functionality)
|
| 623 |
+
example_sequence = "ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCACAGGTAACGGTGCGGGCTGA"
|
| 624 |
+
text_results = predictor.predict_from_text(example_sequence)
|
| 625 |
+
print(f"Found {text_results['total_genes_found']} genes in example sequence")
|
| 626 |
+
|
| 627 |
+
if __name__ == "__main__":
|
| 628 |
+
main()
|