Spaces:
No application file
No application file
Delete predictor.py
Browse files- predictor.py +0 -628
predictor.py
DELETED
|
@@ -1,628 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
import torch.nn.functional as F
|
| 4 |
-
import numpy as np
|
| 5 |
-
from typing import List, Tuple, Dict, Optional, Union
|
| 6 |
-
import logging
|
| 7 |
-
import re
|
| 8 |
-
import os
|
| 9 |
-
from pathlib import Path
|
| 10 |
-
|
| 11 |
-
# Configure logging
|
| 12 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 13 |
-
|
| 14 |
-
# ============================= FILE READERS =============================
|
| 15 |
-
|
| 16 |
-
class FileReader:
|
| 17 |
-
"""Handles reading DNA sequences from various file formats."""
|
| 18 |
-
|
| 19 |
-
@staticmethod
|
| 20 |
-
def read_fasta(file_path: str) -> Dict[str, str]:
|
| 21 |
-
"""
|
| 22 |
-
Read FASTA file and return dictionary of sequence_id: sequence
|
| 23 |
-
"""
|
| 24 |
-
sequences = {}
|
| 25 |
-
current_id = None
|
| 26 |
-
current_seq = []
|
| 27 |
-
|
| 28 |
-
try:
|
| 29 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
| 30 |
-
for line in file:
|
| 31 |
-
line = line.strip()
|
| 32 |
-
if line.startswith('>'):
|
| 33 |
-
# Save previous sequence if exists
|
| 34 |
-
if current_id is not None:
|
| 35 |
-
sequences[current_id] = ''.join(current_seq)
|
| 36 |
-
# Start new sequence
|
| 37 |
-
current_id = line[1:] # Remove '>' character
|
| 38 |
-
current_seq = []
|
| 39 |
-
elif line and current_id is not None:
|
| 40 |
-
# Add sequence line (remove any whitespace)
|
| 41 |
-
current_seq.append(line.replace(' ', '').replace('\t', ''))
|
| 42 |
-
|
| 43 |
-
# Don't forget the last sequence
|
| 44 |
-
if current_id is not None:
|
| 45 |
-
sequences[current_id] = ''.join(current_seq)
|
| 46 |
-
|
| 47 |
-
except Exception as e:
|
| 48 |
-
logging.error(f"Error reading FASTA file {file_path}: {e}")
|
| 49 |
-
raise
|
| 50 |
-
|
| 51 |
-
return sequences
|
| 52 |
-
|
| 53 |
-
@staticmethod
|
| 54 |
-
def read_txt(file_path: str) -> str:
|
| 55 |
-
"""
|
| 56 |
-
Read plain text file containing DNA sequence
|
| 57 |
-
"""
|
| 58 |
-
try:
|
| 59 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
| 60 |
-
content = file.read().strip()
|
| 61 |
-
# Remove any whitespace, newlines, and non-DNA characters
|
| 62 |
-
sequence = ''.join(c.upper() for c in content if c.upper() in 'ACTGN')
|
| 63 |
-
return sequence
|
| 64 |
-
except Exception as e:
|
| 65 |
-
logging.error(f"Error reading TXT file {file_path}: {e}")
|
| 66 |
-
raise
|
| 67 |
-
|
| 68 |
-
@staticmethod
|
| 69 |
-
def detect_file_type(file_path: str) -> str:
|
| 70 |
-
"""
|
| 71 |
-
Detect file type based on extension and content
|
| 72 |
-
"""
|
| 73 |
-
file_path = Path(file_path)
|
| 74 |
-
extension = file_path.suffix.lower()
|
| 75 |
-
|
| 76 |
-
if extension in ['.fasta', '.fa', '.fas', '.fna']:
|
| 77 |
-
return 'fasta'
|
| 78 |
-
elif extension in ['.txt', '.seq']:
|
| 79 |
-
return 'txt'
|
| 80 |
-
else:
|
| 81 |
-
# Try to detect by content
|
| 82 |
-
try:
|
| 83 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
| 84 |
-
first_line = file.readline().strip()
|
| 85 |
-
if first_line.startswith('>'):
|
| 86 |
-
return 'fasta'
|
| 87 |
-
else:
|
| 88 |
-
return 'txt'
|
| 89 |
-
except:
|
| 90 |
-
logging.warning(f"Could not detect file type for {file_path}, assuming txt")
|
| 91 |
-
return 'txt'
|
| 92 |
-
|
| 93 |
-
# ============================= ORIGINAL MODEL COMPONENTS =============================
|
| 94 |
-
# (Including all the original classes: BoundaryAwareGenePredictor, DNAProcessor, EnhancedPostProcessor)
|
| 95 |
-
|
| 96 |
-
class BoundaryAwareGenePredictor(nn.Module):
|
| 97 |
-
"""Multi-task model predicting genes, starts, and ends separately."""
|
| 98 |
-
|
| 99 |
-
def __init__(self, input_dim: int = 14, hidden_dim: int = 256,
|
| 100 |
-
num_layers: int = 3, dropout: float = 0.3):
|
| 101 |
-
super().__init__()
|
| 102 |
-
self.conv_layers = nn.ModuleList([
|
| 103 |
-
nn.Conv1d(input_dim, hidden_dim//4, kernel_size=k, padding=k//2)
|
| 104 |
-
for k in [3, 7, 15, 31]
|
| 105 |
-
])
|
| 106 |
-
self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, num_layers,
|
| 107 |
-
batch_first=True, bidirectional=True, dropout=dropout)
|
| 108 |
-
self.norm = nn.LayerNorm(hidden_dim)
|
| 109 |
-
self.dropout = nn.Dropout(dropout)
|
| 110 |
-
self.boundary_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
|
| 111 |
-
|
| 112 |
-
self.gene_classifier = nn.Sequential(
|
| 113 |
-
nn.Linear(hidden_dim, hidden_dim//2),
|
| 114 |
-
nn.ReLU(),
|
| 115 |
-
nn.Dropout(dropout),
|
| 116 |
-
nn.Linear(hidden_dim//2, 2)
|
| 117 |
-
)
|
| 118 |
-
self.start_classifier = nn.Sequential(
|
| 119 |
-
nn.Linear(hidden_dim, hidden_dim//2),
|
| 120 |
-
nn.ReLU(),
|
| 121 |
-
nn.Dropout(dropout),
|
| 122 |
-
nn.Linear(hidden_dim//2, 2)
|
| 123 |
-
)
|
| 124 |
-
self.end_classifier = nn.Sequential(
|
| 125 |
-
nn.Linear(hidden_dim, hidden_dim//2),
|
| 126 |
-
nn.ReLU(),
|
| 127 |
-
nn.Dropout(dropout),
|
| 128 |
-
nn.Linear(hidden_dim//2, 2)
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
|
| 132 |
-
batch_size, seq_len, _ = x.shape
|
| 133 |
-
x_conv = x.transpose(1, 2)
|
| 134 |
-
conv_features = [F.relu(conv(x_conv)) for conv in self.conv_layers]
|
| 135 |
-
features = torch.cat(conv_features, dim=1).transpose(1, 2)
|
| 136 |
-
|
| 137 |
-
if lengths is not None:
|
| 138 |
-
packed = nn.utils.rnn.pack_padded_sequence(
|
| 139 |
-
features, lengths.cpu(), batch_first=True, enforce_sorted=False
|
| 140 |
-
)
|
| 141 |
-
lstm_out, _ = self.lstm(packed)
|
| 142 |
-
lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
|
| 143 |
-
else:
|
| 144 |
-
lstm_out, _ = self.lstm(features)
|
| 145 |
-
|
| 146 |
-
lstm_out = self.norm(lstm_out)
|
| 147 |
-
attended, _ = self.boundary_attention(lstm_out, lstm_out, lstm_out)
|
| 148 |
-
attended = self.dropout(attended)
|
| 149 |
-
|
| 150 |
-
return {
|
| 151 |
-
'gene': self.gene_classifier(attended),
|
| 152 |
-
'start': self.start_classifier(attended),
|
| 153 |
-
'end': self.end_classifier(attended)
|
| 154 |
-
}
|
| 155 |
-
|
| 156 |
-
class DNAProcessor:
|
| 157 |
-
"""DNA sequence processor with boundary-aware features."""
|
| 158 |
-
|
| 159 |
-
def __init__(self):
|
| 160 |
-
self.base_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}
|
| 161 |
-
self.idx_to_base = {v: k for k, v in self.base_to_idx.items()}
|
| 162 |
-
self.start_codons = {'ATG', 'GTG', 'TTG'}
|
| 163 |
-
self.stop_codons = {'TAA', 'TAG', 'TGA'}
|
| 164 |
-
|
| 165 |
-
def encode_sequence(self, sequence: str) -> torch.Tensor:
|
| 166 |
-
sequence = sequence.upper()
|
| 167 |
-
encoded = [self.base_to_idx.get(base, self.base_to_idx['N']) for base in sequence]
|
| 168 |
-
return torch.tensor(encoded, dtype=torch.long)
|
| 169 |
-
|
| 170 |
-
def create_enhanced_features(self, sequence: str) -> torch.Tensor:
|
| 171 |
-
sequence = sequence.upper()
|
| 172 |
-
seq_len = len(sequence)
|
| 173 |
-
encoded = self.encode_sequence(sequence)
|
| 174 |
-
|
| 175 |
-
# One-hot encoding
|
| 176 |
-
one_hot = torch.zeros(seq_len, 5)
|
| 177 |
-
one_hot.scatter_(1, encoded.unsqueeze(1), 1)
|
| 178 |
-
features = [one_hot]
|
| 179 |
-
|
| 180 |
-
# Start codon indicators
|
| 181 |
-
start_indicators = torch.zeros(seq_len, 3)
|
| 182 |
-
for i in range(seq_len - 2):
|
| 183 |
-
codon = sequence[i:i+3]
|
| 184 |
-
if codon == 'ATG':
|
| 185 |
-
start_indicators[i:i+3, 0] = 1.0
|
| 186 |
-
elif codon == 'GTG':
|
| 187 |
-
start_indicators[i:i+3, 1] = 0.9
|
| 188 |
-
elif codon == 'TTG':
|
| 189 |
-
start_indicators[i:i+3, 2] = 0.8
|
| 190 |
-
features.append(start_indicators)
|
| 191 |
-
|
| 192 |
-
# Stop codon indicators
|
| 193 |
-
stop_indicators = torch.zeros(seq_len, 3)
|
| 194 |
-
for i in range(seq_len - 2):
|
| 195 |
-
codon = sequence[i:i+3]
|
| 196 |
-
if codon == 'TAA':
|
| 197 |
-
stop_indicators[i:i+3, 0] = 1.0
|
| 198 |
-
elif codon == 'TAG':
|
| 199 |
-
stop_indicators[i:i+3, 1] = 1.0
|
| 200 |
-
elif codon == 'TGA':
|
| 201 |
-
stop_indicators[i:i+3, 2] = 1.0
|
| 202 |
-
features.append(stop_indicators)
|
| 203 |
-
|
| 204 |
-
# GC content
|
| 205 |
-
gc_content = torch.zeros(seq_len, 1)
|
| 206 |
-
window_size = 50
|
| 207 |
-
for i in range(seq_len):
|
| 208 |
-
start = max(0, i - window_size//2)
|
| 209 |
-
end = min(seq_len, i + window_size//2)
|
| 210 |
-
window = sequence[start:end]
|
| 211 |
-
gc_count = window.count('G') + window.count('C')
|
| 212 |
-
gc_content[i, 0] = gc_count / len(window) if len(window) > 0 else 0
|
| 213 |
-
features.append(gc_content)
|
| 214 |
-
|
| 215 |
-
# Position encoding
|
| 216 |
-
pos_encoding = torch.zeros(seq_len, 2)
|
| 217 |
-
positions = torch.arange(seq_len, dtype=torch.float)
|
| 218 |
-
pos_encoding[:, 0] = torch.sin(positions / 10000)
|
| 219 |
-
pos_encoding[:, 1] = torch.cos(positions / 10000)
|
| 220 |
-
features.append(pos_encoding)
|
| 221 |
-
|
| 222 |
-
return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14
|
| 223 |
-
|
| 224 |
-
class EnhancedPostProcessor:
|
| 225 |
-
"""Enhanced post-processor with stricter boundary detection."""
|
| 226 |
-
|
| 227 |
-
def __init__(self, min_gene_length: int = 150, max_gene_length: int = 5000):
|
| 228 |
-
self.min_gene_length = min_gene_length
|
| 229 |
-
self.max_gene_length = max_gene_length
|
| 230 |
-
self.start_codons = {'ATG', 'GTG', 'TTG'}
|
| 231 |
-
self.stop_codons = {'TAA', 'TAG', 'TGA'}
|
| 232 |
-
|
| 233 |
-
def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray,
|
| 234 |
-
end_probs: np.ndarray, sequence: str = None) -> np.ndarray:
|
| 235 |
-
"""Process predictions with enhanced boundary detection."""
|
| 236 |
-
gene_pred = (gene_probs[:, 1] > 0.6).astype(int)
|
| 237 |
-
start_pred = (start_probs[:, 1] > 0.4).astype(int)
|
| 238 |
-
end_pred = (end_probs[:, 1] > 0.5).astype(int)
|
| 239 |
-
|
| 240 |
-
if sequence is not None:
|
| 241 |
-
processed = self._refine_with_codons_and_boundaries(
|
| 242 |
-
gene_pred, start_pred, end_pred, sequence
|
| 243 |
-
)
|
| 244 |
-
else:
|
| 245 |
-
processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred)
|
| 246 |
-
|
| 247 |
-
processed = self._apply_constraints(processed, sequence)
|
| 248 |
-
return processed
|
| 249 |
-
|
| 250 |
-
def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray,
|
| 251 |
-
start_pred: np.ndarray, end_pred: np.ndarray,
|
| 252 |
-
sequence: str) -> np.ndarray:
|
| 253 |
-
refined = gene_pred.copy()
|
| 254 |
-
sequence = sequence.upper()
|
| 255 |
-
|
| 256 |
-
start_codon_positions = []
|
| 257 |
-
stop_codon_positions = []
|
| 258 |
-
|
| 259 |
-
for i in range(len(sequence) - 2):
|
| 260 |
-
codon = sequence[i:i+3]
|
| 261 |
-
if codon in self.start_codons:
|
| 262 |
-
start_codon_positions.append(i)
|
| 263 |
-
if codon in self.stop_codons:
|
| 264 |
-
stop_codon_positions.append(i + 3)
|
| 265 |
-
|
| 266 |
-
changes = np.diff(np.concatenate(([0], gene_pred, [0])))
|
| 267 |
-
gene_starts = np.where(changes == 1)[0]
|
| 268 |
-
gene_ends = np.where(changes == -1)[0]
|
| 269 |
-
|
| 270 |
-
refined = np.zeros_like(gene_pred)
|
| 271 |
-
|
| 272 |
-
for g_start, g_end in zip(gene_starts, gene_ends):
|
| 273 |
-
best_start = g_start
|
| 274 |
-
start_window = 100
|
| 275 |
-
nearby_starts = [pos for pos in start_codon_positions
|
| 276 |
-
if abs(pos - g_start) <= start_window]
|
| 277 |
-
|
| 278 |
-
if nearby_starts:
|
| 279 |
-
start_scores = []
|
| 280 |
-
for pos in nearby_starts:
|
| 281 |
-
if pos < len(start_pred):
|
| 282 |
-
codon = sequence[pos:pos+3]
|
| 283 |
-
codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8)
|
| 284 |
-
boundary_score = start_pred[pos]
|
| 285 |
-
distance_penalty = abs(pos - g_start) / start_window * 0.2
|
| 286 |
-
score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty
|
| 287 |
-
start_scores.append((score, pos))
|
| 288 |
-
|
| 289 |
-
if start_scores:
|
| 290 |
-
best_start = max(start_scores, key=lambda x: x[0])[1]
|
| 291 |
-
|
| 292 |
-
best_end = g_end
|
| 293 |
-
end_window = 100
|
| 294 |
-
nearby_ends = [pos for pos in stop_codon_positions
|
| 295 |
-
if g_start < pos <= g_end + end_window]
|
| 296 |
-
|
| 297 |
-
if nearby_ends:
|
| 298 |
-
end_scores = []
|
| 299 |
-
for pos in nearby_ends:
|
| 300 |
-
gene_length = pos - best_start
|
| 301 |
-
if self.min_gene_length <= gene_length <= self.max_gene_length:
|
| 302 |
-
if pos < len(end_pred):
|
| 303 |
-
frame_bonus = 0.2 if (pos - best_start) % 3 == 0 else 0
|
| 304 |
-
boundary_score = end_pred[pos]
|
| 305 |
-
length_penalty = abs(gene_length - 1000) / 10000
|
| 306 |
-
score = boundary_score + frame_bonus - length_penalty
|
| 307 |
-
end_scores.append((score, pos))
|
| 308 |
-
|
| 309 |
-
if end_scores:
|
| 310 |
-
best_end = max(end_scores, key=lambda x: x[0])[1]
|
| 311 |
-
|
| 312 |
-
gene_length = best_end - best_start
|
| 313 |
-
if (gene_length >= self.min_gene_length and
|
| 314 |
-
gene_length <= self.max_gene_length and
|
| 315 |
-
best_start < best_end):
|
| 316 |
-
refined[best_start:best_end] = 1
|
| 317 |
-
|
| 318 |
-
return refined
|
| 319 |
-
|
| 320 |
-
def _refine_with_boundaries(self, gene_pred: np.ndarray, start_pred: np.ndarray,
|
| 321 |
-
end_pred: np.ndarray) -> np.ndarray:
|
| 322 |
-
refined = gene_pred.copy()
|
| 323 |
-
changes = np.diff(np.concatenate(([0], gene_pred, [0])))
|
| 324 |
-
gene_starts = np.where(changes == 1)[0]
|
| 325 |
-
gene_ends = np.where(changes == -1)[0]
|
| 326 |
-
|
| 327 |
-
for g_start, g_end in zip(gene_starts, gene_ends):
|
| 328 |
-
start_window = slice(max(0, g_start-30), min(len(start_pred), g_start+30))
|
| 329 |
-
start_candidates = np.where(start_pred[start_window])[0]
|
| 330 |
-
if len(start_candidates) > 0:
|
| 331 |
-
relative_positions = start_candidates + max(0, g_start-30)
|
| 332 |
-
distances = np.abs(relative_positions - g_start)
|
| 333 |
-
best_start_idx = np.argmin(distances)
|
| 334 |
-
new_start = relative_positions[best_start_idx]
|
| 335 |
-
refined[g_start:new_start] = 0 if new_start > g_start else refined[g_start:new_start]
|
| 336 |
-
refined[new_start:g_end] = 1
|
| 337 |
-
g_start = new_start
|
| 338 |
-
|
| 339 |
-
end_window = slice(max(0, g_end-50), min(len(end_pred), g_end+50))
|
| 340 |
-
end_candidates = np.where(end_pred[end_window])[0]
|
| 341 |
-
if len(end_candidates) > 0:
|
| 342 |
-
relative_positions = end_candidates + max(0, g_end-50)
|
| 343 |
-
valid_ends = [pos for pos in relative_positions
|
| 344 |
-
if self.min_gene_length <= pos - g_start <= self.max_gene_length]
|
| 345 |
-
if valid_ends:
|
| 346 |
-
distances = np.abs(np.array(valid_ends) - g_end)
|
| 347 |
-
new_end = valid_ends[np.argmin(distances)]
|
| 348 |
-
refined[g_start:new_end] = 1
|
| 349 |
-
refined[new_end:g_end] = 0 if new_end < g_end else refined[new_end:g_end]
|
| 350 |
-
|
| 351 |
-
return refined
|
| 352 |
-
|
| 353 |
-
def _apply_constraints(self, predictions: np.ndarray, sequence: str = None) -> np.ndarray:
|
| 354 |
-
processed = predictions.copy()
|
| 355 |
-
changes = np.diff(np.concatenate(([0], predictions, [0])))
|
| 356 |
-
starts = np.where(changes == 1)[0]
|
| 357 |
-
ends = np.where(changes == -1)[0]
|
| 358 |
-
|
| 359 |
-
for start, end in zip(starts, ends):
|
| 360 |
-
gene_length = end - start
|
| 361 |
-
if gene_length < self.min_gene_length or gene_length > self.max_gene_length:
|
| 362 |
-
processed[start:end] = 0
|
| 363 |
-
continue
|
| 364 |
-
if sequence is not None:
|
| 365 |
-
if gene_length % 3 != 0:
|
| 366 |
-
new_length = (gene_length // 3) * 3
|
| 367 |
-
if new_length >= self.min_gene_length:
|
| 368 |
-
new_end = start + new_length
|
| 369 |
-
processed[new_end:end] = 0
|
| 370 |
-
else:
|
| 371 |
-
processed[start:end] = 0
|
| 372 |
-
|
| 373 |
-
return processed
|
| 374 |
-
|
| 375 |
-
# ============================= ENHANCED GENE PREDICTOR =============================
|
| 376 |
-
|
| 377 |
-
class EnhancedGenePredictor:
|
| 378 |
-
"""Enhanced Gene Predictor with file input support."""
|
| 379 |
-
|
| 380 |
-
def __init__(self, model_path: str = 'model/best_boundary_aware_model.pth',
|
| 381 |
-
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
|
| 382 |
-
self.device = device
|
| 383 |
-
self.model = BoundaryAwareGenePredictor(input_dim=14).to(device)
|
| 384 |
-
try:
|
| 385 |
-
self.model.load_state_dict(torch.load(model_path, map_location=device))
|
| 386 |
-
logging.info(f"Loaded model from {model_path}")
|
| 387 |
-
except Exception as e:
|
| 388 |
-
logging.error(f"Failed to load model: {e}")
|
| 389 |
-
raise
|
| 390 |
-
self.model.eval()
|
| 391 |
-
self.processor = DNAProcessor()
|
| 392 |
-
self.post_processor = EnhancedPostProcessor()
|
| 393 |
-
self.file_reader = FileReader()
|
| 394 |
-
|
| 395 |
-
def predict_from_file(self, file_path: str) -> Dict[str, Dict]:
|
| 396 |
-
"""
|
| 397 |
-
Predict genes from a file (.txt or .fasta)
|
| 398 |
-
Returns a dictionary with sequence_id as keys and prediction results as values
|
| 399 |
-
"""
|
| 400 |
-
if not os.path.exists(file_path):
|
| 401 |
-
raise FileNotFoundError(f"File not found: {file_path}")
|
| 402 |
-
|
| 403 |
-
file_type = self.file_reader.detect_file_type(file_path)
|
| 404 |
-
logging.info(f"Detected file type: {file_type}")
|
| 405 |
-
|
| 406 |
-
results = {}
|
| 407 |
-
|
| 408 |
-
if file_type == 'fasta':
|
| 409 |
-
sequences = self.file_reader.read_fasta(file_path)
|
| 410 |
-
for seq_id, sequence in sequences.items():
|
| 411 |
-
logging.info(f"Processing sequence: {seq_id} (length: {len(sequence)})")
|
| 412 |
-
result = self.predict_sequence(sequence, seq_id)
|
| 413 |
-
results[seq_id] = result
|
| 414 |
-
else: # txt file
|
| 415 |
-
sequence = self.file_reader.read_txt(file_path)
|
| 416 |
-
seq_id = Path(file_path).stem # Use filename as sequence ID
|
| 417 |
-
logging.info(f"Processing sequence from {file_path} (length: {len(sequence)})")
|
| 418 |
-
result = self.predict_sequence(sequence, seq_id)
|
| 419 |
-
results[seq_id] = result
|
| 420 |
-
|
| 421 |
-
return results
|
| 422 |
-
|
| 423 |
-
def predict_sequence(self, sequence: str, seq_id: str = "sequence") -> Dict:
|
| 424 |
-
"""
|
| 425 |
-
Predict genes from a single DNA sequence string
|
| 426 |
-
"""
|
| 427 |
-
sequence = sequence.upper()
|
| 428 |
-
if not re.match('^[ACTGN]+$', sequence):
|
| 429 |
-
logging.warning(f"Sequence {seq_id} contains invalid characters. Using 'N' for unknowns.")
|
| 430 |
-
sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence)
|
| 431 |
-
|
| 432 |
-
# Handle very long sequences by chunking if needed
|
| 433 |
-
max_chunk_size = 50000 # Adjust based on your GPU memory
|
| 434 |
-
if len(sequence) > max_chunk_size:
|
| 435 |
-
return self._predict_long_sequence(sequence, seq_id, max_chunk_size)
|
| 436 |
-
|
| 437 |
-
features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device)
|
| 438 |
-
|
| 439 |
-
with torch.no_grad():
|
| 440 |
-
outputs = self.model(features)
|
| 441 |
-
gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
|
| 442 |
-
start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
|
| 443 |
-
end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
|
| 444 |
-
|
| 445 |
-
predictions = self.post_processor.process_predictions(
|
| 446 |
-
gene_probs, start_probs, end_probs, sequence
|
| 447 |
-
)
|
| 448 |
-
confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
|
| 449 |
-
|
| 450 |
-
gene_regions = self.extract_gene_regions(predictions, sequence)
|
| 451 |
-
|
| 452 |
-
return {
|
| 453 |
-
'sequence_id': seq_id,
|
| 454 |
-
'sequence_length': len(sequence),
|
| 455 |
-
'predictions': predictions.tolist(),
|
| 456 |
-
'probabilities': {
|
| 457 |
-
'gene': gene_probs.tolist(),
|
| 458 |
-
'start': start_probs.tolist(),
|
| 459 |
-
'end': end_probs.tolist()
|
| 460 |
-
},
|
| 461 |
-
'confidence': float(confidence),
|
| 462 |
-
'gene_regions': gene_regions,
|
| 463 |
-
'total_genes_found': len(gene_regions)
|
| 464 |
-
}
|
| 465 |
-
|
| 466 |
-
def _predict_long_sequence(self, sequence: str, seq_id: str, chunk_size: int) -> Dict:
|
| 467 |
-
"""
|
| 468 |
-
Handle very long sequences by processing in chunks with overlap
|
| 469 |
-
"""
|
| 470 |
-
overlap = 1000 # Overlap between chunks to avoid missing genes at boundaries
|
| 471 |
-
all_predictions = []
|
| 472 |
-
all_gene_probs = []
|
| 473 |
-
all_start_probs = []
|
| 474 |
-
all_end_probs = []
|
| 475 |
-
|
| 476 |
-
for i in range(0, len(sequence), chunk_size - overlap):
|
| 477 |
-
end_pos = min(i + chunk_size, len(sequence))
|
| 478 |
-
chunk = sequence[i:end_pos]
|
| 479 |
-
|
| 480 |
-
logging.info(f"Processing chunk {i//chunk_size + 1} of sequence {seq_id}")
|
| 481 |
-
|
| 482 |
-
features = self.processor.create_enhanced_features(chunk).unsqueeze(0).to(self.device)
|
| 483 |
-
|
| 484 |
-
with torch.no_grad():
|
| 485 |
-
outputs = self.model(features)
|
| 486 |
-
gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
|
| 487 |
-
start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
|
| 488 |
-
end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
|
| 489 |
-
|
| 490 |
-
chunk_predictions = self.post_processor.process_predictions(
|
| 491 |
-
gene_probs, start_probs, end_probs, chunk
|
| 492 |
-
)
|
| 493 |
-
|
| 494 |
-
# Handle overlaps by taking the first chunk fully and subsequent chunks without overlap
|
| 495 |
-
if i == 0:
|
| 496 |
-
all_predictions.extend(chunk_predictions)
|
| 497 |
-
all_gene_probs.extend(gene_probs)
|
| 498 |
-
all_start_probs.extend(start_probs)
|
| 499 |
-
all_end_probs.extend(end_probs)
|
| 500 |
-
else:
|
| 501 |
-
# Skip overlap region
|
| 502 |
-
skip = min(overlap, len(chunk_predictions))
|
| 503 |
-
all_predictions.extend(chunk_predictions[skip:])
|
| 504 |
-
all_gene_probs.extend(gene_probs[skip:])
|
| 505 |
-
all_start_probs.extend(start_probs[skip:])
|
| 506 |
-
all_end_probs.extend(end_probs[skip:])
|
| 507 |
-
|
| 508 |
-
predictions = np.array(all_predictions)
|
| 509 |
-
gene_probs = np.array(all_gene_probs)
|
| 510 |
-
start_probs = np.array(all_start_probs)
|
| 511 |
-
end_probs = np.array(all_end_probs)
|
| 512 |
-
|
| 513 |
-
confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
|
| 514 |
-
gene_regions = self.extract_gene_regions(predictions, sequence)
|
| 515 |
-
|
| 516 |
-
return {
|
| 517 |
-
'sequence_id': seq_id,
|
| 518 |
-
'sequence_length': len(sequence),
|
| 519 |
-
'predictions': predictions.tolist(),
|
| 520 |
-
'probabilities': {
|
| 521 |
-
'gene': gene_probs.tolist(),
|
| 522 |
-
'start': start_probs.tolist(),
|
| 523 |
-
'end': end_probs.tolist()
|
| 524 |
-
},
|
| 525 |
-
'confidence': float(confidence),
|
| 526 |
-
'gene_regions': gene_regions,
|
| 527 |
-
'total_genes_found': len(gene_regions)
|
| 528 |
-
}
|
| 529 |
-
|
| 530 |
-
def predict_from_text(self, sequence: str) -> Dict:
|
| 531 |
-
"""
|
| 532 |
-
Predict genes from a text string (backward compatibility)
|
| 533 |
-
"""
|
| 534 |
-
return self.predict_sequence(sequence)
|
| 535 |
-
|
| 536 |
-
def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]:
|
| 537 |
-
"""Extract gene regions from predictions"""
|
| 538 |
-
regions = []
|
| 539 |
-
changes = np.diff(np.concatenate(([0], predictions, [0])))
|
| 540 |
-
starts = np.where(changes == 1)[0]
|
| 541 |
-
ends = np.where(changes == -1)[0]
|
| 542 |
-
|
| 543 |
-
for start, end in zip(starts, ends):
|
| 544 |
-
gene_seq = sequence[start:end]
|
| 545 |
-
actual_start_codon = None
|
| 546 |
-
actual_stop_codon = None
|
| 547 |
-
|
| 548 |
-
if len(gene_seq) >= 3:
|
| 549 |
-
start_codon = gene_seq[:3]
|
| 550 |
-
if start_codon in ['ATG', 'GTG', 'TTG']:
|
| 551 |
-
actual_start_codon = start_codon
|
| 552 |
-
|
| 553 |
-
if len(gene_seq) >= 6:
|
| 554 |
-
for i in range(len(gene_seq) - 2, 2, -3):
|
| 555 |
-
codon = gene_seq[i:i+3]
|
| 556 |
-
if codon in ['TAA', 'TAG', 'TGA']:
|
| 557 |
-
actual_stop_codon = codon
|
| 558 |
-
break
|
| 559 |
-
|
| 560 |
-
regions.append({
|
| 561 |
-
'start': int(start),
|
| 562 |
-
'end': int(end),
|
| 563 |
-
'sequence': gene_seq,
|
| 564 |
-
'length': int(end - start),
|
| 565 |
-
'start_codon': actual_start_codon,
|
| 566 |
-
'stop_codon': actual_stop_codon,
|
| 567 |
-
'in_frame': (end - start) % 3 == 0
|
| 568 |
-
})
|
| 569 |
-
|
| 570 |
-
return regions
|
| 571 |
-
|
| 572 |
-
def save_results(self, results: Dict[str, Dict], output_path: str, format: str = 'json'):
|
| 573 |
-
"""
|
| 574 |
-
Save prediction results to file
|
| 575 |
-
"""
|
| 576 |
-
import json
|
| 577 |
-
|
| 578 |
-
if format.lower() == 'json':
|
| 579 |
-
with open(output_path, 'w') as f:
|
| 580 |
-
json.dump(results, f, indent=2)
|
| 581 |
-
elif format.lower() == 'csv':
|
| 582 |
-
import csv
|
| 583 |
-
with open(output_path, 'w', newline='') as f:
|
| 584 |
-
writer = csv.writer(f)
|
| 585 |
-
writer.writerow(['sequence_id', 'gene_start', 'gene_end', 'gene_length',
|
| 586 |
-
'start_codon', 'stop_codon', 'in_frame', 'confidence'])
|
| 587 |
-
|
| 588 |
-
for seq_id, result in results.items():
|
| 589 |
-
for gene in result['gene_regions']:
|
| 590 |
-
writer.writerow([
|
| 591 |
-
seq_id, gene['start'], gene['end'], gene['length'],
|
| 592 |
-
gene['start_codon'], gene['stop_codon'], gene['in_frame'],
|
| 593 |
-
result['confidence']
|
| 594 |
-
])
|
| 595 |
-
|
| 596 |
-
logging.info(f"Results saved to {output_path}")
|
| 597 |
-
|
| 598 |
-
# ============================= USAGE EXAMPLE =============================
|
| 599 |
-
|
| 600 |
-
def main():
|
| 601 |
-
"""Example usage of the enhanced gene predictor"""
|
| 602 |
-
|
| 603 |
-
# Initialize predictor
|
| 604 |
-
predictor = EnhancedGenePredictor(model_path='model/best_boundary_aware_model.pth')
|
| 605 |
-
|
| 606 |
-
# Example 1: Predict from FASTA file
|
| 607 |
-
try:
|
| 608 |
-
fasta_results = predictor.predict_from_file('example.fasta')
|
| 609 |
-
predictor.save_results(fasta_results, 'fasta_predictions.json')
|
| 610 |
-
print("FASTA predictions saved to fasta_predictions.json")
|
| 611 |
-
except FileNotFoundError:
|
| 612 |
-
print("example.fasta not found, skipping FASTA example")
|
| 613 |
-
|
| 614 |
-
# Example 2: Predict from TXT file
|
| 615 |
-
try:
|
| 616 |
-
txt_results = predictor.predict_from_file('example.txt')
|
| 617 |
-
predictor.save_results(txt_results, 'txt_predictions.csv', format='csv')
|
| 618 |
-
print("TXT predictions saved to txt_predictions.csv")
|
| 619 |
-
except FileNotFoundError:
|
| 620 |
-
print("example.txt not found, skipping TXT example")
|
| 621 |
-
|
| 622 |
-
# Example 3: Predict from text string (original functionality)
|
| 623 |
-
example_sequence = "ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCACAGGTAACGGTGCGGGCTGA"
|
| 624 |
-
text_results = predictor.predict_from_text(example_sequence)
|
| 625 |
-
print(f"Found {text_results['total_genes_found']} genes in example sequence")
|
| 626 |
-
|
| 627 |
-
if __name__ == "__main__":
|
| 628 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|