File size: 7,780 Bytes
c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
"""Utilities for persisting and aggregating GIFT-Eval results."""
import argparse
import csv
import glob
import logging
from pathlib import Path
import pandas as pd
from src.gift_eval.constants import (
ALL_DATASETS,
DATASET_PROPERTIES,
MED_LONG_DATASETS,
PRETTY_NAMES,
STANDARD_METRIC_NAMES,
)
from src.gift_eval.core import DatasetMetadata, EvaluationItem
logger = logging.getLogger(__name__)
def _ensure_results_csv(csv_file_path: Path) -> None:
if not csv_file_path.exists():
csv_file_path.parent.mkdir(parents=True, exist_ok=True)
with open(csv_file_path, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
header = (
["dataset", "model"]
+ [f"eval_metrics/{name}" for name in STANDARD_METRIC_NAMES]
+ ["domain", "num_variates"]
)
writer.writerow(header)
def write_results_to_disk(
items: list[EvaluationItem],
dataset_name: str,
output_dir: Path,
model_name: str,
create_plots: bool,
) -> None:
output_dir = output_dir / dataset_name
output_dir.mkdir(parents=True, exist_ok=True)
output_csv_path = output_dir / "results.csv"
_ensure_results_csv(output_csv_path)
try:
import matplotlib.pyplot as plt # Local import to avoid unnecessary dependency at module import time
except ImportError: # pragma: no cover - guard for optional dependency
plt = None
with open(output_csv_path, "a", newline="") as csvfile:
writer = csv.writer(csvfile)
for item in items:
md: DatasetMetadata = item.dataset_metadata
metric_values: list[float | None] = []
for metric_name in STANDARD_METRIC_NAMES:
value = item.metrics.get(metric_name, None)
if value is None:
metric_values.append(None)
else:
if hasattr(value, "__len__") and not isinstance(value, (str, bytes)) and len(value) == 1:
value = value[0]
elif hasattr(value, "item"):
value = value.item()
metric_values.append(value)
ds_key = md.key.lower()
props = DATASET_PROPERTIES.get(ds_key, {})
domain = props.get("domain", "unknown")
num_variates = props.get("num_variates", 1 if md.to_univariate else md.target_dim)
row = [md.full_name, model_name] + metric_values + [domain, num_variates]
writer.writerow(row)
if create_plots and item.figures and plt is not None:
plots_dir = output_dir / "plots" / md.key / md.term
plots_dir.mkdir(parents=True, exist_ok=True)
for fig, filename in item.figures:
filepath = plots_dir / filename
fig.savefig(filepath, dpi=300, bbox_inches="tight")
plt.close(fig)
logger.info(
"Evaluation complete for dataset '%s'. Results saved to %s",
dataset_name,
output_csv_path,
)
if create_plots:
logger.info("Plots saved under %s", output_dir / "plots")
def get_all_datasets_full_name() -> list[str]:
"""Get all possible dataset full names for validation."""
terms = ["short", "medium", "long"]
datasets_full_names: list[str] = []
for name in ALL_DATASETS:
for term in terms:
if term in ["medium", "long"] and name not in MED_LONG_DATASETS:
continue
if "/" in name:
ds_key, ds_freq = name.split("/")
ds_key = ds_key.lower()
ds_key = PRETTY_NAMES.get(ds_key, ds_key)
else:
ds_key = name.lower()
ds_key = PRETTY_NAMES.get(ds_key, ds_key)
ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")
datasets_full_names.append(f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}")
return datasets_full_names
def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
"""Aggregate results from multiple CSV files into a single dataframe."""
result_root = Path(result_root_dir)
logger.info("Aggregating results in: %s", result_root)
result_files = glob.glob(f"{result_root}/**/results.csv", recursive=True)
if not result_files:
logger.error("No result files found!")
return None
dataframes: list[pd.DataFrame] = []
for file in result_files:
try:
df = pd.read_csv(file)
if len(df) > 0:
dataframes.append(df)
else:
logger.warning("Empty file: %s", file)
except pd.errors.EmptyDataError:
logger.warning("Skipping empty file: %s", file)
except Exception as exc:
logger.error("Error reading %s: %s", file, exc)
if not dataframes:
logger.warning("No valid CSV files found to combine")
return None
combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset")
if len(combined_df) != len(set(combined_df.dataset)):
duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist()
logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets)
combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first")
logger.info("Removed duplicates, %s unique datasets remaining", len(combined_df))
logger.info("Combined results: %s datasets", len(combined_df))
all_datasets_full_name = get_all_datasets_full_name()
completed_experiments = combined_df.dataset.tolist()
completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name]
missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean]
logger.info("=== EXPERIMENT SUMMARY ===")
logger.info("Total expected datasets: %s", len(all_datasets_full_name))
logger.info("Completed experiments: %s", len(completed_experiments_clean))
logger.info("Missing/failed experiments: %s", len(missing_or_failed_experiments))
logger.info("Completed experiments:")
for idx, exp in enumerate(completed_experiments_clean, start=1):
logger.info(" %3d: %s", idx, exp)
if missing_or_failed_experiments:
logger.info("Missing or failed experiments:")
for idx, exp in enumerate(missing_or_failed_experiments, start=1):
logger.info(" %3d: %s", idx, exp)
completion_rate = (
len(completed_experiments_clean) / len(all_datasets_full_name) * 100 if all_datasets_full_name else 0.0
)
logger.info("Completion rate: %.1f%%", completion_rate)
output_file = result_root / "all_results.csv"
combined_df.to_csv(output_file, index=False)
logger.info("Combined results saved to: %s", output_file)
return combined_df
__all__ = [
"aggregate_results",
"get_all_datasets_full_name",
"write_results_to_disk",
]
def main() -> None:
"""CLI entry point for aggregating results from disk."""
parser = argparse.ArgumentParser(description="Aggregate GIFT-Eval results from multiple CSV files")
parser.add_argument(
"--result_root_dir",
type=str,
required=True,
help="Root directory containing result subdirectories",
)
args = parser.parse_args()
result_root_dir = Path(args.result_root_dir)
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger.info("Searching in directory: %s", result_root_dir)
aggregate_results(result_root_dir=result_root_dir)
if __name__ == "__main__":
main()
|