import streamlit as st import pandas as pd import numpy as np import joblib import os import torch import random import selfies as sf import matplotlib.pyplot as plt import seaborn as sns import plotly.express as px import plotly.graph_objects as go from rdkit import Chem from rdkit.Chem import SaltRemover, Descriptors from rdkit.Chem.MolStandardize import rdMolStandardize from transformers import AutoTokenizer, AutoModel, pipeline as hf_pipeline # ================================================================= # PART 0: THE BRIDGE (Automatic Brain Setup) # ================================================================= st.set_page_config(page_title="PFAS Discovery AI", layout="wide", initial_sidebar_state="expanded") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = "JuIm/SMILES_BERT" @st.cache_resource def load_bert_brain(): try: with st.spinner("Downloading BERT Brain..."): tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) model.eval() return tokenizer, model except Exception as e: st.error(f"❌ Error loading BERT: {e}") return None, None tokenizer, bert_model = load_bert_brain() def get_descriptors(smiles_list, batch_size=16): if bert_model is None: return np.zeros((len(smiles_list), 768)) bert_model.eval() final_output = np.zeros((len(smiles_list), 768)) for i in range(0, len(smiles_list), batch_size): batch = [s for s in smiles_list[i:i+batch_size] if pd.notna(s)] if not batch: continue inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE) with torch.no_grad(): outputs = bert_model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() valid_idx = 0 for j in range(len(smiles_list[i:i+batch_size])): if pd.notna(smiles_list[i+j]): final_output[i+j] = embeddings[valid_idx] valid_idx += 1 return final_output # ================================================================= # PART 1: LOGIC & MUTATION ENGINE # ================================================================= remover = SaltRemover.SaltRemover() uncharger = rdMolStandardize.Uncharger() def clean_mol(s): try: if pd.isna(s): return None m = Chem.MolFromSmiles(s) if not m: return None m = uncharger.uncharge(remover.StripMol(m, dontRemoveEverything=True)) return Chem.MolToSmiles(m, canonical=True) except: return None # --- PHYSCHEM CALCULATOR --- def calculate_props(smiles): """Calculates MW, LogP, TPSA, and Fluorine Counts""" try: mol = Chem.MolFromSmiles(smiles) if not mol: return {} mw = Descriptors.MolWt(mol) logp = Descriptors.MolLogP(mol) tpsa = Descriptors.TPSA(mol) hbd = Descriptors.NumHDonors(mol) hba = Descriptors.NumHAcceptors(mol) f_count = len(mol.GetSubstructMatches(Chem.MolFromSmarts("[F]"))) c_count = len(mol.GetSubstructMatches(Chem.MolFromSmarts("[#6]"))) fc_ratio = f_count / c_count if c_count > 0 else 0 return { "MW": round(mw, 1), "LogP": round(logp, 2), "TPSA": round(tpsa, 1), "HBD": hbd, "HBA": hba, "F_Count": f_count, "F/C_Ratio": round(fc_ratio, 2) } except: return {"MW":0, "LogP":0, "TPSA":0, "HBD":0, "HBA":0, "F_Count":0, "F/C_Ratio":0} def sanity_check_class(smiles, predicted_class): try: mol = Chem.MolFromSmiles(smiles) if not mol: return "Invalid" pfas_chain = mol.HasSubstructMatch(Chem.MolFromSmarts("[#6](F)(F)-[#6](F)(F)")) pfas_double = mol.HasSubstructMatch(Chem.MolFromSmarts("[#6](F)(F)=[#6](F)(F)")) tfa_group = mol.HasSubstructMatch(Chem.MolFromSmarts("[CX4](F)(F)(F)C(=O)")) if not (pfas_chain or pfas_double or tfa_group): return "Non-PFAS" if mol.HasSubstructMatch(Chem.MolFromSmarts("[CX4](F)(F)C(=O)[OH,O-]")): return "PFCA" if mol.HasSubstructMatch(Chem.MolFromSmarts("[CX4](F)(F)S(=O)(=O)[OH,O-]")): return "PFSA" return predicted_class except: return predicted_class def mutate_smart(s): try: chars = list(sf.split_selfies(sf.encoder(s))) if random.random() < 0.9: insert_idx = random.randint(0, len(chars)) atom = random.choice(["[O]", "[N]", "[C][=O]", "[C][O]"]) chars.insert(insert_idx, atom) if random.random() < 0.6: chars.append(random.choice(["[O]", "[N]", "[C][=O][O]"])) return sf.decoder("".join(chars)) except: return s @st.cache_resource def load_downstream_models(): try: clf = joblib.load("PFAS_Subclass_Classifier.pkl") reg_p = joblib.load("PFAS_Persistence_Regressor.pkl") reg_m = joblib.load("PFAS_Mobility_Regressor.pkl") reg_b = joblib.load("PFAS_Bioaccumulation_Regressor.pkl") oracle = hf_pipeline("text-classification", model="DeepChem/ChemBERTa-77M-MLM") return clf, reg_p, reg_m, reg_b, oracle except: return None, None, None, None, None clf, reg_p, reg_m, reg_b, oracle = load_downstream_models() if clf is None: st.error("❌ Models Missing! Upload the 4 .pkl files.") st.stop() # ================================================================= # PART 2: THE UI # ================================================================= st.title("End-to-End PFAS Discovery AI") st.markdown("### Powered by Evolutionary Optimization & Deep Learning") st.markdown("---") st.sidebar.header("1. Input Data") input_type = st.sidebar.radio("Source:", ["Single Molecule", "Batch CSV"]) data = [] if input_type == "Single Molecule": smi = st.sidebar.text_input("SMILES:", "OC(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F") name = st.sidebar.text_input("Name:", "PFDoA_Test") if smi: data = [{'SMILES': smi, 'ID': 'User', 'NAME': name}] else: f = st.sidebar.file_uploader("Upload CSV", type=["csv"]) if f: df_in = pd.read_csv(f) if 'SMILES' in df_in.columns: data = df_in.to_dict('records') st.sidebar.header("2. Pipeline Mode") mode = st.sidebar.selectbox("Mode:", ["Screening (Analyze)", "Discovery (Optimize)"]) COLOR_MAP = { "Non-PFAS": "#2ecc71", "PFCA": "#e74c3c", "PFSA": "#9b59b6", "General PFAS": "#f39c12", "Invalid": "#95a5a6" } if st.sidebar.button("Run Pipeline") and data: st.info(f"Running **{mode}** on {len(data)} molecules...") df_proc = pd.DataFrame(data) df_proc['Clean_SMILES'] = df_proc['SMILES'].apply(clean_mol) valid_df = df_proc.dropna(subset=['Clean_SMILES']) results = [] # --- PATH A: EVOLUTIONARY DISCOVERY --- if mode == "Discovery (Optimize)": seeds = valid_df['Clean_SMILES'].tolist() progress_bar = st.progress(0) for i, s in enumerate(seeds): population = [s] for _ in range(20): new_mol = mutate_smart(s) if new_mol not in population: population.append(new_mol) feats = get_descriptors(population) preds = clf.predict(feats) scores_b = reg_b.predict(feats) scores_p = reg_p.predict(feats) scores_m = reg_m.predict(feats) ranked_candidates = [] for j, cand in enumerate(population): final_cls = sanity_check_class(cand, preds[j]) props = calculate_props(cand) entry = { "Structure (SMILES)": cand, # <--- RENAMED FOR CLARITY "Type": "Original" if cand == s else "Optimized", "Subclass": final_cls, "Bioaccumulation": scores_b[j], "Persistence": scores_p[j], "Mobility": scores_m[j] } entry.update(props) ranked_candidates.append(entry) ranked_candidates.sort(key=lambda x: x['Bioaccumulation']) results.extend(ranked_candidates[:3]) progress_bar.progress((i + 1) / len(seeds)) # --- PATH B: SCREENING --- else: smiles_list = valid_df['Clean_SMILES'].tolist() feats = get_descriptors(smiles_list) preds = clf.predict(feats) scores_p = reg_p.predict(feats) scores_m = reg_m.predict(feats) scores_b = reg_b.predict(feats) hazards = oracle(smiles_list) for i, row in enumerate(valid_df.itertuples()): final_cls = sanity_check_class(row.Clean_SMILES, preds[i]) tox = hazards[i]['label'] if hazards[i]['score'] > 0.70 else "Inconclusive" props = calculate_props(row.Clean_SMILES) entry = { "ID": getattr(row, 'ID', 'N/A'), "Structure (SMILES)": row.Clean_SMILES, # <--- ADDED HERE TOO "Subclass": final_cls, "Bioaccumulation": round(scores_b[i], 2), "Persistence": round(scores_p[i], 2), "Mobility": round(scores_m[i], 2), "Tox_Result": tox } entry.update(props) results.append(entry) # ------------------------------------------------------------------ # DASHBOARD # ------------------------------------------------------------------ res_df = pd.DataFrame(results) st.markdown("### Analysis Results") # Reorder columns to put SMILES first main_cols = ['Structure (SMILES)', 'Subclass', 'Bioaccumulation', 'Persistence', 'MW', 'LogP', 'F_Count'] remaining_cols = [c for c in res_df.columns if c not in main_cols] final_cols = main_cols + remaining_cols # Display Table with Highlighted Safe Options st.dataframe( res_df[final_cols].style.highlight_min(axis=0, subset=['Bioaccumulation'], color='#d4edda'), use_container_width=True ) st.download_button("Download Full CSV", res_df.to_csv(index=False).encode('utf-8'), "results.csv", "text/csv") st.markdown("---") st.header("Advanced Analytics Dashboard") col1, col2 = st.columns(2) # GRAPH 1: 3D DISCOVERY CUBE with col1: st.subheader("1. Multi-Dimensional Risk Space") fig_3d = px.scatter_3d( res_df, x='Bioaccumulation', y='Mobility', z='Persistence', color='Subclass', symbol='Type' if 'Type' in res_df.columns else 'Subclass', color_discrete_map=COLOR_MAP, opacity=0.9, size_max=12, template="plotly_white", hover_data=['Structure (SMILES)', 'MW', 'LogP'], # Info on hover title="Risk Landscape (Interactive)" ) fig_3d.update_layout(margin=dict(l=0, r=0, b=0, t=40), height=500) st.plotly_chart(fig_3d, use_container_width=True) # GRAPH 2: CLASS DISTRIBUTION with col2: st.subheader("2. Safety Classification") count_df = res_df['Subclass'].value_counts().reset_index() count_df.columns = ['Subclass', 'Count'] fig_bar = px.bar( count_df, x="Subclass", y="Count", color="Subclass", title="Molecule Counts by Class", color_discrete_map=COLOR_MAP, template="plotly_dark", text_auto=True ) fig_bar.update_layout(height=500) st.plotly_chart(fig_bar, use_container_width=True) col3, col4 = st.columns(2) # GRAPH 3: PARALLEL COORDINATES with col3: st.subheader("3. Property Trace") fig_para = px.parallel_coordinates( res_df, dimensions=['Persistence', 'Mobility', 'Bioaccumulation', 'LogP', 'MW'], color="Bioaccumulation", color_continuous_scale="Spectral_r", title="Trace: Chem Properties -> Risk", template="plotly_dark" ) fig_para.update_layout(height=500) st.plotly_chart(fig_para, use_container_width=True) # GRAPH 4: DISTRIBUTION VIOLIN PLOT with col4: st.subheader("4. Bioaccumulation Spread") fig_vio = px.violin( res_df, y="Bioaccumulation", x="Subclass", color="Subclass", box=True, points="all", color_discrete_map=COLOR_MAP, template="plotly_dark", title="Distribution Density" ) fig_vio.add_hline(y=3.5, line_dash="dash", line_color="orange", annotation_text="Safety Limit") fig_vio.update_layout(height=500) st.plotly_chart(fig_vio, use_container_width=True)