tueniuu commited on
Commit
b1e5a57
Β·
verified Β·
1 Parent(s): e7d30e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -32
app.py CHANGED
@@ -8,13 +8,15 @@ import random
8
  import selfies as sf
9
  import matplotlib.pyplot as plt
10
  import seaborn as sns
 
 
11
  from rdkit import Chem
12
  from rdkit.Chem import SaltRemover
13
  from rdkit.Chem.MolStandardize import rdMolStandardize
14
  from transformers import AutoTokenizer, AutoModel, pipeline as hf_pipeline
15
 
16
  # =================================================================
17
- # PART 0: THE BRIDGE (Brain Setup)
18
  # =================================================================
19
  st.set_page_config(page_title="PFAS Discovery AI", layout="wide")
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -96,7 +98,7 @@ def mutate_smart(s):
96
 
97
  # Action 2: Cap Ends to increase Solubility
98
  if random.random() < 0.6:
99
- chars.append(random.choice(["[O]", "[N]", "[C][=O][O]"])) # Add Acid/Alcohol group
100
 
101
  return sf.decoder("".join(chars))
102
  except: return s
@@ -119,10 +121,10 @@ if clf is None:
119
  st.stop()
120
 
121
  # =================================================================
122
- # PART 2: THE UI (With Evolutionary Search)
123
  # =================================================================
124
  st.title("πŸ§ͺ End-to-End PFAS Discovery AI")
125
- st.markdown("Powered by **Evolutionary Optimization** (Generating 20 $\\to$ Keeping Top 3)")
126
 
127
  st.sidebar.header("1. Input Data")
128
  input_type = st.sidebar.radio("Source:", ["Single Molecule", "Batch CSV"])
@@ -154,29 +156,26 @@ if st.sidebar.button("πŸš€ Run Pipeline") and data:
154
  # --- PATH A: EVOLUTIONARY DISCOVERY ---
155
  if mode == "Discovery (Optimize)":
156
  seeds = valid_df['Clean_SMILES'].tolist()
157
-
158
  progress_bar = st.progress(0)
159
 
160
  for i, s in enumerate(seeds):
161
- # 1. SPAWN POPULATION (Generate 20 mutants)
162
- population = [s] # Include original
163
  for _ in range(20):
164
  new_mol = mutate_smart(s)
165
  if new_mol not in population: population.append(new_mol)
166
 
167
- # 2. SCORE POPULATION (The Filter)
168
  feats = get_descriptors(population)
169
- preds = clf.predict(feats) # Class
170
- scores_b = reg_b.predict(feats) # Bioaccumulation (Target)
171
  scores_p = reg_p.predict(feats)
172
  scores_m = reg_m.predict(feats)
173
 
174
- # 3. RANK & SELECT (Survival of the Safest)
175
  ranked_candidates = []
176
  for j, cand in enumerate(population):
177
- # Apply Logic Layer
178
  final_cls = sanity_check_class(cand, preds[j])
179
-
180
  ranked_candidates.append({
181
  "Candidate": cand,
182
  "Type": "Original" if cand == s else "Optimized",
@@ -186,12 +185,9 @@ if st.sidebar.button("πŸš€ Run Pipeline") and data:
186
  "Mobility": scores_m[j]
187
  })
188
 
189
- # SORT: Lowest Bioaccumulation first
190
  ranked_candidates.sort(key=lambda x: x['Bioaccumulation'])
191
-
192
- # KEEP: Only the Top 3 Best
193
  results.extend(ranked_candidates[:3])
194
-
195
  progress_bar.progress((i + 1) / len(seeds))
196
 
197
  # --- PATH B: SCREENING ---
@@ -217,22 +213,68 @@ if st.sidebar.button("πŸš€ Run Pipeline") and data:
217
  "Tox_Result": tox
218
  })
219
 
220
- # Results Display
 
 
221
  res_df = pd.DataFrame(results)
222
- st.subheader("πŸ“Š Optimization Results")
 
 
223
  st.dataframe(res_df)
224
  st.download_button("Download CSV", res_df.to_csv(index=False).encode('utf-8'), "results.csv", "text/csv")
 
 
 
225
 
