PFAS-Analyzer / app.py
tueniuu's picture
Update app.py
80ace75 verified
import streamlit as st
import pandas as pd
import numpy as np
import joblib
import os
import torch
import random
import selfies as sf
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from rdkit import Chem
from rdkit.Chem import SaltRemover, Descriptors
from rdkit.Chem.MolStandardize import rdMolStandardize
from transformers import AutoTokenizer, AutoModel, pipeline as hf_pipeline
# =================================================================
# PART 0: THE BRIDGE (Automatic Brain Setup)
# =================================================================
st.set_page_config(page_title="PFAS Discovery AI", layout="wide", initial_sidebar_state="expanded")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "JuIm/SMILES_BERT"
@st.cache_resource
def load_bert_brain():
try:
with st.spinner("Downloading BERT Brain..."):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
return tokenizer, model
except Exception as e:
st.error(f"❌ Error loading BERT: {e}")
return None, None
tokenizer, bert_model = load_bert_brain()
def get_descriptors(smiles_list, batch_size=16):
if bert_model is None: return np.zeros((len(smiles_list), 768))
bert_model.eval()
final_output = np.zeros((len(smiles_list), 768))
for i in range(0, len(smiles_list), batch_size):
batch = [s for s in smiles_list[i:i+batch_size] if pd.notna(s)]
if not batch: continue
inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
with torch.no_grad():
outputs = bert_model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
valid_idx = 0
for j in range(len(smiles_list[i:i+batch_size])):
if pd.notna(smiles_list[i+j]):
final_output[i+j] = embeddings[valid_idx]
valid_idx += 1
return final_output
# =================================================================
# PART 1: LOGIC & MUTATION ENGINE
# =================================================================
remover = SaltRemover.SaltRemover()
uncharger = rdMolStandardize.Uncharger()
def clean_mol(s):
try:
if pd.isna(s): return None
m = Chem.MolFromSmiles(s)
if not m: return None
m = uncharger.uncharge(remover.StripMol(m, dontRemoveEverything=True))
return Chem.MolToSmiles(m, canonical=True)
except: return None
# --- PHYSCHEM CALCULATOR ---
def calculate_props(smiles):
"""Calculates MW, LogP, TPSA, and Fluorine Counts"""
try:
mol = Chem.MolFromSmiles(smiles)
if not mol: return {}
mw = Descriptors.MolWt(mol)
logp = Descriptors.MolLogP(mol)
tpsa = Descriptors.TPSA(mol)
hbd = Descriptors.NumHDonors(mol)
hba = Descriptors.NumHAcceptors(mol)
f_count = len(mol.GetSubstructMatches(Chem.MolFromSmarts("[F]")))
c_count = len(mol.GetSubstructMatches(Chem.MolFromSmarts("[#6]")))
fc_ratio = f_count / c_count if c_count > 0 else 0
return {
"MW": round(mw, 1),
"LogP": round(logp, 2),
"TPSA": round(tpsa, 1),
"HBD": hbd,
"HBA": hba,
"F_Count": f_count,
"F/C_Ratio": round(fc_ratio, 2)
}
except:
return {"MW":0, "LogP":0, "TPSA":0, "HBD":0, "HBA":0, "F_Count":0, "F/C_Ratio":0}
def sanity_check_class(smiles, predicted_class):
try:
mol = Chem.MolFromSmiles(smiles)
if not mol: return "Invalid"
pfas_chain = mol.HasSubstructMatch(Chem.MolFromSmarts("[#6](F)(F)-[#6](F)(F)"))
pfas_double = mol.HasSubstructMatch(Chem.MolFromSmarts("[#6](F)(F)=[#6](F)(F)"))
tfa_group = mol.HasSubstructMatch(Chem.MolFromSmarts("[CX4](F)(F)(F)C(=O)"))
if not (pfas_chain or pfas_double or tfa_group): return "Non-PFAS"
if mol.HasSubstructMatch(Chem.MolFromSmarts("[CX4](F)(F)C(=O)[OH,O-]")): return "PFCA"
if mol.HasSubstructMatch(Chem.MolFromSmarts("[CX4](F)(F)S(=O)(=O)[OH,O-]")): return "PFSA"
return predicted_class
except: return predicted_class
def mutate_smart(s):
try:
chars = list(sf.split_selfies(sf.encoder(s)))
if random.random() < 0.9:
insert_idx = random.randint(0, len(chars))
atom = random.choice(["[O]", "[N]", "[C][=O]", "[C][O]"])
chars.insert(insert_idx, atom)
if random.random() < 0.6:
chars.append(random.choice(["[O]", "[N]", "[C][=O][O]"]))
return sf.decoder("".join(chars))
except: return s
@st.cache_resource
def load_downstream_models():
try:
clf = joblib.load("PFAS_Subclass_Classifier.pkl")
reg_p = joblib.load("PFAS_Persistence_Regressor.pkl")
reg_m = joblib.load("PFAS_Mobility_Regressor.pkl")
reg_b = joblib.load("PFAS_Bioaccumulation_Regressor.pkl")
oracle = hf_pipeline("text-classification", model="DeepChem/ChemBERTa-77M-MLM")
return clf, reg_p, reg_m, reg_b, oracle
except: return None, None, None, None, None
clf, reg_p, reg_m, reg_b, oracle = load_downstream_models()
if clf is None:
st.error("❌ Models Missing! Upload the 4 .pkl files.")
st.stop()
# =================================================================
# PART 2: THE UI
# =================================================================
st.title("End-to-End PFAS Discovery AI")
st.markdown("### Powered by Evolutionary Optimization & Deep Learning")
st.markdown("---")
st.sidebar.header("1. Input Data")
input_type = st.sidebar.radio("Source:", ["Single Molecule", "Batch CSV"])
data = []
if input_type == "Single Molecule":
smi = st.sidebar.text_input("SMILES:", "OC(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F")
name = st.sidebar.text_input("Name:", "PFDoA_Test")
if smi: data = [{'SMILES': smi, 'ID': 'User', 'NAME': name}]
else:
f = st.sidebar.file_uploader("Upload CSV", type=["csv"])
if f:
df_in = pd.read_csv(f)
if 'SMILES' in df_in.columns: data = df_in.to_dict('records')
st.sidebar.header("2. Pipeline Mode")
mode = st.sidebar.selectbox("Mode:", ["Screening (Analyze)", "Discovery (Optimize)"])
COLOR_MAP = {
"Non-PFAS": "#2ecc71",
"PFCA": "#e74c3c",
"PFSA": "#9b59b6",
"General PFAS": "#f39c12",
"Invalid": "#95a5a6"
}
if st.sidebar.button("Run Pipeline") and data:
st.info(f"Running **{mode}** on {len(data)} molecules...")
df_proc = pd.DataFrame(data)
df_proc['Clean_SMILES'] = df_proc['SMILES'].apply(clean_mol)
valid_df = df_proc.dropna(subset=['Clean_SMILES'])
results = []
# --- PATH A: EVOLUTIONARY DISCOVERY ---
if mode == "Discovery (Optimize)":
seeds = valid_df['Clean_SMILES'].tolist()
progress_bar = st.progress(0)
for i, s in enumerate(seeds):
population = [s]
for _ in range(20):
new_mol = mutate_smart(s)
if new_mol not in population: population.append(new_mol)
feats = get_descriptors(population)
preds = clf.predict(feats)
scores_b = reg_b.predict(feats)
scores_p = reg_p.predict(feats)
scores_m = reg_m.predict(feats)
ranked_candidates = []
for j, cand in enumerate(population):
final_cls = sanity_check_class(cand, preds[j])
props = calculate_props(cand)
entry = {
"Structure (SMILES)": cand, # <--- RENAMED FOR CLARITY
"Type": "Original" if cand == s else "Optimized",
"Subclass": final_cls,
"Bioaccumulation": scores_b[j],
"Persistence": scores_p[j],
"Mobility": scores_m[j]
}
entry.update(props)
ranked_candidates.append(entry)
ranked_candidates.sort(key=lambda x: x['Bioaccumulation'])
results.extend(ranked_candidates[:3])
progress_bar.progress((i + 1) / len(seeds))
# --- PATH B: SCREENING ---
else:
smiles_list = valid_df['Clean_SMILES'].tolist()
feats = get_descriptors(smiles_list)
preds = clf.predict(feats)
scores_p = reg_p.predict(feats)
scores_m = reg_m.predict(feats)
scores_b = reg_b.predict(feats)
hazards = oracle(smiles_list)
for i, row in enumerate(valid_df.itertuples()):
final_cls = sanity_check_class(row.Clean_SMILES, preds[i])
tox = hazards[i]['label'] if hazards[i]['score'] > 0.70 else "Inconclusive"
props = calculate_props(row.Clean_SMILES)
entry = {
"ID": getattr(row, 'ID', 'N/A'),
"Structure (SMILES)": row.Clean_SMILES, # <--- ADDED HERE TOO
"Subclass": final_cls,
"Bioaccumulation": round(scores_b[i], 2),
"Persistence": round(scores_p[i], 2),
"Mobility": round(scores_m[i], 2),
"Tox_Result": tox
}
entry.update(props)
results.append(entry)
# ------------------------------------------------------------------
# DASHBOARD
# ------------------------------------------------------------------
res_df = pd.DataFrame(results)
st.markdown("### Analysis Results")
# Reorder columns to put SMILES first
main_cols = ['Structure (SMILES)', 'Subclass', 'Bioaccumulation', 'Persistence', 'MW', 'LogP', 'F_Count']
remaining_cols = [c for c in res_df.columns if c not in main_cols]
final_cols = main_cols + remaining_cols
# Display Table with Highlighted Safe Options
st.dataframe(
res_df[final_cols].style.highlight_min(axis=0, subset=['Bioaccumulation'], color='#d4edda'),
use_container_width=True
)
st.download_button("Download Full CSV", res_df.to_csv(index=False).encode('utf-8'), "results.csv", "text/csv")
st.markdown("---")
st.header("Advanced Analytics Dashboard")
col1, col2 = st.columns(2)
# GRAPH 1: 3D DISCOVERY CUBE
with col1:
st.subheader("1. Multi-Dimensional Risk Space")
fig_3d = px.scatter_3d(
res_df,
x='Bioaccumulation', y='Mobility', z='Persistence',
color='Subclass',
symbol='Type' if 'Type' in res_df.columns else 'Subclass',
color_discrete_map=COLOR_MAP,
opacity=0.9,
size_max=12,
template="plotly_white",
hover_data=['Structure (SMILES)', 'MW', 'LogP'], # Info on hover
title="Risk Landscape (Interactive)"
)
fig_3d.update_layout(margin=dict(l=0, r=0, b=0, t=40), height=500)
st.plotly_chart(fig_3d, use_container_width=True)
# GRAPH 2: CLASS DISTRIBUTION
with col2:
st.subheader("2. Safety Classification")
count_df = res_df['Subclass'].value_counts().reset_index()
count_df.columns = ['Subclass', 'Count']
fig_bar = px.bar(
count_df, x="Subclass", y="Count", color="Subclass",
title="Molecule Counts by Class",
color_discrete_map=COLOR_MAP,
template="plotly_dark",
text_auto=True
)
fig_bar.update_layout(height=500)
st.plotly_chart(fig_bar, use_container_width=True)
col3, col4 = st.columns(2)
# GRAPH 3: PARALLEL COORDINATES
with col3:
st.subheader("3. Property Trace")
fig_para = px.parallel_coordinates(
res_df,
dimensions=['Persistence', 'Mobility', 'Bioaccumulation', 'LogP', 'MW'],
color="Bioaccumulation",
color_continuous_scale="Spectral_r",
title="Trace: Chem Properties -> Risk",
template="plotly_dark"
)
fig_para.update_layout(height=500)
st.plotly_chart(fig_para, use_container_width=True)
# GRAPH 4: DISTRIBUTION VIOLIN PLOT
with col4:
st.subheader("4. Bioaccumulation Spread")
fig_vio = px.violin(
res_df, y="Bioaccumulation", x="Subclass",
color="Subclass", box=True, points="all",
color_discrete_map=COLOR_MAP,
template="plotly_dark",
title="Distribution Density"
)
fig_vio.add_hline(y=3.5, line_dash="dash", line_color="orange", annotation_text="Safety Limit")
fig_vio.update_layout(height=500)
st.plotly_chart(fig_vio, use_container_width=True)