File size: 6,156 Bytes
d2a0224 86445c5 d2a0224 b9cef5d d2a0224 6780752 d2a0224 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import joblib
from huggingface_hub import hf_hub_download
import pandas as pd
import numpy as np
from collections import Counter
# --- Download model from HF Hub ---
repo_id = "Ym420/Peptide-Function"
model_filename = "xgb_multilabel_model_full.pkl"
# Download and load the saved model package
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
model_package = joblib.load(model_path)
# --- Unwrap model dict ---
# model_dict contains all XGB classifiers for each target cell
# e.g., {'Gram+': XGBClassifier(...), 'Fungus': XGBClassifier(...), ...}
model_dict = model_package['model']
# feature_columns must match the columns returned by extract_features_app
# If you add new features, ensure they are included here and in extract_features_app
feature_columns = model_package['feature_columns']
# --- Metadata (all restored) ---
# If you add new features that depend on new tables or scales, add them here
aa_list = model_package.get('aa_list', [])
dipeptides = model_package.get('dipeptides', [])
hydrophobicity_scale = model_package.get('hydrophobicity_scale', {})
eisenberg_scale = model_package.get('eisenberg_scale', {})
aa_mass = model_package.get('aa_mass', {})
aa_charge = model_package.get('aa_charge', {})
aa_boman = model_package.get('aa_boman', {})
aa_flexibility = model_package.get('aa_flexibility', {})
aa_polarizability = model_package.get('aa_polarizability', {})
aa_aliphatic = model_package.get('aa_aliphatic', {})
aa_deltaG = model_package.get('aa_deltaG', {})
aa_pucker = model_package.get('aa_pucker', {})
# --- Target cells ---
# If you add new labels in the model, you can update this list manually
# Or make it dynamic: TARGET_CELLS = list(model_dict.keys())
TARGET_CELLS = ["Gram+", "Fungus", "Mammalian Cell", "Cancer", "Gram-"]
# --- Feature extraction ---
# When adding new features, compute them here and make sure their names match feature_columns
def extract_features_app(seq: str) -> pd.DataFrame:
seq = seq.upper()
# --- 1. Dipeptide composition ---
count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
total = max(len(seq)-1, 1)
dipep_features = [count.get(dp, 0) / total for dp in dipeptides]
# --- 2. Physicochemical features ---
def g(aa, table): return table.get(aa, 0)
def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0
dipeptides_seq = [seq[i:i+2] for i in range(len(seq)-1)]
if len(seq) < 2:
# For very short sequences, fill physchem features with zeros
physchem_features = [0]*13 # Use the total futures
else:
# Compute physico-chemical properties
mw = np.mean([h(dp, aa_mass) for dp in dipeptides_seq])
charge = np.mean([h(dp, aa_charge) for dp in dipeptides_seq])
hydro = np.mean([h(dp, hydrophobicity_scale) for dp in dipeptides_seq])
aromatic = np.mean([(dp[0] in 'FWY') + (dp[1] in 'FWY') for dp in dipeptides_seq]) / 2.0
pI = np.mean([h(dp, {aa: 7 + (int(aa in 'KRH') - int(aa in 'DE')) for aa in aa_list}) for dp in dipeptides_seq])
instability = np.mean([((dp[0] in 'DEKR') + (dp[1] in 'DEKR')) / 2.0 for dp in dipeptides_seq])
hydro_moment = np.sqrt(np.mean([(h(dp, eisenberg_scale))**2 for dp in dipeptides_seq]))
aliphatic = np.mean([h(dp, aa_aliphatic) for dp in dipeptides_seq])
boman = np.mean([h(dp, aa_boman) for dp in dipeptides_seq])
flexibility = np.mean([h(dp, aa_flexibility) for dp in dipeptides_seq])
polarizability = np.mean([h(dp, aa_polarizability) for dp in dipeptides_seq])
deltag = np.mean([h(dp, aa_deltaG) for dp in dipeptides_seq])
pucker = np.mean([h(dp, aa_pucker) for dp in dipeptides_seq])
physchem_features = [mw, charge, hydro, aromatic, pI, instability,
hydro_moment, aliphatic, boman, flexibility, polarizability, deltag, pucker]
# --- Combine features ---
features = dipep_features + physchem_features
# --- Align with feature_columns ---
# Always ensure the order and names match the training data
df = pd.DataFrame([features], columns=feature_columns)
df = df.astype('float32')
return df
# --- Prediction function ---
# Returns probability for each target cell
def predict_peptide(sequence: str):
seq = "".join(sequence.split()).upper()
if not seq:
return []
X = extract_features_app(seq)
table = []
for target in TARGET_CELLS:
clf = model_dict.get(target)
if clf is not None:
# Positive-class probability between 0-1
prob = clf.predict_proba(X)[0][1]
table.append([target, round(float(prob), 4)])
else:
table.append([target, None])
return table
# --- Gradio Interface ---
custom_css = """
footer, .footer {display:none !important;}
"""
with gr.Blocks(css=custom_css, theme="default") as demo:
gr.Markdown("## AMP Spectrum")
seq_input = gr.Textbox(label="Enter Peptide Sequence")
with gr.Row():
predict_btn = gr.Button("Predict", variant="primary")
clear_btn = gr.Button("Clear")
table_output = gr.Dataframe(
headers=["Target", "Confidence"],
datatype=["str","number"],
interactive=False
)
predict_btn.click(fn=predict_peptide, inputs=seq_input, outputs=table_output)
clear_btn.click(fn=lambda: ("", []), outputs=[seq_input, table_output])
# API endpoint for iOS app
gr.api(predict_peptide, api_name="predict_peptide")
if __name__ == "__main__":
demo.launch(show_error=True)
# --- Notes for manual update ---
# 1. When adding new features in your Colab model:
# - Add the new feature computation in extract_features_app
# - Update feature_columns in the model package if needed
# - Add any new metadata tables to the model_package if used
# 2. If you add new target labels:
# - Add them to TARGET_CELLS manually
# - Or switch to dynamic TARGET_CELLS = list(model_dict.keys()) for auto-detection
# 3. Always ensure the DataFrame returned from extract_features_app matches feature_columns in order and names
|