Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- pages/1_Property_Probe.py +16 -4
- pages/2_Batch_Prediction.py +71 -10
- pages/3_Molecular_View.py +10 -4
- pages/4_Discovery_(Manual).py +9 -3
- pages/5_Discovery_(AI).py +396 -133
- pages/6_Novel_SMILES_Generation.py +53 -15
pages/1_Property_Probe.py
CHANGED
|
@@ -11,14 +11,23 @@ from src.lookup import (
|
|
| 11 |
get_polyinfo,
|
| 12 |
)
|
| 13 |
from src.predictor_router import RouterPredictor
|
| 14 |
-
from src.ui_style import apply_global_style
|
| 15 |
|
| 16 |
st.set_page_config(page_title="Property Probe", layout="wide")
|
| 17 |
apply_global_style()
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
predictor =
|
| 22 |
|
| 23 |
|
| 24 |
def resolve_smiles_from_polymer_name(db_obj, polymer_name_query: str) -> tuple[str | None, str | None]:
|
|
@@ -72,6 +81,8 @@ selected_label = st.selectbox("Select property", options)
|
|
| 72 |
prop = label_to_key[selected_label]
|
| 73 |
|
| 74 |
if st.button("Search", type="primary"):
|
|
|
|
|
|
|
| 75 |
if input_mode == "SMILES":
|
| 76 |
s_canon = canonicalize_smiles(query_value)
|
| 77 |
if s_canon is None:
|
|
@@ -137,4 +148,5 @@ if st.button("Search", type="primary"):
|
|
| 137 |
})
|
| 138 |
|
| 139 |
out = pd.DataFrame(rows)
|
|
|
|
| 140 |
st.table(out)
|
|
|
|
| 11 |
get_polyinfo,
|
| 12 |
)
|
| 13 |
from src.predictor_router import RouterPredictor
|
| 14 |
+
from src.ui_style import apply_global_style, render_page_header
|
| 15 |
|
| 16 |
st.set_page_config(page_title="Property Probe", layout="wide")
|
| 17 |
apply_global_style()
|
| 18 |
+
render_page_header(
|
| 19 |
+
title="Quick Polymer Property Check",
|
| 20 |
+
subtitle="Check one polymer at a time using source lookups plus ensemble ML prediction.",
|
| 21 |
+
badge="Property Probe",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@st.cache_resource(show_spinner=False)
|
| 26 |
+
def get_router_predictor() -> RouterPredictor:
|
| 27 |
+
return RouterPredictor(device="cpu")
|
| 28 |
|
| 29 |
+
|
| 30 |
+
predictor = get_router_predictor()
|
| 31 |
|
| 32 |
|
| 33 |
def resolve_smiles_from_polymer_name(db_obj, polymer_name_query: str) -> tuple[str | None, str | None]:
|
|
|
|
| 81 |
prop = label_to_key[selected_label]
|
| 82 |
|
| 83 |
if st.button("Search", type="primary"):
|
| 84 |
+
db = load_all_sources()
|
| 85 |
+
|
| 86 |
if input_mode == "SMILES":
|
| 87 |
s_canon = canonicalize_smiles(query_value)
|
| 88 |
if s_canon is None:
|
|
|
|
| 148 |
})
|
| 149 |
|
| 150 |
out = pd.DataFrame(rows)
|
| 151 |
+
out.index = range(1, len(out) + 1)
|
| 152 |
st.table(out)
|
pages/2_Batch_Prediction.py
CHANGED
|
@@ -2,17 +2,41 @@ import io
|
|
| 2 |
import pandas as pd
|
| 3 |
import streamlit as st
|
| 4 |
|
| 5 |
-
from src.lookup import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from src.predictor_router import RouterPredictor
|
| 7 |
-
from src.ui_style import apply_global_style
|
| 8 |
|
| 9 |
st.set_page_config(page_title="Batch Prediction", layout="wide")
|
| 10 |
apply_global_style()
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
MAX_RENDER_ROWS = 5000 # above this -> download only (no dataframe render)
|
|
|
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
# -----------------------------
|
|
@@ -82,7 +106,17 @@ for k in prop_keys:
|
|
| 82 |
label_to_key[label] = k
|
| 83 |
|
| 84 |
selected_labels = st.multiselect("Select properties to predict", options=prop_options)
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
st.divider()
|
| 88 |
|
|
@@ -129,14 +163,15 @@ else:
|
|
| 129 |
)
|
| 130 |
dataset_path = dataset[1]
|
| 131 |
|
| 132 |
-
#
|
| 133 |
-
st.caption("
|
| 134 |
|
| 135 |
pick_mode = st.radio("Row selection", options=["First N", "Random sample N"], horizontal=True)
|
| 136 |
mode = "first" if pick_mode == "First N" else "random"
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
|
|
|
| 140 |
|
| 141 |
n = st.number_input(
|
| 142 |
"How many SMILES to use",
|
|
@@ -169,6 +204,23 @@ if run:
|
|
| 169 |
st.stop()
|
| 170 |
|
| 171 |
props = [label_to_key[lbl] for lbl in selected_labels]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
# Decide whether to render table
|
| 174 |
render_table = len(smiles_list) <= MAX_RENDER_ROWS
|
|
@@ -204,6 +256,13 @@ if run:
|
|
| 204 |
if include_std:
|
| 205 |
row[col_name + " [std]"] = std
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
rows.append(row)
|
| 208 |
if total > 0:
|
| 209 |
progress.progress(int(100 * i / total))
|
|
@@ -224,7 +283,9 @@ if run:
|
|
| 224 |
# Render table only if safe
|
| 225 |
if render_table:
|
| 226 |
st.subheader("Predictions")
|
| 227 |
-
|
|
|
|
|
|
|
| 228 |
|
| 229 |
# Download
|
| 230 |
csv_bytes = out_df.to_csv(index=False).encode("utf-8")
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
import streamlit as st
|
| 4 |
|
| 5 |
+
from src.lookup import (
|
| 6 |
+
PROPERTY_META,
|
| 7 |
+
SOURCES,
|
| 8 |
+
SOURCE_LABELS,
|
| 9 |
+
canonicalize_smiles,
|
| 10 |
+
get_value,
|
| 11 |
+
load_all_sources,
|
| 12 |
+
)
|
| 13 |
from src.predictor_router import RouterPredictor
|
| 14 |
+
from src.ui_style import apply_global_style, render_page_header
|
| 15 |
|
| 16 |
st.set_page_config(page_title="Batch Prediction", layout="wide")
|
| 17 |
apply_global_style()
|
| 18 |
+
render_page_header(
|
| 19 |
+
title="Bulk Polymer Property Prediction",
|
| 20 |
+
subtitle="Predict multiple target properties for large candidate sets with downloadable results.",
|
| 21 |
+
badge="Batch Prediction",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@st.cache_resource(show_spinner=False)
|
| 26 |
+
def get_router_predictor() -> RouterPredictor:
|
| 27 |
+
return RouterPredictor(device="cpu")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
predictor = get_router_predictor()
|
| 31 |
|
| 32 |
+
|
| 33 |
+
@st.cache_resource(show_spinner=False)
|
| 34 |
+
def get_lookup_db():
|
| 35 |
+
return load_all_sources()
|
| 36 |
|
| 37 |
MAX_RENDER_ROWS = 5000 # above this -> download only (no dataframe render)
|
| 38 |
+
MAX_BATCH_SMILES = 3000
|
| 39 |
+
MAX_BATCH_PREDICTIONS = 25000
|
| 40 |
|
| 41 |
|
| 42 |
# -----------------------------
|
|
|
|
| 106 |
label_to_key[label] = k
|
| 107 |
|
| 108 |
selected_labels = st.multiselect("Select properties to predict", options=prop_options)
|
| 109 |
+
opt_col1, opt_col2 = st.columns([1, 2])
|
| 110 |
+
with opt_col1:
|
| 111 |
+
include_std = st.checkbox("Include model std (ensemble spread)", value=False)
|
| 112 |
+
with opt_col2:
|
| 113 |
+
selected_source_labels = st.multiselect(
|
| 114 |
+
"Include source database values",
|
| 115 |
+
options=[SOURCE_LABELS.get(src, src) for src in SOURCES],
|
| 116 |
+
placeholder="Select Experiment, MD, DFT, and/or GC",
|
| 117 |
+
)
|
| 118 |
+
source_label_to_key = {SOURCE_LABELS.get(src, src): src for src in SOURCES}
|
| 119 |
+
selected_sources = [source_label_to_key[label] for label in selected_source_labels]
|
| 120 |
|
| 121 |
st.divider()
|
| 122 |
|
|
|
|
| 163 |
)
|
| 164 |
dataset_path = dataset[1]
|
| 165 |
|
| 166 |
+
# Website-safe cap: render mode is not enough, inference itself must stay bounded.
|
| 167 |
+
st.caption("Website-safe limits apply. Large jobs should be run offline rather than in the live app.")
|
| 168 |
|
| 169 |
pick_mode = st.radio("Row selection", options=["First N", "Random sample N"], horizontal=True)
|
| 170 |
mode = "first" if pick_mode == "First N" else "random"
|
| 171 |
|
| 172 |
+
is_virtual_pi1m = dataset_path.endswith("PI1M.csv")
|
| 173 |
+
default_n = 1000 if is_virtual_pi1m else 2000
|
| 174 |
+
max_n = MAX_BATCH_SMILES
|
| 175 |
|
| 176 |
n = st.number_input(
|
| 177 |
"How many SMILES to use",
|
|
|
|
| 204 |
st.stop()
|
| 205 |
|
| 206 |
props = [label_to_key[lbl] for lbl in selected_labels]
|
| 207 |
+
lookup_db = get_lookup_db() if selected_sources else None
|
| 208 |
+
requested_smiles = len(smiles_list)
|
| 209 |
+
prediction_cells = requested_smiles * len(props)
|
| 210 |
+
|
| 211 |
+
if requested_smiles > MAX_BATCH_SMILES:
|
| 212 |
+
st.error(
|
| 213 |
+
f"This website currently limits Batch Prediction to {MAX_BATCH_SMILES:,} SMILES per run. "
|
| 214 |
+
"Use a smaller subset or run larger jobs offline."
|
| 215 |
+
)
|
| 216 |
+
st.stop()
|
| 217 |
+
|
| 218 |
+
if prediction_cells > MAX_BATCH_PREDICTIONS:
|
| 219 |
+
st.error(
|
| 220 |
+
f"This request would run {prediction_cells:,} model predictions, which exceeds the website-safe limit "
|
| 221 |
+
f"of {MAX_BATCH_PREDICTIONS:,}. Reduce the number of SMILES or selected properties."
|
| 222 |
+
)
|
| 223 |
+
st.stop()
|
| 224 |
|
| 225 |
# Decide whether to render table
|
| 226 |
render_table = len(smiles_list) <= MAX_RENDER_ROWS
|
|
|
|
| 256 |
if include_std:
|
| 257 |
row[col_name + " [std]"] = std
|
| 258 |
|
| 259 |
+
if lookup_db is not None:
|
| 260 |
+
for src in selected_sources:
|
| 261 |
+
src_label = SOURCE_LABELS.get(src, src)
|
| 262 |
+
src_col = f"{col_name} [{src_label}]"
|
| 263 |
+
val = get_value(lookup_db, src, s_canon, prop)
|
| 264 |
+
row[src_col] = float("nan") if val is None else val
|
| 265 |
+
|
| 266 |
rows.append(row)
|
| 267 |
if total > 0:
|
| 268 |
progress.progress(int(100 * i / total))
|
|
|
|
| 283 |
# Render table only if safe
|
| 284 |
if render_table:
|
| 285 |
st.subheader("Predictions")
|
| 286 |
+
display_df = out_df.copy()
|
| 287 |
+
display_df.index = range(1, len(display_df) + 1)
|
| 288 |
+
st.dataframe(display_df, width="stretch")
|
| 289 |
|
| 290 |
# Download
|
| 291 |
csv_bytes = out_df.to_csv(index=False).encode("utf-8")
|
pages/3_Molecular_View.py
CHANGED
|
@@ -10,13 +10,17 @@ from streamlit.components.v1 import html
|
|
| 10 |
from rdkit.Chem import Lipinski, Crippen
|
| 11 |
from rdkit.Chem.rdMolDescriptors import CalcTPSA, CalcExactMolWt, CalcFractionCSP3, CalcNumRings, CalcNumAromaticRings
|
| 12 |
|
| 13 |
-
from src.ui_style import apply_global_style
|
| 14 |
|
| 15 |
RDLogger.DisableLog("rdApp.*")
|
| 16 |
|
| 17 |
st.set_page_config(page_title="Molecular View", layout="wide")
|
| 18 |
apply_global_style()
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# -------------------------
|
| 22 |
# Polymer-safe helpers
|
|
@@ -313,7 +317,7 @@ with top_left:
|
|
| 313 |
|
| 314 |
with top_right:
|
| 315 |
st.markdown("Molecule Information ")
|
| 316 |
-
|
| 317 |
{
|
| 318 |
"Property": ["Formula", "Molar Weight", "Atoms"],
|
| 319 |
"Value": [
|
|
@@ -323,6 +327,8 @@ with top_right:
|
|
| 323 |
],
|
| 324 |
}
|
| 325 |
)
|
|
|
|
|
|
|
| 326 |
|
| 327 |
# MOL download *below the table*
|
| 328 |
if mol_block_3d is not None:
|
|
@@ -366,4 +372,4 @@ with bottom_right:
|
|
| 366 |
|
| 367 |
# Legend: include hydrogens + colored dots
|
| 368 |
# Use capped mol (no '*') for clean element counting
|
| 369 |
-
render_element_legend_with_colors(mol_cap, include_hydrogens=True)
|
|
|
|
| 10 |
from rdkit.Chem import Lipinski, Crippen
|
| 11 |
from rdkit.Chem.rdMolDescriptors import CalcTPSA, CalcExactMolWt, CalcFractionCSP3, CalcNumRings, CalcNumAromaticRings
|
| 12 |
|
| 13 |
+
from src.ui_style import apply_global_style, render_page_header
|
| 14 |
|
| 15 |
RDLogger.DisableLog("rdApp.*")
|
| 16 |
|
| 17 |
st.set_page_config(page_title="Molecular View", layout="wide")
|
| 18 |
apply_global_style()
|
| 19 |
+
render_page_header(
|
| 20 |
+
title="Molecular Structure View",
|
| 21 |
+
subtitle="Inspect 2D and 3D polymer structures and review repeat-unit descriptors.",
|
| 22 |
+
badge="Molecular View",
|
| 23 |
+
)
|
| 24 |
|
| 25 |
# -------------------------
|
| 26 |
# Polymer-safe helpers
|
|
|
|
| 317 |
|
| 318 |
with top_right:
|
| 319 |
st.markdown("Molecule Information ")
|
| 320 |
+
info_df = pd.DataFrame(
|
| 321 |
{
|
| 322 |
"Property": ["Formula", "Molar Weight", "Atoms"],
|
| 323 |
"Value": [
|
|
|
|
| 327 |
],
|
| 328 |
}
|
| 329 |
)
|
| 330 |
+
info_df.index = range(1, len(info_df) + 1)
|
| 331 |
+
st.table(info_df)
|
| 332 |
|
| 333 |
# MOL download *below the table*
|
| 334 |
if mol_block_3d is not None:
|
|
|
|
| 372 |
|
| 373 |
# Legend: include hydrogens + colored dots
|
| 374 |
# Use capped mol (no '*') for clean element counting
|
| 375 |
+
render_element_legend_with_colors(mol_cap, include_hydrogens=True)
|
pages/4_Discovery_(Manual).py
CHANGED
|
@@ -13,11 +13,15 @@ import streamlit as st
|
|
| 13 |
|
| 14 |
from src.discovery import run_discovery, spec_from_dict
|
| 15 |
from src.lookup import PROPERTY_META
|
| 16 |
-
from src.ui_style import apply_global_style
|
| 17 |
|
| 18 |
st.set_page_config(page_title="Discovery (Manual)", layout="wide")
|
| 19 |
apply_global_style()
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
# ----------------------------
|
| 23 |
# Files
|
|
@@ -699,7 +703,9 @@ if st.session_state.get("discovery_done"):
|
|
| 699 |
meta = PROPERTY_META[prop_key]
|
| 700 |
rename_map[c] = f"{meta['name']} ({meta['unit']})"
|
| 701 |
preview_df = preview_df.rename(columns=rename_map)
|
| 702 |
-
|
|
|
|
|
|
|
| 703 |
|
| 704 |
st.subheader("📥 Download")
|
| 705 |
buf = io.StringIO()
|
|
|
|
| 13 |
|
| 14 |
from src.discovery import run_discovery, spec_from_dict
|
| 15 |
from src.lookup import PROPERTY_META
|
| 16 |
+
from src.ui_style import apply_global_style, render_page_header
|
| 17 |
|
| 18 |
st.set_page_config(page_title="Discovery (Manual)", layout="wide")
|
| 19 |
apply_global_style()
|
| 20 |
+
render_page_header(
|
| 21 |
+
title="Manual Multi-Objective Discovery",
|
| 22 |
+
subtitle="Tune objectives and constraints directly to explore Pareto-optimal polymer candidates.",
|
| 23 |
+
badge="Discovery (Manual)",
|
| 24 |
+
)
|
| 25 |
|
| 26 |
# ----------------------------
|
| 27 |
# Files
|
|
|
|
| 703 |
meta = PROPERTY_META[prop_key]
|
| 704 |
rename_map[c] = f"{meta['name']} ({meta['unit']})"
|
| 705 |
preview_df = preview_df.rename(columns=rename_map)
|
| 706 |
+
preview_display = preview_df.head(50).copy()
|
| 707 |
+
preview_display.index = range(1, len(preview_display) + 1)
|
| 708 |
+
st.dataframe(preview_display, width="stretch")
|
| 709 |
|
| 710 |
st.subheader("📥 Download")
|
| 711 |
buf = io.StringIO()
|
pages/5_Discovery_(AI).py
CHANGED
|
@@ -8,6 +8,7 @@ import threading
|
|
| 8 |
import time
|
| 9 |
import urllib.request
|
| 10 |
import urllib.error
|
|
|
|
| 11 |
import zipfile
|
| 12 |
from pathlib import Path
|
| 13 |
|
|
@@ -17,11 +18,15 @@ import streamlit as st
|
|
| 17 |
from streamlit.components.v1 import html
|
| 18 |
|
| 19 |
from src.discover_llm import PROPERTY_META, run_discovery, spec_from_dict
|
| 20 |
-
from src.ui_style import apply_global_style
|
| 21 |
|
| 22 |
st.set_page_config(page_title="DISCOVERY (AI)", layout="wide")
|
| 23 |
apply_global_style()
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# ----------------------------
|
| 27 |
# Files
|
|
@@ -307,7 +312,236 @@ def get_webui_base_url() -> str:
|
|
| 307 |
).rstrip("/")
|
| 308 |
|
| 309 |
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
"""Return None when credentials are usable, else an error message."""
|
| 312 |
k = str(api_key or "").strip()
|
| 313 |
u = str(base_url or "").strip().rstrip("/")
|
|
@@ -315,8 +549,23 @@ def validate_api_access(api_key: str, base_url: str) -> str | None:
|
|
| 315 |
return "API key is required."
|
| 316 |
if not u.startswith("https://"):
|
| 317 |
return "API base URL must start with `https://`."
|
|
|
|
|
|
|
| 318 |
try:
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
except Exception as e:
|
| 321 |
return f"API key validation failed: {e}"
|
| 322 |
return None
|
|
@@ -326,40 +575,37 @@ def clear_byok_key() -> None:
|
|
| 326 |
st.session_state["discover_llm_byok_key"] = ""
|
| 327 |
|
| 328 |
|
| 329 |
-
def
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
headers={
|
| 335 |
-
"Authorization": f"Bearer {api_key}",
|
| 336 |
-
"Content-Type": "application/json",
|
| 337 |
-
},
|
| 338 |
-
method=("POST" if payload is not None else "GET"),
|
| 339 |
-
)
|
| 340 |
-
try:
|
| 341 |
-
with urllib.request.urlopen(req, timeout=60) as resp:
|
| 342 |
-
return json.loads(resp.read().decode("utf-8"))
|
| 343 |
-
except urllib.error.HTTPError as e:
|
| 344 |
-
detail = e.read().decode("utf-8", errors="ignore")
|
| 345 |
-
raise RuntimeError(f"WebUI API HTTP {e.code}: {detail}") from e
|
| 346 |
-
except Exception as e:
|
| 347 |
-
raise RuntimeError(f"WebUI API call failed: {e}") from e
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
def list_available_models(api_key: str | None = None, base_url: str | None = None) -> list[str]:
|
| 351 |
api_key = (api_key or get_webui_api_key()).strip()
|
| 352 |
if not api_key:
|
| 353 |
return []
|
| 354 |
base_url = (base_url or get_webui_base_url()).rstrip("/")
|
| 355 |
-
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
if not isinstance(items, list):
|
| 358 |
return []
|
| 359 |
out = []
|
| 360 |
for m in items:
|
| 361 |
if isinstance(m, dict):
|
| 362 |
mid = str(m.get("id", m.get("name", ""))).strip()
|
|
|
|
|
|
|
| 363 |
else:
|
| 364 |
mid = str(m).strip()
|
| 365 |
if mid:
|
|
@@ -368,7 +614,11 @@ def list_available_models(api_key: str | None = None, base_url: str | None = Non
|
|
| 368 |
|
| 369 |
|
| 370 |
def generate_spec_from_llm(
|
| 371 |
-
user_query: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
) -> dict:
|
| 373 |
api_key = (api_key or get_webui_api_key()).strip()
|
| 374 |
if not api_key:
|
|
@@ -401,20 +651,15 @@ def generate_spec_from_llm(
|
|
| 401 |
user_prompt = (
|
| 402 |
"User request:\n" + user_query.strip()
|
| 403 |
)
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
try:
|
| 415 |
-
content = raw["choices"][0]["message"]["content"]
|
| 416 |
-
except Exception:
|
| 417 |
-
raise RuntimeError("Unexpected LLM response format.")
|
| 418 |
|
| 419 |
try:
|
| 420 |
parsed = extract_first_json_object(content)
|
|
@@ -506,11 +751,6 @@ def render_copyable_prompt(prompt_text: str, box_height: int = 220) -> None:
|
|
| 506 |
html(snippet, height=box_height + 54)
|
| 507 |
|
| 508 |
|
| 509 |
-
@st.cache_data(ttl=300, show_spinner=False)
|
| 510 |
-
def list_available_models_cached() -> list[str]:
|
| 511 |
-
return list_available_models()
|
| 512 |
-
|
| 513 |
-
|
| 514 |
def _local_reasoning_fallback(spec_obj: dict, stats: dict) -> str:
|
| 515 |
objectives = spec_obj.get("objectives", []) if isinstance(spec_obj, dict) else []
|
| 516 |
constraints = spec_obj.get("hard_constraints", {}) if isinstance(spec_obj, dict) else {}
|
|
@@ -599,6 +839,7 @@ def generate_selection_reasoning(
|
|
| 599 |
model: str,
|
| 600 |
api_key: str | None = None,
|
| 601 |
base_url: str | None = None,
|
|
|
|
| 602 |
) -> str:
|
| 603 |
api_key = (api_key or get_webui_api_key()).strip()
|
| 604 |
if not api_key:
|
|
@@ -674,19 +915,15 @@ def generate_selection_reasoning(
|
|
| 674 |
"You can add brief clarifying bullets if helpful, but keep it concise and focused.\n\n"
|
| 675 |
f"INPUT:\n{json.dumps(user_payload, indent=2)}"
|
| 676 |
)
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
content = raw["choices"][0]["message"]["content"]
|
| 687 |
-
return str(content).strip()
|
| 688 |
-
except Exception:
|
| 689 |
-
raise RuntimeError("Unexpected LLM response format for reasoning.")
|
| 690 |
|
| 691 |
|
| 692 |
def pareto_publication_plot(plot_df: pd.DataFrame, obj_props: list[str]):
|
|
@@ -1011,13 +1248,17 @@ if "discover_llm_query_text" not in st.session_state:
|
|
| 1011 |
if "discover_llm_last_example_choice" not in st.session_state:
|
| 1012 |
st.session_state["discover_llm_last_example_choice"] = "Select an example prompt…"
|
| 1013 |
if "discover_llm_mode" not in st.session_state:
|
| 1014 |
-
st.session_state["discover_llm_mode"] = "
|
| 1015 |
if "discover_llm_external_response" not in st.session_state:
|
| 1016 |
st.session_state["discover_llm_external_response"] = ""
|
| 1017 |
if "discover_llm_byok_key" not in st.session_state:
|
| 1018 |
st.session_state["discover_llm_byok_key"] = ""
|
| 1019 |
if "discover_llm_byok_base_url" not in st.session_state:
|
| 1020 |
-
st.session_state["discover_llm_byok_base_url"] =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1021 |
|
| 1022 |
# Apply deferred JSON updates before any JSON editor widget is instantiated.
|
| 1023 |
pending_spec_text = st.session_state.get("discover_llm_spec_text_next")
|
|
@@ -1049,7 +1290,7 @@ with st.container(border=True):
|
|
| 1049 |
)
|
| 1050 |
mode = st.radio(
|
| 1051 |
"LLM setup",
|
| 1052 |
-
options=["
|
| 1053 |
key="discover_llm_mode",
|
| 1054 |
horizontal=True,
|
| 1055 |
)
|
|
@@ -1058,75 +1299,73 @@ external_response_text = st.session_state.get("discover_llm_external_response",
|
|
| 1058 |
selected_model = "external-copy-paste"
|
| 1059 |
active_api_key = ""
|
| 1060 |
active_base_url = get_webui_base_url()
|
|
|
|
| 1061 |
api_config_invalid = False
|
| 1062 |
-
default_model = (
|
| 1063 |
-
get_config_value("CRC_OPENWEBUI_MODEL", "")
|
| 1064 |
-
or get_config_value("OPENWEBUI_MODEL", "")
|
| 1065 |
-
or get_config_value("OPENAI_MODEL", "")
|
| 1066 |
-
or "gpt-oss:latest"
|
| 1067 |
-
)
|
| 1068 |
|
| 1069 |
-
if mode
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
|
| 1077 |
-
|
| 1078 |
-
|
| 1079 |
-
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1102 |
|
| 1103 |
available_models: list[str] = []
|
| 1104 |
models_error = ""
|
| 1105 |
if active_api_key and not api_config_invalid:
|
| 1106 |
try:
|
| 1107 |
-
|
| 1108 |
-
available_models = list_available_models_cached()
|
| 1109 |
-
else:
|
| 1110 |
-
available_models = list_available_models(active_api_key, active_base_url)
|
| 1111 |
except Exception as e:
|
| 1112 |
models_error = str(e)
|
| 1113 |
|
| 1114 |
if available_models:
|
| 1115 |
-
model_index = available_models.index(
|
| 1116 |
-
selected_model =
|
| 1117 |
-
|
| 1118 |
-
options=available_models,
|
| 1119 |
-
index=model_index,
|
| 1120 |
-
help="Model used only to translate your natural language request into JSON.",
|
| 1121 |
-
)
|
| 1122 |
else:
|
| 1123 |
if models_error:
|
| 1124 |
-
st.warning(f"Could not load model list from API.
|
| 1125 |
-
selected_model =
|
| 1126 |
-
|
| 1127 |
-
value=default_model,
|
| 1128 |
-
help="Use a valid model id from your CRC Open WebUI instance (for example `gpt-oss:latest`).",
|
| 1129 |
-
)
|
| 1130 |
else:
|
| 1131 |
with st.container(border=True):
|
| 1132 |
st.caption(
|
|
@@ -1148,7 +1387,7 @@ generate_json_btn = False
|
|
| 1148 |
if show_json_editor:
|
| 1149 |
generate_json_btn = st.button(
|
| 1150 |
"Generate JSON from LLM"
|
| 1151 |
-
if mode
|
| 1152 |
else "Generate JSON from pasted response"
|
| 1153 |
)
|
| 1154 |
|
|
@@ -1554,7 +1793,11 @@ def _build_runnable_spec(raw_obj: dict) -> tuple[dict, list[str], list[str]]:
|
|
| 1554 |
|
| 1555 |
|
| 1556 |
def _raw_spec_from_prompt(
|
| 1557 |
-
user_query: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1558 |
) -> tuple[dict, list[str], str | None]:
|
| 1559 |
notes: list[str] = []
|
| 1560 |
extracted = {}
|
|
@@ -1562,7 +1805,13 @@ def _raw_spec_from_prompt(
|
|
| 1562 |
return {}, notes, "Please provide a prompt before generating or running discovery."
|
| 1563 |
with st.spinner("Interpreting prompt and preparing discovery config..."):
|
| 1564 |
try:
|
| 1565 |
-
extracted = generate_spec_from_llm(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1566 |
except Exception as e:
|
| 1567 |
return {}, notes, f"LLM generation failed: {e}"
|
| 1568 |
|
|
@@ -1629,17 +1878,21 @@ def _raw_spec_from_external_response(user_query: str, response_text: str) -> tup
|
|
| 1629 |
|
| 1630 |
|
| 1631 |
if show_json_editor and generate_json_btn:
|
| 1632 |
-
if mode
|
| 1633 |
st.error("Please provide a prompt before generating JSON.")
|
| 1634 |
st.stop()
|
| 1635 |
if mode == "Bring Your Own Key":
|
| 1636 |
-
byok_err = validate_api_access(active_api_key, active_base_url)
|
| 1637 |
if byok_err:
|
| 1638 |
st.error(f"BYOK validation failed: {byok_err}")
|
| 1639 |
st.stop()
|
| 1640 |
-
if mode
|
| 1641 |
raw_spec_obj, prep_notes, parse_error = _raw_spec_from_prompt(
|
| 1642 |
-
llm_query,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1643 |
)
|
| 1644 |
if parse_error:
|
| 1645 |
for msg in prep_notes:
|
|
@@ -1665,11 +1918,11 @@ if show_json_editor and generate_json_btn:
|
|
| 1665 |
run_btn = st.button("Run discovery", type="primary")
|
| 1666 |
|
| 1667 |
if run_btn:
|
| 1668 |
-
if mode
|
| 1669 |
st.error("Please provide a prompt before running discovery.")
|
| 1670 |
st.stop()
|
| 1671 |
if mode == "Bring Your Own Key":
|
| 1672 |
-
byok_err = validate_api_access(active_api_key, active_base_url)
|
| 1673 |
if byok_err:
|
| 1674 |
st.error(f"BYOK validation failed: {byok_err}")
|
| 1675 |
st.stop()
|
|
@@ -1686,9 +1939,13 @@ if run_btn:
|
|
| 1686 |
raw_spec_obj = {}
|
| 1687 |
prep_notes.append("Invalid JSON detected. Using fixed template defaults.")
|
| 1688 |
else:
|
| 1689 |
-
if mode
|
| 1690 |
raw_spec_obj, llm_notes, parse_error = _raw_spec_from_prompt(
|
| 1691 |
-
llm_query,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1692 |
)
|
| 1693 |
if parse_error:
|
| 1694 |
for msg in llm_notes:
|
|
@@ -1728,6 +1985,7 @@ if run_btn:
|
|
| 1728 |
st.session_state["discovery_mode_used"] = mode
|
| 1729 |
st.session_state["discovery_api_key"] = active_api_key if mode == "Bring Your Own Key" else ""
|
| 1730 |
st.session_state["discovery_api_base_url"] = active_base_url if mode == "Bring Your Own Key" else ""
|
|
|
|
| 1731 |
st.session_state["discovery_reasoning_text"] = None
|
| 1732 |
st.session_state["discovery_reasoning_key"] = None
|
| 1733 |
st.session_state["discovery_reasoning_note"] = None
|
|
@@ -1799,13 +2057,15 @@ if st.session_state.get("discovery_done"):
|
|
| 1799 |
c3.metric("Pareto pool", int(stats.get("n_pareto_pool", 0)))
|
| 1800 |
c4.metric("Selected", int(stats.get("n_selected", 0)))
|
| 1801 |
|
| 1802 |
-
if mode_used
|
| 1803 |
reasoning_api_key = st.session_state.get("discovery_api_key", "")
|
| 1804 |
reasoning_api_base_url = st.session_state.get("discovery_api_base_url", "")
|
|
|
|
| 1805 |
reasoning_key_obj = {
|
| 1806 |
"spec": resolved_spec,
|
| 1807 |
"model": model_used,
|
| 1808 |
"mode": mode_used,
|
|
|
|
| 1809 |
"selected_smiles_head": (
|
| 1810 |
out_df["SMILES"].astype(str).head(20).tolist()
|
| 1811 |
if isinstance(out_df, pd.DataFrame) and "SMILES" in out_df.columns
|
|
@@ -1826,6 +2086,7 @@ if st.session_state.get("discovery_done"):
|
|
| 1826 |
model_used,
|
| 1827 |
api_key=(str(reasoning_api_key).strip() or None),
|
| 1828 |
base_url=(str(reasoning_api_base_url).strip() or None),
|
|
|
|
| 1829 |
)
|
| 1830 |
st.session_state["discovery_reasoning_note"] = None
|
| 1831 |
except Exception as e:
|
|
@@ -1869,7 +2130,9 @@ if st.session_state.get("discovery_done"):
|
|
| 1869 |
meta = PROPERTY_META[prop_key]
|
| 1870 |
rename_map[c] = f"{meta['name']} ({meta['unit']})"
|
| 1871 |
preview_df = preview_df.rename(columns=rename_map)
|
| 1872 |
-
|
|
|
|
|
|
|
| 1873 |
|
| 1874 |
st.subheader("📥 Download")
|
| 1875 |
buf = io.StringIO()
|
|
|
|
| 8 |
import time
|
| 9 |
import urllib.request
|
| 10 |
import urllib.error
|
| 11 |
+
import urllib.parse
|
| 12 |
import zipfile
|
| 13 |
from pathlib import Path
|
| 14 |
|
|
|
|
| 18 |
from streamlit.components.v1 import html
|
| 19 |
|
| 20 |
from src.discover_llm import PROPERTY_META, run_discovery, spec_from_dict
|
| 21 |
+
from src.ui_style import apply_global_style, render_page_header
|
| 22 |
|
| 23 |
st.set_page_config(page_title="DISCOVERY (AI)", layout="wide")
|
| 24 |
apply_global_style()
|
| 25 |
+
render_page_header(
|
| 26 |
+
title="AI-Driven Multi-Objective Discovery",
|
| 27 |
+
subtitle="Describe target behavior in plain language and run auto-configured multi-objective search.",
|
| 28 |
+
badge="Discovery (AI)",
|
| 29 |
+
)
|
| 30 |
|
| 31 |
# ----------------------------
|
| 32 |
# Files
|
|
|
|
| 312 |
).rstrip("/")
|
| 313 |
|
| 314 |
|
| 315 |
+
PROVIDER_LABELS = {
|
| 316 |
+
"auto": "Auto detect",
|
| 317 |
+
"openwebui": "OpenWebUI",
|
| 318 |
+
"openai_compatible": "OpenAI-compatible",
|
| 319 |
+
"anthropic": "Anthropic",
|
| 320 |
+
"gemini": "Gemini",
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
PROVIDER_OPTIONS = list(PROVIDER_LABELS.keys())
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _provider_label(provider: str) -> str:
|
| 327 |
+
return PROVIDER_LABELS.get(provider, provider)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def default_model_for_provider(provider: str) -> str:
|
| 331 |
+
p = _normalize_provider(provider)
|
| 332 |
+
if p == "openwebui":
|
| 333 |
+
return (
|
| 334 |
+
get_config_value("CRC_OPENWEBUI_MODEL", "")
|
| 335 |
+
or get_config_value("OPENWEBUI_MODEL", "")
|
| 336 |
+
or get_config_value("OPENAI_MODEL", "")
|
| 337 |
+
or "gpt-oss:latest"
|
| 338 |
+
)
|
| 339 |
+
if p == "openai_compatible":
|
| 340 |
+
return (
|
| 341 |
+
get_config_value("OPENAI_MODEL", "")
|
| 342 |
+
or get_config_value("OPENWEBUI_MODEL", "")
|
| 343 |
+
or get_config_value("CRC_OPENWEBUI_MODEL", "")
|
| 344 |
+
or "gpt-4o-mini"
|
| 345 |
+
)
|
| 346 |
+
if p == "anthropic":
|
| 347 |
+
return get_config_value("ANTHROPIC_MODEL", "") or "claude-3-5-sonnet-latest"
|
| 348 |
+
if p == "gemini":
|
| 349 |
+
return get_config_value("GEMINI_MODEL", "") or "gemini-2.0-flash"
|
| 350 |
+
return get_config_value("OPENAI_MODEL", "") or "gpt-4o-mini"
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def _normalize_provider(provider: str | None) -> str:
|
| 354 |
+
s = str(provider or "").strip().lower().replace("-", "_").replace(" ", "_")
|
| 355 |
+
if s in PROVIDER_LABELS:
|
| 356 |
+
return s
|
| 357 |
+
return "auto"
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def detect_api_provider(base_url: str) -> str:
|
| 361 |
+
u = str(base_url or "").strip().lower()
|
| 362 |
+
if "openwebui" in u:
|
| 363 |
+
return "openwebui"
|
| 364 |
+
if "anthropic.com" in u:
|
| 365 |
+
return "anthropic"
|
| 366 |
+
if "generativelanguage.googleapis.com" in u or "googleapis.com" in u:
|
| 367 |
+
return "gemini"
|
| 368 |
+
if "api.openai.com" in u or "/v1" in u or "openrouter.ai" in u:
|
| 369 |
+
return "openai_compatible"
|
| 370 |
+
return "openai_compatible"
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def resolve_api_provider(base_url: str, provider: str | None = None) -> str:
|
| 374 |
+
p = _normalize_provider(provider)
|
| 375 |
+
if p == "auto":
|
| 376 |
+
return detect_api_provider(base_url)
|
| 377 |
+
return p
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def _provider_root(base_url: str, provider: str) -> str:
|
| 381 |
+
u = str(base_url or "").strip().rstrip("/")
|
| 382 |
+
if provider == "openwebui":
|
| 383 |
+
return u
|
| 384 |
+
if provider == "openai_compatible":
|
| 385 |
+
return u if u.endswith("/v1") else f"{u}/v1"
|
| 386 |
+
if provider == "anthropic":
|
| 387 |
+
return u if u.endswith("/v1") else f"{u}/v1"
|
| 388 |
+
if provider == "gemini":
|
| 389 |
+
if u.endswith("/v1") or u.endswith("/v1beta"):
|
| 390 |
+
return u
|
| 391 |
+
return f"{u}/v1beta"
|
| 392 |
+
return u
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def _join_url(base_url: str, path: str) -> str:
|
| 396 |
+
return f"{base_url.rstrip('/')}{path}"
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def _http_json_request(
|
| 400 |
+
url: str,
|
| 401 |
+
headers: dict[str, str] | None = None,
|
| 402 |
+
payload: dict | None = None,
|
| 403 |
+
method: str | None = None,
|
| 404 |
+
timeout: int = 60,
|
| 405 |
+
) -> dict:
|
| 406 |
+
req = urllib.request.Request(
|
| 407 |
+
url=url,
|
| 408 |
+
data=(json.dumps(payload).encode("utf-8") if payload is not None else None),
|
| 409 |
+
headers=(headers or {}),
|
| 410 |
+
method=(method or ("POST" if payload is not None else "GET")),
|
| 411 |
+
)
|
| 412 |
+
try:
|
| 413 |
+
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
| 414 |
+
return json.loads(resp.read().decode("utf-8"))
|
| 415 |
+
except urllib.error.HTTPError as e:
|
| 416 |
+
detail = e.read().decode("utf-8", errors="ignore")
|
| 417 |
+
raise RuntimeError(f"HTTP {e.code}: {detail}") from e
|
| 418 |
+
except Exception as e:
|
| 419 |
+
raise RuntimeError(str(e)) from e
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def _flatten_text_content(content) -> str:
|
| 423 |
+
if isinstance(content, str):
|
| 424 |
+
return content.strip()
|
| 425 |
+
if isinstance(content, list):
|
| 426 |
+
parts = []
|
| 427 |
+
for item in content:
|
| 428 |
+
if isinstance(item, str):
|
| 429 |
+
parts.append(item)
|
| 430 |
+
elif isinstance(item, dict):
|
| 431 |
+
txt = str(item.get("text", "")).strip()
|
| 432 |
+
if txt:
|
| 433 |
+
parts.append(txt)
|
| 434 |
+
return "\n".join(p for p in parts if p).strip()
|
| 435 |
+
return str(content or "").strip()
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def provider_request(
|
| 439 |
+
base_url: str,
|
| 440 |
+
api_key: str,
|
| 441 |
+
provider: str,
|
| 442 |
+
path: str,
|
| 443 |
+
payload: dict | None = None,
|
| 444 |
+
) -> dict:
|
| 445 |
+
root = _provider_root(base_url, provider)
|
| 446 |
+
headers = {"Content-Type": "application/json"}
|
| 447 |
+
url = _join_url(root, path)
|
| 448 |
+
|
| 449 |
+
if provider in {"openwebui", "openai_compatible"}:
|
| 450 |
+
headers["Authorization"] = f"Bearer {api_key}"
|
| 451 |
+
elif provider == "anthropic":
|
| 452 |
+
headers["x-api-key"] = api_key
|
| 453 |
+
headers["anthropic-version"] = "2023-06-01"
|
| 454 |
+
elif provider == "gemini":
|
| 455 |
+
sep = "&" if "?" in url else "?"
|
| 456 |
+
url = f"{url}{sep}key={urllib.parse.quote(api_key, safe='')}"
|
| 457 |
+
|
| 458 |
+
try:
|
| 459 |
+
return _http_json_request(url, headers=headers, payload=payload)
|
| 460 |
+
except Exception as e:
|
| 461 |
+
raise RuntimeError(f"{_provider_label(provider)} API call failed: {e}") from e
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def chat_text_request(
|
| 465 |
+
base_url: str,
|
| 466 |
+
api_key: str,
|
| 467 |
+
provider: str,
|
| 468 |
+
model: str,
|
| 469 |
+
system_prompt: str,
|
| 470 |
+
user_prompt: str,
|
| 471 |
+
max_tokens: int = 1024,
|
| 472 |
+
) -> str:
|
| 473 |
+
provider = resolve_api_provider(base_url, provider)
|
| 474 |
+
if provider in {"openwebui", "openai_compatible"}:
|
| 475 |
+
raw = provider_request(
|
| 476 |
+
base_url,
|
| 477 |
+
api_key,
|
| 478 |
+
provider,
|
| 479 |
+
"/chat/completions" if provider == "openai_compatible" else "/api/chat/completions",
|
| 480 |
+
payload={
|
| 481 |
+
"model": model,
|
| 482 |
+
"messages": [
|
| 483 |
+
{"role": "system", "content": system_prompt},
|
| 484 |
+
{"role": "user", "content": user_prompt},
|
| 485 |
+
],
|
| 486 |
+
},
|
| 487 |
+
)
|
| 488 |
+
try:
|
| 489 |
+
return _flatten_text_content(raw["choices"][0]["message"]["content"])
|
| 490 |
+
except Exception as e:
|
| 491 |
+
raise RuntimeError("Unexpected chat-completions response format.") from e
|
| 492 |
+
|
| 493 |
+
if provider == "anthropic":
|
| 494 |
+
raw = provider_request(
|
| 495 |
+
base_url,
|
| 496 |
+
api_key,
|
| 497 |
+
provider,
|
| 498 |
+
"/messages",
|
| 499 |
+
payload={
|
| 500 |
+
"model": model,
|
| 501 |
+
"system": system_prompt,
|
| 502 |
+
"max_tokens": int(max_tokens),
|
| 503 |
+
"messages": [{"role": "user", "content": user_prompt}],
|
| 504 |
+
},
|
| 505 |
+
)
|
| 506 |
+
try:
|
| 507 |
+
return "\n".join(
|
| 508 |
+
str(part.get("text", "")).strip()
|
| 509 |
+
for part in raw.get("content", [])
|
| 510 |
+
if isinstance(part, dict) and str(part.get("type", "")) == "text"
|
| 511 |
+
).strip()
|
| 512 |
+
except Exception as e:
|
| 513 |
+
raise RuntimeError("Unexpected Anthropic response format.") from e
|
| 514 |
+
|
| 515 |
+
if provider == "gemini":
|
| 516 |
+
model_name = str(model or "").strip()
|
| 517 |
+
if model_name.startswith("models/"):
|
| 518 |
+
model_name = model_name.split("/", 1)[1]
|
| 519 |
+
raw = provider_request(
|
| 520 |
+
base_url,
|
| 521 |
+
api_key,
|
| 522 |
+
provider,
|
| 523 |
+
f"/models/{urllib.parse.quote(model_name, safe='')}:generateContent",
|
| 524 |
+
payload={
|
| 525 |
+
"system_instruction": {"parts": [{"text": system_prompt}]},
|
| 526 |
+
"contents": [{"role": "user", "parts": [{"text": user_prompt}]}],
|
| 527 |
+
"generationConfig": {"temperature": 0.0, "maxOutputTokens": int(max_tokens)},
|
| 528 |
+
},
|
| 529 |
+
)
|
| 530 |
+
try:
|
| 531 |
+
candidates = raw.get("candidates", [])
|
| 532 |
+
parts = candidates[0]["content"]["parts"] if candidates else []
|
| 533 |
+
return "\n".join(
|
| 534 |
+
str(part.get("text", "")).strip()
|
| 535 |
+
for part in parts
|
| 536 |
+
if isinstance(part, dict) and str(part.get("text", "")).strip()
|
| 537 |
+
).strip()
|
| 538 |
+
except Exception as e:
|
| 539 |
+
raise RuntimeError("Unexpected Gemini response format.") from e
|
| 540 |
+
|
| 541 |
+
raise RuntimeError(f"Unsupported provider: {provider}")
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def validate_api_access(api_key: str, base_url: str, provider: str | None = None, model: str | None = None) -> str | None:
|
| 545 |
"""Return None when credentials are usable, else an error message."""
|
| 546 |
k = str(api_key or "").strip()
|
| 547 |
u = str(base_url or "").strip().rstrip("/")
|
|
|
|
| 549 |
return "API key is required."
|
| 550 |
if not u.startswith("https://"):
|
| 551 |
return "API base URL must start with `https://`."
|
| 552 |
+
|
| 553 |
+
resolved_provider = resolve_api_provider(u, provider)
|
| 554 |
try:
|
| 555 |
+
if resolved_provider in {"openwebui", "openai_compatible"}:
|
| 556 |
+
_ = list_available_models(k, u, resolved_provider)
|
| 557 |
+
elif resolved_provider in {"anthropic", "gemini"}:
|
| 558 |
+
if not str(model or "").strip():
|
| 559 |
+
return f"{_provider_label(resolved_provider)} validation requires a model name."
|
| 560 |
+
_ = chat_text_request(
|
| 561 |
+
u,
|
| 562 |
+
k,
|
| 563 |
+
resolved_provider,
|
| 564 |
+
str(model).strip(),
|
| 565 |
+
"Reply with OK.",
|
| 566 |
+
"ping",
|
| 567 |
+
max_tokens=8,
|
| 568 |
+
)
|
| 569 |
except Exception as e:
|
| 570 |
return f"API key validation failed: {e}"
|
| 571 |
return None
|
|
|
|
| 575 |
st.session_state["discover_llm_byok_key"] = ""
|
| 576 |
|
| 577 |
|
| 578 |
+
def list_available_models(
|
| 579 |
+
api_key: str | None = None,
|
| 580 |
+
base_url: str | None = None,
|
| 581 |
+
provider: str | None = None,
|
| 582 |
+
) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 583 |
api_key = (api_key or get_webui_api_key()).strip()
|
| 584 |
if not api_key:
|
| 585 |
return []
|
| 586 |
base_url = (base_url or get_webui_base_url()).rstrip("/")
|
| 587 |
+
resolved_provider = resolve_api_provider(base_url, provider)
|
| 588 |
+
|
| 589 |
+
if resolved_provider == "openwebui":
|
| 590 |
+
raw = provider_request(base_url, api_key, resolved_provider, "/api/models", payload=None)
|
| 591 |
+
items = raw.get("data", raw) if isinstance(raw, dict) else raw
|
| 592 |
+
elif resolved_provider == "openai_compatible":
|
| 593 |
+
raw = provider_request(base_url, api_key, resolved_provider, "/models", payload=None)
|
| 594 |
+
items = raw.get("data", raw) if isinstance(raw, dict) else raw
|
| 595 |
+
elif resolved_provider == "gemini":
|
| 596 |
+
raw = provider_request(base_url, api_key, resolved_provider, "/models", payload=None)
|
| 597 |
+
items = raw.get("models", raw.get("data", raw)) if isinstance(raw, dict) else raw
|
| 598 |
+
else:
|
| 599 |
+
return []
|
| 600 |
+
|
| 601 |
if not isinstance(items, list):
|
| 602 |
return []
|
| 603 |
out = []
|
| 604 |
for m in items:
|
| 605 |
if isinstance(m, dict):
|
| 606 |
mid = str(m.get("id", m.get("name", ""))).strip()
|
| 607 |
+
if resolved_provider == "gemini" and mid.startswith("models/"):
|
| 608 |
+
mid = mid.split("/", 1)[1]
|
| 609 |
else:
|
| 610 |
mid = str(m).strip()
|
| 611 |
if mid:
|
|
|
|
| 614 |
|
| 615 |
|
| 616 |
def generate_spec_from_llm(
|
| 617 |
+
user_query: str,
|
| 618 |
+
model: str,
|
| 619 |
+
api_key: str | None = None,
|
| 620 |
+
base_url: str | None = None,
|
| 621 |
+
provider: str | None = None,
|
| 622 |
) -> dict:
|
| 623 |
api_key = (api_key or get_webui_api_key()).strip()
|
| 624 |
if not api_key:
|
|
|
|
| 651 |
user_prompt = (
|
| 652 |
"User request:\n" + user_query.strip()
|
| 653 |
)
|
| 654 |
+
content = chat_text_request(
|
| 655 |
+
base_url,
|
| 656 |
+
api_key,
|
| 657 |
+
resolve_api_provider(base_url, provider),
|
| 658 |
+
model,
|
| 659 |
+
system_prompt,
|
| 660 |
+
user_prompt,
|
| 661 |
+
max_tokens=1024,
|
| 662 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
try:
|
| 665 |
parsed = extract_first_json_object(content)
|
|
|
|
| 751 |
html(snippet, height=box_height + 54)
|
| 752 |
|
| 753 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
def _local_reasoning_fallback(spec_obj: dict, stats: dict) -> str:
|
| 755 |
objectives = spec_obj.get("objectives", []) if isinstance(spec_obj, dict) else []
|
| 756 |
constraints = spec_obj.get("hard_constraints", {}) if isinstance(spec_obj, dict) else {}
|
|
|
|
| 839 |
model: str,
|
| 840 |
api_key: str | None = None,
|
| 841 |
base_url: str | None = None,
|
| 842 |
+
provider: str | None = None,
|
| 843 |
) -> str:
|
| 844 |
api_key = (api_key or get_webui_api_key()).strip()
|
| 845 |
if not api_key:
|
|
|
|
| 915 |
"You can add brief clarifying bullets if helpful, but keep it concise and focused.\n\n"
|
| 916 |
f"INPUT:\n{json.dumps(user_payload, indent=2)}"
|
| 917 |
)
|
| 918 |
+
return chat_text_request(
|
| 919 |
+
base_url,
|
| 920 |
+
api_key,
|
| 921 |
+
resolve_api_provider(base_url, provider),
|
| 922 |
+
model,
|
| 923 |
+
system_prompt,
|
| 924 |
+
user_prompt,
|
| 925 |
+
max_tokens=900,
|
| 926 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 927 |
|
| 928 |
|
| 929 |
def pareto_publication_plot(plot_df: pd.DataFrame, obj_props: list[str]):
|
|
|
|
| 1248 |
if "discover_llm_last_example_choice" not in st.session_state:
|
| 1249 |
st.session_state["discover_llm_last_example_choice"] = "Select an example prompt…"
|
| 1250 |
if "discover_llm_mode" not in st.session_state:
|
| 1251 |
+
st.session_state["discover_llm_mode"] = "Bring Your Own Key"
|
| 1252 |
if "discover_llm_external_response" not in st.session_state:
|
| 1253 |
st.session_state["discover_llm_external_response"] = ""
|
| 1254 |
if "discover_llm_byok_key" not in st.session_state:
|
| 1255 |
st.session_state["discover_llm_byok_key"] = ""
|
| 1256 |
if "discover_llm_byok_base_url" not in st.session_state:
|
| 1257 |
+
st.session_state["discover_llm_byok_base_url"] = ""
|
| 1258 |
+
if "discover_llm_byok_provider" not in st.session_state:
|
| 1259 |
+
st.session_state["discover_llm_byok_provider"] = "auto"
|
| 1260 |
+
if st.session_state.get("discover_llm_mode") not in {"Bring Your Own Key", "External LLM (manual copy–paste)"}:
|
| 1261 |
+
st.session_state["discover_llm_mode"] = "Bring Your Own Key"
|
| 1262 |
|
| 1263 |
# Apply deferred JSON updates before any JSON editor widget is instantiated.
|
| 1264 |
pending_spec_text = st.session_state.get("discover_llm_spec_text_next")
|
|
|
|
| 1290 |
)
|
| 1291 |
mode = st.radio(
|
| 1292 |
"LLM setup",
|
| 1293 |
+
options=["Bring Your Own Key", "External LLM (manual copy–paste)"],
|
| 1294 |
key="discover_llm_mode",
|
| 1295 |
horizontal=True,
|
| 1296 |
)
|
|
|
|
| 1299 |
selected_model = "external-copy-paste"
|
| 1300 |
active_api_key = ""
|
| 1301 |
active_base_url = get_webui_base_url()
|
| 1302 |
+
active_provider = "openwebui"
|
| 1303 |
api_config_invalid = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1304 |
|
| 1305 |
+
if mode == "Bring Your Own Key":
|
| 1306 |
+
with st.container(border=True):
|
| 1307 |
+
st.caption(
|
| 1308 |
+
"Bring Your Own Key mode: key is used only for this session and never written to files."
|
| 1309 |
+
)
|
| 1310 |
+
st.caption(
|
| 1311 |
+
"Enter the service root URL, not a full endpoint path. Examples: "
|
| 1312 |
+
"`https://api.openai.com`, `https://api.anthropic.com`, "
|
| 1313 |
+
"`https://generativelanguage.googleapis.com`, or your OpenWebUI base URL."
|
| 1314 |
+
)
|
| 1315 |
+
st.text_input(
|
| 1316 |
+
"Your API key",
|
| 1317 |
+
key="discover_llm_byok_key",
|
| 1318 |
+
type="password",
|
| 1319 |
+
placeholder="Paste your API key",
|
| 1320 |
+
)
|
| 1321 |
+
st.text_input(
|
| 1322 |
+
"API base URL",
|
| 1323 |
+
key="discover_llm_byok_base_url",
|
| 1324 |
+
placeholder="Enter service root URL",
|
| 1325 |
+
)
|
| 1326 |
+
st.selectbox(
|
| 1327 |
+
"API provider",
|
| 1328 |
+
options=PROVIDER_OPTIONS,
|
| 1329 |
+
key="discover_llm_byok_provider",
|
| 1330 |
+
format_func=_provider_label,
|
| 1331 |
+
help=(
|
| 1332 |
+
"Use Auto detect for most endpoints. "
|
| 1333 |
+
"Choose a provider explicitly if the base URL is a direct Anthropic or Gemini endpoint, "
|
| 1334 |
+
"or if your gateway does not identify itself clearly."
|
| 1335 |
+
),
|
| 1336 |
+
)
|
| 1337 |
+
st.button("Clear API key", key="clear_byok_key", on_click=clear_byok_key)
|
| 1338 |
+
active_api_key = str(st.session_state.get("discover_llm_byok_key", "")).strip()
|
| 1339 |
+
user_base_url = str(st.session_state.get("discover_llm_byok_base_url", "")).strip().rstrip("/")
|
| 1340 |
+
active_base_url = user_base_url or get_webui_base_url()
|
| 1341 |
+
configured_provider = str(st.session_state.get("discover_llm_byok_provider", "auto")).strip()
|
| 1342 |
+
active_provider = resolve_api_provider(active_base_url, configured_provider) if active_base_url else "auto"
|
| 1343 |
+
fallback_model = default_model_for_provider(active_provider)
|
| 1344 |
+
if user_base_url and not user_base_url.startswith("https://"):
|
| 1345 |
+
st.error("API base URL must start with `https://`.")
|
| 1346 |
+
api_config_invalid = True
|
| 1347 |
+
elif user_base_url:
|
| 1348 |
+
st.caption(f"Detected provider: `{_provider_label(active_provider)}`")
|
| 1349 |
+
if not active_api_key:
|
| 1350 |
+
st.warning("Enter your API key to enable in-app LLM generation.")
|
| 1351 |
|
| 1352 |
available_models: list[str] = []
|
| 1353 |
models_error = ""
|
| 1354 |
if active_api_key and not api_config_invalid:
|
| 1355 |
try:
|
| 1356 |
+
available_models = list_available_models(active_api_key, active_base_url, active_provider)
|
|
|
|
|
|
|
|
|
|
| 1357 |
except Exception as e:
|
| 1358 |
models_error = str(e)
|
| 1359 |
|
| 1360 |
if available_models:
|
| 1361 |
+
model_index = available_models.index(fallback_model) if fallback_model in available_models else 0
|
| 1362 |
+
selected_model = available_models[model_index]
|
| 1363 |
+
st.caption(f"Using model: `{selected_model}`")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1364 |
else:
|
| 1365 |
if models_error:
|
| 1366 |
+
st.warning(f"Could not load model list from API. Using fallback model `{fallback_model}`. Error: {models_error}")
|
| 1367 |
+
selected_model = fallback_model
|
| 1368 |
+
st.caption(f"Using fallback model: `{selected_model}`")
|
|
|
|
|
|
|
|
|
|
| 1369 |
else:
|
| 1370 |
with st.container(border=True):
|
| 1371 |
st.caption(
|
|
|
|
| 1387 |
if show_json_editor:
|
| 1388 |
generate_json_btn = st.button(
|
| 1389 |
"Generate JSON from LLM"
|
| 1390 |
+
if mode == "Bring Your Own Key"
|
| 1391 |
else "Generate JSON from pasted response"
|
| 1392 |
)
|
| 1393 |
|
|
|
|
| 1793 |
|
| 1794 |
|
| 1795 |
def _raw_spec_from_prompt(
|
| 1796 |
+
user_query: str,
|
| 1797 |
+
model_name: str,
|
| 1798 |
+
api_key: str | None = None,
|
| 1799 |
+
base_url: str | None = None,
|
| 1800 |
+
provider: str | None = None,
|
| 1801 |
) -> tuple[dict, list[str], str | None]:
|
| 1802 |
notes: list[str] = []
|
| 1803 |
extracted = {}
|
|
|
|
| 1805 |
return {}, notes, "Please provide a prompt before generating or running discovery."
|
| 1806 |
with st.spinner("Interpreting prompt and preparing discovery config..."):
|
| 1807 |
try:
|
| 1808 |
+
extracted = generate_spec_from_llm(
|
| 1809 |
+
user_query,
|
| 1810 |
+
model_name,
|
| 1811 |
+
api_key=api_key,
|
| 1812 |
+
base_url=base_url,
|
| 1813 |
+
provider=provider,
|
| 1814 |
+
)
|
| 1815 |
except Exception as e:
|
| 1816 |
return {}, notes, f"LLM generation failed: {e}"
|
| 1817 |
|
|
|
|
| 1878 |
|
| 1879 |
|
| 1880 |
if show_json_editor and generate_json_btn:
|
| 1881 |
+
if mode == "Bring Your Own Key" and not llm_query.strip():
|
| 1882 |
st.error("Please provide a prompt before generating JSON.")
|
| 1883 |
st.stop()
|
| 1884 |
if mode == "Bring Your Own Key":
|
| 1885 |
+
byok_err = validate_api_access(active_api_key, active_base_url, active_provider, selected_model)
|
| 1886 |
if byok_err:
|
| 1887 |
st.error(f"BYOK validation failed: {byok_err}")
|
| 1888 |
st.stop()
|
| 1889 |
+
if mode == "Bring Your Own Key":
|
| 1890 |
raw_spec_obj, prep_notes, parse_error = _raw_spec_from_prompt(
|
| 1891 |
+
llm_query,
|
| 1892 |
+
selected_model,
|
| 1893 |
+
api_key=active_api_key,
|
| 1894 |
+
base_url=active_base_url,
|
| 1895 |
+
provider=active_provider,
|
| 1896 |
)
|
| 1897 |
if parse_error:
|
| 1898 |
for msg in prep_notes:
|
|
|
|
| 1918 |
run_btn = st.button("Run discovery", type="primary")
|
| 1919 |
|
| 1920 |
if run_btn:
|
| 1921 |
+
if mode == "Bring Your Own Key" and not llm_query.strip():
|
| 1922 |
st.error("Please provide a prompt before running discovery.")
|
| 1923 |
st.stop()
|
| 1924 |
if mode == "Bring Your Own Key":
|
| 1925 |
+
byok_err = validate_api_access(active_api_key, active_base_url, active_provider, selected_model)
|
| 1926 |
if byok_err:
|
| 1927 |
st.error(f"BYOK validation failed: {byok_err}")
|
| 1928 |
st.stop()
|
|
|
|
| 1939 |
raw_spec_obj = {}
|
| 1940 |
prep_notes.append("Invalid JSON detected. Using fixed template defaults.")
|
| 1941 |
else:
|
| 1942 |
+
if mode == "Bring Your Own Key":
|
| 1943 |
raw_spec_obj, llm_notes, parse_error = _raw_spec_from_prompt(
|
| 1944 |
+
llm_query,
|
| 1945 |
+
selected_model,
|
| 1946 |
+
api_key=active_api_key,
|
| 1947 |
+
base_url=active_base_url,
|
| 1948 |
+
provider=active_provider,
|
| 1949 |
)
|
| 1950 |
if parse_error:
|
| 1951 |
for msg in llm_notes:
|
|
|
|
| 1985 |
st.session_state["discovery_mode_used"] = mode
|
| 1986 |
st.session_state["discovery_api_key"] = active_api_key if mode == "Bring Your Own Key" else ""
|
| 1987 |
st.session_state["discovery_api_base_url"] = active_base_url if mode == "Bring Your Own Key" else ""
|
| 1988 |
+
st.session_state["discovery_api_provider"] = active_provider if mode == "Bring Your Own Key" else ""
|
| 1989 |
st.session_state["discovery_reasoning_text"] = None
|
| 1990 |
st.session_state["discovery_reasoning_key"] = None
|
| 1991 |
st.session_state["discovery_reasoning_note"] = None
|
|
|
|
| 2057 |
c3.metric("Pareto pool", int(stats.get("n_pareto_pool", 0)))
|
| 2058 |
c4.metric("Selected", int(stats.get("n_selected", 0)))
|
| 2059 |
|
| 2060 |
+
if mode_used == "Bring Your Own Key":
|
| 2061 |
reasoning_api_key = st.session_state.get("discovery_api_key", "")
|
| 2062 |
reasoning_api_base_url = st.session_state.get("discovery_api_base_url", "")
|
| 2063 |
+
reasoning_api_provider = st.session_state.get("discovery_api_provider", "openwebui")
|
| 2064 |
reasoning_key_obj = {
|
| 2065 |
"spec": resolved_spec,
|
| 2066 |
"model": model_used,
|
| 2067 |
"mode": mode_used,
|
| 2068 |
+
"provider": reasoning_api_provider,
|
| 2069 |
"selected_smiles_head": (
|
| 2070 |
out_df["SMILES"].astype(str).head(20).tolist()
|
| 2071 |
if isinstance(out_df, pd.DataFrame) and "SMILES" in out_df.columns
|
|
|
|
| 2086 |
model_used,
|
| 2087 |
api_key=(str(reasoning_api_key).strip() or None),
|
| 2088 |
base_url=(str(reasoning_api_base_url).strip() or None),
|
| 2089 |
+
provider=(str(reasoning_api_provider).strip() or None),
|
| 2090 |
)
|
| 2091 |
st.session_state["discovery_reasoning_note"] = None
|
| 2092 |
except Exception as e:
|
|
|
|
| 2130 |
meta = PROPERTY_META[prop_key]
|
| 2131 |
rename_map[c] = f"{meta['name']} ({meta['unit']})"
|
| 2132 |
preview_df = preview_df.rename(columns=rename_map)
|
| 2133 |
+
preview_display = preview_df.head(50).copy()
|
| 2134 |
+
preview_display.index = range(1, len(preview_display) + 1)
|
| 2135 |
+
st.dataframe(preview_display, width="stretch")
|
| 2136 |
|
| 2137 |
st.subheader("📥 Download")
|
| 2138 |
buf = io.StringIO()
|
pages/6_Novel_SMILES_Generation.py
CHANGED
|
@@ -11,59 +11,95 @@ from src.rnn_smiles.generator import (
|
|
| 11 |
load_existing_smiles_set,
|
| 12 |
load_rnn_model,
|
| 13 |
)
|
| 14 |
-
from src.ui_style import apply_global_style
|
| 15 |
|
| 16 |
st.set_page_config(page_title="Novel SMILES Generation", layout="wide")
|
| 17 |
apply_global_style()
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
APP_ROOT = Path(__file__).resolve().parents[1]
|
| 22 |
MODEL_DIR = APP_ROOT / "models" / "rnn" / "pretrained_model"
|
| 23 |
|
| 24 |
DEFAULT_CKPT = MODEL_DIR / "Prior.ckpt"
|
| 25 |
DEFAULT_VOC = MODEL_DIR / "voc"
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
APP_ROOT / "data" / "EXP.csv",
|
| 29 |
APP_ROOT / "data" / "MD.csv",
|
| 30 |
APP_ROOT / "data" / "DFT.csv",
|
| 31 |
APP_ROOT / "data" / "GC.csv",
|
| 32 |
APP_ROOT / "data" / "POLYINFO.csv",
|
|
|
|
|
|
|
| 33 |
APP_ROOT / "data" / "PI1M.csv",
|
| 34 |
]
|
| 35 |
|
| 36 |
-
with st.sidebar:
|
| 37 |
-
st.subheader("Model Assets")
|
| 38 |
-
ckpt_path = st.text_input("Checkpoint path", value=str(DEFAULT_CKPT))
|
| 39 |
-
voc_path = st.text_input("Vocabulary path", value=str(DEFAULT_VOC))
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
st.subheader("Generation Parameters")
|
| 42 |
target_count = st.number_input("Novel SMILES to return", min_value=1, max_value=5000, value=200, step=25)
|
| 43 |
max_length = st.number_input("Max token length", min_value=20, max_value=300, value=140, step=10)
|
| 44 |
temperature = st.slider("Temperature", min_value=0.2, max_value=2.0, value=1.0, step=0.1)
|
| 45 |
max_attempts = st.number_input("Sampling attempts", min_value=1, max_value=50, value=10, step=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
if not Path(ckpt_path).expanduser().exists() or not Path(voc_path).expanduser().exists():
|
| 48 |
st.error("Model files were not found.")
|
| 49 |
-
st.write("Expected default location:")
|
| 50 |
-
st.code(str(MODEL_DIR))
|
| 51 |
st.stop()
|
| 52 |
|
| 53 |
-
available_datasets = [p for p in
|
| 54 |
-
missing_datasets = [p for p in
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
if missing_datasets:
|
| 57 |
st.warning("Some novelty datasets are missing and were skipped.")
|
| 58 |
for path in missing_datasets:
|
| 59 |
st.write(f"- {path.name}")
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
if not available_datasets:
|
| 62 |
st.warning("No novelty datasets found. Results will only be de-duplicated within this run.")
|
| 63 |
|
| 64 |
if st.button("Generate", type="primary"):
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
with st.spinner("Building novelty index (cached after first load)..."):
|
| 69 |
existing_smiles = load_existing_smiles_set(tuple(str(p) for p in available_datasets)) if available_datasets else set()
|
|
@@ -121,7 +157,9 @@ if st.button("Generate", type="primary"):
|
|
| 121 |
st.stop()
|
| 122 |
|
| 123 |
result_df = pd.DataFrame({"SMILES": novel})
|
| 124 |
-
|
|
|
|
|
|
|
| 125 |
st.download_button(
|
| 126 |
"Download CSV",
|
| 127 |
data=result_df.to_csv(index=False).encode("utf-8"),
|
|
|
|
| 11 |
load_existing_smiles_set,
|
| 12 |
load_rnn_model,
|
| 13 |
)
|
| 14 |
+
from src.ui_style import apply_global_style, render_page_header
|
| 15 |
|
| 16 |
st.set_page_config(page_title="Novel SMILES Generation", layout="wide")
|
| 17 |
apply_global_style()
|
| 18 |
+
render_page_header(
|
| 19 |
+
title="Novel SMILES Generation",
|
| 20 |
+
subtitle="Generate candidate polymers with an RNN and filter against local datasets for novelty.",
|
| 21 |
+
badge="Novel SMILES Generation",
|
| 22 |
+
)
|
| 23 |
|
| 24 |
APP_ROOT = Path(__file__).resolve().parents[1]
|
| 25 |
MODEL_DIR = APP_ROOT / "models" / "rnn" / "pretrained_model"
|
| 26 |
|
| 27 |
DEFAULT_CKPT = MODEL_DIR / "Prior.ckpt"
|
| 28 |
DEFAULT_VOC = MODEL_DIR / "voc"
|
| 29 |
+
ckpt_path = str(DEFAULT_CKPT)
|
| 30 |
+
voc_path = str(DEFAULT_VOC)
|
| 31 |
|
| 32 |
+
FAST_NOVELTY_DATASETS = [
|
| 33 |
APP_ROOT / "data" / "EXP.csv",
|
| 34 |
APP_ROOT / "data" / "MD.csv",
|
| 35 |
APP_ROOT / "data" / "DFT.csv",
|
| 36 |
APP_ROOT / "data" / "GC.csv",
|
| 37 |
APP_ROOT / "data" / "POLYINFO.csv",
|
| 38 |
+
]
|
| 39 |
+
SLOW_NOVELTY_DATASETS = [
|
| 40 |
APP_ROOT / "data" / "PI1M.csv",
|
| 41 |
]
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
def _has_smiles_column(path: Path) -> bool:
|
| 45 |
+
try:
|
| 46 |
+
header = pd.read_csv(path, nrows=0)
|
| 47 |
+
except Exception:
|
| 48 |
+
return False
|
| 49 |
+
cols = [str(c).strip().lower() for c in header.columns]
|
| 50 |
+
return any(c in {"smiles", "canonical_smiles", "canonical smiles", "smile", "smi"} or "smiles" in c for c in cols)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
with st.sidebar:
|
| 54 |
st.subheader("Generation Parameters")
|
| 55 |
target_count = st.number_input("Novel SMILES to return", min_value=1, max_value=5000, value=200, step=25)
|
| 56 |
max_length = st.number_input("Max token length", min_value=20, max_value=300, value=140, step=10)
|
| 57 |
temperature = st.slider("Temperature", min_value=0.2, max_value=2.0, value=1.0, step=0.1)
|
| 58 |
max_attempts = st.number_input("Sampling attempts", min_value=1, max_value=50, value=10, step=1)
|
| 59 |
+
include_virtual_novelty = st.checkbox(
|
| 60 |
+
"Include PI1M in novelty filter (slower)",
|
| 61 |
+
value=False,
|
| 62 |
+
help="Off by default for website responsiveness. Enable only if you need novelty checked against the virtual library too.",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
novelty_datasets = list(FAST_NOVELTY_DATASETS)
|
| 66 |
+
if include_virtual_novelty:
|
| 67 |
+
novelty_datasets.extend(SLOW_NOVELTY_DATASETS)
|
| 68 |
|
| 69 |
if not Path(ckpt_path).expanduser().exists() or not Path(voc_path).expanduser().exists():
|
| 70 |
st.error("Model files were not found.")
|
|
|
|
|
|
|
| 71 |
st.stop()
|
| 72 |
|
| 73 |
+
available_datasets = [p for p in novelty_datasets if p.exists() and _has_smiles_column(p)]
|
| 74 |
+
missing_datasets = [p for p in novelty_datasets if not p.exists()]
|
| 75 |
+
invalid_datasets = [p for p in novelty_datasets if p.exists() and not _has_smiles_column(p)]
|
| 76 |
+
|
| 77 |
+
if include_virtual_novelty:
|
| 78 |
+
st.caption("Full novelty mode includes PI1M and may take significantly longer on the first run.")
|
| 79 |
+
else:
|
| 80 |
+
st.caption("Fast novelty mode checks EXP, MD, DFT, GC, and POLYINFO. PI1M is excluded by default for website responsiveness.")
|
| 81 |
|
| 82 |
if missing_datasets:
|
| 83 |
st.warning("Some novelty datasets are missing and were skipped.")
|
| 84 |
for path in missing_datasets:
|
| 85 |
st.write(f"- {path.name}")
|
| 86 |
|
| 87 |
+
if invalid_datasets:
|
| 88 |
+
st.warning("Some novelty datasets are malformed or missing a SMILES column and were skipped.")
|
| 89 |
+
for path in invalid_datasets:
|
| 90 |
+
st.write(f"- {path.name}")
|
| 91 |
+
|
| 92 |
if not available_datasets:
|
| 93 |
st.warning("No novelty datasets found. Results will only be de-duplicated within this run.")
|
| 94 |
|
| 95 |
if st.button("Generate", type="primary"):
|
| 96 |
+
try:
|
| 97 |
+
with st.spinner("Loading RNN model (cached after first load)..."):
|
| 98 |
+
model, voc = load_rnn_model(ckpt_path, voc_path)
|
| 99 |
+
except Exception as exc:
|
| 100 |
+
st.error(f"Failed to load the RNN checkpoint: {exc}")
|
| 101 |
+
st.info("If you see a Git LFS pointer error, replace `models/rnn/pretrained_model/Prior.ckpt` with the real model file.")
|
| 102 |
+
st.stop()
|
| 103 |
|
| 104 |
with st.spinner("Building novelty index (cached after first load)..."):
|
| 105 |
existing_smiles = load_existing_smiles_set(tuple(str(p) for p in available_datasets)) if available_datasets else set()
|
|
|
|
| 157 |
st.stop()
|
| 158 |
|
| 159 |
result_df = pd.DataFrame({"SMILES": novel})
|
| 160 |
+
display_df = result_df.copy()
|
| 161 |
+
display_df.index = range(1, len(display_df) + 1)
|
| 162 |
+
st.dataframe(display_df, width="stretch")
|
| 163 |
st.download_button(
|
| 164 |
"Download CSV",
|
| 165 |
data=result_df.to_csv(index=False).encode("utf-8"),
|