from copy import deepcopy import torch # A wrapper model for Classifier-free guidance **SAMPLING** only # https://arxiv.org/abs/2207.12598 class ClassifierFreeSampleModel: def __init__(self, model): self.model = model # model is the actual model to run def __call__(self, x, timesteps, y=None, **kwargs): y_uncond = deepcopy(y) y_uncond["encoded_text"] = torch.zeros_like(y["encoded_text"]) y_uncond["f_cond"] = y["f_uncond"] if "multi_text_data" in y: y_uncond["multi_text_data"]["text_embed"] = torch.zeros_like( y["multi_text_data"]["text_embed"] ) out = self.model(x, timesteps, y, **kwargs) out_uncond = self.model(x, timesteps, y_uncond, **kwargs) outputs = dict() for k in out: outputs[k] = out_uncond[k] + y["scale"] * (out[k] - out_uncond[k]) return outputs def parameters(self): return self.model.parameters() def named_parameters(self): return self.model.named_parameters()