File size: 8,065 Bytes
1e63386
 
 
 
 
 
d70d84b
1e63386
 
4dcd117
 
 
1e63386
 
 
 
d70d84b
1e63386
 
 
 
 
d70d84b
1e63386
 
 
 
 
 
 
 
d70d84b
1e63386
d70d84b
1e63386
 
 
 
 
 
 
 
 
 
 
 
d70d84b
1e63386
 
 
 
 
 
 
d70d84b
1e63386
 
 
 
 
 
 
 
 
 
 
 
 
 
d70d84b
 
 
 
 
 
1e63386
 
 
 
 
 
 
 
 
d70d84b
 
 
1e63386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d70d84b
1e63386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d70d84b
1e63386
 
d70d84b
1e63386
 
 
 
 
d70d84b
 
 
1e63386
d70d84b
 
 
 
1e63386
 
 
 
 
 
 
 
 
 
d70d84b
1e63386
 
d70d84b
 
 
1e63386
d70d84b
1e63386
 
d70d84b
 
 
 
 
1e63386
d70d84b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# data_loader.py
from pathlib import Path
from typing import Dict, List, Any
import pandas as pd
import tomli
import streamlit as st
from functools import lru_cache  # For non-Streamlit specific caching if needed

# Assuming utils.py is in the same directory
from infra import (
    DATA_ROOT_PATH,
)  # Used for ensuring directory exists

# --- Cache Clearing Functions ---
# These are more specific cache clearing functions that can be called by model methods.


def clear_study_cache():
    """Clears all study discovery cache."""
    discover_studies_cached.clear()
    st.toast("所有 Study 发现缓存已清除。")


def clear_trial_cache():
    """Clears all trial-related data loading caches."""
    # This is a bit broad. Ideally, clear caches for specific trials/studies.
    load_input_variables_from_path.clear()
    load_all_metrics_for_trial_path.clear()
    discover_trials_from_path.clear()
    st.toast("所有 Trial 数据加载缓存已清除。")


def clear_specific_trial_metric_cache(trial_path: Path):
    load_all_metrics_for_trial_path.clear()  # This clears the whole cache for this func
    # For more granular control with @st.cache_data, you'd typically rely on Streamlit's
    # automatic cache invalidation based on input args, or rerun.
    # If using lru_cache, you could do: load_all_metrics_for_trial_path.cache_clear()
    # but for st.cache_data, clearing for specific args is not direct.
    # The common pattern is to clear the entire function's cache.
    st.toast(f"Trial '{trial_path.name}' 的指标缓存已清除 (函数级别)。")


def clear_specific_trial_input_vars_cache(trial_path: Path):
    load_input_variables_from_path.clear()
    st.toast(f"Trial '{trial_path.name}' 的参数缓存已清除 (函数级别)。")


def clear_specific_study_trial_discovery_cache(study_path: Path):
    discover_trials_from_path.clear()
    st.toast(f"Study '{study_path.name}' 的 Trial 发现缓存已清除 (函数级别)。")


# --- Data Discovery and Loading Functions (Cached) ---


def ensure_data_directory_exists(data_path: Path = DATA_ROOT_PATH):
    """Ensures the root data directory exists."""
    if not data_path.exists():
        try:
            data_path.mkdir(parents=True, exist_ok=True)
            st.info(f"数据目录 {data_path} 已创建。")
        except Exception as e:
            st.error(f"创建数据目录 {data_path} 失败: {e}")
            st.stop()
    elif not data_path.is_dir():
        st.error(f"路径 {data_path} 已存在但不是一个目录。")
        st.stop()


@st.cache_data(ttl=3600)  # Cache for 1 hour, or adjust as needed
def discover_studies_cached(
    _data_root: Path,
) -> Dict[
    str, Any
]:  # Return type hint as Any to avoid circular dep with data_models.Study
    """
    Scans the data_root for study directories and returns a dictionary
    mapping study names to Study objects (or just their paths initially).
    The actual Study object creation happens in the main app for now.
    """
    # To avoid issues with caching complex objects directly, or circular dependencies,
    # this function can return simpler structures like Dict[str, Path]
    # and the main app or model can instantiate Study objects.
    # For this iteration, we'll import Study here for convenience, assuming careful structure.
    from data_models import (
        Study,
    )  # Local import to help with potential circularity if models grow complex

    if not _data_root.is_dir():
        return {}
    studies = {}
    for d in _data_root.iterdir():
        if d.is_dir():
            studies[d.name] = Study(name=d.name, path=d)
    return studies


