re-type commited on
Commit
22db390
·
verified ·
1 Parent(s): 74167c4

Update predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +375 -406
predictor.py CHANGED
@@ -1,414 +1,383 @@
1
- # -*- coding: utf-8 -*-
2
- """predictor.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1JURb-0j-R4LWK3oxeGrNxpJm3V6nnX02
8
- """
9
-
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
  import numpy as np
14
- from typing import List, Tuple, Dict, Optional
15
- import logging
16
  import re
17
-
18
- # Configure logging
19
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
-
21
- # ============================= MODEL COMPONENTS =============================
22
-
23
- class BoundaryAwareGenePredictor(nn.Module):
24
- """Multi-task model predicting genes, starts, and ends separately."""
25
-
26
- def __init__(self, input_dim: int = 14, hidden_dim: int = 256,
27
- num_layers: int = 3, dropout: float = 0.3):
28
- super().__init__()
29
- self.conv_layers = nn.ModuleList([
30
- nn.Conv1d(input_dim, hidden_dim//4, kernel_size=k, padding=k//2)
31
- for k in [3, 7, 15, 31]
32
- ])
33
- self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, num_layers,
34
- batch_first=True, bidirectional=True, dropout=dropout)
35
- self.norm = nn.LayerNorm(hidden_dim)
36
- self.dropout = nn.Dropout(dropout)
37
- self.boundary_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
38
-
39
- self.gene_classifier = nn.Sequential(
40
- nn.Linear(hidden_dim, hidden_dim//2),
41
- nn.ReLU(),
42
- nn.Dropout(dropout),
43
- nn.Linear(hidden_dim//2, 2)
44
- )
45
- self.start_classifier = nn.Sequential(
46
- nn.Linear(hidden_dim, hidden_dim//2),
47
- nn.ReLU(),
48
- nn.Dropout(dropout),
49
- nn.Linear(hidden_dim//2, 2)
50
- )
51
- self.end_classifier = nn.Sequential(
52
- nn.Linear(hidden_dim, hidden_dim//2),
53
- nn.ReLU(),
54
- nn.Dropout(dropout),
55
- nn.Linear(hidden_dim//2, 2)
56
- )
57
-
58
- def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
59
- batch_size, seq_len, _ = x.shape
60
- x_conv = x.transpose(1, 2)
61
- conv_features = [F.relu(conv(x_conv)) for conv in self.conv_layers]
62
- features = torch.cat(conv_features, dim=1).transpose(1, 2)
63
-
64
- if lengths is not None:
65
- packed = nn.utils.rnn.pack_padded_sequence(
66
- features, lengths.cpu(), batch_first=True, enforce_sorted=False
67
- )
68
- lstm_out, _ = self.lstm(packed)
69
- lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  else:
71
- lstm_out, _ = self.lstm(features)
72
-
73
- lstm_out = self.norm(lstm_out)
74
- attended, _ = self.boundary_attention(lstm_out, lstm_out, lstm_out)
75
- attended = self.dropout(attended)
76
-
 
 
 
 
 
77
  return {
78
- 'gene': self.gene_classifier(attended),
79
- 'start': self.start_classifier(attended),
80
- 'end': self.end_classifier(attended)
 
 
81
  }
82
 
83
- # ============================= DATA PREPROCESSING =============================
84
-
85
- class DNAProcessor:
86
- """DNA sequence processor with boundary-aware features."""
87
-
88
- def __init__(self):
89
- self.base_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}
90
- self.idx_to_base = {v: k for k, v in self.base_to_idx.items()}
91
- self.start_codons = {'ATG', 'GTG', 'TTG'}
92
- self.stop_codons = {'TAA', 'TAG', 'TGA'}
93
-
94
- def encode_sequence(self, sequence: str) -> torch.Tensor:
95
- sequence = sequence.upper()
96
- encoded = [self.base_to_idx.get(base, self.base_to_idx['N']) for base in sequence]
97
- return torch.tensor(encoded, dtype=torch.long)
98
-
99
- def create_enhanced_features(self, sequence: str) -> torch.Tensor:
100
- sequence = sequence.upper()
101
- seq_len = len(sequence)
102
- encoded = self.encode_sequence(sequence)
103
-
104
- # One-hot encoding
105
- one_hot = torch.zeros(seq_len, 5)
106
- one_hot.scatter_(1, encoded.unsqueeze(1), 1)
107
- features = [one_hot]
108
-
109
- # Start codon indicators (increased weights for GTG and TTG)
110
- start_indicators = torch.zeros(seq_len, 3)
111
- for i in range(seq_len - 2):
112
- codon = sequence[i:i+3]
113
- if codon == 'ATG':
114
- start_indicators[i:i+3, 0] = 1.0
115
- elif codon == 'GTG':
116
- start_indicators[i:i+3, 1] = 0.9 # Increased from 0.7
117
- elif codon == 'TTG':
118
- start_indicators[i:i+3, 2] = 0.8 # Increased from 0.5
119
- features.append(start_indicators)
120
-
121
- # Stop codon indicators
122
- stop_indicators = torch.zeros(seq_len, 3)
123
- for i in range(seq_len - 2):
124
- codon = sequence[i:i+3]
125
- if codon == 'TAA':
126
- stop_indicators[i:i+3, 0] = 1.0
127
- elif codon == 'TAG':
128
- stop_indicators[i:i+3, 1] = 1.0
129
- elif codon == 'TGA':
130
- stop_indicators[i:i+3, 2] = 1.0
131
- features.append(stop_indicators)
132
-
133
- # GC content
134
- gc_content = torch.zeros(seq_len, 1)
135
- window_size = 50
136
- for i in range(seq_len):
137
- start = max(0, i - window_size//2)
138
- end = min(seq_len, i + window_size//2)
139
- window = sequence[start:end]
140
- gc_count = window.count('G') + window.count('C')
141
- gc_content[i, 0] = gc_count / len(window) if len(window) > 0 else 0
142
- features.append(gc_content)
143
-
144
- # Position encoding
145
- pos_encoding = torch.zeros(seq_len, 2)
146
- positions = torch.arange(seq_len, dtype=torch.float)
147
- pos_encoding[:, 0] = torch.sin(positions / 10000)
148
- pos_encoding[:, 1] = torch.cos(positions / 10000)
149
- features.append(pos_encoding)
150
-
151
- return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14
152
-
153
- # ============================= POST-PROCESSING =============================
154
-
155
- class EnhancedPostProcessor:
156
- """Enhanced post-processor with stricter boundary detection."""
157
-
158
- def __init__(self, min_gene_length: int = 150, max_gene_length: int = 5000):
159
- self.min_gene_length = min_gene_length
160
- self.max_gene_length = max_gene_length
161
- self.start_codons = {'ATG', 'GTG', 'TTG'}
162
- self.stop_codons = {'TAA', 'TAG', 'TGA'}
163
-
164
- def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray,
165
- end_probs: np.ndarray, sequence: str = None) -> np.ndarray:
166
- """Process predictions with enhanced boundary detection."""
167
-
168
- # More conservative thresholds
169
- gene_pred = (gene_probs[:, 1] > 0.6).astype(int)
170
- start_pred = (start_probs[:, 1] > 0.4).astype(int)
171
- end_pred = (end_probs[:, 1] > 0.5).astype(int)
172
-
173
- if sequence is not None:
174
- processed = self._refine_with_codons_and_boundaries(
175
- gene_pred, start_pred, end_pred, sequence
176
- )
177
  else:
178
- processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred)
179
-
180
- processed = self._apply_constraints(processed, sequence)
181
-
182
- return processed
183
-
184
- def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray,
185
- start_pred: np.ndarray, end_pred: np.ndarray,
186
- sequence: str) -> np.ndarray:
187
- refined = gene_pred.copy()
188
- sequence = sequence.upper()
189
-
190
- start_codon_positions = []
191
- stop_codon_positions = []
192
-
193
- for i in range(len(sequence) - 2):
194
- codon = sequence[i:i+3]
195
- if codon in self.start_codons:
196
- start_codon_positions.append(i)
197
- if codon in self.stop_codons:
198
- stop_codon_positions.append(i + 3)
199
-
200
- changes = np.diff(np.concatenate(([0], gene_pred, [0])))
201
- gene_starts = np.where(changes == 1)[0]
202
- gene_ends = np.where(changes == -1)[0]
203
-
204
- refined = np.zeros_like(gene_pred)
205
-
206
- for g_start, g_end in zip(gene_starts, gene_ends):
207
- best_start = g_start
208
- start_window = 100 # Increased from 50
209
- nearby_starts = [pos for pos in start_codon_positions
210
- if abs(pos - g_start) <= start_window]
211
-
212
- if nearby_starts:
213
- start_scores = []
214
- for pos in nearby_starts:
215
- if pos < len(start_pred):
216
- codon = sequence[pos:pos+3]
217
- codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8)
218
- boundary_score = start_pred[pos]
219
- distance_penalty = abs(pos - g_start) / start_window * 0.2 # Add distance penalty
220
- score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty
221
- start_scores.append((score, pos))
222
-
223
- if start_scores:
224
- best_start = max(start_scores, key=lambda x: x[0])[1]
225
-
226
- best_end = g_end
227
- end_window = 100
228
- nearby_ends = [pos for pos in stop_codon_positions
229
- if g_start < pos <= g_end + end_window]
230
-
231
- if nearby_ends:
232
- end_scores = []
233
- for pos in nearby_ends:
234
- gene_length = pos - best_start
235
- if self.min_gene_length <= gene_length <= self.max_gene_length:
236
- if pos < len(end_pred):
237
- frame_bonus = 0.2 if (pos - best_start) % 3 == 0 else 0
238
- boundary_score = end_pred[pos]
239
- length_penalty = abs(gene_length - 1000) / 10000
240
- score = boundary_score + frame_bonus - length_penalty
241
- end_scores.append((score, pos))
242
-
243
- if end_scores:
244
- best_end = max(end_scores, key=lambda x: x[0])[1]
245
-
246
- gene_length = best_end - best_start
247
- if (gene_length >= self.min_gene_length and
248
- gene_length <= self.max_gene_length and
249
- best_start < best_end):
250
- refined[best_start:best_end] = 1
251
-
252
- return refined
253
-
254
- def _refine_with_boundaries(self, gene_pred: np.ndarray, start_pred: np.ndarray,
255
- end_pred: np.ndarray) -> np.ndarray:
256
- refined = gene_pred.copy()
257
- changes = np.diff(np.concatenate(([0], gene_pred, [0])))
258
- gene_starts = np.where(changes == 1)[0]
259
- gene_ends = np.where(changes == -1)[0]
260
-
261
- for g_start, g_end in zip(gene_starts, gene_ends):
262
- start_window = slice(max(0, g_start-30), min(len(start_pred), g_start+30))
263
- start_candidates = np.where(start_pred[start_window])[0]
264
- if len(start_candidates) > 0:
265
- relative_positions = start_candidates + max(0, g_start-30)
266
- distances = np.abs(relative_positions - g_start)
267
- best_start_idx = np.argmin(distances)
268
- new_start = relative_positions[best_start_idx]
269
- refined[g_start:new_start] = 0 if new_start > g_start else refined[g_start:new_start]
270
- refined[new_start:g_end] = 1
271
- g_start = new_start
272
-
273
- end_window = slice(max(0, g_end-50), min(len(end_pred), g_end+50))
274
- end_candidates = np.where(end_pred[end_window])[0]
275
- if len(end_candidates) > 0:
276
- relative_positions = end_candidates + max(0, g_end-50)
277
- valid_ends = [pos for pos in relative_positions
278
- if self.min_gene_length <= pos - g_start <= self.max_gene_length]
279
- if valid_ends:
280
- distances = np.abs(np.array(valid_ends) - g_end)
281
- new_end = valid_ends[np.argmin(distances)]
282
- refined[g_start:new_end] = 1
283
- refined[new_end:g_end] = 0 if new_end < g_end else refined[new_end:g_end]
284
-
285
- return refined
286
-
287
- def _apply_constraints(self, predictions: np.ndarray, sequence: str = None) -> np.ndarray:
288
- processed = predictions.copy()
289
- changes = np.diff(np.concatenate(([0], predictions, [0])))
290
- starts = np.where(changes == 1)[0]
291
- ends = np.where(changes == -1)[0]
292
-
293
- for start, end in zip(starts, ends):
294
- gene_length = end - start
295
- if gene_length < self.min_gene_length or gene_length > self.max_gene_length:
296
- processed[start:end] = 0
297
- continue
298
- if sequence is not None:
299
- if gene_length % 3 != 0:
300
- new_length = (gene_length // 3) * 3
301
- if new_length >= self.min_gene_length:
302
- new_end = start + new_length
303
- processed[new_end:end] = 0
304
- else:
305
- processed[start:end] = 0
306
-
307
- return processed
308
-
309
- # ============================= PREDICTION =============================
310
-
311
- class GenePredictor:
312
- """Handles gene prediction using the trained boundary-aware model."""
313
-
314
- def __init__(self, model_path: str = 'model/best_boundary_aware_model.pth',
315
- device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
316
- self.device = device
317
- self.model = BoundaryAwareGenePredictor(input_dim=14).to(device)
318
- try:
319
- self.model.load_state_dict(torch.load(model_path, map_location=device))
320
- logging.info(f"Loaded model from {model_path}")
321
- except Exception as e:
322
- logging.error(f"Failed to load model: {e}")
323
- raise
324
- self.model.eval()
325
- self.processor = DNAProcessor()
326
- self.post_processor = EnhancedPostProcessor()
327
-
328
- def predict(self, sequence: str) -> Tuple[np.ndarray, Dict[str, np.ndarray], float]:
329
- sequence = sequence.upper()
330
- if not re.match('^[ACTGN]+$', sequence):
331
- logging.warning("Sequence contains invalid characters. Using 'N' for unknowns.")
332
- sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence)
333
-
334
- features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device)
335
-
336
- with torch.no_grad():
337
- outputs = self.model(features)
338
- gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
339
- start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
340
- end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
341
-
342
- predictions = self.post_processor.process_predictions(
343
- gene_probs, start_probs, end_probs, sequence
344
  )
345
- confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
346
-
347
- return predictions, {'gene': gene_probs, 'start': start_probs, 'end': end_probs}, confidence
348
-
349
- def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]:
350
- regions = []
351
- changes = np.diff(np.concatenate(([0], predictions, [0])))
352
- starts = np.where(changes == 1)[0]
353
- ends = np.where(changes == -1)[0]
354
-
355
- for start, end in zip(starts, ends):
356
- gene_seq = sequence[start:end]
357
- actual_start_codon = None
358
- actual_stop_codon = None
359
-
360
- if len(gene_seq) >= 3:
361
- start_codon = gene_seq[:3]
362
- if start_codon in ['ATG', 'GTG', 'TTG']:
363
- actual_start_codon = start_codon
364
-
365
- if len(gene_seq) >= 6:
366
- for i in range(len(gene_seq) - 2, 2, -3):
367
- codon = gene_seq[i:i+3]
368
- if codon in ['TAA', 'TAG', 'TGA']:
369
- actual_stop_codon = codon
370
- break
371
-
372
- regions.append({
373
- 'start': int(start), # Convert to Python int for JSON serialization
374
- 'end': int(end),
375
- 'sequence': gene_seq, # Return full sequence
376
- 'length': int(end - start),
377
- 'start_codon': actual_start_codon,
378
- 'stop_codon': actual_stop_codon,
379
- 'in_frame': (end - start) % 3 == 0
380
- })
381
-
382
- return regions
383
-
384
- def compute_accuracy(self, predictions: np.ndarray, labels: List[int]) -> Dict:
385
- min_len = min(len(predictions), len(labels))
386
- predictions = predictions[:min_len]
387
- labels = np.array(labels[:min_len])
388
-
389
- accuracy = np.mean(predictions == labels)
390
- true_pos = np.sum((predictions == 1) & (labels == 1))
391
- false_neg = np.sum((predictions == 0) & (labels == 1))
392
- false_pos = np.sum((predictions == 1) & (labels == 0))
393
-
394
- precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) > 0 else 0.0
395
- recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) > 0 else 0.0
396
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
397
-
398
- return {
399
- 'accuracy': accuracy,
400
- 'precision': precision,
401
- 'recall': recall,
402
- 'f1': f1,
403
- 'true_positives': int(true_pos),
404
- 'false_positives': int(false_pos),
405
- 'false_negatives': int(false_neg)
406
- }
407
-
408
- def labels_from_coordinates(self, seq_len: int, start: int, end: int) -> List[int]:
409
- labels = [0] * seq_len
410
- start = max(0, min(start, seq_len - 1))
411
- end = max(start, min(end, seq_len))
412
- for i in range(start, end):
413
- labels[i] = 1
414
- return labels
 
