File size: 5,766 Bytes
aa85173
 
 
 
 
 
bffa89b
 
8e0600f
aa85173
aa07613
 
 
 
 
aa85173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa07613
 
aa85173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e0600f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import joblib
from huggingface_hub import hf_hub_download
import pandas as pd
import numpy as np
from collections import Counter
import os
os.environ["GRADIO_SSR_MODE"] = "false"
os.environ["GRADIO_ANALYTICS_ENABLED"] = "false"

import time
print("STARTING APP...")
time.sleep(2)
print("APP READY")

# ➕ ADDED: placeholders (instead of loading at startup)
model_dict = None
feature_columns = None
model_package = None

# ❌ REMOVED: direct model loading at startup
# repo_id = "Ym420/Peptide-Function"
# model_filename = "xgb_multilabel_model_full.pkl"
# model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
# model_package = joblib.load(model_path)
# model_dict = model_package['model']
# feature_columns = model_package['feature_columns']

# ➕ ADDED: lazy loader function
def init_model():
    global model_dict, feature_columns, model_package
    
    if model_dict is None:
        repo_id = "Ym420/Peptide-Function"
        model_filename = "xgb_multilabel_model_full.pkl"

        model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
        model_package = joblib.load(model_path)

        model_dict = model_package['model']
        feature_columns = model_package['feature_columns']

        # 🔁 MOVED: metadata loading (was global before)
        global aa_list, dipeptides, hydrophobicity_scale, eisenberg_scale
        global aa_mass, aa_charge, aa_boman, aa_flexibility
        global aa_polarizability, aa_aliphatic, aa_deltaG, aa_pucker

        aa_list = model_package.get('aa_list', [])
        dipeptides = model_package.get('dipeptides', [])
        hydrophobicity_scale = model_package.get('hydrophobicity_scale', {})
        eisenberg_scale = model_package.get('eisenberg_scale', {})
        aa_mass = model_package.get('aa_mass', {})
        aa_charge = model_package.get('aa_charge', {})
        aa_boman = model_package.get('aa_boman', {})
        aa_flexibility = model_package.get('aa_flexibility', {})
        aa_polarizability = model_package.get('aa_polarizability', {})
        aa_aliphatic = model_package.get('aa_aliphatic', {})
        aa_deltaG = model_package.get('aa_deltaG', {})
        aa_pucker = model_package.get('aa_pucker', {})

# --- Target cells ---
TARGET_CELLS = ["Gram+", "Fungus", "Mammalian Cell", "Cancer", "Gram-"]

# --- Feature extraction ---
def extract_features_app(seq: str) -> pd.DataFrame:
    seq = seq.upper()
    
    count = Counter([seq[i:i+2] for i in range(len(seq)-1)])
    total = max(len(seq)-1, 1)
    dipep_features = [count.get(dp, 0) / total for dp in dipeptides]

    def g(aa, table): return table.get(aa, 0)
    def h(dp, table): return (g(dp[0], table) + g(dp[1], table)) / 2.0

    dipeptides_seq = [seq[i:i+2] for i in range(len(seq)-1)]
    
    if len(seq) < 2:
        physchem_features = [0]*13
    else:
        mw = np.mean([h(dp, aa_mass) for dp in dipeptides_seq])
        charge = np.mean([h(dp, aa_charge) for dp in dipeptides_seq])
        hydro = np.mean([h(dp, hydrophobicity_scale) for dp in dipeptides_seq])
        aromatic = np.mean([(dp[0] in 'FWY') + (dp[1] in 'FWY') for dp in dipeptides_seq]) / 2.0
        pI = np.mean([h(dp, {aa: 7 + (int(aa in 'KRH') - int(aa in 'DE')) for aa in aa_list}) for dp in dipeptides_seq])
        instability = np.mean([((dp[0] in 'DEKR') + (dp[1] in 'DEKR')) / 2.0 for dp in dipeptides_seq])
        hydro_moment = np.sqrt(np.mean([(h(dp, eisenberg_scale))**2 for dp in dipeptides_seq]))
        aliphatic = np.mean([h(dp, aa_aliphatic) for dp in dipeptides_seq])
        boman = np.mean([h(dp, aa_boman) for dp in dipeptides_seq])
        flexibility = np.mean([h(dp, aa_flexibility) for dp in dipeptides_seq])
        polarizability = np.mean([h(dp, aa_polarizability) for dp in dipeptides_seq])
        deltag = np.mean([h(dp, aa_deltaG) for dp in dipeptides_seq])
        pucker = np.mean([h(dp, aa_pucker) for dp in dipeptides_seq])

        physchem_features = [mw, charge, hydro, aromatic, pI, instability,
                             hydro_moment, aliphatic, boman, flexibility, polarizability, deltag, pucker]
    
    features = dipep_features + physchem_features
    
    df = pd.DataFrame([features], columns=feature_columns)
    df = df.astype('float32')
    return df

# --- Prediction function ---
def predict_peptide(sequence: str):
    init_model()   # ➕ ADDED: ensures model loads only when needed

    seq = "".join(sequence.split()).upper()
    if not seq:
        return []

    X = extract_features_app(seq)
    
    table = []
    for target in TARGET_CELLS:
        clf = model_dict.get(target)
        if clf is not None:
            prob = clf.predict_proba(X)[0][1]
            table.append([target, round(float(prob), 4)])
        else:
            table.append([target, None])

    return table

# --- Gradio Interface ---
custom_css = """
footer, .footer {display:none !important;}
"""

#with gr.Blocks(css=custom_css, theme="default") as demo:
with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("## AMP Spectrum")
    
    seq_input = gr.Textbox(label="Enter Peptide Sequence")
    
    with gr.Row():
        predict_btn = gr.Button("Predict", variant="primary")
        clear_btn = gr.Button("Clear")
    
    table_output = gr.Dataframe(
        headers=["Target", "Confidence"],
        datatype=["str","number"],
        interactive=False
    )
    
    predict_btn.click(fn=predict_peptide, inputs=seq_input, outputs=table_output)
    clear_btn.click(fn=lambda: ("", []), outputs=[seq_input, table_output])
    
    gr.api(predict_peptide, api_name="predict_peptide")

if __name__ == "__main__": 
    demo.launch(
    server_name="0.0.0.0",
    server_port=7860,
    root_path="/",
    show_error=True,
    ssr_mode=False
)