re-type commited on
Commit
31e359a
·
1 Parent(s): 1cebb06

Add model file and app files

Browse files
Files changed (3) hide show
  1. app.py +257 -0
  2. best_boundary_aware_model.pth +3 -0
  3. predictor.py +414 -0
app.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import json
5
+ from typing import Optional, List, Dict, Tuple
6
+ import logging
7
+ from predictor import GenePredictor
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Initialize the predictor globally
14
+ try:
15
+ predictor = GenePredictor(model_path='best_boundary_aware_model.pth')
16
+ logger.info("Gene predictor initialized successfully")
17
+ except Exception as e:
18
+ logger.error(f"Failed to initialize predictor: {e}")
19
+ predictor = None
20
+
21
+ def predict_gene_regions(sequence: str,
22
+ ground_truth_labels: Optional[str] = None,
23
+ ground_truth_start: Optional[int] = None,
24
+ ground_truth_end: Optional[int] = None) -> Tuple[str, str, str]:
25
+ """
26
+ Main prediction function for Gradio interface
27
+
28
+ Returns:
29
+ - regions_display: Formatted string showing predicted regions
30
+ - metrics_display: Formatted string showing accuracy metrics (if ground truth provided)
31
+ - detailed_json: JSON string with full prediction details
32
+ """
33
+
34
+ if predictor is None:
35
+ error_msg = "❌ Model not loaded. Please check the model file."
36
+ return error_msg, "", ""
37
+
38
+ # Input validation
39
+ sequence = sequence.strip().upper()
40
+
41
+ if not sequence:
42
+ error_msg = "❌ Sequence cannot be empty"
43
+ return error_msg, "", ""
44
+
45
+ if not all(c in 'ACTGN' for c in sequence):
46
+ error_msg = "❌ Sequence contains invalid characters. Only A, C, T, G, N allowed"
47
+ return error_msg, "", ""
48
+
49
+ # Process ground truth if provided
50
+ labels = None
51
+ try:
52
+ if ground_truth_labels and ground_truth_labels.strip():
53
+ labels = [int(x) for x in ground_truth_labels.split(',')]
54
+ if len(labels) != len(sequence):
55
+ error_msg = f"❌ Labels length ({len(labels)}) must match sequence length ({len(sequence)})"
56
+ return error_msg, "", ""
57
+ if not all(x in (0, 1) for x in labels):
58
+ error_msg = "❌ Labels must be 0 or 1"
59
+ return error_msg, "", ""
60
+ elif ground_truth_start is not None and ground_truth_end is not None:
61
+ start = int(ground_truth_start)
62
+ end = int(ground_truth_end)
63
+ if start < 0 or end > len(sequence) or start >= end:
64
+ error_msg = f"❌ Invalid coordinates: start={start}, end={end}"
65
+ return error_msg, "", ""
66
+ labels = predictor.labels_from_coordinates(len(sequence), start, end)
67
+ except ValueError as e:
68
+ error_msg = f"❌ Invalid ground truth format: {str(e)}"
69
+ return error_msg, "", ""
70
+
71
+ # Make prediction
72
+ try:
73
+ predictions, probs_dict, confidence = predictor.predict(sequence)
74
+ regions = predictor.extract_gene_regions(predictions, sequence)
75
+
76
+ # Format regions display
77
+ regions_display = format_regions_display(regions, confidence)
78
+
79
+ # Compute metrics if ground truth provided
80
+ metrics_display = ""
81
+ metrics = None
82
+ if labels is not None:
83
+ metrics = predictor.compute_accuracy(predictions, labels)
84
+ metrics_display = format_metrics_display(metrics)
85
+
86
+ # Create detailed JSON response
87
+ detailed_response = {
88
+ "regions": regions,
89
+ "confidence": float(confidence),
90
+ "metrics": metrics,
91
+ "sequence_length": len(sequence),
92
+ "num_predicted_genes": len(regions),
93
+ "prediction_summary": {
94
+ "total_gene_positions": int(np.sum(predictions)),
95
+ "gene_coverage": float(np.sum(predictions) / len(predictions))
96
+ }
97
+ }
98
+
99
+ detailed_json = json.dumps(detailed_response, indent=2)
100
+
101
+ return regions_display, metrics_display, detailed_json
102
+
103
+ except Exception as e:
104
+ logger.error(f"Prediction failed: {e}")
105
+ error_msg = f"❌ Prediction failed: {str(e)}"
106
+ return error_msg, "", ""
107
+
108
+ def format_regions_display(regions: List[Dict], confidence: float) -> str:
109
+ """Format the regions for display in the Gradio interface"""
110
+
111
+ if not regions:
112
+ return f"🔍 **No gene regions detected** (Confidence: {confidence:.3f})\n\nThe model did not identify any gene regions in the provided sequence."
113
+
114
+ display = f"🧬 **Found {len(regions)} gene region(s)** (Overall Confidence: {confidence:.3f})\n\n"
115
+
116
+ for i, region in enumerate(regions, 1):
117
+ display += f"**Gene {i}:**\n"
118
+ display += f" • Position: {region['start']} - {region['end']}\n"
119
+ display += f" • Length: {region['length']} bp\n"
120
+ display += f" • Start Codon: {region['start_codon'] or 'None detected'}\n"
121
+ display += f" • Stop Codon: {region['stop_codon'] or 'None detected'}\n"
122
+ display += f" • In Frame: {'✅ Yes' if region['in_frame'] else '❌ No'}\n"
123
+ display += f" • Sequence Preview: {region['sequence'][:60]}{'...' if len(region['sequence']) > 60 else ''}\n\n"
124
+
125
+ return display
126
+
127
+ def format_metrics_display(metrics: Dict) -> str:
128
+ """Format the accuracy metrics for display"""
129
+
130
+ display = "📊 **Accuracy Metrics** (vs Ground Truth)\n\n"
131
+ display += f" • **Accuracy:** {metrics['accuracy']:.3f} ({metrics['accuracy']*100:.1f}%)\n"
132
+ display += f" • **Precision:** {metrics['precision']:.3f}\n"
133
+ display += f" • **Recall:** {metrics['recall']:.3f}\n"
134
+ display += f" • **F1 Score:** {metrics['f1']:.3f}\n\n"
135
+ display += f"**Confusion Matrix:**\n"
136
+ display += f" • True Positives: {metrics['true_positives']}\n"
137
+ display += f" • False Positives: {metrics['false_positives']}\n"
138
+ display += f" • False Negatives: {metrics['false_negatives']}\n"
139
+
140
+ return display
141
+
142
+ def load_example_sequence():
143
+ """Load an example DNA sequence for testing"""
144
+ example = """ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCACAGGTAACGGTGCGGGCTGACGCGTACAGGAAACACAGAAAAAAGCCCGCACCTGACAGTGCGGGCTTTTTTTTTCGACCAAAGGTAACGAGGTAACAACCATGCGAGTGTTGAAGTTCGGCGGTACATCAGTGGCAAATGCAGAACGTTTTCTGCGGGTTGCCGATATTCTGGAAAGCAATGCCAGGCAGGGGCAGGTGGCCACCGTCCTCTCTGCCCCCGCCAAAATCACCAACCACCTGGTGGCGATGATTGAAAAAACCATTAGCGGCCAGGATGCTTTACCCAATATCAGCGATGCCGAACGTATTTTTGCCGAACTTTTGACGGGACTCGCCGCCGCCCAGCCGGGGTTCCCGCTGGCGCAATTGAAAACTTTCGTCGATCAGGAATTTGCCCAAATAAAACATGTCCTGCATGGCATTAGTTTGTTGGGGCAGTGCCCGGATAGCATCAACGCTGCGCTGATTTGCCGTGGCGAGAAAATGTCGATCGCCATTATGGCCGGCGTATTAGAAGCGCGCGGTCACAACGTTACTGTTATCGATCCGGTCGAAAAACTGCTGGCAGTGGGGCATTACCTCGAATCTACCGTCGATATTGCTGAGTCCACCCGCCGTATTGCGGCAAGCCGCATTCCGGCTGATCACATGGTGCTGATGGCAGGTTTCACCGCCGGTAATGAAAAAGGCGAACTGGTGGTGCTTGGACGCAACGGTTCCGACTACTCTGCTGCGGTGCTGGCTGCCTGTTTACGCGCCGATTGTTGCGAGATTTGGACGGACGTTGACGGGGTCTATACCTGCGACCCGCGTCAGGTGCCCGATGCGAGGTTGTTGAAGTCGATGTCCTACCAGGAAGCGATGGAGCTTTCCTACTTCGGCGCTAAAGTTCTTCACCCCCGCACCATTACCCCCATCGCCCAGTTCCAGATCCCTTGCCTGATTAAAAATACCGGAAATCCTCAAGCACCAGGTACGCTCATTGGTGCCAGCCGTGATGAAGACGAATTACCGGTCAAGGGCATTTCCAATCTGAATAACATGGCAATGTTCAGCGTTTCCGGCCCGGGGATGAAAGGGATGGTCGGCATGGCGGCGCGCGTCTTTGCAGCGATGTCACGCGCCCGTATTTCCGTGGTGCTGATTACGCAATCATCTTCCGAATACAGCATCAGTTTCTGCGTTCCACAAAGCGACTGTGTGCGAGCTGAACGGGCAATGCAGGAAGAGTTCTACCTGGAACTGAAAGAAGGCTTACTGGAGCCGCTGGCAGTGACGGAACGGCTGGCCATTATCTCGGTGGTAGGTGATGGTATGCGCACCTTGCGTGGGATCTCGGCACCAGCGAAAGACGGTGGGCCGTGGATAAAGCGCGGCGTCTCGGCGTTTTCGGACCCCGCGGTCTCTTAACCCGAGTCCGAAAATTGTGATCGGGGCCGGGTTTAACGATGGAGCGATCGGGTCAATTGGGGCTGCACCGTTTGACCTGAAGACGCCGGCGGGAAACCGCGTTTCGTTTGCCAGGCGTGAGAGTATTCTTTCCGGCTCCGGTATAGCTGAAACATGAAATGCTTTCCCCTGCGCTTGGCCGATACGCTGGTTTAAGACTTCGGATCGCCGGGAAAGTCGCCCCCCACATTCTGCCAACGATTTGGTTAAAATAGTGACATTGGTGGAAACGGGGAAATGGGTTGACGGTTTTGAAGGGCGTGTCACACCATCGGTTGTTGGCGTTGACAAACGCGATCCGTATAATGAAACTGAATTTGTACACTTTCGCGTCGGGGATGTGGTCAGCAGTTAGGCTCCAATTGATGCCACGTTGACATGATCAATACCTGCGTGCCGGTCACAATCACCTTACCACCCAGTCCGATCAACGCCTGCGCGGGTGCGCAGATACGCGTGGTGTGTCTCGCGAACCGGGATCGTCGCACGGGCATGGAACACTATGGTGAGCAAGGGCGAGGAGTGATTACGCCTGATCTGCTGTTGAGAAGAAGCGCGTCTACCCCTCGGGACAAGGCAAAGAATTTGCTGCAGAAATACGCTGGAGATTGAAGGTTCTGGGAAACGTTTTGTTGACAGTTTACCTCCTGGACGATCCCGCGCCCGCAGGCTGGCGTCGCGATGAAACGAATTTCGGTTCACGGCCGGTGTAAGACGATCGATGGGCAGGGAATTGATGCCGATGCGGATGCCGCACCCGGGAAAGAACACGCTGCTGTGTACTGTCGGGTCGAAGAAAAGCTTGAAAGCGGGCGAAATTTTTCGCGCACCGTCGATGATCCGCACCCGCGAATTCGACCAGTGAAAGCGACTCGCGATGCGGCCGCGCTACAGGTTGTTAACCTGAATGAGGGCTAG"""
145
+ return example
146
+
147
+ # Create the Gradio interface
148
+ def create_interface():
149
+ with gr.Blocks(title="F Gene Prediction Tool", theme=gr.themes.Soft()) as interface:
150
+
151
+ gr.Markdown("""
152
+ # 🧬 F Gene Prediction Tool
153
+
154
+ This tool predicts gene regions in DNA sequences using a boundary-aware deep learning model.
155
+ The model identifies start and end positions of genes, along with confidence scores and detailed analysis.
156
+ """)
157
+
158
+ with gr.Row():
159
+ with gr.Column(scale=2):
160
+ # Input section
161
+ gr.Markdown("## 📝 Input")
162
+
163
+ sequence_input = gr.Textbox(
164
+ label="DNA Sequence",
165
+ placeholder="Enter your DNA sequence (A, C, T, G, N only)...",
166
+ lines=5,
167
+ max_lines=10
168
+ )
169
+
170
+ with gr.Row():
171
+ example_btn = gr.Button("📋 Load Example Sequence", variant="secondary")
172
+ predict_btn = gr.Button("🔬 Predict Genes", variant="primary")
173
+
174
+ # Ground truth section (optional)
175
+ gr.Markdown("## 🎯 Ground Truth (Optional)")
176
+ gr.Markdown("*Provide ground truth data to calculate accuracy metrics*")
177
+
178
+ with gr.Row():
179
+ gt_start = gr.Number(
180
+ label="Ground Truth Start Position",
181
+ precision=0,
182
+ value=None
183
+ )
184
+ gt_end = gr.Number(
185
+ label="Ground Truth End Position",
186
+ precision=0,
187
+ value=None
188
+ )
189
+
190
+ gt_labels = gr.Textbox(
191
+ label="Ground Truth Labels (comma-separated 0s and 1s)",
192
+ placeholder="0,0,1,1,1,0,0... (optional, alternative to start/end)",
193
+ lines=2
194
+ )
195
+
196
+ with gr.Column(scale=3):
197
+ # Output section
198
+ gr.Markdown("## 🔬 Prediction Results")
199
+
200
+ regions_output = gr.Markdown(
201
+ label="Predicted Gene Regions",
202
+ value="*Results will appear here after prediction...*"
203
+ )
204
+
205
+ with gr.Row():
206
+ with gr.Column():
207
+ metrics_output = gr.Markdown(
208
+ label="Accuracy Metrics",
209
+ value="*Metrics will appear here if ground truth is provided...*"
210
+ )
211
+
212
+ # Detailed JSON output (collapsible)
213
+ with gr.Accordion("📄 Detailed JSON Output", open=False):
214
+ json_output = gr.Code(
215
+ label="Full Prediction Details",
216
+ language="json",
217
+ value="{}",
218
+ lines=20
219
+ )
220
+
221
+ # Event handlers
222
+ example_btn.click(
223
+ fn=load_example_sequence,
224
+ outputs=sequence_input
225
+ )
226
+
227
+ predict_btn.click(
228
+ fn=predict_gene_regions,
229
+ inputs=[sequence_input, gt_labels, gt_start, gt_end],
230
+ outputs=[regions_output, metrics_output, json_output]
231
+ )
232
+
233
+ # Also trigger prediction on Enter in the sequence box
234
+ sequence_input.submit(
235
+ fn=predict_gene_regions,
236
+ inputs=[sequence_input, gt_labels, gt_start, gt_end],
237
+ outputs=[regions_output, metrics_output, json_output]
238
+ )
239
+
240
+ # Footer
241
+ gr.Markdown("""
242
+ ---
243
+ **Model Info:** Boundary-aware gene prediction using multi-task deep learning
244
+ **Supported:** DNA sequences with A, C, T, G, N nucleotides
245
+ **Output:** Gene regions with start/end positions, codons, and confidence scores
246
+ """)
247
+
248
+ return interface
249
+
250
+ # Launch the app
251
+ if __name__ == "__main__":
252
+ interface = create_interface()
253
+ interface.launch(
254
+ server_name="0.0.0.0", # Required for Hugging Face Spaces
255
+ server_port=7860, # Standard port for HF Spaces
256
+ share=True
257
+ )
best_boundary_aware_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13c92e4883bba94b680ba84904e2c36a3c01105196c2a935c979b583fe0dc30c
3
+ size 6410291
predictor.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """predictor.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1JURb-0j-R4LWK3oxeGrNxpJm3V6nnX02
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from typing import List, Tuple, Dict, Optional
15
+ import logging
16
+ import re
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+
21
+ # ============================= MODEL COMPONENTS =============================
22
+
23
+ class BoundaryAwareGenePredictor(nn.Module):
24
+ """Multi-task model predicting genes, starts, and ends separately."""
25
+
26
+ def __init__(self, input_dim: int = 14, hidden_dim: int = 256,
27
+ num_layers: int = 3, dropout: float = 0.3):
28
+ super().__init__()
29
+ self.conv_layers = nn.ModuleList([
30
+ nn.Conv1d(input_dim, hidden_dim//4, kernel_size=k, padding=k//2)
31
+ for k in [3, 7, 15, 31]
32
+ ])
33
+ self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, num_layers,
34
+ batch_first=True, bidirectional=True, dropout=dropout)
35
+ self.norm = nn.LayerNorm(hidden_dim)
36
+ self.dropout = nn.Dropout(dropout)
37
+ self.boundary_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
38
+
39
+ self.gene_classifier = nn.Sequential(
40
+ nn.Linear(hidden_dim, hidden_dim//2),
41
+ nn.ReLU(),
42
+ nn.Dropout(dropout),
43
+ nn.Linear(hidden_dim//2, 2)
44
+ )
45
+ self.start_classifier = nn.Sequential(
46
+ nn.Linear(hidden_dim, hidden_dim//2),
47
+ nn.ReLU(),
48
+ nn.Dropout(dropout),
49
+ nn.Linear(hidden_dim//2, 2)
50
+ )
51
+ self.end_classifier = nn.Sequential(
52
+ nn.Linear(hidden_dim, hidden_dim//2),
53
+ nn.ReLU(),
54
+ nn.Dropout(dropout),
55
+ nn.Linear(hidden_dim//2, 2)
56
+ )
57
+
58
+ def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
59
+ batch_size, seq_len, _ = x.shape
60
+ x_conv = x.transpose(1, 2)
61
+ conv_features = [F.relu(conv(x_conv)) for conv in self.conv_layers]
62
+ features = torch.cat(conv_features, dim=1).transpose(1, 2)
63
+
64
+ if lengths is not None:
65
+ packed = nn.utils.rnn.pack_padded_sequence(
66
+ features, lengths.cpu(), batch_first=True, enforce_sorted=False
67
+ )
68
+ lstm_out, _ = self.lstm(packed)
69
+ lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
70
+ else:
71
+ lstm_out, _ = self.lstm(features)
72
+
73
+ lstm_out = self.norm(lstm_out)
74
+ attended, _ = self.boundary_attention(lstm_out, lstm_out, lstm_out)
75
+ attended = self.dropout(attended)
76
+
77
+ return {
78
+ 'gene': self.gene_classifier(attended),
79
+ 'start': self.start_classifier(attended),
80
+ 'end': self.end_classifier(attended)
81
+ }
82
+
83
+ # ============================= DATA PREPROCESSING =============================
84
+
85
+ class DNAProcessor:
86
+ """DNA sequence processor with boundary-aware features."""
87
+
88
+ def __init__(self):
89
+ self.base_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}
90
+ self.idx_to_base = {v: k for k, v in self.base_to_idx.items()}
91
+ self.start_codons = {'ATG', 'GTG', 'TTG'}
92
+ self.stop_codons = {'TAA', 'TAG', 'TGA'}
93
+
94
+ def encode_sequence(self, sequence: str) -> torch.Tensor:
95
+ sequence = sequence.upper()
96
+ encoded = [self.base_to_idx.get(base, self.base_to_idx['N']) for base in sequence]
97
+ return torch.tensor(encoded, dtype=torch.long)
98
+
99
+ def create_enhanced_features(self, sequence: str) -> torch.Tensor:
100
+ sequence = sequence.upper()
101
+ seq_len = len(sequence)
102
+ encoded = self.encode_sequence(sequence)
103
+
104
+ # One-hot encoding
105
+ one_hot = torch.zeros(seq_len, 5)
106
+ one_hot.scatter_(1, encoded.unsqueeze(1), 1)
107
+ features = [one_hot]
108
+
109
+ # Start codon indicators (increased weights for GTG and TTG)
110
+ start_indicators = torch.zeros(seq_len, 3)
111
+ for i in range(seq_len - 2):
112
+ codon = sequence[i:i+3]
113
+ if codon == 'ATG':
114
+ start_indicators[i:i+3, 0] = 1.0
115
+ elif codon == 'GTG':
116
+ start_indicators[i:i+3, 1] = 0.9 # Increased from 0.7
117
+ elif codon == 'TTG':
118
+ start_indicators[i:i+3, 2] = 0.8 # Increased from 0.5
119
+ features.append(start_indicators)
120
+
121
+ # Stop codon indicators
122
+ stop_indicators = torch.zeros(seq_len, 3)
123
+ for i in range(seq_len - 2):
124
+ codon = sequence[i:i+3]
125
+ if codon == 'TAA':
126
+ stop_indicators[i:i+3, 0] = 1.0
127
+ elif codon == 'TAG':
128
+ stop_indicators[i:i+3, 1] = 1.0
129
+ elif codon == 'TGA':
130
+ stop_indicators[i:i+3, 2] = 1.0
131
+ features.append(stop_indicators)
132
+
133
+ # GC content
134
+ gc_content = torch.zeros(seq_len, 1)
135
+ window_size = 50
136
+ for i in range(seq_len):
137
+ start = max(0, i - window_size//2)
138
+ end = min(seq_len, i + window_size//2)
139
+ window = sequence[start:end]
140
+ gc_count = window.count('G') + window.count('C')
141
+ gc_content[i, 0] = gc_count / len(window) if len(window) > 0 else 0
142
+ features.append(gc_content)
143
+
144
+ # Position encoding
145
+ pos_encoding = torch.zeros(seq_len, 2)
146
+ positions = torch.arange(seq_len, dtype=torch.float)
147
+ pos_encoding[:, 0] = torch.sin(positions / 10000)
148
+ pos_encoding[:, 1] = torch.cos(positions / 10000)
149
+ features.append(pos_encoding)
150
+
151
+ return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14
152
+
153
+ # ============================= POST-PROCESSING =============================
154
+
155
+ class EnhancedPostProcessor:
156
+ """Enhanced post-processor with stricter boundary detection."""
157
+
158
+ def __init__(self, min_gene_length: int = 150, max_gene_length: int = 5000):
159
+ self.min_gene_length = min_gene_length
160
+ self.max_gene_length = max_gene_length
161
+ self.start_codons = {'ATG', 'GTG', 'TTG'}
162
+ self.stop_codons = {'TAA', 'TAG', 'TGA'}
163
+
164
+ def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray,
165
+ end_probs: np.ndarray, sequence: str = None) -> np.ndarray:
166
+ """Process predictions with enhanced boundary detection."""
167
+
168
+ # More conservative thresholds
169
+ gene_pred = (gene_probs[:, 1] > 0.6).astype(int)
170
+ start_pred = (start_probs[:, 1] > 0.4).astype(int)
171
+ end_pred = (end_probs[:, 1] > 0.5).astype(int)
172
+
173
+ if sequence is not None:
174
+ processed = self._refine_with_codons_and_boundaries(
175
+ gene_pred, start_pred, end_pred, sequence
176
+ )
177
+ else:
178
+ processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred)
179
+
180
+ processed = self._apply_constraints(processed, sequence)
181
+
182
+ return processed
183
+
184
+ def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray,
185
+ start_pred: np.ndarray, end_pred: np.ndarray,
186
+ sequence: str) -> np.ndarray:
187
+ refined = gene_pred.copy()
188
+ sequence = sequence.upper()
189
+
190
+ start_codon_positions = []
191
+ stop_codon_positions = []
192
+
193
+ for i in range(len(sequence) - 2):
194
+ codon = sequence[i:i+3]
195
+ if codon in self.start_codons:
196
+ start_codon_positions.append(i)
197
+ if codon in self.stop_codons:
198
+ stop_codon_positions.append(i + 3)
199
+
200
+ changes = np.diff(np.concatenate(([0], gene_pred, [0])))
201
+ gene_starts = np.where(changes == 1)[0]
202
+ gene_ends = np.where(changes == -1)[0]
203
+
204
+ refined = np.zeros_like(gene_pred)
205
+
206
+ for g_start, g_end in zip(gene_starts, gene_ends):
207
+ best_start = g_start
208
+ start_window = 100 # Increased from 50
209
+ nearby_starts = [pos for pos in start_codon_positions
210
+ if abs(pos - g_start) <= start_window]
211
+
212
+ if nearby_starts:
213
+ start_scores = []
214
+ for pos in nearby_starts:
215
+ if pos < len(start_pred):
216
+ codon = sequence[pos:pos+3]
217
+ codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8)
218
+ boundary_score = start_pred[pos]
219
+ distance_penalty = abs(pos - g_start) / start_window * 0.2 # Add distance penalty
220
+ score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty
221
+ start_scores.append((score, pos))
222
+
223
+ if start_scores:
224
+ best_start = max(start_scores, key=lambda x: x[0])[1]
225
+
226
+ best_end = g_end
227
+ end_window = 100
228
+ nearby_ends = [pos for pos in stop_codon_positions
229
+ if g_start < pos <= g_end + end_window]
230
+
231
+ if nearby_ends:
232
+ end_scores = []
233
+ for pos in nearby_ends:
234
+ gene_length = pos - best_start
235
+ if self.min_gene_length <= gene_length <= self.max_gene_length:
236
+ if pos < len(end_pred):
237
+ frame_bonus = 0.2 if (pos - best_start) % 3 == 0 else 0
238
+ boundary_score = end_pred[pos]
239
+ length_penalty = abs(gene_length - 1000) / 10000
240
+ score = boundary_score + frame_bonus - length_penalty
241
+ end_scores.append((score, pos))
242
+
243
+ if end_scores:
244
+ best_end = max(end_scores, key=lambda x: x[0])[1]
245
+
246
+ gene_length = best_end - best_start
247
+ if (gene_length >= self.min_gene_length and
248
+ gene_length <= self.max_gene_length and
249
+ best_start < best_end):
250
+ refined[best_start:best_end] = 1
251
+
252
+ return refined
253
+
254
+ def _refine_with_boundaries(self, gene_pred: np.ndarray, start_pred: np.ndarray,
255
+ end_pred: np.ndarray) -> np.ndarray:
256
+ refined = gene_pred.copy()
257
+ changes = np.diff(np.concatenate(([0], gene_pred, [0])))
258
+ gene_starts = np.where(changes == 1)[0]
259
+ gene_ends = np.where(changes == -1)[0]
260
+
261
+ for g_start, g_end in zip(gene_starts, gene_ends):
262
+ start_window = slice(max(0, g_start-30), min(len(start_pred), g_start+30))
263
+ start_candidates = np.where(start_pred[start_window])[0]
264
+ if len(start_candidates) > 0:
265
+ relative_positions = start_candidates + max(0, g_start-30)
266
+ distances = np.abs(relative_positions - g_start)
267
+ best_start_idx = np.argmin(distances)
268
+ new_start = relative_positions[best_start_idx]
269
+ refined[g_start:new_start] = 0 if new_start > g_start else refined[g_start:new_start]
270
+ refined[new_start:g_end] = 1
271
+ g_start = new_start
272
+
273
+ end_window = slice(max(0, g_end-50), min(len(end_pred), g_end+50))
274
+ end_candidates = np.where(end_pred[end_window])[0]
275
+ if len(end_candidates) > 0:
276
+ relative_positions = end_candidates + max(0, g_end-50)
277
+ valid_ends = [pos for pos in relative_positions
278
+ if self.min_gene_length <= pos - g_start <= self.max_gene_length]
279
+ if valid_ends:
280
+ distances = np.abs(np.array(valid_ends) - g_end)
281
+ new_end = valid_ends[np.argmin(distances)]
282
+ refined[g_start:new_end] = 1
283
+ refined[new_end:g_end] = 0 if new_end < g_end else refined[new_end:g_end]
284
+
285
+ return refined
286
+
287
+ def _apply_constraints(self, predictions: np.ndarray, sequence: str = None) -> np.ndarray:
288
+ processed = predictions.copy()
289
+ changes = np.diff(np.concatenate(([0], predictions, [0])))
290
+ starts = np.where(changes == 1)[0]
291
+ ends = np.where(changes == -1)[0]
292
+
293
+ for start, end in zip(starts, ends):
294
+ gene_length = end - start
295
+ if gene_length < self.min_gene_length or gene_length > self.max_gene_length:
296
+ processed[start:end] = 0
297
+ continue
298
+ if sequence is not None:
299
+ if gene_length % 3 != 0:
300
+ new_length = (gene_length // 3) * 3
301
+ if new_length >= self.min_gene_length:
302
+ new_end = start + new_length
303
+ processed[new_end:end] = 0
304
+ else:
305
+ processed[start:end] = 0
306
+
307
+ return processed
308
+
309
+ # ============================= PREDICTION =============================
310
+
311
+ class GenePredictor:
312
+ """Handles gene prediction using the trained boundary-aware model."""
313
+
314
+ def __init__(self, model_path: str = 'model/best_boundary_aware_model.pth',
315
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
316
+ self.device = device
317
+ self.model = BoundaryAwareGenePredictor(input_dim=14).to(device)
318
+ try:
319
+ self.model.load_state_dict(torch.load(model_path, map_location=device))
320
+ logging.info(f"Loaded model from {model_path}")
321
+ except Exception as e:
322
+ logging.error(f"Failed to load model: {e}")
323
+ raise
324
+ self.model.eval()
325
+ self.processor = DNAProcessor()
326
+ self.post_processor = EnhancedPostProcessor()
327
+
328
+ def predict(self, sequence: str) -> Tuple[np.ndarray, Dict[str, np.ndarray], float]:
329
+ sequence = sequence.upper()
330
+ if not re.match('^[ACTGN]+$', sequence):
331
+ logging.warning("Sequence contains invalid characters. Using 'N' for unknowns.")
332
+ sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence)
333
+
334
+ features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device)
335
+
336
+ with torch.no_grad():
337
+ outputs = self.model(features)
338
+ gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
339
+ start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
340
+ end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
341
+
342
+ predictions = self.post_processor.process_predictions(
343
+ gene_probs, start_probs, end_probs, sequence
344
+ )
345
+ confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
346
+
347
+ return predictions, {'gene': gene_probs, 'start': start_probs, 'end': end_probs}, confidence
348
+
349
+ def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]:
350
+ regions = []
351
+ changes = np.diff(np.concatenate(([0], predictions, [0])))
352
+ starts = np.where(changes == 1)[0]
353
+ ends = np.where(changes == -1)[0]
354
+
355
+ for start, end in zip(starts, ends):
356
+ gene_seq = sequence[start:end]
357
+ actual_start_codon = None
358
+ actual_stop_codon = None
359
+
360
+ if len(gene_seq) >= 3:
361
+ start_codon = gene_seq[:3]
362
+ if start_codon in ['ATG', 'GTG', 'TTG']:
363
+ actual_start_codon = start_codon
364
+
365
+ if len(gene_seq) >= 6:
366
+ for i in range(len(gene_seq) - 2, 2, -3):
367
+ codon = gene_seq[i:i+3]
368
+ if codon in ['TAA', 'TAG', 'TGA']:
369
+ actual_stop_codon = codon
370
+ break
371
+
372
+ regions.append({
373
+ 'start': int(start), # Convert to Python int for JSON serialization
374
+ 'end': int(end),
375
+ 'sequence': gene_seq, # Return full sequence
376
+ 'length': int(end - start),
377
+ 'start_codon': actual_start_codon,
378
+ 'stop_codon': actual_stop_codon,
379
+ 'in_frame': (end - start) % 3 == 0
380
+ })
381
+
382
+ return regions
383
+
384
+ def compute_accuracy(self, predictions: np.ndarray, labels: List[int]) -> Dict:
385
+ min_len = min(len(predictions), len(labels))
386
+ predictions = predictions[:min_len]
387
+ labels = np.array(labels[:min_len])
388
+
389
+ accuracy = np.mean(predictions == labels)
390
+ true_pos = np.sum((predictions == 1) & (labels == 1))
391
+ false_neg = np.sum((predictions == 0) & (labels == 1))
392
+ false_pos = np.sum((predictions == 1) & (labels == 0))
393
+
394
+ precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) > 0 else 0.0
395
+ recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) > 0 else 0.0
396
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
397
+
398
+ return {
399
+ 'accuracy': accuracy,
400
+ 'precision': precision,
401
+ 'recall': recall,
402
+ 'f1': f1,
403
+ 'true_positives': int(true_pos),
404
+ 'false_positives': int(false_pos),
405
+ 'false_negatives': int(false_neg)
406
+ }
407
+
408
+ def labels_from_coordinates(self, seq_len: int, start: int, end: int) -> List[int]:
409
+ labels = [0] * seq_len
410
+ start = max(0, min(start, seq_len - 1))
411
+ end = max(start, min(end, seq_len))
412
+ for i in range(start, end):
413
+ labels[i] = 1
414
+ return labels