nonzeroexit commited on
Commit
59d7aab
·
verified ·
1 Parent(s): 0ff9972

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -20
app.py CHANGED
@@ -3,7 +3,6 @@ import joblib
3
  import numpy as np
4
  import pandas as pd
5
  from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
6
- from sklearn.preprocessing import MinMaxScaler
7
  import torch
8
  from transformers import BertTokenizer, BertModel
9
  from lime.lime_tabular import LimeTabularExplainer
@@ -19,7 +18,7 @@ protbert_model = BertModel.from_pretrained("Rostlab/prot_bert")
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  protbert_model = protbert_model.to(device).eval()
21
 
22
- # Define selected features (put your complete list here)
23
  selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
24
  "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
25
  "_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
@@ -47,13 +46,17 @@ selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondarySt
47
  "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
48
  "APAAC15", "APAAC18", "APAAC19", "APAAC24"]
49
 
50
- # Dummy data for LIME
51
- sample_data = np.random.rand(100, len(selected_features))
 
 
 
52
  explainer = LimeTabularExplainer(
53
  training_data=sample_data,
54
  feature_names=selected_features,
55
  class_names=["AMP", "Non-AMP"],
56
- mode="classification"
 
57
  )
58
 
59
  # Feature extraction function
@@ -111,11 +114,16 @@ def predictmic(sequence):
111
  mic_results = {}
112
  for bacterium, cfg in bacteria_config.items():
113
  try:
114
- scaler = joblib.load(cfg["scaler"])
115
- scaled = scaler.transform(embedding)
 
 
 
 
 
116
  transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
117
- model = joblib.load(cfg["model"])
118
- mic_log = model.predict(transformed)[0]
119
  mic = round(expm1(mic_log), 3)
120
  mic_results[bacterium] = mic
121
  except Exception as e:
@@ -141,6 +149,22 @@ def full_prediction(sequence):
141
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
142
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  if prediction == 0:
145
  mic_values = predictmic(sequence)
146
  result += "\nPredicted MIC Values (μM):\n"
@@ -149,16 +173,6 @@ def full_prediction(sequence):
149
  else:
150
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
151
 
152
- explanation = explainer.explain_instance(
153
- data_row=features[0],
154
- predict_fn=model.predict_proba,
155
- num_features=10
156
- )
157
-
158
- result += "\nTop Features Influencing Prediction:\n"
159
- for feat, weight in explanation.as_list():
160
- result += f"- {feat}: {round(weight, 4)}\n"
161
-
162
  return result
163
 
164
  # Gradio UI
@@ -170,4 +184,6 @@ iface = gr.Interface(
170
  description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
171
  )
172
 
173
- iface.launch(share=True)
 
 
 
3
  import numpy as np
4
  import pandas as pd
5
  from propy import AAComposition, Autocorrelation, CTD, PseudoAAC
 
6
  import torch
7
  from transformers import BertTokenizer, BertModel
8
  from lime.lime_tabular import LimeTabularExplainer
 
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  protbert_model = protbert_model.to(device).eval()
20
 
21
+ # Define selected features (146 RFE-selected features)
22
  selected_features = ["_SolventAccessibilityC3", "_SecondaryStrC1", "_SecondaryStrC3", "_ChargeC1", "_PolarityC1",
23
  "_NormalizedVDWVC1", "_HydrophobicityC3", "_SecondaryStrT23", "_PolarizabilityD1001", "_PolarizabilityD2001",
24
  "_PolarizabilityD3001", "_SolventAccessibilityD1001", "_SolventAccessibilityD2001", "_SolventAccessibilityD3001",
 
46
  "GearyAuto_Mutability30", "APAAC1", "APAAC4", "APAAC5", "APAAC6", "APAAC8", "APAAC9", "APAAC12", "APAAC13",
47
  "APAAC15", "APAAC18", "APAAC19", "APAAC24"]
48
 
49
+ # --- FIX (LIME): seed the random background so explanations are reproducible
50
+ # across Space restarts. (Loading a real saved training sample here would
51
+ # produce more faithful weights; see build_lime_background.py for that path.)
52
+ np.random.seed(42)
53
+ sample_data = np.random.rand(500, len(selected_features))
54
  explainer = LimeTabularExplainer(
55
  training_data=sample_data,
56
  feature_names=selected_features,
57
  class_names=["AMP", "Non-AMP"],
58
+ mode="classification",
59
+ random_state=42,
60
  )
61
 
62
  # Feature extraction function
 
114
  mic_results = {}
115
  for bacterium, cfg in bacteria_config.items():
116
  try:
117
+ # --- FIX (variable shadowing): renamed locals so the global `scaler`
118
+ # and `model` (the AMP RF + its MinMax scaler) are NEVER overwritten.
119
+ # The original code reused the names `scaler` and `model` here, which
120
+ # silently broke the AMP classifier on every prediction after the
121
+ # first MIC run.
122
+ mic_scaler = joblib.load(cfg["scaler"])
123
+ scaled = mic_scaler.transform(embedding)
124
  transformed = joblib.load(cfg["pca"]).transform(scaled) if cfg["pca"] else scaled
125
+ mic_model = joblib.load(cfg["model"])
126
+ mic_log = mic_model.predict(transformed)[0]
127
  mic = round(expm1(mic_log), 3)
128
  mic_results[bacterium] = mic
129
  except Exception as e:
 
149
  amp_result = "Antimicrobial Peptide (AMP)" if prediction == 0 else "Non-AMP"
150
  result = f"Prediction: {amp_result}\nConfidence: {confidence}%\n"
151
 
152
+ # --- LIME first (per spec: LIME before SHAP in the HTML report).
153
+ # explain_instance perturbs THIS single input sequence's feature row 2000
154
+ # times and fits a local linear model; weights describe this specific input.
155
+ try:
156
+ explanation = explainer.explain_instance(
157
+ data_row=features[0],
158
+ predict_fn=model.predict_proba,
159
+ num_features=10,
160
+ num_samples=2000,
161
+ )
162
+ result += "\nTop Features Influencing Prediction (LIME):\n"
163
+ for feat, weight in explanation.as_list():
164
+ result += f"- {feat}: {round(weight, 4)}\n"
165
+ except Exception as e:
166
+ result += f"\nLIME explanation failed: {str(e)}\n"
167
+
168
  if prediction == 0:
169
  mic_values = predictmic(sequence)
170
  result += "\nPredicted MIC Values (μM):\n"
 
173
  else:
174
  result += "\nMIC prediction skipped for Non-AMP sequences.\n"
175
 
 
 
 
 
 
 
 
 
 
 
176
  return result
177
 
178
  # Gradio UI
 
184
  description="Paste an amino acid sequence (≥10 characters). Get AMP classification, MIC predictions, and LIME interpretability insights."
185
  )
186
 
187
+ # --- FIX (launch): removed share=True. On Hugging Face Spaces the public URL
188
+ # is provided by the platform; share=True is for local dev only.
189
+ iface.launch()