Dixing (Dex) Xu
commited on
:sparkles: Add validation plot and score for webui (#28)
Browse files* :sparkles: Add validation plot and score for webui
* Add validation plot
* Add validation score
* Update style.css
* :art: Put the best validation score under the tab
* :rotating_light: update lint and example text
- aide/webui/app.py +112 -15
- aide/webui/style.css +1 -1
aide/webui/app.py
CHANGED
|
@@ -158,24 +158,35 @@ class WebUI:
|
|
| 158 |
Returns:
|
| 159 |
list: List of uploaded or example files.
|
| 160 |
"""
|
| 161 |
-
if
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
if st.session_state.get("example_files"):
|
| 167 |
st.info("Example files loaded! Click 'Run AIDE' to proceed.")
|
| 168 |
with st.expander("View Loaded Files", expanded=False):
|
| 169 |
for file in st.session_state.example_files:
|
| 170 |
st.text(f"📄 {file['name']}")
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
"Upload Data Files",
|
| 175 |
-
accept_multiple_files=True,
|
| 176 |
-
type=["csv", "txt", "json", "md"],
|
| 177 |
-
)
|
| 178 |
-
return uploaded_files
|
| 179 |
|
| 180 |
def handle_user_inputs(self):
|
| 181 |
"""
|
|
@@ -187,12 +198,12 @@ class WebUI:
|
|
| 187 |
goal_text = st.text_area(
|
| 188 |
"Goal",
|
| 189 |
value=st.session_state.get("goal", ""),
|
| 190 |
-
placeholder="Example: Predict house
|
| 191 |
)
|
| 192 |
eval_text = st.text_area(
|
| 193 |
"Evaluation Criteria",
|
| 194 |
value=st.session_state.get("eval", ""),
|
| 195 |
-
placeholder="Example: Use RMSE metric",
|
| 196 |
)
|
| 197 |
num_steps = st.slider(
|
| 198 |
"Number of Steps",
|
|
@@ -450,7 +461,16 @@ class WebUI:
|
|
| 450 |
st.header("Results")
|
| 451 |
if st.session_state.get("results"):
|
| 452 |
results = st.session_state.results
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
with tabs[0]:
|
| 456 |
self.render_tree_visualization(results)
|
|
@@ -460,6 +480,12 @@ class WebUI:
|
|
| 460 |
self.render_config(results)
|
| 461 |
with tabs[3]:
|
| 462 |
self.render_journal(results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
else:
|
| 464 |
st.info("No results to display. Please run an experiment.")
|
| 465 |
|
|
@@ -529,6 +555,77 @@ class WebUI:
|
|
| 529 |
else:
|
| 530 |
st.info("No journal available.")
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
|
| 533 |
if __name__ == "__main__":
|
| 534 |
app = WebUI()
|
|
|
|
| 158 |
Returns:
|
| 159 |
list: List of uploaded or example files.
|
| 160 |
"""
|
| 161 |
+
# Only show file uploader if no example files are loaded
|
| 162 |
+
if not st.session_state.get("example_files"):
|
| 163 |
+
uploaded_files = st.file_uploader(
|
| 164 |
+
"Upload Data Files",
|
| 165 |
+
accept_multiple_files=True,
|
| 166 |
+
type=["csv", "txt", "json", "md"],
|
| 167 |
+
label_visibility="collapsed",
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if uploaded_files:
|
| 171 |
+
st.session_state.pop(
|
| 172 |
+
"example_files", None
|
| 173 |
+
) # Remove example files if any
|
| 174 |
+
return uploaded_files
|
| 175 |
+
|
| 176 |
+
# Only show example button if no files are uploaded
|
| 177 |
+
if st.button(
|
| 178 |
+
"Load Example Experiment", type="primary", use_container_width=True
|
| 179 |
+
):
|
| 180 |
+
st.session_state.example_files = self.load_example_files()
|
| 181 |
|
| 182 |
if st.session_state.get("example_files"):
|
| 183 |
st.info("Example files loaded! Click 'Run AIDE' to proceed.")
|
| 184 |
with st.expander("View Loaded Files", expanded=False):
|
| 185 |
for file in st.session_state.example_files:
|
| 186 |
st.text(f"📄 {file['name']}")
|
| 187 |
+
return st.session_state.example_files
|
| 188 |
+
|
| 189 |
+
return [] # Return empty list if no files are uploaded or loaded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
def handle_user_inputs(self):
|
| 192 |
"""
|
|
|
|
| 198 |
goal_text = st.text_area(
|
| 199 |
"Goal",
|
| 200 |
value=st.session_state.get("goal", ""),
|
| 201 |
+
placeholder="Example: Predict the sales price for each house",
|
| 202 |
)
|
| 203 |
eval_text = st.text_area(
|
| 204 |
"Evaluation Criteria",
|
| 205 |
value=st.session_state.get("eval", ""),
|
| 206 |
+
placeholder="Example: Use the RMSE metric between the logarithm of the predicted and observed values.",
|
| 207 |
)
|
| 208 |
num_steps = st.slider(
|
| 209 |
"Number of Steps",
|
|
|
|
| 461 |
st.header("Results")
|
| 462 |
if st.session_state.get("results"):
|
| 463 |
results = st.session_state.results
|
| 464 |
+
|
| 465 |
+
tabs = st.tabs(
|
| 466 |
+
[
|
| 467 |
+
"Tree Visualization",
|
| 468 |
+
"Best Solution",
|
| 469 |
+
"Config",
|
| 470 |
+
"Journal",
|
| 471 |
+
"Validation Plot",
|
| 472 |
+
]
|
| 473 |
+
)
|
| 474 |
|
| 475 |
with tabs[0]:
|
| 476 |
self.render_tree_visualization(results)
|
|
|
|
| 480 |
self.render_config(results)
|
| 481 |
with tabs[3]:
|
| 482 |
self.render_journal(results)
|
| 483 |
+
with tabs[4]:
|
| 484 |
+
# Display best score before the plot
|
| 485 |
+
best_metric = self.get_best_metric(results)
|
| 486 |
+
if best_metric is not None:
|
| 487 |
+
st.metric("Best Validation Score", f"{best_metric:.4f}")
|
| 488 |
+
self.render_validation_plot(results)
|
| 489 |
else:
|
| 490 |
st.info("No results to display. Please run an experiment.")
|
| 491 |
|
|
|
|
| 555 |
else:
|
| 556 |
st.info("No journal available.")
|
| 557 |
|
| 558 |
+
@staticmethod
|
| 559 |
+
def get_best_metric(results):
|
| 560 |
+
"""
|
| 561 |
+
Extract the best validation metric from results.
|
| 562 |
+
"""
|
| 563 |
+
try:
|
| 564 |
+
journal_data = json.loads(results["journal"])
|
| 565 |
+
metrics = []
|
| 566 |
+
for node in journal_data:
|
| 567 |
+
if node["metric"] is not None:
|
| 568 |
+
try:
|
| 569 |
+
# Convert string metric to float
|
| 570 |
+
metric_value = float(node["metric"])
|
| 571 |
+
metrics.append(metric_value)
|
| 572 |
+
except (ValueError, TypeError):
|
| 573 |
+
continue
|
| 574 |
+
return max(metrics) if metrics else None
|
| 575 |
+
except (json.JSONDecodeError, KeyError):
|
| 576 |
+
return None
|
| 577 |
+
|
| 578 |
+
@staticmethod
|
| 579 |
+
def render_validation_plot(results):
|
| 580 |
+
"""
|
| 581 |
+
Render the validation score plot.
|
| 582 |
+
"""
|
| 583 |
+
try:
|
| 584 |
+
journal_data = json.loads(results["journal"])
|
| 585 |
+
steps = []
|
| 586 |
+
metrics = []
|
| 587 |
+
|
| 588 |
+
for node in journal_data:
|
| 589 |
+
if node["metric"] is not None and node["metric"].lower() != "none":
|
| 590 |
+
try:
|
| 591 |
+
metric_value = float(node["metric"])
|
| 592 |
+
steps.append(node["step"])
|
| 593 |
+
metrics.append(metric_value)
|
| 594 |
+
except (ValueError, TypeError):
|
| 595 |
+
continue
|
| 596 |
+
|
| 597 |
+
if metrics:
|
| 598 |
+
import plotly.graph_objects as go
|
| 599 |
+
|
| 600 |
+
fig = go.Figure()
|
| 601 |
+
fig.add_trace(
|
| 602 |
+
go.Scatter(
|
| 603 |
+
x=steps,
|
| 604 |
+
y=metrics,
|
| 605 |
+
mode="lines+markers",
|
| 606 |
+
name="Validation Score",
|
| 607 |
+
line=dict(color="#F04370"),
|
| 608 |
+
marker=dict(color="#F04370"),
|
| 609 |
+
)
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
fig.update_layout(
|
| 613 |
+
title="Validation Score Progress",
|
| 614 |
+
xaxis_title="Step",
|
| 615 |
+
yaxis_title="Validation Score",
|
| 616 |
+
template="plotly_white",
|
| 617 |
+
hovermode="x unified",
|
| 618 |
+
plot_bgcolor="rgba(0,0,0,0)",
|
| 619 |
+
paper_bgcolor="rgba(0,0,0,0)",
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 623 |
+
else:
|
| 624 |
+
st.info("No validation metrics available to plot.")
|
| 625 |
+
|
| 626 |
+
except (json.JSONDecodeError, KeyError):
|
| 627 |
+
st.error("Could not parse validation metrics data.")
|
| 628 |
+
|
| 629 |
|
| 630 |
if __name__ == "__main__":
|
| 631 |
app = WebUI()
|
aide/webui/style.css
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
/* Main colors */
|
| 2 |
:root {
|
| 3 |
--background: #F2F0E7;
|
| 4 |
-
--background-shaded: #
|
| 5 |
--card: #FFFFFF;
|
| 6 |
--primary: #0D0F18;
|
| 7 |
--accent: #F04370;
|
|
|
|
| 1 |
/* Main colors */
|
| 2 |
:root {
|
| 3 |
--background: #F2F0E7;
|
| 4 |
+
--background-shaded: #FFFFFF;
|
| 5 |
--card: #FFFFFF;
|
| 6 |
--primary: #0D0F18;
|
| 7 |
--accent: #F04370;
|