226
- st.subheader("⚠️ Safety Dashboard")
227
- fig, ax = plt.subplots(figsize=(10, 6))
228
- palette = {"Non-PFAS": "green", "General PFAS": "red", "PFCA": "darkred", "PFSA": "purple"}
229
- for u in res_df['Subclass'].unique():
230
- if u not in palette: palette[u] = "gray"
231
-
232
- sns.scatterplot(
233
- data=res_df, x='Bioaccumulation', y='Mobility', hue='Subclass', style='Subclass',
234
- size='Persistence', sizes=(50, 300), palette=palette, ax=ax, alpha=0.8, edgecolor='black'
235
- )
236
- plt.axvline(x=3.5, color='orange', linestyle='--', label='Bioacc Limit')
237
- plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
238
- st.pyplot(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  import selfies as sf
9
  import matplotlib.pyplot as plt
10
  import seaborn as sns
11
+ import plotly.express as px # Interactive Graphs
12
+ import plotly.graph_objects as go
13
  from rdkit import Chem
14
  from rdkit.Chem import SaltRemover
15
  from rdkit.Chem.MolStandardize import rdMolStandardize
16
  from transformers import AutoTokenizer, AutoModel, pipeline as hf_pipeline
17
 
18
  # =================================================================
19
+ # PART 0: THE BRIDGE (Automatic Brain Setup)
20
  # =================================================================
21
  st.set_page_config(page_title="PFAS Discovery AI", layout="wide")
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
98
 
99
  # Action 2: Cap Ends to increase Solubility
100
  if random.random() < 0.6:
101
+ chars.append(random.choice(["[O]", "[N]", "[C][=O][O]"]))
102
 
103
  return sf.decoder("".join(chars))
104
  except: return s
 
121
  st.stop()
122
 
123
  # =================================================================
124
+ # PART 2: THE UI (With 4-Graph Dashboard)
125
  # =================================================================
126
  st.title("πŸ§ͺ End-to-End PFAS Discovery AI")
127
+ st.markdown("Powered by **Evolutionary Optimization** & **Advanced Visualization**")
128
 
129
  st.sidebar.header("1. Input Data")
130
  input_type = st.sidebar.radio("Source:", ["Single Molecule", "Batch CSV"])
 
156
  # --- PATH A: EVOLUTIONARY DISCOVERY ---
157
  if mode == "Discovery (Optimize)":
158
  seeds = valid_df['Clean_SMILES'].tolist()
 
159
  progress_bar = st.progress(0)
160
 
161
  for i, s in enumerate(seeds):
162
+ # 1. SPAWN POPULATION
163
+ population = [s]
164
  for _ in range(20):
165
  new_mol = mutate_smart(s)
166
  if new_mol not in population: population.append(new_mol)
167
 
168
+ # 2. SCORE
169
  feats = get_descriptors(population)
170
+ preds = clf.predict(feats)
171
+ scores_b = reg_b.predict(feats)
172
  scores_p = reg_p.predict(feats)
173
  scores_m = reg_m.predict(feats)
174
 
175
+ # 3. RANK
176
  ranked_candidates = []
177
  for j, cand in enumerate(population):
 
178
  final_cls = sanity_check_class(cand, preds[j])
 
179
  ranked_candidates.append({
180
  "Candidate": cand,
181
  "Type": "Original" if cand == s else "Optimized",
 
185
  "Mobility": scores_m[j]
186
  })
187
 
188
+ # Select Top 3 Best
189
  ranked_candidates.sort(key=lambda x: x['Bioaccumulation'])
 
 
190
  results.extend(ranked_candidates[:3])
 
191
  progress_bar.progress((i + 1) / len(seeds))
192
 
193
  # --- PATH B: SCREENING ---
 
213
  "Tox_Result": tox
214
  })
215
 
216
+ # ------------------------------------------------------------------
217
+ # VISUALIZATION DASHBOARD
218
+ # ------------------------------------------------------------------
219
  res_df = pd.DataFrame(results)
220
+
221
+ # 1. RESULTS TABLE
222
+ st.subheader("πŸ“Š Data Table")
223
  st.dataframe(res_df)
224
  st.download_button("Download CSV", res_df.to_csv(index=False).encode('utf-8'), "results.csv", "text/csv")
225
+
226
+ st.markdown("---")
227
+ st.header("πŸ“ˆ Advanced Analytics Dashboard")
228
 
229
+ col1, col2 = st.columns(2)
230
+
231
+ color_map = {"Non-PFAS": "green", "PFCA": "red", "PFSA": "purple", "General PFAS": "orange"}
232
+
233
+ # GRAPH 1: 3D DISCOVERY CUBE
234
+ with col1:
235
+ st.subheader("🧊 1. Multi-Dimensional Risk")
236
+ fig_3d = px.scatter_3d(
237
+ res_df,
238
+ x='Bioaccumulation', y='Mobility', z='Persistence',
239
+ color='Subclass', symbol='Type' if 'Type' in res_df.columns else 'Subclass',
240
+ color_discrete_map=color_map, opacity=0.8, size_max=10,
241
+ title="Bioacc vs Mobility vs Persistence"
242
+ )
243
+ fig_3d.update_layout(margin=dict(l=0, r=0, b=0, t=30))
244
+ st.plotly_chart(fig_3d, use_container_width=True)
245
+
246
+ # GRAPH 2: CLASS DISTRIBUTION (Bar Chart)
247
+ with col2:
248
+ st.subheader("πŸ“Š 2. Class Composition")
249
+ fig_bar = px.bar(
250
+ res_df, x="Subclass", color="Subclass",
251
+ title="Count of Molecules by Class",
252
+ color_discrete_map=color_map
253
+ )
254
+ st.plotly_chart(fig_bar, use_container_width=True)
255
+
256
+ col3, col4 = st.columns(2)
257
+
258
+ # GRAPH 3: PARALLEL COORDINATES (The "Trace" Graph)
259
+ with col3:
260
+ st.subheader("πŸ“‰ 3. Property Tracing")
261
+ # Normalize Subclass to integer for coloring if needed, or use Bioacc
262
+ fig_para = px.parallel_coordinates(
263
+ res_df,
264
+ dimensions=['Persistence', 'Mobility', 'Bioaccumulation'],
265
+ color="Bioaccumulation",
266
+ color_continuous_scale=px.colors.diverging.TealRose,
267
+ title="Trace: Persist -> Mobile -> Bioacc"
268
+ )
269
+ st.plotly_chart(fig_para, use_container_width=True)
270
+
271
+ # GRAPH 4: DISTRIBUTION VIOLIN PLOT
272
+ with col4:
273
+ st.subheader("🎻 4. Risk Distribution")
274
+ fig_vio = px.violin(
275
+ res_df, y="Bioaccumulation", x="Subclass",
276
+ color="Subclass", box=True, points="all",
277
+ color_discrete_map=color_map,
278
+ title="Bioaccumulation Spread per Class"
279
+ )
280
+ st.plotly_chart(fig_vio, use_container_width=True)