# data_models.py from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Any import pandas as pd import streamlit as st # For caching # Import from data_loader, assuming it's in the same directory # We'll define these functions in data_loader.py # To avoid circular imports, data_loader functions won't import data_models directly for type hints if possible, # or use string type hints. # Forward declaration for type hint if data_loader needs Study/Trial # class Study: pass # class Trial: pass from data_loader import ( load_input_variables_from_path, load_all_metrics_for_trial_path, discover_trials_from_path, clear_trial_cache as clear_trial_loader_cache, clear_study_cache as clear_study_loader_cache, clear_specific_trial_metric_cache, clear_specific_trial_input_vars_cache, clear_specific_study_trial_discovery_cache, ) @dataclass class Trial: name: str path: Path study_name: str # To know its parent study input_variables: Dict[str, Any] = field(default_factory=dict, repr=False) metrics_data: Dict[str, pd.DataFrame] = field( default_factory=dict, repr=False ) # Key: metric_name, Value: DataFrame with global_step, value, track def __post_init__(self): # Automatically load data if needed, but prefer explicit calls from UI for clarity pass # Use st.cache_data on the loader functions, not directly here for complex objects. # Instead, methods here will call cached loader functions. def load_input_variables_cached(self): """Loads or retrieves cached input variables.""" if not self.input_variables: # Load only if not already populated self.input_variables = load_input_variables_from_path(self.path) return self.input_variables def load_metrics_cached(self): """Loads or retrieves cached metrics data.""" if not self.metrics_data: # Load only if not already populated self.metrics_data = load_all_metrics_for_trial_path(self.path) return self.metrics_data def get_metric_dataframe(self, metric_name: str) -> Optional[pd.DataFrame]: """Returns the DataFrame for a specific metric, combining all tracks.""" return self.metrics_data.get(metric_name) def clear_cache(self): """Clears cached data for this specific trial.""" # Clear Streamlit's cache for functions related to this trial clear_specific_trial_metric_cache(self.path) clear_specific_trial_input_vars_cache(self.path) # Reset instance variables self.input_variables = {} self.metrics_data = {} st.success(f"Trial '{self.name}' 的缓存已清除。") @dataclass class Study: name: str path: Path trials: Dict[str, Trial] = field(default_factory=dict, repr=False) def discover_trials_cached(self): """Discovers or retrieves cached trials for this study.""" if not self.trials: # Discover only if not already populated trial_paths = discover_trials_from_path( self.path ) # This loader function should be cached for trial_name, trial_path in trial_paths.items(): self.trials[trial_name] = Trial( name=trial_name, path=trial_path, study_name=self.name ) return self.trials def get_trial(self, trial_name: str) -> Optional[Trial]: return self.trials.get(trial_name) def clear_cache(self): """Clears cached data for this study and its trials.""" clear_specific_study_trial_discovery_cache(self.path) for trial in self.trials.values(): trial.clear_cache() # Clear cache for each trial within the study self.trials = {} # Reset trials dictionary st.success(f"Study '{self.name}' 及其 Trials 的缓存已清除。")