Update app.py
Browse files
app.py
CHANGED
|
@@ -6,12 +6,14 @@ import numpy as np
|
|
| 6 |
from collections import Counter
|
| 7 |
|
| 8 |
# --- Download model from HF Hub ---
|
| 9 |
-
repo_id = "Ym420/Peptide-Function"
|
| 10 |
model_filename = "xgb_multilabel_model_full.pkl"
|
| 11 |
|
| 12 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
| 13 |
model_package = joblib.load(model_path)
|
| 14 |
-
|
|
|
|
|
|
|
| 15 |
feature_columns = model_package['feature_columns']
|
| 16 |
|
| 17 |
# --- Metadata ---
|
|
@@ -34,12 +36,12 @@ TARGET_CELLS = ["Gram+", "Fungus", "Mammalian Cell", "Cancer", "Gram-"]
|
|
| 34 |
def extract_features_app(seq: str) -> pd.DataFrame:
|
| 35 |
seq = seq.upper()
|
| 36 |
|
| 37 |
-
#
|
| 38 |
count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
|
| 39 |
total = max(len(seq)-1, 1)
|
| 40 |
dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
|
| 41 |
|
| 42 |
-
#
|
| 43 |
def g(aa, table): return table.get(aa, 0)
|
| 44 |
def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
|
| 45 |
|
|
@@ -67,7 +69,7 @@ def extract_features_app(seq: str) -> pd.DataFrame:
|
|
| 67 |
features = dipep_features + physchem_features
|
| 68 |
|
| 69 |
df = pd.DataFrame([features], columns=feature_columns)
|
| 70 |
-
df = df.astype('float32')
|
| 71 |
return df
|
| 72 |
|
| 73 |
# --- Prediction function ---
|
|
@@ -78,16 +80,17 @@ def predict_peptide(sequence: str):
|
|
| 78 |
|
| 79 |
X = extract_features_app(seq)
|
| 80 |
|
| 81 |
-
probs_list = model.predict_proba(X) # list of arrays per target
|
| 82 |
-
# --- probs_list: list of arrays from each estimator ---
|
| 83 |
-
#probs_list = [est.predict_proba(X) for est in model] # model is a list
|
| 84 |
-
|
| 85 |
table = []
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
|
|
|
| 91 |
|
| 92 |
# --- Gradio Interface ---
|
| 93 |
custom_css = """
|
|
@@ -115,8 +118,5 @@ with gr.Blocks(css=custom_css, theme="default") as demo:
|
|
| 115 |
# API endpoint for iOS app
|
| 116 |
gr.api(predict_peptide, api_name="predict_peptide")
|
| 117 |
|
| 118 |
-
#if __name__ == "__main__":
|
| 119 |
-
# demo.launch()
|
| 120 |
-
|
| 121 |
if __name__ == "__main__":
|
| 122 |
demo.launch(show_error=True)
|
|
|
|
| 6 |
from collections import Counter
|
| 7 |
|
| 8 |
# --- Download model from HF Hub ---
|
| 9 |
+
repo_id = "Ym420/Peptide-Function"
|
| 10 |
model_filename = "xgb_multilabel_model_full.pkl"
|
| 11 |
|
| 12 |
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
| 13 |
model_package = joblib.load(model_path)
|
| 14 |
+
|
| 15 |
+
# --- Unwrap model dict ---
|
| 16 |
+
model_dict = model_package['model'] # dict: {'Gram+': XGBClassifier, ...}
|
| 17 |
feature_columns = model_package['feature_columns']
|
| 18 |
|
| 19 |
# --- Metadata ---
|
|
|
|
| 36 |
def extract_features_app(seq: str) -> pd.DataFrame:
|
| 37 |
seq = seq.upper()
|
| 38 |
|
| 39 |
+
# Dipeptide composition
|
| 40 |
count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
|
| 41 |
total = max(len(seq)-1, 1)
|
| 42 |
dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
|
| 43 |
|
| 44 |
+
# Physicochemical features
|
| 45 |
def g(aa, table): return table.get(aa, 0)
|
| 46 |
def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
|
| 47 |
|
|
|
|
| 69 |
features = dipep_features + physchem_features
|
| 70 |
|
| 71 |
df = pd.DataFrame([features], columns=feature_columns)
|
| 72 |
+
df = df.astype('float32')
|
| 73 |
return df
|
| 74 |
|
| 75 |
# --- Prediction function ---
|
|
|
|
| 80 |
|
| 81 |
X = extract_features_app(seq)
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
table = []
|
| 84 |
+
# Iterate over each target classifier
|
| 85 |
+
for target in TARGET_CELLS:
|
| 86 |
+
clf = model_dict.get(target)
|
| 87 |
+
if clf is not None:
|
| 88 |
+
prob = clf.predict_proba(X)[0][1] # positive-class probability (0-1)
|
| 89 |
+
table.append([target, round(float(prob), 4)])
|
| 90 |
+
else:
|
| 91 |
+
table.append([target, None])
|
| 92 |
|
| 93 |
+
return table
|
| 94 |
|
| 95 |
# --- Gradio Interface ---
|
| 96 |
custom_css = """
|
|
|
|
| 118 |
# API endpoint for iOS app
|
| 119 |
gr.api(predict_peptide, api_name="predict_peptide")
|
| 120 |
|
|
|
|
|
|
|
|
|
|
| 121 |
if __name__ == "__main__":
|
| 122 |
demo.launch(show_error=True)
|