|
|
""" |
|
|
A 'snapshot' object used for experiment tracking. |
|
|
Contains an experiment_name, the config used, the predictions produced, |
|
|
the resulting metrics, the model parameters and optional metadata. |
|
|
|
|
|
Benefits: |
|
|
- Single, self-contained function call to persist an experiment run. |
|
|
- Clean and automatic organization of experimental results facilitating model improvements. |
|
|
""" |
|
|
|
|
|
import time, json, torch |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass, asdict, is_dataclass |
|
|
from typing import Dict, Any, Optional |
|
|
from src.config.model_configs import BaseModelConfig |
|
|
from src.etl.types import Prediction |
|
|
from src.evaluation.metrics import Metrics |
|
|
from src.models.bert_based_model import BertBasedQAModel |
|
|
from src.models.base_qa_model import QAModel |
|
|
|
|
|
DEFAULT_ENCODING = "utf-8" |
|
|
|
|
|
|
|
|
@dataclass(frozen=True) |
|
|
class ExperimentSnapshot: |
|
|
experiment_name: str |
|
|
config: BaseModelConfig |
|
|
predictions: Dict[str, Prediction] |
|
|
metrics: Metrics |
|
|
metadata: Optional[Dict[str, Any]] = None |
|
|
model: Optional[QAModel] = None |
|
|
|
|
|
def _timestamped_dir(self, root: Path) -> Path: |
|
|
ts = time.strftime("%Y%m%d_%H%M%S") |
|
|
return root / f"{ts}_{self.experiment_name}" |
|
|
|
|
|
def _as_config_dict(self) -> Dict[str, Any]: |
|
|
return asdict(self.config) if is_dataclass(self.config) else dict(self.config) |
|
|
|
|
|
def _manifest(self, run_id: str) -> Dict[str, Any]: |
|
|
model_type = getattr(self.config, "MODEL_TYPE", None) |
|
|
assert model_type is not None, "Unexpected empty model type." |
|
|
mani = { |
|
|
"run_id": run_id, |
|
|
"experiment_name": self.experiment_name, |
|
|
"model_type": model_type, |
|
|
"artifacts": { |
|
|
"config": "config.json", |
|
|
"predictions": "predictions.json", |
|
|
"metrics": "metrics.json", |
|
|
"model": "model/", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
if self.metadata: |
|
|
mani["metadata"] = self.metadata |
|
|
return mani |
|
|
|
|
|
def save(self, experiments_root: Path = Path("experiments")) -> Path: |
|
|
run_dir = self._timestamped_dir(experiments_root) |
|
|
|
|
|
run_dir.mkdir(parents=True, exist_ok=False) |
|
|
|
|
|
(run_dir / "config.json").write_text( |
|
|
json.dumps(self._as_config_dict(), indent=2), encoding=DEFAULT_ENCODING |
|
|
) |
|
|
(run_dir / "predictions.json").write_text( |
|
|
json.dumps( |
|
|
Prediction.flatten_predicted_answers(predictions=self.predictions), |
|
|
ensure_ascii=False, |
|
|
indent=2, |
|
|
), |
|
|
encoding=DEFAULT_ENCODING, |
|
|
) |
|
|
(run_dir / "metrics.json").write_text( |
|
|
json.dumps(self.metrics.export_for_exp_tracking(), indent=2), |
|
|
encoding=DEFAULT_ENCODING, |
|
|
) |
|
|
|
|
|
if self.model is not None: |
|
|
self._save_model(run_dir / "model") |
|
|
|
|
|
manifest = self._manifest(run_dir.name) |
|
|
(run_dir / "manifest.json").write_text( |
|
|
json.dumps(manifest, indent=2), encoding=DEFAULT_ENCODING |
|
|
) |
|
|
return run_dir |
|
|
|
|
|
def _save_model(self, model_path: Path) -> None: |
|
|
"""Save model weights and tokenizer.""" |
|
|
assert isinstance( |
|
|
self.model, BertBasedQAModel |
|
|
), "Currently model saving is only supported for the BertBasedQAModel type." |
|
|
model_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
torch.save(self.model.qa_module.state_dict(), model_path / "pytorch_model.bin") |
|
|
|
|
|
self.model.tokenizer.save_pretrained(model_path / "tokenizer") |
|
|
print(f"Model saved to {model_path}") |
|
|
|