Spaces:
Running
Running
File size: 1,560 Bytes
7428365 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .FAN_feature_extractor import FAN_SA
from einops import rearrange
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid
from diffusers.models.modeling_utils import ModelMixin
def zero_module(module):
# Zero out the parameters of a module and return it.
assert isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear), type(module)
for p in module.parameters():
p.detach().zero_()
return module
class MotEncoder(ModelMixin):
def __init__(self, out_ch=16):
super().__init__()
self.model = FAN_SA()
self.out_drop = None #nn.Dropout(p=0.4)
self.out_ch = out_ch
expr_dim = 512
extra_pos_embed = get_1d_sincos_pos_embed_from_grid(out_ch, np.arange(expr_dim//out_ch))
self.register_buffer("pe", torch.from_numpy(extra_pos_embed).float().unsqueeze(0))
self.final_proj = nn.Linear(expr_dim, expr_dim)
self.out_bn = None
def change_out_dim(self, out_ch):
self.out_proj = nn.Linear(self.out_ch, out_ch)
def set_attn_processor(self, processor):
self.model.set_attn_processor(processor)
def forward(self, x):
x = x.to(self.dtype)
latent = self.model(rearrange(x, "b c f h w -> (b f) c h w"))
latent = self.final_proj(latent)
latent = rearrange(latent, "b (l c) -> b l c", c=self.out_ch) + self.pe
latent = rearrange(latent, "(b f) l c -> b f l c", f=x.shape[2])
return latent |