YeCanming commited on
Commit
1e63386
·
1 Parent(s): a4f451c

feat: 初步实现

Browse files
Files changed (6) hide show
  1. .gitignore +154 -1
  2. .streamlit/config.toml +6 -0
  3. src/data_loader.py +188 -0
  4. src/data_models.py +94 -0
  5. src/streamlit_app.py +234 -93
  6. src/utils.py +16 -0
.gitignore CHANGED
@@ -1 +1,154 @@
1
- data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data
2
+ deprecated
3
+ _docs/
4
+ _proc/
5
+
6
+ *.bak
7
+ .gitattributes
8
+ .last_checked
9
+ .gitconfig
10
+ *.bak
11
+ *.log
12
+ *~
13
+ ~*
14
+ _tmp*
15
+ tmp*
16
+ tags
17
+ *.pkg
18
+
19
+ # Byte-compiled / optimized / DLL files
20
+ __pycache__/
21
+ *.py[cod]
22
+ *$py.class
23
+
24
+ # C extensions
25
+ *.so
26
+
27
+ # Distribution / packaging
28
+ .Python
29
+ env/
30
+ build/
31
+ conda/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ lib/
38
+ lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ wheels/
43
+ *.egg-info/
44
+ .installed.cfg
45
+ *.egg
46
+
47
+ # PyInstaller
48
+ # Usually these files are written by a python script from a template
49
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
50
+ *.manifest
51
+ *.spec
52
+
53
+ # Installer logs
54
+ pip-log.txt
55
+ pip-delete-this-directory.txt
56
+
57
+ # Unit test / coverage reports
58
+ htmlcov/
59
+ .tox/
60
+ .coverage
61
+ .coverage.*
62
+ .cache
63
+ nosetests.xml
64
+ coverage.xml
65
+ *.cover
66
+ .hypothesis/
67
+
68
+ # Translations
69
+ *.mo
70
+ *.pot
71
+
72
+ # Django stuff:
73
+ *.log
74
+ local_settings.py
75
+
76
+ # Flask stuff:
77
+ instance/
78
+ .webassets-cache
79
+
80
+ # Scrapy stuff:
81
+ .scrapy
82
+
83
+ # Sphinx documentation
84
+ docs/_build/
85
+
86
+ # PyBuilder
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # pyenv
93
+ .python-version
94
+
95
+ # celery beat schedule file
96
+ celerybeat-schedule
97
+
98
+ # SageMath parsed files
99
+ *.sage.py
100
+
101
+ # dotenv
102
+ .env
103
+
104
+ # virtualenv
105
+ .venv
106
+ venv/
107
+ ENV/
108
+
109
+ # Spyder project settings
110
+ .spyderproject
111
+ .spyproject
112
+
113
+ # Rope project settings
114
+ .ropeproject
115
+
116
+ # mkdocs documentation
117
+ /site
118
+
119
+ # mypy
120
+ .mypy_cache/
121
+
122
+ .vscode
123
+ *.swp
124
+
125
+ # osx generated files
126
+ .DS_Store
127
+ .DS_Store?
128
+ .Trashes
129
+ ehthumbs.db
130
+ Thumbs.db
131
+ .idea
132
+
133
+ # pytest
134
+ .pytest_cache
135
+
136
+ # tools/trust-doc-nbs
137
+ docs_src/.last_checked
138
+
139
+ # symlinks to fastai
140
+ docs_src/fastai
141
+ tools/fastai
142
+
143
+ # link checker
144
+ checklink/cookies.txt
145
+
146
+ # .gitconfig is now autogenerated
147
+ .gitconfig
148
+
149
+ # Quarto installer
150
+ .deb
151
+ .pkg
152
+
153
+ # Quarto
154
+ .quarto
.streamlit/config.toml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [theme]
2
+ primaryColor="#F39C12"
3
+ backgroundColor="#2E86C1"
4
+ secondaryBackgroundColor="#7F8C8D"
5
+ textColor="#FFFFFF"
6
+ font="monospace"
src/data_loader.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_loader.py
2
+ from pathlib import Path
3
+ from typing import Dict, List, Any
4
+ import pandas as pd
5
+ import tomli
6
+ import streamlit as st
7
+ from functools import lru_cache # For non-Streamlit specific caching if needed
8
+
9
+ # Assuming utils.py is in the same directory
10
+ from utils import DATA_ROOT_PATH # Used for ensuring directory exists
11
+
12
+ # --- Cache Clearing Functions ---
13
+ # These are more specific cache clearing functions that can be called by model methods.
14
+
15
+ def clear_study_cache():
16
+ """Clears all study discovery cache."""
17
+ discover_studies_cached.clear()
18
+ st.toast("所有 Study 发现缓存已清除。")
19
+
20
+ def clear_trial_cache():
21
+ """Clears all trial-related data loading caches."""
22
+ # This is a bit broad. Ideally, clear caches for specific trials/studies.
23
+ load_input_variables_from_path.clear()
24
+ load_all_metrics_for_trial_path.clear()
25
+ discover_trials_from_path.clear()
26
+ st.toast("所有 Trial 数据加载缓存已清除。")
27
+
28
+ def clear_specific_trial_metric_cache(trial_path: Path):
29
+ load_all_metrics_for_trial_path.clear() # This clears the whole cache for this func
30
+ # For more granular control with @st.cache_data, you'd typically rely on Streamlit's
31
+ # automatic cache invalidation based on input args, or rerun.
32
+ # If using lru_cache, you could do: load_all_metrics_for_trial_path.cache_clear()
33
+ # but for st.cache_data, clearing for specific args is not direct.
34
+ # The common pattern is to clear the entire function's cache.
35
+ st.toast(f"Trial '{trial_path.name}' 的指标缓存已清除 (函数级别)。")
36
+
37
+
38
+ def clear_specific_trial_input_vars_cache(trial_path: Path):
39
+ load_input_variables_from_path.clear()
40
+ st.toast(f"Trial '{trial_path.name}' 的参数缓存已清除 (函数级别)。")
41
+
42
+ def clear_specific_study_trial_discovery_cache(study_path: Path):
43
+ discover_trials_from_path.clear()
44
+ st.toast(f"Study '{study_path.name}' 的 Trial 发现缓存已清除 (函数级别)。")
45
+
46
+
47
+ # --- Data Discovery and Loading Functions (Cached) ---
48
+
49
+ def ensure_data_directory_exists(data_path: Path = DATA_ROOT_PATH):
50
+ """Ensures the root data directory exists."""
51
+ if not data_path.exists():
52
+ try:
53
+ data_path.mkdir(parents=True, exist_ok=True)
54
+ st.info(f"数据目录 {data_path} 已创建。")
55
+ except Exception as e:
56
+ st.error(f"创建数据目录 {data_path} 失败: {e}")
57
+ st.stop()
58
+ elif not data_path.is_dir():
59
+ st.error(f"路径 {data_path} 已存在但不是一个目录。")
60
+ st.stop()
61
+
62
+
63
+ @st.cache_data(ttl=3600) # Cache for 1 hour, or adjust as needed
64
+ def discover_studies_cached(_data_root: Path) -> Dict[str, Any]: # Return type hint as Any to avoid circular dep with data_models.Study
65
+ """
66
+ Scans the data_root for study directories and returns a dictionary
67
+ mapping study names to Study objects (or just their paths initially).
68
+ The actual Study object creation happens in the main app for now.
69
+ """
70
+ # To avoid issues with caching complex objects directly, or circular dependencies,
71
+ # this function can return simpler structures like Dict[str, Path]
72
+ # and the main app or model can instantiate Study objects.
73
+ # For this iteration, we'll import Study here for convenience, assuming careful structure.
74
+ from data_models import Study # Local import to help with potential circularity if models grow complex
75
+
76
+ if not _data_root.is_dir():
77
+ return {}
78
+ studies = {}
79
+ for d in _data_root.iterdir():
80
+ if d.is_dir():
81
+ studies[d.name] = Study(name=d.name, path=d)
82
+ return studies
83
+
84
+
85
+ @st.cache_data(ttl=3600)
86
+ def discover_trials_from_path(_study_path: Path) -> Dict[str, Path]:
87
+ """Scans a study_path for trial directories."""
88
+ if not _study_path.is_dir():
89
+ return {}
90
+ trials = {}
91
+ for d in _study_path.iterdir():
92
+ if d.is_dir():
93
+ trials[d.name] = d
94
+ return trials
95
+
96
+
97
+ @st.cache_data(ttl=3600)
98
+ def load_input_variables_from_path(_trial_path: Path) -> Dict[str, Any]:
99
+ """Loads input_variables.toml from a trial directory."""
100
+ input_vars_file = _trial_path / "input_variables.toml"
101
+ if input_vars_file.exists():
102
+ try:
103
+ with open(input_vars_file, "rb") as f:
104
+ return tomli.load(f)
105
+ except tomli.TOMLDecodeError:
106
+ # st.error(f"错误:无法解析 input_variables.toml 文件于 {_trial_path}") # Avoid st.error in cached funcs if possible
107
+ print(f"Error parsing input_variables.toml at {_trial_path}")
108
+ return {}
109
+ return {}
110
+
111
+
112
+ def _load_single_metric_toml(_toml_file_path: Path) -> pd.DataFrame:
113
+ """Loads metrics from a single TOML file into a DataFrame."""
114
+ if not _toml_file_path.exists():
115
+ return pd.DataFrame()
116
+ try:
117
+ with open(_toml_file_path, "rb") as f:
118
+ data = tomli.load(f)
119
+ metrics_list = data.get("metrics", [])
120
+ if not metrics_list:
121
+ return pd.DataFrame()
122
+ return pd.DataFrame(metrics_list)
123
+ except tomli.TOMLDecodeError:
124
+ print(f"Error parsing TOML file: {_toml_file_path.name}")
125
+ return pd.DataFrame()
126
+ except Exception as e:
127
+ print(f"Error loading {_toml_file_path.name}: {e}")
128
+ return pd.DataFrame()
129
+
130
+
131
+ @st.cache_data(ttl=300) # Cache metric data for 5 minutes
132
+ def load_all_metrics_for_trial_path(_trial_path: Path) -> Dict[str, pd.DataFrame]:
133
+ """
134
+ Loads all metrics from all tracks in a trial.
135
+ Returns a dictionary where keys are metric names (e.g., 'loss', 'accuracy')
136
+ and values are DataFrames containing 'global_step', 'value', and 'track'.
137
+ """
138
+ scalar_dir = _trial_path / "logs" / "scalar"
139
+ if not scalar_dir.is_dir():
140
+ return {}
141
+
142
+ all_metrics_data_combined: Dict[str, pd.DataFrame] = {}
143
+
144
+ for toml_file in scalar_dir.glob("metrics_*.toml"):
145
+ track_name = toml_file.stem.replace("metrics_", "")
146
+ df_track = _load_single_metric_toml(toml_file)
147
+
148
+ if df_track.empty or "global_step" not in df_track.columns:
149
+ continue
150
+
151
+ id_vars = ["global_step"]
152
+ value_vars = [col for col in df_track.columns if col not in id_vars]
153
+
154
+ if not value_vars:
155
+ continue
156
+
157
+ # Process each metric column individually to build up the combined DataFrame
158
+ for metric_col_name in value_vars:
159
+ try:
160
+ # Create a DataFrame for the current metric and track
161
+ current_metric_df = df_track[["global_step", metric_col_name]].copy()
162
+ current_metric_df.rename(columns={metric_col_name: "value"}, inplace=True)
163
+ current_metric_df["track"] = track_name
164
+ current_metric_df['value'] = pd.to_numeric(current_metric_df['value'], errors='coerce')
165
+ current_metric_df.dropna(subset=['value'], inplace=True)
166
+
167
+ if current_metric_df.empty:
168
+ continue
169
+
170
+ # Append to the combined DataFrame for this metric_col_name
171
+ if metric_col_name not in all_metrics_data_combined:
172
+ all_metrics_data_combined[metric_col_name] = current_metric_df
173
+ else:
174
+ all_metrics_data_combined[metric_col_name] = pd.concat(
175
+ [all_metrics_data_combined[metric_col_name], current_metric_df],
176
+ ignore_index=True
177
+ )
178
+ except Exception as e:
179
+ print(f"Error processing metric '{metric_col_name}' from file '{toml_file.name}': {e}")
180
+ continue
181
+
182
+ # Sort data by global_step for proper line plotting
183
+ for metric_name in all_metrics_data_combined:
184
+ all_metrics_data_combined[metric_name] = all_metrics_data_combined[metric_name].sort_values(
185
+ by=["track", "global_step"]
186
+ ).reset_index(drop=True)
187
+
188
+ return all_metrics_data_combined
src/data_models.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_models.py
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional, Any
5
+ import pandas as pd
6
+ import streamlit as st # For caching
7
+
8
+ # Import from data_loader, assuming it's in the same directory
9
+ # We'll define these functions in data_loader.py
10
+ # To avoid circular imports, data_loader functions won't import data_models directly for type hints if possible,
11
+ # or use string type hints.
12
+
13
+ # Forward declaration for type hint if data_loader needs Study/Trial
14
+ # class Study: pass
15
+ # class Trial: pass
16
+
17
+ from data_loader import (
18
+ load_input_variables_from_path,
19
+ load_all_metrics_for_trial_path,
20
+ discover_trials_from_path,
21
+ clear_trial_cache as clear_trial_loader_cache,
22
+ clear_study_cache as clear_study_loader_cache,
23
+ clear_specific_trial_metric_cache,
24
+ clear_specific_trial_input_vars_cache,
25
+ clear_specific_study_trial_discovery_cache
26
+ )
27
+
28
+
29
+ @dataclass
30
+ class Trial:
31
+ name: str
32
+ path: Path
33
+ study_name: str # To know its parent study
34
+ input_variables: Dict[str, Any] = field(default_factory=dict, repr=False)
35
+ metrics_data: Dict[str, pd.DataFrame] = field(default_factory=dict, repr=False) # Key: metric_name, Value: DataFrame with global_step, value, track
36
+
37
+ def __post_init__(self):
38
+ # Automatically load data if needed, but prefer explicit calls from UI for clarity
39
+ pass
40
+
41
+ # Use st.cache_data on the loader functions, not directly here for complex objects.
42
+ # Instead, methods here will call cached loader functions.
43
+
44
+ def load_input_variables_cached(self):
45
+ """Loads or retrieves cached input variables."""
46
+ if not self.input_variables: # Load only if not already populated
47
+ self.input_variables = load_input_variables_from_path(self.path)
48
+ return self.input_variables
49
+
50
+ def load_metrics_cached(self):
51
+ """Loads or retrieves cached metrics data."""
52
+ if not self.metrics_data: # Load only if not already populated
53
+ self.metrics_data = load_all_metrics_for_trial_path(self.path)
54
+ return self.metrics_data
55
+
56
+ def get_metric_dataframe(self, metric_name: str) -> Optional[pd.DataFrame]:
57
+ """Returns the DataFrame for a specific metric, combining all tracks."""
58
+ return self.metrics_data.get(metric_name)
59
+
60
+ def clear_cache(self):
61
+ """Clears cached data for this specific trial."""
62
+ # Clear Streamlit's cache for functions related to this trial
63
+ clear_specific_trial_metric_cache(self.path)
64
+ clear_specific_trial_input_vars_cache(self.path)
65
+ # Reset instance variables
66
+ self.input_variables = {}
67
+ self.metrics_data = {}
68
+ st.success(f"Trial '{self.name}' 的缓存已清除。")
69
+
70
+
71
+ @dataclass
72
+ class Study:
73
+ name: str
74
+ path: Path
75
+ trials: Dict[str, Trial] = field(default_factory=dict, repr=False)
76
+
77
+ def discover_trials_cached(self):
78
+ """Discovers or retrieves cached trials for this study."""
79
+ if not self.trials: # Discover only if not already populated
80
+ trial_paths = discover_trials_from_path(self.path) # This loader function should be cached
81
+ for trial_name, trial_path in trial_paths.items():
82
+ self.trials[trial_name] = Trial(name=trial_name, path=trial_path, study_name=self.name)
83
+ return self.trials
84
+
85
+ def get_trial(self, trial_name: str) -> Optional[Trial]:
86
+ return self.trials.get(trial_name)
87
+
88
+ def clear_cache(self):
89
+ """Clears cached data for this study and its trials."""
90
+ clear_specific_study_trial_discovery_cache(self.path)
91
+ for trial in self.trials.values():
92
+ trial.clear_cache() # Clear cache for each trial within the study
93
+ self.trials = {} # Reset trials dictionary
94
+ st.success(f"Study '{self.name}' 及其 Trials 的缓存已清除。")
src/streamlit_app.py CHANGED
@@ -1,97 +1,238 @@
1
  import streamlit as st
