| import torch |
| import torch.nn as nn |
| from .wan_video_dit import sinusoidal_embedding_1d |
|
|
|
|
| class WanMotionControllerModel(torch.nn.Module): |
| def __init__(self, freq_dim=256, dim=1536): |
| super().__init__() |
| self.freq_dim = freq_dim |
| self.linear = nn.Sequential( |
| nn.Linear(freq_dim, dim), |
| nn.SiLU(), |
| nn.Linear(dim, dim), |
| nn.SiLU(), |
| nn.Linear(dim, dim * 6), |
| ) |
|
|
| def forward(self, motion_bucket_id): |
| emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10) |
| emb = self.linear(emb) |
| return emb |
|
|
| def init(self): |
| state_dict = self.linear[-1].state_dict() |
| state_dict = {i: state_dict[i] * 0 for i in state_dict} |
| self.linear[-1].load_state_dict(state_dict) |
|
|
| @staticmethod |
| def state_dict_converter(): |
| return WanMotionControllerModelDictConverter() |
|
|
|
|
| class WanMotionControllerModelDictConverter: |
| def __init__(self): |
| pass |
|
|
| def from_diffusers(self, state_dict): |
| return state_dict |
|
|
| def from_civitai(self, state_dict): |
| return state_dict |
|
|