VParka commited on
Commit
c6c590f
ยท
verified ยท
1 Parent(s): a532578

Upload inference.py via DNA Console (Portable Version)

Browse files
Files changed (1) hide show
  1. inference.py +37 -4
inference.py CHANGED
@@ -125,7 +125,15 @@ class SequenceFeatureExtractor:
125
  for b in 'ATGC': res.append(p_c[p].get(b, 0) / t)
126
  return res
127
 
128
- def predict_dna(sequence):
 
 
 
 
 
 
 
 
129
  # Load Models
130
  rf_model = joblib.load("dna_classifier.joblib")
131
  rf_scaler = joblib.load("scaler_rf.joblib")
@@ -137,16 +145,41 @@ def predict_dna(sequence):
137
  feat_rf = extractor_rf.transform([sequence])
138
  scaled_rf = rf_scaler.transform(feat_rf)
139
  type_basic = rf_model.predict(scaled_rf)[0]
 
 
140
 
141
- # 2. ViralBoost Prediction (Virus Type)
142
  extractor_gb = SequenceFeatureExtractor()
143
  feat_gb = extractor_gb.transform([sequence])
144
  scaled_gb = gb_scaler.transform(feat_gb)
145
- type_virus = gb_model.predict(scaled_gb)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  return {
148
  "classification": type_basic,
149
- "virus_identity": type_virus
 
 
 
 
150
  }
151
 
152
  if __name__ == "__main__":
 
125
  for b in 'ATGC': res.append(p_c[p].get(b, 0) / t)
126
  return res
127
 
128
+ def predict_dna(sequence, confidence_threshold=0.55, rare_class_threshold=0.65):
129
+ """
130
+ DNA sequence prediction with confidence thresholds.
131
+
132
+ Args:
133
+ sequence: DNA sequence string
134
+ confidence_threshold: Minimum confidence for general classification (default 55%)
135
+ rare_class_threshold: Higher threshold for rare classes like Influenza B (default 65%)
136
+ """
137
  # Load Models
138
  rf_model = joblib.load("dna_classifier.joblib")
139
  rf_scaler = joblib.load("scaler_rf.joblib")
 
145
  feat_rf = extractor_rf.transform([sequence])
146
  scaled_rf = rf_scaler.transform(feat_rf)
147
  type_basic = rf_model.predict(scaled_rf)[0]
148
+ rf_proba = rf_model.predict_proba(scaled_rf)[0]
149
+ rf_confidence = max(rf_proba)
150
 
151
+ # 2. ViralBoost Prediction (Virus Type) with Confidence Check
152
  extractor_gb = SequenceFeatureExtractor()
153
  feat_gb = extractor_gb.transform([sequence])
154
  scaled_gb = gb_scaler.transform(feat_gb)
155
+
156
+ gb_proba = gb_model.predict_proba(scaled_gb)[0]
157
+ gb_confidence = max(gb_proba)
158
+ predicted_idx = gb_proba.argmax()
159
+ predicted_class = gb_model.classes_[predicted_idx]
160
+
161
+ # ํฌ๊ท€ ํด๋ž˜์Šค (Influenza B ๋“ฑ)๋Š” ๋” ๋†’์€ ์‹ ๋ขฐ๋„ ์š”๊ตฌ
162
+ rare_classes = ['Influenza B', 'Chicken anemia virus']
163
+ if predicted_class in rare_classes:
164
+ effective_threshold = rare_class_threshold
165
+ else:
166
+ effective_threshold = confidence_threshold
167
+
168
+ # ์‹ ๋ขฐ๋„ ์ž„๊ณ„๊ฐ’ ๋ฏธ๋‹ฌ ์‹œ 'Unknown'์œผ๋กœ ๋ถ„๋ฅ˜
169
+ if gb_confidence < effective_threshold:
170
+ type_virus = 'Unknown'
171
+ virus_confidence = gb_confidence
172
+ else:
173
+ type_virus = predicted_class
174
+ virus_confidence = gb_confidence
175
 
176
  return {
177
  "classification": type_basic,
178
+ "classification_confidence": float(rf_confidence),
179
+ "virus_identity": type_virus,
180
+ "virus_confidence": float(virus_confidence),
181
+ "raw_prediction": predicted_class, # ์›๋ž˜ ์˜ˆ์ธก (๋””๋ฒ„๊น…์šฉ)
182
+ "raw_confidence": float(gb_confidence)
183
  }
184
 
185
  if __name__ == "__main__":