File size: 9,286 Bytes
0ff9972
942bf87
51a3749
ea9a1bf
4222f98
0ff9972
 
1dcb272
0ff9972
bd01e5d
0ff9972
bd01e5d
 
f0f9b27
0ff9972
 
 
 
 
 
59d7aab
8a9cc7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59d7aab
 
 
 
 
bd01e5d
0ff9972
1dcb272
bd01e5d
59d7aab
 
1dcb272
 
0ff9972
44f5cf9
63d3a19
 
 
 
 
1dcb272
bd01e5d
 
 
 
1dcb272
4222f98
 
bd01e5d
4222f98
 
03f381c
4222f98
bd01e5d
 
03f381c
bd01e5d
0ff9972
63d3a19
bd01e5d
 
63d3a19
 
 
0ff9972
44f5cf9
63d3a19
 
 
 
0ff9972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59d7aab
 
 
 
 
 
 
0ff9972
59d7aab
 
0ff9972
 
 
 
 
 
 
 
44f5cf9
63d3a19
 
 
1dcb272
bd01e5d
 
1dcb272
bd01e5d
 
 
 
 
b206439
1dcb272
bd01e5d
63d3a19
59d7aab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5745f40
63d3a19
 
 
 
 
 
 
0ff9972
2739a59
bd01e5d
44f5cf9
63d3a19
 
 
 
 
44f5cf9
68ded6f
59d7aab
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import gradio as gr
import joblib
import numpy as np
import pandas as pd
from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
import torch
from transformers import BertTokenizer, BertModel
from lime.lime_tabular import LimeTabularExplainer
from math import expm1

# Load AMP Classifier and Scaler
model = joblib.load("RF.joblib")
scaler = joblib.load("norm (4).joblib")

# Load ProtBert
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
protbert_model = protbert_model.to(device).eval()

# Define selected features (146 RFE-selected features)
selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
"_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
"_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
"_SecondaryStrD1001", "_SecondaryStrD1075", "_SecondaryStrD2001", "_SecondaryStrD3001", "_ChargeD1001",
"_ChargeD1025", "_ChargeD2001", "_ChargeD3075", "_ChargeD3100", "_PolarityD1001", "_PolarityD1050",
"_PolarityD2001", "_PolarityD3001", "_NormalizedVDWVD1001", "_NormalizedVDWVD2001", "_NormalizedVDWVD2025",
"_NormalizedVDWVD2050", "_NormalizedVDWVD3001", "_HydrophobicityD1001", "_HydrophobicityD2001",
"_HydrophobicityD3001", "_HydrophobicityD3025", "A", "R", "D", "C", "E", "Q", "H", "I", "M", "P", "Y", "V",
"AR", "AV", "RC", "RL", "RV", "CR", "CC", "CL", "CK", "EE", "EI", "EL", "HC", "IA", "IL", "IV", "LA", "LC", "LE",
"LI", "LT", "LV", "KC", "MA", "MS", "SC", "TC", "TV", "YC", "VC", "VE", "VL", "VK", "VV",
"MoreauBrotoAuto_FreeEnergy30", "MoranAuto_Hydrophobicity2", "MoranAuto_Hydrophobicity4",
"GearyAuto_Hydrophobicity20", "GearyAuto_Hydrophobicity24", "GearyAuto_Hydrophobicity26",
"GearyAuto_Hydrophobicity27", "GearyAuto_Hydrophobicity28", "GearyAuto_Hydrophobicity29",
"GearyAuto_Hydrophobicity30", "GearyAuto_AvFlexibility22", "GearyAuto_AvFlexibility26",
"GearyAuto_AvFlexibility27", "GearyAuto_AvFlexibility28", "GearyAuto_AvFlexibility29", "GearyAuto_AvFlexibility30",
"GearyAuto_Polarizability22", "GearyAuto_Polarizability24", "GearyAuto_Polarizability25",
"GearyAuto_Polarizability27", "GearyAuto_Polarizability28", "GearyAuto_Polarizability29",
"GearyAuto_Polarizability30", "GearyAuto_FreeEnergy24", "GearyAuto_FreeEnergy25", "GearyAuto_FreeEnergy30",
"GearyAuto_ResidueASA21", "GearyAuto_ResidueASA22", "GearyAuto_ResidueASA23", "GearyAuto_ResidueASA24",
"GearyAuto_ResidueASA30", "GearyAuto_ResidueVol21", "GearyAuto_ResidueVol24", "GearyAuto_ResidueVol25",
"GearyAuto_ResidueVol26", "GearyAuto_ResidueVol28", "GearyAuto_ResidueVol29", "GearyAuto_ResidueVol30",
"GearyAuto_Steric18", "GearyAuto_Steric21", "GearyAuto_Steric26", "GearyAuto_Steric27", "GearyAuto_Steric28",
"GearyAuto_Steric29", "GearyAuto_Steric30", "GearyAuto_Mutability23", "GearyAuto_Mutability25",
"GearyAuto_Mutability26", "GearyAuto_Mutability27", "GearyAuto_Mutability28", "GearyAuto_Mutability29",
"GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
"APAAC15", "APAAC18", "APAAC19", "APAAC24"]

