Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -75,18 +75,6 @@ def extract_features(sequence):
|
|
| 75 |
selected_df = normalized_df[selected_features].fillna(0)
|
| 76 |
return selected_df.values
|
| 77 |
|
| 78 |
-
# AMP Classifier
|
| 79 |
-
def predict(sequence):
|
| 80 |
-
features = extract_features(sequence)
|
| 81 |
-
if isinstance(features, str):
|
| 82 |
-
return features
|
| 83 |
-
prediction = model.predict(features)[0]
|
| 84 |
-
probabilities = model.predict_proba(features)[0]
|
| 85 |
-
if prediction == 0:
|
| 86 |
-
return f"{probabilities[0] * 100:.2f}% chance of being an Antimicrobial Peptide (AMP)"
|
| 87 |
-
else:
|
| 88 |
-
return f"{probabilities[1] * 100:.2f}% chance of being Non-AMP"
|
| 89 |
-
|
| 90 |
# MIC Predictor
|
| 91 |
def predictmic(sequence):
|
| 92 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
|
@@ -99,37 +87,17 @@ def predictmic(sequence):
|
|
| 99 |
outputs = protbert_model(**tokens)
|
| 100 |
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
|
| 101 |
bacteria_config = {
|
| 102 |
-
"E.coli": {
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
},
|
| 107 |
-
"S.aureus": {
|
| 108 |
-
"model": "aur_xgboost_model.pkl",
|
| 109 |
-
"scaler": "aur_scaler.pkl",
|
| 110 |
-
"pca": None
|
| 111 |
-
},
|
| 112 |
-
"P.aeruginosa": {
|
| 113 |
-
"model": "arg_xgboost_model.pkl",
|
| 114 |
-
"scaler": "arg_scaler.pkl",
|
| 115 |
-
"pca": None
|
| 116 |
-
},
|
| 117 |
-
"K.Pneumonia": {
|
| 118 |
-
"model": "pne_mlp_model.pkl",
|
| 119 |
-
"scaler": "pne_scaler.pkl",
|
| 120 |
-
"pca": "pne_pca.pkl"
|
| 121 |
-
}
|
| 122 |
}
|
| 123 |
mic_results = {}
|
| 124 |
for bacterium, cfg in bacteria_config.items():
|
| 125 |
try:
|
| 126 |
scaler = joblib.load(cfg["scaler"])
|
| 127 |
scaled = scaler.transform(embedding)
|
| 128 |
-
if cfg["pca"]
|
| 129 |
-
pca = joblib.load(cfg["pca"])
|
| 130 |
-
transformed = pca.transform(scaled)
|
| 131 |
-
else:
|
| 132 |
-
transformed = scaled
|
| 133 |
model = joblib.load(cfg["model"])
|
| 134 |
mic_log = model.predict(transformed)[0]
|
| 135 |
mic = round(expm1(mic_log), 3)
|
|
@@ -138,29 +106,29 @@ def predictmic(sequence):
|
|
| 138 |
mic_results[bacterium] = f"Error: {str(e)}"
|
| 139 |
return mic_results
|
| 140 |
|
| 141 |
-
# Combined
|
| 142 |
def full_prediction(sequence):
|
| 143 |
features = extract_features(sequence)
|
| 144 |
if isinstance(features, str):
|
| 145 |
-
return
|
| 146 |
prediction = model.predict(features)[0]
|
| 147 |
probabilities = model.predict_proba(features)[0]
|
| 148 |
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
|
| 149 |
confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
|
| 150 |
mic_values = predictmic(sequence)
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
# Gradio Interface
|
| 154 |
iface = gr.Interface(
|
| 155 |
fn=full_prediction,
|
| 156 |
inputs=gr.Textbox(label="Enter Protein Sequence"),
|
| 157 |
-
outputs=
|
| 158 |
-
gr.Label(label="AMP Classification"),
|
| 159 |
-
gr.Label(label="Confidence"),
|
| 160 |
-
gr.JSON(label="Predicted MIC (µM) for Each Bacterium")
|
| 161 |
-
],
|
| 162 |
title="AMP & MIC Predictor",
|
| 163 |
description="Enter an amino acid sequence (≥10 valid letters) to predict AMP class and MIC values."
|
| 164 |
)
|
| 165 |
|
| 166 |
iface.launch(share=True)
|
|
|
|
|
|
| 75 |
selected_df = normalized_df[selected_features].fillna(0)
|
| 76 |
return selected_df.values
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# MIC Predictor
|
| 79 |
def predictmic(sequence):
|
| 80 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
|
|
|
| 87 |
outputs = protbert_model(**tokens)
|
| 88 |
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
|
| 89 |
bacteria_config = {
|
| 90 |
+
"E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
|
| 91 |
+
"S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
|
| 92 |
+
"P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
|
| 93 |
+
"K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
}
|
| 95 |
mic_results = {}
|
| 96 |
for bacterium, cfg in bacteria_config.items():
|
| 97 |
try:
|
| 98 |
scaler = joblib.load(cfg["scaler"])
|
| 99 |
scaled = scaler.transform(embedding)
|
| 100 |
+
transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
model = joblib.load(cfg["model"])
|
| 102 |
mic_log = model.predict(transformed)[0]
|
| 103 |
mic = round(expm1(mic_log), 3)
|
|
|
|
| 106 |
mic_results[bacterium] = f"Error: {str(e)}"
|
| 107 |
return mic_results
|
| 108 |
|
| 109 |
+
# Combined Output as Single String
|
| 110 |
def full_prediction(sequence):
|
| 111 |
features = extract_features(sequence)
|
| 112 |
if isinstance(features, str):
|
| 113 |
+
return features
|
| 114 |
prediction = model.predict(features)[0]
|
| 115 |
probabilities = model.predict_proba(features)[0]
|
| 116 |
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
|
| 117 |
confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
|
| 118 |
mic_values = predictmic(sequence)
|
| 119 |
+
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n\nPredicted MIC Values (µM):\n"
|
| 120 |
+
for organism, mic in mic_values.items():
|
| 121 |
+
result += f"- {organism}: {mic}\n"
|
| 122 |
+
return result
|
| 123 |
|
| 124 |
+
# Gradio Interface (Single Label Output)
|
| 125 |
iface = gr.Interface(
|
| 126 |
fn=full_prediction,
|
| 127 |
inputs=gr.Textbox(label="Enter Protein Sequence"),
|
| 128 |
+
outputs=gr.Textbox(label="AMP & MIC Prediction Summary"),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
title="AMP & MIC Predictor",
|
| 130 |
description="Enter an amino acid sequence (≥10 valid letters) to predict AMP class and MIC values."
|
| 131 |
)
|
| 132 |
|
| 133 |
iface.launch(share=True)
|
| 134 |
+
|