Ym420's picture
Update app.py
ae65209 verified
raw
history blame
3.29 kB
import gradio as gr
import joblib
from huggingface_hub import hf_hub_download
import numpy as np
import pandas as pd
# --- Download model from HF Hub ---
repo_id = "Ym420/Peptide-Function" # replace with your HF repo
model_filename = "xgb_multilabel_model_full.pkl"
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
model_package = joblib.load(model_path)
model = model_package['model']
feature_columns = model_package['feature_columns']
# Metadata
aa_list = model_package['aa_list']
dipeptides = model_package['dipeptides']
hydrophobicity_scale = model_package['hydrophobicity_scale']
aa_mass = model_package['aa_mass']
aa_charge = model_package['aa_charge']
aa_boman = model_package['aa_boman']
aa_flexibility = model_package['aa_flexibility']
aa_polarizability = model_package['aa_polarizability']
aa_aliphatic = model_package['aa_aliphatic']
# --- Feature extraction ---
def extract_features(sequence: str) -> pd.DataFrame:
seq = sequence.upper()
features = {}
# Amino acid composition
for aa in aa_list:
features[f"AA_{aa}"] = seq.count(aa) / len(seq) if len(seq) > 0 else 0
# Dipeptide composition
for dp in dipeptides:
count = sum(1 for i in range(len(seq)-1) if seq[i:i+2] == dp)
features[f"DP_{dp}"] = count / (len(seq)-1) if len(seq) > 1 else 0
# Hydrophobicity
features['hydrophobicity'] = sum(hydrophobicity_scale.get(aa, 0) for aa in seq) / len(seq) if len(seq) > 0 else 0
# Other physicochemical properties
props = ['mass', 'charge', 'boman', 'flexibility', 'polarizability', 'aliphatic']
for prop, table in zip(props, [aa_mass, aa_charge, aa_boman, aa_flexibility, aa_polarizability, aa_aliphatic]):
features[prop] = sum(table.get(aa, 0) for aa in seq) / len(seq) if len(seq) > 0 else 0
df = pd.DataFrame([features])
df = df.reindex(columns=feature_columns, fill_value=0)
return df
# --- Prediction function ---
def predict_peptide(sequence: str):
seq = "".join(sequence.split()).upper()
if not seq:
return []
X = extract_features(seq)
probs_list = model.predict_proba(X) # list of arrays per target cell
# Format output as table: Target Cell | Probability
table = []
for i, target in enumerate(model.classes_):
table.append([target, float(probs_list[i][0][1])])
return table
# --- Gradio Interface ---
custom_css = """
footer, .footer {display:none !important;}
"""
with gr.Blocks(css=custom_css, theme="default") as demo:
gr.Markdown("## Peptide Antimicrobial Predictor\nEnter a peptide sequence to predict efficacy/toxicity.")
seq_input = gr.Textbox(label="Enter Peptide Sequence")
with gr.Row():
predict_btn = gr.Button("Predict", variant="primary")
clear_btn = gr.Button("Clear")
table_output = gr.Dataframe(
headers=["Target Cell", "Probability of Efficacy/Toxicity"],
datatype=["str","number"],
interactive=False
)
predict_btn.click(fn=predict_peptide, inputs=seq_input, outputs=table_output)
clear_btn.click(fn=lambda: ("", []), outputs=[seq_input, table_output])
# API endpoint for iOS app
gr.api(predict_peptide, api_name="predict_peptide")
if __name__ == "__main__":
demo.launch()