flowillower_app_template / src /data_loader.py
YeCanming
feat: presentation
4dcd117
# data_loader.py
from pathlib import Path
from typing import Dict, List, Any
import pandas as pd
import tomli
import streamlit as st
from functools import lru_cache # For non-Streamlit specific caching if needed
# Assuming utils.py is in the same directory
from infra import (
DATA_ROOT_PATH,
) # Used for ensuring directory exists
# --- Cache Clearing Functions ---
# These are more specific cache clearing functions that can be called by model methods.
def clear_study_cache():
"""Clears all study discovery cache."""
discover_studies_cached.clear()
st.toast("所有 Study 发现缓存已清除。")
def clear_trial_cache():
"""Clears all trial-related data loading caches."""
# This is a bit broad. Ideally, clear caches for specific trials/studies.
load_input_variables_from_path.clear()
load_all_metrics_for_trial_path.clear()
discover_trials_from_path.clear()
st.toast("所有 Trial 数据加载缓存已清除。")
def clear_specific_trial_metric_cache(trial_path: Path):
load_all_metrics_for_trial_path.clear() # This clears the whole cache for this func
# For more granular control with @st.cache_data, you'd typically rely on Streamlit's
# automatic cache invalidation based on input args, or rerun.
# If using lru_cache, you could do: load_all_metrics_for_trial_path.cache_clear()
# but for st.cache_data, clearing for specific args is not direct.
# The common pattern is to clear the entire function's cache.
st.toast(f"Trial '{trial_path.name}' 的指标缓存已清除 (函数级别)。")
def clear_specific_trial_input_vars_cache(trial_path: Path):
load_input_variables_from_path.clear()
st.toast(f"Trial '{trial_path.name}' 的参数缓存已清除 (函数级别)。")
def clear_specific_study_trial_discovery_cache(study_path: Path):
discover_trials_from_path.clear()
st.toast(f"Study '{study_path.name}' 的 Trial 发现缓存已清除 (函数级别)。")
# --- Data Discovery and Loading Functions (Cached) ---
def ensure_data_directory_exists(data_path: Path = DATA_ROOT_PATH):
"""Ensures the root data directory exists."""
if not data_path.exists():
try:
data_path.mkdir(parents=True, exist_ok=True)
st.info(f"数据目录 {data_path} 已创建。")
except Exception as e:
st.error(f"创建数据目录 {data_path} 失败: {e}")
st.stop()
elif not data_path.is_dir():
st.error(f"路径 {data_path} 已存在但不是一个目录。")
st.stop()
@st.cache_data(ttl=3600) # Cache for 1 hour, or adjust as needed
def discover_studies_cached(
_data_root: Path,
) -> Dict[
str, Any
]: # Return type hint as Any to avoid circular dep with data_models.Study
"""
Scans the data_root for study directories and returns a dictionary
mapping study names to Study objects (or just their paths initially).
The actual Study object creation happens in the main app for now.
"""
# To avoid issues with caching complex objects directly, or circular dependencies,
# this function can return simpler structures like Dict[str, Path]
# and the main app or model can instantiate Study objects.
# For this iteration, we'll import Study here for convenience, assuming careful structure.
from data_models import (
Study,
) # Local import to help with potential circularity if models grow complex
if not _data_root.is_dir():
return {}
studies = {}
for d in _data_root.iterdir():
if d.is_dir():
studies[d.name] = Study(name=d.name, path=d)
return studies
@st.cache_data(ttl=3600)
def discover_trials_from_path(_study_path: Path) -> Dict[str, Path]:
"""Scans a study_path for trial directories."""
if not _study_path.is_dir():
return {}
trials = {}
for d in _study_path.iterdir():
if d.is_dir():
trials[d.name] = d
return trials
@st.cache_data(ttl=3600)
def load_input_variables_from_path(_trial_path: Path) -> Dict[str, Any]:
"""Loads input_variables.toml from a trial directory."""
input_vars_file = _trial_path / "input_variables.toml"
if input_vars_file.exists():
try:
with open(input_vars_file, "rb") as f:
return tomli.load(f)
except tomli.TOMLDecodeError:
# st.error(f"错误:无法解析 input_variables.toml 文件于 {_trial_path}") # Avoid st.error in cached funcs if possible
print(f"Error parsing input_variables.toml at {_trial_path}")
return {}
return {}
def _load_single_metric_toml(_toml_file_path: Path) -> pd.DataFrame:
"""Loads metrics from a single TOML file into a DataFrame."""
if not _toml_file_path.exists():
return pd.DataFrame()
try:
with open(_toml_file_path, "rb") as f:
data = tomli.load(f)
metrics_list = data.get("metrics", [])
if not metrics_list:
return pd.DataFrame()
return pd.DataFrame(metrics_list)
except tomli.TOMLDecodeError:
print(f"Error parsing TOML file: {_toml_file_path.name}")
return pd.DataFrame()
except Exception as e:
print(f"Error loading {_toml_file_path.name}: {e}")
return pd.DataFrame()
@st.cache_data(ttl=300) # Cache metric data for 5 minutes
def load_all_metrics_for_trial_path(_trial_path: Path) -> Dict[str, pd.DataFrame]:
"""
Loads all metrics from all tracks in a trial.
Returns a dictionary where keys are metric names (e.g., 'loss', 'accuracy')
and values are DataFrames containing 'global_step', 'value', and 'track'.
"""
scalar_dir = _trial_path / "logs" / "scalar"
if not scalar_dir.is_dir():
return {}
all_metrics_data_combined: Dict[str, pd.DataFrame] = {}
for toml_file in scalar_dir.glob("metrics_*.toml"):
track_name = toml_file.stem.replace("metrics_", "")
df_track = _load_single_metric_toml(toml_file)
if df_track.empty or "global_step" not in df_track.columns:
continue
id_vars = ["global_step"]
value_vars = [col for col in df_track.columns if col not in id_vars]
if not value_vars:
continue
# Process each metric column individually to build up the combined DataFrame
for metric_col_name in value_vars:
try:
# Create a DataFrame for the current metric and track
current_metric_df = df_track[["global_step", metric_col_name]].copy()
current_metric_df.rename(
columns={metric_col_name: "value"}, inplace=True
)
current_metric_df["track"] = track_name
current_metric_df["value"] = pd.to_numeric(
current_metric_df["value"], errors="coerce"
)
current_metric_df.dropna(subset=["value"], inplace=True)
if current_metric_df.empty:
continue
# Append to the combined DataFrame for this metric_col_name
if metric_col_name not in all_metrics_data_combined:
all_metrics_data_combined[metric_col_name] = current_metric_df
else:
all_metrics_data_combined[metric_col_name] = pd.concat(
[all_metrics_data_combined[metric_col_name], current_metric_df],
ignore_index=True,
)
except Exception as e:
print(
f"Error processing metric '{metric_col_name}' from file '{toml_file.name}': {e}"
)
continue
# Sort data by global_step for proper line plotting
for metric_name in all_metrics_data_combined:
all_metrics_data_combined[metric_name] = (
all_metrics_data_combined[metric_name]
.sort_values(by=["track", "global_step"])
.reset_index(drop=True)
)
return all_metrics_data_combined