File size: 5,542 Bytes
5007d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
from copy import deepcopy
from utils.misc import wrapped_getattr


# A wrapper model for Classifier-free guidance **SAMPLING** only
# https://arxiv.org/abs/2207.12598
class ClassifierFreeSampleModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model  # model is the actual model to run

        assert self.model.cond_mask_prob > 0, (
            "Cannot run a guided diffusion on a model that has not been trained with no conditions"
        )

        # pointers to inner model
        self.rot2xyz = self.model.rot2xyz
        self.translation = self.model.translation
        self.njoints = self.model.njoints
        self.nfeats = self.model.nfeats
        self.data_rep = self.model.data_rep
        self.cond_mode = self.model.cond_mode
        self.encode_text = self.model.encode_text
        self.dataset = self.model.dataset if hasattr(self.model, "dataset") else None

    def forward(self, x, timesteps, y=None, **kwargs):
        """
        Forward pass with classifier-free guidance.

        Args:
            x: Input tensor
            timesteps: Diffusion timesteps
            y: Text/action conditioning dictionary
            **kwargs: Additional conditioning arguments including cond_images
        """
        cond_mode = self.model.cond_mode
        assert cond_mode in ["text", "action"]

        if "cond_images" in kwargs:
            # 1. Fully conditional pass (with text/action from `y` and image from `kwargs`)
            out_text_and_image = self.model(x, timesteps, y, **kwargs)

            # 2. Text-conditional but image-unconditional pass
            kwargs_no_image = kwargs.copy()
            if "cond_images" in kwargs_no_image:
                del kwargs_no_image["cond_images"]

            out_text_no_image = self.model(x, timesteps, y, **kwargs_no_image)

            # 3. Apply classifier-free guidance formula for image conditioning
            if y is not None and "scale" in y:
                guidance_scale = y["scale"].view(-1, 1, 1, 1)
                guided_output = out_text_no_image + guidance_scale * (
                    out_text_and_image - out_text_no_image
                )
            else:
                print(
                    "Warning: No scale provided for guidance. Using unconditional output."
                )
                guided_output = out_text_and_image

            return guided_output
        else:
            # old text/action conditioning without images
            # Create unconditional version
            y_uncond = deepcopy(y)
            y_uncond["uncond"] = True

            # Run both conditional and unconditional passes
            # Pass through any additional kwargs (including cond_images)
            out = self.model(x, timesteps, y, **kwargs)
            out_uncond = self.model(x, timesteps, y_uncond, **kwargs)

            # Apply classifier-free guidance formula
            return out_uncond + (y["scale"].view(-1, 1, 1, 1) * (out - out_uncond))

    def __getattr__(self, name, default=None):
        # this method is reached only if name is not in self.__dict__.
        return wrapped_getattr(self, name, default=None)


class AutoRegressiveSampler:
    def __init__(self, args, sample_fn, required_frames=196):
        self.sample_fn = sample_fn
        self.args = args
        self.required_frames = required_frames

    def sample(self, model, shape, **kargs):
        bs = shape[0]
        n_iterations = (self.required_frames // self.args.pred_len) + int(
            self.required_frames % self.args.pred_len > 0
        )
        samples_buf = []
        cur_prefix = deepcopy(kargs["model_kwargs"]["y"]["prefix"])  # init with data
        dynamic_text_mode = (
            type(kargs["model_kwargs"]["y"]["text"][0]) == list
        )  # Text changes on the fly - prompt per prediction is provided as a list (instead of a single prompt)
        if self.args.autoregressive_include_prefix:
            samples_buf.append(cur_prefix)
        autoregressive_shape = list(deepcopy(shape))
        autoregressive_shape[-1] = self.args.pred_len

        # Autoregressive sampling
        for i in range(n_iterations):
            # Build the current kargs
            cur_kargs = deepcopy(kargs)
            cur_kargs["model_kwargs"]["y"]["prefix"] = cur_prefix
            if dynamic_text_mode:
                cur_kargs["model_kwargs"]["y"]["text"] = [
                    s[i] for s in kargs["model_kwargs"]["y"]["text"]
                ]
                if model.text_encoder_type == "bert":
                    cur_kargs["model_kwargs"]["y"]["text_embed"] = (
                        cur_kargs["model_kwargs"]["y"]["text_embed"][0][:, :, i],
                        cur_kargs["model_kwargs"]["y"]["text_embed"][1][:, i],
                    )
                else:
                    raise NotImplementedError(
                        "DiP model only supports BERT text encoder at the moment. If you implement this, please send a PR!"
                    )

            # Sample the next prediction
            sample = self.sample_fn(model, autoregressive_shape, **cur_kargs)

            # Buffer the sample
            samples_buf.append(sample.clone()[..., -self.args.pred_len :])

            # Update the prefix
            cur_prefix = sample.clone()[..., -self.args.context_len :]

        full_batch = torch.cat(samples_buf, dim=-1)[
            ..., : self.required_frames
        ]  # 200 -> 196
        return full_batch