File size: 3,866 Bytes
461f64f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
A 'snapshot' object used for experiment tracking.
Contains an experiment_name, the config used, the predictions produced,
the resulting metrics, the model parameters and optional metadata.

Benefits:
    - Single, self-contained function call to persist an experiment run.
    - Clean and automatic organization of experimental results facilitating model improvements.
"""

import time, json, torch
from pathlib import Path
from dataclasses import dataclass, asdict, is_dataclass
from typing import Dict, Any, Optional
from src.config.model_configs import BaseModelConfig
from src.etl.types import Prediction
from src.evaluation.metrics import Metrics
from src.models.bert_based_model import BertBasedQAModel
from src.models.base_qa_model import QAModel

DEFAULT_ENCODING = "utf-8"


@dataclass(frozen=True)
class ExperimentSnapshot:
    experiment_name: str
    config: BaseModelConfig
    predictions: Dict[str, Prediction]
    metrics: Metrics
    metadata: Optional[Dict[str, Any]] = None
    model: Optional[QAModel] = None  # stores model reference

    def _timestamped_dir(self, root: Path) -> Path:
        ts = time.strftime("%Y%m%d_%H%M%S")
        return root / f"{ts}_{self.experiment_name}"

    def _as_config_dict(self) -> Dict[str, Any]:
        return asdict(self.config) if is_dataclass(self.config) else dict(self.config)

    def _manifest(self, run_id: str) -> Dict[str, Any]:
        model_type = getattr(self.config, "MODEL_TYPE", None)
        assert model_type is not None, "Unexpected empty model type."
        mani = {
            "run_id": run_id,
            "experiment_name": self.experiment_name,
            "model_type": model_type,
            "artifacts": {
                "config": "config.json",
                "predictions": "predictions.json",
                "metrics": "metrics.json",
                "model": "model/",
            },
        }

        # TODO - consider adding path to model checkpoints once we have those
        if self.metadata:
            mani["metadata"] = self.metadata  # pass-through, unchanged
        return mani

    def save(self, experiments_root: Path = Path("experiments")) -> Path:
        run_dir = self._timestamped_dir(experiments_root)
        # raise error if accidentally attempting to overwrite previous run
        run_dir.mkdir(parents=True, exist_ok=False)

        (run_dir / "config.json").write_text(
            json.dumps(self._as_config_dict(), indent=2), encoding=DEFAULT_ENCODING
        )
        (run_dir / "predictions.json").write_text(
            json.dumps(
                Prediction.flatten_predicted_answers(predictions=self.predictions),
                ensure_ascii=False,  # preserve original characters (e.g., accented characters etc.)
                indent=2,
            ),
            encoding=DEFAULT_ENCODING,
        )
        (run_dir / "metrics.json").write_text(
            json.dumps(self.metrics.export_for_exp_tracking(), indent=2),
            encoding=DEFAULT_ENCODING,
        )

        if self.model is not None:
            self._save_model(run_dir / "model")

        manifest = self._manifest(run_dir.name)
        (run_dir / "manifest.json").write_text(
            json.dumps(manifest, indent=2), encoding=DEFAULT_ENCODING
        )
        return run_dir

    def _save_model(self, model_path: Path) -> None:
        """Save model weights and tokenizer."""
        assert isinstance(
            self.model, BertBasedQAModel
        ), "Currently model saving is only supported for the BertBasedQAModel type."
        model_path.mkdir(parents=True, exist_ok=True)

        # Save model weights
        torch.save(self.model.qa_module.state_dict(), model_path / "pytorch_model.bin")
        # Save tokenizer
        self.model.tokenizer.save_pretrained(model_path / "tokenizer")
        print(f"Model saved to {model_path}")