hadir123 commited on
Commit
24d1289
·
1 Parent(s): 99965bb

Add new model prediction

Browse files
Files changed (4) hide show
  1. best_boundary_aware_model.pth +3 -0
  2. predict_app.py +91 -0
  3. predictor.py +414 -0
  4. requirement.txt +5 -0
best_boundary_aware_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13c92e4883bba94b680ba84904e2c36a3c01105196c2a935c979b583fe0dc30c
3
+ size 6410291
predict_app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """predict_app.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/18tTfrNXDQWf7MfzUe4SaNZe2yftvvJsn
8
+ """
9
+
10
+ # app/main.py
11
+ from fastapi import FastAPI, HTTPException
12
+ from pydantic import BaseModel
13
+ from typing import Optional
14
+ from model.predictor import GenePredictor
15
+ import logging
16
+
17
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
+
19
+ app = FastAPI(title="F Gene Prediction API", description="API for predicting f gene start and end positions in DNA sequences")
20
+
21
+ class SequenceInput(BaseModel):
22
+ sequence: str
23
+ ground_truth_labels: Optional[str] = None
24
+ ground_truth_start: Optional[int] = None
25
+ ground_truth_end: Optional[int] = None
26
+
27
+ class PredictionResponse(BaseModel):
28
+ regions: list
29
+ confidence: float
30
+ metrics: Optional[dict] = None
31
+ message: str
32
+
33
+ try:
34
+ predictor = GenePredictor(model_path='model/best_boundary_aware_model.pth')
35
+ except Exception as e:
36
+ logging.error(f"Failed to initialize predictor: {e}")
37
+ raise
38
+
39
+ @app.post("/predict", response_model=PredictionResponse)
40
+ async def predict_gene(input_data: SequenceInput):
41
+ sequence = input_data.sequence.strip().upper()
42
+
43
+ if not sequence:
44
+ raise HTTPException(status_code=400, detail="Sequence cannot be empty")
45
+ if not all(c in 'ACTGN' for c in sequence):
46
+ raise HTTPException(status_code=400, detail="Sequence contains invalid characters. Only A, C, T, G, N allowed")
47
+
48
+ labels = None
49
+ if input_data.ground_truth_labels:
50
+ try:
51
+ labels = [int(x) for x in input_data.ground_truth_labels.split(',')]
52
+ if len(labels) != len(sequence):
53
+ raise HTTPException(status_code=400, detail=f"Labels length ({len(labels)}) must match sequence length ({len(sequence)})")
54
+ if not all(x in (0, 1) for x in labels):
55
+ raise HTTPException(status_code=400, detail="Labels must be 0 or 1")
56
+ except ValueError:
57
+ raise HTTPException(status_code=400, detail="Invalid labels format. Use comma-separated 0s and 1s")
58
+ elif input_data.ground_truth_start is not None and input_data.ground_truth_end is not None:
59
+ try:
60
+ start = input_data.ground_truth_start
61
+ end = input_data.ground_truth_end
62
+ if start < 0 or end > len(sequence) or start >= end:
63
+ raise HTTPException(status_code=400, detail=f"Invalid coordinates: start={start}, end={end}")
64
+ labels = predictor.labels_from_coordinates(len(sequence), start, end)
65
+ except ValueError:
66
+ raise HTTPException(status_code=400, detail="Invalid start/end coordinates")
67
+
68
+ try:
69
+ predictions, probs_dict, confidence = predictor.predict(sequence)
70
+ regions = predictor.extract_gene_regions(predictions, sequence)
71
+
72
+ metrics = None
73
+ if labels is not None:
74
+ metrics = predictor.compute_accuracy(predictions, labels)
75
+
76
+ response = {
77
+ "regions": regions,
78
+ "confidence": float(confidence),
79
+ "metrics": metrics,
80
+ "message": "Prediction successful"
81
+ }
82
+
83
+ return response
84
+
85
+ except Exception as e:
86
+ logging.error(f"Prediction failed: {e}")
87
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
88
+
89
+ @app.get("/health")
90
+ async def health_check():
91
+ return {"status": "API is running"}
predictor.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """predictor.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1JURb-0j-R4LWK3oxeGrNxpJm3V6nnX02
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import numpy as np
14
+ from typing import List, Tuple, Dict, Optional
15
+ import logging
16
+ import re
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
+
21
+ # ============================= MODEL COMPONENTS =============================
22
+
23
+ class BoundaryAwareGenePredictor(nn.Module):
24
+ """Multi-task model predicting genes, starts, and ends separately."""
25
+
26
+ def __init__(self, input_dim: int = 14, hidden_dim: int = 256,
27
+ num_layers: int = 3, dropout: float = 0.3):
28
+ super().__init__()
29
+ self.conv_layers = nn.ModuleList([
30
+ nn.Conv1d(input_dim, hidden_dim//4, kernel_size=k, padding=k//2)
31
+ for k in [3, 7, 15, 31]
32
+ ])
33
+ self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, num_layers,
34
+ batch_first=True, bidirectional=True, dropout=dropout)
35
+ self.norm = nn.LayerNorm(hidden_dim)
36
+ self.dropout = nn.Dropout(dropout)
37
+ self.boundary_attention = nn.MultiheadAttention(hidden_dim, num_heads=8, batch_first=True)
38
+
39
+ self.gene_classifier = nn.Sequential(
40
+ nn.Linear(hidden_dim, hidden_dim//2),
41
+ nn.ReLU(),
42
+ nn.Dropout(dropout),
43
+ nn.Linear(hidden_dim//2, 2)
44
+ )
45
+ self.start_classifier = nn.Sequential(
46
+ nn.Linear(hidden_dim, hidden_dim//2),
47
+ nn.ReLU(),
48
+ nn.Dropout(dropout),
49
+ nn.Linear(hidden_dim//2, 2)
50
+ )
51
+ self.end_classifier = nn.Sequential(
52
+ nn.Linear(hidden_dim, hidden_dim//2),
53
+ nn.ReLU(),
54
+ nn.Dropout(dropout),
55
+ nn.Linear(hidden_dim//2, 2)
56
+ )
57
+
58
+ def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
59
+ batch_size, seq_len, _ = x.shape
60
+ x_conv = x.transpose(1, 2)
61
+ conv_features = [F.relu(conv(x_conv)) for conv in self.conv_layers]
62
+ features = torch.cat(conv_features, dim=1).transpose(1, 2)
63
+
64
+ if lengths is not None:
65
+ packed = nn.utils.rnn.pack_padded_sequence(
66
+ features, lengths.cpu(), batch_first=True, enforce_sorted=False
67
+ )
68
+ lstm_out, _ = self.lstm(packed)
69
+ lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
70
+ else:
71
+ lstm_out, _ = self.lstm(features)
72
+
73
+ lstm_out = self.norm(lstm_out)
74
+ attended, _ = self.boundary_attention(lstm_out, lstm_out, lstm_out)
75
+ attended = self.dropout(attended)
76
+
77
+ return {
78
+ 'gene': self.gene_classifier(attended),
79
+ 'start': self.start_classifier(attended),
80
+ 'end': self.end_classifier(attended)
81
+ }
82
+
83
+ # ============================= DATA PREPROCESSING =============================
84
+
85
+ class DNAProcessor:
86
+ """DNA sequence processor with boundary-aware features."""
87
+
88
+ def __init__(self):
89
+ self.base_to_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3, 'N': 4}
90
+ self.idx_to_base = {v: k for k, v in self.base_to_idx.items()}
91
+ self.start_codons = {'ATG', 'GTG', 'TTG'}
92
+ self.stop_codons = {'TAA', 'TAG', 'TGA'}
93
+
94
+ def encode_sequence(self, sequence: str) -> torch.Tensor:
95
+ sequence = sequence.upper()
96
+ encoded = [self.base_to_idx.get(base, self.base_to_idx['N']) for base in sequence]
97
+ return torch.tensor(encoded, dtype=torch.long)
98
+
99
+ def create_enhanced_features(self, sequence: str) -> torch.Tensor:
100
+ sequence = sequence.upper()
101
+ seq_len = len(sequence)
102
+ encoded = self.encode_sequence(sequence)
103
+
104
+ # One-hot encoding
105
+ one_hot = torch.zeros(seq_len, 5)
106
+ one_hot.scatter_(1, encoded.unsqueeze(1), 1)
107
+ features = [one_hot]
108
+
109
+ # Start codon indicators (increased weights for GTG and TTG)
110
+ start_indicators = torch.zeros(seq_len, 3)
111
+ for i in range(seq_len - 2):
112
+ codon = sequence[i:i+3]
113
+ if codon == 'ATG':
114
+ start_indicators[i:i+3, 0] = 1.0
115
+ elif codon == 'GTG':
116
+ start_indicators[i:i+3, 1] = 0.9 # Increased from 0.7
117
+ elif codon == 'TTG':
118
+ start_indicators[i:i+3, 2] = 0.8 # Increased from 0.5
119
+ features.append(start_indicators)
120
+
121
+ # Stop codon indicators
122
+ stop_indicators = torch.zeros(seq_len, 3)
123
+ for i in range(seq_len - 2):
124
+ codon = sequence[i:i+3]
125
+ if codon == 'TAA':
126
+ stop_indicators[i:i+3, 0] = 1.0
127
+ elif codon == 'TAG':
128
+ stop_indicators[i:i+3, 1] = 1.0
129
+ elif codon == 'TGA':
130
+ stop_indicators[i:i+3, 2] = 1.0
131
+ features.append(stop_indicators)
132
+
133
+ # GC content
134
+ gc_content = torch.zeros(seq_len, 1)
135
+ window_size = 50
136
+ for i in range(seq_len):
137
+ start = max(0, i - window_size//2)
138
+ end = min(seq_len, i + window_size//2)
139
+ window = sequence[start:end]
140
+ gc_count = window.count('G') + window.count('C')
141
+ gc_content[i, 0] = gc_count / len(window) if len(window) > 0 else 0
142
+ features.append(gc_content)
143
+
144
+ # Position encoding
145
+ pos_encoding = torch.zeros(seq_len, 2)
146
+ positions = torch.arange(seq_len, dtype=torch.float)
147
+ pos_encoding[:, 0] = torch.sin(positions / 10000)
148
+ pos_encoding[:, 1] = torch.cos(positions / 10000)
149
+ features.append(pos_encoding)
150
+
151
+ return torch.cat(features, dim=1) # 5 + 3 + 3 + 1 + 2 = 14
152
+
153
+ # ============================= POST-PROCESSING =============================
154
+
155
+ class EnhancedPostProcessor:
156
+ """Enhanced post-processor with stricter boundary detection."""
157
+
158
+ def __init__(self, min_gene_length: int = 150, max_gene_length: int = 5000):
159
+ self.min_gene_length = min_gene_length
160
+ self.max_gene_length = max_gene_length
161
+ self.start_codons = {'ATG', 'GTG', 'TTG'}
162
+ self.stop_codons = {'TAA', 'TAG', 'TGA'}
163
+
164
+ def process_predictions(self, gene_probs: np.ndarray, start_probs: np.ndarray,
165
+ end_probs: np.ndarray, sequence: str = None) -> np.ndarray:
166
+ """Process predictions with enhanced boundary detection."""
167
+
168
+ # More conservative thresholds
169
+ gene_pred = (gene_probs[:, 1] > 0.6).astype(int)
170
+ start_pred = (start_probs[:, 1] > 0.4).astype(int)
171
+ end_pred = (end_probs[:, 1] > 0.5).astype(int)
172
+
173
+ if sequence is not None:
174
+ processed = self._refine_with_codons_and_boundaries(
175
+ gene_pred, start_pred, end_pred, sequence
176
+ )
177
+ else:
178
+ processed = self._refine_with_boundaries(gene_pred, start_pred, end_pred)
179
+
180
+ processed = self._apply_constraints(processed, sequence)
181
+
182
+ return processed
183
+
184
+ def _refine_with_codons_and_boundaries(self, gene_pred: np.ndarray,
185
+ start_pred: np.ndarray, end_pred: np.ndarray,
186
+ sequence: str) -> np.ndarray:
187
+ refined = gene_pred.copy()
188
+ sequence = sequence.upper()
189
+
190
+ start_codon_positions = []
191
+ stop_codon_positions = []
192
+
193
+ for i in range(len(sequence) - 2):
194
+ codon = sequence[i:i+3]
195
+ if codon in self.start_codons:
196
+ start_codon_positions.append(i)
197
+ if codon in self.stop_codons:
198
+ stop_codon_positions.append(i + 3)
199
+
200
+ changes = np.diff(np.concatenate(([0], gene_pred, [0])))
201
+ gene_starts = np.where(changes == 1)[0]
202
+ gene_ends = np.where(changes == -1)[0]
203
+
204
+ refined = np.zeros_like(gene_pred)
205
+
206
+ for g_start, g_end in zip(gene_starts, gene_ends):
207
+ best_start = g_start
208
+ start_window = 100 # Increased from 50
209
+ nearby_starts = [pos for pos in start_codon_positions
210
+ if abs(pos - g_start) <= start_window]
211
+
212
+ if nearby_starts:
213
+ start_scores = []
214
+ for pos in nearby_starts:
215
+ if pos < len(start_pred):
216
+ codon = sequence[pos:pos+3]
217
+ codon_weight = 1.0 if codon == 'ATG' else (0.9 if codon == 'GTG' else 0.8)
218
+ boundary_score = start_pred[pos]
219
+ distance_penalty = abs(pos - g_start) / start_window * 0.2 # Add distance penalty
220
+ score = codon_weight * 0.5 + boundary_score * 0.4 - distance_penalty
221
+ start_scores.append((score, pos))
222
+
223
+ if start_scores:
224
+ best_start = max(start_scores, key=lambda x: x[0])[1]
225
+
226
+ best_end = g_end
227
+ end_window = 100
228
+ nearby_ends = [pos for pos in stop_codon_positions
229
+ if g_start < pos <= g_end + end_window]
230
+
231
+ if nearby_ends:
232
+ end_scores = []
233
+ for pos in nearby_ends:
234
+ gene_length = pos - best_start
235
+ if self.min_gene_length <= gene_length <= self.max_gene_length:
236
+ if pos < len(end_pred):
237
+ frame_bonus = 0.2 if (pos - best_start) % 3 == 0 else 0
238
+ boundary_score = end_pred[pos]
239
+ length_penalty = abs(gene_length - 1000) / 10000
240
+ score = boundary_score + frame_bonus - length_penalty
241
+ end_scores.append((score, pos))
242
+
243
+ if end_scores:
244
+ best_end = max(end_scores, key=lambda x: x[0])[1]
245
+
246
+ gene_length = best_end - best_start
247
+ if (gene_length >= self.min_gene_length and
248
+ gene_length <= self.max_gene_length and
249
+ best_start < best_end):
250
+ refined[best_start:best_end] = 1
251
+
252
+ return refined
253
+
254
+ def _refine_with_boundaries(self, gene_pred: np.ndarray, start_pred: np.ndarray,
255
+ end_pred: np.ndarray) -> np.ndarray:
256
+ refined = gene_pred.copy()
257
+ changes = np.diff(np.concatenate(([0], gene_pred, [0])))
258
+ gene_starts = np.where(changes == 1)[0]
259
+ gene_ends = np.where(changes == -1)[0]
260
+
261
+ for g_start, g_end in zip(gene_starts, gene_ends):
262
+ start_window = slice(max(0, g_start-30), min(len(start_pred), g_start+30))
263
+ start_candidates = np.where(start_pred[start_window])[0]
264
+ if len(start_candidates) > 0:
265
+ relative_positions = start_candidates + max(0, g_start-30)
266
+ distances = np.abs(relative_positions - g_start)
267
+ best_start_idx = np.argmin(distances)
268
+ new_start = relative_positions[best_start_idx]
269
+ refined[g_start:new_start] = 0 if new_start > g_start else refined[g_start:new_start]
270
+ refined[new_start:g_end] = 1
271
+ g_start = new_start
272
+
273
+ end_window = slice(max(0, g_end-50), min(len(end_pred), g_end+50))
274
+ end_candidates = np.where(end_pred[end_window])[0]
275
+ if len(end_candidates) > 0:
276
+ relative_positions = end_candidates + max(0, g_end-50)
277
+ valid_ends = [pos for pos in relative_positions
278
+ if self.min_gene_length <= pos - g_start <= self.max_gene_length]
279
+ if valid_ends:
280
+ distances = np.abs(np.array(valid_ends) - g_end)
281
+ new_end = valid_ends[np.argmin(distances)]
282
+ refined[g_start:new_end] = 1
283
+ refined[new_end:g_end] = 0 if new_end < g_end else refined[new_end:g_end]
284
+
285
+ return refined
286
+
287
+ def _apply_constraints(self, predictions: np.ndarray, sequence: str = None) -> np.ndarray:
288
+ processed = predictions.copy()
289
+ changes = np.diff(np.concatenate(([0], predictions, [0])))
290
+ starts = np.where(changes == 1)[0]
291
+ ends = np.where(changes == -1)[0]
292
+
293
+ for start, end in zip(starts, ends):
294
+ gene_length = end - start
295
+ if gene_length < self.min_gene_length or gene_length > self.max_gene_length:
296
+ processed[start:end] = 0
297
+ continue
298
+ if sequence is not None:
299
+ if gene_length % 3 != 0:
300
+ new_length = (gene_length // 3) * 3
301
+ if new_length >= self.min_gene_length:
302
+ new_end = start + new_length
303
+ processed[new_end:end] = 0
304
+ else:
305
+ processed[start:end] = 0
306
+
307
+ return processed
308
+
309
+ # ============================= PREDICTION =============================
310
+
311
+ class GenePredictor:
312
+ """Handles gene prediction using the trained boundary-aware model."""
313
+
314
+ def __init__(self, model_path: str = 'model/best_boundary_aware_model.pth',
315
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
316
+ self.device = device
317
+ self.model = BoundaryAwareGenePredictor(input_dim=14).to(device)
318
+ try:
319
+ self.model.load_state_dict(torch.load(model_path, map_location=device))
320
+ logging.info(f"Loaded model from {model_path}")
321
+ except Exception as e:
322
+ logging.error(f"Failed to load model: {e}")
323
+ raise
324
+ self.model.eval()
325
+ self.processor = DNAProcessor()
326
+ self.post_processor = EnhancedPostProcessor()
327
+
328
+ def predict(self, sequence: str) -> Tuple[np.ndarray, Dict[str, np.ndarray], float]:
329
+ sequence = sequence.upper()
330
+ if not re.match('^[ACTGN]+$', sequence):
331
+ logging.warning("Sequence contains invalid characters. Using 'N' for unknowns.")
332
+ sequence = ''.join(c if c in 'ACTGN' else 'N' for c in sequence)
333
+
334
+ features = self.processor.create_enhanced_features(sequence).unsqueeze(0).to(self.device)
335
+
336
+ with torch.no_grad():
337
+ outputs = self.model(features)
338
+ gene_probs = F.softmax(outputs['gene'], dim=-1).cpu().numpy()[0]
339
+ start_probs = F.softmax(outputs['start'], dim=-1).cpu().numpy()[0]
340
+ end_probs = F.softmax(outputs['end'], dim=-1).cpu().numpy()[0]
341
+
342
+ predictions = self.post_processor.process_predictions(
343
+ gene_probs, start_probs, end_probs, sequence
344
+ )
345
+ confidence = np.mean(gene_probs[:, 1][predictions == 1]) if np.any(predictions == 1) else 0.0
346
+
347
+ return predictions, {'gene': gene_probs, 'start': start_probs, 'end': end_probs}, confidence
348
+
349
+ def extract_gene_regions(self, predictions: np.ndarray, sequence: str) -> List[Dict]:
350
+ regions = []
351
+ changes = np.diff(np.concatenate(([0], predictions, [0])))
352
+ starts = np.where(changes == 1)[0]
353
+ ends = np.where(changes == -1)[0]
354
+
355
+ for start, end in zip(starts, ends):
356
+ gene_seq = sequence[start:end]
357
+ actual_start_codon = None
358
+ actual_stop_codon = None
359
+
360
+ if len(gene_seq) >= 3:
361
+ start_codon = gene_seq[:3]
362
+ if start_codon in ['ATG', 'GTG', 'TTG']:
363
+ actual_start_codon = start_codon
364
+
365
+ if len(gene_seq) >= 6:
366
+ for i in range(len(gene_seq) - 2, 2, -3):
367
+ codon = gene_seq[i:i+3]
368
+ if codon in ['TAA', 'TAG', 'TGA']:
369
+ actual_stop_codon = codon
370
+ break
371
+
372
+ regions.append({
373
+ 'start': int(start), # Convert to Python int for JSON serialization
374
+ 'end': int(end),
375
+ 'sequence': gene_seq, # Return full sequence
376
+ 'length': int(end - start),
377
+ 'start_codon': actual_start_codon,
378
+ 'stop_codon': actual_stop_codon,
379
+ 'in_frame': (end - start) % 3 == 0
380
+ })
381
+
382
+ return regions
383
+
384
+ def compute_accuracy(self, predictions: np.ndarray, labels: List[int]) -> Dict:
385
+ min_len = min(len(predictions), len(labels))
386
+ predictions = predictions[:min_len]
387
+ labels = np.array(labels[:min_len])
388
+
389
+ accuracy = np.mean(predictions == labels)
390
+ true_pos = np.sum((predictions == 1) & (labels == 1))
391
+ false_neg = np.sum((predictions == 0) & (labels == 1))
392
+ false_pos = np.sum((predictions == 1) & (labels == 0))
393
+
394
+ precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) > 0 else 0.0
395
+ recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) > 0 else 0.0
396
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
397
+
398
+ return {
399
+ 'accuracy': accuracy,
400
+ 'precision': precision,
401
+ 'recall': recall,
402
+ 'f1': f1,
403
+ 'true_positives': int(true_pos),
404
+ 'false_positives': int(false_pos),
405
+ 'false_negatives': int(false_neg)
406
+ }
407
+
408
+ def labels_from_coordinates(self, seq_len: int, start: int, end: int) -> List[int]:
409
+ labels = [0] * seq_len
410
+ start = max(0, min(start, seq_len - 1))
411
+ end = max(start, min(end, seq_len))
412
+ for i in range(start, end):
413
+ labels[i] = 1
414
+ return labels
requirement.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn==0.30.6
3
+ torch==2.4.1
4
+ pandas==2.2.2
5
+ numpy==1.26.4