Spaces:
Sleeping
Sleeping
| # data_models.py | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Any | |
| import pandas as pd | |
| import streamlit as st # For caching | |
| # Import from data_loader, assuming it's in the same directory | |
| # We'll define these functions in data_loader.py | |
| # To avoid circular imports, data_loader functions won't import data_models directly for type hints if possible, | |
| # or use string type hints. | |
| # Forward declaration for type hint if data_loader needs Study/Trial | |
| # class Study: pass | |
| # class Trial: pass | |
| from data_loader import ( | |
| load_input_variables_from_path, | |
| load_all_metrics_for_trial_path, | |
| discover_trials_from_path, | |
| clear_trial_cache as clear_trial_loader_cache, | |
| clear_study_cache as clear_study_loader_cache, | |
| clear_specific_trial_metric_cache, | |
| clear_specific_trial_input_vars_cache, | |
| clear_specific_study_trial_discovery_cache, | |
| ) | |
| class Trial: | |
| name: str | |
| path: Path | |
| study_name: str # To know its parent study | |
| input_variables: Dict[str, Any] = field(default_factory=dict, repr=False) | |
| metrics_data: Dict[str, pd.DataFrame] = field( | |
| default_factory=dict, repr=False | |
| ) # Key: metric_name, Value: DataFrame with global_step, value, track | |
| def __post_init__(self): | |
| # Automatically load data if needed, but prefer explicit calls from UI for clarity | |
| pass | |
| # Use st.cache_data on the loader functions, not directly here for complex objects. | |
| # Instead, methods here will call cached loader functions. | |
| def load_input_variables_cached(self): | |
| """Loads or retrieves cached input variables.""" | |
| if not self.input_variables: # Load only if not already populated | |
| self.input_variables = load_input_variables_from_path(self.path) | |
| return self.input_variables | |
| def load_metrics_cached(self): | |
| """Loads or retrieves cached metrics data.""" | |
| if not self.metrics_data: # Load only if not already populated | |
| self.metrics_data = load_all_metrics_for_trial_path(self.path) | |
| return self.metrics_data | |
| def get_metric_dataframe(self, metric_name: str) -> Optional[pd.DataFrame]: | |
| """Returns the DataFrame for a specific metric, combining all tracks.""" | |
| return self.metrics_data.get(metric_name) | |
| def clear_cache(self): | |
| """Clears cached data for this specific trial.""" | |
| # Clear Streamlit's cache for functions related to this trial | |
| clear_specific_trial_metric_cache(self.path) | |
| clear_specific_trial_input_vars_cache(self.path) | |
| # Reset instance variables | |
| self.input_variables = {} | |
| self.metrics_data = {} | |
| st.success(f"Trial '{self.name}' 的缓存已清除。") | |
| class Study: | |
| name: str | |
| path: Path | |
| trials: Dict[str, Trial] = field(default_factory=dict, repr=False) | |
| def discover_trials_cached(self): | |
| """Discovers or retrieves cached trials for this study.""" | |
| if not self.trials: # Discover only if not already populated | |
| trial_paths = discover_trials_from_path( | |
| self.path | |
| ) # This loader function should be cached | |
| for trial_name, trial_path in trial_paths.items(): | |
| self.trials[trial_name] = Trial( | |
| name=trial_name, path=trial_path, study_name=self.name | |
| ) | |
| return self.trials | |
| def get_trial(self, trial_name: str) -> Optional[Trial]: | |
| return self.trials.get(trial_name) | |
| def clear_cache(self): | |
| """Clears cached data for this study and its trials.""" | |
| clear_specific_study_trial_discovery_cache(self.path) | |
| for trial in self.trials.values(): | |
| trial.clear_cache() # Clear cache for each trial within the study | |
| self.trials = {} # Reset trials dictionary | |
| st.success(f"Study '{self.name}' 及其 Trials 的缓存已清除。") | |