Spaces:
Sleeping
Sleeping
File size: 5,542 Bytes
5007d4b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | import torch
import torch.nn as nn
from copy import deepcopy
from utils.misc import wrapped_getattr
# A wrapper model for Classifier-free guidance **SAMPLING** only
# https://arxiv.org/abs/2207.12598
class ClassifierFreeSampleModel(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model # model is the actual model to run
assert self.model.cond_mask_prob > 0, (
"Cannot run a guided diffusion on a model that has not been trained with no conditions"
)
# pointers to inner model
self.rot2xyz = self.model.rot2xyz
self.translation = self.model.translation
self.njoints = self.model.njoints
self.nfeats = self.model.nfeats
self.data_rep = self.model.data_rep
self.cond_mode = self.model.cond_mode
self.encode_text = self.model.encode_text
self.dataset = self.model.dataset if hasattr(self.model, "dataset") else None
def forward(self, x, timesteps, y=None, **kwargs):
"""
Forward pass with classifier-free guidance.
Args:
x: Input tensor
timesteps: Diffusion timesteps
y: Text/action conditioning dictionary
**kwargs: Additional conditioning arguments including cond_images
"""
cond_mode = self.model.cond_mode
assert cond_mode in ["text", "action"]
if "cond_images" in kwargs:
# 1. Fully conditional pass (with text/action from `y` and image from `kwargs`)
out_text_and_image = self.model(x, timesteps, y, **kwargs)
# 2. Text-conditional but image-unconditional pass
kwargs_no_image = kwargs.copy()
if "cond_images" in kwargs_no_image:
del kwargs_no_image["cond_images"]
out_text_no_image = self.model(x, timesteps, y, **kwargs_no_image)
# 3. Apply classifier-free guidance formula for image conditioning
if y is not None and "scale" in y:
guidance_scale = y["scale"].view(-1, 1, 1, 1)
guided_output = out_text_no_image + guidance_scale * (
out_text_and_image - out_text_no_image
)
else:
print(
"Warning: No scale provided for guidance. Using unconditional output."
)
guided_output = out_text_and_image
return guided_output
else:
# old text/action conditioning without images
# Create unconditional version
y_uncond = deepcopy(y)
y_uncond["uncond"] = True
# Run both conditional and unconditional passes
# Pass through any additional kwargs (including cond_images)
out = self.model(x, timesteps, y, **kwargs)
out_uncond = self.model(x, timesteps, y_uncond, **kwargs)
# Apply classifier-free guidance formula
return out_uncond + (y["scale"].view(-1, 1, 1, 1) * (out - out_uncond))
def __getattr__(self, name, default=None):
# this method is reached only if name is not in self.__dict__.
return wrapped_getattr(self, name, default=None)
class AutoRegressiveSampler:
def __init__(self, args, sample_fn, required_frames=196):
self.sample_fn = sample_fn
self.args = args
self.required_frames = required_frames
def sample(self, model, shape, **kargs):
bs = shape[0]
n_iterations = (self.required_frames // self.args.pred_len) + int(
self.required_frames % self.args.pred_len > 0
)
samples_buf = []
cur_prefix = deepcopy(kargs["model_kwargs"]["y"]["prefix"]) # init with data
dynamic_text_mode = (
type(kargs["model_kwargs"]["y"]["text"][0]) == list
) # Text changes on the fly - prompt per prediction is provided as a list (instead of a single prompt)
if self.args.autoregressive_include_prefix:
samples_buf.append(cur_prefix)
autoregressive_shape = list(deepcopy(shape))
autoregressive_shape[-1] = self.args.pred_len
# Autoregressive sampling
for i in range(n_iterations):
# Build the current kargs
cur_kargs = deepcopy(kargs)
cur_kargs["model_kwargs"]["y"]["prefix"] = cur_prefix
if dynamic_text_mode:
cur_kargs["model_kwargs"]["y"]["text"] = [
s[i] for s in kargs["model_kwargs"]["y"]["text"]
]
if model.text_encoder_type == "bert":
cur_kargs["model_kwargs"]["y"]["text_embed"] = (
cur_kargs["model_kwargs"]["y"]["text_embed"][0][:, :, i],
cur_kargs["model_kwargs"]["y"]["text_embed"][1][:, i],
)
else:
raise NotImplementedError(
"DiP model only supports BERT text encoder at the moment. If you implement this, please send a PR!"
)
# Sample the next prediction
sample = self.sample_fn(model, autoregressive_shape, **cur_kargs)
# Buffer the sample
samples_buf.append(sample.clone()[..., -self.args.pred_len :])
# Update the prefix
cur_prefix = sample.clone()[..., -self.args.context_len :]
full_batch = torch.cat(samples_buf, dim=-1)[
..., : self.required_frames
] # 200 -> 196
return full_batch
|