Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from .FAN_feature_extractor import FAN_SA | |
| from einops import rearrange | |
| from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid | |
| from diffusers.models.modeling_utils import ModelMixin | |
| def zero_module(module): | |
| # Zero out the parameters of a module and return it. | |
| assert isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear), type(module) | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class MotEncoder(ModelMixin): | |
| def __init__(self, out_ch=16): | |
| super().__init__() | |
| self.model = FAN_SA() | |
| self.out_drop = None #nn.Dropout(p=0.4) | |
| self.out_ch = out_ch | |
| expr_dim = 512 | |
| extra_pos_embed = get_1d_sincos_pos_embed_from_grid(out_ch, np.arange(expr_dim//out_ch)) | |
| self.register_buffer("pe", torch.from_numpy(extra_pos_embed).float().unsqueeze(0)) | |
| self.final_proj = nn.Linear(expr_dim, expr_dim) | |
| self.out_bn = None | |
| def change_out_dim(self, out_ch): | |
| self.out_proj = nn.Linear(self.out_ch, out_ch) | |
| def set_attn_processor(self, processor): | |
| self.model.set_attn_processor(processor) | |
| def forward(self, x): | |
| x = x.to(self.dtype) | |
| latent = self.model(rearrange(x, "b c f h w -> (b f) c h w")) | |
| latent = self.final_proj(latent) | |
| latent = rearrange(latent, "b (l c) -> b l c", c=self.out_ch) + self.pe | |
| latent = rearrange(latent, "(b f) l c -> b f l c", f=x.shape[2]) | |
| return latent |