import streamlit as st
# --- Page Configuration ---
st.set_page_config(
layout="wide",
page_title="柳暗花明 (flowillower)",
page_icon=":sunrise_over_mountains:",
initial_sidebar_state="expanded",
)
from pathlib import Path
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import pandas as pd
import time
# --- Logo ---
st.logo("logo.png", icon_image="logo.png")
# 导入重构后的模块
try:
from infra import DATA_ROOT_PATH, AppMode
from data_models import Study, Trial # Study, Trial will be used
from data_loader import discover_studies_cached, ensure_data_directory_exists
from theme_selector import render_theme_selector # 新增:导入主题选择器
except ImportError as e:
st.error(
f"导入模块失败,请确保 utils.py, data_models.py, data_loader.py, theme_selector.py 文件存在于正确的位置: {e}"
)
st.stop()
# --- 应用状态管理 ---
if "selected_study_name" not in st.session_state:
st.session_state.selected_study_name = None
if "selected_trial_name" not in st.session_state:
st.session_state.selected_trial_name = None
# if "studies_data" not in st.session_state: # Not directly used, discover_studies_cached returns objects
# st.session_state.studies_data = {}
if "app_mode" not in st.session_state:
st.session_state.app_mode = AppMode.VIEWING
# 新增: 用于跨图表共享选中的 global_step
if "shared_selected_global_step" not in st.session_state:
st.session_state.shared_selected_global_step = None
# 新增: 自动播放相关状态
if "is_auto_playing" not in st.session_state:
st.session_state.is_auto_playing = False
if "auto_play_speed" not in st.session_state:
st.session_state.auto_play_speed = 1.0
if "auto_play_needs_rerun" not in st.session_state:
st.session_state.auto_play_needs_rerun = False
# --- UI Rendering ---
# --- Header ---
header_cols = st.columns([2, 3, 1.5, 0.5, 0.5, 0.5, 1]) # 新增一列用于主题选择器
with header_cols[0]:
st.markdown("## 柳暗花明")
st.caption("flowillower")
ensure_data_directory_exists(DATA_ROOT_PATH)
all_study_objects = discover_studies_cached(DATA_ROOT_PATH)
study_names = list(all_study_objects.keys())
if not study_names:
st.warning(
f"在 {DATA_ROOT_PATH} 未找到任何 Study。请确保您的数据结构正确或使用 flowillower API 开始记录实验。"
)
if study_names:
with header_cols[1]:
if st.session_state.selected_study_name not in study_names:
st.session_state.selected_study_name = (
study_names[0] if study_names else None
)
selected_study_name_from_ui = st.selectbox(
"选择 Study (Select Study)",
study_names,
index=study_names.index(st.session_state.selected_study_name)
if st.session_state.selected_study_name in study_names
else 0,
label_visibility="collapsed",
key="study_selector_main_ui",
)
if selected_study_name_from_ui != st.session_state.selected_study_name:
st.session_state.selected_study_name = selected_study_name_from_ui
st.session_state.selected_trial_name = None
st.session_state.shared_selected_global_step = None # Study 变化时清除高亮
st.rerun()
with header_cols[2]:
if st.session_state.selected_study_name:
st.write(f"当前 Study: **{st.session_state.selected_study_name}**")
else:
with header_cols[1]:
st.info("没有可用的 Study。")
with header_cols[3]:
st.button("➕", help="添加 (Add)", disabled=True)
with header_cols[4]:
st.button("⚙️", help="设置 (Settings)", disabled=True)
with header_cols[5]:
st.button("👤", help="用户 (User)", disabled=True)
with header_cols[6]: # 新增:主题选择器列
with st.container():
# st.markdown("**主题**")
render_theme_selector()
st.markdown("---")
# --- Sidebar ---
current_study: Study | None = None
if (
st.session_state.selected_study_name
and st.session_state.selected_study_name in all_study_objects
):
current_study = all_study_objects[st.session_state.selected_study_name]
if not current_study.trials:
current_study.discover_trials_cached()
trial_names = list(current_study.trials.keys()) if current_study else []
with st.sidebar:
st.markdown("### Study")
if current_study:
st.markdown(f"##### {current_study.name}")
if st.button("刷新 Study 数据 (Refresh Study Data)", use_container_width=True):
current_study.clear_cache()
st.rerun()
if st.button("概览 (Overview)", use_container_width=True, disabled=True):
st.toast("功能待实现")
if st.button(
"图表对比视图 (Chart Comparison View)",
use_container_width=True,
disabled=True,
):
st.toast("功能待实现")
else:
st.markdown("未选择 Study")
st.markdown("---")
st.markdown("### Trial")
if current_study and trial_names:
if st.session_state.selected_trial_name not in trial_names:
st.session_state.selected_trial_name = (
trial_names[0] if trial_names else None
)
selected_trial_name_from_ui = st.radio(
"选择 Trial (Select Trial)",
trial_names,
index=trial_names.index(st.session_state.selected_trial_name)
if st.session_state.selected_trial_name in trial_names
else 0,
label_visibility="collapsed",
key="trial_selector_sidebar_ui",
)
if selected_trial_name_from_ui != st.session_state.selected_trial_name:
st.session_state.selected_trial_name = selected_trial_name_from_ui
st.session_state.shared_selected_global_step = None # Trial 变化时清除高亮
st.rerun()
if st.session_state.selected_trial_name:
st.markdown(f"当前选择: **{st.session_state.selected_trial_name}**")
elif current_study:
st.info(f"Study '{current_study.name}' 中没有 Trial。")
else:
st.info("请先选择一个 Study。")
st.markdown("---")
if st.button("⚙️ App 设置 (App Settings)", use_container_width=True, disabled=True):
st.toast("功能待实现")
# --- Main Content Area ---
current_trial: Trial | None = None
if (
current_study
and st.session_state.selected_trial_name
and st.session_state.selected_trial_name in current_study.trials
):
current_trial = current_study.trials[st.session_state.selected_trial_name]
current_trial.load_input_variables_cached()
current_trial.load_metrics_cached()
if current_study and current_trial:
main_title_cols = st.columns([3, 1, 0.5])
with main_title_cols[0]:
st.markdown(f"## {current_trial.name}")
st.caption(f"属于 Study: {current_study.name}")
with main_title_cols[1]:
if st.button("刷新 Trial 数据 (Refresh Trial Data)", type="secondary"):
current_trial.clear_cache()
st.rerun()
with main_title_cols[2]:
st.button("...", help="更多选项 (More Options)", disabled=True)
# 添加全局步骤控制器
if current_trial.metrics_data:
st.markdown("### 全局步骤控制 (Global Step Control)")
# 获取所有指标的全局步骤范围
all_global_steps = set()
for metric_name in current_trial.metrics_data.keys():
df_metric = current_trial.get_metric_dataframe(metric_name)
if (
df_metric is not None
and not df_metric.empty
and "global_step" in df_metric.columns
):
all_global_steps.update(df_metric["global_step"].tolist())
if all_global_steps:
all_global_steps = sorted(list(all_global_steps))
min_step, max_step = min(all_global_steps), max(all_global_steps)
# 控制器布局
control_cols = st.columns([3, 1, 1, 1])
with control_cols[0]:
# 滑动条
if st.session_state.shared_selected_global_step is None:
# 默认选择最后一个step
st.session_state.shared_selected_global_step = max_step
# 确保当前选中的步骤在有效范围内
if st.session_state.shared_selected_global_step not in all_global_steps:
# 找到最接近的有效步骤
closest_step = min(
all_global_steps,
key=lambda x: abs(
x - st.session_state.shared_selected_global_step
),
)
st.session_state.shared_selected_global_step = closest_step
selected_step = st.select_slider(
"选择全局步骤",
options=all_global_steps,
value=st.session_state.shared_selected_global_step,
format_func=lambda x: f"Step {x}",
key="global_step_slider",
)
if selected_step != st.session_state.shared_selected_global_step:
st.session_state.shared_selected_global_step = selected_step
st.rerun()
with control_cols[1]:
# 播放/暂停按钮
if st.session_state.is_auto_playing:
if st.button("⏸️ 暂停", type="primary", use_container_width=True):
st.session_state.is_auto_playing = False
st.rerun()
else:
if st.button("▶️ 播放", type="primary", use_container_width=True):
st.session_state.is_auto_playing = True
st.rerun()
with control_cols[2]:
# 速度控制
speed = st.selectbox(
"播放速度",
options=[0.5, 1.0, 2.0, 4.0],
index=[0.5, 1.0, 2.0, 4.0].index(st.session_state.auto_play_speed),
format_func=lambda x: f"{x}x",
key="speed_selector",
)
if speed != st.session_state.auto_play_speed:
st.session_state.auto_play_speed = speed
with control_cols[3]:
# 重置按钮
if st.button("🔄 重置", use_container_width=True):
st.session_state.shared_selected_global_step = min_step
st.session_state.is_auto_playing = False
st.rerun()
# 自动播放逻辑 - 设置标志但不立即rerun
if st.session_state.is_auto_playing:
current_index = all_global_steps.index(
st.session_state.shared_selected_global_step
)
if current_index < len(all_global_steps) - 1:
# 等待指定时间后移动到下一步
time.sleep(1.0 / st.session_state.auto_play_speed)
st.session_state.shared_selected_global_step = all_global_steps[
current_index + 1
]
st.session_state.auto_play_needs_rerun = True
else:
# 到达末尾,停止播放
st.session_state.is_auto_playing = False
st.session_state.auto_play_needs_rerun = True
# 显示当前步骤信息
st.info(
f"当前选中步骤: **{st.session_state.shared_selected_global_step}** / {max_step}"
)
st.markdown("---")
tab_titles = [
"图表 (Charts)",
"参数 (Parameters)",
"系统 (System)",
"日志 (Logs)",
"环境 (Environment)",
]
tab_charts, tab_params, tab_system, tab_logs, tab_env = st.tabs(tab_titles)
with tab_charts:
st.header("指标图表 (Metrics Charts)")
st.markdown("---")
if not current_trial.metrics_data:
st.info("当前 Trial 没有可显示的指标数据。")
else:
num_metrics = len(current_trial.metrics_data)
cols_per_row = st.slider(
"每行图表数量 (Charts per row)",
1,
4,
min(2, num_metrics) if num_metrics > 0 else 1,
key=f"cols_slider_{current_study.name}_{current_trial.name}",
)
metric_names = sorted(list(current_trial.metrics_data.keys()))
for i in range(0, num_metrics, cols_per_row):
metric_chunk = metric_names[i : i + cols_per_row]
chart_cols = st.columns(cols_per_row)
for j, metric_name in enumerate(metric_chunk):
with chart_cols[j]:
df_metric = current_trial.get_metric_dataframe(metric_name)
if df_metric is None or df_metric.empty:
st.warning(f"指标 '{metric_name}' 数据不完整或缺失。")
continue
with st.container(border=True):
st.subheader(metric_name)
# 添加metric组件 - 显示当前值和增量
try:
current_step = (
st.session_state.shared_selected_global_step
)
# 获取所有可能的track
all_tracks = (
df_metric["track"].unique()
if "track" in df_metric.columns
else [None]
)
# 为每个track创建metric组件
if len(all_tracks) > 1:
metric_cols = st.columns(len(all_tracks))
else:
metric_cols = [st] # 使用整个容器
for idx, track in enumerate(all_tracks):
# 查找当前步骤的数据
if track is not None:
current_step_data = df_metric[
(df_metric["global_step"] == current_step)
& (df_metric["track"] == track)
]
else:
current_step_data = df_metric[
df_metric["global_step"] == current_step
]
current_value = None
delta_value = None
# 如果当前步骤没有该track的数据,向前查找最近的步骤
if current_step_data.empty:
# 向前查找最近的有该track数据的步骤
current_index = all_global_steps.index(
current_step
)
for search_idx in range(
current_index - 1, -1, -1
):
search_step = all_global_steps[search_idx]
if track is not None:
search_data = df_metric[
(
df_metric["global_step"]
== search_step
)
& (df_metric["track"] == track)
]
else:
search_data = df_metric[
df_metric["global_step"]
== search_step
]
if not search_data.empty:
current_value = search_data[
"value"
].iloc[0]
current_step_found = search_step
break
else:
current_value = current_step_data["value"].iloc[
0
]
current_step_found = current_step
# 计算增量:查找比当前找到的步骤更早的数据
if current_value is not None:
current_found_index = all_global_steps.index(
current_step_found
)
for prev_idx in range(
current_found_index - 1, -1, -1
):
prev_step = all_global_steps[prev_idx]
if track is not None:
prev_step_data = df_metric[
(
df_metric["global_step"]
== prev_step
)
& (df_metric["track"] == track)
]
else:
prev_step_data = df_metric[
df_metric["global_step"]
== prev_step
]
if not prev_step_data.empty:
prev_value = prev_step_data[
"value"
].iloc[0]
delta_value = current_value - prev_value
break
# 显示metric组件
# print(metric_cols)
# print(metric_cols[idx])
# print(len(metric_cols), len(all_tracks))
metric_col = metric_cols[0] if len(metric_cols) == 1 else metric_cols[idx]
try:
with (
metric_col
):
if current_value is not None:
# 确定label
if track is not None:
if current_step_found != current_step:
label = f"{track} (Step {current_step_found})"
else:
label = f"{track}"
else:
if current_step_found != current_step:
label = f"当前值 (Step {current_step_found})"
else:
label = (
f"当前值 (Step {current_step})"
)
st.metric(
label=label,
value=f"{current_value:.4f}",
delta=f"{delta_value:.4f}"
if delta_value is not None
else None,
)
else:
# 没有找到任何数据
track_label = (
track if track is not None else "数据"
)
st.metric(
label=f"{track_label}",
value="无数据",
delta=None,
)
except Exception as e:
if current_value is not None:
# 确定label
if track is not None:
if current_step_found != current_step:
label = f"{track} (Step {current_step_found})"
else:
label = f"{track}"
else:
if current_step_found != current_step:
label = f"当前值 (Step {current_step_found})"
else:
label = (
f"当前值 (Step {current_step})"
)
st.metric(
label=label,
value=f"{current_value:.4f}",
delta=f"{delta_value:.4f}"
if delta_value is not None
else None,
)
else:
# 没有找到任何数据
track_label = (
track if track is not None else "数据"
)
st.metric(
label=f"{track_label}",
value="无数据",
delta=None,
)
except Exception as e:
st.warning(f"计算指标值时出错: {e}")
raise e
try:
# 创建 Plotly 图表
fig = go.Figure()
# 按track分组绘制线条
if "track" in df_metric.columns:
tracks = df_metric["track"].unique()
colors = px.colors.qualitative.Set1[: len(tracks)]
for k, track in enumerate(tracks):
track_data = df_metric[
df_metric["track"] == track
]
fig.add_trace(
go.Scatter(
x=track_data["global_step"],
y=track_data["value"],
mode="lines+markers",
name=track,
line=dict(
color=colors[k % len(colors)]
),
marker=dict(
size=6,
color=colors[k % len(colors)],
line=dict(width=1, color="white"),
),
customdata=track_data[
["global_step", "value", "track"]
],
hovertemplate="%{fullData.name}
"
+ "Global Step: %{x}
"
+ "Value: %{y}
"
+ "",
)
)
else:
# 如果没有track列,绘制单条线
fig.add_trace(
go.Scatter(
x=df_metric["global_step"],
y=df_metric["value"],
mode="lines+markers",
name=metric_name,
marker=dict(
size=6,
line=dict(width=1, color="white"),
),
customdata=df_metric[
["global_step", "value"]
],
hovertemplate="Global Step: %{x}
"
+ "Value: %{y}
"
+ "",
)
)
# 如果有共享的选中步骤,添加高亮线
if (
st.session_state.shared_selected_global_step
is not None
):
fig.add_vline(
x=st.session_state.shared_selected_global_step,
line_width=2,
line_dash="solid",
line_color="firebrick",
opacity=0.9,
)
# 设置图表布局
fig.update_layout(
title=None,
xaxis_title="全局步骤 (Global Step)",
yaxis_title=metric_name,
height=400,
margin=dict(l=0, r=0, t=0, b=0),
showlegend=True
if "track" in df_metric.columns
and len(df_metric["track"].unique()) > 1
else False,
hovermode="closest",
)
# 显示图表并处理点击事件
chart_key = f"chart_metric_{current_study.name}_{current_trial.name}_{metric_name}"
clicked_points = st.plotly_chart(
fig,
use_container_width=True,
key=chart_key,
on_select="rerun",
)
# 处理点击事件
if clicked_points and "selection" in clicked_points:
selection = clicked_points["selection"]
if (
"points" in selection
and len(selection["points"]) > 0
):
# 获取第一个点击点的 x 坐标 (global_step)
clicked_x = selection["points"][0]["x"]
if clicked_x is not None:
new_step = int(clicked_x)
if (
st.session_state.get(
"shared_selected_global_step"
)
!= new_step
):
st.session_state.shared_selected_global_step = new_step
# 点击图表时停止自动播放
st.session_state.is_auto_playing = False
st.rerun()
except Exception as e:
st.error(f"为指标 '{metric_name}' 生成图表时出错: {e}")
st.dataframe(df_metric)
# raise e
with tab_params:
st.header("输入参数 (Input Parameters)")
if current_trial.input_variables:
st.json(current_trial.input_variables)
else:
st.info("未找到 `input_variables.toml` 或文件为空。")
for tab_content, name in [
(tab_system, "系统监控"),
(tab_logs, "日志"),
(tab_env, "环境"),
]:
with tab_content:
st.header(name)
st.info("此功能待您的 `flowillower` API 提供相关数据后实现。")
elif not st.session_state.selected_study_name:
st.info("👈 请从顶部选择一个 Study 开始。")
elif not st.session_state.selected_trial_name:
st.info("👈 请从侧边栏选择一个 Trial。")
else:
st.info("请选择 Study 和 Trial 以查看数据。")
st.markdown("---")
st.caption("柳暗花明 (flowillower) - 数据可视化App")
# 在页面最后处理自动播放的rerun
if st.session_state.get("auto_play_needs_rerun", False):
st.session_state.auto_play_needs_rerun = False
st.rerun()