Boolean / app.py
praveen2302's picture
Create app.py
b682b6c verified
import os
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
from Bio.Align import PairwiseAligner
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
import joblib
import streamlit as st
# Optional heavy deps
try:
import torch
import transformers
from transformers import AutoTokenizer, AutoModel
HAS_EMB = True
except:
HAS_EMB = False
try:
import xgboost as xgb
HAS_XGB = True
except:
HAS_XGB = False
# -------------------------
# GLOBALS
# -------------------------
PREFERRED_PLUS1 = set(['C', 'S', 'T'])
aligner = PairwiseAligner()
aligner.mode = "global"
# -------------------------
# Basic functions
# -------------------------
def seq_identity(a, b):
if not a or not b:
return 0.0
try:
score = aligner.score(a, b)
return score / max(len(a), len(b))
except:
matches = sum(x == y for x, y in zip(a, b))
return matches / max(len(a), len(b))
def aa_comp_props(seq):
if not seq:
res = {f'aa_pct_{aa}': 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"}
res.update({"aromaticity": 0.0, "instability_index": 0.0, "isoelectric_point": 0.0})
return res
pa = ProteinAnalysis(seq)
comp = pa.get_amino_acids_percent()
out = {f'aa_pct_{aa}': comp.get(aa, 0.0) for aa in "ACDEFGHIKLMNPQRSTVWY"}
out['aromaticity'] = pa.aromaticity()
out['instability_index'] = pa.instability_index()
out['isoelectric_point'] = pa.isoelectric_point()
return out
# -------------------------
# Embedding Provider
# -------------------------
class ProtBertProvider:
def __init__(self, model_name="Rostlab/prot_bert"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def embed(self, seq):
if not seq:
return np.zeros(1024)
tokens = " ".join(list(seq))
inputs = self.tokenizer(tokens, return_tensors="pt")
with torch.no_grad():
output = self.model(**inputs).last_hidden_state.mean(dim=1)
return output.squeeze().numpy()
# -------------------------
# Feature Extraction
# -------------------------
def extract_row(row, use_emb=False, emb=None):
nseq = str(row.get('n_intein_seq', ""))
cseq = str(row.get('c_intein_seq', ""))
plus1 = str(row.get('extein_plus1', "")).upper()
feats = {
"pair_identity": seq_identity(nseq, cseq),
"len_N": len(nseq),
"len_C": len(cseq),
"plus1_good": 1 if plus1 in PREFERRED_PLUS1 else 0,
"plus1_code": ord(plus1[0]) - 65 if plus1 else -1,
"cognate": int(row.get('cognate', 0)),
"docking_score": float(row.get('docking_score', 0)),
"pLDDT_N": float(row.get('pLDDT_N', row.get("struct_confidence", 0))),
"pLDDT_C": float(row.get('pLDDT_C', row.get("struct_confidence", 0)))
}
# AA properties
nprops = aa_comp_props(nseq)
cprops = aa_comp_props(cseq)
for k, v in nprops.items():
feats[f"N_{k}"] = v
for k, v in cprops.items():
feats[f"C_{k}"] = v
# embeddings
if use_emb and emb:
n_emb = emb.embed(nseq)
c_emb = emb.embed(cseq)
for i, x in enumerate(n_emb[:256]):
feats[f"N_emb_{i}"] = float(x)
for i, x in enumerate(c_emb[:256]):
feats[f"C_emb_{i}"] = float(x)
return feats
def build_matrix(df, use_emb=False, emb=None):
feat_rows = []
for _, r in df.iterrows():
feat_rows.append(extract_row(r, use_emb, emb))
return pd.DataFrame(feat_rows).fillna(0.0)
# -------------------------
# Train Model
# -------------------------
def train_model(df, use_emb=False, model_type="rf"):
emb = ProtBertProvider() if (use_emb and HAS_EMB) else None
X = build_matrix(df, use_emb, emb)
y = df['label'].astype(int)
if model_type == "xgb":
if not HAS_XGB:
st.error("XGBoost unavailable.")
return None
scaler = StandardScaler()
Xs = scaler.fit_transform(X)
model = xgb.XGBClassifier(objective='multi:softprob', num_class=3)
model.fit(Xs, y)
return {"model": model, "scaler": scaler, "cols": list(X.columns)}
# RandomForest
pipe = Pipeline([
("scale", StandardScaler()),
("clf", RandomForestClassifier(n_estimators=300, class_weight="balanced"))
])
pipe.fit(X, y)
return {"pipeline": pipe, "cols": list(X.columns)}
# -------------------------
# Predict
# -------------------------
def run_predict(df, saved, use_emb=False):
emb = ProtBertProvider() if (use_emb and HAS_EMB) else None
X = build_matrix(df, use_emb, emb)
if "pipeline" in saved:
pipe = saved["pipeline"]
preds = pipe.predict(X)
probs = pipe.predict_proba(X)
else:
model = saved["model"]
scaler = saved["scaler"]
cols = saved["cols"]
Xs = scaler.transform(X[cols])
preds = model.predict(Xs)
probs = model.predict_proba(Xs)
df["pred_label"] = preds
for i in range(probs.shape[1]):
df[f"prob_{i}"] = probs[:, i]
return df
# -------------------------
# Streamlit UI for Hugging Face
# -------------------------
st.title("🔬 Intein Splice Predictor — Hugging Face Space")
st.write("Upload CSV containing columns:")
st.write("`n_intein_seq`, `c_intein_seq`, `extein_plus1`, `cognate`, `docking_score`, `struct_confidence`")
mode = st.radio("Choose mode:", ["Train Model", "Predict With Model"])
# ------------------------------------
# MODE 1: TRAIN
# ------------------------------------
if mode == "Train Model":
train_file = st.file_uploader("Upload training CSV (must contain column: label)", type=["csv"])
use_emb = st.checkbox("Use ProtBert embeddings (slow, needs GPU)", value=False)
model_type = st.selectbox("Model Type", ["rf", "xgb"])
if st.button("Train"):
if train_file:
df = pd.read_csv(train_file)
saved = train_model(df, use_emb, model_type)
joblib.dump(saved, "intein_model.joblib")
st.success("Model trained & saved as intein_model.joblib")
else:
st.error("Upload a CSV first.")
# ------------------------------------
# MODE 2: PREDICT
# ------------------------------------
else:
pred_file = st.file_uploader("Upload CSV for prediction", type=["csv"])
model_file = st.file_uploader("Upload your intein_model.joblib", type=["joblib"])
use_emb = st.checkbox("Use embeddings (same setting used during training)")
if st.button("Predict"):
if pred_file and model_file:
df = pd.read_csv(pred_file)
saved = joblib.load(model_file)
out = run_predict(df, saved, use_emb)
out.to_csv("predictions.csv", index=False)
st.success("Predictions generated!")
st.download_button("Download predictions.csv", out.to_csv(index=False), "predictions.csv")
else:
st.error("Upload both CSV and model file.")