BayesOptGPT / scripts /evaluate.py
BlackWizardX11's picture
Redeploying backend
2d3dd8e verified
Raw
History Blame Contribute Delete
2.51 kB
from __future__ import annotations
import argparse
from pathlib import Path
from bayes_gp_llmops.evaluation.pipeline import run_evaluation_pipeline
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Evaluate a trained AG News classifier checkpoint."
)
parser.add_argument("--data-config", type=Path, default=Path("configs/data.yaml"))
parser.add_argument("--model-config", type=Path, default=Path("configs/model.yaml"))
parser.add_argument("--train-config", type=Path, default=Path("configs/train.yaml"))
parser.add_argument("--checkpoint", type=Path, default=None)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--output-dir", type=Path, default=Path("artifacts/evaluation"))
parser.add_argument("--disable-temperature-scaling", action="store_true")
parser.add_argument("--debug", action="store_true")
return parser.parse_args()
def main() -> None:
args = parse_args()
artifacts = run_evaluation_pipeline(
data_config_path=args.data_config,
model_config_path=args.model_config,
train_config_path=args.train_config,
checkpoint_path=args.checkpoint,
device_override=args.device,
output_dir=args.output_dir,
enable_temperature_scaling=not args.disable_temperature_scaling,
debug_mode=args.debug,
)
print(f"evaluation_output_dir={artifacts.output_dir}")
print(f"metrics_validation={artifacts.metrics_validation_path}")
print(f"metrics_test={artifacts.metrics_test_path}")
if artifacts.metrics_validation_calibrated_path is not None:
print(f"metrics_validation_calibrated={artifacts.metrics_validation_calibrated_path}")
if artifacts.metrics_test_calibrated_path is not None:
print(f"metrics_test_calibrated={artifacts.metrics_test_calibrated_path}")
print(f"predictions_validation={artifacts.predictions_validation_path}")
print(f"predictions_test={artifacts.predictions_test_path}")
print(f"temperature_scaling={artifacts.temperature_scaling_path}")
print(f"confusion_matrix_plot={artifacts.confusion_matrix_plot_path}")
print(f"reliability_diagram_plot={artifacts.reliability_diagram_plot_path}")
print(f"confidence_histogram_plot={artifacts.confidence_histogram_plot_path}")
print(f"entropy_histogram_plot={artifacts.entropy_histogram_plot_path}")
if __name__ == "__main__":
main()