import streamlit as st # --- Page Configuration --- st.set_page_config( layout="wide", page_title="柳暗花明 (flowillower)", page_icon=":sunrise_over_mountains:", initial_sidebar_state="expanded", ) from pathlib import Path import plotly.graph_objects as go import plotly.express as px from plotly.subplots import make_subplots import pandas as pd import time # --- Logo --- st.logo("logo.png", icon_image="logo.png") # 导入重构后的模块 try: from infra import DATA_ROOT_PATH, AppMode from data_models import Study, Trial # Study, Trial will be used from data_loader import discover_studies_cached, ensure_data_directory_exists from theme_selector import render_theme_selector # 新增:导入主题选择器 except ImportError as e: st.error( f"导入模块失败,请确保 utils.py, data_models.py, data_loader.py, theme_selector.py 文件存在于正确的位置: {e}" ) st.stop() # --- 应用状态管理 --- if "selected_study_name" not in st.session_state: st.session_state.selected_study_name = None if "selected_trial_name" not in st.session_state: st.session_state.selected_trial_name = None # if "studies_data" not in st.session_state: # Not directly used, discover_studies_cached returns objects # st.session_state.studies_data = {} if "app_mode" not in st.session_state: st.session_state.app_mode = AppMode.VIEWING # 新增: 用于跨图表共享选中的 global_step if "shared_selected_global_step" not in st.session_state: st.session_state.shared_selected_global_step = None # 新增: 自动播放相关状态 if "is_auto_playing" not in st.session_state: st.session_state.is_auto_playing = False if "auto_play_speed" not in st.session_state: st.session_state.auto_play_speed = 1.0 if "auto_play_needs_rerun" not in st.session_state: st.session_state.auto_play_needs_rerun = False # --- UI Rendering --- # --- Header --- header_cols = st.columns([2, 3, 1.5, 0.5, 0.5, 0.5, 1]) # 新增一列用于主题选择器 with header_cols[0]: st.markdown("## 柳暗花明") st.caption("flowillower") ensure_data_directory_exists(DATA_ROOT_PATH) all_study_objects = discover_studies_cached(DATA_ROOT_PATH) study_names = list(all_study_objects.keys()) if not study_names: st.warning( f"在 {DATA_ROOT_PATH} 未找到任何 Study。请确保您的数据结构正确或使用 flowillower API 开始记录实验。" ) if study_names: with header_cols[1]: if st.session_state.selected_study_name not in study_names: st.session_state.selected_study_name = ( study_names[0] if study_names else None ) selected_study_name_from_ui = st.selectbox( "选择 Study (Select Study)", study_names, index=study_names.index(st.session_state.selected_study_name) if st.session_state.selected_study_name in study_names else 0, label_visibility="collapsed", key="study_selector_main_ui", ) if selected_study_name_from_ui != st.session_state.selected_study_name: st.session_state.selected_study_name = selected_study_name_from_ui st.session_state.selected_trial_name = None st.session_state.shared_selected_global_step = None # Study 变化时清除高亮 st.rerun() with header_cols[2]: if st.session_state.selected_study_name: st.write(f"当前 Study: **{st.session_state.selected_study_name}**") else: with header_cols[1]: st.info("没有可用的 Study。") with header_cols[3]: st.button("➕", help="添加 (Add)", disabled=True) with header_cols[4]: st.button("⚙️", help="设置 (Settings)", disabled=True) with header_cols[5]: st.button("👤", help="用户 (User)", disabled=True) with header_cols[6]: # 新增:主题选择器列 with st.container(): # st.markdown("**主题**") render_theme_selector() st.markdown("---") # --- Sidebar --- current_study: Study | None = None if ( st.session_state.selected_study_name and st.session_state.selected_study_name in all_study_objects ): current_study = all_study_objects[st.session_state.selected_study_name] if not current_study.trials: current_study.discover_trials_cached() trial_names = list(current_study.trials.keys()) if current_study else [] with st.sidebar: st.markdown("### Study") if current_study: st.markdown(f"##### {current_study.name}") if st.button("刷新 Study 数据 (Refresh Study Data)", use_container_width=True): current_study.clear_cache() st.rerun() if st.button("概览 (Overview)", use_container_width=True, disabled=True): st.toast("功能待实现") if st.button( "图表对比视图 (Chart Comparison View)", use_container_width=True, disabled=True, ): st.toast("功能待实现") else: st.markdown("未选择 Study") st.markdown("---") st.markdown("### Trial") if current_study and trial_names: if st.session_state.selected_trial_name not in trial_names: st.session_state.selected_trial_name = ( trial_names[0] if trial_names else None ) selected_trial_name_from_ui = st.radio( "选择 Trial (Select Trial)", trial_names, index=trial_names.index(st.session_state.selected_trial_name) if st.session_state.selected_trial_name in trial_names else 0, label_visibility="collapsed", key="trial_selector_sidebar_ui", ) if selected_trial_name_from_ui != st.session_state.selected_trial_name: st.session_state.selected_trial_name = selected_trial_name_from_ui st.session_state.shared_selected_global_step = None # Trial 变化时清除高亮 st.rerun() if st.session_state.selected_trial_name: st.markdown(f"当前选择: **{st.session_state.selected_trial_name}**") elif current_study: st.info(f"Study '{current_study.name}' 中没有 Trial。") else: st.info("请先选择一个 Study。") st.markdown("---") if st.button("⚙️ App 设置 (App Settings)", use_container_width=True, disabled=True): st.toast("功能待实现") # --- Main Content Area --- current_trial: Trial | None = None if ( current_study and st.session_state.selected_trial_name and st.session_state.selected_trial_name in current_study.trials ): current_trial = current_study.trials[st.session_state.selected_trial_name] current_trial.load_input_variables_cached() current_trial.load_metrics_cached() if current_study and current_trial: main_title_cols = st.columns([3, 1, 0.5]) with main_title_cols[0]: st.markdown(f"## {current_trial.name}") st.caption(f"属于 Study: {current_study.name}") with main_title_cols[1]: if st.button("刷新 Trial 数据 (Refresh Trial Data)", type="secondary"): current_trial.clear_cache() st.rerun() with main_title_cols[2]: st.button("...", help="更多选项 (More Options)", disabled=True) # 添加全局步骤控制器 if current_trial.metrics_data: st.markdown("### 全局步骤控制 (Global Step Control)") # 获取所有指标的全局步骤范围 all_global_steps = set() for metric_name in current_trial.metrics_data.keys(): df_metric = current_trial.get_metric_dataframe(metric_name) if ( df_metric is not None and not df_metric.empty and "global_step" in df_metric.columns ): all_global_steps.update(df_metric["global_step"].tolist()) if all_global_steps: all_global_steps = sorted(list(all_global_steps)) min_step, max_step = min(all_global_steps), max(all_global_steps) # 控制器布局 control_cols = st.columns([3, 1, 1, 1]) with control_cols[0]: # 滑动条 if st.session_state.shared_selected_global_step is None: # 默认选择最后一个step st.session_state.shared_selected_global_step = max_step # 确保当前选中的步骤在有效范围内 if st.session_state.shared_selected_global_step not in all_global_steps: # 找到最接近的有效步骤 closest_step = min( all_global_steps, key=lambda x: abs( x - st.session_state.shared_selected_global_step ), ) st.session_state.shared_selected_global_step = closest_step selected_step = st.select_slider( "选择全局步骤", options=all_global_steps, value=st.session_state.shared_selected_global_step, format_func=lambda x: f"Step {x}", key="global_step_slider", ) if selected_step != st.session_state.shared_selected_global_step: st.session_state.shared_selected_global_step = selected_step st.rerun() with control_cols[1]: # 播放/暂停按钮 if st.session_state.is_auto_playing: if st.button("⏸️ 暂停", type="primary", use_container_width=True): st.session_state.is_auto_playing = False st.rerun() else: if st.button("▶️ 播放", type="primary", use_container_width=True): st.session_state.is_auto_playing = True st.rerun() with control_cols[2]: # 速度控制 speed = st.selectbox( "播放速度", options=[0.5, 1.0, 2.0, 4.0], index=[0.5, 1.0, 2.0, 4.0].index(st.session_state.auto_play_speed), format_func=lambda x: f"{x}x", key="speed_selector", ) if speed != st.session_state.auto_play_speed: st.session_state.auto_play_speed = speed with control_cols[3]: # 重置按钮 if st.button("🔄 重置", use_container_width=True): st.session_state.shared_selected_global_step = min_step st.session_state.is_auto_playing = False st.rerun() # 自动播放逻辑 - 设置标志但不立即rerun if st.session_state.is_auto_playing: current_index = all_global_steps.index( st.session_state.shared_selected_global_step ) if current_index < len(all_global_steps) - 1: # 等待指定时间后移动到下一步 time.sleep(1.0 / st.session_state.auto_play_speed) st.session_state.shared_selected_global_step = all_global_steps[ current_index + 1 ] st.session_state.auto_play_needs_rerun = True else: # 到达末尾,停止播放 st.session_state.is_auto_playing = False st.session_state.auto_play_needs_rerun = True # 显示当前步骤信息 st.info( f"当前选中步骤: **{st.session_state.shared_selected_global_step}** / {max_step}" ) st.markdown("---") tab_titles = [ "图表 (Charts)", "参数 (Parameters)", "系统 (System)", "日志 (Logs)", "环境 (Environment)", ] tab_charts, tab_params, tab_system, tab_logs, tab_env = st.tabs(tab_titles) with tab_charts: st.header("指标图表 (Metrics Charts)") st.markdown("---") if not current_trial.metrics_data: st.info("当前 Trial 没有可显示的指标数据。") else: num_metrics = len(current_trial.metrics_data) cols_per_row = st.slider( "每行图表数量 (Charts per row)", 1, 4, min(2, num_metrics) if num_metrics > 0 else 1, key=f"cols_slider_{current_study.name}_{current_trial.name}", ) metric_names = sorted(list(current_trial.metrics_data.keys())) for i in range(0, num_metrics, cols_per_row): metric_chunk = metric_names[i : i + cols_per_row] chart_cols = st.columns(cols_per_row) for j, metric_name in enumerate(metric_chunk): with chart_cols[j]: df_metric = current_trial.get_metric_dataframe(metric_name) if df_metric is None or df_metric.empty: st.warning(f"指标 '{metric_name}' 数据不完整或缺失。") continue with st.container(border=True): st.subheader(metric_name) # 添加metric组件 - 显示当前值和增量 try: current_step = ( st.session_state.shared_selected_global_step ) # 获取所有可能的track all_tracks = ( df_metric["track"].unique() if "track" in df_metric.columns else [None] ) # 为每个track创建metric组件 if len(all_tracks) > 1: metric_cols = st.columns(len(all_tracks)) else: metric_cols = [st] # 使用整个容器 for idx, track in enumerate(all_tracks): # 查找当前步骤的数据 if track is not None: current_step_data = df_metric[ (df_metric["global_step"] == current_step) & (df_metric["track"] == track) ] else: current_step_data = df_metric[ df_metric["global_step"] == current_step ] current_value = None delta_value = None # 如果当前步骤没有该track的数据,向前查找最近的步骤 if current_step_data.empty: # 向前查找最近的有该track数据的步骤 current_index = all_global_steps.index( current_step ) for search_idx in range( current_index - 1, -1, -1 ): search_step = all_global_steps[search_idx] if track is not None: search_data = df_metric[ ( df_metric["global_step"] == search_step ) & (df_metric["track"] == track) ] else: search_data = df_metric[ df_metric["global_step"] == search_step ] if not search_data.empty: current_value = search_data[ "value" ].iloc[0] current_step_found = search_step break else: current_value = current_step_data["value"].iloc[ 0 ] current_step_found = current_step # 计算增量:查找比当前找到的步骤更早的数据 if current_value is not None: current_found_index = all_global_steps.index( current_step_found ) for prev_idx in range( current_found_index - 1, -1, -1 ): prev_step = all_global_steps[prev_idx] if track is not None: prev_step_data = df_metric[ ( df_metric["global_step"] == prev_step ) & (df_metric["track"] == track) ] else: prev_step_data = df_metric[ df_metric["global_step"] == prev_step ] if not prev_step_data.empty: prev_value = prev_step_data[ "value" ].iloc[0] delta_value = current_value - prev_value break # 显示metric组件 # print(metric_cols) # print(metric_cols[idx]) # print(len(metric_cols), len(all_tracks)) metric_col = metric_cols[0] if len(metric_cols) == 1 else metric_cols[idx] try: with ( metric_col ): if current_value is not None: # 确定label if track is not None: if current_step_found != current_step: label = f"{track} (Step {current_step_found})" else: label = f"{track}" else: if current_step_found != current_step: label = f"当前值 (Step {current_step_found})" else: label = ( f"当前值 (Step {current_step})" ) st.metric( label=label, value=f"{current_value:.4f}", delta=f"{delta_value:.4f}" if delta_value is not None else None, ) else: # 没有找到任何数据 track_label = ( track if track is not None else "数据" ) st.metric( label=f"{track_label}", value="无数据", delta=None, ) except Exception as e: if current_value is not None: # 确定label if track is not None: if current_step_found != current_step: label = f"{track} (Step {current_step_found})" else: label = f"{track}" else: if current_step_found != current_step: label = f"当前值 (Step {current_step_found})" else: label = ( f"当前值 (Step {current_step})" ) st.metric( label=label, value=f"{current_value:.4f}", delta=f"{delta_value:.4f}" if delta_value is not None else None, ) else: # 没有找到任何数据 track_label = ( track if track is not None else "数据" ) st.metric( label=f"{track_label}", value="无数据", delta=None, ) except Exception as e: st.warning(f"计算指标值时出错: {e}") raise e try: # 创建 Plotly 图表 fig = go.Figure() # 按track分组绘制线条 if "track" in df_metric.columns: tracks = df_metric["track"].unique() colors = px.colors.qualitative.Set1[: len(tracks)] for k, track in enumerate(tracks): track_data = df_metric[ df_metric["track"] == track ] fig.add_trace( go.Scatter( x=track_data["global_step"], y=track_data["value"], mode="lines+markers", name=track, line=dict( color=colors[k % len(colors)] ), marker=dict( size=6, color=colors[k % len(colors)], line=dict(width=1, color="white"), ), customdata=track_data[ ["global_step", "value", "track"] ], hovertemplate="%{fullData.name}
" + "Global Step: %{x}
" + "Value: %{y}
" + "", ) ) else: # 如果没有track列,绘制单条线 fig.add_trace( go.Scatter( x=df_metric["global_step"], y=df_metric["value"], mode="lines+markers", name=metric_name, marker=dict( size=6, line=dict(width=1, color="white"), ), customdata=df_metric[ ["global_step", "value"] ], hovertemplate="Global Step: %{x}
" + "Value: %{y}
" + "", ) ) # 如果有共享的选中步骤,添加高亮线 if ( st.session_state.shared_selected_global_step is not None ): fig.add_vline( x=st.session_state.shared_selected_global_step, line_width=2, line_dash="solid", line_color="firebrick", opacity=0.9, ) # 设置图表布局 fig.update_layout( title=None, xaxis_title="全局步骤 (Global Step)", yaxis_title=metric_name, height=400, margin=dict(l=0, r=0, t=0, b=0), showlegend=True if "track" in df_metric.columns and len(df_metric["track"].unique()) > 1 else False, hovermode="closest", ) # 显示图表并处理点击事件 chart_key = f"chart_metric_{current_study.name}_{current_trial.name}_{metric_name}" clicked_points = st.plotly_chart( fig, use_container_width=True, key=chart_key, on_select="rerun", ) # 处理点击事件 if clicked_points and "selection" in clicked_points: selection = clicked_points["selection"] if ( "points" in selection and len(selection["points"]) > 0 ): # 获取第一个点击点的 x 坐标 (global_step) clicked_x = selection["points"][0]["x"] if clicked_x is not None: new_step = int(clicked_x) if ( st.session_state.get( "shared_selected_global_step" ) != new_step ): st.session_state.shared_selected_global_step = new_step # 点击图表时停止自动播放 st.session_state.is_auto_playing = False st.rerun() except Exception as e: st.error(f"为指标 '{metric_name}' 生成图表时出错: {e}") st.dataframe(df_metric) # raise e with tab_params: st.header("输入参数 (Input Parameters)") if current_trial.input_variables: st.json(current_trial.input_variables) else: st.info("未找到 `input_variables.toml` 或文件为空。") for tab_content, name in [ (tab_system, "系统监控"), (tab_logs, "日志"), (tab_env, "环境"), ]: with tab_content: st.header(name) st.info("此功能待您的 `flowillower` API 提供相关数据后实现。") elif not st.session_state.selected_study_name: st.info("👈 请从顶部选择一个 Study 开始。") elif not st.session_state.selected_trial_name: st.info("👈 请从侧边栏选择一个 Trial。") else: st.info("请选择 Study 和 Trial 以查看数据。") st.markdown("---") st.caption("柳暗花明 (flowillower) - 数据可视化App") # 在页面最后处理自动播放的rerun if st.session_state.get("auto_play_needs_rerun", False): st.session_state.auto_play_needs_rerun = False st.rerun()