Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| # --- Page Configuration --- | |
| st.set_page_config( | |
| layout="wide", | |
| page_title="柳暗花明 (flowillower)", | |
| page_icon=":sunrise_over_mountains:", | |
| initial_sidebar_state="expanded", | |
| ) | |
| from pathlib import Path | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from plotly.subplots import make_subplots | |
| import pandas as pd | |
| import time | |
| # --- Logo --- | |
| st.logo("logo.png", icon_image="logo.png") | |
| # 导入重构后的模块 | |
| try: | |
| from utils import DATA_ROOT_PATH, AppMode | |
| from data_models import Study, Trial # Study, Trial will be used | |
| from data_loader import discover_studies_cached, ensure_data_directory_exists | |
| from theme_selector import render_theme_selector # 新增:导入主题选择器 | |
| except ImportError as e: | |
| st.error( | |
| f"导入模块失败,请确保 utils.py, data_models.py, data_loader.py, theme_selector.py 文件存在于正确的位置: {e}" | |
| ) | |
| st.stop() | |
| # --- 应用状态管理 --- | |
| if "selected_study_name" not in st.session_state: | |
| st.session_state.selected_study_name = None | |
| if "selected_trial_name" not in st.session_state: | |
| st.session_state.selected_trial_name = None | |
| # if "studies_data" not in st.session_state: # Not directly used, discover_studies_cached returns objects | |
| # st.session_state.studies_data = {} | |
| if "app_mode" not in st.session_state: | |
| st.session_state.app_mode = AppMode.VIEWING | |
| # 新增: 用于跨图表共享选中的 global_step | |
| if "shared_selected_global_step" not in st.session_state: | |
| st.session_state.shared_selected_global_step = None | |
| # 新增: 自动播放相关状态 | |
| if "is_auto_playing" not in st.session_state: | |
| st.session_state.is_auto_playing = False | |
| if "auto_play_speed" not in st.session_state: | |
| st.session_state.auto_play_speed = 1.0 | |
| if "auto_play_needs_rerun" not in st.session_state: | |
| st.session_state.auto_play_needs_rerun = False | |
| # --- UI Rendering --- | |
| # --- Header --- | |
| header_cols = st.columns([2, 3, 1.5, 0.5, 0.5, 0.5, 1]) # 新增一列用于主题选择器 | |
| with header_cols[0]: | |
| st.markdown("## 柳暗花明") | |
| st.caption("flowillower") | |
| ensure_data_directory_exists(DATA_ROOT_PATH) | |
| all_study_objects = discover_studies_cached(DATA_ROOT_PATH) | |
| study_names = list(all_study_objects.keys()) | |
| if not study_names: | |
| st.warning( | |
| f"在 {DATA_ROOT_PATH} 未找到任何 Study。请确保您的数据结构正确或使用 flowillower API 开始记录实验。" | |
| ) | |
| if study_names: | |
| with header_cols[1]: | |
| if st.session_state.selected_study_name not in study_names: | |
| st.session_state.selected_study_name = ( | |
| study_names[0] if study_names else None | |
| ) | |
| selected_study_name_from_ui = st.selectbox( | |
| "选择 Study (Select Study)", | |
| study_names, | |
| index=study_names.index(st.session_state.selected_study_name) | |
| if st.session_state.selected_study_name in study_names | |
| else 0, | |
| label_visibility="collapsed", | |
| key="study_selector_main_ui", | |
| ) | |
| if selected_study_name_from_ui != st.session_state.selected_study_name: | |
| st.session_state.selected_study_name = selected_study_name_from_ui | |
| st.session_state.selected_trial_name = None | |
| st.session_state.shared_selected_global_step = None # Study 变化时清除高亮 | |
| st.rerun() | |
| with header_cols[2]: | |
| if st.session_state.selected_study_name: | |
| st.write(f"当前 Study: **{st.session_state.selected_study_name}**") | |
| else: | |
| with header_cols[1]: | |
| st.info("没有可用的 Study。") | |
| with header_cols[3]: | |
| st.button("➕", help="添加 (Add)", disabled=True) | |
| with header_cols[4]: | |
| st.button("⚙️", help="设置 (Settings)", disabled=True) | |
| with header_cols[5]: | |
| st.button("👤", help="用户 (User)", disabled=True) | |
| with header_cols[6]: # 新增:主题选择器列 | |
| with st.container(): | |
| # st.markdown("**主题**") | |
| render_theme_selector() | |
| st.markdown("---") | |
| # --- Sidebar --- | |
| current_study: Study | None = None | |
| if ( | |
| st.session_state.selected_study_name | |
| and st.session_state.selected_study_name in all_study_objects | |
| ): | |
| current_study = all_study_objects[st.session_state.selected_study_name] | |
| if not current_study.trials: | |
| current_study.discover_trials_cached() | |
| trial_names = list(current_study.trials.keys()) if current_study else [] | |
| with st.sidebar: | |
| st.markdown("### Study") | |
| if current_study: | |
| st.markdown(f"##### {current_study.name}") | |
| if st.button("刷新 Study 数据 (Refresh Study Data)", use_container_width=True): | |
| current_study.clear_cache() | |
| st.rerun() | |
| if st.button("概览 (Overview)", use_container_width=True, disabled=True): | |
| st.toast("功能待实现") | |
| if st.button( | |
| "图表对比视图 (Chart Comparison View)", | |
| use_container_width=True, | |
| disabled=True, | |
| ): | |
| st.toast("功能待实现") | |
| else: | |
| st.markdown("未选择 Study") | |
| st.markdown("---") | |
| st.markdown("### Trial") | |
| if current_study and trial_names: | |
| if st.session_state.selected_trial_name not in trial_names: | |
| st.session_state.selected_trial_name = ( | |
| trial_names[0] if trial_names else None | |
| ) | |
| selected_trial_name_from_ui = st.radio( | |
| "选择 Trial (Select Trial)", | |
| trial_names, | |
| index=trial_names.index(st.session_state.selected_trial_name) | |
| if st.session_state.selected_trial_name in trial_names | |
| else 0, | |
| label_visibility="collapsed", | |
| key="trial_selector_sidebar_ui", | |
| ) | |
| if selected_trial_name_from_ui != st.session_state.selected_trial_name: | |
| st.session_state.selected_trial_name = selected_trial_name_from_ui | |
| st.session_state.shared_selected_global_step = None # Trial 变化时清除高亮 | |
| st.rerun() | |
| if st.session_state.selected_trial_name: | |
| st.markdown(f"当前选择: **{st.session_state.selected_trial_name}**") | |
| elif current_study: | |
| st.info(f"Study '{current_study.name}' 中没有 Trial。") | |
| else: | |
| st.info("请先选择一个 Study。") | |
| st.markdown("---") | |
| if st.button("⚙️ App 设置 (App Settings)", use_container_width=True, disabled=True): | |
| st.toast("功能待实现") | |
| # --- Main Content Area --- | |
| current_trial: Trial | None = None | |
| if ( | |
| current_study | |
| and st.session_state.selected_trial_name | |
| and st.session_state.selected_trial_name in current_study.trials | |
| ): | |
| current_trial = current_study.trials[st.session_state.selected_trial_name] | |
| current_trial.load_input_variables_cached() | |
| current_trial.load_metrics_cached() | |
| if current_study and current_trial: | |
| main_title_cols = st.columns([3, 1, 0.5]) | |
| with main_title_cols[0]: | |
| st.markdown(f"## {current_trial.name}") | |
| st.caption(f"属于 Study: {current_study.name}") | |
| with main_title_cols[1]: | |
| if st.button("刷新 Trial 数据 (Refresh Trial Data)", type="secondary"): | |
| current_trial.clear_cache() | |
| st.rerun() | |
| with main_title_cols[2]: | |
| st.button("...", help="更多选项 (More Options)", disabled=True) | |
| # 添加全局步骤控制器 | |
| if current_trial.metrics_data: | |
| st.markdown("### 全局步骤控制 (Global Step Control)") | |
| # 获取所有指标的全局步骤范围 | |
| all_global_steps = set() | |
| for metric_name in current_trial.metrics_data.keys(): | |
| df_metric = current_trial.get_metric_dataframe(metric_name) | |
| if ( | |
| df_metric is not None | |
| and not df_metric.empty | |
| and "global_step" in df_metric.columns | |
| ): | |
| all_global_steps.update(df_metric["global_step"].tolist()) | |
| if all_global_steps: | |
| all_global_steps = sorted(list(all_global_steps)) | |
| min_step, max_step = min(all_global_steps), max(all_global_steps) | |
| # 控制器布局 | |
| control_cols = st.columns([3, 1, 1, 1]) | |
| with control_cols[0]: | |
| # 滑动条 | |
| if st.session_state.shared_selected_global_step is None: | |
| # 默认选择最后一个step | |
| st.session_state.shared_selected_global_step = max_step | |
| # 确保当前选中的步骤在有效范围内 | |
| if st.session_state.shared_selected_global_step not in all_global_steps: | |
| # 找到最接近的有效步骤 | |
| closest_step = min( | |
| all_global_steps, | |
| key=lambda x: abs( | |
| x - st.session_state.shared_selected_global_step | |
| ), | |
| ) | |
| st.session_state.shared_selected_global_step = closest_step | |
| selected_step = st.select_slider( | |
| "选择全局步骤", | |
| options=all_global_steps, | |
| value=st.session_state.shared_selected_global_step, | |
| format_func=lambda x: f"Step {x}", | |
| key="global_step_slider", | |
| ) | |
| if selected_step != st.session_state.shared_selected_global_step: | |
| st.session_state.shared_selected_global_step = selected_step | |
| st.rerun() | |
| with control_cols[1]: | |
| # 播放/暂停按钮 | |
| if st.session_state.is_auto_playing: | |
| if st.button("⏸️ 暂停", type="primary", use_container_width=True): | |
| st.session_state.is_auto_playing = False | |
| st.rerun() | |
| else: | |
| if st.button("▶️ 播放", type="primary", use_container_width=True): | |
| st.session_state.is_auto_playing = True | |
| st.rerun() | |
| with control_cols[2]: | |
| # 速度控制 | |
| speed = st.selectbox( | |
| "播放速度", | |
| options=[0.5, 1.0, 2.0, 4.0], | |
| index=[0.5, 1.0, 2.0, 4.0].index(st.session_state.auto_play_speed), | |
| format_func=lambda x: f"{x}x", | |
| key="speed_selector", | |
| ) | |
| if speed != st.session_state.auto_play_speed: | |
| st.session_state.auto_play_speed = speed | |
| with control_cols[3]: | |
| # 重置按钮 | |
| if st.button("🔄 重置", use_container_width=True): | |
| st.session_state.shared_selected_global_step = min_step | |
| st.session_state.is_auto_playing = False | |
| st.rerun() | |
| # 自动播放逻辑 - 设置标志但不立即rerun | |
| if st.session_state.is_auto_playing: | |
| current_index = all_global_steps.index( | |
| st.session_state.shared_selected_global_step | |
| ) | |
| if current_index < len(all_global_steps) - 1: | |
| # 等待指定时间后移动到下一步 | |
| time.sleep(1.0 / st.session_state.auto_play_speed) | |
| st.session_state.shared_selected_global_step = all_global_steps[ | |
| current_index + 1 | |
| ] | |
| st.session_state.auto_play_needs_rerun = True | |
| else: | |
| # 到达末尾,停止播放 | |
| st.session_state.is_auto_playing = False | |
| st.session_state.auto_play_needs_rerun = True | |
| # 显示当前步骤信息 | |
| st.info( | |
| f"当前选中步骤: **{st.session_state.shared_selected_global_step}** / {max_step}" | |
| ) | |
| st.markdown("---") | |
| tab_titles = [ | |
| "图表 (Charts)", | |
| "参数 (Parameters)", | |
| "系统 (System)", | |
| "日志 (Logs)", | |
| "环境 (Environment)", | |
| ] | |
| tab_charts, tab_params, tab_system, tab_logs, tab_env = st.tabs(tab_titles) | |
| with tab_charts: | |
| st.header("指标图表 (Metrics Charts)") | |
| st.markdown("---") | |
| if not current_trial.metrics_data: | |
| st.info("当前 Trial 没有可显示的指标数据。") | |
| else: | |
| num_metrics = len(current_trial.metrics_data) | |
| cols_per_row = st.slider( | |
| "每行图表数量 (Charts per row)", | |
| 1, | |
| 4, | |
| min(2, num_metrics) if num_metrics > 0 else 1, | |
| key=f"cols_slider_{current_study.name}_{current_trial.name}", | |
| ) | |
| metric_names = sorted(list(current_trial.metrics_data.keys())) | |
| for i in range(0, num_metrics, cols_per_row): | |
| metric_chunk = metric_names[i : i + cols_per_row] | |
| chart_cols = st.columns(cols_per_row) | |
| for j, metric_name in enumerate(metric_chunk): | |
| with chart_cols[j]: | |
| df_metric = current_trial.get_metric_dataframe(metric_name) | |
| if df_metric is None or df_metric.empty: | |
| st.warning(f"指标 '{metric_name}' 数据不完整或缺失。") | |
| continue | |
| with st.container(border=True): | |
| st.subheader(metric_name) | |
| # 添加metric组件 - 显示当前值和增量 | |
| try: | |
| current_step = ( | |
| st.session_state.shared_selected_global_step | |
| ) | |
| # 获取所有可能的track | |
| all_tracks = ( | |
| df_metric["track"].unique() | |
| if "track" in df_metric.columns | |
| else [None] | |
| ) | |
| # 为每个track创建metric组件 | |
| if len(all_tracks) > 1: | |
| metric_cols = st.columns(len(all_tracks)) | |
| else: | |
| metric_cols = [st] # 使用整个容器 | |
| for idx, track in enumerate(all_tracks): | |
| # 查找当前步骤的数据 | |
| if track is not None: | |
| current_step_data = df_metric[ | |
| (df_metric["global_step"] == current_step) | |
| & (df_metric["track"] == track) | |
| ] | |
| else: | |
| current_step_data = df_metric[ | |
| df_metric["global_step"] == current_step | |
| ] | |
| current_value = None | |
| delta_value = None | |
| # 如果当前步骤没有该track的数据,向前查找最近的步骤 | |
| if current_step_data.empty: | |
| # 向前查找最近的有该track数据的步骤 | |
| current_index = all_global_steps.index( | |
| current_step | |
| ) | |
| for search_idx in range( | |
| current_index - 1, -1, -1 | |
| ): | |
| search_step = all_global_steps[search_idx] | |
| if track is not None: | |
| search_data = df_metric[ | |
| ( | |
| df_metric["global_step"] | |
| == search_step | |
| ) | |
| & (df_metric["track"] == track) | |
| ] | |
| else: | |
| search_data = df_metric[ | |
| df_metric["global_step"] | |
| == search_step | |
| ] | |
| if not search_data.empty: | |
| current_value = search_data[ | |
| "value" | |
| ].iloc[0] | |
| current_step_found = search_step | |
| break | |
| else: | |
| current_value = current_step_data["value"].iloc[ | |
| 0 | |
| ] | |
| current_step_found = current_step | |
| # 计算增量:查找比当前找到的步骤更早的数据 | |
| if current_value is not None: | |
| current_found_index = all_global_steps.index( | |
| current_step_found | |
| ) | |
| for prev_idx in range( | |
| current_found_index - 1, -1, -1 | |
| ): | |
| prev_step = all_global_steps[prev_idx] | |
| if track is not None: | |
| prev_step_data = df_metric[ | |
| ( | |
| df_metric["global_step"] | |
| == prev_step | |
| ) | |
| & (df_metric["track"] == track) | |
| ] | |
| else: | |
| prev_step_data = df_metric[ | |
| df_metric["global_step"] | |
| == prev_step | |
| ] | |
| if not prev_step_data.empty: | |
| prev_value = prev_step_data[ | |
| "value" | |
| ].iloc[0] | |
| delta_value = current_value - prev_value | |
| break | |
| # 显示metric组件 | |
| with ( | |
| metric_cols[idx] | |
| if len(all_tracks) > 1 | |
| else metric_cols[0] | |
| ): | |
| if current_value is not None: | |
| # 确定label | |
| if track is not None: | |
| if current_step_found != current_step: | |
| label = f"{track} (Step {current_step_found})" | |
| else: | |
| label = f"{track}" | |
| else: | |
| if current_step_found != current_step: | |
| label = f"当前值 (Step {current_step_found})" | |
| else: | |
| label = ( | |
| f"当前值 (Step {current_step})" | |
| ) | |
| st.metric( | |
| label=label, | |
| value=f"{current_value:.4f}", | |
| delta=f"{delta_value:.4f}" | |
| if delta_value is not None | |
| else None, | |
| ) | |
| else: | |
| # 没有找到任何数据 | |
| track_label = ( | |
| track if track is not None else "数据" | |
| ) | |
| st.metric( | |
| label=f"{track_label}", | |
| value="无数据", | |
| delta=None, | |
| ) | |
| except Exception as e: | |
| st.warning(f"计算指标值时出错: {e}") | |
| try: | |
| # 创建 Plotly 图表 | |
| fig = go.Figure() | |
| # 按track分组绘制线条 | |
| if "track" in df_metric.columns: | |
| tracks = df_metric["track"].unique() | |
| colors = px.colors.qualitative.Set1[: len(tracks)] | |
| for k, track in enumerate(tracks): | |
| track_data = df_metric[ | |
| df_metric["track"] == track | |
| ] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=track_data["global_step"], | |
| y=track_data["value"], | |
| mode="lines+markers", | |
| name=track, | |
| line=dict( | |
| color=colors[k % len(colors)] | |
| ), | |
| marker=dict( | |
| size=6, | |
| color=colors[k % len(colors)], | |
| line=dict(width=1, color="white"), | |
| ), | |
| customdata=track_data[ | |
| ["global_step", "value", "track"] | |
| ], | |
| hovertemplate="<b>%{fullData.name}</b><br>" | |
| + "Global Step: %{x}<br>" | |
| + "Value: %{y}<br>" | |
| + "<extra></extra>", | |
| ) | |
| ) | |
| else: | |
| # 如果没有track列,绘制单条线 | |
| fig.add_trace( | |
| go.Scatter( | |
| x=df_metric["global_step"], | |
| y=df_metric["value"], | |
| mode="lines+markers", | |
| name=metric_name, | |
| marker=dict( | |
| size=6, | |
| line=dict(width=1, color="white"), | |
| ), | |
| customdata=df_metric[ | |
| ["global_step", "value"] | |
| ], | |
| hovertemplate="Global Step: %{x}<br>" | |
| + "Value: %{y}<br>" | |
| + "<extra></extra>", | |
| ) | |
| ) | |
| # 如果有共享的选中步骤,添加高亮线 | |
| if ( | |
| st.session_state.shared_selected_global_step | |
| is not None | |
| ): | |
| fig.add_vline( | |
| x=st.session_state.shared_selected_global_step, | |
| line_width=2, | |
| line_dash="solid", | |
| line_color="firebrick", | |
| opacity=0.9, | |
| ) | |
| # 设置图表布局 | |
| fig.update_layout( | |
| title=None, | |
| xaxis_title="全局步骤 (Global Step)", | |
| yaxis_title=metric_name, | |
| height=400, | |
| margin=dict(l=0, r=0, t=0, b=0), | |
| showlegend=True | |
| if "track" in df_metric.columns | |
| and len(df_metric["track"].unique()) > 1 | |
| else False, | |
| hovermode="closest", | |
| ) | |
| # 显示图表并处理点击事件 | |
| chart_key = f"chart_metric_{current_study.name}_{current_trial.name}_{metric_name}" | |
| clicked_points = st.plotly_chart( | |
| fig, | |
| use_container_width=True, | |
| key=chart_key, | |
| on_select="rerun", | |
| ) | |
| # 处理点击事件 | |
| if clicked_points and "selection" in clicked_points: | |
| selection = clicked_points["selection"] | |
| if ( | |
| "points" in selection | |
| and len(selection["points"]) > 0 | |
| ): | |
| # 获取第一个点击点的 x 坐标 (global_step) | |
| clicked_x = selection["points"][0]["x"] | |
| if clicked_x is not None: | |
| new_step = int(clicked_x) | |
| if ( | |
| st.session_state.get( | |
| "shared_selected_global_step" | |
| ) | |
| != new_step | |
| ): | |
| st.session_state.shared_selected_global_step = new_step | |
| # 点击图表时停止自动播放 | |
| st.session_state.is_auto_playing = False | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"为指标 '{metric_name}' 生成图表时出错: {e}") | |
| st.dataframe(df_metric) | |
| # raise e | |
| with tab_params: | |
| st.header("输入参数 (Input Parameters)") | |
| if current_trial.input_variables: | |
| st.json(current_trial.input_variables) | |
| else: | |
| st.info("未找到 `input_variables.toml` 或文件为空。") | |
| for tab_content, name in [ | |
| (tab_system, "系统监控"), | |
| (tab_logs, "日志"), | |
| (tab_env, "环境"), | |
| ]: | |
| with tab_content: | |
| st.header(name) | |
| st.info("此功能待您的 `flowillower` API 提供相关数据后实现。") | |
| elif not st.session_state.selected_study_name: | |
| st.info("👈 请从顶部选择一个 Study 开始。") | |
| elif not st.session_state.selected_trial_name: | |
| st.info("👈 请从侧边栏选择一个 Trial。") | |
| else: | |
| st.info("请选择 Study 和 Trial 以查看数据。") | |
| st.markdown("---") | |
| st.caption("柳暗花明 (flowillower) - 数据可视化App") | |
| # 在页面最后处理自动播放的rerun | |
| if st.session_state.get("auto_play_needs_rerun", False): | |
| st.session_state.auto_play_needs_rerun = False | |
| st.rerun() | |