2
  from pathlib import Path
3
- import tomli
4
- import plotly.graph_objects as go
5
- from typing import List, Dict
6
-
7
- # === 配置路径 ===
8
- DATA_ROOT = Path(__file__).parent.parent / "data"
9
-
10
- # === 页面设置 ===
11
- st.set_page_config(page_title="Flowillower 🌸", layout="wide")
12
- st.markdown("<h1 style='text-align: center; color: #7b68ee;'>Flowillower 🌸</h1>", unsafe_allow_html=True)
13
- st.markdown("<p style='text-align: center;'>Elegant and Transparent Experiment Tracking</p>", unsafe_allow_html=True)
14
- st.divider()
15
-
16
- # === 加载所有 Study ===
17
- def get_all_studies(data_root: Path) -> List[Path]:
18
- return [p for p in data_root.iterdir() if p.is_dir()]
19
-
20
- studies = get_all_studies(DATA_ROOT)
21
- study_names = []
22
-
23
- for s in studies:
24
- meta_file = s / ".study_meta.toml"
25
- if meta_file.exists():
26
- try:
27
- meta = tomli.loads(meta_file.read_text())
28
- name = meta.get("display_name", s.name)
29
- except:
30
- name = s.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  else:
32
- name = s.name
33
- study_names.append((name, s))
34
-
35
- # === Study 选择器 ===
36
- study_name_list = [name for name, _ in study_names]
37
- selected_study_name = st.sidebar.selectbox("Select a Study", study_name_list)
38
- selected_study_path = dict(study_names)[selected_study_name]
39
-
40
- # === Study 概览区 ===
41
- st.subheader(f"📁 Study Overview: {selected_study_name}")
42
- st.markdown("Metadata, description, config summary can be shown here (to be implemented).")
43
- st.divider()
44
-
45
- # === 加载所有 Trial ===
46
- def get_all_trials(study_path: Path) -> List[Path]:
47
- return [p for p in study_path.iterdir() if p.is_dir()]
48
-
49
- trials = get_all_trials(selected_study_path)
50
- trial_names = [t.name for t in trials]
51
- selected_trial_name = st.sidebar.selectbox("🧪 Select Trial", trial_names)
52
- selected_trial_path = selected_study_path / selected_trial_name
53
-
54
- # === 加载 track 的 metric 数据 ===
55
- def load_track_data(trial_path: Path) -> Dict[str, List[Dict]]:
56
- track_dir = trial_path / "logs/scalar"
57
- result = {}
58
- if not track_dir.exists(): return result
59
-
60
- for f in track_dir.glob("metrics_*.toml"):
61
- track_name = f.stem.replace("metrics_", "")
62
- with open(f, "rb") as fp:
63
- try:
64
- data = tomli.load(fp).get("metrics", [])
65
- result[track_name] = data
66
- except:
67
- continue
68
- return result
69
-
70
- # === 展示 Trial 详细信息 ===
71
- st.subheader(f"📊 Trial: {selected_trial_name}")
72
- track_data = load_track_data(selected_trial_path)
73
-
74
- if not track_data:
75
- st.warning("No metrics data found for this trial.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  else:
77
- for track_name, records in track_data.items():
78
- st.markdown(f"### Track: {track_name}")
79
- if not records:
80
- st.info("No data in this track.")
81
- continue
82
-
83
- # 自动获取所有标量键(除 global_step)
84
- keys = [k for k in records[0].keys() if k != "global_step"]
85
- for key in keys:
86
- fig = go.Figure()
87
- fig.add_trace(go.Scatter(
88
- x=[r["global_step"] for r in records],
89
- y=[r.get(key, None) for r in records],
90
- mode='lines+markers',
91
- name=key
92
- ))
93
- fig.update_layout(title=f"{track_name.upper()} - {key}", template="plotly_dark", height=350)
94
- st.plotly_chart(fig, use_container_width=True)
95
-
96
- st.divider()
97
- st.markdown("*Flowillower 🌸 — Minimalist, Poetic and Open Source.*")
 
1
  import streamlit as st
2
  from pathlib import Path
3
+ import altair as alt
4
+
5
+ # 导入重构后的模块
6
+ # 这些文件需要您在与此 app.py 同级的目录下或Python路径中创建
7
+ try:
8
+ from utils import DATA_ROOT_PATH, AppMode
9
+ from data_models import Study, Trial
10
+ from data_loader import discover_studies_cached, ensure_data_directory_exists
11
+ except ImportError as e:
12
+ st.error(
13
+ f"导入模块失败,请确保 utils.py, data_models.py, data_loader.py 文件存在于正确的位置: {e}"
14
+ )
15
+ st.stop()
16
+
17
+
18
+ # --- 应用状态管理 ---
19
+ if "selected_study_name" not in st.session_state:
20
+ st.session_state.selected_study_name = None
21
+ if "selected_trial_name" not in st.session_state:
22
+ st.session_state.selected_trial_name = None
23
+ if "studies_data" not in st.session_state: # To store loaded Study objects
24
+ st.session_state.studies_data = {}
25
+ if "app_mode" not in st.session_state:
26
+ st.session_state.app_mode = AppMode.VIEWING # Default mode
27
+
28
+ # --- Page Configuration ---
29
+ st.set_page_config(layout="wide", page_title="柳暗花明 (flowillower)")
30
+
31
+
32
+ # --- UI Rendering ---
33
+
34
+ # --- Header ---
35
+ header_cols = st.columns([2, 3, 1.5, 0.5, 0.5, 0.5])
36
+ with header_cols[0]:
37
+ st.markdown("## 柳暗花明")
38
+ st.caption("flowillower")
39
+
40
+ # 确保数据目录存在
41
+ ensure_data_directory_exists(DATA_ROOT_PATH)
42
+
43
+ # 加载 Studies
44
+ all_study_objects = discover_studies_cached(DATA_ROOT_PATH) # Returns Dict[str, Study]
45
+ study_names = list(all_study_objects.keys())
46
+
47
+ if not study_names:
48
+ st.warning(f"在 {DATA_ROOT_PATH} 未找到任何 Study。请确保您的数据结构正确或使用 flowillower API 开始记录实验。")
49
+
50
+ # Study 选择
51
+ if study_names:
52
+ with header_cols[1]:
53
+ # 如果 session_state 中的 study_name 不在当前发现的 study_names 中,重置它
54
+ if st.session_state.selected_study_name not in study_names:
55
+ st.session_state.selected_study_name = study_names[0] if study_names else None
56
+
57
+ selected_study_name_from_ui = st.selectbox(
58
+ "选择 Study (Select Study)",
59
+ study_names,
60
+ index=study_names.index(st.session_state.selected_study_name) if st.session_state.selected_study_name in study_names else 0,
61
+ label_visibility="collapsed",
62
+ key="study_selector_ui" # Use a different key to avoid conflict if direct assignment is used
63
+ )
64
+ # Update session state if selection changes
65
+ if selected_study_name_from_ui != st.session_state.selected_study_name:
66
+ st.session_state.selected_study_name = selected_study_name_from_ui
67
+ st.session_state.selected_trial_name = None # Reset trial when study changes
68
+ st.rerun()
69
+
70
+
71
+ with header_cols[2]:
72
+ if st.session_state.selected_study_name:
73
+ st.write(f"当前 Study: **{st.session_state.selected_study_name}**")
74
+ else:
75
+ with header_cols[1]:
76
+ st.info("没有可用的 Study。")
77
+
78
+
79
+ # Placeholder for right-side icons
80
+ with header_cols[3]:
81
+ st.button("➕", help="添加 (Add)", disabled=True)
82
+ with header_cols[4]:
83
+ st.button("⚙️", help="设置 (Settings)", disabled=True)
84
+ with header_cols[5]:
85
+ st.button("👤", help="用户 (User)", disabled=True)
86
+
87
+ st.markdown("---")
88
+
89
+
90
+ # --- Sidebar ---
91
+ current_study: Study | None = None
92
+ if st.session_state.selected_study_name and st.session_state.selected_study_name in all_study_objects:
93
+ current_study = all_study_objects[st.session_state.selected_study_name]
94
+ # Ensure trials are discovered for the current study
95
+ if not current_study.trials: # Discover trials if not already done
96
+ current_study.discover_trials_cached()
97
+
98
+
99
+ trial_names = []
100
+ if current_study:
101
+ trial_names = list(current_study.trials.keys())
102
+
103
+
104
+ with st.sidebar:
105
+ st.markdown("### Study")
106
+ if current_study:
107
+ st.markdown(f"##### {current_study.name}")
108
+ if st.button("刷新 Study 数据 (Refresh Study Data)", use_container_width=True):
109
+ current_study.clear_cache() # Clear specific study cache
110
+ st.rerun()
111
+
112
+ # These buttons can be linked to specific views or functionalities later
113
+ if st.button("概览 (Overview)", use_container_width=True, disabled=True):
114
+ st.toast("功能待实现 (Feature to be implemented)")
115
+ if st.button("图表对比视图 (Chart Comparison View)", use_container_width=True, disabled=True):
116
+ st.toast("功能待实现 (Feature to be implemented)")
117
  else:
118
+ st.markdown("未选择 Study (No Study Selected)")
119
+
120
+ st.markdown("---")
121
+ st.markdown("### Trial")
122
+
123
+ if current_study and trial_names:
124
+ # 如果 session_state 中的 trial_name 不在当前发现的 trial_names 中,重置它
125
+ if st.session_state.selected_trial_name not in trial_names:
126
+ st.session_state.selected_trial_name = trial_names[0] if trial_names else None
127
+
128
+ selected_trial_name_from_ui = st.radio(
129
+ "选择 Trial (Select Trial)",
130
+ trial_names,
131
+ index=trial_names.index(st.session_state.selected_trial_name) if st.session_state.selected_trial_name in trial_names else 0,
132
+ label_visibility="collapsed",
133
+ key="trial_selector_ui"
134
+ )
135
+ if selected_trial_name_from_ui != st.session_state.selected_trial_name:
136
+ st.session_state.selected_trial_name = selected_trial_name_from_ui
137
+ st.rerun()
138
+
139
+ if st.session_state.selected_trial_name:
140
+ st.markdown(f"当前选择: **{st.session_state.selected_trial_name}**")
141
+
142
+ elif current_study:
143
+ st.info(f"Study '{current_study.name}' 中没有 Trial。")
144
+ else:
145
+ st.info("请先选择一个 Study。")
146
+
147
+ st.markdown("---")
148
+ if st.button("⚙️ App 设置 (App Settings)", use_container_width=True, disabled=True):
149
+ st.toast("功能待实现 (Feature to be implemented)")
150
+
151
+
152
+ # --- Main Content Area ---
153
+ current_trial: Trial | None = None
154
+ if current_study and st.session_state.selected_trial_name and st.session_state.selected_trial_name in current_study.trials:
155
+ current_trial = current_study.trials[st.session_state.selected_trial_name]
156
+ # Load data for the current trial if not already loaded (methods are cached)
157
+ current_trial.load_input_variables_cached()
158
+ current_trial.load_metrics_cached()
159
+
160
+
161
+ if current_study and current_trial:
162
+ main_title_cols = st.columns([3,1, 0.5])
163
+ with main_title_cols[0]:
164
+ st.markdown(f"## {current_trial.name}")
165
+ st.caption(f"属于 Study: {current_study.name}")
166
+ with main_title_cols[1]:
167
+ if st.button("刷新 Trial 数据 (Refresh Trial Data)", type="secondary"):
168
+ current_trial.clear_cache()
169
+ st.rerun()
170
+ with main_title_cols[2]:
171
+ st.button("...", help="更多选项 (More Options)", disabled=True)
172
+
173
+ tab_titles = ["图表 (Charts)", "参数 (Parameters)", "系统 (System)", "日志 (Logs)", "环境 (Environment)"]
174
+ tab_charts, tab_params, tab_system, tab_logs, tab_env = st.tabs(tab_titles)
175
+
176
+ with tab_charts:
177
+ st.header("指标图表 (Metrics Charts)")
178
+ st.markdown("---")
179
+
180
+ if not current_trial.metrics_data:
181
+ st.info("当前 Trial 没有可显示的指标数据。请检查 `logs/scalar` 文件夹和 TOML 文件。")
182
+ else:
183
+ num_metrics = len(current_trial.metrics_data)
184
+ cols_per_row = st.slider("每行图表数量 (Charts per row)", 1, 4, min(2, num_metrics) if num_metrics > 0 else 1, key=f"cols_slider_{current_trial.name}")
185
+
186
+ metric_names = sorted(list(current_trial.metrics_data.keys()))
187
+
188
+ for i in range(0, num_metrics, cols_per_row):
189
+ metric_chunk = metric_names[i : i + cols_per_row]
190
+ chart_cols = st.columns(cols_per_row)
191
+ for j, metric_name in enumerate(metric_chunk):
192
+ with chart_cols[j]:
193
+ df_metric = current_trial.get_metric_dataframe(metric_name)
194
+ if df_metric is None or df_metric.empty:
195
+ st.warning(f"指标 '{metric_name}' 数据不完整或缺失。")
196
+ continue
197
+
198
+ with st.container(border=True):
199
+ st.subheader(metric_name)
200
+ try:
201
+ chart = alt.Chart(df_metric).mark_line(point=alt.MarkDef(size=20)).encode(
202
+ x=alt.X('global_step:Q', title='全局步骤 (Global Step)'),
203
+ y=alt.Y('value:Q', title=metric_name, scale=alt.Scale(zero=False)),
204
+ color='track:N',
205
+ tooltip=['global_step', 'value', 'track']
206
+ ).interactive()
207
+ st.altair_chart(chart, use_container_width=True)
208
+ except Exception as e:
209
+ st.error(f"为指标 '{metric_name}' 生成图表时出错: {e}")
210
+ st.dataframe(df_metric)
211
+
212
+ with tab_params:
213
+ st.header("输入参数 (Input Parameters)")
214
+ if current_trial.input_variables:
215
+ st.json(current_trial.input_variables)
216
+ else:
217
+ st.info("未找到 `input_variables.toml` 或文件为空。")
218
+
219
+ # Placeholder tabs
220
+ for tab_content, name in [(tab_system, "系统监控 (System Monitoring)"),
221
+ (tab_logs, "日志 (Logs)"),
222
+ (tab_env, "环境 (Environment)")]:
223
+ with tab_content:
224
+ st.header(name)
225
+ st.info("此功能待您的 `flowillower` API 提供相关数据后实现。")
226
+
227
+ elif not st.session_state.selected_study_name:
228
+ st.info("👈 请从顶部选择一个 Study 开始。(Please select a Study from the top to begin.)")
229
+ elif not st.session_state.selected_trial_name:
230
+ st.info("👈 请从侧边栏选择一个 Trial。(Please select a Trial from the sidebar.)")
231
  else:
232
+ st.info("请选择 Study Trial 以查看数据。")
233
+
234
+
235
+ # --- Footer (Optional) ---
236
+ st.markdown("---")
237
+ st.caption("柳暗花明 (flowillower) - 数据可视化App (Data Visualization App)")
238
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ from pathlib import Path
3
+ from enum import Enum, auto
4
+
5
+ # Base path for studies and trials.
6
+ # Streamlit apps are typically run from their root directory.
7
+ # If your app.py is in 'src/', and 'data/' is at the same level as 'src/',
8
+ # then Path("./data") from app.py's perspective would be Path("../data").
9
+ # For simplicity, assuming data is relative to where streamlit run is executed,
10
+ # or you adjust this path accordingly.
11
+ DATA_ROOT_PATH = Path("./data").resolve()
12
+
13
+
14
+ class AppMode(Enum):
15
+ VIEWING = auto()
16
+ # Potentially other modes like COMPARISON, EDITING etc.