squad2-qa / src /utils /experiment_snapshot.py
Kimis Perros
Initial deployment
461f64f
"""
A 'snapshot' object used for experiment tracking.
Contains an experiment_name, the config used, the predictions produced,
the resulting metrics, the model parameters and optional metadata.
Benefits:
- Single, self-contained function call to persist an experiment run.
- Clean and automatic organization of experimental results facilitating model improvements.
"""
import time, json, torch
from pathlib import Path
from dataclasses import dataclass, asdict, is_dataclass
from typing import Dict, Any, Optional
from src.config.model_configs import BaseModelConfig
from src.etl.types import Prediction
from src.evaluation.metrics import Metrics
from src.models.bert_based_model import BertBasedQAModel
from src.models.base_qa_model import QAModel
DEFAULT_ENCODING = "utf-8"
@dataclass(frozen=True)
class ExperimentSnapshot:
experiment_name: str
config: BaseModelConfig
predictions: Dict[str, Prediction]
metrics: Metrics
metadata: Optional[Dict[str, Any]] = None
model: Optional[QAModel] = None # stores model reference
def _timestamped_dir(self, root: Path) -> Path:
ts = time.strftime("%Y%m%d_%H%M%S")
return root / f"{ts}_{self.experiment_name}"
def _as_config_dict(self) -> Dict[str, Any]:
return asdict(self.config) if is_dataclass(self.config) else dict(self.config)
def _manifest(self, run_id: str) -> Dict[str, Any]:
model_type = getattr(self.config, "MODEL_TYPE", None)
assert model_type is not None, "Unexpected empty model type."
mani = {
"run_id": run_id,
"experiment_name": self.experiment_name,
"model_type": model_type,
"artifacts": {
"config": "config.json",
"predictions": "predictions.json",
"metrics": "metrics.json",
"model": "model/",
},
}
# TODO - consider adding path to model checkpoints once we have those
if self.metadata:
mani["metadata"] = self.metadata # pass-through, unchanged
return mani
def save(self, experiments_root: Path = Path("experiments")) -> Path:
run_dir = self._timestamped_dir(experiments_root)
# raise error if accidentally attempting to overwrite previous run
run_dir.mkdir(parents=True, exist_ok=False)
(run_dir / "config.json").write_text(
json.dumps(self._as_config_dict(), indent=2), encoding=DEFAULT_ENCODING
)
(run_dir / "predictions.json").write_text(
json.dumps(
Prediction.flatten_predicted_answers(predictions=self.predictions),
ensure_ascii=False, # preserve original characters (e.g., accented characters etc.)
indent=2,
),
encoding=DEFAULT_ENCODING,
)
(run_dir / "metrics.json").write_text(
json.dumps(self.metrics.export_for_exp_tracking(), indent=2),
encoding=DEFAULT_ENCODING,
)
if self.model is not None:
self._save_model(run_dir / "model")
manifest = self._manifest(run_dir.name)
(run_dir / "manifest.json").write_text(
json.dumps(manifest, indent=2), encoding=DEFAULT_ENCODING
)
return run_dir
def _save_model(self, model_path: Path) -> None:
"""Save model weights and tokenizer."""
assert isinstance(
self.model, BertBasedQAModel
), "Currently model saving is only supported for the BertBasedQAModel type."
model_path.mkdir(parents=True, exist_ok=True)
# Save model weights
torch.save(self.model.qa_module.state_dict(), model_path / "pytorch_model.bin")
# Save tokenizer
self.model.tokenizer.save_pretrained(model_path / "tokenizer")
print(f"Model saved to {model_path}")