flowillower_app_template / src /data_models.py
YeCanming
feat: format
d70d84b
# data_models.py
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Any
import pandas as pd
import streamlit as st # For caching
# Import from data_loader, assuming it's in the same directory
# We'll define these functions in data_loader.py
# To avoid circular imports, data_loader functions won't import data_models directly for type hints if possible,
# or use string type hints.
# Forward declaration for type hint if data_loader needs Study/Trial
# class Study: pass
# class Trial: pass
from data_loader import (
load_input_variables_from_path,
load_all_metrics_for_trial_path,
discover_trials_from_path,
clear_trial_cache as clear_trial_loader_cache,
clear_study_cache as clear_study_loader_cache,
clear_specific_trial_metric_cache,
clear_specific_trial_input_vars_cache,
clear_specific_study_trial_discovery_cache,
)
@dataclass
class Trial:
name: str
path: Path
study_name: str # To know its parent study
input_variables: Dict[str, Any] = field(default_factory=dict, repr=False)
metrics_data: Dict[str, pd.DataFrame] = field(
default_factory=dict, repr=False
) # Key: metric_name, Value: DataFrame with global_step, value, track
def __post_init__(self):
# Automatically load data if needed, but prefer explicit calls from UI for clarity
pass
# Use st.cache_data on the loader functions, not directly here for complex objects.
# Instead, methods here will call cached loader functions.
def load_input_variables_cached(self):
"""Loads or retrieves cached input variables."""
if not self.input_variables: # Load only if not already populated
self.input_variables = load_input_variables_from_path(self.path)
return self.input_variables
def load_metrics_cached(self):
"""Loads or retrieves cached metrics data."""
if not self.metrics_data: # Load only if not already populated
self.metrics_data = load_all_metrics_for_trial_path(self.path)
return self.metrics_data
def get_metric_dataframe(self, metric_name: str) -> Optional[pd.DataFrame]:
"""Returns the DataFrame for a specific metric, combining all tracks."""
return self.metrics_data.get(metric_name)
def clear_cache(self):
"""Clears cached data for this specific trial."""
# Clear Streamlit's cache for functions related to this trial
clear_specific_trial_metric_cache(self.path)
clear_specific_trial_input_vars_cache(self.path)
# Reset instance variables
self.input_variables = {}
self.metrics_data = {}
st.success(f"Trial '{self.name}' 的缓存已清除。")
@dataclass
class Study:
name: str
path: Path
trials: Dict[str, Trial] = field(default_factory=dict, repr=False)
def discover_trials_cached(self):
"""Discovers or retrieves cached trials for this study."""
if not self.trials: # Discover only if not already populated
trial_paths = discover_trials_from_path(
self.path
) # This loader function should be cached
for trial_name, trial_path in trial_paths.items():
self.trials[trial_name] = Trial(
name=trial_name, path=trial_path, study_name=self.name
)
return self.trials
def get_trial(self, trial_name: str) -> Optional[Trial]:
return self.trials.get(trial_name)
def clear_cache(self):
"""Clears cached data for this study and its trials."""
clear_specific_study_trial_discovery_cache(self.path)
for trial in self.trials.values():
trial.clear_cache() # Clear cache for each trial within the study
self.trials = {} # Reset trials dictionary
st.success(f"Study '{self.name}' 及其 Trials 的缓存已清除。")