Update app.py
Browse files
app.py
CHANGED
|
@@ -7,10 +7,16 @@ from collections import Counter
|
|
| 7 |
|
| 8 |
# --- Download model from HF Hub ---
|
| 9 |
repo_id = "Ym420/Peptide-Function" # replace with your HF repo
|
| 10 |
-
model_filename = "
|
| 11 |
|
| 12 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
| 13 |
model_package = joblib.load(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
model = model_package['model']
|
| 15 |
feature_columns = model_package['feature_columns']
|
| 16 |
|
|
@@ -34,12 +40,12 @@ TARGET_CELLS = ["Gram+", "Fungus", "Mammalian Cell", "Cancer", "Gram-"]
|
|
| 34 |
def extract_features_app(seq: str) -> pd.DataFrame:
|
| 35 |
seq = seq.upper()
|
| 36 |
|
| 37 |
-
#
|
| 38 |
count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
|
| 39 |
total = max(len(seq)-1, 1)
|
| 40 |
dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
|
| 41 |
|
| 42 |
-
#
|
| 43 |
def g(aa, table): return table.get(aa, 0)
|
| 44 |
def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
|
| 45 |
|
|
@@ -77,18 +83,20 @@ def predict_peptide(sequence: str):
|
|
| 77 |
return []
|
| 78 |
|
| 79 |
X = extract_features_app(seq)
|
| 80 |
-
|
| 81 |
-
#
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
| 84 |
|
| 85 |
table = []
|
| 86 |
for i, target in enumerate(TARGET_CELLS):
|
| 87 |
prob = float(probs_list[i][0][1])
|
| 88 |
table.append([target, round(prob, 4)])
|
|
|
|
| 89 |
return table
|
| 90 |
|
| 91 |
-
|
| 92 |
# --- Gradio Interface ---
|
| 93 |
custom_css = """
|
| 94 |
footer, .footer {display:none !important;}
|
|
@@ -112,11 +120,9 @@ with gr.Blocks(css=custom_css, theme="default") as demo:
|
|
| 112 |
predict_btn.click(fn=predict_peptide, inputs=seq_input, outputs=table_output)
|
| 113 |
clear_btn.click(fn=lambda: ("", []), outputs=[seq_input, table_output])
|
| 114 |
|
| 115 |
-
# API endpoint for iOS app
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
#if __name__ == "__main__":
|
| 119 |
-
# demo.launch()
|
| 120 |
|
| 121 |
if __name__ == "__main__":
|
| 122 |
demo.launch(show_error=True)
|
|
|
|
| 7 |
|
| 8 |
# --- Download model from HF Hub ---
|
| 9 |
repo_id = "Ym420/Peptide-Function" # replace with your HF repo
|
| 10 |
+
model_filename = "xgb_multilabel_model_full.pkl"
|
| 11 |
|
| 12 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
| 13 |
model_package = joblib.load(model_path)
|
| 14 |
+
|
| 15 |
+
# --- Debug: Check what we loaded ---
|
| 16 |
+
print("Loaded model_package type:", type(model_package))
|
| 17 |
+
print("model_package keys:", list(model_package.keys()))
|
| 18 |
+
print("Type of model:", type(model_package['model']))
|
| 19 |
+
|
| 20 |
model = model_package['model']
|
| 21 |
feature_columns = model_package['feature_columns']
|
| 22 |
|
|
|
|
| 40 |
def extract_features_app(seq: str) -> pd.DataFrame:
|
| 41 |
seq = seq.upper()
|
| 42 |
|
| 43 |
+
# 1. Dipeptide composition
|
| 44 |
count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
|
| 45 |
total = max(len(seq)-1, 1)
|
| 46 |
dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
|
| 47 |
|
| 48 |
+
# 2. Physicochemical features
|
| 49 |
def g(aa, table): return table.get(aa, 0)
|
| 50 |
def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
|
| 51 |
|
|
|
|
| 83 |
return []
|
| 84 |
|
| 85 |
X = extract_features_app(seq)
|
| 86 |
+
|
| 87 |
+
# --- Handle model being a list of estimators ---
|
| 88 |
+
if isinstance(model, list):
|
| 89 |
+
probs_list = [est.predict_proba(X) for est in model]
|
| 90 |
+
else:
|
| 91 |
+
probs_list = [model.predict_proba(X)] # single model
|
| 92 |
|
| 93 |
table = []
|
| 94 |
for i, target in enumerate(TARGET_CELLS):
|
| 95 |
prob = float(probs_list[i][0][1])
|
| 96 |
table.append([target, round(prob, 4)])
|
| 97 |
+
|
| 98 |
return table
|
| 99 |
|
|
|
|
| 100 |
# --- Gradio Interface ---
|
| 101 |
custom_css = """
|
| 102 |
footer, .footer {display:none !important;}
|
|
|
|
| 120 |
predict_btn.click(fn=predict_peptide, inputs=seq_input, outputs=table_output)
|
| 121 |
clear_btn.click(fn=lambda: ("", []), outputs=[seq_input, table_output])
|
| 122 |
|
| 123 |
+
# Optional API endpoint for iOS app
|
| 124 |
+
# Note: use only if Gradio version supports `api`
|
| 125 |
+
# gr.api(predict_peptide, api_name="predict_peptide")
|
|
|
|
|
|
|
| 126 |
|
| 127 |
if __name__ == "__main__":
|
| 128 |
demo.launch(show_error=True)
|