Upload inference.py via DNA Console (Portable Version)
Browse files- inference.py +37 -4
inference.py
CHANGED
|
@@ -125,7 +125,15 @@ class SequenceFeatureExtractor:
|
|
| 125 |
for b in 'ATGC': res.append(p_c[p].get(b, 0) / t)
|
| 126 |
return res
|
| 127 |
|
| 128 |
-
def predict_dna(sequence):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# Load Models
|
| 130 |
rf_model = joblib.load("dna_classifier.joblib")
|
| 131 |
rf_scaler = joblib.load("scaler_rf.joblib")
|
|
@@ -137,16 +145,41 @@ def predict_dna(sequence):
|
|
| 137 |
feat_rf = extractor_rf.transform([sequence])
|
| 138 |
scaled_rf = rf_scaler.transform(feat_rf)
|
| 139 |
type_basic = rf_model.predict(scaled_rf)[0]
|
|
|
|
|
|
|
| 140 |
|
| 141 |
-
# 2. ViralBoost Prediction (Virus Type)
|
| 142 |
extractor_gb = SequenceFeatureExtractor()
|
| 143 |
feat_gb = extractor_gb.transform([sequence])
|
| 144 |
scaled_gb = gb_scaler.transform(feat_gb)
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
return {
|
| 148 |
"classification": type_basic,
|
| 149 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
}
|
| 151 |
|
| 152 |
if __name__ == "__main__":
|
|
|
|
| 125 |
for b in 'ATGC': res.append(p_c[p].get(b, 0) / t)
|
| 126 |
return res
|
| 127 |
|
| 128 |
+
def predict_dna(sequence, confidence_threshold=0.55, rare_class_threshold=0.65):
|
| 129 |
+
"""
|
| 130 |
+
DNA sequence prediction with confidence thresholds.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
sequence: DNA sequence string
|
| 134 |
+
confidence_threshold: Minimum confidence for general classification (default 55%)
|
| 135 |
+
rare_class_threshold: Higher threshold for rare classes like Influenza B (default 65%)
|
| 136 |
+
"""
|
| 137 |
# Load Models
|
| 138 |
rf_model = joblib.load("dna_classifier.joblib")
|
| 139 |
rf_scaler = joblib.load("scaler_rf.joblib")
|
|
|
|
| 145 |
feat_rf = extractor_rf.transform([sequence])
|
| 146 |
scaled_rf = rf_scaler.transform(feat_rf)
|
| 147 |
type_basic = rf_model.predict(scaled_rf)[0]
|
| 148 |
+
rf_proba = rf_model.predict_proba(scaled_rf)[0]
|
| 149 |
+
rf_confidence = max(rf_proba)
|
| 150 |
|
| 151 |
+
# 2. ViralBoost Prediction (Virus Type) with Confidence Check
|
| 152 |
extractor_gb = SequenceFeatureExtractor()
|
| 153 |
feat_gb = extractor_gb.transform([sequence])
|
| 154 |
scaled_gb = gb_scaler.transform(feat_gb)
|
| 155 |
+
|
| 156 |
+
gb_proba = gb_model.predict_proba(scaled_gb)[0]
|
| 157 |
+
gb_confidence = max(gb_proba)
|
| 158 |
+
predicted_idx = gb_proba.argmax()
|
| 159 |
+
predicted_class = gb_model.classes_[predicted_idx]
|
| 160 |
+
|
| 161 |
+
# ํฌ๊ท ํด๋์ค (Influenza B ๋ฑ)๋ ๋ ๋์ ์ ๋ขฐ๋ ์๊ตฌ
|
| 162 |
+
rare_classes = ['Influenza B', 'Chicken anemia virus']
|
| 163 |
+
if predicted_class in rare_classes:
|
| 164 |
+
effective_threshold = rare_class_threshold
|
| 165 |
+
else:
|
| 166 |
+
effective_threshold = confidence_threshold
|
| 167 |
+
|
| 168 |
+
# ์ ๋ขฐ๋ ์๊ณ๊ฐ ๋ฏธ๋ฌ ์ 'Unknown'์ผ๋ก ๋ถ๋ฅ
|
| 169 |
+
if gb_confidence < effective_threshold:
|
| 170 |
+
type_virus = 'Unknown'
|
| 171 |
+
virus_confidence = gb_confidence
|
| 172 |
+
else:
|
| 173 |
+
type_virus = predicted_class
|
| 174 |
+
virus_confidence = gb_confidence
|
| 175 |
|
| 176 |
return {
|
| 177 |
"classification": type_basic,
|
| 178 |
+
"classification_confidence": float(rf_confidence),
|
| 179 |
+
"virus_identity": type_virus,
|
| 180 |
+
"virus_confidence": float(virus_confidence),
|
| 181 |
+
"raw_prediction": predicted_class, # ์๋ ์์ธก (๋๋ฒ๊น
์ฉ)
|
| 182 |
+
"raw_confidence": float(gb_confidence)
|
| 183 |
}
|
| 184 |
|
| 185 |
if __name__ == "__main__":
|