1
+ # Improved F Gene Prediction Functions
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
 
 
3
  import re
4
+ import logging
5
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
6
+
7
+ def preprocess_sequence_for_ndv_f_gene(sequence):
8
+ """Enhanced preprocessing specifically for NDV F gene sequences"""
9
+ try:
10
+ # Convert to uppercase and remove whitespace
11
+ sequence = sequence.upper().strip()
12
+
13
+ # Remove non-nucleotide characters except N
14
+ sequence = re.sub(r'[^ATCGN]', '', sequence)
15
+
16
+ # NDV F gene specific checks
17
+ # NDV F gene is typically around 1662-1800 nucleotides
18
+ if len(sequence) < 1000:
19
+ logging.warning(f"Sequence length ({len(sequence)}) shorter than typical NDV F gene (1662-1800 nt)")
20
+
21
+ # Check for start codon (ATG) - NDV F gene should start with ATG
22
+ if not sequence.startswith('ATG'):
23
+ logging.warning("Sequence doesn't start with ATG start codon")
24
+ # Try to find the first ATG
25
+ atg_pos = sequence.find('ATG')
26
+ if atg_pos != -1:
27
+ sequence = sequence[atg_pos:]
28
+ logging.info(f"Found ATG at position {atg_pos}, using sequence from there")
29
+
30
+ # Check reading frame (sequence length should be divisible by 3)
31
+ if len(sequence) % 3 != 0:
32
+ # Trim to make it divisible by 3
33
+ sequence = sequence[:len(sequence) - (len(sequence) % 3)]
34
+ logging.info(f"Trimmed sequence to maintain reading frame: {len(sequence)} nt")
35
+
36
+ # Look for NDV F gene specific motifs
37
+ # Fusion peptide region (typically around position 117-137)
38
+ # Heptad repeat regions
39
+ # These are characteristic of NDV F protein
40
+
41
+ return sequence
42
+
43
+ except Exception as e:
44
+ logging.error(f"Sequence preprocessing failed: {e}")
45
+ return sequence
46
+
47
+ def enhanced_keras_prediction(sequence, keras_model, kmer_to_index, kmer_size=6):
48
+ """Enhanced Keras prediction with better handling for NDV F gene"""
49
+ try:
50
+ if not keras_model or not kmer_to_index:
51
+ return "Keras model not available"
52
+
53
+ # Preprocess sequence
54
+ processed_seq = preprocess_sequence_for_ndv_f_gene(sequence)
55
+
56
+ if len(processed_seq) < kmer_size:
57
+ return f"Sequence too short for k-mer prediction (minimum {kmer_size} nucleotides required)"
58
+
59
+ # Generate k-mers
60
+ kmers = [processed_seq[i:i+kmer_size] for i in range(len(processed_seq)-kmer_size+1)]
61
+
62
+ # Convert k-mers to indices
63
+ indices = []
64
+ unknown_kmers = 0
65
+ for kmer in kmers:
66
+ if kmer in kmer_to_index:
67
+ indices.append(kmer_to_index[kmer])
68
+ else:
69
+ indices.append(0) # Unknown k-mer
70
+ unknown_kmers += 1
71
+
72
+ # Log statistics
73
+ logging.info(f"Generated {len(kmers)} k-mers, {unknown_kmers} unknown k-mers")
74
+
75
+ # Prepare input for model
76
+ input_arr = np.array([indices])
77
+
78
+ # Get prediction
79
+ prediction = keras_model.predict(input_arr, verbose=0)[0]
80
+
81
+ # Enhanced interpretation
82
+ max_prob = np.max(prediction)
83
+ mean_prob = np.mean(prediction)
84
+
85
+ # Calculate confidence metrics
86
+ confidence_score = max_prob
87
+ consistency_score = 1.0 - np.std(prediction) # Lower std = more consistent
88
+
89
+ result = {
90
+ 'raw_prediction': prediction.tolist(),
91
+ 'max_probability': float(max_prob),
92
+ 'mean_probability': float(mean_prob),
93
+ 'confidence_score': float(confidence_score),
94
+ 'consistency_score': float(consistency_score),
95
+ 'sequence_length': len(processed_seq),
96
+ 'kmers_generated': len(kmers),
97
+ 'unknown_kmers': unknown_kmers,
98
+ 'kmer_coverage': 1.0 - (unknown_kmers / len(kmers)) if kmers else 0.0
99
+ }
100
+
101
+ return result
102
+
103
+ except Exception as e:
104
+ logging.error(f"Enhanced Keras prediction failed: {e}")
105
+ return f"Enhanced Keras prediction failed: {str(e)}"
106
+
107
+ def enhanced_classify_sequence(sequence, classifier_model, classifier_kmer_to_index, classifier_maxlen, labels):
108
+ """Enhanced classification with NDV F gene specific improvements"""
109
+ try:
110
+ if not classifier_model or not classifier_kmer_to_index or classifier_maxlen is None:
111
+ return {
112
+ "status": "error",
113
+ "message": "Classification model not available",
114
+ "confidence": None,
115
+ "predicted_label": None,
116
+ "details": {}
117
+ }
118
+
119
+ # Preprocess sequence
120
+ processed_seq = preprocess_sequence_for_ndv_f_gene(sequence)
121
+
122
+ # NDV F gene specific length check
123
+ if len(processed_seq) < 1000:
124
+ return {
125
+ "status": "warning",
126
+ "message": f"Sequence shorter than typical NDV F gene ({len(processed_seq)} < 1000 nt)",
127
+ "confidence": None,
128
+ "predicted_label": None,
129
+ "details": {"sequence_length": len(processed_seq)}
130
+ }
131
+
132
+ # Generate k-mers (6-mers)
133
+ kmer_size = 6
134
+ tokens = [processed_seq[i:i+kmer_size] for i in range(len(processed_seq)-kmer_size+1)]
135
+
136
+ # Encode k-mers
137
+ encoded = []
138
+ unknown_count = 0
139
+ for kmer in tokens:
140
+ if kmer in classifier_kmer_to_index:
141
+ encoded.append(classifier_kmer_to_index[kmer])
142
+ else:
143
+ encoded.append(0) # Unknown k-mer
144
+ unknown_count += 1
145
+
146
+ # Pad sequences
147
+ padded = pad_sequences([encoded], maxlen=classifier_maxlen, padding='post')
148
+
149
+ # Get prediction
150
+ pred = classifier_model.predict(padded, verbose=0)
151
+ predicted_class = int(np.argmax(pred))
152
+ confidence = float(np.max(pred))
153
+ predicted_label = labels[predicted_class] if predicted_class < len(labels) else "Unknown"
154
+
155
+ # Calculate additional metrics
156
+ kmer_coverage = 1.0 - (unknown_count / len(tokens)) if tokens else 0.0
157
+ prediction_entropy = -np.sum(pred[0] * np.log(pred[0] + 1e-10)) # Lower entropy = more confident
158
+
159
+ details = {
160
+ "sequence_length": len(processed_seq),
161
+ "kmers_generated": len(tokens),
162
+ "unknown_kmers": unknown_count,
163
+ "kmer_coverage": kmer_coverage,
164
+ "prediction_entropy": float(prediction_entropy),
165
+ "all_probabilities": {labels[i]: float(pred[0][i]) for i in range(len(labels)) if i < len(pred[0])},
166
+ "starts_with_atg": processed_seq.startswith('ATG'),
167
+ "length_in_frame": len(processed_seq) % 3 == 0
168
+ }
169
+
170
+ # Enhanced decision logic for NDV F gene
171
+ if predicted_label == "F":
172
+ # Additional checks for F gene confidence
173
+ f_gene_score = confidence
174
+
175
+ # Bonus for good k-mer coverage
176
+ if kmer_coverage > 0.8:
177
+ f_gene_score *= 1.1
178
+
179
+ # Bonus for proper start codon
180
+ if processed_seq.startswith('ATG'):
181
+ f_gene_score *= 1.05
182
+
183
+ # Bonus for proper reading frame
184
+ if len(processed_seq) % 3 == 0:
185
+ f_gene_score *= 1.05
186
+
187
+ # Bonus for appropriate length (NDV F gene is ~1662-1800 nt)
188
+ if 1500 <= len(processed_seq) <= 2000:
189
+ f_gene_score *= 1.1
190
+
191
+ details["enhanced_f_score"] = min(f_gene_score, 1.0)
192
+
193
+ if f_gene_score > 0.7:
194
+ return {
195
+ "status": "success",
196
+ "message": "NDV F gene detected with high confidence",
197
+ "confidence": confidence,
198
+ "predicted_label": predicted_label,
199
+ "details": details
200
+ }
201
+ elif f_gene_score > 0.5:
202
+ return {
203
+ "status": "success",
204
+ "message": "NDV F gene detected with moderate confidence",
205
+ "confidence": confidence,
206
+ "predicted_label": predicted_label,
207
+ "details": details
208
+ }
209
+ else:
210
+ return {
211
+ "status": "warning",
212
+ "message": "Possible F gene but low confidence - check sequence quality",
213
+ "confidence": confidence,
214
+ "predicted_label": predicted_label,
215
+ "details": details
216
+ }
217
+
218
+ elif predicted_label == "Random":
219
+ # Check if it might still be an F gene with issues
220
+ if kmer_coverage < 0.5:
221
+ return {
222
+ "status": "error",
223
+ "message": f"Poor sequence quality detected (coverage: {kmer_coverage:.1%}). Check for sequencing errors.",
224
+ "confidence": confidence,
225
+ "predicted_label": predicted_label,
226
+ "details": details
227
+ }
228
+ else:
229
+ return {
230
+ "status": "error",
231
+ "message": "Sequence does not appear to be NDV F gene. Verify input sequence.",
232
+ "confidence": confidence,
233
+ "predicted_label": predicted_label,
234
+ "details": details
235
+ }
236
+
237
  else:
