re-type commited on
Commit
2c6a591
·
verified ·
1 Parent(s): 2438908

Delete predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +0 -628
predictor.py DELETED
@@ -1,628 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import numpy as np
5
- from typing import List, Tuple, Dict, Optional, Union
6
- import logging
7
- import re
8
- import os
9
- from pathlib import Path
10
-
11
- # Configure logging
12
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
-
14
- # ============================= FILE READERS =============================
15
-
16
- class FileReader:
17
- """Handles reading DNA sequences from various file formats."""
18
-
19
- @staticmethod
20
- def read_fasta(file_path: str) -> Dict[str, str]:
21
- """
22
- Read FASTA file and return dictionary of sequence_id: sequence
23
- """
24
- sequences = {}
25
- current_id = None
26
- current_seq = []
27
-
28
- try:
29
- with open(file_path, 'r', encoding='utf-8') as file:
30
- for line in file:
31
- line = line.strip()
32
- if line.startswith('>'):
33
- # Save previous sequence if exists
34
- if current_id is not None:
35
- sequences[current_id] = ''.join(current_seq)
36
- # Start new sequence
37
- current_id = line[1:] # Remove '>' character
38
- current_seq = []
39
- elif line and current_id is not None:
40
- # Add sequence line (remove any whitespace)
41
- current_seq.append(line.replace(' ', '').replace('\t', ''))
42
-
43
- # Don't forget the last sequence
44
- if current_id is not None:
45
- sequences[current_id] = ''.join(current_seq)
46
-
47
- except Exception as e:
48
- logging.error(f"Error reading FASTA file {file_path}: {e}")
49
- raise
50
-
51
- return sequences
52
-
53
- @staticmethod
54
- def read_txt(file_path: str) -> str:
55
- """
56
- Read plain text file containing DNA sequence
57
- """
58
- try:
59
- with open(file_path, 'r', encoding='utf-8') as file:
60
- content = file.read().strip()
61
- # Remove any whitespace, newlines, and non-DNA characters
62
- sequence = ''.join(c.upper() for c in content if c.upper() in 'ACTGN')
63
- return sequence
64
- except Exception as e:
65
- logging.error(f"Error reading TXT file {file_path}: {e}")
66
- raise
67
-
68
- @staticmethod
69
- def detect_file_type(file_path: str) -> str:
70
- """
71
- Detect file type based on extension and content
72
- """
73
- file_path = Path(file_path)
74
- extension = file_path.suffix.lower()
75
-
76
- if extension in ['.fasta', '.fa', '.fas', '.fna']:
77
- return 'fasta'
78
- elif extension in ['.txt', '.seq']:
79
- return 'txt'
80
- else:
81
- # Try to detect by content
82
- try:
83
- with open(file_path, 'r', encoding='utf-8') as file:
84
- first_line = file.readline().strip()
85
- if first_line.startswith('>'):
86
- return 'fasta'
87
- else:
88
- return 'txt'
89
- except:
90
- logging.warning(f"Could not detect file type for {file_path}, assuming txt")
91
- return 'txt'
92
-
93
- # ============================= ORIGINAL MODEL COMPONENTS =============================
94
- # (Including all the original classes: BoundaryAwareGenePredictor, DNAProcessor, EnhancedPostProcessor)
95
-
96
- class BoundaryAwareGenePredictor(nn.Module):
97
- """Multi-task model predicting genes, starts, and ends separately."""
98
-
99
- def __init__(self, input_dim: int = 14, hidden_dim: int = 256,
100
- num_layers: int = 3, dropout: float = 0.3):
101
- super().__init__()
102
- self.conv_layers = nn.ModuleList([
103
- nn.Conv1d(input_dim, hidden_dim//4, kernel_size=k, padding=k//2)
104
- for k in [3, 7, 15, 31]
105
- ])
106
- self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, num_layers,
107
- batch_first=True, bidirectional=True, dropout=dropout)
108
- self.norm = nn.LayerNorm(hidden_dim)
109
- self.dropout = nn.Dropout(dropout)
110
- self.boundary_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
111
-
112
- self.gene_classifier = nn.Sequential(
113
- nn.Linear(hidden_dim, hidden_dim//2),
114
- nn.ReLU(),
115
- nn.Dropout(dropout),
116
- nn.Linear(hidden_dim//2, 2)
117
- )
118
- self.start_classifier = nn.Sequential(
119
- nn.Linear(hidden_dim, hidden_dim//2),
120
- nn.ReLU(),
121
- nn.Dropout(dropout),
122
- nn.Linear(hidden_dim//2, 2)
123
- )
124
- self.end_classifier = nn.Sequential(
125
- nn.Linear(hidden_dim, hidden_dim//2),
126
- nn.ReLU(),
127
- nn.Dropout(dropout),
128
- nn.Linear(hidden_dim//2, 2)
129
- )
130
-
131
- def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
132
- batch_size, seq_len, _ = x.shape
133
- x_conv = x.transpose(1, 2)
134
- conv_features = [F.relu(conv(x_conv)) for conv in self.conv_layers]
135
- features = torch.cat(conv_features, dim=1).transpose(1, 2)
136
-
137
- if lengths is not None:
138
- packed = nn.utils.rnn.pack_padded_sequence(
139
- features, lengths.cpu(), batch_first=True, enforce_sorted=False
140
- )
141
- lstm_out, _ = self.lstm(packed)
142
- lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
143
- else:
144
- lstm_out, _ = self.lstm(features)
145
-
146
- lstm_out = self.norm(lstm_out)
147
- attended, _ = self.boundary_attention(lstm_out, lstm_out, lstm_out)
148
- attended = self.dropout(attended)
149
-
150
- return {
151
- 'gene': self.gene_classifier(attended),
152
- 'start': self.start_classifier(attended),
153
- 'end': self.end_classifier(attended)
154
- }
155
-
156
- class DNAProcessor:
157
- """DNA sequence processor with boundary-aware features."""
158
-
159
- def __init__(self):
160
- self.base_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}
161
- self.idx_to_base = {v: k for k, v in self.base_to_idx.items()}
162
- self.start_codons = {'ATG', 'GTG', 'TTG'}
163
- self.stop_codons = {'TAA', 'TAG', 'TGA'}
164
-
165
- def encode_sequence(self, sequence: str) -> torch.Tensor:
166
- sequence = sequence.upper()
167
- encoded = [self.base_to_idx.get(base, self.base_to_idx['N']) for base in sequence]
168
- return torch.tensor(encoded, dtype=torch.long)
169
-
170
- def create_enhanced_features(self, sequence: str) -> torch.Tensor:
171
- sequence = sequence.upper()
172
- seq_len = len(sequence)
173
- encoded = self.encode_sequence(sequence)
174
-
175
- # One-hot encoding
176
- one_hot = torch.zeros(seq_len, 5)
177
- one_hot.scatter_(1, encoded.unsqueeze(1), 1)
178
- features = [one_hot]
179
-
180
- # Start codon indicators
181
- start_indicators = torch.zeros(seq_len, 3)
182
- for i in range(seq_len - 2):
183
- codon = sequence[i:i+3]
184
- if codon == 'ATG':
185
- start_indicators[i:i+3, 0] = 1.0
186
- elif codon == 'GTG':
187
- start_indicators[i:i+3, 1] = 0.9
188
- elif codon == 'TTG':
189
- start_indicators[i:i+3, 2] = 0.8
190
- features.append(start_indicators)
191
-
192
- # Stop codon indicators
193
- stop_indicators = torch.zeros(seq_len, 3)
194
- for i in range(seq_len - 2):
195
- codon = sequence[i:i+3]
196
- if codon == 'TAA':
197
- stop_indicators[i:i+3, 0] = 1.0
198
- elif codon == 'TAG':
199
- stop_indicators[i:i+3, 1] = 1.0
200
- elif codon == 'TGA':
201
- stop_indicators[i:i+3, 2] = 1.0
202
- features.append(stop_indicators)
203
-
204
- # GC content
205
- gc_content = torch.zeros(seq_len, 1)
206
- window_size = 50
207
- for i in range(seq_len):
208
- start = max(0, i - window_size//2)
209
- end = min(seq_len, i + window_size//2)
210
- window = sequence[start:end]
211
- gc_count = window.count('G') + window.count('C')
212
- gc_content[i, 0] = gc_count / len(window) if len(window) > 0 else 0
213
- features.append(gc_content)
214
-
215
- # Position encoding
216
- pos_encoding = torch.zeros(seq_len, 2)
217
- positions = torch.arange(seq_len, dtype=torch.float)
218
- pos_encoding[:, 0] = torch.sin(positions / 10000)
219
- pos_encoding[:, 1] = torch.cos(positions / 10000)
220
- features.append(pos_encoding)
221
-
222
- return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14
223
-
224
- class EnhancedPostProcessor:
225
- """Enhanced post-processor with stricter boundary detection."""
226
-
227
- def __init__(self, min_gene_length: int = 150, max_gene_length: int = 5000):
228
- self.min_gene_length = min_gene_length
229
- self.max_gene_length = max_gene_length
230
- self.start_codons = {'ATG', 'GTG', 'TTG'}
231
- self.stop_codons = {'TAA', 'TAG', 'TGA'}
232
-
233
- def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray,
234
- end_probs: np.ndarray, sequence: str = None) -> np.ndarray:
235
- """Process predictions with enhanced boundary detection."""
236
- gene_pred = (gene_probs[:, 1] > 0.6).astype(int)
237
- start_pred = (start_probs[:, 1] > 0.4).astype(int)
238
- end_pred = (end_probs[:, 1] > 0.5).astype(int)
239
-
240
- if sequence is not None:
241
- processed = self._refine_with_codons_and_boundaries(
242
- gene_pred, start_pred, end_pred, sequence
243
- )
244
- else:
245
- processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred)
246
-
247
- processed = self._apply_constraints(processed, sequence)
248
- return processed
249
-
250
- def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray,
251
- start_pred: np.ndarray, end_pred: np.ndarray,
252
- sequence: str) -> np.ndarray:
253
- refined = gene_pred.copy()
254
- sequence = sequence.upper()
255
-
256
- start_codon_positions = []
257
- stop_codon_positions = []
258
-
259
- for i in range(len(sequence) - 2):
260
- codon = sequence[i:i+3]
261
- if codon in self.start_codons:
262
- start_codon_positions.append(i)
263
- if codon in self.stop_codons:
264
- stop_codon_positions.append(i + 3)
265
-
266
- changes = np.diff(np.concatenate(([0], gene_pred, [0])))
267
- gene_starts = np.where(changes == 1)[0]
268
- gene_ends = np.where(changes == -1)[0]
269
-
270
- refined = np.zeros_like(gene_pred)
271
-
272
- for g_start, g_end in zip(gene_starts, gene_ends):
273
- best_start = g_start
274
- start_window = 100
275
- nearby_starts = [pos for pos in start_codon_positions
276
- if abs(pos - g_start) <= start_window]
277
-
278
- if nearby_starts:
279
- start_scores = []
280
- for pos in nearby_starts:
281
- if pos < len(start_pred):
282
- codon = sequence[pos:pos+3]
283
- codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8)
284
- boundary_score = start_pred[pos]
285
- distance_penalty = abs(pos - g_start) / start_window * 0.2
286
- score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty
287
- start_scores.append((score, pos))
288
-
289
- if start_scores:
290
- best_start = max(start_scores, key=lambda x: x[0])[1]
291
-
292
- best_end = g_end
293
- end_window = 100
294
- nearby_ends = [pos for pos in stop_codon_positions
295
- if g_start < pos <= g_end + end_window]
296
-
297
- if nearby_ends:
298
- end_scores = []
299
- for pos in nearby_ends:
300
- gene_length = pos - best_start
301
- if self.min_gene_length <= gene_length <= self.max_gene_length:
302
- if pos < len(end_pred):
303
- frame_bonus = 0.2 if (pos - best_start) % 3 == 0 else 0
304
- boundary_score = end_pred[pos]
305
- length_penalty = abs(gene_length - 1000) / 10000
306
- score = boundary_score + frame_bonus - length_penalty
307
- end_scores.append((score, pos))
308
-
309
- if end_scores:
310
- best_end = max(end_scores, key=lambda x: x[0])[1]
311
-
312
- gene_length = best_end - best_start
313
- if (gene_length >= self.min_gene_length and
314
- gene_length <= self.max_gene_length and
315
- best_start < best_end):
316
- refined[best_start:best_end] = 1
317
-
318
- return refined
319
-
320
- def _refine_with_boundaries(self, gene_pred: np.ndarray, start_pred: np.ndarray,
321
- end_pred: np.ndarray) -> np.ndarray:
322
- refined = gene_pred.copy()
323
- changes = np.diff(np.concatenate(([0], gene_pred, [0])))
324
- gene_starts = np.where(changes == 1)[0]
325
- gene_ends = np.where(changes == -1)[0]
326
-
327
- for g_start, g_end in zip(gene_starts, gene_ends):
328
- start_window = slice(max(0, g_start-30), min(len(start_pred), g_start+30))
329
- start_candidates = np.where(start_pred[start_window])[0]
330
- if len(start_candidates) > 0:
331
- relative_positions = start_candidates + max(0, g_start-30)
332
- distances = np.abs(relative_positions - g_start)
333
- best_start_idx = np.argmin(distances)
334
- new_start = relative_positions[best_start_idx]
335
- refined[g_start:new_start] = 0 if new_start > g_start else refined[g_start:new_start]
336
- refined[new_start:g_end] = 1
337
- g_start = new_start
338
-
339
- end_window = slice(max(0, g_end-50), min(len(end_pred), g_end+50))
340
- end_candidates = np.where(end_pred[end_window])[0]
341
- if len(end_candidates) > 0:
342
- relative_positions = end_candidates + max(0, g_end-50)
343
- valid_ends = [pos for pos in relative_positions
344
- if self.min_gene_length <= pos - g_start <= self.max_gene_length]
345
- if valid_ends:
346
- distances = np.abs(np.array(valid_ends) - g_end)
347
- new_end = valid_ends[np.argmin(distances)]
348
- refined[g_start:new_end] = 1
349
- refined[new_end:g_end] = 0 if new_end < g_end else refined[new_end:g_end]
350
-
351
- return refined
352
-
353
- def _apply_constraints(self, predictions: np.ndarray, sequence: str = None) -> np.ndarray:
354
- processed = predictions.copy()
355
- changes = np.diff(np.concatenate(([0], predictions, [0])))
356
- starts = np.where(changes == 1)[0]
357
- ends = np.where(changes == -1)[0]
358
-
359
- for start, end in zip(starts, ends):
360
- gene_length = end - start
361
- if gene_length < self.min_gene_length or gene_length > self.max_gene_length:
362
- processed[start:end] = 0
363
- continue
364
- if sequence is not None:
365
- if gene_length % 3 != 0:
366
- new_length = (gene_length // 3) * 3
367
- if new_length >= self.min_gene_length:
368
- new_end = start + new_length
369
- processed[new_end:end] = 0
370
- else:
371
- processed[start:end] = 0
372
-
373
- return processed
374
-
375
- # ============================= ENHANCED GENE PREDICTOR =============================
376
-
377
- class EnhancedGenePredictor:
378
- """Enhanced Gene Predictor with file input support."""
379
-
380
- def __init__(self, model_path: str = 'model/best_boundary_aware_model.pth',
381
- device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
382
- self.device = device
383
- self.model = BoundaryAwareGenePredictor(input_dim=14).to(device)
384
- try:
385
- self.model.load_state_dict(torch.load(model_path, map_location=device))
386
- logging.info(f"Loaded model from {model_path}")
387
- except Exception as e:
388
- logging.error(f"Failed to load model: {e}")
389
- raise
390
- self.model.eval()
391
- self.processor = DNAProcessor()
392
- self.post_processor = EnhancedPostProcessor()
393
- self.file_reader = FileReader()
394
-
395
- def predict_from_file(self, file_path: str) -> Dict[str, Dict]:
396
- """
397
- Predict genes from a file (.txt or .fasta)
398
- Returns a dictionary with sequence_id as keys and prediction results as values
399
- """
400
- if not os.path.exists(file_path):
401
- raise FileNotFoundError(f"File not found: {file_path}")
402
-
403
- file_type = self.file_reader.detect_file_type(file_path)
404
- logging.info(f"Detected file type: {file_type}")
405
-
406
- results = {}
407
-
408
- if file_type == 'fasta':
409
- sequences = self.file_reader.read_fasta(file_path)
410
- for seq_id, sequence in sequences.items():
411
- logging.info(f"Processing sequence: {seq_id} (length: {len(sequence)})")
412
- result = self.predict_sequence(sequence, seq_id)
413
- results[seq_id] = result
414
- else: # txt file
415
- sequence = self.file_reader.read_txt(file_path)
416
- seq_id = Path(file_path).stem # Use filename as sequence ID
417
- logging.info(f"Processing sequence from {file_path} (length: {len(sequence)})")
418
- result = self.predict_sequence(sequence, seq_id)
419
- results[seq_id] = result
420
-
421
- return results
422
-
423
- def predict_sequence(self, sequence: str, seq_id: str = "sequence") -> Dict:
424
- """
425
- Predict genes from a single DNA sequence string
426
- """
427
- sequence = sequence.upper()
428
- if not re.match('^[ACTGN]+$', sequence):
429
- logging.warning(f"Sequence {seq_id} contains invalid characters. Using 'N' for unknowns.")
430
- sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence)
431
-
432
- # Handle very long sequences by chunking if needed
433
- max_chunk_size = 50000 # Adjust based on your GPU memory
434
- if len(sequence) > max_chunk_size:
435
- return self._predict_long_sequence(sequence, seq_id, max_chunk_size)
436
-
437
- features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device)
438
-
439
- with torch.no_grad():
440
- outputs = self.model(features)
441
- gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
442
- start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
443
- end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
444
-
445
- predictions = self.post_processor.process_predictions(
446
- gene_probs, start_probs, end_probs, sequence
447
- )
448
- confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
449
-
450
- gene_regions = self.extract_gene_regions(predictions, sequence)
451
-
452
- return {
453
- 'sequence_id': seq_id,
454
- 'sequence_length': len(sequence),
455
- 'predictions': predictions.tolist(),
456
- 'probabilities': {
457
- 'gene': gene_probs.tolist(),
458
- 'start': start_probs.tolist(),
459
- 'end': end_probs.tolist()
460
- },
461
- 'confidence': float(confidence),
462
- 'gene_regions': gene_regions,
463
- 'total_genes_found': len(gene_regions)
464
- }
465
-
466
- def _predict_long_sequence(self, sequence: str, seq_id: str, chunk_size: int) -> Dict:
467
- """
468
- Handle very long sequences by processing in chunks with overlap
469
- """
470
- overlap = 1000 # Overlap between chunks to avoid missing genes at boundaries
471
- all_predictions = []
472
- all_gene_probs = []
473
- all_start_probs = []
474
- all_end_probs = []
475
-
476
- for i in range(0, len(sequence), chunk_size - overlap):
477
- end_pos = min(i + chunk_size, len(sequence))
478
- chunk = sequence[i:end_pos]
479
-
480
- logging.info(f"Processing chunk {i//chunk_size + 1} of sequence {seq_id}")
481
-
482
- features = self.processor.create_enhanced_features(chunk).unsqueeze(0).to(self.device)
483
-
484
- with torch.no_grad():
485
- outputs = self.model(features)
486
- gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
487
- start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
488
- end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
489
-
490
- chunk_predictions = self.post_processor.process_predictions(
491
- gene_probs, start_probs, end_probs, chunk
492
- )
493
-
494
- # Handle overlaps by taking the first chunk fully and subsequent chunks without overlap
495
- if i == 0:
496
- all_predictions.extend(chunk_predictions)
497
- all_gene_probs.extend(gene_probs)
498
- all_start_probs.extend(start_probs)
499
- all_end_probs.extend(end_probs)
500
- else:
501
- # Skip overlap region
502
- skip = min(overlap, len(chunk_predictions))
503
- all_predictions.extend(chunk_predictions[skip:])
504
- all_gene_probs.extend(gene_probs[skip:])
505
- all_start_probs.extend(start_probs[skip:])
506
- all_end_probs.extend(end_probs[skip:])
507
-
508
- predictions = np.array(all_predictions)
509
- gene_probs = np.array(all_gene_probs)
510
- start_probs = np.array(all_start_probs)
511
- end_probs = np.array(all_end_probs)
512
-
513
- confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
514
- gene_regions = self.extract_gene_regions(predictions, sequence)
515
-
516
- return {
517
- 'sequence_id': seq_id,
518
- 'sequence_length': len(sequence),
519
- 'predictions': predictions.tolist(),
520
- 'probabilities': {
521
- 'gene': gene_probs.tolist(),
522
- 'start': start_probs.tolist(),
523
- 'end': end_probs.tolist()
524
- },
525
- 'confidence': float(confidence),
526
- 'gene_regions': gene_regions,
527
- 'total_genes_found': len(gene_regions)
528
- }
529
-
530
- def predict_from_text(self, sequence: str) -> Dict:
531
- """
532
- Predict genes from a text string (backward compatibility)
533
- """
534
- return self.predict_sequence(sequence)
535
-
536
- def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]:
537
- """Extract gene regions from predictions"""
538
- regions = []
539
- changes = np.diff(np.concatenate(([0], predictions, [0])))
540
- starts = np.where(changes == 1)[0]
541
- ends = np.where(changes == -1)[0]
542
-
543
- for start, end in zip(starts, ends):
544
- gene_seq = sequence[start:end]
545
- actual_start_codon = None
546
- actual_stop_codon = None
547
-
548
- if len(gene_seq) >= 3:
549
- start_codon = gene_seq[:3]
550
- if start_codon in ['ATG', 'GTG', 'TTG']:
551
- actual_start_codon = start_codon
552
-
553
- if len(gene_seq) >= 6:
554
- for i in range(len(gene_seq) - 2, 2, -3):
555
- codon = gene_seq[i:i+3]
556
- if codon in ['TAA', 'TAG', 'TGA']:
557
- actual_stop_codon = codon
558
- break
559
-
560
- regions.append({
561
- 'start': int(start),
562
- 'end': int(end),
563
- 'sequence': gene_seq,
564
- 'length': int(end - start),
565
- 'start_codon': actual_start_codon,
566
- 'stop_codon': actual_stop_codon,
567
- 'in_frame': (end - start) % 3 == 0
568
- })
569
-
570
- return regions
571
-
572
- def save_results(self, results: Dict[str, Dict], output_path: str, format: str = 'json'):
573
- """
574
- Save prediction results to file
575
- """
576
- import json
577
-
578
- if format.lower() == 'json':
579
- with open(output_path, 'w') as f:
580
- json.dump(results, f, indent=2)
581
- elif format.lower() == 'csv':
582
- import csv
583
- with open(output_path, 'w', newline='') as f:
584
- writer = csv.writer(f)
585
- writer.writerow(['sequence_id', 'gene_start', 'gene_end', 'gene_length',
586
- 'start_codon', 'stop_codon', 'in_frame', 'confidence'])
587
-
588
- for seq_id, result in results.items():
589
- for gene in result['gene_regions']:
590
- writer.writerow([
591
- seq_id, gene['start'], gene['end'], gene['length'],
592
- gene['start_codon'], gene['stop_codon'], gene['in_frame'],
593
- result['confidence']
594
- ])
595
-
596
- logging.info(f"Results saved to {output_path}")
597
-
598
- # ============================= USAGE EXAMPLE =============================
599
-
600
- def main():
601
- """Example usage of the enhanced gene predictor"""
602
-
603
- # Initialize predictor
604
- predictor = EnhancedGenePredictor(model_path='model/best_boundary_aware_model.pth')
605
-
606
- # Example 1: Predict from FASTA file
607
- try:
608
- fasta_results = predictor.predict_from_file('example.fasta')
609
- predictor.save_results(fasta_results, 'fasta_predictions.json')
610
- print("FASTA predictions saved to fasta_predictions.json")
611
- except FileNotFoundError:
612
- print("example.fasta not found, skipping FASTA example")
613
-
614
- # Example 2: Predict from TXT file
615
- try:
616
- txt_results = predictor.predict_from_file('example.txt')
617
- predictor.save_results(txt_results, 'txt_predictions.csv', format='csv')
618
- print("TXT predictions saved to txt_predictions.csv")
619
- except FileNotFoundError:
620
- print("example.txt not found, skipping TXT example")
621
-
622
- # Example 3: Predict from text string (original functionality)
623
- example_sequence = "ATGAAACGCATTAGCACCACCATTACCACCACCATCACCATTACCACAGGTAACGGTGCGGGCTGA"
624
- text_results = predictor.predict_from_text(example_sequence)
625
- print(f"Found {text_results['total_genes_found']} genes in example sequence")
626
-
627
- if __name__ == "__main__":
628
- main()