praveen2302 commited on
Commit
b682b6c
·
verified ·
1 Parent(s): a552e6f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ warnings.filterwarnings("ignore")
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from Bio.Align import PairwiseAligner
8
+ from Bio.SeqUtils.ProtParam import ProteinAnalysis
9
+ from sklearn.pipeline import Pipeline
10
+ from sklearn.preprocessing import StandardScaler
11
+ from sklearn.ensemble import RandomForestClassifier
12
+ import joblib
13
+ import streamlit as st
14
+
15
+ # Optional heavy deps
16
+ try:
17
+ import torch
18
+ import transformers
19
+ from transformers import AutoTokenizer, AutoModel
20
+ HAS_EMB = True
21
+ except:
22
+ HAS_EMB = False
23
+
24
+ try:
25
+ import xgboost as xgb
26
+ HAS_XGB = True
27
+ except:
28
+ HAS_XGB = False
29
+
30
+ # -------------------------
31
+ # GLOBALS
32
+ # -------------------------
33
+ PREFERRED_PLUS1 = set(['C', 'S', 'T'])
34
+ aligner = PairwiseAligner()
35
+ aligner.mode = "global"
36
+
37
+ # -------------------------
38
+ # Basic functions
39
+ # -------------------------
40
+ def seq_identity(a, b):
41
+ if not a or not b:
42
+ return 0.0
43
+ try:
44
+ score = aligner.score(a, b)
45
+ return score / max(len(a), len(b))
46
+ except:
47
+ matches = sum(x == y for x, y in zip(a, b))
48
+ return matches / max(len(a), len(b))
49
+
50
+
51
+ def aa_comp_props(seq):
52
+ if not seq:
53
+ res = {f'aa_pct_{aa}': 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
54
+ res.update({"aromaticity": 0.0, "instability_index": 0.0, "isoelectric_point": 0.0})
55
+ return res
56
+
57
+ pa = ProteinAnalysis(seq)
58
+ comp = pa.get_amino_acids_percent()
59
+ out = {f'aa_pct_{aa}': comp.get(aa, 0.0) for aa in "ACDEFGHIKLMNPQRSTVWY"}
60
+ out['aromaticity'] = pa.aromaticity()
61
+ out['instability_index'] = pa.instability_index()
62
+ out['isoelectric_point'] = pa.isoelectric_point()
63
+ return out
64
+
65
+
66
+ # -------------------------
67
+ # Embedding Provider
68
+ # -------------------------
69
+ class ProtBertProvider:
70
+ def __init__(self, model_name="Rostlab/prot_bert"):
71
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
72
+ self.model = AutoModel.from_pretrained(model_name)
73
+ self.model.eval()
74
+
75
+ def embed(self, seq):
76
+ if not seq:
77
+ return np.zeros(1024)
78
+
79
+ tokens = " ".join(list(seq))
80
+ inputs = self.tokenizer(tokens, return_tensors="pt")
81
+
82
+ with torch.no_grad():
83
+ output = self.model(**inputs).last_hidden_state.mean(dim=1)
84
+ return output.squeeze().numpy()
85
+
86
+
87
+ # -------------------------
88
+ # Feature Extraction
89
+ # -------------------------
90
+ def extract_row(row, use_emb=False, emb=None):
91
+ nseq = str(row.get('n_intein_seq', ""))
92
+ cseq = str(row.get('c_intein_seq', ""))
93
+ plus1 = str(row.get('extein_plus1', "")).upper()
94
+
95
+ feats = {
96
+ "pair_identity": seq_identity(nseq, cseq),
97
+ "len_N": len(nseq),
98
+ "len_C": len(cseq),
99
+ "plus1_good": 1 if plus1 in PREFERRED_PLUS1 else 0,
100
+ "plus1_code": ord(plus1[0]) - 65 if plus1 else -1,
101
+ "cognate": int(row.get('cognate', 0)),
102
+ "docking_score": float(row.get('docking_score', 0)),
103
+ "pLDDT_N": float(row.get('pLDDT_N', row.get("struct_confidence", 0))),
104
+ "pLDDT_C": float(row.get('pLDDT_C', row.get("struct_confidence", 0)))
105
+ }
106
+
107
+ # AA properties
108
+ nprops = aa_comp_props(nseq)
109
+ cprops = aa_comp_props(cseq)
110
+ for k, v in nprops.items():
111
+ feats[f"N_{k}"] = v
112
+ for k, v in cprops.items():
113
+ feats[f"C_{k}"] = v
114
+
115
+ # embeddings
116
+ if use_emb and emb:
117
+ n_emb = emb.embed(nseq)
118
+ c_emb = emb.embed(cseq)
119
+ for i, x in enumerate(n_emb[:256]):
120
+ feats[f"N_emb_{i}"] = float(x)
121
+ for i, x in enumerate(c_emb[:256]):
122
+ feats[f"C_emb_{i}"] = float(x)
123
+
124
+ return feats
125
+
126
+
127
+ def build_matrix(df, use_emb=False, emb=None):
128
+ feat_rows = []
129
+ for _, r in df.iterrows():
130
+ feat_rows.append(extract_row(r, use_emb, emb))
131
+ return pd.DataFrame(feat_rows).fillna(0.0)
132
+
133
+
134
+ # -------------------------
135
+ # Train Model
136
+ # -------------------------
137
+ def train_model(df, use_emb=False, model_type="rf"):
138
+ emb = ProtBertProvider() if (use_emb and HAS_EMB) else None
139
+
140
+ X = build_matrix(df, use_emb, emb)
141
+ y = df['label'].astype(int)
142
+
143
+ if model_type == "xgb":
144
+ if not HAS_XGB:
145
+ st.error("XGBoost unavailable.")
146
+ return None
147
+
148
+ scaler = StandardScaler()
149
+ Xs = scaler.fit_transform(X)
150
+
151
+ model = xgb.XGBClassifier(objective='multi:softprob', num_class=3)
152
+ model.fit(Xs, y)
153
+
154
+ return {"model": model, "scaler": scaler, "cols": list(X.columns)}
155
+
156
+ # RandomForest
157
+ pipe = Pipeline([
158
+ ("scale", StandardScaler()),
159
+ ("clf", RandomForestClassifier(n_estimators=300, class_weight="balanced"))
160
+ ])
161
+
162
+ pipe.fit(X, y)
163
+ return {"pipeline": pipe, "cols": list(X.columns)}
164
+
165
+
166
+ # -------------------------
167
+ # Predict
168
+ # -------------------------
169
+ def run_predict(df, saved, use_emb=False):
170
+ emb = ProtBertProvider() if (use_emb and HAS_EMB) else None
171
+ X = build_matrix(df, use_emb, emb)
172
+
173
+ if "pipeline" in saved:
174
+ pipe = saved["pipeline"]
175
+ preds = pipe.predict(X)
176
+ probs = pipe.predict_proba(X)
177
+ else:
178
+ model = saved["model"]
179
+ scaler = saved["scaler"]
180
+ cols = saved["cols"]
181
+ Xs = scaler.transform(X[cols])
182
+ preds = model.predict(Xs)
183
+ probs = model.predict_proba(Xs)
184
+
185
+ df["pred_label"] = preds
186
+ for i in range(probs.shape[1]):
187
+ df[f"prob_{i}"] = probs[:, i]
188
+
189
+ return df
190
+
191
+
192
+ # -------------------------
193
+ # Streamlit UI for Hugging Face
194
+ # -------------------------
195
+ st.title("🔬 Intein Splice Predictor — Hugging Face Space")
196
+ st.write("Upload CSV containing columns:")
197
+ st.write("`n_intein_seq`, `c_intein_seq`, `extein_plus1`, `cognate`, `docking_score`, `struct_confidence`")
198
+
199
+ mode = st.radio("Choose mode:", ["Train Model", "Predict With Model"])
200
+
201
+ # ------------------------------------
202
+ # MODE 1: TRAIN
203
+ # ------------------------------------
204
+ if mode == "Train Model":
205
+ train_file = st.file_uploader("Upload training CSV (must contain column: label)", type=["csv"])
206
+ use_emb = st.checkbox("Use ProtBert embeddings (slow, needs GPU)", value=False)
207
+ model_type = st.selectbox("Model Type", ["rf", "xgb"])
208
+
209
+ if st.button("Train"):
210
+ if train_file:
211
+ df = pd.read_csv(train_file)
212
+ saved = train_model(df, use_emb, model_type)
213
+ joblib.dump(saved, "intein_model.joblib")
214
+ st.success("Model trained & saved as intein_model.joblib")
215
+ else:
216
+ st.error("Upload a CSV first.")
217
+
218
+ # ------------------------------------
219
+ # MODE 2: PREDICT
220
+ # ------------------------------------
221
+ else:
222
+ pred_file = st.file_uploader("Upload CSV for prediction", type=["csv"])
223
+ model_file = st.file_uploader("Upload your intein_model.joblib", type=["joblib"])
224
+ use_emb = st.checkbox("Use embeddings (same setting used during training)")
225
+
226
+ if st.button("Predict"):
227
+ if pred_file and model_file:
228
+ df = pd.read_csv(pred_file)
229
+ saved = joblib.load(model_file)
230
+ out = run_predict(df, saved, use_emb)
231
+ out.to_csv("predictions.csv", index=False)
232
+ st.success("Predictions generated!")
233
+ st.download_button("Download predictions.csv", out.to_csv(index=False), "predictions.csv")
234
+ else:
235
+ st.error("Upload both CSV and model file.")