238
+ # Other gene detected
239
+ return {
240
+ "status": "error",
241
+ "message": f"Detected as {predicted_label} gene, not F gene. Please provide NDV F gene sequence.",
242
+ "confidence": confidence,
243
+ "predicted_label": predicted_label,
244
+ "details": details
245
+ }
246
+
247
+ except Exception as e:
248
+ logging.error(f"Enhanced classification failed: {e}")
249
  return {
250
+ "status": "error",
251
+ "message": f"Classification failed: {str(e)}",
252
+ "confidence": None,
253
+ "predicted_label": None,
254
+ "details": {"error": str(e)}
255
  }
256
 
257
+ def validate_ndv_f_gene_sequence(sequence):
258
+ """Additional validation specific to NDV F gene characteristics"""
259
+ issues = []
260
+ suggestions = []
261
+
262
+ # Length check
263
+ if len(sequence) < 1500:
264
+ issues.append(f"Sequence length ({len(sequence)}) shorter than typical NDV F gene (1662-1800 nt)")
265
+ suggestions.append("Verify complete F gene sequence was provided")
266
+ elif len(sequence) > 2000:
267
+ issues.append(f"Sequence length ({len(sequence)}) longer than typical NDV F gene")
268
+ suggestions.append("Check if sequence contains additional regions beyond F gene")
269
+
270
+ # Start codon check
271
+ if not sequence.startswith('ATG'):
272
+ issues.append("Sequence doesn't start with ATG start codon")
273
+ suggestions.append("Ensure sequence starts from the translation start site")
274
+
275
+ # Reading frame check
276
+ if len(sequence) % 3 != 0:
277
+ issues.append("Sequence length not divisible by 3 (reading frame issue)")
278
+ suggestions.append("Check for insertions/deletions or trim to proper reading frame")
279
+
280
+ # Stop codon check
281
+ if len(sequence) >= 3:
282
+ last_codon = sequence[-3:]
283
+ stop_codons = ['TAA', 'TAG', 'TGA']
284
+ if last_codon not in stop_codons:
285
+ issues.append(f"Sequence doesn't end with stop codon (ends with {last_codon})")
286
+ suggestions.append("Verify complete F gene sequence including stop codon")
287
+
288
+ # Nucleotide composition check
289
+ gc_content = (sequence.count('G') + sequence.count('C')) / len(sequence) * 100
290
+ if gc_content < 30 or gc_content > 70:
291
+ issues.append(f"Unusual GC content: {gc_content:.1f}% (typical range: 35-65%)")
292
+ suggestions.append("Verify sequence quality and correct nucleotide composition")
293
+
294
+ # Check for too many N's (ambiguous nucleotides)
295
+ n_content = sequence.count('N') / len(sequence) * 100
296
+ if n_content > 5:
297
+ issues.append(f"High ambiguous nucleotide content: {n_content:.1f}% N's")
298
+ suggestions.append("Consider resequencing regions with ambiguous nucleotides")
299
+
300
+ return issues, suggestions
301
+
302
+ # Updated run_pipeline function with enhanced predictions
303
+ def enhanced_run_pipeline(dna_input, keras_model, kmer_to_index, classifier_model,
304
+ classifier_kmer_to_index, classifier_maxlen, labels,
305
+ similarity_score=95.0, build_ml_tree=False):
306
+ """Enhanced pipeline with improved F gene prediction"""
307
+ try:
308
+ # Input validation and preprocessing
309
+ dna_input = dna_input.upper().strip()
310
+ if not dna_input:
311
+ return "Empty input", "", "", "", "", "", "", "", "", None, None, None, "No input provided"
312
+
313
+ # Clean sequence
314
+ if not re.match('^[ACTGN]+$', dna_input):
315
+ dna_input = ''.join(c if c in 'ACTGN' else 'N' for c in dna_input)
316
+ logging.info("DNA sequence sanitized")
317
+
318
+ # Validate NDV F gene characteristics
319
+ validation_issues, validation_suggestions = validate_ndv_f_gene_sequence(dna_input)
320
+
321
+ # Step 1: Enhanced Keras Prediction
322
+ keras_result = enhanced_keras_prediction(dna_input, keras_model, kmer_to_index)
323
+ if isinstance(keras_result, dict):
324
+ keras_output = f"Prediction confidence: {keras_result['confidence_score']:.3f}\n"
325
+ keras_output += f"K-mer coverage: {keras_result['kmer_coverage']:.1%}\n"
326
+ keras_output += f"Sequence length: {keras_result['sequence_length']} nt"
327
+ if keras_result['kmer_coverage'] < 0.8:
328
+ keras_output += "\n⚠️ Low k-mer coverage - may affect accuracy"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  else:
330
+ keras_output = str(keras_result)
331
+
332
+ # Step 2: Enhanced Classification
333
+ classifier_result = enhanced_classify_sequence(
334
+ dna_input, classifier_model, classifier_kmer_to_index, classifier_maxlen, labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  )
336
+
337
+ classifier_status = classifier_result["status"]
338
+ classifier_message = classifier_result["message"]
339
+ classifier_label = classifier_result["predicted_label"]
340
+ classifier_confidence = classifier_result["confidence"]
341
+
342
+ # Add validation feedback
343
+ if validation_issues:
344
+ classifier_message += f"\n\n⚠️ Sequence validation issues:\n" + "\n".join(f"• {issue}" for issue in validation_issues[:3])
345
+ if validation_suggestions:
346
+ classifier_message += f"\n\n💡 Suggestions:\n" + "\n".join(f"• {sug}" for sug in validation_suggestions[:3])
347
+
348
+ # Enhanced confidence reporting
349
+ if classifier_result.get("details"):
350
+ details = classifier_result["details"]
351
+ if "all_probabilities" in details:
352
+ probs = details["all_probabilities"]
353
+ classifier_message += f"\n\nPrediction probabilities:"
354
+ for label, prob in sorted(probs.items(), key=lambda x: x[1], reverse=True)[:3]:
355
+ classifier_message += f"\n• {label}: {prob:.1%}"
356
+
357
+ # Return enhanced results
358
+ boundary_output = f"Enhanced preprocessing applied. Length: {len(dna_input)} bp"
359
+ if validation_issues:
360
+ boundary_output += f"\n{len(validation_issues)} validation issues detected"
361
+
362
+ return (
363
+ boundary_output,
364
+ keras_output,
365
+ classifier_status,
366
+ classifier_message,
367
+ classifier_label or "Unknown",
368
+ f"{classifier_confidence:.3f}" if classifier_confidence else "N/A",
369
+ "ML tree not requested" if not build_ml_tree else "ML tree processing...",
370
+ "Enhanced analysis completed",
371
+ "<p>Enhanced F gene analysis completed</p>",
372
+ None, None, None,
373
+ f"Enhanced pipeline completed. Processed {len(dna_input)} bp sequence."
374
+ )
375
+
376
+ except Exception as e:
377
+ error_msg = f"Enhanced pipeline failed: {str(e)}"
378
+ logging.error(error_msg)
379
+ return (
380
+ error_msg, "", "error", error_msg, "Error", "0.000",
381
+ "", "", f"<p>Error: {error_msg}</p>",
382
+ None, None, None, error_msg
383
+ )