Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,8 +19,33 @@ protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
|
|
| 19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
protbert_model = protbert_model.to(device).eval()
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
selected_features = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# LIME Explainer Setup
|
| 26 |
sample_data = np.random.rand(100, len(selected_features))
|
|
@@ -31,9 +56,7 @@ explainer = LimeTabularExplainer(
|
|
| 31 |
mode="classification"
|
| 32 |
)
|
| 33 |
|
| 34 |
-
# Feature Extractor
|
| 35 |
def extract_features(sequence):
|
| 36 |
-
all_features_dict = {}
|
| 37 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
| 38 |
if len(sequence) < 10:
|
| 39 |
return "Error: Sequence too short."
|
|
@@ -42,6 +65,7 @@ def extract_features(sequence):
|
|
| 42 |
ctd_features = CTD.CalculateCTD(sequence)
|
| 43 |
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
|
| 44 |
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
|
|
|
|
| 45 |
all_features_dict.update(ctd_features)
|
| 46 |
all_features_dict.update(filtered_dipeptide_features)
|
| 47 |
all_features_dict.update(auto_features)
|
|
@@ -49,10 +73,11 @@ def extract_features(sequence):
|
|
| 49 |
feature_df_all = pd.DataFrame([all_features_dict])
|
| 50 |
normalized_array = scaler.transform(feature_df_all.values)
|
| 51 |
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
|
|
|
|
|
|
|
| 52 |
selected_df = normalized_df[selected_features].fillna(0)
|
| 53 |
return selected_df.values
|
| 54 |
|
| 55 |
-
# MIC Predictor
|
| 56 |
def predictmic(sequence):
|
| 57 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
| 58 |
if len(sequence) < 10:
|
|
@@ -83,46 +108,38 @@ def predictmic(sequence):
|
|
| 83 |
mic_results[bacterium] = f"Error: {str(e)}"
|
| 84 |
return mic_results
|
| 85 |
|
| 86 |
-
# Full Prediction with LIME Explanation
|
| 87 |
def full_prediction(sequence):
|
| 88 |
features = extract_features(sequence)
|
| 89 |
-
if isinstance(features, str):
|
| 90 |
return features
|
| 91 |
-
|
| 92 |
prediction = model.predict(features)[0]
|
| 93 |
probabilities = model.predict_proba(features)[0]
|
| 94 |
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
|
| 95 |
confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
|
| 96 |
-
|
| 97 |
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
|
| 98 |
-
|
| 99 |
if prediction == 0:
|
| 100 |
mic_values = predictmic(sequence)
|
| 101 |
-
result += "\nPredicted MIC Values (
|
| 102 |
for org, mic in mic_values.items():
|
| 103 |
result += f"- {org}: {mic}\n"
|
| 104 |
else:
|
| 105 |
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
|
| 106 |
-
|
| 107 |
-
# LIME explanation
|
| 108 |
explanation = explainer.explain_instance(
|
| 109 |
data_row=features[0],
|
| 110 |
predict_fn=model.predict_proba,
|
| 111 |
num_features=10
|
| 112 |
)
|
| 113 |
-
result += "\nTop Features Influencing
|
| 114 |
for feat, weight in explanation.as_list():
|
| 115 |
result += f"- {feat}: {round(weight, 4)}\n"
|
| 116 |
-
|
| 117 |
return result
|
| 118 |
|
| 119 |
-
# Gradio UI
|
| 120 |
iface = gr.Interface(
|
| 121 |
fn=full_prediction,
|
| 122 |
inputs=gr.Textbox(label="Enter Protein Sequence"),
|
| 123 |
-
outputs=gr.Textbox(label="
|
| 124 |
title="AMP & MIC Predictor + LIME Explanation",
|
| 125 |
-
description="Paste an amino acid sequence (
|
| 126 |
)
|
| 127 |
|
| 128 |
-
iface.launch(share=True)
|
|
|
|
| 19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
protbert_model = protbert_model.to(device).eval()
|
| 21 |
|
| 22 |
+
# Full list of selected features
|
| 23 |
+
selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
|
| 24 |
+
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
|
| 25 |
+
"_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
|
| 26 |
+
"_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
|
| 27 |
+
"_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
|
| 28 |
+
"_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
|
| 29 |
+
"_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001",
|
| 30 |
+
"_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
|
| 31 |
+
"AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE",
|
| 32 |
+
"LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
|
| 33 |
+
"MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
|
| 34 |
+
"GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
|
| 35 |
+
"GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
|
| 36 |
+
"GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
|
| 37 |
+
"GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30",
|
| 38 |
+
"GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25",
|
| 39 |
+
"GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29",
|
| 40 |
+
"GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30",
|
| 41 |
+
"GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
|
| 42 |
+
"GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25",
|
| 43 |
+
"GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30",
|
| 44 |
+
"GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
|
| 45 |
+
"GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
|
| 46 |
+
"GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29",
|
| 47 |
+
"GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
|
| 48 |
+
"APAAC15", "APAAC18", "APAAC19", "APAAC24"]
|
| 49 |
|
| 50 |
# LIME Explainer Setup
|
| 51 |
sample_data = np.random.rand(100, len(selected_features))
|
|
|
|
| 56 |
mode="classification"
|
| 57 |
)
|
| 58 |
|
|
|
|
| 59 |
def extract_features(sequence):
|
|
|
|
| 60 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
| 61 |
if len(sequence) < 10:
|
| 62 |
return "Error: Sequence too short."
|
|
|
|
| 65 |
ctd_features = CTD.CalculateCTD(sequence)
|
| 66 |
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
|
| 67 |
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
|
| 68 |
+
all_features_dict = {}
|
| 69 |
all_features_dict.update(ctd_features)
|
| 70 |
all_features_dict.update(filtered_dipeptide_features)
|
| 71 |
all_features_dict.update(auto_features)
|
|
|
|
| 73 |
feature_df_all = pd.DataFrame([all_features_dict])
|
| 74 |
normalized_array = scaler.transform(feature_df_all.values)
|
| 75 |
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
|
| 76 |
+
if not set(selected_features).issubset(set(normalized_df.columns)):
|
| 77 |
+
return "Error: Some selected features are missing from computed features."
|
| 78 |
selected_df = normalized_df[selected_features].fillna(0)
|
| 79 |
return selected_df.values
|
| 80 |
|
|
|
|
| 81 |
def predictmic(sequence):
|
| 82 |
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
|
| 83 |
if len(sequence) < 10:
|
|
|
|
| 108 |
mic_results[bacterium] = f"Error: {str(e)}"
|
| 109 |
return mic_results
|
| 110 |
|
|
|
|
| 111 |
def full_prediction(sequence):
|
| 112 |
features = extract_features(sequence)
|
| 113 |
+
if isinstance(features, str):
|
| 114 |
return features
|
|
|
|
| 115 |
prediction = model.predict(features)[0]
|
| 116 |
probabilities = model.predict_proba(features)[0]
|
| 117 |
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
|
| 118 |
confidence = round(probabilities[0 if prediction == 0 else 1] * 100, 2)
|
|
|
|
| 119 |
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
|
|
|
|
| 120 |
if prediction == 0:
|
| 121 |
mic_values = predictmic(sequence)
|
| 122 |
+
result += "\nPredicted MIC Values (\u00b5M):\n"
|
| 123 |
for org, mic in mic_values.items():
|
| 124 |
result += f"- {org}: {mic}\n"
|
| 125 |
else:
|
| 126 |
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
|
|
|
|
|
|
|
| 127 |
explanation = explainer.explain_instance(
|
| 128 |
data_row=features[0],
|
| 129 |
predict_fn=model.predict_proba,
|
| 130 |
num_features=10
|
| 131 |
)
|
| 132 |
+
result += "\nTop Features Influencing Prediction:\n"
|
| 133 |
for feat, weight in explanation.as_list():
|
| 134 |
result += f"- {feat}: {round(weight, 4)}\n"
|
|
|
|
| 135 |
return result
|
| 136 |
|
|
|
|
| 137 |
iface = gr.Interface(
|
| 138 |
fn=full_prediction,
|
| 139 |
inputs=gr.Textbox(label="Enter Protein Sequence"),
|
| 140 |
+
outputs=gr.Textbox(label="Results"),
|
| 141 |
title="AMP & MIC Predictor + LIME Explanation",
|
| 142 |
+
description="Paste an amino acid sequence (\u226510 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
|
| 143 |
)
|
| 144 |
|
| 145 |
+
iface.launch(share=True)
|