File size: 3,504 Bytes
eca55dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import Any, Optional, Union

import torch
from torch.optim.swa_utils import get_ema_avg_fn
from lightning.pytorch.utilities.rank_zero import rank_zero_info

from src.callbacks.lightning_weight_averaging import WeightAveraging


class WarmupEMAWeightAveraging(WeightAveraging):
    def __init__(
        self,
        warmup_pct: float,
        enabled: bool = True,
        decay: Optional[float] = None,
        decay_numerator: float = 20.0,
        update_every_n_steps: int = 1,
        update_starting_at_step: Optional[int] = None,
        device: Optional[Union[torch.device, str, int]] = None,
        use_buffers: bool = True,
    ) -> None:
        super().__init__(device=device, use_buffers=use_buffers)
        self.enabled = enabled
        self.warmup_pct = warmup_pct
        self.decay = decay
        self.decay_numerator = decay_numerator
        self.update_every_n_steps = update_every_n_steps
        self.update_starting_at_step = update_starting_at_step

        self.resolved_decay: Optional[float] = None
        self.resolved_start_step: Optional[int] = None

    def setup(self, trainer, pl_module, stage: str) -> None:
        if not self.enabled:
            rank_zero_info("WarmupEMAWeightAveraging is disabled by config.")
            return

        if stage != "fit":
            return super().setup(trainer, pl_module, stage)

        if trainer.max_steps and trainer.max_steps > 0:
            total_steps = trainer.max_steps
        else:
            total_steps = trainer.estimated_stepping_batches

        if total_steps <= 0:
            total_steps = 100000

        warmup_steps = int(total_steps * self.warmup_pct)
        if self.update_starting_at_step is None:
            self.resolved_start_step = warmup_steps
        else:
            self.resolved_start_step = self.update_starting_at_step

        if self.decay is None:
            active_steps = max(1, total_steps - self.resolved_start_step)
            computed_decay = 1.0 - (self.decay_numerator / active_steps)
            self.resolved_decay = min(0.99999, max(0.9, computed_decay))
        else:
            self.resolved_decay = self.decay

        self._kwargs["avg_fn"] = get_ema_avg_fn(decay=self.resolved_decay)

        super().setup(trainer, pl_module, stage)

        rank_zero_info(
            "WarmupEMAWeightAveraging configured: "
            f"total_steps={total_steps}, warmup_steps={warmup_steps}, "
            f"start_step={self.resolved_start_step}, decay={self.resolved_decay:.8f}"
        )

    def should_update(
        self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None
    ) -> bool:
        if step_idx is None:
            return False

        if not self.enabled:
            return False

        if self.resolved_start_step is None:
            return False

        if step_idx < self.resolved_start_step:
            return False

        if self.update_every_n_steps <= 0:
            return False

        return step_idx % self.update_every_n_steps == 0

    def state_dict(self) -> dict[str, Any]:
        state = super().state_dict()
        state["resolved_decay"] = self.resolved_decay
        state["resolved_start_step"] = self.resolved_start_step
        return state

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        super().load_state_dict(state_dict)
        self.resolved_decay = state_dict.get("resolved_decay")
        self.resolved_start_step = state_dict.get("resolved_start_step")