@st.cache_data(ttl=3600)
def discover_trials_from_path(_study_path: Path) -> Dict[str, Path]:
    """Scans a study_path for trial directories."""
    if not _study_path.is_dir():
        return {}
    trials = {}
    for d in _study_path.iterdir():
        if d.is_dir():
            trials[d.name] = d
    return trials


@st.cache_data(ttl=3600)
def load_input_variables_from_path(_trial_path: Path) -> Dict[str, Any]:
    """Loads input_variables.toml from a trial directory."""
    input_vars_file = _trial_path / "input_variables.toml"
    if input_vars_file.exists():
        try:
            with open(input_vars_file, "rb") as f:
                return tomli.load(f)
        except tomli.TOMLDecodeError:
            # st.error(f"错误:无法解析 input_variables.toml 文件于 {_trial_path}") # Avoid st.error in cached funcs if possible
            print(f"Error parsing input_variables.toml at {_trial_path}")
            return {}
    return {}


def _load_single_metric_toml(_toml_file_path: Path) -> pd.DataFrame:
    """Loads metrics from a single TOML file into a DataFrame."""
    if not _toml_file_path.exists():
        return pd.DataFrame()
    try:
        with open(_toml_file_path, "rb") as f:
            data = tomli.load(f)
        metrics_list = data.get("metrics", [])
        if not metrics_list:
            return pd.DataFrame()
        return pd.DataFrame(metrics_list)
    except tomli.TOMLDecodeError:
        print(f"Error parsing TOML file: {_toml_file_path.name}")
        return pd.DataFrame()
    except Exception as e:
        print(f"Error loading {_toml_file_path.name}: {e}")
        return pd.DataFrame()


@st.cache_data(ttl=300)  # Cache metric data for 5 minutes
def load_all_metrics_for_trial_path(_trial_path: Path) -> Dict[str, pd.DataFrame]:
    """
    Loads all metrics from all tracks in a trial.
    Returns a dictionary where keys are metric names (e.g., 'loss', 'accuracy')
    and values are DataFrames containing 'global_step', 'value', and 'track'.
    """
    scalar_dir = _trial_path / "logs" / "scalar"
    if not scalar_dir.is_dir():
        return {}

    all_metrics_data_combined: Dict[str, pd.DataFrame] = {}

    for toml_file in scalar_dir.glob("metrics_*.toml"):
        track_name = toml_file.stem.replace("metrics_", "")
        df_track = _load_single_metric_toml(toml_file)

        if df_track.empty or "global_step" not in df_track.columns:
            continue

        id_vars = ["global_step"]
        value_vars = [col for col in df_track.columns if col not in id_vars]

        if not value_vars:
            continue

        # Process each metric column individually to build up the combined DataFrame
        for metric_col_name in value_vars:
            try:
                # Create a DataFrame for the current metric and track
                current_metric_df = df_track[["global_step", metric_col_name]].copy()
                current_metric_df.rename(
                    columns={metric_col_name: "value"}, inplace=True
                )
                current_metric_df["track"] = track_name
                current_metric_df["value"] = pd.to_numeric(
                    current_metric_df["value"], errors="coerce"
                )
                current_metric_df.dropna(subset=["value"], inplace=True)

                if current_metric_df.empty:
                    continue

                # Append to the combined DataFrame for this metric_col_name
                if metric_col_name not in all_metrics_data_combined:
                    all_metrics_data_combined[metric_col_name] = current_metric_df
                else:
                    all_metrics_data_combined[metric_col_name] = pd.concat(
                        [all_metrics_data_combined[metric_col_name], current_metric_df],
                        ignore_index=True,
                    )
            except Exception as e:
                print(
                    f"Error processing metric '{metric_col_name}' from file '{toml_file.name}': {e}"
                )
                continue

    # Sort data by global_step for proper line plotting
    for metric_name in all_metrics_data_combined:
        all_metrics_data_combined[metric_name] = (
            all_metrics_data_combined[metric_name]
            .sort_values(by=["track", "global_step"])
            .reset_index(drop=True)
        )

    return all_metrics_data_combined