# --- FIX (LIME): seed the random background so explanations are reproducible
# across Space restarts. (Loading a real saved training sample here would
# produce more faithful weights; see build_lime_background.py for that path.)
np.random.seed(42)
sample_data = np.random.rand(500, len(selected_features))
explainer = LimeTabularExplainer(
    training_data=sample_data,
    feature_names=selected_features,
    class_names=["AMP", "Non-AMP"],
    mode="classification",
    random_state=42,
)

# Feature extraction function
def extract_features(sequence):
    sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
    if len(sequence) < 10:
        return "Error: Sequence too short."

    try:
        dipeptide_features = AAComposition.CalculateAADipeptideComposition(sequence)
        filtered_dipeptide_features = {k: dipeptide_features[k] for k in list(dipeptide_features.keys())[:420]}
        ctd_features = CTD.CalculateCTD(sequence)
        auto_features = Autocorrelation.CalculateAutoTotal(sequence)
        pseudo_features = PseudoAAC.GetAPseudoAAC(sequence, lamda=9)

        all_features_dict = {}
        all_features_dict.update(ctd_features)
        all_features_dict.update(filtered_dipeptide_features)
        all_features_dict.update(auto_features)
        all_features_dict.update(pseudo_features)

        feature_df_all = pd.DataFrame([all_features_dict])
        normalized_array = scaler.transform(feature_df_all.values)
        normalized_df = pd.DataFrame(normalized_array, columns=feature_df_all.columns)

        if not set(selected_features).issubset(normalized_df.columns):
            return "Error: Some selected features are missing."

        selected_df = normalized_df[selected_features].fillna(0)
        return selected_df.values
    except Exception as e:
        return f"Error in feature extraction: {str(e)}"

# MIC prediction function
def predictmic(sequence):
    sequence = ''.join([aa for aa in sequence.upper() if aa in "ACDEFGHIKLMNPQRSTVWY"])
    if len(sequence) < 10:
        return {"Error": "Sequence too short or invalid."}

    seq_spaced = ' '.join(list(sequence))
    tokens = tokenizer(seq_spaced, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    tokens = {k: v.to(device) for k, v in tokens.items()}

    with torch.no_grad():
        outputs = protbert_model(**tokens)
        embedding = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy().reshape(1, -1)

    bacteria_config = {
        "E.coli": {"model": "coli_xgboost_model.pkl", "scaler": "coli_scaler.pkl", "pca": None},
        "S.aureus": {"model": "aur_xgboost_model.pkl", "scaler": "aur_scaler.pkl", "pca": None},
        "P.aeruginosa": {"model": "arg_xgboost_model.pkl", "scaler": "arg_scaler.pkl", "pca": None},
        "K.Pneumonia": {"model": "pne_mlp_model.pkl", "scaler": "pne_scaler.pkl", "pca": "pne_pca.pkl"}
    }

    mic_results = {}
    for bacterium, cfg in bacteria_config.items():
        try:
            # --- FIX (variable shadowing): renamed locals so the global `scaler`
            # and `model` (the AMP RF + its MinMax scaler) are NEVER overwritten.
            # The original code reused the names `scaler` and `model` here, which
            # silently broke the AMP classifier on every prediction after the
            # first MIC run.
            mic_scaler = joblib.load(cfg["scaler"])
            scaled = mic_scaler.transform(embedding)
            transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
            mic_model = joblib.load(cfg["model"])
            mic_log = mic_model.predict(transformed)[0]
            mic = round(expm1(mic_log), 3)
            mic_results[bacterium] = mic
        except Exception as e:
            mic_results[bacterium] = f"Error: {str(e)}"

    return mic_results

# Main prediction function
def full_prediction(sequence):
    features = extract_features(sequence)
    if isinstance(features, str):
        return features

    prediction = model.predict(features)[0]
    probabilities = model.predict_proba(features)[0]

    try:
        class_index = list(model.classes_).index(prediction)
        confidence = round(probabilities[class_index] * 100, 2)
    except Exception:
        confidence = "Unknown"

    amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
    result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"

    # --- LIME first (per spec: LIME before SHAP in the HTML report).
    # explain_instance perturbs THIS single input sequence's feature row 2000
    # times and fits a local linear model; weights describe this specific input.
    try:
        explanation = explainer.explain_instance(
            data_row=features[0],
            predict_fn=model.predict_proba,
            num_features=10,
            num_samples=2000,
        )
        result += "\nTop Features Influencing Prediction (LIME):\n"
        for feat, weight in explanation.as_list():
            result += f"- {feat}: {round(weight, 4)}\n"
    except Exception as e:
        result += f"\nLIME explanation failed: {str(e)}\n"

    if prediction == 0:
        mic_values = predictmic(sequence)
        result += "\nPredicted MIC Values (μM):\n"
        for org, mic in mic_values.items():
            result += f"- {org}: {mic}\n"
    else:
        result += "\nMIC prediction skipped for Non-AMP sequences.\n"

    return result

# Gradio UI
iface = gr.Interface(
    fn=full_prediction,
    inputs=gr.Textbox(label="Enter Protein Sequence"),
    outputs=gr.Textbox(label="Results"),
    title="AMP & MIC Predictor + LIME Explanation",
    description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
)

# --- FIX (launch): removed share=True. On Hugging Face Spaces the public URL
# is provided by the platform; share=True is for local dev only.
iface.launch()