| import numpy as np |
| import torch |
| import logging |
| from copy import deepcopy |
|
|
|
|
| def unwrap_model(model): |
| return model.module if hasattr(model, "module") else model |
|
|
|
|
| class ModelAverager(object): |
| def __init__(self, model, methods: str): |
| self.model = model |
| self.avgs_dict = {} |
| for method in methods.split(","): |
| args = method.split("_") |
| |
| |
| self.avgs_dict[method] = Averager(model, args) |
|
|
| def step(self): |
| for avg in self.avgs_dict.values(): |
| avg.step() |
|
|
|
|
| class Averager(object): |
| def __init__(self, model, args): |
| self.model = model |
| self.method = args[0] |
| self.update_counter = 1 |
| self.step_counter = 1 |
| self.freq = 1 if (len(args) <= 2) else int(args[2]) |
| if self.method == "none": |
| self.av_model = model |
| return |
| else: |
| self.av_model = deepcopy(unwrap_model(model)) |
|
|
| if self.method == "poly": |
| self.eta = 0.0 if len(args) <= 1 else float(args[1]) |
| elif self.method == "ema": |
| self.gamma = 0.99 if len(args) <= 1 else float(args[1]) |
|
|
| elif self.method == "cosine": |
| pass |
| else: |
| print(f"Unknown averaging method {self.method}") |
|
|
| def step(self): |
| if self.update_counter != self.freq: |
| pass |
| else: |
| self.update() |
| self.update_counter += 1 |
| if self.update_counter > self.freq: |
| self.update_counter = 1 |
| return |
|
|
| def update(self): |
| method = self.method |
| if method == "none": |
| return |
| t = self.step_counter |
| |
| |
| model_sd = self.model.state_dict() |
| av_sd = self.av_model.state_dict() |
| if self.method == "cosine" or self.method == "degree": |
| pass |
| first_k_av_sd = list(av_sd.keys())[0] |
| for k in model_sd.keys(): |
| av_sd_k = k |
| if k.startswith("module") and not first_k_av_sd.startswith("module"): |
| av_sd_k = k[len("module.") :] |
| if isinstance(av_sd[av_sd_k], (torch.LongTensor, torch.cuda.LongTensor)): |
| |
| av_sd[av_sd_k].copy_(model_sd[k]) |
| continue |
| if method == "poly": |
| |
| |
| av_sd[av_sd_k].mul_(1 - ((self.eta + 1) / (self.eta + t))).add_( |
| model_sd[k], alpha=(self.eta + 1) / (self.eta + t) |
| ) |
| if method == "ema": |
| |
| av_sd[av_sd_k].mul_(self.gamma).add_(model_sd[k], alpha=1 - self.gamma) |
| self.step_counter += 1 |
|
|
| def reset(self): |
| self.step_counter = 2 |
|
|
| @property |
| def averaged_model(self): |
| return self.av_model |
|
|
| def get_state_dict_avg(self): |
| state_dict = { |
| "update_counter": self.update_counter, |
| "step_counter": self.step_counter, |
| "freq": self.freq, |
| "av_model_sd": unwrap_model(self.av_model).state_dict(), |
| "method": self.method, |
| "eta": self.eta if hasattr(self, "eta") else None, |
| "gamma": self.gamma if hasattr(self, "gamma") else None, |
| "suffix_steps": self.suffix_steps if hasattr(self, "suffix_steps") else None, |
| "power": self.power if hasattr(self, "power") else None, |
| "start": self.start if hasattr(self, "start") else None, |
| } |
| return state_dict |
|
|
| def load_state_dict_avg(self, state_dict): |
| self.update_counter = state_dict["update_counter"] |
| self.step_counter = state_dict["step_counter"] |
| self.freq = state_dict["freq"] |
| self.method = state_dict["method"] |
| self.av_model.load_state_dict(state_dict["av_model_sd"]) |
| if hasattr(self, "eta"): |
| self.eta = state_dict["eta"] |
| if hasattr(self, "gamma"): |
| self.gamma = state_dict["gamma"] |
| if hasattr(self, "suffix_steps"): |
| self.suffix_steps = state_dict["suffix_steps"] |
| if hasattr(self, "power"): |
| self.power = state_dict["power"] |
| if hasattr(self, "start"): |
| self.start = state_dict["start"] |
|
|