File size: 7,780 Bytes
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
0a58567
c4b87d2
 
 
 
 
 
0a58567
 
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
0a58567
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""Utilities for persisting and aggregating GIFT-Eval results."""

import argparse
import csv
import glob
import logging
from pathlib import Path

import pandas as pd

from src.gift_eval.constants import (
    ALL_DATASETS,
    DATASET_PROPERTIES,
    MED_LONG_DATASETS,
    PRETTY_NAMES,
    STANDARD_METRIC_NAMES,
)
from src.gift_eval.core import DatasetMetadata, EvaluationItem

logger = logging.getLogger(__name__)


def _ensure_results_csv(csv_file_path: Path) -> None:
    if not csv_file_path.exists():
        csv_file_path.parent.mkdir(parents=True, exist_ok=True)
        with open(csv_file_path, "w", newline="") as csvfile:
            writer = csv.writer(csvfile)
            header = (
                ["dataset", "model"]
                + [f"eval_metrics/{name}" for name in STANDARD_METRIC_NAMES]
                + ["domain", "num_variates"]
            )
            writer.writerow(header)


def write_results_to_disk(
    items: list[EvaluationItem],
    dataset_name: str,
    output_dir: Path,
    model_name: str,
    create_plots: bool,
) -> None:
    output_dir = output_dir / dataset_name
    output_dir.mkdir(parents=True, exist_ok=True)
    output_csv_path = output_dir / "results.csv"
    _ensure_results_csv(output_csv_path)

    try:
        import matplotlib.pyplot as plt  # Local import to avoid unnecessary dependency at module import time
    except ImportError:  # pragma: no cover - guard for optional dependency
        plt = None

    with open(output_csv_path, "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        for item in items:
            md: DatasetMetadata = item.dataset_metadata
            metric_values: list[float | None] = []
            for metric_name in STANDARD_METRIC_NAMES:
                value = item.metrics.get(metric_name, None)
                if value is None:
                    metric_values.append(None)
                else:
                    if hasattr(value, "__len__") and not isinstance(value, (str, bytes)) and len(value) == 1:
                        value = value[0]
                    elif hasattr(value, "item"):
                        value = value.item()
                    metric_values.append(value)

            ds_key = md.key.lower()
            props = DATASET_PROPERTIES.get(ds_key, {})
            domain = props.get("domain", "unknown")
            num_variates = props.get("num_variates", 1 if md.to_univariate else md.target_dim)

            row = [md.full_name, model_name] + metric_values + [domain, num_variates]
            writer.writerow(row)

            if create_plots and item.figures and plt is not None:
                plots_dir = output_dir / "plots" / md.key / md.term
                plots_dir.mkdir(parents=True, exist_ok=True)
                for fig, filename in item.figures:
                    filepath = plots_dir / filename
                    fig.savefig(filepath, dpi=300, bbox_inches="tight")
                    plt.close(fig)

    logger.info(
        "Evaluation complete for dataset '%s'. Results saved to %s",
        dataset_name,
        output_csv_path,
    )
    if create_plots:
        logger.info("Plots saved under %s", output_dir / "plots")


def get_all_datasets_full_name() -> list[str]:
    """Get all possible dataset full names for validation."""

    terms = ["short", "medium", "long"]
    datasets_full_names: list[str] = []

    for name in ALL_DATASETS:
        for term in terms:
            if term in ["medium", "long"] and name not in MED_LONG_DATASETS:
                continue

            if "/" in name:
                ds_key, ds_freq = name.split("/")
                ds_key = ds_key.lower()
                ds_key = PRETTY_NAMES.get(ds_key, ds_key)
            else:
                ds_key = name.lower()
                ds_key = PRETTY_NAMES.get(ds_key, ds_key)
                ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")

            datasets_full_names.append(f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}")

    return datasets_full_names


def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
    """Aggregate results from multiple CSV files into a single dataframe."""

    result_root = Path(result_root_dir)

    logger.info("Aggregating results in: %s", result_root)

    result_files = glob.glob(f"{result_root}/**/results.csv", recursive=True)

    if not result_files:
        logger.error("No result files found!")
        return None

    dataframes: list[pd.DataFrame] = []
    for file in result_files:
        try:
            df = pd.read_csv(file)
            if len(df) > 0:
                dataframes.append(df)
            else:
                logger.warning("Empty file: %s", file)
        except pd.errors.EmptyDataError:
            logger.warning("Skipping empty file: %s", file)
        except Exception as exc:
            logger.error("Error reading %s: %s", file, exc)

    if not dataframes:
        logger.warning("No valid CSV files found to combine")
        return None

    combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset")

    if len(combined_df) != len(set(combined_df.dataset)):
        duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist()
        logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets)
        combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first")
        logger.info("Removed duplicates, %s unique datasets remaining", len(combined_df))

    logger.info("Combined results: %s datasets", len(combined_df))

    all_datasets_full_name = get_all_datasets_full_name()
    completed_experiments = combined_df.dataset.tolist()

    completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name]
    missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean]

    logger.info("=== EXPERIMENT SUMMARY ===")
    logger.info("Total expected datasets: %s", len(all_datasets_full_name))
    logger.info("Completed experiments: %s", len(completed_experiments_clean))
    logger.info("Missing/failed experiments: %s", len(missing_or_failed_experiments))

    logger.info("Completed experiments:")
    for idx, exp in enumerate(completed_experiments_clean, start=1):
        logger.info("  %3d: %s", idx, exp)

    if missing_or_failed_experiments:
        logger.info("Missing or failed experiments:")
        for idx, exp in enumerate(missing_or_failed_experiments, start=1):
            logger.info("  %3d: %s", idx, exp)

    completion_rate = (
        len(completed_experiments_clean) / len(all_datasets_full_name) * 100 if all_datasets_full_name else 0.0
    )
    logger.info("Completion rate: %.1f%%", completion_rate)

    output_file = result_root / "all_results.csv"
    combined_df.to_csv(output_file, index=False)
    logger.info("Combined results saved to: %s", output_file)

    return combined_df


__all__ = [
    "aggregate_results",
    "get_all_datasets_full_name",
    "write_results_to_disk",
]


def main() -> None:
    """CLI entry point for aggregating results from disk."""

    parser = argparse.ArgumentParser(description="Aggregate GIFT-Eval results from multiple CSV files")
    parser.add_argument(
        "--result_root_dir",
        type=str,
        required=True,
        help="Root directory containing result subdirectories",
    )

    args = parser.parse_args()
    result_root_dir = Path(args.result_root_dir)

    logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
    logger.info("Searching in directory: %s", result_root_dir)

    aggregate_results(result_root_dir=result_root_dir)


if __name__ == "__main__":
    main()