File size: 1,560 Bytes
7428365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .FAN_feature_extractor import FAN_SA
from einops import rearrange
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid
from diffusers.models.modeling_utils import ModelMixin

def zero_module(module):
    # Zero out the parameters of a module and return it.
    assert isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear), type(module)
    for p in module.parameters():
        p.detach().zero_()
    return module

class MotEncoder(ModelMixin):
    def __init__(self, out_ch=16):
        super().__init__()
        self.model = FAN_SA()
        self.out_drop = None #nn.Dropout(p=0.4)
        self.out_ch = out_ch
        expr_dim = 512
        extra_pos_embed = get_1d_sincos_pos_embed_from_grid(out_ch, np.arange(expr_dim//out_ch))
        self.register_buffer("pe", torch.from_numpy(extra_pos_embed).float().unsqueeze(0))
        self.final_proj = nn.Linear(expr_dim, expr_dim)
        self.out_bn = None

    def change_out_dim(self, out_ch):
        self.out_proj = nn.Linear(self.out_ch, out_ch)

    def set_attn_processor(self, processor):
        self.model.set_attn_processor(processor)

    def forward(self, x):
        x = x.to(self.dtype)
        latent = self.model(rearrange(x, "b c f h w -> (b f) c h w"))
        latent = self.final_proj(latent)
        latent = rearrange(latent, "b (l c) -> b l c", c=self.out_ch) + self.pe
        latent = rearrange(latent, "(b f) l c -> b f l c", f=x.shape[2])
        return latent