Ym420's picture
Update app.py
86445c5 verified
import gradio as gr
import joblib
from huggingface_hub import hf_hub_download
import pandas as pd
import numpy as np
from collections import Counter
# --- Download model from HF Hub ---
repo_id = "Ym420/Peptide-Function"
model_filename = "xgb_multilabel_model_full.pkl"
# Download and load the saved model package
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
model_package = joblib.load(model_path)
# --- Unwrap model dict ---
# model_dict contains all XGB classifiers for each target cell
# e.g., {'Gram+': XGBClassifier(...), 'Fungus': XGBClassifier(...), ...}
model_dict = model_package['model']
# feature_columns must match the columns returned by extract_features_app
# If you add new features, ensure they are included here and in extract_features_app
feature_columns = model_package['feature_columns']
# --- Metadata (all restored) ---
# If you add new features that depend on new tables or scales, add them here
aa_list = model_package.get('aa_list', [])
dipeptides = model_package.get('dipeptides', [])
hydrophobicity_scale = model_package.get('hydrophobicity_scale', {})
eisenberg_scale = model_package.get('eisenberg_scale', {})
aa_mass = model_package.get('aa_mass', {})
aa_charge = model_package.get('aa_charge', {})
aa_boman = model_package.get('aa_boman', {})
aa_flexibility = model_package.get('aa_flexibility', {})
aa_polarizability = model_package.get('aa_polarizability', {})
aa_aliphatic = model_package.get('aa_aliphatic', {})
aa_deltaG = model_package.get('aa_deltaG', {})
aa_pucker = model_package.get('aa_pucker', {})
# --- Target cells ---
# If you add new labels in the model, you can update this list manually
# Or make it dynamic: TARGET_CELLS = list(model_dict.keys())
TARGET_CELLS = ["Gram+", "Fungus", "Mammalian Cell", "Cancer", "Gram-"]
# --- Feature extraction ---
# When adding new features, compute them here and make sure their names match feature_columns
def extract_features_app(seq: str) -> pd.DataFrame:
seq = seq.upper()
# --- 1. Dipeptide composition ---
count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
total = max(len(seq)-1, 1)
dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
# --- 2. Physicochemical features ---
def g(aa, table): return table.get(aa, 0)
def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
dipeptides_seq = [seq[i:i+2] for i in range(len(seq)-1)]
if len(seq) < 2:
# For very short sequences, fill physchem features with zeros
physchem_features = [0]*13 # Use the total futures
else:
# Compute physico-chemical properties
mw = np.mean([h(dp, aa_mass) for dp in dipeptides_seq])
charge = np.mean([h(dp, aa_charge) for dp in dipeptides_seq])
hydro = np.mean([h(dp, hydrophobicity_scale) for dp in dipeptides_seq])
aromatic = np.mean([(dp[0] in 'FWY') + (dp[1] in 'FWY') for dp in dipeptides_seq]) / 2.0
pI = np.mean([h(dp, {aa: 7 + (int(aa in 'KRH') - int(aa in 'DE')) for aa in aa_list}) for dp in dipeptides_seq])
instability = np.mean([((dp[0] in 'DEKR') + (dp[1] in 'DEKR')) / 2.0 for dp in dipeptides_seq])
hydro_moment = np.sqrt(np.mean([(h(dp, eisenberg_scale))**2 for dp in dipeptides_seq]))
aliphatic = np.mean([h(dp, aa_aliphatic) for dp in dipeptides_seq])
boman = np.mean([h(dp, aa_boman) for dp in dipeptides_seq])
flexibility = np.mean([h(dp, aa_flexibility) for dp in dipeptides_seq])
polarizability = np.mean([h(dp, aa_polarizability) for dp in dipeptides_seq])
deltag = np.mean([h(dp, aa_deltaG) for dp in dipeptides_seq])
pucker = np.mean([h(dp, aa_pucker) for dp in dipeptides_seq])
physchem_features = [mw, charge, hydro, aromatic, pI, instability,
hydro_moment, aliphatic, boman, flexibility, polarizability, deltag, pucker]
# --- Combine features ---
features = dipep_features + physchem_features
# --- Align with feature_columns ---
# Always ensure the order and names match the training data
df = pd.DataFrame([features], columns=feature_columns)
df = df.astype('float32')
return df
# --- Prediction function ---
# Returns probability for each target cell
def predict_peptide(sequence: str):
seq = "".join(sequence.split()).upper()
if not seq:
return []
X = extract_features_app(seq)
table = []
for target in TARGET_CELLS:
clf = model_dict.get(target)
if clf is not None:
# Positive-class probability between 0-1
prob = clf.predict_proba(X)[0][1]
table.append([target, round(float(prob), 4)])
else:
table.append([target, None])
return table
# --- Gradio Interface ---
custom_css = """
footer, .footer {display:none !important;}
"""
with gr.Blocks(css=custom_css, theme="default") as demo:
gr.Markdown("## AMP Spectrum")
seq_input = gr.Textbox(label="Enter Peptide Sequence")
with gr.Row():
predict_btn = gr.Button("Predict", variant="primary")
clear_btn = gr.Button("Clear")
table_output = gr.Dataframe(
headers=["Target", "Confidence"],
datatype=["str","number"],
interactive=False
)
predict_btn.click(fn=predict_peptide, inputs=seq_input, outputs=table_output)
clear_btn.click(fn=lambda: ("", []), outputs=[seq_input, table_output])
# API endpoint for iOS app
gr.api(predict_peptide, api_name="predict_peptide")
if __name__ == "__main__":
demo.launch(show_error=True)
# --- Notes for manual update ---
# 1. When adding new features in your Colab model:
# - Add the new feature computation in extract_features_app
# - Update feature_columns in the model package if needed
# - Add any new metadata tables to the model_package if used
# 2. If you add new target labels:
# - Add them to TARGET_CELLS manually
# - Or switch to dynamic TARGET_CELLS = list(model_dict.keys()) for auto-detection
# 3. Always ensure the DataFrame returned from extract_features_app matches feature_columns in order and names