File size: 2,893 Bytes
4724018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from diffusers import DiffusionPipeline


class TrajPipeline(DiffusionPipeline):
    def __init__(self, model, scheduler):
        super().__init__()
        self.register_modules(model=model, scheduler=scheduler)

    @torch.no_grad()
    def __call__(self, init_pc, force, E, nu, mask, drag_point, floor_height, gravity, coeff,
        generator, 
        device, 
        y = None,
        batch_size: int = 1, 
        num_inference_steps: int = 50, 
        guidance_scale=1.0, 
        n_frames=20
    ):
        # Sample gaussian noise to begin loop
        sample = torch.randn((batch_size, n_frames, init_pc.shape[2], 3), generator=generator).to(device)
        self.model.to(device)
        init_pc = init_pc.to(device)
        force = force.to(device)
        E = E.to(device)
        nu = nu.to(device)
        mask = mask.to(device).to(dtype=sample.dtype)
        drag_point = drag_point.to(device)
        floor_height = floor_height.to(device)
        coeff = coeff.to(device)
        gravity = gravity.to(device) if gravity is not None else None
        y = y.to(device) if y is not None else None
        # set step values
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        do_classifier_free_guidance = (guidance_scale > 1.0)
        null_emb = torch.tensor([1] * batch_size).to(sample.dtype)
        if do_classifier_free_guidance:
            init_pc = torch.cat([init_pc] * 2)
            force = torch.cat([force] * 2)
            E = torch.cat([E] * 2)
            nu = torch.cat([nu] * 2)
            mask = torch.cat([mask] * 2)
            drag_point = torch.cat([drag_point] * 2)
            floor_height = torch.cat([floor_height] * 2)
            null_emb = torch.cat([torch.tensor([0] * batch_size).to(sample.dtype), null_emb])
        null_emb = null_emb[:, None, None].to(device)
        for t in self.progress_bar(self.scheduler.timesteps):
            t = torch.tensor([t] * batch_size, device=device)
            sample_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample
            t = torch.cat([t] * 2) if do_classifier_free_guidance else t
            # 1. predict noise model_output
            model_output = self.model(sample_input, t, init_pc, force, E, nu, mask, drag_point, floor_height=floor_height, gravity_label=gravity, coeff=coeff, y=y, null_emb=null_emb)
            if do_classifier_free_guidance:
                model_pred_uncond, model_pred_cond = model_output.chunk(2)
                model_output = model_pred_uncond + guidance_scale * (model_pred_cond - model_pred_uncond)
            # 2. predict previous mean of image x_t-1 and add variance depending on eta
            # eta corresponds to η in paper and should be between [0, 1]
            # do x_t -> x_t-1 
            sample = self.scheduler.step(model_output, t[0], sample).prev_sample
        return sample