sobinalosious92 commited on
Commit
f15e949
·
verified ·
1 Parent(s): a09feaf

Upload 6 files

Browse files
pages/1_Property_Probe.py CHANGED
@@ -11,14 +11,23 @@ from src.lookup import (
11
  get_polyinfo,
12
  )
13
  from src.predictor_router import RouterPredictor
14
- from src.ui_style import apply_global_style
15
 
16
  st.set_page_config(page_title="Property Probe", layout="wide")
17
  apply_global_style()
18
- st.title("Quick Polymer Property Check")
 
 
 
 
 
 
 
 
 
19
 
20
- db = load_all_sources()
21
- predictor = RouterPredictor(device="cpu")
22
 
23
 
24
  def resolve_smiles_from_polymer_name(db_obj, polymer_name_query: str) -> tuple[str | None, str | None]:
@@ -72,6 +81,8 @@ selected_label = st.selectbox("Select property", options)
72
  prop = label_to_key[selected_label]
73
 
74
  if st.button("Search", type="primary"):
 
 
75
  if input_mode == "SMILES":
76
  s_canon = canonicalize_smiles(query_value)
77
  if s_canon is None:
@@ -137,4 +148,5 @@ if st.button("Search", type="primary"):
137
  })
138
 
139
  out = pd.DataFrame(rows)
 
140
  st.table(out)
 
11
  get_polyinfo,
12
  )
13
  from src.predictor_router import RouterPredictor
14
+ from src.ui_style import apply_global_style, render_page_header
15
 
16
  st.set_page_config(page_title="Property Probe", layout="wide")
17
  apply_global_style()
18
+ render_page_header(
19
+ title="Quick Polymer Property Check",
20
+ subtitle="Check one polymer at a time using source lookups plus ensemble ML prediction.",
21
+ badge="Property Probe",
22
+ )
23
+
24
+
25
+ @st.cache_resource(show_spinner=False)
26
+ def get_router_predictor() -> RouterPredictor:
27
+ return RouterPredictor(device="cpu")
28
 
29
+
30
+ predictor = get_router_predictor()
31
 
32
 
33
  def resolve_smiles_from_polymer_name(db_obj, polymer_name_query: str) -> tuple[str | None, str | None]:
 
81
  prop = label_to_key[selected_label]
82
 
83
  if st.button("Search", type="primary"):
84
+ db = load_all_sources()
85
+
86
  if input_mode == "SMILES":
87
  s_canon = canonicalize_smiles(query_value)
88
  if s_canon is None:
 
148
  })
149
 
150
  out = pd.DataFrame(rows)
151
+ out.index = range(1, len(out) + 1)
152
  st.table(out)
pages/2_Batch_Prediction.py CHANGED
@@ -2,17 +2,41 @@ import io
2
  import pandas as pd
3
  import streamlit as st
4
 
5
- from src.lookup import PROPERTY_META, canonicalize_smiles
 
 
 
 
 
 
 
6
  from src.predictor_router import RouterPredictor
7
- from src.ui_style import apply_global_style
8
 
9
  st.set_page_config(page_title="Batch Prediction", layout="wide")
10
  apply_global_style()
11
- st.title("Bulk Polymer Property Prediction")
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- predictor = RouterPredictor(device="cpu")
 
 
 
14
 
15
  MAX_RENDER_ROWS = 5000 # above this -> download only (no dataframe render)
 
 
16
 
17
 
18
  # -----------------------------
@@ -82,7 +106,17 @@ for k in prop_keys:
82
  label_to_key[label] = k
83
 
84
  selected_labels = st.multiselect("Select properties to predict", options=prop_options)
85
- include_std = st.checkbox("Include model std (ensemble spread)", value=False)
 
 
 
 
 
 
 
 
 
 
86
 
87
  st.divider()
88
 
@@ -129,14 +163,15 @@ else:
129
  )
130
  dataset_path = dataset[1]
131
 
132
- # For PI1M, force an N limit for the live web MVP
133
- st.caption("Note: large selections will switch to download-only to avoid crashing the page.")
134
 
135
  pick_mode = st.radio("Row selection", options=["First N", "Random sample N"], horizontal=True)
136
  mode = "first" if pick_mode == "First N" else "random"
137
 
138
- default_n = 13000 if dataset_path.endswith("PI.csv") else 2000
139
- max_n = 13000 if dataset_path.endswith("PI.csv") else 50000 # sensible web limit for MVP
 
140
 
141
  n = st.number_input(
142
  "How many SMILES to use",
@@ -169,6 +204,23 @@ if run:
169
  st.stop()
170
 
171
  props = [label_to_key[lbl] for lbl in selected_labels]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  # Decide whether to render table
174
  render_table = len(smiles_list) <= MAX_RENDER_ROWS
@@ -204,6 +256,13 @@ if run:
204
  if include_std:
205
  row[col_name + " [std]"] = std
206
 
 
 
 
 
 
 
 
207
  rows.append(row)
208
  if total > 0:
209
  progress.progress(int(100 * i / total))
@@ -224,7 +283,9 @@ if run:
224
  # Render table only if safe
225
  if render_table:
226
  st.subheader("Predictions")
227
- st.dataframe(out_df, width="stretch")
 
 
228
 
229
  # Download
230
  csv_bytes = out_df.to_csv(index=False).encode("utf-8")
 
2
  import pandas as pd
3
  import streamlit as st
4
 
5
+ from src.lookup import (
6
+ PROPERTY_META,
7
+ SOURCES,
8
+ SOURCE_LABELS,
9
+ canonicalize_smiles,
10
+ get_value,
11
+ load_all_sources,
12
+ )
13
  from src.predictor_router import RouterPredictor
14
+ from src.ui_style import apply_global_style, render_page_header
15
 
16
  st.set_page_config(page_title="Batch Prediction", layout="wide")
17
  apply_global_style()
18
+ render_page_header(
19
+ title="Bulk Polymer Property Prediction",
20
+ subtitle="Predict multiple target properties for large candidate sets with downloadable results.",
21
+ badge="Batch Prediction",
22
+ )
23
+
24
+
25
+ @st.cache_resource(show_spinner=False)
26
+ def get_router_predictor() -> RouterPredictor:
27
+ return RouterPredictor(device="cpu")
28
+
29
+
30
+ predictor = get_router_predictor()
31
 
32
+
33
+ @st.cache_resource(show_spinner=False)
34
+ def get_lookup_db():
35
+ return load_all_sources()
36
 
37
  MAX_RENDER_ROWS = 5000 # above this -> download only (no dataframe render)
38
+ MAX_BATCH_SMILES = 3000
39
+ MAX_BATCH_PREDICTIONS = 25000
40
 
41
 
42
  # -----------------------------
 
106
  label_to_key[label] = k
107
 
108
  selected_labels = st.multiselect("Select properties to predict", options=prop_options)
109
+ opt_col1, opt_col2 = st.columns([1, 2])
110
+ with opt_col1:
111
+ include_std = st.checkbox("Include model std (ensemble spread)", value=False)
112
+ with opt_col2:
113
+ selected_source_labels = st.multiselect(
114
+ "Include source database values",
115
+ options=[SOURCE_LABELS.get(src, src) for src in SOURCES],
116
+ placeholder="Select Experiment, MD, DFT, and/or GC",
117
+ )
118
+ source_label_to_key = {SOURCE_LABELS.get(src, src): src for src in SOURCES}
119
+ selected_sources = [source_label_to_key[label] for label in selected_source_labels]
120
 
121
  st.divider()
122
 
 
163
  )
164
  dataset_path = dataset[1]
165
 
166
+ # Website-safe cap: render mode is not enough, inference itself must stay bounded.
167
+ st.caption("Website-safe limits apply. Large jobs should be run offline rather than in the live app.")
168
 
169
  pick_mode = st.radio("Row selection", options=["First N", "Random sample N"], horizontal=True)
170
  mode = "first" if pick_mode == "First N" else "random"
171
 
172
+ is_virtual_pi1m = dataset_path.endswith("PI1M.csv")
173
+ default_n = 1000 if is_virtual_pi1m else 2000
174
+ max_n = MAX_BATCH_SMILES
175
 
