""" A 'snapshot' object used for experiment tracking. Contains an experiment_name, the config used, the predictions produced, the resulting metrics, the model parameters and optional metadata. Benefits: - Single, self-contained function call to persist an experiment run. - Clean and automatic organization of experimental results facilitating model improvements. """ import time, json, torch from pathlib import Path from dataclasses import dataclass, asdict, is_dataclass from typing import Dict, Any, Optional from src.config.model_configs import BaseModelConfig from src.etl.types import Prediction from src.evaluation.metrics import Metrics from src.models.bert_based_model import BertBasedQAModel from src.models.base_qa_model import QAModel DEFAULT_ENCODING = "utf-8" @dataclass(frozen=True) class ExperimentSnapshot: experiment_name: str config: BaseModelConfig predictions: Dict[str, Prediction] metrics: Metrics metadata: Optional[Dict[str, Any]] = None model: Optional[QAModel] = None # stores model reference def _timestamped_dir(self, root: Path) -> Path: ts = time.strftime("%Y%m%d_%H%M%S") return root / f"{ts}_{self.experiment_name}" def _as_config_dict(self) -> Dict[str, Any]: return asdict(self.config) if is_dataclass(self.config) else dict(self.config) def _manifest(self, run_id: str) -> Dict[str, Any]: model_type = getattr(self.config, "MODEL_TYPE", None) assert model_type is not None, "Unexpected empty model type." mani = { "run_id": run_id, "experiment_name": self.experiment_name, "model_type": model_type, "artifacts": { "config": "config.json", "predictions": "predictions.json", "metrics": "metrics.json", "model": "model/", }, } # TODO - consider adding path to model checkpoints once we have those if self.metadata: mani["metadata"] = self.metadata # pass-through, unchanged return mani def save(self, experiments_root: Path = Path("experiments")) -> Path: run_dir = self._timestamped_dir(experiments_root) # raise error if accidentally attempting to overwrite previous run run_dir.mkdir(parents=True, exist_ok=False) (run_dir / "config.json").write_text( json.dumps(self._as_config_dict(), indent=2), encoding=DEFAULT_ENCODING ) (run_dir / "predictions.json").write_text( json.dumps( Prediction.flatten_predicted_answers(predictions=self.predictions), ensure_ascii=False, # preserve original characters (e.g., accented characters etc.) indent=2, ), encoding=DEFAULT_ENCODING, ) (run_dir / "metrics.json").write_text( json.dumps(self.metrics.export_for_exp_tracking(), indent=2), encoding=DEFAULT_ENCODING, ) if self.model is not None: self._save_model(run_dir / "model") manifest = self._manifest(run_dir.name) (run_dir / "manifest.json").write_text( json.dumps(manifest, indent=2), encoding=DEFAULT_ENCODING ) return run_dir def _save_model(self, model_path: Path) -> None: """Save model weights and tokenizer.""" assert isinstance( self.model, BertBasedQAModel ), "Currently model saving is only supported for the BertBasedQAModel type." model_path.mkdir(parents=True, exist_ok=True) # Save model weights torch.save(self.model.qa_module.state_dict(), model_path / "pytorch_model.bin") # Save tokenizer self.model.tokenizer.save_pretrained(model_path / "tokenizer") print(f"Model saved to {model_path}")