hmr-dataset / genmo /network /genmo_cfg_sampler.py
zirobtc's picture
Upload folder using huggingface_hub
fbb20ff verified
from copy import deepcopy
import torch
# A wrapper model for Classifier-free guidance **SAMPLING** only
# https://arxiv.org/abs/2207.12598
class ClassifierFreeSampleModel:
def __init__(self, model):
self.model = model # model is the actual model to run
def __call__(self, x, timesteps, y=None, **kwargs):
y_uncond = deepcopy(y)
y_uncond["encoded_text"] = torch.zeros_like(y["encoded_text"])
y_uncond["f_cond"] = y["f_uncond"]
if "multi_text_data" in y:
y_uncond["multi_text_data"]["text_embed"] = torch.zeros_like(
y["multi_text_data"]["text_embed"]
)
out = self.model(x, timesteps, y, **kwargs)
out_uncond = self.model(x, timesteps, y_uncond, **kwargs)
outputs = dict()
for k in out:
outputs[k] = out_uncond[k] + y["scale"] * (out[k] - out_uncond[k])
return outputs
def parameters(self):
return self.model.parameters()
def named_parameters(self):
return self.model.named_parameters()