Spaces:
Sleeping
Sleeping
File size: 9,286 Bytes
0ff9972 942bf87 51a3749 ea9a1bf 4222f98 0ff9972 1dcb272 0ff9972 bd01e5d 0ff9972 bd01e5d f0f9b27 0ff9972 59d7aab 8a9cc7c 59d7aab bd01e5d 0ff9972 1dcb272 bd01e5d 59d7aab 1dcb272 0ff9972 44f5cf9 63d3a19 1dcb272 bd01e5d 1dcb272 4222f98 bd01e5d 4222f98 03f381c 4222f98 bd01e5d 03f381c bd01e5d 0ff9972 63d3a19 bd01e5d 63d3a19 0ff9972 44f5cf9 63d3a19 0ff9972 59d7aab 0ff9972 59d7aab 0ff9972 44f5cf9 63d3a19 1dcb272 bd01e5d 1dcb272 bd01e5d b206439 1dcb272 bd01e5d 63d3a19 59d7aab 5745f40 63d3a19 0ff9972 2739a59 bd01e5d 44f5cf9 63d3a19 44f5cf9 68ded6f 59d7aab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import gradio as gr
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
import torch
from transformers import BertTokenizer, BertModel
from lime.lime_tabular import LimeTabularExplainer
from math import expm1
# Load AMP Classifier and Scaler
model = joblib.load("RF.joblib")
scaler = joblib.load("norm (4).joblib")
# Load ProtBert
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
protbert_model = protbert_model.to(device).eval()
# Define selected features (146 RFE-selected features)
selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
"_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
"_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
"_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
"_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
"_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001",
"_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
"AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE",
"LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
"MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
"GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
"GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
"GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
"GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30",
"GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25",
"GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29",
"GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30",
"GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
"GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25",
"GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30",
"GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
"GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
"GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29",
"GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
"APAAC15", "APAAC18", "APAAC19", "APAAC24"]
# --- FIX (LIME): seed the random background so explanations are reproducible
# across Space restarts. (Loading a real saved training sample here would
# produce more faithful weights; see build_lime_background.py for that path.)
np.random.seed(42)
sample_data = np.random.rand(500, len(selected_features))
explainer = LimeTabularExplainer(
training_data=sample_data,
feature_names=selected_features,
class_names=["AMP", "Non-AMP"],
mode="classification",
random_state=42,
)
# Feature extraction function
def extract_features(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return "Error: Sequence too short."
try:
dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
ctd_features = CTD.CalculateCTD(sequence)
auto_features = Autocorrelation.CalculateAutoTotal(sequence)
pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)
all_features_dict = {}
all_features_dict.update(ctd_features)
all_features_dict.update(filtered_dipeptide_features)
all_features_dict.update(auto_features)
all_features_dict.update(pseudo_features)
feature_df_all = pd.DataFrame([all_features_dict])
normalized_array = scaler.transform(feature_df_all.values)
normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)
if not set(selected_features).issubset(normalized_df.columns):
return "Error: Some selected features are missing."
selected_df = normalized_df[selected_features].fillna(0)
return selected_df.values
except Exception as e:
return f"Error in feature extraction: {str(e)}"
# MIC prediction function
def predictmic(sequence):
sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
if len(sequence) < 10:
return {"Error": "Sequence too short or invalid."}
seq_spaced = ' '.join(list(sequence))
tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
tokens = {k: v.to(device) for k, v in tokens.items()}
with torch.no_grad():
outputs = protbert_model(**tokens)
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)
bacteria_config = {
"E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
"S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
"P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
"K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
}
mic_results = {}
for bacterium, cfg in bacteria_config.items():
try:
# --- FIX (variable shadowing): renamed locals so the global `scaler`
# and `model` (the AMP RF + its MinMax scaler) are NEVER overwritten.
# The original code reused the names `scaler` and `model` here, which
# silently broke the AMP classifier on every prediction after the
# first MIC run.
mic_scaler = joblib.load(cfg["scaler"])
scaled = mic_scaler.transform(embedding)
transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
mic_model = joblib.load(cfg["model"])
mic_log = mic_model.predict(transformed)[0]
mic = round(expm1(mic_log), 3)
mic_results[bacterium] = mic
except Exception as e:
mic_results[bacterium] = f"Error: {str(e)}"
return mic_results
# Main prediction function
def full_prediction(sequence):
features = extract_features(sequence)
if isinstance(features, str):
return features
prediction = model.predict(features)[0]
probabilities = model.predict_proba(features)[0]
try:
class_index = list(model.classes_).index(prediction)
confidence = round(probabilities[class_index] * 100, 2)
except Exception:
confidence = "Unknown"
amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
# --- LIME first (per spec: LIME before SHAP in the HTML report).
# explain_instance perturbs THIS single input sequence's feature row 2000
# times and fits a local linear model; weights describe this specific input.
try:
explanation = explainer.explain_instance(
data_row=features[0],
predict_fn=model.predict_proba,
num_features=10,
num_samples=2000,
)
result += "\nTop Features Influencing Prediction (LIME):\n"
for feat, weight in explanation.as_list():
result += f"- {feat}: {round(weight, 4)}\n"
except Exception as e:
result += f"\nLIME explanation failed: {str(e)}\n"
if prediction == 0:
mic_values = predictmic(sequence)
result += "\nPredicted MIC Values (μM):\n"
for org, mic in mic_values.items():
result += f"- {org}: {mic}\n"
else:
result += "\nMIC prediction skipped for Non-AMP sequences.\n"
return result
# Gradio UI
iface = gr.Interface(
fn=full_prediction,
inputs=gr.Textbox(label="Enter Protein Sequence"),
outputs=gr.Textbox(label="Results"),
title="AMP & MIC Predictor + LIME Explanation",
description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
)
# --- FIX (launch): removed share=True. On Hugging Face Spaces the public URL
# is provided by the platform; share=True is for local dev only.
iface.launch() |