ankitt6174 commited on
Commit
e883774
·
1 Parent(s): 6bd2415

Making more effificient

Browse files
Files changed (2) hide show
  1. app.py +15 -3
  2. predict.py +68 -58
app.py CHANGED
@@ -1,8 +1,19 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from predict import predict
 
4
 
5
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
6
 
7
  class InputData(BaseModel):
8
  dnasequence: str
@@ -104,7 +115,8 @@ def home():
104
 
105
  @app.post("/predict")
106
  def prediction(data: InputData):
107
- result = predict(
 
108
  seq = data.dnasequence,
109
  pos=101,
110
  ref = data.reference,
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from predict import PredictionModel
4
+ from contextlib import asynccontextmanager
5
 
6
+ ml_models = {}
7
+
8
+ @asynccontextmanager
9
+ async def lifespan(app: FastAPI):
10
+ # Load the ML model
11
+ ml_models["dna_mutation_predictor"] = PredictionModel("./model/model.pth")
12
+ yield
13
+ # Clean up the ML models and release the resources
14
+ ml_models.clear()
15
+
16
+ app = FastAPI(lifespan=lifespan)
17
 
18
  class InputData(BaseModel):
19
  dnasequence: str
 
115
 
116
  @app.post("/predict")
117
  def prediction(data: InputData):
118
+ predictor = ml_models["dna_mutation_predictor"]
119
+ result = predictor.predict(
120
  seq = data.dnasequence,
121
  pos=101,
122
  ref = data.reference,
predict.py CHANGED
@@ -5,22 +5,6 @@ import torch
5
  import torch.nn as nn
6
  import math
7
 
8
- print("="*30)
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- print("Using device:", device)
11
- print("="*30)
12
-
13
- checkpoint = torch.load("./model/model.pth", map_location=device, weights_only=False)
14
-
15
- feature_scaler = checkpoint['feature_scaler']
16
- hyperparameters = checkpoint['hyperparameters']
17
- vocab = checkpoint['vocab']
18
-
19
- mutation_type_encoder = checkpoint['encoders']['mutation_type']
20
- chromosome_encoder = checkpoint['encoders']['chromosome']
21
- ref_encoder = checkpoint['encoders']['ref']
22
- alt_encoder = checkpoint['encoders']['alt']
23
-
24
  chrom_lengths = {
25
  'chr1': 248956422,
26
  'chr2': 242193529,
@@ -46,7 +30,7 @@ chrom_lengths = {
46
  'chr22': 50818468,
47
  }
48
 
49
- def get_feature_data(seq, pos, ref, alt, chrom, genomic_pos, mutation_type):
50
  def gc_content(seq):
51
  seq = seq.upper()
52
  gc = seq.count('G') + seq.count('C')
@@ -81,15 +65,15 @@ def get_feature_data(seq, pos, ref, alt, chrom, genomic_pos, mutation_type):
81
  return genomic_pos / chrom_length
82
 
83
  def get_dummies(mutation_type, chrom, ref, alt):
84
- mutation_type_df = pd.DataFrame([[mutation_type]], columns=mutation_type_encoder.feature_names_in_)
85
- chromosome_df = pd.DataFrame([[chrom]], columns=chromosome_encoder.feature_names_in_)
86
- ref_df = pd.DataFrame([[ref]], columns=ref_encoder.feature_names_in_)
87
- alt_df = pd.DataFrame([[alt]], columns=alt_encoder.feature_names_in_)
88
 
89
- mutation_type_encoded = mutation_type_encoder.transform(mutation_type_df).toarray()[0]
90
- chromosome_encoded = chromosome_encoder.transform(chromosome_df).toarray()[0]
91
- ref_encoded = ref_encoder.transform(ref_df).toarray()[0]
92
- alt_encoded = alt_encoder.transform(alt_df).toarray()[0]
93
 
94
  return np.concatenate([mutation_type_encoded, chromosome_encoded, ref_encoded, alt_encoded])
95
 
@@ -125,11 +109,11 @@ def get_feature_data(seq, pos, ref, alt, chrom, genomic_pos, mutation_type):
125
 
126
  return result
127
 
128
- def get_codon(seq, k=hyperparameters['k-mers']):
129
  return [seq[i:i+k] for i in range(len(seq) - k + 1)]
130
 
131
- def get_tensor(text):
132
- return [vocab[codons.lower()] for codons in get_codon(text)]
133
 
134
  class PositionalEncoding(nn.Module):
135
  def __init__(self, embed_dim, max_len=5000):
@@ -215,33 +199,59 @@ class CNNTransformerHybrid(nn.Module):
215
  output = self.fc_layers(combined_features)
216
  return output
217
 
218
- model = CNNTransformerHybrid(
219
- vocab_size = len(vocab),
220
- embed_dim = hyperparameters['embed_dim'],
221
- num_classes = 2,
222
- max_len = hyperparameters['max_len'],
223
- dropout = hyperparameters['dropout'],
224
- num_heads = hyperparameters['num_heads'],
225
- num_transformer_layers = hyperparameters['num_transformer_layers'],
226
- ff_dim = hyperparameters['ff_dim'],
227
- cnn_out_channels = hyperparameters['cnn_out_channels'],
228
- num_extra_features = 39,
229
- )
230
-
231
- model.load_state_dict(checkpoint['model_state_dict'])
232
- model.to(device)
233
- model.eval()
234
-
235
- def predict(seq, pos, ref, alt, chrom, genomic_pos, mutation_type):
236
- features = get_feature_data(seq, pos, ref, alt, chrom, genomic_pos, mutation_type)['Array']
237
- scaled_features = feature_scaler.transform(features.reshape(1, -1))
238
-
239
- with torch.no_grad():
240
- input_tensor = torch.tensor(get_tensor(seq)).unsqueeze(0).to(device)
241
- features_tensor = torch.tensor(scaled_features, dtype=torch.float32).to(device)
242
- output = model(input_tensor, features_tensor)
243
-
244
- return {
245
- 'Prediction': torch.softmax(output, dim=1).argmax(dim=1).item(),
246
- 'Confidence': torch.softmax(output, dim=1)[0]
247
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch.nn as nn
6
  import math
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  chrom_lengths = {
9
  'chr1': 248956422,
10
  'chr2': 242193529,
 
30
  'chr22': 50818468,
31
  }
32
 
33
+ def get_feature_data(seq, pos, ref, alt, chrom, genomic_pos, mutation_type, encoders):
34
  def gc_content(seq):
35
  seq = seq.upper()
36
  gc = seq.count('G') + seq.count('C')
 
65
  return genomic_pos / chrom_length
66
 
67
  def get_dummies(mutation_type, chrom, ref, alt):
68
+ mutation_type_df = pd.DataFrame([[mutation_type]], columns=encoders['mutation_type'].feature_names_in_)
69
+ chromosome_df = pd.DataFrame([[chrom]], columns=encoders['chromosome'].feature_names_in_)
70
+ ref_df = pd.DataFrame([[ref]], columns=encoders['ref'].feature_names_in_)
71
+ alt_df = pd.DataFrame([[alt]], columns=encoders['alt'].feature_names_in_)
72
 
73
+ mutation_type_encoded = encoders['mutation_type'].transform(mutation_type_df).toarray()[0]
74
+ chromosome_encoded = encoders['chromosome'].transform(chromosome_df).toarray()[0]
75
+ ref_encoded = encoders['ref'].transform(ref_df).toarray()[0]
76
+ alt_encoded = encoders['alt'].transform(alt_df).toarray()[0]
77
 
78
  return np.concatenate([mutation_type_encoded, chromosome_encoded, ref_encoded, alt_encoded])
79
 
 
109
 
110
  return result
111
 
112
+ def get_codon(seq, k):
113
  return [seq[i:i+k] for i in range(len(seq) - k + 1)]
114
 
115
+ def get_tensor(text, vocab, k):
116
+ return [vocab[codons.lower()] for codons in get_codon(text, k)]
117
 
118
  class PositionalEncoding(nn.Module):
119
  def __init__(self, embed_dim, max_len=5000):
 
199
  output = self.fc_layers(combined_features)
200
  return output
201
 
202
+ class PredictionModel:
203
+ def __init__(self, model_path: str):
204
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
205
+ print("="*30)
206
+ print(f"Loading model on device: {self.device}")
207
+
208
+ checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
209
+
210
+ self.feature_scaler = checkpoint['feature_scaler']
211
+ self.hyperparameters = checkpoint['hyperparameters']
212
+ self.vocab = checkpoint['vocab']
213
+ self.encoders = checkpoint['encoders']
214
+
215
+ self.model = CNNTransformerHybrid(
216
+ vocab_size=len(self.vocab),
217
+ embed_dim=self.hyperparameters['embed_dim'],
218
+ num_classes=2,
219
+ max_len=self.hyperparameters['max_len'],
220
+ dropout=self.hyperparameters['dropout'],
221
+ num_heads=self.hyperparameters['num_heads'],
222
+ num_transformer_layers=self.hyperparameters['num_transformer_layers'],
223
+ ff_dim=self.hyperparameters['ff_dim'],
224
+ cnn_out_channels=self.hyperparameters['cnn_out_channels'],
225
+ num_extra_features=39,
226
+ )
227
+
228
+ self.model.load_state_dict(checkpoint['model_state_dict'])
229
+ self.model.to(self.device)
230
+ self.model.eval()
231
+ print("Model loaded successfully.")
232
+ print("="*30)
233
+
234
+ def predict(self, seq, pos, ref, alt, chrom, genomic_pos, mutation_type):
235
+ features = get_feature_data(
236
+ seq, pos, ref, alt, chrom, genomic_pos, mutation_type, self.encoders
237
+ )['Array']
238
+ scaled_features = self.feature_scaler.transform(features.reshape(1, -1))
239
+
240
+ with torch.no_grad():
241
+ input_tensor = torch.tensor(
242
+ get_tensor(seq, self.vocab, self.hyperparameters['k-mers'])
243
+ ).unsqueeze(0).to(self.device)
244
+
245
+ features_tensor = torch.tensor(
246
+ scaled_features, dtype=torch.float32
247
+ ).to(self.device)
248
+
249
+ output = self.model(input_tensor, features_tensor)
250
+
251
+ confidence = torch.softmax(output, dim=1)[0]
252
+ prediction = confidence.argmax().item()
253
+
254
+ return {
255
+ 'Prediction': prediction,
256
+ 'Confidence': confidence
257
+ }