| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass, field |
| | from datetime import timedelta |
| | from io import BytesIO |
| | from typing import Any, Dict, List |
| |
|
| | import torch |
| | from torch.distributed.checkpoint.stateful import Stateful |
| |
|
| |
|
| | @dataclass |
| | class TrainState(Stateful): |
| | step: int = 0 |
| | skipped_step: int = 0 |
| | token: int = 0 |
| | elapsed: timedelta = timedelta(0) |
| | global_avg_losses: List[float] = field(default_factory=list) |
| | global_max_losses: List[float] = field(default_factory=list) |
| | log_steps: List[int] = field(default_factory=list) |
| |
|
| | def state_dict(self) -> Dict[str, Any]: |
| | |
| | |
| | global_avg_losses_bytes = BytesIO() |
| | torch.save(self.global_avg_losses, global_avg_losses_bytes) |
| | global_max_losses_bytes = BytesIO() |
| | torch.save(self.global_max_losses, global_max_losses_bytes) |
| | log_steps_bytes = BytesIO() |
| | torch.save(self.log_steps, log_steps_bytes) |
| | return { |
| | "step": torch.tensor(self.step, dtype=torch.int32), |
| | "skipped_step": torch.tensor(self.skipped_step, dtype=torch.int32), |
| | "token": torch.tensor(self.token, dtype=torch.int64), |
| | "elapsed": self.elapsed, |
| | "global_avg_losses": global_avg_losses_bytes, |
| | "global_max_losses": global_max_losses_bytes, |
| | "log_steps": log_steps_bytes, |
| | } |
| |
|
| | def load_state_dict(self, state_dict) -> None: |
| | self.step = state_dict["step"].item() |
| | self.skipped_step = state_dict.get("skipped_step", 0).item() |
| | self.token = state_dict["token"].item() |
| | self.elapsed = state_dict["elapsed"] |
| | state_dict["global_avg_losses"].seek(0) |
| | self.global_avg_losses = torch.load( |
| | state_dict["global_avg_losses"], weights_only=False |
| | ) |
| | state_dict["global_max_losses"].seek(0) |
| | self.global_max_losses = torch.load( |
| | state_dict["global_max_losses"], weights_only=False |
| | ) |
| | state_dict["log_steps"].seek(0) |
| | self.log_steps = torch.load(state_dict["log_steps"], weights_only=False) |
| |
|