olm-chat-7b / open_lm /utils /averaging_utils.py
henhenhahi111112's picture
Upload folder using huggingface_hub
af6e330 verified
import numpy as np
import torch
import logging
from copy import deepcopy
def unwrap_model(model):
return model.module if hasattr(model, "module") else model
class ModelAverager(object):
def __init__(self, model, methods: str):
self.model = model
self.avgs_dict = {}
for method in methods.split(","):
args = method.split("_")
# method_name = args[0][:-1] if args[0].endswith('_') else args[0]
# freq = int(args[1]) if len(args) > 1 else 1
self.avgs_dict[method] = Averager(model, args)
def step(self):
for avg in self.avgs_dict.values():
avg.step()
class Averager(object):
def __init__(self, model, args):
self.model = model
self.method = args[0]
self.update_counter = 1
self.step_counter = 1
self.freq = 1 if (len(args) <= 2) else int(args[2])
if self.method == "none":
self.av_model = model
return
else:
self.av_model = deepcopy(unwrap_model(model))
if self.method == "poly":
self.eta = 0.0 if len(args) <= 1 else float(args[1])
elif self.method == "ema":
self.gamma = 0.99 if len(args) <= 1 else float(args[1])
elif self.method == "cosine":
pass
else:
print(f"Unknown averaging method {self.method}")
def step(self):
if self.update_counter != self.freq:
pass
else:
self.update()
self.update_counter += 1
if self.update_counter > self.freq:
self.update_counter = 1
return
def update(self):
method = self.method
if method == "none":
return
t = self.step_counter
# model_sd is the current model state dict
# av_sd is the averaged model state dict
model_sd = self.model.state_dict()
av_sd = self.av_model.state_dict()
if self.method == "cosine" or self.method == "degree":
pass
first_k_av_sd = list(av_sd.keys())[0]
for k in model_sd.keys():
av_sd_k = k
if k.startswith("module") and not first_k_av_sd.startswith("module"):
av_sd_k = k[len("module.") :]
if isinstance(av_sd[av_sd_k], (torch.LongTensor, torch.cuda.LongTensor)):
# these are buffers that store how many batches batch norm has seen so far
av_sd[av_sd_k].copy_(model_sd[k])
continue
if method == "poly":
# the update rule is: new_average = (1 - (eta + 1) / (eta + t)) * old_average + (eta + 1) / (eta + t) * current_model
# which is eq(10) in https://arxiv.org/pdf/1212.1824.pdf
av_sd[av_sd_k].mul_(1 - ((self.eta + 1) / (self.eta + t))).add_(
model_sd[k], alpha=(self.eta + 1) / (self.eta + t)
)
if method == "ema":
# the update rule is: new_average = (1 - gamma) * old_average + gamma * current_model
av_sd[av_sd_k].mul_(self.gamma).add_(model_sd[k], alpha=1 - self.gamma)
self.step_counter += 1
def reset(self):
self.step_counter = 2
@property
def averaged_model(self):
return self.av_model
def get_state_dict_avg(self):
state_dict = {
"update_counter": self.update_counter,
"step_counter": self.step_counter,
"freq": self.freq,
"av_model_sd": unwrap_model(self.av_model).state_dict(),
"method": self.method,
"eta": self.eta if hasattr(self, "eta") else None,
"gamma": self.gamma if hasattr(self, "gamma") else None,
"suffix_steps": self.suffix_steps if hasattr(self, "suffix_steps") else None,
"power": self.power if hasattr(self, "power") else None,
"start": self.start if hasattr(self, "start") else None,
}
return state_dict
def load_state_dict_avg(self, state_dict):
self.update_counter = state_dict["update_counter"]
self.step_counter = state_dict["step_counter"]
self.freq = state_dict["freq"]
self.method = state_dict["method"]
self.av_model.load_state_dict(state_dict["av_model_sd"])
if hasattr(self, "eta"):
self.eta = state_dict["eta"]
if hasattr(self, "gamma"):
self.gamma = state_dict["gamma"]
if hasattr(self, "suffix_steps"):
self.suffix_steps = state_dict["suffix_steps"]
if hasattr(self, "power"):
self.power = state_dict["power"]
if hasattr(self, "start"):
self.start = state_dict["start"]