176
  n = st.number_input(
177
  "How many SMILES to use",
 
204
  st.stop()
205
 
206
  props = [label_to_key[lbl] for lbl in selected_labels]
207
+ lookup_db = get_lookup_db() if selected_sources else None
208
+ requested_smiles = len(smiles_list)
209
+ prediction_cells = requested_smiles * len(props)
210
+
211
+ if requested_smiles > MAX_BATCH_SMILES:
212
+ st.error(
213
+ f"This website currently limits Batch Prediction to {MAX_BATCH_SMILES:,} SMILES per run. "
214
+ "Use a smaller subset or run larger jobs offline."
215
+ )
216
+ st.stop()
217
+
218
+ if prediction_cells > MAX_BATCH_PREDICTIONS:
219
+ st.error(
220
+ f"This request would run {prediction_cells:,} model predictions, which exceeds the website-safe limit "
221
+ f"of {MAX_BATCH_PREDICTIONS:,}. Reduce the number of SMILES or selected properties."
222
+ )
223
+ st.stop()
224
 
225
  # Decide whether to render table
226
  render_table = len(smiles_list) <= MAX_RENDER_ROWS
 
256
  if include_std:
257
  row[col_name + " [std]"] = std
258
 
259
+ if lookup_db is not None:
260
+ for src in selected_sources:
261
+ src_label = SOURCE_LABELS.get(src, src)
262
+ src_col = f"{col_name} [{src_label}]"
263
+ val = get_value(lookup_db, src, s_canon, prop)
264
+ row[src_col] = float("nan") if val is None else val
265
+
266
  rows.append(row)
267
  if total > 0:
268
  progress.progress(int(100 * i / total))
 
283
  # Render table only if safe
284
  if render_table:
285
  st.subheader("Predictions")
286
+ display_df = out_df.copy()
287
+ display_df.index = range(1, len(display_df) + 1)
288
+ st.dataframe(display_df, width="stretch")
289
 
290
  # Download
291
  csv_bytes = out_df.to_csv(index=False).encode("utf-8")
pages/3_Molecular_View.py CHANGED
@@ -10,13 +10,17 @@ from streamlit.components.v1 import html
10
  from rdkit.Chem import Lipinski, Crippen
11
  from rdkit.Chem.rdMolDescriptors import CalcTPSA, CalcExactMolWt, CalcFractionCSP3, CalcNumRings, CalcNumAromaticRings
12
 
13
- from src.ui_style import apply_global_style
14
 
15
  RDLogger.DisableLog("rdApp.*")
16
 
17
  st.set_page_config(page_title="Molecular View", layout="wide")
18
  apply_global_style()
19
- st.title("Molecular Structure View")
 
 
 
 
20
 
21
  # -------------------------
22
  # Polymer-safe helpers
@@ -313,7 +317,7 @@ with top_left:
313
 
314
  with top_right:
315
  st.markdown("Molecule Information ")
316
- st.table(
317
  {
318
  "Property": ["Formula", "Molar Weight", "Atoms"],
319
  "Value": [
@@ -323,6 +327,8 @@ with top_right:
323
  ],
324
  }
325
  )
 
 
326
 
327
  # MOL download *below the table*
328
  if mol_block_3d is not None:
@@ -366,4 +372,4 @@ with bottom_right:
366
 
367
  # Legend: include hydrogens + colored dots
368
  # Use capped mol (no '*') for clean element counting
369
- render_element_legend_with_colors(mol_cap, include_hydrogens=True)
 
10
  from rdkit.Chem import Lipinski, Crippen
11
  from rdkit.Chem.rdMolDescriptors import CalcTPSA, CalcExactMolWt, CalcFractionCSP3, CalcNumRings, CalcNumAromaticRings
12
 
13
+ from src.ui_style import apply_global_style, render_page_header
14
 
15
  RDLogger.DisableLog("rdApp.*")
16
 
17
  st.set_page_config(page_title="Molecular View", layout="wide")
18
  apply_global_style()
19
+ render_page_header(
20
+ title="Molecular Structure View",
21
+ subtitle="Inspect 2D and 3D polymer structures and review repeat-unit descriptors.",
22
+ badge="Molecular View",
23
+ )
24
 
25
  # -------------------------
26
  # Polymer-safe helpers
 
317
 
318
  with top_right:
319
  st.markdown("Molecule Information ")
320
+ info_df = pd.DataFrame(
321
  {
322
  "Property": ["Formula", "Molar Weight", "Atoms"],
323
  "Value": [
 
327
  ],
328
  }
329
  )
330
+ info_df.index = range(1, len(info_df) + 1)
331
+ st.table(info_df)
332
 
333
  # MOL download *below the table*
334
  if mol_block_3d is not None:
 
372
 
373
  # Legend: include hydrogens + colored dots
374
  # Use capped mol (no '*') for clean element counting
375
+ render_element_legend_with_colors(mol_cap, include_hydrogens=True)
pages/4_Discovery_(Manual).py CHANGED
@@ -13,11 +13,15 @@ import streamlit as st
13
 
14
  from src.discovery import run_discovery, spec_from_dict
15
  from src.lookup import PROPERTY_META
16
- from src.ui_style import apply_global_style
17
 
18
  st.set_page_config(page_title="Discovery (Manual)", layout="wide")
19
  apply_global_style()
20
- st.title("Manual Multi-Objective Discovery")
 
 
 
 
21
 
22
  # ----------------------------
23
  # Files
@@ -699,7 +703,9 @@ if st.session_state.get("discovery_done"):
699
  meta = PROPERTY_META[prop_key]
700
  rename_map[c] = f"{meta['name']} ({meta['unit']})"
701
  preview_df = preview_df.rename(columns=rename_map)
702
- st.dataframe(preview_df.head(50), width="stretch")
 
 
703
 
704
  st.subheader("📥 Download")
705
  buf = io.StringIO()
 
13
 
14
  from src.discovery import run_discovery, spec_from_dict
15
  from src.lookup import PROPERTY_META
16
+ from src.ui_style import apply_global_style, render_page_header
17
 
18
  st.set_page_config(page_title="Discovery (Manual)", layout="wide")
19
  apply_global_style()
20
+ render_page_header(
21
+ title="Manual Multi-Objective Discovery",
22
+ subtitle="Tune objectives and constraints directly to explore Pareto-optimal polymer candidates.",
23
+ badge="Discovery (Manual)",
24
+ )
25
 
26
  # ----------------------------
27
  # Files
 
703
  meta = PROPERTY_META[prop_key]
704
  rename_map[c] = f"{meta['name']} ({meta['unit']})"
705
  preview_df = preview_df.rename(columns=rename_map)
706
+ preview_display = preview_df.head(50).copy()
707
+ preview_display.index = range(1, len(preview_display) + 1)
708
+ st.dataframe(preview_display, width="stretch")
709
 
710
  st.subheader("📥 Download")
711
  buf = io.StringIO()
pages/5_Discovery_(AI).py CHANGED
@@ -8,6 +8,7 @@ import threading
8
  import time
9
  import urllib.request
10
  import urllib.error
 
11
  import zipfile
12
  from pathlib import Path
13
 
@@ -17,11 +18,15 @@ import streamlit as st
17
  from streamlit.components.v1 import html
18
 
19
  from src.discover_llm import PROPERTY_META, run_discovery, spec_from_dict
20
- from src.ui_style import apply_global_style
21
 
22
  st.set_page_config(page_title="DISCOVERY (AI)", layout="wide")
23
  apply_global_style()
24
- st.title("AI-Driven Multi-Objective Discovery")
 
 
 
 
25
 
26
  # ----------------------------
27
  # Files
@@ -307,7 +312,236 @@ def get_webui_base_url() -> str:
307
  ).rstrip("/")
308
 
309
 
310
- def validate_api_access(api_key: str, base_url: str) -> str | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  """Return None when credentials are usable, else an error message."""
312
  k = str(api_key or "").strip()
313
  u = str(base_url or "").strip().rstrip("/")
@@ -315,8 +549,23 @@ def validate_api_access(api_key: str, base_url: str) -> str | None:
315
  return "API key is required."
316
  if not u.startswith("https://"):
317
  return "API base URL must start with `https://`."
 
 
318
  try:
