# data_loader.py from pathlib import Path from typing import Dict, List, Any import pandas as pd import tomli import streamlit as st from functools import lru_cache # For non-Streamlit specific caching if needed # Assuming utils.py is in the same directory from infra import ( DATA_ROOT_PATH, ) # Used for ensuring directory exists # --- Cache Clearing Functions --- # These are more specific cache clearing functions that can be called by model methods. def clear_study_cache(): """Clears all study discovery cache.""" discover_studies_cached.clear() st.toast("所有 Study 发现缓存已清除。") def clear_trial_cache(): """Clears all trial-related data loading caches.""" # This is a bit broad. Ideally, clear caches for specific trials/studies. load_input_variables_from_path.clear() load_all_metrics_for_trial_path.clear() discover_trials_from_path.clear() st.toast("所有 Trial 数据加载缓存已清除。") def clear_specific_trial_metric_cache(trial_path: Path): load_all_metrics_for_trial_path.clear() # This clears the whole cache for this func # For more granular control with @st.cache_data, you'd typically rely on Streamlit's # automatic cache invalidation based on input args, or rerun. # If using lru_cache, you could do: load_all_metrics_for_trial_path.cache_clear() # but for st.cache_data, clearing for specific args is not direct. # The common pattern is to clear the entire function's cache. st.toast(f"Trial '{trial_path.name}' 的指标缓存已清除 (函数级别)。") def clear_specific_trial_input_vars_cache(trial_path: Path): load_input_variables_from_path.clear() st.toast(f"Trial '{trial_path.name}' 的参数缓存已清除 (函数级别)。") def clear_specific_study_trial_discovery_cache(study_path: Path): discover_trials_from_path.clear() st.toast(f"Study '{study_path.name}' 的 Trial 发现缓存已清除 (函数级别)。") # --- Data Discovery and Loading Functions (Cached) --- def ensure_data_directory_exists(data_path: Path = DATA_ROOT_PATH): """Ensures the root data directory exists.""" if not data_path.exists(): try: data_path.mkdir(parents=True, exist_ok=True) st.info(f"数据目录 {data_path} 已创建。") except Exception as e: st.error(f"创建数据目录 {data_path} 失败: {e}") st.stop() elif not data_path.is_dir(): st.error(f"路径 {data_path} 已存在但不是一个目录。") st.stop() @st.cache_data(ttl=3600) # Cache for 1 hour, or adjust as needed def discover_studies_cached( _data_root: Path, ) -> Dict[ str, Any ]: # Return type hint as Any to avoid circular dep with data_models.Study """ Scans the data_root for study directories and returns a dictionary mapping study names to Study objects (or just their paths initially). The actual Study object creation happens in the main app for now. """ # To avoid issues with caching complex objects directly, or circular dependencies, # this function can return simpler structures like Dict[str, Path] # and the main app or model can instantiate Study objects. # For this iteration, we'll import Study here for convenience, assuming careful structure. from data_models import ( Study, ) # Local import to help with potential circularity if models grow complex if not _data_root.is_dir(): return {} studies = {} for d in _data_root.iterdir(): if d.is_dir(): studies[d.name] = Study(name=d.name, path=d) return studies @st.cache_data(ttl=3600) def discover_trials_from_path(_study_path: Path) -> Dict[str, Path]: """Scans a study_path for trial directories.""" if not _study_path.is_dir(): return {} trials = {} for d in _study_path.iterdir(): if d.is_dir(): trials[d.name] = d return trials @st.cache_data(ttl=3600) def load_input_variables_from_path(_trial_path: Path) -> Dict[str, Any]: """Loads input_variables.toml from a trial directory.""" input_vars_file = _trial_path / "input_variables.toml" if input_vars_file.exists(): try: with open(input_vars_file, "rb") as f: return tomli.load(f) except tomli.TOMLDecodeError: # st.error(f"错误:无法解析 input_variables.toml 文件于 {_trial_path}") # Avoid st.error in cached funcs if possible print(f"Error parsing input_variables.toml at {_trial_path}") return {} return {} def _load_single_metric_toml(_toml_file_path: Path) -> pd.DataFrame: """Loads metrics from a single TOML file into a DataFrame.""" if not _toml_file_path.exists(): return pd.DataFrame() try: with open(_toml_file_path, "rb") as f: data = tomli.load(f) metrics_list = data.get("metrics", []) if not metrics_list: return pd.DataFrame() return pd.DataFrame(metrics_list) except tomli.TOMLDecodeError: print(f"Error parsing TOML file: {_toml_file_path.name}") return pd.DataFrame() except Exception as e: print(f"Error loading {_toml_file_path.name}: {e}") return pd.DataFrame() @st.cache_data(ttl=300) # Cache metric data for 5 minutes def load_all_metrics_for_trial_path(_trial_path: Path) -> Dict[str, pd.DataFrame]: """ Loads all metrics from all tracks in a trial. Returns a dictionary where keys are metric names (e.g., 'loss', 'accuracy') and values are DataFrames containing 'global_step', 'value', and 'track'. """ scalar_dir = _trial_path / "logs" / "scalar" if not scalar_dir.is_dir(): return {} all_metrics_data_combined: Dict[str, pd.DataFrame] = {} for toml_file in scalar_dir.glob("metrics_*.toml"): track_name = toml_file.stem.replace("metrics_", "") df_track = _load_single_metric_toml(toml_file) if df_track.empty or "global_step" not in df_track.columns: continue id_vars = ["global_step"] value_vars = [col for col in df_track.columns if col not in id_vars] if not value_vars: continue # Process each metric column individually to build up the combined DataFrame for metric_col_name in value_vars: try: # Create a DataFrame for the current metric and track current_metric_df = df_track[["global_step", metric_col_name]].copy() current_metric_df.rename( columns={metric_col_name: "value"}, inplace=True ) current_metric_df["track"] = track_name current_metric_df["value"] = pd.to_numeric( current_metric_df["value"], errors="coerce" ) current_metric_df.dropna(subset=["value"], inplace=True) if current_metric_df.empty: continue # Append to the combined DataFrame for this metric_col_name if metric_col_name not in all_metrics_data_combined: all_metrics_data_combined[metric_col_name] = current_metric_df else: all_metrics_data_combined[metric_col_name] = pd.concat( [all_metrics_data_combined[metric_col_name], current_metric_df], ignore_index=True, ) except Exception as e: print( f"Error processing metric '{metric_col_name}' from file '{toml_file.name}': {e}" ) continue # Sort data by global_step for proper line plotting for metric_name in all_metrics_data_combined: all_metrics_data_combined[metric_name] = ( all_metrics_data_combined[metric_name] .sort_values(by=["track", "global_step"]) .reset_index(drop=True) ) return all_metrics_data_combined