Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import torch.nn.functional as F | |
| # from util import util | |
| from lib.models.networks.audio_network import ResNetSE, SEBasicBlock | |
| from lib.models.networks.FAN_feature_extractor import FAN_use | |
| from lib.models.networks.vision_network import ResNeXt50 | |
| from torchvision.models.vgg import vgg19_bn | |
| from lib.models.networks.swin_transformer import SwinTransformer | |
| from lib.models.networks.wavlm.wavlm import WavLM, WavLMConfig | |
| class ResSEAudioEncoder(nn.Module): | |
| def __init__(self, opt, nOut=2048, n_mel_T=None): | |
| super(ResSEAudioEncoder, self).__init__() | |
| self.nOut = nOut | |
| self.opt = opt | |
| pose_dim = self.opt.model.net_nonidentity.pose_dim | |
| eye_dim = self.opt.model.net_nonidentity.eye_dim | |
| # Number of filters | |
| num_filters = [32, 64, 128, 256] | |
| if n_mel_T is None: # use it when use audio identity | |
| n_mel_T = opt.model.net_audio.n_mel_T | |
| self.model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, self.nOut, n_mel_T=n_mel_T) | |
| if opt.audio_only: | |
| self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim)) | |
| else: | |
| self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim-eye_dim)) | |
| # def forward_feature(self, x): | |
| def forward(self, x, _type=None): | |
| input_size = x.size() | |
| if len(input_size) == 5: | |
| bz, clip_len, c, f, t = input_size | |
| x = x.view(bz * clip_len, c, f, t) | |
| out = self.model(x) | |
| if _type == "to_mouth_embed": | |
| out = out.view(-1, out.shape[-1]) | |
| mouth_embed = self.mouth_embed(out) | |
| return out, mouth_embed | |
| return out | |
| # def forward(self, x): | |
| # out = self.forward_feature(x) | |
| # score = self.fc(out) | |
| # return out, score | |
| class ResSESyncEncoder(ResSEAudioEncoder): | |
| def __init__(self, opt): | |
| super(ResSESyncEncoder, self).__init__(opt, nOut=512, n_mel_T=1) | |
| class ResNeXtEncoder(ResNeXt50): | |
| def __init__(self, opt): | |
| super(ResNeXtEncoder, self).__init__(opt) | |
| class VGGEncoder(nn.Module): | |
| def __init__(self, opt): | |
| super(VGGEncoder, self).__init__() | |
| self.model = vgg19_bn(num_classes=opt.data.num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| class FanEncoder(nn.Module): | |
| def __init__(self, opt): | |
| super(FanEncoder, self).__init__() | |
| self.opt = opt | |
| pose_dim = self.opt.model.net_nonidentity.pose_dim | |
| eye_dim = self.opt.model.net_nonidentity.eye_dim | |
| self.model = FAN_use() | |
| # self.classifier = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.Linear(512, opt.data.num_classes)) | |
| # mapper to mouth subspace | |
| ### revised version1 | |
| # self.to_mouth = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) | |
| # self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim-eye_dim)) | |
| # mapper to head pose subspace | |
| ### revised version1 | |
| self.to_headpose = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) | |
| self.headpose_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, pose_dim)) | |
| self.to_eye = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) | |
| self.eye_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, eye_dim)) | |
| self.to_emo = nn.Sequential(nn.Linear(512, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) | |
| self.emo_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 30)) | |
| # self.feature_fuse = nn.Sequential(nn.ReLU(), nn.Linear(1036, 512)) | |
| # self.feature_fuse = nn.Sequential(nn.ReLU(), nn.Linear(1036, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 512)) | |
| def forward_feature(self, x): | |
| net = self.model(x) | |
| return net | |
| def forward(self, x, _type="feature"): | |
| if _type == "feature": | |
| return self.forward_feature(x) | |
| elif _type == "feature_embed": | |
| x = self.model(x) | |
| # mouth_feat = self.to_mouth(x) | |
| # mouth_emb = self.mouth_embed(mouth_feat) | |
| headpose_feat = self.to_headpose(x) | |
| headpose_emb = self.headpose_embed(headpose_feat) | |
| eye_feat = self.to_eye(x) | |
| eye_embed = self.eye_embed(eye_feat) | |
| emo_feat = self.to_emo(x) | |
| emo_embed = self.emo_embed(emo_feat) | |
| # return headpose_emb, eye_embed, emo_feat | |
| return headpose_emb, eye_embed, emo_embed | |
| elif _type == "to_headpose": | |
| x = self.model(x) | |
| headpose_feat = self.to_headpose(x) | |
| headpose_emb = self.headpose_embed(headpose_feat) | |
| return headpose_emb | |
| # class WavlmEncoder(nn.Module): | |
| # def __init__(self, opt): | |
| # super(WavlmEncoder, self).__init__() | |
| # wavlm_checkpoint = torch.load(opt.model.net_audio.official_pretrain) | |
| # wavlm_cfg = WavLMConfig(wavlm_checkpoint['cfg']) | |
| # # pose_dim = opt.model.net_nonidentity.pose_dim | |
| # self.model = WavLM(wavlm_cfg) | |
| # # self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(768, 512-pose_dim)) | |
| # def forward(self, x): | |
| # feature = self.model.extract_features(x)[0] | |
| # # audio_feat = self.mouth_embed(feature.mean(1)) | |
| # return feature.mean(1) | |
| class WavlmEncoder(nn.Module): | |
| def __init__(self, opt): | |
| super(WavlmEncoder, self).__init__() | |
| self.input_wins = opt.audio.num_frames_per_clip | |
| self.s = (self.input_wins - 5) // 2 * 2 | |
| self.e = self.s + 5 * 2 - 1 | |
| wavlm_checkpoint = torch.load(opt.model.net_audio.official_pretrain) | |
| wavlm_cfg = WavLMConfig(wavlm_checkpoint['cfg']) | |
| pose_dim = opt.model.net_nonidentity.pose_dim | |
| self.model = WavLM(wavlm_cfg) | |
| self.mouth_feat = nn.Sequential(nn.Linear(768, 512), nn.ReLU(), nn.Linear(512, 512)) | |
| self.mouth_embed = nn.Sequential(nn.ReLU(), nn.Linear(512, 512-pose_dim)) | |
| def forward(self, x): | |
| feature = self.model.extract_features(x)[0] | |
| feature = self.mouth_feat(feature[:, self.s:self.e].mean(1)) | |
| audio_feat = self.mouth_embed(feature) | |
| return feature, audio_feat | |
| class SwinEncoder(nn.Module): | |
| def __init__(self, cfg): | |
| super(SwinEncoder, self).__init__() | |
| self.encoder = SwinTransformer( | |
| num_classes = 0, | |
| img_size = cfg.model.net_nonidentity.img_size, | |
| patch_size = cfg.model.net_nonidentity.patch_size, | |
| in_chans = cfg.model.net_nonidentity.in_chans, | |
| embed_dim = cfg.model.net_nonidentity.embed_dim, | |
| depths = cfg.model.net_nonidentity.depths, | |
| num_heads = cfg.model.net_nonidentity.num_heads, | |
| window_size = cfg.model.net_nonidentity.window_size, | |
| mlp_ratio = cfg.model.net_nonidentity.mlp_ratio, | |
| qkv_bias = cfg.model.net_nonidentity.qkv_bias, | |
| qk_scale = None if not cfg.model.net_nonidentity.qk_scale else 0.1, | |
| drop_rate = cfg.model.net_nonidentity.drop_rate, | |
| drop_path_rate = cfg.model.net_nonidentity.drop_path_rate, | |
| ape = cfg.model.net_nonidentity.ape, | |
| patch_norm = cfg.model.net_nonidentity.patch_norm, | |
| use_checkpoint = False | |
| ) | |
| # self.audio_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.eye_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.landmark_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.exp_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| def forward(self, img): | |
| feature = self.encoder(img) | |
| # audio_embed = self.audio_mlp(feature) | |
| # eye_embed = self.eye_mlp(feature) | |
| # ldmk_embed = self.landmark_mlp(feature) | |
| # exp_embed = self.exp_mlp(feature) | |
| # return feature, audio_embed, eye_embed, ldmk_embed, exp_embed | |
| return feature | |
| class ResEncoder(nn.Module): | |
| def __init__(self, opt): | |
| super(ResEncoder, self).__init__() | |
| self.opt = opt | |
| self.model = resnet50(num_classes=512, include_top=True) | |
| def forward(self, x): | |
| feature = self.model(x) | |
| # print(feature.shape) | |
| return feature | |
| class FansEncoder(nn.Module): | |
| def __init__(self, cfg): | |
| super(FansEncoder, self).__init__() | |
| self.encoder = FAN_use(out_dim=768) | |
| # self.audio_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.eye_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.landmark_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.exp_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| def forward(self, img): | |
| feature = self.encoder(img) | |
| # audio_embed = self.audio_mlp(feature) | |
| # eye_embed = self.eye_mlp(feature) | |
| # ldmk_embed = self.landmark_mlp(feature) | |
| # exp_embed = self.exp_mlp(feature) | |
| # return feature, audio_embed, eye_embed, ldmk_embed, exp_embed | |
| return feature | |
| class SwinEncoder(nn.Module): | |
| def __init__(self, cfg): | |
| super(SwinEncoder, self).__init__() | |
| self.encoder = SwinTransformer( | |
| num_classes = 0, | |
| img_size = cfg.model.net_nonidentity.img_size, | |
| patch_size = cfg.model.net_nonidentity.patch_size, | |
| in_chans = cfg.model.net_nonidentity.in_chans, | |
| embed_dim = cfg.model.net_nonidentity.embed_dim, | |
| depths = cfg.model.net_nonidentity.depths, | |
| num_heads = cfg.model.net_nonidentity.num_heads, | |
| window_size = cfg.model.net_nonidentity.window_size, | |
| mlp_ratio = cfg.model.net_nonidentity.mlp_ratio, | |
| qkv_bias = cfg.model.net_nonidentity.qkv_bias, | |
| qk_scale = None if not cfg.model.net_nonidentity.qk_scale else 0.1, | |
| drop_rate = cfg.model.net_nonidentity.drop_rate, | |
| drop_path_rate = cfg.model.net_nonidentity.drop_path_rate, | |
| ape = cfg.model.net_nonidentity.ape, | |
| patch_norm = cfg.model.net_nonidentity.patch_norm, | |
| use_checkpoint = False | |
| ) | |
| # self.audio_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.eye_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.landmark_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.exp_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| def forward(self, img): | |
| feature = self.encoder(img) | |
| # audio_embed = self.audio_mlp(feature) | |
| # eye_embed = self.eye_mlp(feature) | |
| # ldmk_embed = self.landmark_mlp(feature) | |
| # exp_embed = self.exp_mlp(feature) | |
| # return feature, audio_embed, eye_embed, ldmk_embed, exp_embed | |
| return feature | |
| class ResEncoder(nn.Module): | |
| def __init__(self, opt): | |
| super(ResEncoder, self).__init__() | |
| self.opt = opt | |
| self.model = resnet50(num_classes=512, include_top=True) | |
| def forward(self, x): | |
| feature = self.model(x) | |
| # print(feature.shape) | |
| return feature | |
| class FansEncoder(nn.Module): | |
| def __init__(self, cfg): | |
| super(FansEncoder, self).__init__() | |
| self.encoder = FAN_use(out_dim=768) | |
| # self.audio_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.eye_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.landmark_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| # self.exp_mlp = nn.Sequential(nn.Linear(768, 768), nn.ReLU(), nn.LayerNorm(768), nn.Linear(768, 768)) | |
| def forward(self, img): | |
| feature = self.encoder(img) | |
| # audio_embed = self.audio_mlp(feature) | |
| # eye_embed = self.eye_mlp(feature) | |
| # ldmk_embed = self.landmark_mlp(feature) | |
| # exp_embed = self.exp_mlp(feature) | |
| # return feature, audio_embed, eye_embed, ldmk_embed, exp_embed | |
| return feature |