File size: 3,931 Bytes
1e63386
 
 
 
 
d70d84b
1e63386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d70d84b
1e63386
 
 
 
 
 
 
d70d84b
1e63386
d70d84b
 
 
1e63386
 
 
 
 
 
 
 
 
 
d70d84b
1e63386
 
 
 
 
d70d84b
1e63386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d70d84b
 
 
 
1e63386
d70d84b
 
 
1e63386
 
 
 
 
 
 
 
 
d70d84b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# data_models.py
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Any
import pandas as pd
import streamlit as st  # For caching

# Import from data_loader, assuming it's in the same directory
# We'll define these functions in data_loader.py
# To avoid circular imports, data_loader functions won't import data_models directly for type hints if possible,
# or use string type hints.

# Forward declaration for type hint if data_loader needs Study/Trial
# class Study: pass
# class Trial: pass

from data_loader import (
    load_input_variables_from_path,
    load_all_metrics_for_trial_path,
    discover_trials_from_path,
    clear_trial_cache as clear_trial_loader_cache,
    clear_study_cache as clear_study_loader_cache,
    clear_specific_trial_metric_cache,
    clear_specific_trial_input_vars_cache,
    clear_specific_study_trial_discovery_cache,
)


@dataclass
class Trial:
    name: str
    path: Path
    study_name: str  # To know its parent study
    input_variables: Dict[str, Any] = field(default_factory=dict, repr=False)
    metrics_data: Dict[str, pd.DataFrame] = field(
        default_factory=dict, repr=False
    )  # Key: metric_name, Value: DataFrame with global_step, value, track

    def __post_init__(self):
        # Automatically load data if needed, but prefer explicit calls from UI for clarity
        pass

    # Use st.cache_data on the loader functions, not directly here for complex objects.
    # Instead, methods here will call cached loader functions.

    def load_input_variables_cached(self):
        """Loads or retrieves cached input variables."""
        if not self.input_variables:  # Load only if not already populated
            self.input_variables = load_input_variables_from_path(self.path)
        return self.input_variables

    def load_metrics_cached(self):
        """Loads or retrieves cached metrics data."""
        if not self.metrics_data:  # Load only if not already populated
            self.metrics_data = load_all_metrics_for_trial_path(self.path)
        return self.metrics_data

    def get_metric_dataframe(self, metric_name: str) -> Optional[pd.DataFrame]:
        """Returns the DataFrame for a specific metric, combining all tracks."""
        return self.metrics_data.get(metric_name)

    def clear_cache(self):
        """Clears cached data for this specific trial."""
        # Clear Streamlit's cache for functions related to this trial
        clear_specific_trial_metric_cache(self.path)
        clear_specific_trial_input_vars_cache(self.path)
        # Reset instance variables
        self.input_variables = {}
        self.metrics_data = {}
        st.success(f"Trial '{self.name}' 的缓存已清除。")


@dataclass
class Study:
    name: str
    path: Path
    trials: Dict[str, Trial] = field(default_factory=dict, repr=False)

    def discover_trials_cached(self):
        """Discovers or retrieves cached trials for this study."""
        if not self.trials:  # Discover only if not already populated
            trial_paths = discover_trials_from_path(
                self.path
            )  # This loader function should be cached
            for trial_name, trial_path in trial_paths.items():
                self.trials[trial_name] = Trial(
                    name=trial_name, path=trial_path, study_name=self.name
                )
        return self.trials

    def get_trial(self, trial_name: str) -> Optional[Trial]:
        return self.trials.get(trial_name)

    def clear_cache(self):
        """Clears cached data for this study and its trials."""
        clear_specific_study_trial_discovery_cache(self.path)
        for trial in self.trials.values():
            trial.clear_cache()  # Clear cache for each trial within the study
        self.trials = {}  # Reset trials dictionary
        st.success(f"Study '{self.name}' 及其 Trials 的缓存已清除。")