|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import joblib |
|
|
import os |
|
|
import torch |
|
|
import random |
|
|
import selfies as sf |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import plotly.express as px |
|
|
import plotly.graph_objects as go |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import SaltRemover, Descriptors |
|
|
from rdkit.Chem.MolStandardize import rdMolStandardize |
|
|
from transformers import AutoTokenizer, AutoModel, pipeline as hf_pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="PFAS Discovery AI", layout="wide", initial_sidebar_state="expanded") |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
MODEL_NAME = "JuIm/SMILES_BERT" |
|
|
|
|
|
@st.cache_resource |
|
|
def load_bert_brain(): |
|
|
try: |
|
|
with st.spinner("Downloading BERT Brain..."): |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) |
|
|
model.eval() |
|
|
return tokenizer, model |
|
|
except Exception as e: |
|
|
st.error(f"❌ Error loading BERT: {e}") |
|
|
return None, None |
|
|
|
|
|
tokenizer, bert_model = load_bert_brain() |
|
|
|
|
|
def get_descriptors(smiles_list, batch_size=16): |
|
|
if bert_model is None: return np.zeros((len(smiles_list), 768)) |
|
|
bert_model.eval() |
|
|
final_output = np.zeros((len(smiles_list), 768)) |
|
|
for i in range(0, len(smiles_list), batch_size): |
|
|
batch = [s for s in smiles_list[i:i+batch_size] if pd.notna(s)] |
|
|
if not batch: continue |
|
|
inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE) |
|
|
with torch.no_grad(): |
|
|
outputs = bert_model(**inputs) |
|
|
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy() |
|
|
valid_idx = 0 |
|
|
for j in range(len(smiles_list[i:i+batch_size])): |
|
|
if pd.notna(smiles_list[i+j]): |
|
|
final_output[i+j] = embeddings[valid_idx] |
|
|
valid_idx += 1 |
|
|
return final_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
remover = SaltRemover.SaltRemover() |
|
|
uncharger = rdMolStandardize.Uncharger() |
|
|
|
|
|
def clean_mol(s): |
|
|
try: |
|
|
if pd.isna(s): return None |
|
|
m = Chem.MolFromSmiles(s) |
|
|
if not m: return None |
|
|
m = uncharger.uncharge(remover.StripMol(m, dontRemoveEverything=True)) |
|
|
return Chem.MolToSmiles(m, canonical=True) |
|
|
except: return None |
|
|
|
|
|
|
|
|
def calculate_props(smiles): |
|
|
"""Calculates MW, LogP, TPSA, and Fluorine Counts""" |
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if not mol: return {} |
|
|
|
|
|
mw = Descriptors.MolWt(mol) |
|
|
logp = Descriptors.MolLogP(mol) |
|
|
tpsa = Descriptors.TPSA(mol) |
|
|
hbd = Descriptors.NumHDonors(mol) |
|
|
hba = Descriptors.NumHAcceptors(mol) |
|
|
|
|
|
f_count = len(mol.GetSubstructMatches(Chem.MolFromSmarts("[F]"))) |
|
|
c_count = len(mol.GetSubstructMatches(Chem.MolFromSmarts("[#6]"))) |
|
|
fc_ratio = f_count / c_count if c_count > 0 else 0 |
|
|
|
|
|
return { |
|
|
"MW": round(mw, 1), |
|
|
"LogP": round(logp, 2), |
|
|
"TPSA": round(tpsa, 1), |
|
|
"HBD": hbd, |
|
|
"HBA": hba, |
|
|
"F_Count": f_count, |
|
|
"F/C_Ratio": round(fc_ratio, 2) |
|
|
} |
|
|
except: |
|
|
return {"MW":0, "LogP":0, "TPSA":0, "HBD":0, "HBA":0, "F_Count":0, "F/C_Ratio":0} |
|
|
|
|
|
def sanity_check_class(smiles, predicted_class): |
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if not mol: return "Invalid" |
|
|
pfas_chain = mol.HasSubstructMatch(Chem.MolFromSmarts("[#6](F)(F)-[#6](F)(F)")) |
|
|
pfas_double = mol.HasSubstructMatch(Chem.MolFromSmarts("[#6](F)(F)=[#6](F)(F)")) |
|
|
tfa_group = mol.HasSubstructMatch(Chem.MolFromSmarts("[CX4](F)(F)(F)C(=O)")) |
|
|
if not (pfas_chain or pfas_double or tfa_group): return "Non-PFAS" |
|
|
if mol.HasSubstructMatch(Chem.MolFromSmarts("[CX4](F)(F)C(=O)[OH,O-]")): return "PFCA" |
|
|
if mol.HasSubstructMatch(Chem.MolFromSmarts("[CX4](F)(F)S(=O)(=O)[OH,O-]")): return "PFSA" |
|
|
return predicted_class |
|
|
except: return predicted_class |
|
|
|
|
|
def mutate_smart(s): |
|
|
try: |
|
|
chars = list(sf.split_selfies(sf.encoder(s))) |
|
|
if random.random() < 0.9: |
|
|
insert_idx = random.randint(0, len(chars)) |
|
|
atom = random.choice(["[O]", "[N]", "[C][=O]", "[C][O]"]) |
|
|
chars.insert(insert_idx, atom) |
|
|
if random.random() < 0.6: |
|
|
chars.append(random.choice(["[O]", "[N]", "[C][=O][O]"])) |
|
|
return sf.decoder("".join(chars)) |
|
|
except: return s |
|
|
|
|
|
@st.cache_resource |
|
|
def load_downstream_models(): |
|
|
try: |
|
|
clf = joblib.load("PFAS_Subclass_Classifier.pkl") |
|
|
reg_p = joblib.load("PFAS_Persistence_Regressor.pkl") |
|
|
reg_m = joblib.load("PFAS_Mobility_Regressor.pkl") |
|
|
reg_b = joblib.load("PFAS_Bioaccumulation_Regressor.pkl") |
|
|
oracle = hf_pipeline("text-classification", model="DeepChem/ChemBERTa-77M-MLM") |
|
|
return clf, reg_p, reg_m, reg_b, oracle |
|
|
except: return None, None, None, None, None |
|
|
|
|
|
clf, reg_p, reg_m, reg_b, oracle = load_downstream_models() |
|
|
|
|
|
if clf is None: |
|
|
st.error("❌ Models Missing! Upload the 4 .pkl files.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("End-to-End PFAS Discovery AI") |
|
|
st.markdown("### Powered by Evolutionary Optimization & Deep Learning") |
|
|
st.markdown("---") |
|
|
|
|
|
st.sidebar.header("1. Input Data") |
|
|
input_type = st.sidebar.radio("Source:", ["Single Molecule", "Batch CSV"]) |
|
|
|
|
|
data = [] |
|
|
if input_type == "Single Molecule": |
|
|
smi = st.sidebar.text_input("SMILES:", "OC(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F") |
|
|
name = st.sidebar.text_input("Name:", "PFDoA_Test") |
|
|
if smi: data = [{'SMILES': smi, 'ID': 'User', 'NAME': name}] |
|
|
else: |
|
|
f = st.sidebar.file_uploader("Upload CSV", type=["csv"]) |
|
|
if f: |
|
|
df_in = pd.read_csv(f) |
|
|
if 'SMILES' in df_in.columns: data = df_in.to_dict('records') |
|
|
|
|
|
st.sidebar.header("2. Pipeline Mode") |
|
|
mode = st.sidebar.selectbox("Mode:", ["Screening (Analyze)", "Discovery (Optimize)"]) |
|
|
|
|
|
COLOR_MAP = { |
|
|
"Non-PFAS": "#2ecc71", |
|
|
"PFCA": "#e74c3c", |
|
|
"PFSA": "#9b59b6", |
|
|
"General PFAS": "#f39c12", |
|
|
"Invalid": "#95a5a6" |
|
|
} |
|
|
|
|
|
if st.sidebar.button("Run Pipeline") and data: |
|
|
st.info(f"Running **{mode}** on {len(data)} molecules...") |
|
|
|
|
|
df_proc = pd.DataFrame(data) |
|
|
df_proc['Clean_SMILES'] = df_proc['SMILES'].apply(clean_mol) |
|
|
valid_df = df_proc.dropna(subset=['Clean_SMILES']) |
|
|
|
|
|
results = [] |
|
|
|
|
|
|
|
|
if mode == "Discovery (Optimize)": |
|
|
seeds = valid_df['Clean_SMILES'].tolist() |
|
|
progress_bar = st.progress(0) |
|
|
|
|
|
for i, s in enumerate(seeds): |
|
|
population = [s] |
|
|
for _ in range(20): |
|
|
new_mol = mutate_smart(s) |
|
|
if new_mol not in population: population.append(new_mol) |
|
|
|
|
|
feats = get_descriptors(population) |
|
|
preds = clf.predict(feats) |
|
|
scores_b = reg_b.predict(feats) |
|
|
scores_p = reg_p.predict(feats) |
|
|
scores_m = reg_m.predict(feats) |
|
|
|
|
|
ranked_candidates = [] |
|
|
for j, cand in enumerate(population): |
|
|
final_cls = sanity_check_class(cand, preds[j]) |
|
|
props = calculate_props(cand) |
|
|
|
|
|
entry = { |
|
|
"Structure (SMILES)": cand, |
|
|
"Type": "Original" if cand == s else "Optimized", |
|
|
"Subclass": final_cls, |
|
|
"Bioaccumulation": scores_b[j], |
|
|
"Persistence": scores_p[j], |
|
|
"Mobility": scores_m[j] |
|
|
} |
|
|
entry.update(props) |
|
|
ranked_candidates.append(entry) |
|
|
|
|
|
ranked_candidates.sort(key=lambda x: x['Bioaccumulation']) |
|
|
results.extend(ranked_candidates[:3]) |
|
|
progress_bar.progress((i + 1) / len(seeds)) |
|
|
|
|
|
|
|
|
else: |
|
|
smiles_list = valid_df['Clean_SMILES'].tolist() |
|
|
feats = get_descriptors(smiles_list) |
|
|
preds = clf.predict(feats) |
|
|
scores_p = reg_p.predict(feats) |
|
|
scores_m = reg_m.predict(feats) |
|
|
scores_b = reg_b.predict(feats) |
|
|
hazards = oracle(smiles_list) |
|
|
|
|
|
for i, row in enumerate(valid_df.itertuples()): |
|
|
final_cls = sanity_check_class(row.Clean_SMILES, preds[i]) |
|
|
tox = hazards[i]['label'] if hazards[i]['score'] > 0.70 else "Inconclusive" |
|
|
props = calculate_props(row.Clean_SMILES) |
|
|
|
|
|
entry = { |
|
|
"ID": getattr(row, 'ID', 'N/A'), |
|
|
"Structure (SMILES)": row.Clean_SMILES, |
|
|
"Subclass": final_cls, |
|
|
"Bioaccumulation": round(scores_b[i], 2), |
|
|
"Persistence": round(scores_p[i], 2), |
|
|
"Mobility": round(scores_m[i], 2), |
|
|
"Tox_Result": tox |
|
|
} |
|
|
entry.update(props) |
|
|
results.append(entry) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res_df = pd.DataFrame(results) |
|
|
|
|
|
st.markdown("### Analysis Results") |
|
|
|
|
|
|
|
|
main_cols = ['Structure (SMILES)', 'Subclass', 'Bioaccumulation', 'Persistence', 'MW', 'LogP', 'F_Count'] |
|
|
remaining_cols = [c for c in res_df.columns if c not in main_cols] |
|
|
final_cols = main_cols + remaining_cols |
|
|
|
|
|
|
|
|
st.dataframe( |
|
|
res_df[final_cols].style.highlight_min(axis=0, subset=['Bioaccumulation'], color='#d4edda'), |
|
|
use_container_width=True |
|
|
) |
|
|
|
|
|
st.download_button("Download Full CSV", res_df.to_csv(index=False).encode('utf-8'), "results.csv", "text/csv") |
|
|
|
|
|
st.markdown("---") |
|
|
st.header("Advanced Analytics Dashboard") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
|
|
|
with col1: |
|
|
st.subheader("1. Multi-Dimensional Risk Space") |
|
|
fig_3d = px.scatter_3d( |
|
|
res_df, |
|
|
x='Bioaccumulation', y='Mobility', z='Persistence', |
|
|
color='Subclass', |
|
|
symbol='Type' if 'Type' in res_df.columns else 'Subclass', |
|
|
color_discrete_map=COLOR_MAP, |
|
|
opacity=0.9, |
|
|
size_max=12, |
|
|
template="plotly_white", |
|
|
hover_data=['Structure (SMILES)', 'MW', 'LogP'], |
|
|
title="Risk Landscape (Interactive)" |
|
|
) |
|
|
fig_3d.update_layout(margin=dict(l=0, r=0, b=0, t=40), height=500) |
|
|
st.plotly_chart(fig_3d, use_container_width=True) |
|
|
|
|
|
|
|
|
with col2: |
|
|
st.subheader("2. Safety Classification") |
|
|
count_df = res_df['Subclass'].value_counts().reset_index() |
|
|
count_df.columns = ['Subclass', 'Count'] |
|
|
fig_bar = px.bar( |
|
|
count_df, x="Subclass", y="Count", color="Subclass", |
|
|
title="Molecule Counts by Class", |
|
|
color_discrete_map=COLOR_MAP, |
|
|
template="plotly_dark", |
|
|
text_auto=True |
|
|
) |
|
|
fig_bar.update_layout(height=500) |
|
|
st.plotly_chart(fig_bar, use_container_width=True) |
|
|
|
|
|
col3, col4 = st.columns(2) |
|
|
|
|
|
|
|
|
with col3: |
|
|
st.subheader("3. Property Trace") |
|
|
fig_para = px.parallel_coordinates( |
|
|
res_df, |
|
|
dimensions=['Persistence', 'Mobility', 'Bioaccumulation', 'LogP', 'MW'], |
|
|
color="Bioaccumulation", |
|
|
color_continuous_scale="Spectral_r", |
|
|
title="Trace: Chem Properties -> Risk", |
|
|
template="plotly_dark" |
|
|
) |
|
|
fig_para.update_layout(height=500) |
|
|
st.plotly_chart(fig_para, use_container_width=True) |
|
|
|
|
|
|
|
|
with col4: |
|
|
st.subheader("4. Bioaccumulation Spread") |
|
|
fig_vio = px.violin( |
|
|
res_df, y="Bioaccumulation", x="Subclass", |
|
|
color="Subclass", box=True, points="all", |
|
|
color_discrete_map=COLOR_MAP, |
|
|
template="plotly_dark", |
|
|
title="Distribution Density" |
|
|
) |
|
|
fig_vio.add_hline(y=3.5, line_dash="dash", line_color="orange", annotation_text="Safety Limit") |
|
|
fig_vio.update_layout(height=500) |
|
|
st.plotly_chart(fig_vio, use_container_width=True) |