Spaces:
Sleeping
Sleeping
| # data_loader.py | |
| from pathlib import Path | |
| from typing import Dict, List, Any | |
| import pandas as pd | |
| import tomli | |
| import streamlit as st | |
| from functools import lru_cache # For non-Streamlit specific caching if needed | |
| # Assuming utils.py is in the same directory | |
| from infra import ( | |
| DATA_ROOT_PATH, | |
| ) # Used for ensuring directory exists | |
| # --- Cache Clearing Functions --- | |
| # These are more specific cache clearing functions that can be called by model methods. | |
| def clear_study_cache(): | |
| """Clears all study discovery cache.""" | |
| discover_studies_cached.clear() | |
| st.toast("所有 Study 发现缓存已清除。") | |
| def clear_trial_cache(): | |
| """Clears all trial-related data loading caches.""" | |
| # This is a bit broad. Ideally, clear caches for specific trials/studies. | |
| load_input_variables_from_path.clear() | |
| load_all_metrics_for_trial_path.clear() | |
| discover_trials_from_path.clear() | |
| st.toast("所有 Trial 数据加载缓存已清除。") | |
| def clear_specific_trial_metric_cache(trial_path: Path): | |
| load_all_metrics_for_trial_path.clear() # This clears the whole cache for this func | |
| # For more granular control with @st.cache_data, you'd typically rely on Streamlit's | |
| # automatic cache invalidation based on input args, or rerun. | |
| # If using lru_cache, you could do: load_all_metrics_for_trial_path.cache_clear() | |
| # but for st.cache_data, clearing for specific args is not direct. | |
| # The common pattern is to clear the entire function's cache. | |
| st.toast(f"Trial '{trial_path.name}' 的指标缓存已清除 (函数级别)。") | |
| def clear_specific_trial_input_vars_cache(trial_path: Path): | |
| load_input_variables_from_path.clear() | |
| st.toast(f"Trial '{trial_path.name}' 的参数缓存已清除 (函数级别)。") | |
| def clear_specific_study_trial_discovery_cache(study_path: Path): | |
| discover_trials_from_path.clear() | |
| st.toast(f"Study '{study_path.name}' 的 Trial 发现缓存已清除 (函数级别)。") | |
| # --- Data Discovery and Loading Functions (Cached) --- | |
| def ensure_data_directory_exists(data_path: Path = DATA_ROOT_PATH): | |
| """Ensures the root data directory exists.""" | |
| if not data_path.exists(): | |
| try: | |
| data_path.mkdir(parents=True, exist_ok=True) | |
| st.info(f"数据目录 {data_path} 已创建。") | |
| except Exception as e: | |
| st.error(f"创建数据目录 {data_path} 失败: {e}") | |
| st.stop() | |
| elif not data_path.is_dir(): | |
| st.error(f"路径 {data_path} 已存在但不是一个目录。") | |
| st.stop() | |
| # Cache for 1 hour, or adjust as needed | |
| def discover_studies_cached( | |
| _data_root: Path, | |
| ) -> Dict[ | |
| str, Any | |
| ]: # Return type hint as Any to avoid circular dep with data_models.Study | |
| """ | |
| Scans the data_root for study directories and returns a dictionary | |
| mapping study names to Study objects (or just their paths initially). | |
| The actual Study object creation happens in the main app for now. | |
| """ | |
| # To avoid issues with caching complex objects directly, or circular dependencies, | |
| # this function can return simpler structures like Dict[str, Path] | |
| # and the main app or model can instantiate Study objects. | |
| # For this iteration, we'll import Study here for convenience, assuming careful structure. | |
| from data_models import ( | |
| Study, | |
| ) # Local import to help with potential circularity if models grow complex | |
| if not _data_root.is_dir(): | |
| return {} | |
| studies = {} | |
| for d in _data_root.iterdir(): | |
| if d.is_dir(): | |
| studies[d.name] = Study(name=d.name, path=d) | |
| return studies | |
| def discover_trials_from_path(_study_path: Path) -> Dict[str, Path]: | |
| """Scans a study_path for trial directories.""" | |
| if not _study_path.is_dir(): | |
| return {} | |
| trials = {} | |
| for d in _study_path.iterdir(): | |
| if d.is_dir(): | |
| trials[d.name] = d | |
| return trials | |
| def load_input_variables_from_path(_trial_path: Path) -> Dict[str, Any]: | |
| """Loads input_variables.toml from a trial directory.""" | |
| input_vars_file = _trial_path / "input_variables.toml" | |
| if input_vars_file.exists(): | |
| try: | |
| with open(input_vars_file, "rb") as f: | |
| return tomli.load(f) | |
| except tomli.TOMLDecodeError: | |
| # st.error(f"错误:无法解析 input_variables.toml 文件于 {_trial_path}") # Avoid st.error in cached funcs if possible | |
| print(f"Error parsing input_variables.toml at {_trial_path}") | |
| return {} | |
| return {} | |
| def _load_single_metric_toml(_toml_file_path: Path) -> pd.DataFrame: | |
| """Loads metrics from a single TOML file into a DataFrame.""" | |
| if not _toml_file_path.exists(): | |
| return pd.DataFrame() | |
| try: | |
| with open(_toml_file_path, "rb") as f: | |
| data = tomli.load(f) | |
| metrics_list = data.get("metrics", []) | |
| if not metrics_list: | |
| return pd.DataFrame() | |
| return pd.DataFrame(metrics_list) | |
| except tomli.TOMLDecodeError: | |
| print(f"Error parsing TOML file: {_toml_file_path.name}") | |
| return pd.DataFrame() | |
| except Exception as e: | |
| print(f"Error loading {_toml_file_path.name}: {e}") | |
| return pd.DataFrame() | |
| # Cache metric data for 5 minutes | |
| def load_all_metrics_for_trial_path(_trial_path: Path) -> Dict[str, pd.DataFrame]: | |
| """ | |
| Loads all metrics from all tracks in a trial. | |
| Returns a dictionary where keys are metric names (e.g., 'loss', 'accuracy') | |
| and values are DataFrames containing 'global_step', 'value', and 'track'. | |
| """ | |
| scalar_dir = _trial_path / "logs" / "scalar" | |
| if not scalar_dir.is_dir(): | |
| return {} | |
| all_metrics_data_combined: Dict[str, pd.DataFrame] = {} | |
| for toml_file in scalar_dir.glob("metrics_*.toml"): | |
| track_name = toml_file.stem.replace("metrics_", "") | |
| df_track = _load_single_metric_toml(toml_file) | |
| if df_track.empty or "global_step" not in df_track.columns: | |
| continue | |
| id_vars = ["global_step"] | |
| value_vars = [col for col in df_track.columns if col not in id_vars] | |
| if not value_vars: | |
| continue | |
| # Process each metric column individually to build up the combined DataFrame | |
| for metric_col_name in value_vars: | |
| try: | |
| # Create a DataFrame for the current metric and track | |
| current_metric_df = df_track[["global_step", metric_col_name]].copy() | |
| current_metric_df.rename( | |
| columns={metric_col_name: "value"}, inplace=True | |
| ) | |
| current_metric_df["track"] = track_name | |
| current_metric_df["value"] = pd.to_numeric( | |
| current_metric_df["value"], errors="coerce" | |
| ) | |
| current_metric_df.dropna(subset=["value"], inplace=True) | |
| if current_metric_df.empty: | |
| continue | |
| # Append to the combined DataFrame for this metric_col_name | |
| if metric_col_name not in all_metrics_data_combined: | |
| all_metrics_data_combined[metric_col_name] = current_metric_df | |
| else: | |
| all_metrics_data_combined[metric_col_name] = pd.concat( | |
| [all_metrics_data_combined[metric_col_name], current_metric_df], | |
| ignore_index=True, | |
| ) | |
| except Exception as e: | |
| print( | |
| f"Error processing metric '{metric_col_name}' from file '{toml_file.name}': {e}" | |
| ) | |
| continue | |
| # Sort data by global_step for proper line plotting | |
| for metric_name in all_metrics_data_combined: | |
| all_metrics_data_combined[metric_name] = ( | |
| all_metrics_data_combined[metric_name] | |
| .sort_values(by=["track", "global_step"]) | |
| .reset_index(drop=True) | |
| ) | |
| return all_metrics_data_combined | |