simplified_tree_AI / predictor.py
re-type's picture
Update predictor.py
9fc06ce verified
raw
history blame
26.2 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Dict, Optional, Union
import logging
import re
import os
from pathlib import Path
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# ============================= FILE READERS =============================
class FileReader:
"""Handles reading DNA sequences from various file formats."""
@staticmethod
def read_fasta(file_path: str) -> Dict[str, str]:
"""
Read FASTA file and return dictionary of sequence_id: sequence
"""
sequences = {}
current_id = None
current_seq = []
try:
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip()
if line.startswith('>'):
# Save previous sequence if exists
if current_id is not None:
sequences[current_id] = ''.join(current_seq)
# Start new sequence
current_id = line[1:] # Remove '>' character
current_seq = []
elif line and current_id is not None:
# Add sequence line (remove any whitespace)
current_seq.append(line.replace(' ', '').replace('\t', ''))
# Don't forget the last sequence
if current_id is not None:
sequences[current_id] = ''.join(current_seq)
except Exception as e:
logging.error(f"Error reading FASTA file {file_path}: {e}")
raise
return sequences
@staticmethod
def read_txt(file_path: str) -> str:
"""
Read plain text file containing DNA sequence
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read().strip()
# Remove any whitespace, newlines, and non-DNA characters
sequence = ''.join(c.upper() for c in content if c.upper() in 'ACTGN')
return sequence
except Exception as e:
logging.error(f"Error reading TXT file {file_path}: {e}")
raise
@staticmethod
def detect_file_type(file_path: str) -> str:
"""
Detect file type based on extension and content
"""
file_path = Path(file_path)
extension = file_path.suffix.lower()
if extension in ['.fasta', '.fa', '.fas', '.fna']:
return 'fasta'
elif extension in ['.txt', '.seq']:
return 'txt'
else:
# Try to detect by content
try:
with open(file_path, 'r', encoding='utf-8') as file:
first_line = file.readline().strip()
if first_line.startswith('>'):
return 'fasta'
else:
return 'txt'
except:
logging.warning(f"Could not detect file type for {file_path}, assuming txt")
return 'txt'
# ============================= ORIGINAL MODEL COMPONENTS =============================
# (Including all the original classes: BoundaryAwareGenePredictor, DNAProcessor, EnhancedPostProcessor)
class BoundaryAwareGenePredictor(nn.Module):
"""Multi-task model predicting genes, starts, and ends separately."""
def __init__(self, input_dim: int = 14, hidden_dim: int = 256,
num_layers: int = 3, dropout: float = 0.3):
super().__init__()
self.conv_layers = nn.ModuleList([
nn.Conv1d(input_dim, hidden_dim//4, kernel_size=k, padding=k//2)
for k in [3, 7, 15, 31]
])
self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, num_layers,
batch_first=True, bidirectional=True, dropout=dropout)
self.norm = nn.LayerNorm(hidden_dim)
self.dropout = nn.Dropout(dropout)
self.boundary_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
self.gene_classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim//2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim//2, 2)
)
self.start_classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim//2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim//2, 2)
)
self.end_classifier = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim//2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim//2, 2)
)
def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
batch_size, seq_len, _ = x.shape
x_conv = x.transpose(1, 2)
conv_features = [F.relu(conv(x_conv)) for conv in self.conv_layers]
features = torch.cat(conv_features, dim=1).transpose(1, 2)
if lengths is not None:
packed = nn.utils.rnn.pack_padded_sequence(
features, lengths.cpu(), batch_first=True, enforce_sorted=False
)
lstm_out, _ = self.lstm(packed)
lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
else:
lstm_out, _ = self.lstm(features)
lstm_out = self.norm(lstm_out)
attended, _ = self.boundary_attention(lstm_out, lstm_out, lstm_out)
attended = self.dropout(attended)
return {
'gene': self.gene_classifier(attended),
'start': self.start_classifier(attended),
'end': self.end_classifier(attended)
}
class DNAProcessor:
"""DNA sequence processor with boundary-aware features."""
def __init__(self):
self.base_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}
self.idx_to_base = {v: k for k, v in self.base_to_idx.items()}
self.start_codons = {'ATG', 'GTG', 'TTG'}
self.stop_codons = {'TAA', 'TAG', 'TGA'}
def encode_sequence(self, sequence: str) -> torch.Tensor:
sequence = sequence.upper()
encoded = [self.base_to_idx.get(base, self.base_to_idx['N']) for base in sequence]
return torch.tensor(encoded, dtype=torch.long)
def create_enhanced_features(self, sequence: str) -> torch.Tensor:
sequence = sequence.upper()
seq_len = len(sequence)
encoded = self.encode_sequence(sequence)
# One-hot encoding
one_hot = torch.zeros(seq_len, 5)
one_hot.scatter_(1, encoded.unsqueeze(1), 1)
features = [one_hot]
# Start codon indicators
start_indicators = torch.zeros(seq_len, 3)
for i in range(seq_len - 2):
codon = sequence[i:i+3]
if codon == 'ATG':
start_indicators[i:i+3, 0] = 1.0
elif codon == 'GTG':
start_indicators[i:i+3, 1] = 0.9
elif codon == 'TTG':
start_indicators[i:i+3, 2] = 0.8
features.append(start_indicators)
# Stop codon indicators
stop_indicators = torch.zeros(seq_len, 3)
for i in range(seq_len - 2):
codon = sequence[i:i+3]
if codon == 'TAA':
stop_indicators[i:i+3, 0] = 1.0
elif codon == 'TAG':
stop_indicators[i:i+3, 1] = 1.0
elif codon == 'TGA':
stop_indicators[i:i+3, 2] = 1.0
features.append(stop_indicators)
# GC content
gc_content = torch.zeros(seq_len, 1)
window_size = 50
for i in range(seq_len):
start = max(0, i - window_size//2)
end = min(seq_len, i + window_size//2)
window = sequence[start:end]
gc_count = window.count('G') + window.count('C')
gc_content[i, 0] = gc_count / len(window) if len(window) > 0 else 0
features.append(gc_content)
# Position encoding
pos_encoding = torch.zeros(seq_len, 2)
positions = torch.arange(seq_len, dtype=torch.float)
pos_encoding[:, 0] = torch.sin(positions / 10000)
pos_encoding[:, 1] = torch.cos(positions / 10000)
features.append(pos_encoding)
return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14
class EnhancedPostProcessor:
"""Enhanced post-processor with stricter boundary detection."""
def __init__(self, min_gene_length: int = 150, max_gene_length: int = 5000):
self.min_gene_length = min_gene_length
self.max_gene_length = max_gene_length
self.start_codons = {'ATG', 'GTG', 'TTG'}
self.stop_codons = {'TAA', 'TAG', 'TGA'}
def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray,
end_probs: np.ndarray, sequence: str = None) -> np.ndarray:
"""Process predictions with enhanced boundary detection."""
gene_pred = (gene_probs[:, 1] > 0.6).astype(int)
start_pred = (start_probs[:, 1] > 0.4).astype(int)
end_pred = (end_probs[:, 1] > 0.5).astype(int)
if sequence is not None:
processed = self._refine_with_codons_and_boundaries(
gene_pred, start_pred, end_pred, sequence
)
else:
processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred)
processed = self._apply_constraints(processed, sequence)
return processed
def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray,
start_pred: np.ndarray, end_pred: np.ndarray,
sequence: str) -> np.ndarray:
refined = gene_pred.copy()
sequence = sequence.upper()
start_codon_positions = []
stop_codon_positions = []
for i in range(len(sequence) - 2):
codon = sequence[i:i+3]
if codon in self.start_codons:
start_codon_positions.append(i)
if codon in self.stop_codons:
stop_codon_positions.append(i + 3)
changes = np.diff(np.concatenate(([0], gene_pred, [0])))
gene_starts = np.where(changes == 1)[0]
gene_ends = np.where(changes == -1)[0]
refined = np.zeros_like(gene_pred)
for g_start, g_end in zip(gene_starts, gene_ends):
best_start = g_start
start_window = 100
nearby_starts = [pos for pos in start_codon_positions
if abs(pos - g_start) <= start_window]
if nearby_starts:
start_scores = []
for pos in nearby_starts:
if pos < len(start_pred):
codon = sequence[pos:pos+3]
codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8)
boundary_score = start_pred[pos]
distance_penalty = abs(pos - g_start) / start_window * 0.2
score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty
start_scores.append((score, pos))
if start_scores:
best_start = max(start_scores, key=lambda x: x[0])[1]
best_end = g_end
end_window = 100
nearby_ends = [pos for pos in stop_codon_positions
if g_start < pos <= g_end + end_window]
if nearby_ends:
end_scores = []
for pos in nearby_ends:
gene_length = pos - best_start
if self.min_gene_length <= gene_length <= self.max_gene_length:
if pos < len(end_pred):
frame_bonus = 0.2 if (pos - best_start) % 3 == 0 else 0
boundary_score = end_pred[pos]
length_penalty = abs(gene_length - 1000) / 10000
score = boundary_score + frame_bonus - length_penalty
end_scores.append((score, pos))
if end_scores:
best_end = max(end_scores, key=lambda x: x[0])[1]
gene_length = best_end - best_start
if (gene_length >= self.min_gene_length and
gene_length <= self.max_gene_length and
best_start < best_end):
refined[best_start:best_end] = 1
return refined
def _refine_with_boundaries(self, gene_pred: np.ndarray, start_pred: np.ndarray,
end_pred: np.ndarray) -> np.ndarray:
refined = gene_pred.copy()
changes = np.diff(np.concatenate(([0], gene_pred, [0])))
gene_starts = np.where(changes == 1)[0]
gene_ends = np.where(changes == -1)[0]
for g_start, g_end in zip(gene_starts, gene_ends):
start_window = slice(max(0, g_start-30), min(len(start_pred), g_start+30))
start_candidates = np.where(start_pred[start_window])[0]
if len(start_candidates) > 0:
relative_positions = start_candidates + max(0, g_start-30)
distances = np.abs(relative_positions - g_start)
best_start_idx = np.argmin(distances)
new_start = relative_positions[best_start_idx]
refined[g_start:new_start] = 0 if new_start > g_start else refined[g_start:new_start]
refined[new_start:g_end] = 1
g_start = new_start
end_window = slice(max(0, g_end-50), min(len(end_pred), g_end+50))
end_candidates = np.where(end_pred[end_window])[0]
if len(end_candidates) > 0:
relative_positions = end_candidates + max(0, g_end-50)
valid_ends = [pos for pos in relative_positions
if self.min_gene_length <= pos - g_start <= self.max_gene_length]
if valid_ends:
distances = np.abs(np.array(valid_ends) - g_end)
new_end = valid_ends[np.argmin(distances)]
refined[g_start:new_end] = 1
refined[new_end:g_end] = 0 if new_end < g_end else refined[new_end:g_end]
return refined
def _apply_constraints(self, predictions: np.ndarray, sequence: str = None) -> np.ndarray:
processed = predictions.copy()
changes = np.diff(np.concatenate(([0], predictions, [0])))
starts = np.where(changes == 1)[0]
ends = np.where(changes == -1)[0]
for start, end in zip(starts, ends):
gene_length = end - start
if gene_length < self.min_gene_length or gene_length > self.max_gene_length:
processed[start:end] = 0
continue
if sequence is not None:
if gene_length % 3 != 0:
new_length = (gene_length // 3) * 3
if new_length >= self.min_gene_length:
new_end = start + new_length
processed[new_end:end] = 0
else:
processed[start:end] = 0
return processed
# ============================= ENHANCED GENE PREDICTOR =============================
class EnhancedGenePredictor:
"""Enhanced Gene Predictor with file input support."""
def __init__(self, model_path: str = 'model/best_boundary_aware_model.pth',
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
self.device = device
self.model = BoundaryAwareGenePredictor(input_dim=14).to(device)
try:
self.model.load_state_dict(torch.load(model_path, map_location=device))
logging.info(f"Loaded model from {model_path}")
except Exception as e:
logging.error(f"Failed to load model: {e}")
raise
self.model.eval()
self.processor = DNAProcessor()
self.post_processor = EnhancedPostProcessor()
self.file_reader = FileReader()
def predict_from_file(self, file_path: str) -> Dict[str, Dict]:
"""
Predict genes from a file (.txt or .fasta)
Returns a dictionary with sequence_id as keys and prediction results as values
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
file_type = self.file_reader.detect_file_type(file_path)
logging.info(f"Detected file type: {file_type}")
results = {}
if file_type == 'fasta':
sequences = self.file_reader.read_fasta(file_path)
for seq_id, sequence in sequences.items():
logging.info(f"Processing sequence: {seq_id} (length: {len(sequence)})")
result = self.predict_sequence(sequence, seq_id)
results[seq_id] = result
else: # txt file
sequence = self.file_reader.read_txt(file_path)
seq_id = Path(file_path).stem # Use filename as sequence ID
logging.info(f"Processing sequence from {file_path} (length: {len(sequence)})")
result = self.predict_sequence(sequence, seq_id)
results[seq_id] = result
return results
def predict_sequence(self, sequence: str, seq_id: str = "sequence") -> Dict:
"""
Predict genes from a single DNA sequence string
"""
sequence = sequence.upper()
if not re.match('^[ACTGN]+$', sequence):
logging.warning(f"Sequence {seq_id} contains invalid characters. Using 'N' for unknowns.")
sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence)
# Handle very long sequences by chunking if needed
max_chunk_size = 50000 # Adjust based on your GPU memory
if len(sequence) > max_chunk_size:
return self._predict_long_sequence(sequence, seq_id, max_chunk_size)
features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(features)
gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
predictions = self.post_processor.process_predictions(
gene_probs, start_probs, end_probs, sequence
)
confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
gene_regions = self.extract_gene_regions(predictions, sequence)
return {
'sequence_id': seq_id,
'sequence_length': len(sequence),
'predictions': predictions.tolist(),
'probabilities': {
'gene': gene_probs.tolist(),
'start': start_probs.tolist(),
'end': end_probs.tolist()
},
'confidence': float(confidence),
'gene_regions': gene_regions,
'total_genes_found': len(gene_regions)
}
def _predict_long_sequence(self, sequence: str, seq_id: str, chunk_size: int) -> Dict:
"""
Handle very long sequences by processing in chunks with overlap
"""
overlap = 1000 # Overlap between chunks to avoid missing genes at boundaries
all_predictions = []
all_gene_probs = []
all_start_probs = []
all_end_probs = []
for i in range(0, len(sequence), chunk_size - overlap):
end_pos = min(i + chunk_size, len(sequence))
chunk = sequence[i:end_pos]
logging.info(f"Processing chunk {i//chunk_size + 1} of sequence {seq_id}")
features = self.processor.create_enhanced_features(chunk).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(features)
gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
chunk_predictions = self.post_processor.process_predictions(
gene_probs, start_probs, end_probs, chunk
)
# Handle overlaps by taking the first chunk fully and subsequent chunks without overlap
if i == 0:
all_predictions.extend(chunk_predictions)
all_gene_probs.extend(gene_probs)
all_start_probs.extend(start_probs)
all_end_probs.extend(end_probs)
else:
# Skip overlap region
skip = min(overlap, len(chunk_predictions))
all_predictions.extend(chunk_predictions[skip:])
all_gene_probs.extend(gene_probs[skip:])
all_start_probs.extend(start_probs[skip:])
all_end_probs.extend(end_probs[skip:])
predictions = np.array(all_predictions)
gene_probs = np.array(all_gene_probs)
start_probs = np.array(all_start_probs)
end_probs = np.array(all_end_probs)
confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
gene_regions = self.extract_gene_regions(predictions, sequence)
return {
'sequence_id': seq_id,
'sequence_length': len(sequence),
'predictions': predictions.tolist(),
'probabilities': {
'gene': gene_probs.tolist(),
'start': start_probs.tolist(),
'end': end_probs.tolist()
},
'confidence': float(confidence),
'gene_regions': gene_regions,
'total_genes_found': len(gene_regions)
}
def predict_from_text(self, sequence: str) -> Dict:
"""
Predict genes from a text string (backward compatibility)
"""
return self.predict_sequence(sequence)
def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]:
"""Extract gene regions from predictions"""
regions = []
changes = np.diff(np.concatenate(([0], predictions, [0])))
starts = np.where(changes == 1)[0]
ends = np.where(changes == -1)[0]
for start, end in zip(starts, ends):
gene_seq = sequence[start:end]
actual_start_codon = None
actual_stop_codon = None
if len(gene_seq) >= 3:
start_codon = gene_seq[:3]
if start_codon in ['ATG', 'GTG', 'TTG']:
actual_start_codon = start_codon
if len(gene_seq) >= 6:
for i in range(len(gene_seq) - 2, 2, -3):
codon = gene_seq[i:i+3]
if codon in ['TAA', 'TAG', 'TGA']:
actual_stop_codon = codon
break
regions.append({
'start': int(start),
'end': int(end),
'sequence': gene_seq,
'length': int(end - start),
'start_codon': actual_start_codon,
'stop_codon': actual_stop_codon,
'in_frame': (end - start) % 3 == 0
})
return regions
def save_results(self, results: Dict[str, Dict], output_path: str, format: str = 'json'):
"""
Save prediction results to file
"""
import json
if format.lower() == 'json':
with open(output_path, 'w') as f:
json.dump(results, f, indent=2)
elif format.lower() == 'csv':
import csv
with open(output_path, 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['sequence_id', 'gene_start', 'gene_end', 'gene_length',
'start_codon', 'stop_codon', 'in_frame', 'confidence'])
for seq_id, result in results.items():
for gene in result['gene_regions']:
writer.writerow([
seq_id, gene['start'], gene['end'], gene['length'],
gene['start_codon'], gene['stop_codon'], gene['in_frame'],
result['confidence']
])
logging.info(f"Results saved to {output_path}")
# ============================= USAGE EXAMPLE =============================
def main():
"""Example usage of the enhanced gene predictor"""
# Initialize predictor
predictor = EnhancedGenePredictor(model_path='model/best_boundary_aware_model.pth')
# Example 1: Predict from FASTA file
try:
fasta_results = predictor.predict_from_file('example.fasta')
predictor.save_results(fasta_results, 'fasta_predictions.json')
print("FASTA predictions saved to fasta_predictions.json")
except FileNotFoundError:
print("example.fasta not found, skipping FASTA example")
# Example 2: Predict from TXT file
try:
txt_results = predictor.predict_from_file('example.txt')
predictor.save_results(txt_results, 'txt_predictions.csv', format='csv')
print("TXT predictions saved to txt_predictions.csv")
except FileNotFoundError:
print("example.txt not found, skipping TXT example")
# Example 3: Predict from text string (original functionality)
example_sequence = "ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCACAGGTAACGGTGCGGGCTGA"
text_results = predictor.predict_from_text(example_sequence)
print(f"Found {text_results['total_genes_found']} genes in example sequence")
if __name__ == "__main__":
main()