Ym420 commited on
Commit
0fe07d0
·
verified ·
1 Parent(s): d8b9b3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -6,12 +6,14 @@ import numpy as np
6
  from collections import Counter
7
 
8
  # --- Download model from HF Hub ---
9
- repo_id = "Ym420/Peptide-Function" # replace with your HF repo
10
  model_filename = "xgb_multilabel_model_full.pkl"
11
 
12
  model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
13
  model_package = joblib.load(model_path)
14
- model = model_package['model']
 
 
15
  feature_columns = model_package['feature_columns']
16
 
17
  # --- Metadata ---
@@ -34,12 +36,12 @@ TARGET_CELLS = ["Gram+", "Fungus", "Mammalian Cell", "Cancer", "Gram-"]
34
  def extract_features_app(seq: str) -> pd.DataFrame:
35
  seq = seq.upper()
36
 
37
- # --- 1. Dipeptide composition ---
38
  count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
39
  total = max(len(seq)-1, 1)
40
  dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
41
 
42
- # --- 2. Physicochemical features ---
43
  def g(aa, table): return table.get(aa, 0)
44
  def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
45
 
@@ -67,7 +69,7 @@ def extract_features_app(seq: str) -> pd.DataFrame:
67
  features = dipep_features + physchem_features
68
 
69
  df = pd.DataFrame([features], columns=feature_columns)
70
- df = df.astype('float32') # ensure same type as training
71
  return df
72
 
73
  # --- Prediction function ---
@@ -78,16 +80,17 @@ def predict_peptide(sequence: str):
78
 
79
  X = extract_features_app(seq)
80
 
81
- probs_list = model.predict_proba(X) # list of arrays per target
82
- # --- probs_list: list of arrays from each estimator ---
83
- #probs_list = [est.predict_proba(X) for est in model] # model is a list
84
-
85
  table = []
86
- for i, target in enumerate(TARGET_CELLS):
87
- prob = float(probs_list[i][0][1])
88
- table.append([target, round(prob, 4)])
89
- return table
 
 
 
 
90
 
 
91
 
92
  # --- Gradio Interface ---
93
  custom_css = """
@@ -115,8 +118,5 @@ with gr.Blocks(css=custom_css, theme="default") as demo:
115
  # API endpoint for iOS app
116
  gr.api(predict_peptide, api_name="predict_peptide")
117
 
118
- #if __name__ == "__main__":
119
- # demo.launch()
120
-
121
  if __name__ == "__main__":
122
  demo.launch(show_error=True)
 
6
  from collections import Counter
7
 
8
  # --- Download model from HF Hub ---
9
+ repo_id = "Ym420/Peptide-Function"
10
  model_filename = "xgb_multilabel_model_full.pkl"
11
 
12
  model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
13
  model_package = joblib.load(model_path)
14
+
15
+ # --- Unwrap model dict ---
16
+ model_dict = model_package['model'] # dict: {'Gram+': XGBClassifier, ...}
17
  feature_columns = model_package['feature_columns']
18
 
19
  # --- Metadata ---
 
36
  def extract_features_app(seq: str) -> pd.DataFrame:
37
  seq = seq.upper()
38
 
39
+ # Dipeptide composition
40
  count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
41
  total = max(len(seq)-1, 1)
42
  dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
43
 
44
+ # Physicochemical features
45
  def g(aa, table): return table.get(aa, 0)
46
  def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
47
 
 
69
  features = dipep_features + physchem_features
70
 
71
  df = pd.DataFrame([features], columns=feature_columns)
72
+ df = df.astype('float32')
73
  return df
74
 
75
  # --- Prediction function ---
 
80
 
81
  X = extract_features_app(seq)
82
 
 
 
 
 
83
  table = []
84
+ # Iterate over each target classifier
85
+ for target in TARGET_CELLS:
86
+ clf = model_dict.get(target)
87
+ if clf is not None:
88
+ prob = clf.predict_proba(X)[0][1] # positive-class probability (0-1)
89
+ table.append([target, round(float(prob), 4)])
90
+ else:
91
+ table.append([target, None])
92
 
93
+ return table
94
 
95
  # --- Gradio Interface ---
96
  custom_css = """
 
118
  # API endpoint for iOS app
119
  gr.api(predict_peptide, api_name="predict_peptide")
120
 
 
 
 
121
  if __name__ == "__main__":
122
  demo.launch(show_error=True)