File size: 5,189 Bytes
4724018
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.append('./')

from model.dit import CogVideoXTransformer3DModel

class PointEmbed(nn.Module):
    def __init__(self, hidden_dim=96, dim=512):
        super().__init__()

        assert hidden_dim % 6 == 0

        self.embedding_dim = hidden_dim
        e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
        e = torch.stack([
            torch.cat([e, torch.zeros(self.embedding_dim // 6),
                        torch.zeros(self.embedding_dim // 6)]),
            torch.cat([torch.zeros(self.embedding_dim // 6), e,
                        torch.zeros(self.embedding_dim // 6)]),
            torch.cat([torch.zeros(self.embedding_dim // 6),
                        torch.zeros(self.embedding_dim // 6), e]),
        ])
        self.register_buffer('basis', e)  # 3 x 16

        self.mlp = nn.Linear(self.embedding_dim+3, dim)

    @staticmethod
    def embed(input, basis):
        projections = torch.einsum(
            'bnd,de->bne', input, basis)
        embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
        return embeddings
    
    def forward(self, input):
        # input: B x N x 3
        embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
        return embed

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_parameter('pe', nn.Parameter(pe, requires_grad=False))

    def forward(self, x):
        # not used in the final model
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)

class MDM_DiT(nn.Module):
    
    def __init__(self, n_points, n_frame, n_feats, model_config):
        super().__init__()

        self.n_points = n_points
        self.n_feats = n_feats
        self.latent_dim = model_config.latent_dim
        self.cond_seq_length = 4
        self.cond_frame = 1 if model_config.frame_cond else 0

        self.dit = CogVideoXTransformer3DModel(sample_points=n_points, sample_frames=n_frame+self.cond_frame, in_channels=n_feats,
            num_layers=model_config.n_layers, num_attention_heads=self.latent_dim // 64, cond_seq_length=self.cond_seq_length)
        
        self.input_encoder = PointEmbed(dim=self.latent_dim)
        # self.init_cond_encoder = PointEmbed(dim=self.latent_dim)
        self.E_cond_encoder = nn.Linear(1, self.latent_dim)
        self.nu_cond_encoder = nn.Linear(1, self.latent_dim)
        self.force_cond_encoder = nn.Linear(3, self.latent_dim)
        self.drag_point_encoder = nn.Linear(3, self.latent_dim)
    
    def enable_gradient_checkpointing(self):
        self.dit._set_gradient_checkpointing(True)

    def forward(self, x, timesteps, init_pc, force, E, nu, drag_mask, drag_point, floor_height=None, coeff=None, y=None, null_emb=0):
        
        """
        x: [batch_size, frame, n_points, n_feats], denoted x_t in the paper
        timesteps: [batch_size] (int)
        """
        
        bs, n_frame, n_points, n_feats = x.shape
        
        init_pc = init_pc.reshape(bs, n_points, n_feats)
        force = force.unsqueeze(1)
        E = E.unsqueeze(1)
        nu = nu.unsqueeze(1)
        drag_point = drag_point.unsqueeze(1)
        x = torch.cat([init_pc.unsqueeze(1), x], axis=1)
        n_frame += 1
        encoder_hidden_states = torch.cat([self.force_cond_encoder(force), self.E_cond_encoder(E),
                self.nu_cond_encoder(nu), self.drag_point_encoder(drag_point)], axis=1) 
        hidden_states = self.input_encoder(x.reshape(bs * n_frame, n_points,
            n_feats)).reshape(bs, n_frame, n_points, self.latent_dim)
        full_seq = torch.cat([encoder_hidden_states, hidden_states.reshape(bs, n_frame * n_points, self.latent_dim)], axis=1)
        output = self.dit(full_seq, timesteps).reshape(bs, n_frame, n_points, 3)[:, self.cond_frame:]
        output = output + init_pc.unsqueeze(1)
            
        return output

if __name__ == "__main__":
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    point_num = 512
    frame_num = 6
    
    x = torch.randn(2, frame_num, point_num, 3).to(device).to(torch.float16)
    timesteps = torch.tensor([999, 999]).int().to(device).to(torch.float16)
    init_pc = torch.randn(2, 1, point_num, 3).to(device).to(torch.float16)
    force = torch.randn(2, 3).to(device).to(torch.float16)
    E = torch.randn(2, 1).to(device).to(torch.float16)
    nu = torch.randn(2, 1).to(device).to(torch.float16)
    
    model = MDM_DiT([point_num], 3).to(device).to(torch.float16)
    output = model(x, timesteps, init_pc, force, E, nu)
    print(output.shape)