File size: 3,998 Bytes
3c45764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Exponential Moving Average (EMA) for model parameters.

EMA maintains a smoothed copy of model parameters that updates more slowly
than the training model, leading to more stable and better-performing models.
"""

import torch
from collections import OrderedDict
from copy import deepcopy


class EMA:
    """
    Exponential Moving Average for model parameters.
    
    Maintains a separate copy of model parameters that are updated using
    exponential moving average: ema = ema * rate + model * (1 - rate)
    
    Args:
        model: The model to create EMA for
        ema_rate: EMA decay rate (default: 0.999)
        device: Device to store EMA parameters on
    """
    
    def __init__(self, model, ema_rate=0.999, device=None):
        """
        Initialize EMA with a copy of model parameters.
        
        Args:
            model: PyTorch model to create EMA for
            ema_rate: Decay rate for EMA (0.999 means 99.9% old, 0.1% new)
            device: Device to store EMA parameters (defaults to model's device)
        """
        self.ema_rate = ema_rate
        self.device = device if device is not None else next(model.parameters()).device
        
        # Create EMA state dict (copy of model parameters)
        self.ema_state = OrderedDict()
        model_state = model.state_dict()
        
        for key, value in model_state.items():
            # Copy parameter data to EMA state
            self.ema_state[key] = deepcopy(value.data).to(self.device)
        
        # Parameters to ignore (not trainable, should be copied directly)
        self.ignore_keys = [
            x for x in self.ema_state.keys() 
            if ('running_' in x or 'num_batches_tracked' in x)
        ]
    
    def update(self, model):
        """
        Update EMA state with current model parameters.
        
        Should be called after optimizer.step() to update EMA with the
        newly optimized model weights.
        
        Args:
            model: The model to read parameters from
        """
        with torch.no_grad():
            source_state = model.state_dict()
            
            for key, value in self.ema_state.items():
                if key in self.ignore_keys:
                    # For non-trainable parameters (e.g., BatchNorm stats), copy directly
                    self.ema_state[key] = source_state[key].to(self.device)
                else:
                    # EMA update: ema = ema * rate + model * (1 - rate)
                    source_param = source_state[key].detach().to(self.device)
                    self.ema_state[key].mul_(self.ema_rate).add_(source_param, alpha=1 - self.ema_rate)
    
    def apply_to_model(self, model):
        """
        Load EMA state into model.
        
        This replaces model parameters with EMA parameters. Useful for
        validation or inference using the EMA model.
        
        Args:
            model: Model to load EMA state into
        """
        model.load_state_dict(self.ema_state)
    
    def state_dict(self):
        """
        Get EMA state dict for saving.
        
        Returns:
            OrderedDict: EMA state dictionary
        """
        return self.ema_state
    
    def load_state_dict(self, state_dict):
        """
        Load EMA state from saved checkpoint.
        
        Args:
            state_dict: EMA state dictionary to load
        """
        self.ema_state = OrderedDict(state_dict)
    
    def add_ignore_key(self, key_pattern):
        """
        Add a key pattern to ignore list.
        
        Parameters matching this pattern will be copied directly instead
        of using EMA update.
        
        Args:
            key_pattern: String pattern to match (e.g., 'relative_position_index')
        """
        matching_keys = [x for x in self.ema_state.keys() if key_pattern in x]
        self.ignore_keys.extend(matching_keys)
        # Remove duplicates
        self.ignore_keys = list(set(self.ignore_keys))