319
- _ = webui_request(u, k, "/api/models", payload=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  except Exception as e:
321
  return f"API key validation failed: {e}"
322
  return None
@@ -326,40 +575,37 @@ def clear_byok_key() -> None:
326
  st.session_state["discover_llm_byok_key"] = ""
327
 
328
 
329
- def webui_request(base_url: str, api_key: str, path: str, payload: dict | None = None) -> dict:
330
- url = f"{base_url}{path}"
331
- req = urllib.request.Request(
332
- url=url,
333
- data=(json.dumps(payload).encode("utf-8") if payload is not None else None),
334
- headers={
335
- "Authorization": f"Bearer {api_key}",
336
- "Content-Type": "application/json",
337
- },
338
- method=("POST" if payload is not None else "GET"),
339
- )
340
- try:
341
- with urllib.request.urlopen(req, timeout=60) as resp:
342
- return json.loads(resp.read().decode("utf-8"))
343
- except urllib.error.HTTPError as e:
344
- detail = e.read().decode("utf-8", errors="ignore")
345
- raise RuntimeError(f"WebUI API HTTP {e.code}: {detail}") from e
346
- except Exception as e:
347
- raise RuntimeError(f"WebUI API call failed: {e}") from e
348
-
349
-
350
- def list_available_models(api_key: str | None = None, base_url: str | None = None) -> list[str]:
351
  api_key = (api_key or get_webui_api_key()).strip()
352
  if not api_key:
353
  return []
354
  base_url = (base_url or get_webui_base_url()).rstrip("/")
355
- raw = webui_request(base_url, api_key, "/api/models", payload=None)
356
- items = raw.get("data", raw) if isinstance(raw, dict) else raw
 
 
 
 
 
 
 
 
 
 
 
 
357
  if not isinstance(items, list):
358
  return []
359
  out = []
360
  for m in items:
361
  if isinstance(m, dict):
362
  mid = str(m.get("id", m.get("name", ""))).strip()
 
 
363
  else:
364
  mid = str(m).strip()
365
  if mid:
@@ -368,7 +614,11 @@ def list_available_models(api_key: str | None = None, base_url: str | None = Non
368
 
369
 
370
  def generate_spec_from_llm(
371
- user_query: str, model: str, api_key: str | None = None, base_url: str | None = None
 
 
 
 
372
  ) -> dict:
373
  api_key = (api_key or get_webui_api_key()).strip()
374
  if not api_key:
@@ -401,20 +651,15 @@ def generate_spec_from_llm(
401
  user_prompt = (
402
  "User request:\n" + user_query.strip()
403
  )
404
-
405
- payload = {
406
- "model": model,
407
- "messages": [
408
- {"role": "system", "content": system_prompt},
409
- {"role": "user", "content": user_prompt},
410
- ],
411
- }
412
- raw = webui_request(base_url, api_key, "/api/chat/completions", payload=payload)
413
-
414
- try:
415
- content = raw["choices"][0]["message"]["content"]
416
- except Exception:
417
- raise RuntimeError("Unexpected LLM response format.")
418
 
419
  try:
420
  parsed = extract_first_json_object(content)
@@ -506,11 +751,6 @@ def render_copyable_prompt(prompt_text: str, box_height: int = 220) -> None:
506
  html(snippet, height=box_height + 54)
507
 
508
 
509
- @st.cache_data(ttl=300, show_spinner=False)
510
- def list_available_models_cached() -> list[str]:
511
- return list_available_models()
512
-
513
-
514
  def _local_reasoning_fallback(spec_obj: dict, stats: dict) -> str:
515
  objectives = spec_obj.get("objectives", []) if isinstance(spec_obj, dict) else []
516
  constraints = spec_obj.get("hard_constraints", {}) if isinstance(spec_obj, dict) else {}
@@ -599,6 +839,7 @@ def generate_selection_reasoning(
599
  model: str,
600
  api_key: str | None = None,
601
  base_url: str | None = None,
 
602
  ) -> str:
603
  api_key = (api_key or get_webui_api_key()).strip()
604
  if not api_key:
@@ -674,19 +915,15 @@ def generate_selection_reasoning(
674
  "You can add brief clarifying bullets if helpful, but keep it concise and focused.\n\n"
675
  f"INPUT:\n{json.dumps(user_payload, indent=2)}"
676
  )
677
- payload = {
678
- "model": model,
679
- "messages": [
680
- {"role": "system", "content": system_prompt},
681
- {"role": "user", "content": user_prompt},
682
- ],
683
- }
684
- raw = webui_request(base_url, api_key, "/api/chat/completions", payload=payload)
685
- try:
686
- content = raw["choices"][0]["message"]["content"]
687
- return str(content).strip()
688
- except Exception:
689
- raise RuntimeError("Unexpected LLM response format for reasoning.")
690
 
691
 
692
  def pareto_publication_plot(plot_df: pd.DataFrame, obj_props: list[str]):
@@ -1011,13 +1248,17 @@ if "discover_llm_query_text" not in st.session_state:
1011
  if "discover_llm_last_example_choice" not in st.session_state:
1012
  st.session_state["discover_llm_last_example_choice"] = "Select an example prompt…"
1013
  if "discover_llm_mode" not in st.session_state:
1014
- st.session_state["discover_llm_mode"] = "Built-in API"
1015
  if "discover_llm_external_response" not in st.session_state:
1016
  st.session_state["discover_llm_external_response"] = ""
1017
  if "discover_llm_byok_key" not in st.session_state:
1018
  st.session_state["discover_llm_byok_key"] = ""
1019
  if "discover_llm_byok_base_url" not in st.session_state:
1020
- st.session_state["discover_llm_byok_base_url"] = get_webui_base_url()
 
 
 
 
1021
 
1022
  # Apply deferred JSON updates before any JSON editor widget is instantiated.
1023
  pending_spec_text = st.session_state.get("discover_llm_spec_text_next")
@@ -1049,7 +1290,7 @@ with st.container(border=True):
1049
  )
1050
  mode = st.radio(
1051
  "LLM setup",
1052
- options=["Built-in API", "Bring Your Own Key", "External LLM (manual copy–paste)"],
1053
  key="discover_llm_mode",
1054
  horizontal=True,
1055
  )
@@ -1058,75 +1299,73 @@ external_response_text = st.session_state.get("discover_llm_external_response",
1058
  selected_model = "external-copy-paste"
1059
  active_api_key = ""
1060
  active_base_url = get_webui_base_url()
 
1061
  api_config_invalid = False
1062
- default_model = (
1063
- get_config_value("CRC_OPENWEBUI_MODEL", "")
1064
- or get_config_value("OPENWEBUI_MODEL", "")
1065
- or get_config_value("OPENAI_MODEL", "")
1066
- or "gpt-oss:latest"
1067
- )
1068
 
1069
- if mode in {"Built-in API", "Bring Your Own Key"}:
1070
- if mode == "Built-in API":
1071
- active_api_key = get_webui_api_key()
1072
- active_base_url = get_webui_base_url()
1073
- if not active_api_key:
1074
- st.warning(
1075
- "No API key found. Set `CRC_OPENWEBUI_API_KEY` in `.streamlit/secrets.toml` "
1076
- "or an environment variable."
1077
- )
1078
- else:
1079
- with st.container(border=True):
1080
- st.caption(
1081
- "Bring Your Own Key mode: key is used only for this session and never written to files."
1082
- )
1083
- st.text_input(
1084
- "Your API key",
1085
- key="discover_llm_byok_key",
1086
- type="password",
1087
- placeholder="Paste your API key",
1088
- )
1089
- st.text_input(
1090
- "API base URL",
1091
- key="discover_llm_byok_base_url",
1092
- placeholder="https://openwebui.crc.nd.edu",
1093
- )
1094
- st.button("Clear API key", key="clear_byok_key", on_click=clear_byok_key)
1095
- active_api_key = str(st.session_state.get("discover_llm_byok_key", "")).strip()
1096
- active_base_url = str(st.session_state.get("discover_llm_byok_base_url", "")).strip().rstrip("/")
1097
- if active_base_url and not active_base_url.startswith("https://"):
1098
- st.error("API base URL must start with `https://`.")
1099
- api_config_invalid = True
1100
- if not active_api_key:
1101
- st.warning("Enter your API key to enable in-app LLM generation.")
 
 
 
 
 
 
 
 
 
 
 
 
 
1102
 
1103
  available_models: list[str] = []
1104
  models_error = ""
1105
  if active_api_key and not api_config_invalid:
1106
  try:
1107
- if mode == "Built-in API":
1108
- available_models = list_available_models_cached()
1109
- else:
1110
- available_models = list_available_models(active_api_key, active_base_url)
1111
  except Exception as e:
1112
  models_error = str(e)
1113
 
1114
  if available_models:
1115
- model_index = available_models.index(default_model) if default_model in available_models else 0
1116
- selected_model = st.selectbox(
1117
- "LLM model",
1118
- options=available_models,
1119
- index=model_index,
1120
- help="Model used only to translate your natural language request into JSON.",
1121
- )
1122
  else:
1123
  if models_error:
1124
- st.warning(f"Could not load model list from API. Enter model name manually. Error: {models_error}")
1125
- selected_model = st.text_input(
1126
- "LLM model",
1127
- value=default_model,
1128
- help="Use a valid model id from your CRC Open WebUI instance (for example `gpt-oss:latest`).",
1129
- )
1130
  else:
1131
  with st.container(border=True):
1132
  st.caption(
@@ -1148,7 +1387,7 @@ generate_json_btn = False
1148
  if show_json_editor:
1149
  generate_json_btn = st.button(
1150
  "Generate JSON from LLM"
1151
- if mode in {"Built-in API", "Bring Your Own Key"}
1152
  else "Generate JSON from pasted response"
1153
  )
1154
 
@@ -1554,7 +1793,11 @@ def _build_runnable_spec(raw_obj: dict) -> tuple[dict, list[str], list[str]]:
1554
 
1555
 
1556
  def _raw_spec_from_prompt(
1557
- user_query: str, model_name: str, api_key: str | None = None, base_url: str | None = None
 
 
 
 
1558
  ) -> tuple[dict, list[str], str | None]:
1559
  notes: list[str] = []
1560
  extracted = {}
@@ -1562,7 +1805,13 @@ def _raw_spec_from_prompt(
1562
  return {}, notes, "Please provide a prompt before generating or running discovery."
1563
  with st.spinner("Interpreting prompt and preparing discovery config..."):
1564
  try:
1565
- extracted = generate_spec_from_llm(user_query, model_name, api_key=api_key, base_url=base_url)
 
 
 
 
 
 
1566
  except Exception as e:
1567
  return {}, notes, f"LLM generation failed: {e}"
1568
 
@@ -1629,17 +1878,21 @@ def _raw_spec_from_external_response(user_query: str, response_text: str) -> tup
1629
 
1630
 
1631
  if show_json_editor and generate_json_btn:
1632
- if mode in {"Built-in API", "Bring Your Own Key"} and not llm_query.strip():
1633
  st.error("Please provide a prompt before generating JSON.")
1634
  st.stop()
1635
  if mode == "Bring Your Own Key":
1636
- byok_err = validate_api_access(active_api_key, active_base_url)
1637
  if byok_err:
1638
  st.error(f"BYOK validation failed: {byok_err}")
1639
  st.stop()
1640
- if mode in {"Built-in API", "Bring Your Own Key"}:
1641
  raw_spec_obj, prep_notes, parse_error = _raw_spec_from_prompt(
1642
- llm_query, selected_model, api_key=active_api_key, base_url=active_base_url
 
 
 
 
1643
  )
1644
  if parse_error:
1645
  for msg in prep_notes:
@@ -1665,11 +1918,11 @@ if show_json_editor and generate_json_btn:
1665
  run_btn = st.button("Run discovery", type="primary")
1666
 
1667
  if run_btn:
1668
- if mode in {"Built-in API", "Bring Your Own Key"} and not llm_query.strip():
1669
  st.error("Please provide a prompt before running discovery.")
1670
  st.stop()
1671
  if mode == "Bring Your Own Key":
1672
- byok_err = validate_api_access(active_api_key, active_base_url)
1673
  if byok_err:
1674
  st.error(f"BYOK validation failed: {byok_err}")
1675
  st.stop()
@@ -1686,9 +1939,13 @@ if run_btn:
1686
  raw_spec_obj = {}
1687
  prep_notes.append("Invalid JSON detected. Using fixed template defaults.")
1688
  else:
1689
- if mode in {"Built-in API", "Bring Your Own Key"}:
1690
  raw_spec_obj, llm_notes, parse_error = _raw_spec_from_prompt(
1691
- llm_query, selected_model, api_key=active_api_key, base_url=active_base_url
 
 
 
 
1692
  )
1693
  if parse_error:
1694
  for msg in llm_notes:
@@ -1728,6 +1985,7 @@ if run_btn:
1728
  st.session_state["discovery_mode_used"] = mode
1729
  st.session_state["discovery_api_key"] = active_api_key if mode == "Bring Your Own Key" else ""
1730
  st.session_state["discovery_api_base_url"] = active_base_url if mode == "Bring Your Own Key" else ""
 
1731
  st.session_state["discovery_reasoning_text"] = None
1732
  st.session_state["discovery_reasoning_key"] = None
1733
  st.session_state["discovery_reasoning_note"] = None
@@ -1799,13 +2057,15 @@ if st.session_state.get("discovery_done"):
1799
  c3.metric("Pareto pool", int(stats.get("n_pareto_pool", 0)))
1800
  c4.metric("Selected", int(stats.get("n_selected", 0)))
1801
 
1802
- if mode_used in {"Built-in API", "Bring Your Own Key"}:
1803
  reasoning_api_key = st.session_state.get("discovery_api_key", "")
1804
  reasoning_api_base_url = st.session_state.get("discovery_api_base_url", "")
 
1805
  reasoning_key_obj = {
1806
  "spec": resolved_spec,
1807
  "model": model_used,
1808
  "mode": mode_used,
 
1809
  "selected_smiles_head": (
1810
  out_df["SMILES"].astype(str).head(20).tolist()
1811
  if isinstance(out_df, pd.DataFrame) and "SMILES" in out_df.columns
@@ -1826,6 +2086,7 @@ if st.session_state.get("discovery_done"):
1826
  model_used,
1827
  api_key=(str(reasoning_api_key).strip() or None),
1828
  base_url=(str(reasoning_api_base_url).strip() or None),
 
1829
  )
1830
  st.session_state["discovery_reasoning_note"] = None
1831
  except Exception as e:
@@ -1869,7 +2130,9 @@ if st.session_state.get("discovery_done"):
1869
  meta = PROPERTY_META[prop_key]
1870
  rename_map[c] = f"{meta['name']} ({meta['unit']})"
1871
  preview_df = preview_df.rename(columns=rename_map)
1872
- st.dataframe(preview_df.head(50), width="stretch")
 
 
1873
 
1874
  st.subheader("📥 Download")
1875
  buf = io.StringIO()
 
8
  import time
9
  import urllib.request
10
  import urllib.error
11
+ import urllib.parse
12
  import zipfile
13
  from pathlib import Path
14
 
 
18
  from streamlit.components.v1 import html
19
 
20
  from src.discover_llm import PROPERTY_META, run_discovery, spec_from_dict
21
+ from src.ui_style import apply_global_style, render_page_header
22
 
23
  st.set_page_config(page_title="DISCOVERY (AI)", layout="wide")
24
  apply_global_style()
25
+ render_page_header(
26
+ title="AI-Driven Multi-Objective Discovery",
27
+ subtitle="Describe target behavior in plain language and run auto-configured multi-objective search.",
28
+ badge="Discovery (AI)",
29
+ )
30
 
31
  # ----------------------------
32
  # Files
 
312
  ).rstrip("/")
313
 
314
 
315
+ PROVIDER_LABELS = {
316
+ "auto": "Auto detect",
317
+ "openwebui": "OpenWebUI",
318
+ "openai_compatible": "OpenAI-compatible",
319
+ "anthropic": "Anthropic",
320
+ "gemini": "Gemini",
321
+ }
322
+
323
+ PROVIDER_OPTIONS = list(PROVIDER_LABELS.keys())
324
+
325
+
326
+ def _provider_label(provider: str) -> str:
327
+ return PROVIDER_LABELS.get(provider, provider)
328
+
329
+
330
+ def default_model_for_provider(provider: str) -> str:
331
+ p = _normalize_provider(provider)
332
+ if p == "openwebui":
333
+ return (
334
+ get_config_value("CRC_OPENWEBUI_MODEL", "")
335
+ or get_config_value("OPENWEBUI_MODEL", "")
336
+ or get_config_value("OPENAI_MODEL", "")
337
+ or "gpt-oss:latest"
338
+ )
339
+ if p == "openai_compatible":
340
+ return (
341
+ get_config_value("OPENAI_MODEL", "")
342
+ or get_config_value("OPENWEBUI_MODEL", "")
343
+ or get_config_value("CRC_OPENWEBUI_MODEL", "")
344
+ or "gpt-4o-mini"
345
+ )
346
+ if p == "anthropic":
347
+ return get_config_value("ANTHROPIC_MODEL", "") or "claude-3-5-sonnet-latest"
348
+ if p == "gemini":
349
+ return get_config_value("GEMINI_MODEL", "") or "gemini-2.0-flash"
350
+ return get_config_value("OPENAI_MODEL", "") or "gpt-4o-mini"
351
+
352
+
353
+ def _normalize_provider(provider: str | None) -> str:
354
+ s = str(provider or "").strip().lower().replace("-", "_").replace(" ", "_")
355
+ if s in PROVIDER_LABELS:
356
+ return s
357
+ return "auto"
358
+
359
+
360
+ def detect_api_provider(base_url: str) -> str:
361
+ u = str(base_url or "").strip().lower()
362
+ if "openwebui" in u:
363
+ return "openwebui"
364
+ if "anthropic.com" in u:
365
+ return "anthropic"
366
+ if "generativelanguage.googleapis.com" in u or "googleapis.com" in u:
367
+ return "gemini"
368
+ if "api.openai.com" in u or "/v1" in u or "openrouter.ai" in u:
369
+ return "openai_compatible"
370
+ return "openai_compatible"
371
+
372
+
373
+ def resolve_api_provider(base_url: str, provider: str | None = None) -> str:
374
+ p = _normalize_provider(provider)
375
+ if p == "auto":
376
+ return detect_api_provider(base_url)
377
+ return p
378
+
379
+
380
+ def _provider_root(base_url: str, provider: str) -> str:
381
+ u = str(base_url or "").strip().rstrip("/")
382
+ if provider == "openwebui":
383
+ return u
384
+ if provider == "openai_compatible":
385
+ return u if u.endswith("/v1") else f"{u}/v1"
386
+ if provider == "anthropic":
387
+ return u if u.endswith("/v1") else f"{u}/v1"
388
+ if provider == "gemini":
389
+ if u.endswith("/v1") or u.endswith("/v1beta"):
390
+ return u
391
+ return f"{u}/v1beta"
392
+ return u
393
+
394
+
395
+ def _join_url(base_url: str, path: str) -> str:
396
+ return f"{base_url.rstrip('/')}{path}"
397
+
398
+
399
+ def _http_json_request(
400
+ url: str,
401
+ headers: dict[str, str] | None = None,
402
+ payload: dict | None = None,
403
+ method: str | None = None,
404
+ timeout: int = 60,
405
+ ) -> dict:
406
+ req = urllib.request.Request(
407
+ url=url,
408
+ data=(json.dumps(payload).encode("utf-8") if payload is not None else None),
409
+ headers=(headers or {}),
410
+ method=(method or ("POST" if payload is not None else "GET")),
411
+ )
412
+ try:
413
+ with urllib.request.urlopen(req, timeout=timeout) as resp:
414
+ return json.loads(resp.read().decode("utf-8"))
415
+ except urllib.error.HTTPError as e:
416
+ detail = e.read().decode("utf-8", errors="ignore")
417
+ raise RuntimeError(f"HTTP {e.code}: {detail}") from e
418
+ except Exception as e:
419
+ raise RuntimeError(str(e)) from e
420
+
421
+
422
+ def _flatten_text_content(content) -> str:
423
+ if isinstance(content, str):
424
+ return content.strip()
425
+ if isinstance(content, list):
426
+ parts = []
427
+ for item in content:
428
+ if isinstance(item, str):
429
+ parts.append(item)
430
+ elif isinstance(item, dict):
431
+ txt = str(item.get("text", "")).strip()
432
+ if txt:
433
+ parts.append(txt)
434
+ return "\n".join(p for p in parts if p).strip()
435
+ return str(content or "").strip()
436
+
437
+
438
+ def provider_request(
439
+ base_url: str,
440
+ api_key: str,
441
+ provider: str,
442
+ path: str,
443
+ payload: dict | None = None,
444
+ ) -> dict:
445
+ root = _provider_root(base_url, provider)
446
+ headers = {"Content-Type": "application/json"}
447
+ url = _join_url(root, path)
448
+
449
+ if provider in {"openwebui", "openai_compatible"}:
450
+ headers["Authorization"] = f"Bearer {api_key}"
451
+ elif provider == "anthropic":
452
+ headers["x-api-key"] = api_key
453
+ headers["anthropic-version"] = "2023-06-01"
454
+ elif provider == "gemini":
455
+ sep = "&" if "?" in url else "?"
456
+ url = f"{url}{sep}key={urllib.parse.quote(api_key, safe='')}"
457
+
458
+ try:
459
+ return _http_json_request(url, headers=headers, payload=payload)
460
+ except Exception as e:
461
+ raise RuntimeError(f"{_provider_label(provider)} API call failed: {e}") from e
462
+
463
+
464
+ def chat_text_request(
465
+ base_url: str,
466
+ api_key: str,
467
+ provider: str,
468
+ model: str,
469
+ system_prompt: str,
470
+ user_prompt: str,
471
+ max_tokens: int = 1024,
472
+ ) -> str:
473
+ provider = resolve_api_provider(base_url, provider)
474
+ if provider in {"openwebui", "openai_compatible"}:
475
+ raw = provider_request(
476
+ base_url,
477
+ api_key,
478
+ provider,
479
+ "/chat/completions" if provider == "openai_compatible" else "/api/chat/completions",
480
+ payload={
481
+ "model": model,
482
+ "messages": [
483
+ {"role": "system", "content": system_prompt},
484
+ {"role": "user", "content": user_prompt},
485
+ ],
486
+ },
487
+ )
488
+ try:
489
+ return _flatten_text_content(raw["choices"][0]["message"]["content"])
490
+ except Exception as e:
491
+ raise RuntimeError("Unexpected chat-completions response format.") from e
492
+
493
+ if provider == "anthropic":
494
+ raw = provider_request(
495
+ base_url,
496
+ api_key,
497
+ provider,
498
+ "/messages",
499
+ payload={
500
+ "model": model,
501
+ "system": system_prompt,
502
+ "max_tokens": int(max_tokens),
503
+ "messages": [{"role": "user", "content": user_prompt}],
504
+ },
505
+ )
506
+ try:
507
+ return "\n".join(
508
+ str(part.get("text", "")).strip()
509
+ for part in raw.get("content", [])
510
+ if isinstance(part, dict) and str(part.get("type", "")) == "text"
511
+ ).strip()
512
+ except Exception as e:
513
+ raise RuntimeError("Unexpected Anthropic response format.") from e
514
+
515
+ if provider == "gemini":
516
+ model_name = str(model or "").strip()
517
+ if model_name.startswith("models/"):
518
+ model_name = model_name.split("/", 1)[1]
519
+ raw = provider_request(
520
+ base_url,
521
+ api_key,
522
+ provider,
523
+ f"/models/{urllib.parse.quote(model_name, safe='')}:generateContent",
524
+ payload={
525
+ "system_instruction": {"parts": [{"text": system_prompt}]},
526
+ "contents": [{"role": "user", "parts": [{"text": user_prompt}]}],
527
+ "generationConfig": {"temperature": 0.0, "maxOutputTokens": int(max_tokens)},
528
+ },
529
+ )
530
+ try:
531
+ candidates = raw.get("candidates", [])
532
+ parts = candidates[0]["content"]["parts"] if candidates else []
533
+ return "\n".join(
534
+ str(part.get("text", "")).strip()
535
+ for part in parts
536
+ if isinstance(part, dict) and str(part.get("text", "")).strip()
537
+ ).strip()
538
+ except Exception as e:
539
+ raise RuntimeError("Unexpected Gemini response format.") from e
540
+
541
+ raise RuntimeError(f"Unsupported provider: {provider}")
542
+
543
+
544
+ def validate_api_access(api_key: str, base_url: str, provider: str | None = None, model: str | None = None) -> str | None:
545
  """Return None when credentials are usable, else an error message."""
546
  k = str(api_key or "").strip()
547
  u = str(base_url or "").strip().rstrip("/")
 
549
  return "API key is required."
550
  if not u.startswith("https://"):
551
  return "API base URL must start with `https://`."
552
+
553
+ resolved_provider = resolve_api_provider(u, provider)
554
  try:
555
+ if resolved_provider in {"openwebui", "openai_compatible"}:
556
+ _ = list_available_models(k, u, resolved_provider)
557
+ elif resolved_provider in {"anthropic", "gemini"}:
558
+ if not str(model or "").strip():
559
+ return f"{_provider_label(resolved_provider)} validation requires a model name."
560
+ _ = chat_text_request(
561
+ u,
562
+ k,
563
+ resolved_provider,
564
+ str(model).strip(),
565
+ "Reply with OK.",
566
+ "ping",
567
+ max_tokens=8,
568
+ )
569
  except Exception as e:
570
  return f"API key validation failed: {e}"
571
  return None
 
575
  st.session_state["discover_llm_byok_key"] = ""
576
 
577
 
578
+ def list_available_models(
579
+ api_key: str | None = None,
580
+ base_url: str | None = None,
581
+ provider: str | None = None,
582
+ ) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  api_key = (api_key or get_webui_api_key()).strip()
584
  if not api_key:
585
  return []
586
  base_url = (base_url or get_webui_base_url()).rstrip("/")
587
+ resolved_provider = resolve_api_provider(base_url, provider)
588
+
589
+ if resolved_provider == "openwebui":
590
+ raw = provider_request(base_url, api_key, resolved_provider, "/api/models", payload=None)
591
+ items = raw.get("data", raw) if isinstance(raw, dict) else raw
592
+ elif resolved_provider == "openai_compatible":
593
+ raw = provider_request(base_url, api_key, resolved_provider, "/models", payload=None)
594
+ items = raw.get("data", raw) if isinstance(raw, dict) else raw
595
+ elif resolved_provider == "gemini":
596
+ raw = provider_request(base_url, api_key, resolved_provider, "/models", payload=None)
597
+ items = raw.get("models", raw.get("data", raw)) if isinstance(raw, dict) else raw
598
+ else:
599
+ return []
600
+
601
  if not isinstance(items, list):
602
  return []
603
  out = []
604
  for m in items:
605
  if isinstance(m, dict):
606
  mid = str(m.get("id", m.get("name", ""))).strip()
607
+ if resolved_provider == "gemini" and mid.startswith("models/"):
608
+ mid = mid.split("/", 1)[1]
609
  else:
610
  mid = str(m).strip()
611
  if mid:
 
614
 
615
 
616
  def generate_spec_from_llm(
617
+ user_query: str,
618
+ model: str,
619
+ api_key: str | None = None,
620
+ base_url: str | None = None,
621
+ provider: str | None = None,
622
  ) -> dict:
623
  api_key = (api_key or get_webui_api_key()).strip()
624
  if not api_key:
 
651
  user_prompt = (
652
  "User request:\n" + user_query.strip()
653
  )
654
+ content = chat_text_request(
655
+ base_url,
656
+ api_key,
657
+ resolve_api_provider(base_url, provider),
658
+ model,
659
+ system_prompt,
660
+ user_prompt,
661
+ max_tokens=1024,
662
+ )
 
 
 
 
 
663
 
664
  try:
665
  parsed = extract_first_json_object(content)
 
751
  html(snippet, height=box_height + 54)
752
 
753
 
 
 
 
 
 
754
  def _local_reasoning_fallback(spec_obj: dict, stats: dict) -> str:
755
  objectives = spec_obj.get("objectives", []) if isinstance(spec_obj, dict) else []
756
  constraints = spec_obj.get("hard_constraints", {}) if isinstance(spec_obj, dict) else {}
 
839
  model: str,
840
  api_key: str | None = None,
841
  base_url: str | None = None,
842
+ provider: str | None = None,
843
  ) -> str:
844
  api_key = (api_key or get_webui_api_key()).strip()
845
  if not api_key:
 
915
  "You can add brief clarifying bullets if helpful, but keep it concise and focused.\n\n"
916
  f"INPUT:\n{json.dumps(user_payload, indent=2)}"
917
  )
918
+ return chat_text_request(
919
+ base_url,
920
+ api_key,
921
+ resolve_api_provider(base_url, provider),
922
+ model,
923
+ system_prompt,
924
+ user_prompt,
925
+ max_tokens=900,
926
+ )
 
 
 
 
927
 
928
 
929
  def pareto_publication_plot(plot_df: pd.DataFrame, obj_props: list[str]):
 
1248
  if "discover_llm_last_example_choice" not in st.session_state:
1249
  st.session_state["discover_llm_last_example_choice"] = "Select an example prompt…"
1250
  if "discover_llm_mode" not in st.session_state:
1251
+ st.session_state["discover_llm_mode"] = "Bring Your Own Key"
1252
  if "discover_llm_external_response" not in st.session_state:
1253
  st.session_state["discover_llm_external_response"] = ""
1254
  if "discover_llm_byok_key" not in st.session_state:
1255
  st.session_state["discover_llm_byok_key"] = ""
1256
  if "discover_llm_byok_base_url" not in st.session_state:
1257
+ st.session_state["discover_llm_byok_base_url"] = ""
1258
+ if "discover_llm_byok_provider" not in st.session_state:
1259
+ st.session_state["discover_llm_byok_provider"] = "auto"
1260
+ if st.session_state.get("discover_llm_mode") not in {"Bring Your Own Key", "External LLM (manual copy–paste)"}:
1261
+ st.session_state["discover_llm_mode"] = "Bring Your Own Key"
1262
 
1263
  # Apply deferred JSON updates before any JSON editor widget is instantiated.
1264
  pending_spec_text = st.session_state.get("discover_llm_spec_text_next")
 
1290
  )
1291
  mode = st.radio(
1292
  "LLM setup",
1293
+ options=["Bring Your Own Key", "External LLM (manual copy–paste)"],
1294
  key="discover_llm_mode",
1295
  horizontal=True,
1296
  )
 
1299
  selected_model = "external-copy-paste"
1300
  active_api_key = ""
1301
  active_base_url = get_webui_base_url()
1302
+ active_provider = "openwebui"
1303
  api_config_invalid = False
 
 
 
 
 
 
1304
 
1305
+ if mode == "Bring Your Own Key":
1306
+ with st.container(border=True):
1307
+ st.caption(
1308
+ "Bring Your Own Key mode: key is used only for this session and never written to files."
1309
+ )
1310
+ st.caption(
1311
+ "Enter the service root URL, not a full endpoint path. Examples: "
1312
+ "`https://api.openai.com`, `https://api.anthropic.com`, "
1313
+ "`https://generativelanguage.googleapis.com`, or your OpenWebUI base URL."
1314
+ )
1315
+ st.text_input(
1316
+ "Your API key",
1317
+ key="discover_llm_byok_key",
1318
+ type="password",
1319
+ placeholder="Paste your API key",
1320
+ )
1321
+ st.text_input(
1322
+ "API base URL",
1323
+ key="discover_llm_byok_base_url",
1324
+ placeholder="Enter service root URL",
1325
+ )
1326
+ st.selectbox(
1327
+ "API provider",
1328
+ options=PROVIDER_OPTIONS,
1329
+ key="discover_llm_byok_provider",
1330
+ format_func=_provider_label,
1331
+ help=(
1332
+ "Use Auto detect for most endpoints. "
1333
+ "Choose a provider explicitly if the base URL is a direct Anthropic or Gemini endpoint, "
1334
+ "or if your gateway does not identify itself clearly."
1335
+ ),
1336
+ )
1337
+ st.button("Clear API key", key="clear_byok_key", on_click=clear_byok_key)
1338
+ active_api_key = str(st.session_state.get("discover_llm_byok_key", "")).strip()
1339
+ user_base_url = str(st.session_state.get("discover_llm_byok_base_url", "")).strip().rstrip("/")
1340
+ active_base_url = user_base_url or get_webui_base_url()
1341
+ configured_provider = str(st.session_state.get("discover_llm_byok_provider", "auto")).strip()
1342
+ active_provider = resolve_api_provider(active_base_url, configured_provider) if active_base_url else "auto"
1343
+ fallback_model = default_model_for_provider(active_provider)
1344
+ if user_base_url and not user_base_url.startswith("https://"):
1345
+ st.error("API base URL must start with `https://`.")
1346
+ api_config_invalid = True
1347
+ elif user_base_url:
1348
+ st.caption(f"Detected provider: `{_provider_label(active_provider)}`")
1349
+ if not active_api_key:
1350
+ st.warning("Enter your API key to enable in-app LLM generation.")
1351
 
1352
  available_models: list[str] = []
1353
  models_error = ""
1354
  if active_api_key and not api_config_invalid:
1355
  try:
1356
+ available_models = list_available_models(active_api_key, active_base_url, active_provider)
 
 
 
1357
  except Exception as e:
1358
  models_error = str(e)
1359
 
1360
  if available_models:
1361
+ model_index = available_models.index(fallback_model) if fallback_model in available_models else 0
1362
+ selected_model = available_models[model_index]
1363
+ st.caption(f"Using model: `{selected_model}`")
 
 
 
 
1364
  else:
1365
  if models_error:
1366
+ st.warning(f"Could not load model list from API. Using fallback model `{fallback_model}`. Error: {models_error}")
1367
+ selected_model = fallback_model
1368
+ st.caption(f"Using fallback model: `{selected_model}`")
 
 
 
1369
  else:
1370
  with st.container(border=True):
1371
  st.caption(
 
1387
  if show_json_editor:
1388
  generate_json_btn = st.button(
1389
  "Generate JSON from LLM"
1390
+ if mode == "Bring Your Own Key"
1391
  else "Generate JSON from pasted response"
1392
  )
1393
 
 
1793
 
1794
 
1795
  def _raw_spec_from_prompt(
1796
+ user_query: str,
1797
+ model_name: str,
1798
+ api_key: str | None = None,
1799
+ base_url: str | None = None,
1800
+ provider: str | None = None,
1801
  ) -> tuple[dict, list[str], str | None]:
1802
  notes: list[str] = []
1803
  extracted = {}
 
1805
  return {}, notes, "Please provide a prompt before generating or running discovery."
1806
  with st.spinner("Interpreting prompt and preparing discovery config..."):
1807
  try:
1808
+ extracted = generate_spec_from_llm(
1809
+ user_query,
1810
+ model_name,
1811
+ api_key=api_key,
1812
+ base_url=base_url,
1813
+ provider=provider,
1814
+ )
1815
  except Exception as e:
1816
  return {}, notes, f"LLM generation failed: {e}"
1817
 
 
1878
 
1879
 
1880
  if show_json_editor and generate_json_btn:
1881
+ if mode == "Bring Your Own Key" and not llm_query.strip():
1882
  st.error("Please provide a prompt before generating JSON.")
1883
  st.stop()
1884
  if mode == "Bring Your Own Key":
1885
+ byok_err = validate_api_access(active_api_key, active_base_url, active_provider, selected_model)
1886
  if byok_err:
1887
  st.error(f"BYOK validation failed: {byok_err}")
1888
  st.stop()
1889
+ if mode == "Bring Your Own Key":
1890
  raw_spec_obj, prep_notes, parse_error = _raw_spec_from_prompt(
1891
+ llm_query,
1892
+ selected_model,
1893
+ api_key=active_api_key,
1894
+ base_url=active_base_url,
1895
+ provider=active_provider,
1896
  )
1897
  if parse_error:
1898
  for msg in prep_notes:
 
1918
  run_btn = st.button("Run discovery", type="primary")
1919
 
1920
  if run_btn:
1921
+ if mode == "Bring Your Own Key" and not llm_query.strip():
1922
  st.error("Please provide a prompt before running discovery.")
1923
  st.stop()
1924
  if mode == "Bring Your Own Key":
1925
+ byok_err = validate_api_access(active_api_key, active_base_url, active_provider, selected_model)
1926
  if byok_err:
1927
  st.error(f"BYOK validation failed: {byok_err}")
1928
  st.stop()
 
1939
  raw_spec_obj = {}
1940
  prep_notes.append("Invalid JSON detected. Using fixed template defaults.")
1941
  else:
1942
+ if mode == "Bring Your Own Key":
1943
  raw_spec_obj, llm_notes, parse_error = _raw_spec_from_prompt(
1944
+ llm_query,
1945
+ selected_model,
1946
+ api_key=active_api_key,
1947
+ base_url=active_base_url,
1948
+ provider=active_provider,
1949
  )
1950
  if parse_error:
1951
  for msg in llm_notes:
 
1985
  st.session_state["discovery_mode_used"] = mode
1986
  st.session_state["discovery_api_key"] = active_api_key if mode == "Bring Your Own Key" else ""
1987
  st.session_state["discovery_api_base_url"] = active_base_url if mode == "Bring Your Own Key" else ""
1988
+ st.session_state["discovery_api_provider"] = active_provider if mode == "Bring Your Own Key" else ""
1989
  st.session_state["discovery_reasoning_text"] = None
1990
  st.session_state["discovery_reasoning_key"] = None
1991
  st.session_state["discovery_reasoning_note"] = None
 
2057
  c3.metric("Pareto pool", int(stats.get("n_pareto_pool", 0)))
2058
  c4.metric("Selected", int(stats.get("n_selected", 0)))
2059
 
2060
+ if mode_used == "Bring Your Own Key":
2061
  reasoning_api_key = st.session_state.get("discovery_api_key", "")
2062
  reasoning_api_base_url = st.session_state.get("discovery_api_base_url", "")
2063
+ reasoning_api_provider = st.session_state.get("discovery_api_provider", "openwebui")
2064
  reasoning_key_obj = {
2065
  "spec": resolved_spec,
2066
  "model": model_used,
2067
  "mode": mode_used,
2068
+ "provider": reasoning_api_provider,
2069
  "selected_smiles_head": (
2070
  out_df["SMILES"].astype(str).head(20).tolist()
2071
  if isinstance(out_df, pd.DataFrame) and "SMILES" in out_df.columns
 
2086
  model_used,
2087
  api_key=(str(reasoning_api_key).strip() or None),
2088
  base_url=(str(reasoning_api_base_url).strip() or None),
2089
+ provider=(str(reasoning_api_provider).strip() or None),
2090
  )
2091
  st.session_state["discovery_reasoning_note"] = None
2092
  except Exception as e:
 
2130
  meta = PROPERTY_META[prop_key]
2131
  rename_map[c] = f"{meta['name']} ({meta['unit']})"
2132
  preview_df = preview_df.rename(columns=rename_map)
2133
+ preview_display = preview_df.head(50).copy()
2134
+ preview_display.index = range(1, len(preview_display) + 1)
2135
+ st.dataframe(preview_display, width="stretch")
2136
 
2137
  st.subheader("📥 Download")
2138
  buf = io.StringIO()
pages/6_Novel_SMILES_Generation.py CHANGED
@@ -11,59 +11,95 @@ from src.rnn_smiles.generator import (
11
  load_existing_smiles_set,
12
  load_rnn_model,
13
  )
14
- from src.ui_style import apply_global_style
15
 
16
  st.set_page_config(page_title="Novel SMILES Generation", layout="wide")
17
  apply_global_style()
18
- st.title("Novel SMILES Generation")
19
- st.caption("Generate candidate polymers with an RNN and keep only molecules not seen in local datasets.")
 
 
 
20
 
21
  APP_ROOT = Path(__file__).resolve().parents[1]
22
  MODEL_DIR = APP_ROOT / "models" / "rnn" / "pretrained_model"
23
 
24
  DEFAULT_CKPT = MODEL_DIR / "Prior.ckpt"
25
  DEFAULT_VOC = MODEL_DIR / "voc"
 
 
26
 
27
- NOVELTY_DATASETS = [
28
  APP_ROOT / "data" / "EXP.csv",
29
  APP_ROOT / "data" / "MD.csv",
30
  APP_ROOT / "data" / "DFT.csv",
31
  APP_ROOT / "data" / "GC.csv",
32
  APP_ROOT / "data" / "POLYINFO.csv",
 
 
33
  APP_ROOT / "data" / "PI1M.csv",
34
  ]
35
 
36
- with st.sidebar:
37
- st.subheader("Model Assets")
38
- ckpt_path = st.text_input("Checkpoint path", value=str(DEFAULT_CKPT))
39
- voc_path = st.text_input("Vocabulary path", value=str(DEFAULT_VOC))
40
 
 
 
 
 
 
 
 
 
 
 
41
  st.subheader("Generation Parameters")
42
  target_count = st.number_input("Novel SMILES to return", min_value=1, max_value=5000, value=200, step=25)
43
  max_length = st.number_input("Max token length", min_value=20, max_value=300, value=140, step=10)
44
  temperature = st.slider("Temperature", min_value=0.2, max_value=2.0, value=1.0, step=0.1)
45
  max_attempts = st.number_input("Sampling attempts", min_value=1, max_value=50, value=10, step=1)
 
 
 
 
 
 
 
 
 
46
 
47
  if not Path(ckpt_path).expanduser().exists() or not Path(voc_path).expanduser().exists():
48
  st.error("Model files were not found.")
49
- st.write("Expected default location:")
50
- st.code(str(MODEL_DIR))
51
  st.stop()
52
 
53
- available_datasets = [p for p in NOVELTY_DATASETS if p.exists()]
54
- missing_datasets = [p for p in NOVELTY_DATASETS if not p.exists()]
 
 
 
 
 
 
55
 
56
  if missing_datasets:
57
  st.warning("Some novelty datasets are missing and were skipped.")
58
  for path in missing_datasets:
59
  st.write(f"- {path.name}")
60
 
 
 
 
 
 
61
  if not available_datasets:
62
  st.warning("No novelty datasets found. Results will only be de-duplicated within this run.")
63
 
64
  if st.button("Generate", type="primary"):
65
- with st.spinner("Loading RNN model (cached after first load)..."):
66
- model, voc = load_rnn_model(ckpt_path, voc_path)
 
 
 
 
 
67
 
68
  with st.spinner("Building novelty index (cached after first load)..."):
69
  existing_smiles = load_existing_smiles_set(tuple(str(p) for p in available_datasets)) if available_datasets else set()
@@ -121,7 +157,9 @@ if st.button("Generate", type="primary"):
121
  st.stop()
122
 
123
  result_df = pd.DataFrame({"SMILES": novel})
124
- st.dataframe(result_df, width="stretch")
 
 
125
  st.download_button(
126
  "Download CSV",
127
  data=result_df.to_csv(index=False).encode("utf-8"),
 
11
  load_existing_smiles_set,
12
  load_rnn_model,
13
  )
14
+ from src.ui_style import apply_global_style, render_page_header
15
 
16
  st.set_page_config(page_title="Novel SMILES Generation", layout="wide")
17
  apply_global_style()
18
+ render_page_header(
19
+ title="Novel SMILES Generation",
20
+ subtitle="Generate candidate polymers with an RNN and filter against local datasets for novelty.",
21
+ badge="Novel SMILES Generation",
22
+ )
23
 
24
  APP_ROOT = Path(__file__).resolve().parents[1]
25
  MODEL_DIR = APP_ROOT / "models" / "rnn" / "pretrained_model"
26
 
27
  DEFAULT_CKPT = MODEL_DIR / "Prior.ckpt"
28
  DEFAULT_VOC = MODEL_DIR / "voc"
29
+ ckpt_path = str(DEFAULT_CKPT)
30
+ voc_path = str(DEFAULT_VOC)
31
 
32
+ FAST_NOVELTY_DATASETS = [
33
  APP_ROOT / "data" / "EXP.csv",
34
  APP_ROOT / "data" / "MD.csv",
35
  APP_ROOT / "data" / "DFT.csv",
36
  APP_ROOT / "data" / "GC.csv",
37
  APP_ROOT / "data" / "POLYINFO.csv",
38
+ ]
39
+ SLOW_NOVELTY_DATASETS = [
40
  APP_ROOT / "data" / "PI1M.csv",
41
  ]
42
 
 
 
 
 
43
 
44
+ def _has_smiles_column(path: Path) -> bool:
45
+ try:
46
+ header = pd.read_csv(path, nrows=0)
47
+ except Exception:
48
+ return False
49
+ cols = [str(c).strip().lower() for c in header.columns]
50
+ return any(c in {"smiles", "canonical_smiles", "canonical smiles", "smile", "smi"} or "smiles" in c for c in cols)
51
+
52
+
53
+ with st.sidebar:
54
  st.subheader("Generation Parameters")
55
  target_count = st.number_input("Novel SMILES to return", min_value=1, max_value=5000, value=200, step=25)
56
  max_length = st.number_input("Max token length", min_value=20, max_value=300, value=140, step=10)
57
  temperature = st.slider("Temperature", min_value=0.2, max_value=2.0, value=1.0, step=0.1)
58
  max_attempts = st.number_input("Sampling attempts", min_value=1, max_value=50, value=10, step=1)
59
+ include_virtual_novelty = st.checkbox(
60
+ "Include PI1M in novelty filter (slower)",
61
+ value=False,
62
+ help="Off by default for website responsiveness. Enable only if you need novelty checked against the virtual library too.",
63
+ )
64
+
65
+ novelty_datasets = list(FAST_NOVELTY_DATASETS)
66
+ if include_virtual_novelty:
67
+ novelty_datasets.extend(SLOW_NOVELTY_DATASETS)
68
 
69
  if not Path(ckpt_path).expanduser().exists() or not Path(voc_path).expanduser().exists():
70
  st.error("Model files were not found.")
 
 
71
  st.stop()
72
 
73
+ available_datasets = [p for p in novelty_datasets if p.exists() and _has_smiles_column(p)]
74
+ missing_datasets = [p for p in novelty_datasets if not p.exists()]
75
+ invalid_datasets = [p for p in novelty_datasets if p.exists() and not _has_smiles_column(p)]
76
+
77
+ if include_virtual_novelty:
78
+ st.caption("Full novelty mode includes PI1M and may take significantly longer on the first run.")
79
+ else:
80
+ st.caption("Fast novelty mode checks EXP, MD, DFT, GC, and POLYINFO. PI1M is excluded by default for website responsiveness.")
81
 
82
  if missing_datasets:
83
  st.warning("Some novelty datasets are missing and were skipped.")
84
  for path in missing_datasets:
85
  st.write(f"- {path.name}")
86
 
87
+ if invalid_datasets:
88
+ st.warning("Some novelty datasets are malformed or missing a SMILES column and were skipped.")
89
+ for path in invalid_datasets:
90
+ st.write(f"- {path.name}")
91
+
92
  if not available_datasets:
93
  st.warning("No novelty datasets found. Results will only be de-duplicated within this run.")
94
 
95
  if st.button("Generate", type="primary"):
96
+ try:
97
+ with st.spinner("Loading RNN model (cached after first load)..."):
98
+ model, voc = load_rnn_model(ckpt_path, voc_path)
99
+ except Exception as exc:
100
+ st.error(f"Failed to load the RNN checkpoint: {exc}")
101
+ st.info("If you see a Git LFS pointer error, replace `models/rnn/pretrained_model/Prior.ckpt` with the real model file.")
102
+ st.stop()
103
 
104
  with st.spinner("Building novelty index (cached after first load)..."):
105
  existing_smiles = load_existing_smiles_set(tuple(str(p) for p in available_datasets)) if available_datasets else set()
 
157
  st.stop()
158
 
159
  result_df = pd.DataFrame({"SMILES": novel})
160
+ display_df = result_df.copy()
161
+ display_df.index = range(1, len(display_df) + 1)
162
+ st.dataframe(display_df, width="stretch")
163
  st.download_button(
164
  "Download CSV",
165
  data=result_df.to_csv(index=False).encode("utf-8"),