Spaces:
Sleeping
Sleeping
File size: 8,065 Bytes
1e63386 d70d84b 1e63386 4dcd117 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b 1e63386 d70d84b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# data_loader.py
from pathlib import Path
from typing import Dict, List, Any
import pandas as pd
import tomli
import streamlit as st
from functools import lru_cache # For non-Streamlit specific caching if needed
# Assuming utils.py is in the same directory
from infra import (
DATA_ROOT_PATH,
) # Used for ensuring directory exists
# --- Cache Clearing Functions ---
# These are more specific cache clearing functions that can be called by model methods.
def clear_study_cache():
"""Clears all study discovery cache."""
discover_studies_cached.clear()
st.toast("所有 Study 发现缓存已清除。")
def clear_trial_cache():
"""Clears all trial-related data loading caches."""
# This is a bit broad. Ideally, clear caches for specific trials/studies.
load_input_variables_from_path.clear()
load_all_metrics_for_trial_path.clear()
discover_trials_from_path.clear()
st.toast("所有 Trial 数据加载缓存已清除。")
def clear_specific_trial_metric_cache(trial_path: Path):
load_all_metrics_for_trial_path.clear() # This clears the whole cache for this func
# For more granular control with @st.cache_data, you'd typically rely on Streamlit's
# automatic cache invalidation based on input args, or rerun.
# If using lru_cache, you could do: load_all_metrics_for_trial_path.cache_clear()
# but for st.cache_data, clearing for specific args is not direct.
# The common pattern is to clear the entire function's cache.
st.toast(f"Trial '{trial_path.name}' 的指标缓存已清除 (函数级别)。")
def clear_specific_trial_input_vars_cache(trial_path: Path):
load_input_variables_from_path.clear()
st.toast(f"Trial '{trial_path.name}' 的参数缓存已清除 (函数级别)。")
def clear_specific_study_trial_discovery_cache(study_path: Path):
discover_trials_from_path.clear()
st.toast(f"Study '{study_path.name}' 的 Trial 发现缓存已清除 (函数级别)。")
# --- Data Discovery and Loading Functions (Cached) ---
def ensure_data_directory_exists(data_path: Path = DATA_ROOT_PATH):
"""Ensures the root data directory exists."""
if not data_path.exists():
try:
data_path.mkdir(parents=True, exist_ok=True)
st.info(f"数据目录 {data_path} 已创建。")
except Exception as e:
st.error(f"创建数据目录 {data_path} 失败: {e}")
st.stop()
elif not data_path.is_dir():
st.error(f"路径 {data_path} 已存在但不是一个目录。")
st.stop()
@st.cache_data(ttl=3600) # Cache for 1 hour, or adjust as needed
def discover_studies_cached(
_data_root: Path,
) -> Dict[
str, Any
]: # Return type hint as Any to avoid circular dep with data_models.Study
"""
Scans the data_root for study directories and returns a dictionary
mapping study names to Study objects (or just their paths initially).
The actual Study object creation happens in the main app for now.
"""
# To avoid issues with caching complex objects directly, or circular dependencies,
# this function can return simpler structures like Dict[str, Path]
# and the main app or model can instantiate Study objects.
# For this iteration, we'll import Study here for convenience, assuming careful structure.
from data_models import (
Study,
) # Local import to help with potential circularity if models grow complex
if not _data_root.is_dir():
return {}
studies = {}
for d in _data_root.iterdir():
if d.is_dir():
studies[d.name] = Study(name=d.name, path=d)
return studies
@st.cache_data(ttl=3600)
def discover_trials_from_path(_study_path: Path) -> Dict[str, Path]:
"""Scans a study_path for trial directories."""
if not _study_path.is_dir():
return {}
trials = {}
for d in _study_path.iterdir():
if d.is_dir():
trials[d.name] = d
return trials
@st.cache_data(ttl=3600)
def load_input_variables_from_path(_trial_path: Path) -> Dict[str, Any]:
"""Loads input_variables.toml from a trial directory."""
input_vars_file = _trial_path / "input_variables.toml"
if input_vars_file.exists():
try:
with open(input_vars_file, "rb") as f:
return tomli.load(f)
except tomli.TOMLDecodeError:
# st.error(f"错误:无法解析 input_variables.toml 文件于 {_trial_path}") # Avoid st.error in cached funcs if possible
print(f"Error parsing input_variables.toml at {_trial_path}")
return {}
return {}
def _load_single_metric_toml(_toml_file_path: Path) -> pd.DataFrame:
"""Loads metrics from a single TOML file into a DataFrame."""
if not _toml_file_path.exists():
return pd.DataFrame()
try:
with open(_toml_file_path, "rb") as f:
data = tomli.load(f)
metrics_list = data.get("metrics", [])
if not metrics_list:
return pd.DataFrame()
return pd.DataFrame(metrics_list)
except tomli.TOMLDecodeError:
print(f"Error parsing TOML file: {_toml_file_path.name}")
return pd.DataFrame()
except Exception as e:
print(f"Error loading {_toml_file_path.name}: {e}")
return pd.DataFrame()
@st.cache_data(ttl=300) # Cache metric data for 5 minutes
def load_all_metrics_for_trial_path(_trial_path: Path) -> Dict[str, pd.DataFrame]:
"""
Loads all metrics from all tracks in a trial.
Returns a dictionary where keys are metric names (e.g., 'loss', 'accuracy')
and values are DataFrames containing 'global_step', 'value', and 'track'.
"""
scalar_dir = _trial_path / "logs" / "scalar"
if not scalar_dir.is_dir():
return {}
all_metrics_data_combined: Dict[str, pd.DataFrame] = {}
for toml_file in scalar_dir.glob("metrics_*.toml"):
track_name = toml_file.stem.replace("metrics_", "")
df_track = _load_single_metric_toml(toml_file)
if df_track.empty or "global_step" not in df_track.columns:
continue
id_vars = ["global_step"]
value_vars = [col for col in df_track.columns if col not in id_vars]
if not value_vars:
continue
# Process each metric column individually to build up the combined DataFrame
for metric_col_name in value_vars:
try:
# Create a DataFrame for the current metric and track
current_metric_df = df_track[["global_step", metric_col_name]].copy()
current_metric_df.rename(
columns={metric_col_name: "value"}, inplace=True
)
current_metric_df["track"] = track_name
current_metric_df["value"] = pd.to_numeric(
current_metric_df["value"], errors="coerce"
)
current_metric_df.dropna(subset=["value"], inplace=True)
if current_metric_df.empty:
continue
# Append to the combined DataFrame for this metric_col_name
if metric_col_name not in all_metrics_data_combined:
all_metrics_data_combined[metric_col_name] = current_metric_df
else:
all_metrics_data_combined[metric_col_name] = pd.concat(
[all_metrics_data_combined[metric_col_name], current_metric_df],
ignore_index=True,
)
except Exception as e:
print(
f"Error processing metric '{metric_col_name}' from file '{toml_file.name}': {e}"
)
continue
# Sort data by global_step for proper line plotting
for metric_name in all_metrics_data_combined:
all_metrics_data_combined[metric_name] = (
all_metrics_data_combined[metric_name]
.sort_values(by=["track", "global_step"])
.reset_index(drop=True)
)
return all_metrics_data_combined
|