import os import warnings warnings.filterwarnings("ignore") import numpy as np import pandas as pd from Bio.Align import PairwiseAligner from Bio.SeqUtils.ProtParam import ProteinAnalysis from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.ensemble import RandomForestClassifier import joblib import streamlit as st # Optional heavy deps try: import torch import transformers from transformers import AutoTokenizer, AutoModel HAS_EMB = True except: HAS_EMB = False try: import xgboost as xgb HAS_XGB = True except: HAS_XGB = False # ------------------------- # GLOBALS # ------------------------- PREFERRED_PLUS1 = set(['C', 'S', 'T']) aligner = PairwiseAligner() aligner.mode = "global" # ------------------------- # Basic functions # ------------------------- def seq_identity(a, b): if not a or not b: return 0.0 try: score = aligner.score(a, b) return score / max(len(a), len(b)) except: matches = sum(x == y for x, y in zip(a, b)) return matches / max(len(a), len(b)) def aa_comp_props(seq): if not seq: res = {f'aa_pct_{aa}': 0.0 for aa in "ACDEFGHIKLMNPQRSTVWY"} res.update({"aromaticity": 0.0, "instability_index": 0.0, "isoelectric_point": 0.0}) return res pa = ProteinAnalysis(seq) comp = pa.get_amino_acids_percent() out = {f'aa_pct_{aa}': comp.get(aa, 0.0) for aa in "ACDEFGHIKLMNPQRSTVWY"} out['aromaticity'] = pa.aromaticity() out['instability_index'] = pa.instability_index() out['isoelectric_point'] = pa.isoelectric_point() return out # ------------------------- # Embedding Provider # ------------------------- class ProtBertProvider: def __init__(self, model_name="Rostlab/prot_bert"): self.tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False) self.model = AutoModel.from_pretrained(model_name) self.model.eval() def embed(self, seq): if not seq: return np.zeros(1024) tokens = " ".join(list(seq)) inputs = self.tokenizer(tokens, return_tensors="pt") with torch.no_grad(): output = self.model(**inputs).last_hidden_state.mean(dim=1) return output.squeeze().numpy() # ------------------------- # Feature Extraction # ------------------------- def extract_row(row, use_emb=False, emb=None): nseq = str(row.get('n_intein_seq', "")) cseq = str(row.get('c_intein_seq', "")) plus1 = str(row.get('extein_plus1', "")).upper() feats = { "pair_identity": seq_identity(nseq, cseq), "len_N": len(nseq), "len_C": len(cseq), "plus1_good": 1 if plus1 in PREFERRED_PLUS1 else 0, "plus1_code": ord(plus1[0]) - 65 if plus1 else -1, "cognate": int(row.get('cognate', 0)), "docking_score": float(row.get('docking_score', 0)), "pLDDT_N": float(row.get('pLDDT_N', row.get("struct_confidence", 0))), "pLDDT_C": float(row.get('pLDDT_C', row.get("struct_confidence", 0))) } # AA properties nprops = aa_comp_props(nseq) cprops = aa_comp_props(cseq) for k, v in nprops.items(): feats[f"N_{k}"] = v for k, v in cprops.items(): feats[f"C_{k}"] = v # embeddings if use_emb and emb: n_emb = emb.embed(nseq) c_emb = emb.embed(cseq) for i, x in enumerate(n_emb[:256]): feats[f"N_emb_{i}"] = float(x) for i, x in enumerate(c_emb[:256]): feats[f"C_emb_{i}"] = float(x) return feats def build_matrix(df, use_emb=False, emb=None): feat_rows = [] for _, r in df.iterrows(): feat_rows.append(extract_row(r, use_emb, emb)) return pd.DataFrame(feat_rows).fillna(0.0) # ------------------------- # Train Model # ------------------------- def train_model(df, use_emb=False, model_type="rf"): emb = ProtBertProvider() if (use_emb and HAS_EMB) else None X = build_matrix(df, use_emb, emb) y = df['label'].astype(int) if model_type == "xgb": if not HAS_XGB: st.error("XGBoost unavailable.") return None scaler = StandardScaler() Xs = scaler.fit_transform(X) model = xgb.XGBClassifier(objective='multi:softprob', num_class=3) model.fit(Xs, y) return {"model": model, "scaler": scaler, "cols": list(X.columns)} # RandomForest pipe = Pipeline([ ("scale", StandardScaler()), ("clf", RandomForestClassifier(n_estimators=300, class_weight="balanced")) ]) pipe.fit(X, y) return {"pipeline": pipe, "cols": list(X.columns)} # ------------------------- # Predict # ------------------------- def run_predict(df, saved, use_emb=False): emb = ProtBertProvider() if (use_emb and HAS_EMB) else None X = build_matrix(df, use_emb, emb) if "pipeline" in saved: pipe = saved["pipeline"] preds = pipe.predict(X) probs = pipe.predict_proba(X) else: model = saved["model"] scaler = saved["scaler"] cols = saved["cols"] Xs = scaler.transform(X[cols]) preds = model.predict(Xs) probs = model.predict_proba(Xs) df["pred_label"] = preds for i in range(probs.shape[1]): df[f"prob_{i}"] = probs[:, i] return df # ------------------------- # Streamlit UI for Hugging Face # ------------------------- st.title("🔬 Intein Splice Predictor — Hugging Face Space") st.write("Upload CSV containing columns:") st.write("`n_intein_seq`, `c_intein_seq`, `extein_plus1`, `cognate`, `docking_score`, `struct_confidence`") mode = st.radio("Choose mode:", ["Train Model", "Predict With Model"]) # ------------------------------------ # MODE 1: TRAIN # ------------------------------------ if mode == "Train Model": train_file = st.file_uploader("Upload training CSV (must contain column: label)", type=["csv"]) use_emb = st.checkbox("Use ProtBert embeddings (slow, needs GPU)", value=False) model_type = st.selectbox("Model Type", ["rf", "xgb"]) if st.button("Train"): if train_file: df = pd.read_csv(train_file) saved = train_model(df, use_emb, model_type) joblib.dump(saved, "intein_model.joblib") st.success("Model trained & saved as intein_model.joblib") else: st.error("Upload a CSV first.") # ------------------------------------ # MODE 2: PREDICT # ------------------------------------ else: pred_file = st.file_uploader("Upload CSV for prediction", type=["csv"]) model_file = st.file_uploader("Upload your intein_model.joblib", type=["joblib"]) use_emb = st.checkbox("Use embeddings (same setting used during training)") if st.button("Predict"): if pred_file and model_file: df = pd.read_csv(pred_file) saved = joblib.load(model_file) out = run_predict(df, saved, use_emb) out.to_csv("predictions.csv", index=False) st.success("Predictions generated!") st.download_button("Download predictions.csv", out.to_csv(index=False), "predictions.csv") else: st.error("Upload both CSV and model file.")