Ym420 commited on
Commit
4e7dcad
·
verified ·
1 Parent(s): d8d7b6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -13
app.py CHANGED
@@ -7,10 +7,16 @@ from collections import Counter
7
 
8
  # --- Download model from HF Hub ---
9
  repo_id = "Ym420/Peptide-Function" # replace with your HF repo
10
- model_filename = "xgb_multilabel_model_full_extendedFuture.pkl"
11
 
12
  model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
13
  model_package = joblib.load(model_path)
 
 
 
 
 
 
14
  model = model_package['model']
15
  feature_columns = model_package['feature_columns']
16
 
@@ -34,12 +40,12 @@ TARGET_CELLS = ["Gram+", "Fungus", "Mammalian Cell", "Cancer", "Gram-"]
34
  def extract_features_app(seq: str) -> pd.DataFrame:
35
  seq = seq.upper()
36
 
37
- # --- 1. Dipeptide composition ---
38
  count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
39
  total = max(len(seq)-1, 1)
40
  dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
41
 
42
- # --- 2. Physicochemical features ---
43
  def g(aa, table): return table.get(aa, 0)
44
  def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
45
 
@@ -77,18 +83,20 @@ def predict_peptide(sequence: str):
77
  return []
78
 
79
  X = extract_features_app(seq)
80
-
81
- #probs_list = model.predict_proba(X) # list of arrays per target
82
- # --- probs_list: list of arrays from each estimator ---
83
- probs_list = [est.predict_proba(X) for est in model] # model is a list
 
 
84
 
85
  table = []
86
  for i, target in enumerate(TARGET_CELLS):
87
  prob = float(probs_list[i][0][1])
88
  table.append([target, round(prob, 4)])
 
89
  return table
90
 
91
-
92
  # --- Gradio Interface ---
93
  custom_css = """
94
  footer, .footer {display:none !important;}
@@ -112,11 +120,9 @@ with gr.Blocks(css=custom_css, theme="default") as demo:
112
  predict_btn.click(fn=predict_peptide, inputs=seq_input, outputs=table_output)
113
  clear_btn.click(fn=lambda: ("", []), outputs=[seq_input, table_output])
114
 
115
- # API endpoint for iOS app
116
- gr.api(predict_peptide, api_name="predict_peptide")
117
-
118
- #if __name__ == "__main__":
119
- # demo.launch()
120
 
121
  if __name__ == "__main__":
122
  demo.launch(show_error=True)
 
7
 
8
  # --- Download model from HF Hub ---
9
  repo_id = "Ym420/Peptide-Function" # replace with your HF repo
10
+ model_filename = "xgb_multilabel_model_full.pkl"
11
 
12
  model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
13
  model_package = joblib.load(model_path)
14
+
15
+ # --- Debug: Check what we loaded ---
16
+ print("Loaded model_package type:", type(model_package))
17
+ print("model_package keys:", list(model_package.keys()))
18
+ print("Type of model:", type(model_package['model']))
19
+
20
  model = model_package['model']
21
  feature_columns = model_package['feature_columns']
22
 
 
40
  def extract_features_app(seq: str) -> pd.DataFrame:
41
  seq = seq.upper()
42
 
43
+ # 1. Dipeptide composition
44
  count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
45
  total = max(len(seq)-1, 1)
46
  dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
47
 
48
+ # 2. Physicochemical features
49
  def g(aa, table): return table.get(aa, 0)
50
  def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
51
 
 
83
  return []
84
 
85
  X = extract_features_app(seq)
86
+
87
+ # --- Handle model being a list of estimators ---
88
+ if isinstance(model, list):
89
+ probs_list = [est.predict_proba(X) for est in model]
90
+ else:
91
+ probs_list = [model.predict_proba(X)] # single model
92
 
93
  table = []
94
  for i, target in enumerate(TARGET_CELLS):
95
  prob = float(probs_list[i][0][1])
96
  table.append([target, round(prob, 4)])
97
+
98
  return table
99
 
 
100
  # --- Gradio Interface ---
101
  custom_css = """
102
  footer, .footer {display:none !important;}
 
120
  predict_btn.click(fn=predict_peptide, inputs=seq_input, outputs=table_output)
121
  clear_btn.click(fn=lambda: ("", []), outputs=[seq_input, table_output])
122
 
123
+ # Optional API endpoint for iOS app
124
+ # Note: use only if Gradio version supports `api`
125
+ # gr.api(predict_peptide, api_name="predict_peptide")
 
 
126
 
127
  if __name__ == "__main__":
128
  demo.launch(show_error=True)