cbsjtu01 commited on
Commit
8471f73
·
1 Parent(s): 5c960f0

update models

Browse files
.gitattributes CHANGED
File without changes
README.md CHANGED
File without changes
app.py CHANGED
@@ -82,7 +82,7 @@ class AppConfig:
82
  self.renderer_path = "./checkpoints/renderer.ckpt"
83
  self.generator_path = "./checkpoints/generator.ckpt"
84
  self.wav2vec_model_path = "./checkpoints/wav2vec2-base-960h"
85
- self.input_size = 256
86
  self.input_nc = 3
87
  self.fps = 25.0
88
  self.rank = "cuda"
@@ -136,7 +136,7 @@ class DataProcessor:
136
  else:
137
  print("Local wav2vec model not found, downloading from 'facebook/wav2vec2-base-960h'...")
138
  self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
139
- self.transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
140
 
141
  def process_img(self, img: Image.Image) -> Image.Image:
142
  img_arr = np.array(img)
@@ -518,10 +518,10 @@ with gr.Blocks(title="IMTalker Demo") as demo:
518
  )
519
 
520
  with gr.Accordion("Settings", open=True):
521
- a_crop = gr.Checkbox(label="Auto Crop Face", value=True)
522
  a_seed = gr.Number(label="Seed", value=42)
523
  a_nfe = gr.Slider(5, 50, value=10, step=1, label="Steps (NFE)")
524
- a_cfg = gr.Slider(1.0, 5.0, value=3.0, label="CFG Scale")
525
 
526
  a_btn = gr.Button("Generate (Audio Driven)", variant="primary")
527
 
 
82
  self.renderer_path = "./checkpoints/renderer.ckpt"
83
  self.generator_path = "./checkpoints/generator.ckpt"
84
  self.wav2vec_model_path = "./checkpoints/wav2vec2-base-960h"
85
+ self.input_size = 512
86
  self.input_nc = 3
87
  self.fps = 25.0
88
  self.rank = "cuda"
 
136
  else:
137
  print("Local wav2vec model not found, downloading from 'facebook/wav2vec2-base-960h'...")
138
  self.wav2vec_preprocessor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
139
+ self.transform = transforms.Compose([transforms.Resize((512, 512)), transforms.ToTensor()])
140
 
141
  def process_img(self, img: Image.Image) -> Image.Image:
142
  img_arr = np.array(img)
 
518
  )
519
 
520
  with gr.Accordion("Settings", open=True):
521
+ a_crop = gr.Checkbox(label="Auto Crop Face", value=False)
522
  a_seed = gr.Number(label="Seed", value=42)
523
  a_nfe = gr.Slider(5, 50, value=10, step=1, label="Steps (NFE)")
524
+ a_cfg = gr.Slider(1.0, 5.0, value=2.0, label="CFG Scale")
525
 
526
  a_btn = gr.Button("Generate (Audio Driven)", variant="primary")
527
 
generator/FMT.py CHANGED
@@ -1,402 +1,383 @@
1
- import os, math, torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
-
6
- from timm.layers import use_fused_attn
7
- from timm.models.vision_transformer import Mlp
8
-
9
-
10
- def enc_dec_mask(T, S, frame_width = 1, expansion = 2):
11
- mask = torch.ones(T, S)
12
- for i in range(T):
13
- mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0
14
- return mask == 1
15
-
16
-
17
- def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
18
- """
19
- Sinusoidal position encoding table.
20
- Args:
21
- n_position (int): the length of the input sequence
22
- d_hid (int): the dimension of the hidden state
23
- """
24
- def cal_angle(position, hid_idx):
25
- return position / (10000 ** (2 * (hid_idx // 2) / d_hid))
26
-
27
- def get_posi_angle_vec(position):
28
- return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
29
-
30
- sinusoid_table = torch.Tensor([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
31
- sinusoid_table[:, 0::2] = torch.sin(sinusoid_table[:, 0::2]) # dim 2i
32
- sinusoid_table[:, 1::2] = torch.cos(sinusoid_table[:, 1::2]) # dim 2i+1
33
- if padding_idx is not None: sinusoid_table[padding_idx] = 0.
34
- return sinusoid_table
35
-
36
-
37
- class Attention(nn.Module):
38
- def __init__(
39
- self,
40
- dim: int,
41
- num_heads: int = 8,
42
- qkv_bias: bool = False,
43
- qk_norm: bool = False,
44
- attn_drop: float = 0.,
45
- proj_drop: float = 0.,
46
- norm_layer: nn.Module = nn.LayerNorm,
47
- ) -> None:
48
-
49
- super().__init__()
50
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
51
- self.num_heads = num_heads
52
- self.head_dim = dim // num_heads
53
- self.scale = self.head_dim ** -0.5
54
- self.fused_attn = use_fused_attn()
55
-
56
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
57
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
58
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
59
- self.attn_drop = nn.Dropout(attn_drop)
60
- self.proj = nn.Linear(dim, dim)
61
- self.proj_drop = nn.Dropout(proj_drop)
62
-
63
- def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
64
- B, N, C = x.shape
65
-
66
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
67
- q, k, v = qkv.unbind(0)
68
- q, k = self.q_norm(q), self.k_norm(k)
69
-
70
- if self.fused_attn:
71
- x = F.scaled_dot_product_attention(
72
- q, k, v,
73
- attn_mask = ~mask,
74
- dropout_p=self.attn_drop.p if self.training else 0.,
75
- )
76
- else:
77
- q = q * self.scale
78
- attn = q @ k.transpose(-2, -1)
79
- attn = attn.softmax(dim=-1)
80
- attn = self.attn_drop(attn)
81
- x = attn @ v
82
-
83
- x = x.transpose(1, 2).reshape(B, N, C)
84
- x = self.proj(x)
85
- x = self.proj_drop(x)
86
- return x
87
-
88
- class TimestepEmbedder(nn.Module):
89
- """
90
- Embeds scalar timesteps into vector representations.
91
- """
92
- def __init__(self, hidden_size, frequency_embedding_size = 256):
93
- super().__init__()
94
- self.mlp = nn.Sequential(
95
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
96
- nn.SiLU(),
97
- nn.Linear(hidden_size, hidden_size, bias=True),
98
- )
99
- self.frequency_embedding_size = frequency_embedding_size
100
-
101
- @staticmethod
102
- def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
103
- """
104
- Create sinusoidal timestep embeddings.
105
- :param t: a 1-D Tensor of N indices, one per batch element.
106
- These may be fractional.
107
- :param dim: the dimension of the output.
108
- :param max_period: controls the minimum frequency of the embeddings.
109
- :return: an (N, D) Tensor of positional embeddings.
110
- """
111
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
112
- half = dim // 2
113
- freqs = torch.exp(
114
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
115
- ).to(device=t.device)
116
- args = t[:, None].float() * freqs[None]
117
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
118
- if dim % 2:
119
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
120
- return embedding
121
-
122
- def forward(self, t: torch.Tensor) -> torch.Tensor:
123
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
124
- t_emb = self.mlp(t_freq)
125
- return t_emb
126
-
127
- class SequenceEmbed(nn.Module):
128
- def __init__(
129
- self,
130
- dim_w,
131
- dim_h,
132
- norm_layer=None,
133
- bias=True,
134
- ):
135
- super().__init__()
136
-
137
- self.proj = nn.Linear(dim_w, dim_h, bias=bias)
138
- self.norm = norm_layer(dim_h) if norm_layer else nn.Identity()
139
-
140
- def forward(self, x: torch.Tensor) -> torch.Tensor:
141
- return self.norm(self.proj(x))
142
-
143
-
144
- class FMTBlock(nn.Module):
145
- """
146
- A FMT block inspried by DiT Block
147
- """
148
- def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs) -> None:
149
- super().__init__()
150
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
151
- self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
152
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
154
- approx_gelu = lambda: nn.GELU(approximate="tanh")
155
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
156
- self.adaLN_modulation = nn.Sequential(
157
- nn.SiLU(),
158
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
159
- )
160
-
161
- def framewise_modulate(self, x, shift, scale) -> torch.Tensor:
162
- return x * (1 + scale) + shift
163
-
164
- def forward(self, x, c, mask=None) -> torch.Tensor:
165
- assert mask is not None
166
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
167
- x = x + gate_msa * self.attn(self.framewise_modulate(self.norm1(x), shift_msa, scale_msa), mask = mask)
168
- x = x + gate_mlp * self.mlp(self.framewise_modulate(self.norm2(x), shift_mlp, scale_mlp))
169
- return x
170
-
171
- class Decoder(nn.Module):
172
- """
173
- The final decoder of FlowMatchingTransformer.
174
- """
175
- def __init__(self, hidden_size, dim_w):
176
- super().__init__()
177
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
178
- self.adaLN_modulation = nn.Sequential(
179
- nn.SiLU(),
180
- nn.Linear(hidden_size, 2 * hidden_size, bias=True)
181
- )
182
- self.linear = nn.Linear(hidden_size, dim_w, bias=True)
183
-
184
- def framewise_modulate(self, x, shift, scale) -> torch.Tensor:
185
- return x * (1 + scale) + shift
186
-
187
- def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
188
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
189
- x = self.framewise_modulate(self.norm_final(x), shift, scale)
190
- return self.linear(x)
191
-
192
-
193
- class FlowMatchingTransformer(nn.Module):
194
- """
195
- Flow Matching Transformer (FMT)
196
- """
197
- def __init__(self, opt) -> None:
198
- super().__init__()
199
- self.opt = opt
200
-
201
- self.num_frames_for_clip = int(self.opt.wav2vec_sec * self.opt.fps)
202
- self.num_prev_frames = int(opt.num_prev_frames)
203
- self.num_total_frames = self.num_prev_frames + self.num_frames_for_clip
204
-
205
- self.hidden_size = opt.dim_h
206
- self.mlp_ratio = opt.mlp_ratio
207
- self.fmt_depth = opt.fmt_depth
208
- self.num_heads = opt.num_heads
209
-
210
- # 输入序列嵌入
211
- self.x_embedder = SequenceEmbed(2 * opt.dim_motion, self.hidden_size)
212
-
213
- # video time position encoding
214
- self.pos_embed = nn.Parameter(
215
- torch.zeros(1, self.num_total_frames, self.hidden_size),
216
- requires_grad=False
217
- )
218
-
219
- # flow trajectory time encoding
220
- self.t_embedder = TimestepEmbedder(self.hidden_size)
221
- self.c_embedder = nn.Linear(opt.dim_c, self.hidden_size)
222
-
223
- # define FMT blocks
224
- self.blocks = nn.ModuleList([
225
- FMTBlock(self.hidden_size, self.num_heads, mlp_ratio=self.mlp_ratio)
226
- for _ in range(self.fmt_depth)
227
- ])
228
- self.decoder = Decoder(self.hidden_size, self.opt.dim_motion)
229
- self.initialize_weights()
230
-
231
- # define alignment mask
232
- alignment_mask = enc_dec_mask(
233
- self.num_total_frames, self.num_total_frames, 1,
234
- expansion=opt.attention_window
235
- )
236
- self.register_buffer("alignment_mask", alignment_mask)
237
-
238
- def initialize_weights(self) -> None:
239
- def _basic_init(module):
240
- if isinstance(module, nn.Linear):
241
- torch.nn.init.xavier_uniform_(module.weight)
242
- if module.bias is not None:
243
- nn.init.constant_(module.bias, 0)
244
-
245
- self.apply(_basic_init)
246
-
247
- pos_embed = get_sinusoid_encoding_table(
248
- self.num_total_frames, self.hidden_size
249
- )
250
- self.pos_embed.data.copy_(pos_embed.unsqueeze(0))
251
-
252
- w = self.x_embedder.proj.weight.data
253
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
254
- nn.init.constant_(self.x_embedder.proj.bias, 0)
255
-
256
- # Initialize timestep embedding MLP
257
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
258
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
259
-
260
- # Zero-out adaLN modulation layers in FMT blocks
261
- for block in self.blocks:
262
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
263
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
264
-
265
- # Zero-out output layers
266
- nn.init.constant_(self.decoder.adaLN_modulation[-1].weight, 0)
267
- nn.init.constant_(self.decoder.adaLN_modulation[-1].bias, 0)
268
- nn.init.constant_(self.decoder.linear.weight, 0)
269
- nn.init.constant_(self.decoder.linear.bias, 0)
270
-
271
- def sequence_embedder(
272
- self, sequence: torch.Tensor,
273
- dropout_prob: float,
274
- train: bool = False
275
- ) -> torch.Tensor:
276
- if train:
277
- batch_id_for_drop = torch.where(
278
- torch.rand(sequence.shape[0], device=sequence.device) < dropout_prob
279
- )
280
- sequence[batch_id_for_drop] = 0
281
- return sequence
282
-
283
- def forward(
284
- self,
285
- t,
286
- x,
287
- a,
288
- prev_x,
289
- prev_a,
290
- ref_x,
291
- gaze,
292
- prev_gaze,
293
- pose,
294
- prev_pose,
295
- cam,
296
- prev_cam,
297
- train: bool = True,
298
- **kwargs
299
- ) -> torch.Tensor:
300
- t = self.t_embedder(t).unsqueeze(1) # (N, D)
301
- a = self.sequence_embedder(a, dropout_prob=self.opt.audio_dropout_prob, train=train)
302
- pose = self.sequence_embedder(pose, dropout_prob=self.opt.audio_dropout_prob, train=train)
303
- cam = self.sequence_embedder(cam, dropout_prob=self.opt.audio_dropout_prob, train=train)
304
- gaze = self.sequence_embedder(gaze, dropout_prob=self.opt.audio_dropout_prob, train=train)
305
-
306
- if prev_x is not None:
307
- prev_x = self.sequence_embedder(prev_x, dropout_prob=0.5, train=train)
308
- prev_a = self.sequence_embedder(prev_a, dropout_prob=0.5, train=train)
309
- prev_pose = self.sequence_embedder(prev_pose, dropout_prob=0.5, train=train)
310
- prev_cam = self.sequence_embedder(prev_cam, dropout_prob=0.5, train=train)
311
- prev_gaze = self.sequence_embedder(prev_gaze, dropout_prob=0.5, train=train)
312
-
313
- x = torch.cat([prev_x, x], dim=1)
314
- a = torch.cat([prev_a, a], dim=1)
315
- pose = torch.cat([prev_pose, pose], dim=1)
316
- cam = torch.cat([prev_cam, cam], dim=1)
317
- gaze = torch.cat([prev_gaze, gaze], dim=1)
318
-
319
- ref_x = ref_x[:, None, ...].repeat(1, x.shape[1], 1)
320
- x = torch.cat([ref_x, x], dim=-1)
321
- x = self.x_embedder(x)
322
- x = x + self.pos_embed # (N, L + L', D)
323
-
324
- c = self.c_embedder(a + pose + cam + gaze)
325
- c = t + c
326
-
327
- # forwarding FMT Blocks
328
- for block in self.blocks:
329
- x = block(x, c, self.alignment_mask) # (N, T, D)
330
-
331
- return self.decoder(x, c)
332
-
333
- @torch.no_grad()
334
- def forward_with_cfg(
335
- self,
336
- t,
337
- x,
338
- a,
339
- prev_x,
340
- prev_a,
341
- ref_x,
342
- gaze,
343
- prev_gaze,
344
- pose,
345
- prev_pose,
346
- cam,
347
- prev_cam,
348
- a_cfg_scale: float = 1.0,
349
- **kwargs
350
- ) -> torch.Tensor:
351
- """
352
- Forward pass with Classifier-Free Guidance (CFG).
353
- """
354
- if a_cfg_scale != 1.0:
355
- null_a = torch.zeros_like(a)
356
- audio_cat = torch.cat([null_a, a], dim=0)
357
- gaze_cat = torch.cat([gaze, gaze], dim=0)
358
- pose_cat = torch.cat([pose, pose], dim=0)
359
- cam_cat = torch.cat([cam, cam], dim=0)
360
-
361
- x_cat = torch.cat([x, x], dim=0)
362
- prev_x_cat = torch.cat([prev_x, prev_x], dim=0)
363
- prev_a_cat = torch.cat([prev_a, prev_a], dim=0)
364
- prev_gaze_cat = torch.cat([prev_gaze, prev_gaze], dim=0)
365
- prev_pose_cat = torch.cat([prev_pose, prev_pose], dim=0)
366
- prev_cam_cat = torch.cat([prev_cam, prev_cam], dim=0)
367
- ref_x_cat = torch.cat([ref_x, ref_x], dim=0)
368
-
369
- model_output = self.forward(
370
- t=t,
371
- x=x_cat,
372
- a=audio_cat,
373
- prev_x=prev_x_cat,
374
- prev_a=prev_a_cat,
375
- ref_x=ref_x_cat,
376
- gaze=gaze_cat,
377
- prev_gaze=prev_gaze_cat,
378
- pose=pose_cat,
379
- prev_pose=prev_pose_cat,
380
- cam=cam_cat,
381
- prev_cam=prev_cam_cat,
382
- train=False
383
- )
384
- uncond, all_cond = torch.chunk(model_output, chunks=2, dim=0)
385
- return uncond + a_cfg_scale * (all_cond - uncond)
386
-
387
- else:
388
- return self.forward(
389
- t=t,
390
- x=x,
391
- a=a,
392
- prev_x=prev_x,
393
- prev_a=prev_a,
394
- ref_x=ref_x,
395
- gaze=gaze,
396
- prev_gaze=prev_gaze,
397
- pose=pose,
398
- prev_pose=prev_pose,
399
- cam=cam,
400
- prev_cam=prev_cam,
401
- train=False
402
- )
 
1
+ import os, math, torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from timm.layers import use_fused_attn
6
+ from timm.models.vision_transformer import Mlp
7
+
8
+ # ==========================================
9
+ # RoPE Implementation
10
+ # ==========================================
11
+
12
+ class RotaryEmbedding(nn.Module):
13
+ def __init__(self, dim, max_position_embeddings=4096, base=10000, device=None):
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.max_position_embeddings = max_position_embeddings
17
+ self.base = base
18
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float().to(device) / dim))
19
+ self.register_buffer("inv_freq", inv_freq)
20
+ self._set_cos_sin_cache(
21
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
22
+ )
23
+
24
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
25
+ self.max_seq_len_cached = seq_len
26
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
27
+ freqs = torch.outer(t, self.inv_freq)
28
+ emb = torch.cat((freqs, freqs), dim=-1)
29
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
30
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
31
+
32
+ def forward(self, x, seq_len=None):
33
+ if seq_len > self.max_seq_len_cached:
34
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
35
+ return (
36
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
37
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
38
+ )
39
+
40
+ def rotate_half(x):
41
+ x1 = x[..., : x.shape[-1] // 2]
42
+ x2 = x[..., x.shape[-1] // 2 :]
43
+ return torch.cat((-x2, x1), dim=-1)
44
+
45
+ def apply_rotary_pos_emb(q, k, cos, sin):
46
+ cos = cos.unsqueeze(0).unsqueeze(0)
47
+ sin = sin.unsqueeze(0).unsqueeze(0)
48
+ q_embed = (q * cos) + (rotate_half(q) * sin)
49
+ k_embed = (k * cos) + (rotate_half(k) * sin)
50
+ return q_embed, k_embed
51
+
52
+ # ==========================================
53
+ # Core Modules
54
+ # ==========================================
55
+
56
+ class Attention(nn.Module):
57
+ def __init__(
58
+ self,
59
+ dim: int,
60
+ num_heads: int = 8,
61
+ qkv_bias: bool = False,
62
+ qk_norm: bool = False,
63
+ attn_drop: float = 0.,
64
+ proj_drop: float = 0.,
65
+ norm_layer: nn.Module = nn.LayerNorm,
66
+ ) -> None:
67
+
68
+ super().__init__()
69
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
70
+ self.num_heads = num_heads
71
+ self.head_dim = dim // num_heads
72
+ self.scale = self.head_dim ** -0.5
73
+ self.fused_attn = use_fused_attn()
74
+
75
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
76
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
77
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
78
+ self.attn_drop = nn.Dropout(attn_drop)
79
+ self.proj = nn.Linear(dim, dim)
80
+ self.proj_drop = nn.Dropout(proj_drop)
81
+
82
+ def forward(self, x: torch.Tensor, rotary_pos_emb=None) -> torch.Tensor:
83
+ B, N, C = x.shape
84
+
85
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
86
+ q, k, v = qkv.unbind(0)
87
+ q, k = self.q_norm(q), self.k_norm(k)
88
+
89
+ if rotary_pos_emb is not None:
90
+ cos, sin = rotary_pos_emb
91
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
92
+
93
+ if self.fused_attn:
94
+ x = F.scaled_dot_product_attention(
95
+ q, k, v,
96
+ attn_mask=None,
97
+ dropout_p=self.attn_drop.p if self.training else 0.,
98
+ )
99
+ else:
100
+ q = q * self.scale
101
+ attn = q @ k.transpose(-2, -1)
102
+ attn = attn.softmax(dim=-1)
103
+ attn = self.attn_drop(attn)
104
+ x = attn @ v
105
+
106
+ x = x.transpose(1, 2).reshape(B, N, C)
107
+ x = self.proj(x)
108
+ x = self.proj_drop(x)
109
+ return x
110
+
111
+ class TimestepEmbedder(nn.Module):
112
+ def __init__(self, hidden_size, frequency_embedding_size = 256):
113
+ super().__init__()
114
+ self.mlp = nn.Sequential(
115
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
116
+ nn.SiLU(),
117
+ nn.Linear(hidden_size, hidden_size, bias=True),
118
+ )
119
+ self.frequency_embedding_size = frequency_embedding_size
120
+
121
+ @staticmethod
122
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor:
123
+ half = dim // 2
124
+ freqs = torch.exp(
125
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
126
+ ).to(device=t.device)
127
+ args = t[:, None].float() * freqs[None]
128
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
129
+ if dim % 2:
130
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
131
+ return embedding
132
+
133
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
134
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
135
+ t_emb = self.mlp(t_freq)
136
+ return t_emb
137
+
138
+ class SequenceEmbed(nn.Module):
139
+ def __init__(
140
+ self,
141
+ dim_w,
142
+ dim_h,
143
+ norm_layer=None,
144
+ bias=True,
145
+ ):
146
+ super().__init__()
147
+ self.proj = nn.Linear(dim_w, dim_h, bias=bias)
148
+ self.norm = norm_layer(dim_h) if norm_layer else nn.Identity()
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ return self.norm(self.proj(x))
152
+
153
+
154
+ class FMTBlock(nn.Module):
155
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs) -> None:
156
+ super().__init__()
157
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
158
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
159
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
160
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
161
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
162
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
163
+ self.adaLN_modulation = nn.Sequential(
164
+ nn.SiLU(),
165
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
166
+ )
167
+
168
+ def framewise_modulate(self, x, shift, scale) -> torch.Tensor:
169
+ return x * (1 + scale) + shift
170
+
171
+ def forward(self, x, c, rotary_pos_emb=None) -> torch.Tensor:
172
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
173
+ x = x + gate_msa * self.attn(self.framewise_modulate(self.norm1(x), shift_msa, scale_msa), rotary_pos_emb=rotary_pos_emb)
174
+ x = x + gate_mlp * self.mlp(self.framewise_modulate(self.norm2(x), shift_mlp, scale_mlp))
175
+ return x
176
+
177
+ class Decoder(nn.Module):
178
+ def __init__(self, hidden_size, dim_w):
179
+ super().__init__()
180
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
181
+ self.adaLN_modulation = nn.Sequential(
182
+ nn.SiLU(),
183
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
184
+ )
185
+ self.linear = nn.Linear(hidden_size, dim_w, bias=True)
186
+
187
+ def framewise_modulate(self, x, shift, scale) -> torch.Tensor:
188
+ return x * (1 + scale) + shift
189
+
190
+ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
191
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
192
+ x = self.framewise_modulate(self.norm_final(x), shift, scale)
193
+ return self.linear(x)
194
+
195
+ # ==========================================
196
+ # Main Model
197
+ # ==========================================
198
+
199
+ class FlowMatchingTransformer(nn.Module):
200
+ def __init__(self, opt) -> None:
201
+ super().__init__()
202
+ self.opt = opt
203
+
204
+ self.num_frames_for_clip = int(self.opt.wav2vec_sec * self.opt.fps)
205
+ self.num_prev_frames = int(opt.num_prev_frames)
206
+ self.num_total_frames = self.num_prev_frames + self.num_frames_for_clip
207
+
208
+ self.hidden_size = opt.dim_h
209
+ self.mlp_ratio = opt.mlp_ratio
210
+ self.fmt_depth = opt.fmt_depth
211
+ self.num_heads = opt.num_heads
212
+
213
+ self.x_embedder = SequenceEmbed(2 * opt.dim_motion, self.hidden_size)
214
+
215
+ # RoPE Setup
216
+ head_dim = self.hidden_size // self.num_heads
217
+ self.rotary_emb = RotaryEmbedding(head_dim)
218
+
219
+ self.t_embedder = TimestepEmbedder(self.hidden_size)
220
+ self.c_embedder = nn.Linear(opt.dim_c, self.hidden_size)
221
+
222
+ self.blocks = nn.ModuleList([
223
+ FMTBlock(self.hidden_size, self.num_heads, mlp_ratio=self.mlp_ratio)
224
+ for _ in range(self.fmt_depth)
225
+ ])
226
+ self.decoder = Decoder(self.hidden_size, self.opt.dim_motion)
227
+ self.initialize_weights()
228
+
229
+ def initialize_weights(self) -> None:
230
+ def _basic_init(module):
231
+ if isinstance(module, nn.Linear):
232
+ torch.nn.init.xavier_uniform_(module.weight)
233
+ if module.bias is not None:
234
+ nn.init.constant_(module.bias, 0)
235
+
236
+ self.apply(_basic_init)
237
+
238
+ w = self.x_embedder.proj.weight.data
239
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
240
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
241
+
242
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
243
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
244
+
245
+ for block in self.blocks:
246
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
247
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
248
+
249
+ nn.init.constant_(self.decoder.adaLN_modulation[-1].weight, 0)
250
+ nn.init.constant_(self.decoder.adaLN_modulation[-1].bias, 0)
251
+ nn.init.constant_(self.decoder.linear.weight, 0)
252
+ nn.init.constant_(self.decoder.linear.bias, 0)
253
+
254
+ def sequence_embedder(
255
+ self, sequence: torch.Tensor,
256
+ dropout_prob: float,
257
+ train: bool = False
258
+ ) -> torch.Tensor:
259
+ if train:
260
+ batch_id_for_drop = torch.where(
261
+ torch.rand(sequence.shape[0], device=sequence.device) < dropout_prob
262
+ )
263
+ sequence[batch_id_for_drop] = 0
264
+ return sequence
265
+
266
+ def forward(
267
+ self,
268
+ t,
269
+ x,
270
+ a,
271
+ prev_x,
272
+ prev_a,
273
+ ref_x,
274
+ gaze,
275
+ prev_gaze,
276
+ pose,
277
+ prev_pose,
278
+ cam,
279
+ prev_cam,
280
+ train: bool = True,
281
+ **kwargs
282
+ ) -> torch.Tensor:
283
+ t = self.t_embedder(t).unsqueeze(1)
284
+ a = self.sequence_embedder(a, dropout_prob=self.opt.audio_dropout_prob, train=train)
285
+ pose = self.sequence_embedder(pose, dropout_prob=self.opt.audio_dropout_prob, train=train)
286
+ cam = self.sequence_embedder(cam, dropout_prob=self.opt.audio_dropout_prob, train=train)
287
+ gaze = self.sequence_embedder(gaze, dropout_prob=self.opt.audio_dropout_prob, train=train)
288
+
289
+ if prev_x is not None:
290
+ prev_x = self.sequence_embedder(prev_x, dropout_prob=0.5, train=train)
291
+ prev_a = self.sequence_embedder(prev_a, dropout_prob=0.5, train=train)
292
+ prev_pose = self.sequence_embedder(prev_pose, dropout_prob=0.5, train=train)
293
+ prev_cam = self.sequence_embedder(prev_cam, dropout_prob=0.5, train=train)
294
+ prev_gaze = self.sequence_embedder(prev_gaze, dropout_prob=0.5, train=train)
295
+
296
+ x = torch.cat([prev_x, x], dim=1)
297
+ a = torch.cat([prev_a, a], dim=1)
298
+ pose = torch.cat([prev_pose, pose], dim=1)
299
+ cam = torch.cat([prev_cam, cam], dim=1)
300
+ gaze = torch.cat([prev_gaze, gaze], dim=1)
301
+
302
+ ref_x = ref_x[:, None, ...].repeat(1, x.shape[1], 1)
303
+ x = torch.cat([ref_x, x], dim=-1)
304
+ x = self.x_embedder(x)
305
+
306
+ # Calculate RoPE
307
+ rotary_pos_emb = self.rotary_emb(x, seq_len=x.shape[1])
308
+
309
+ c = self.c_embedder(a + pose + cam + gaze)
310
+ c = t + c
311
+
312
+ for block in self.blocks:
313
+ x = block(x, c, rotary_pos_emb=rotary_pos_emb)
314
+
315
+ return self.decoder(x, c)
316
+
317
+ @torch.no_grad()
318
+ def forward_with_cfg(
319
+ self,
320
+ t,
321
+ x,
322
+ a,
323
+ prev_x,
324
+ prev_a,
325
+ ref_x,
326
+ gaze,
327
+ prev_gaze,
328
+ pose,
329
+ prev_pose,
330
+ cam,
331
+ prev_cam,
332
+ a_cfg_scale: float = 1.0,
333
+ **kwargs
334
+ ) -> torch.Tensor:
335
+ if a_cfg_scale != 1.0:
336
+ null_a = torch.zeros_like(a)
337
+ audio_cat = torch.cat([null_a, a], dim=0)
338
+ gaze_cat = torch.cat([gaze, gaze], dim=0)
339
+ pose_cat = torch.cat([pose, pose], dim=0)
340
+ cam_cat = torch.cat([cam, cam], dim=0)
341
+
342
+ x_cat = torch.cat([x, x], dim=0)
343
+ prev_x_cat = torch.cat([prev_x, prev_x], dim=0)
344
+ prev_a_cat = torch.cat([prev_a, prev_a], dim=0)
345
+ prev_gaze_cat = torch.cat([prev_gaze, prev_gaze], dim=0)
346
+ prev_pose_cat = torch.cat([prev_pose, prev_pose], dim=0)
347
+ prev_cam_cat = torch.cat([prev_cam, prev_cam], dim=0)
348
+ ref_x_cat = torch.cat([ref_x, ref_x], dim=0)
349
+
350
+ model_output = self.forward(
351
+ t=t,
352
+ x=x_cat,
353
+ a=audio_cat,
354
+ prev_x=prev_x_cat,
355
+ prev_a=prev_a_cat,
356
+ ref_x=ref_x_cat,
357
+ gaze=gaze_cat,
358
+ prev_gaze=prev_gaze_cat,
359
+ pose=pose_cat,
360
+ prev_pose=prev_pose_cat,
361
+ cam=cam_cat,
362
+ prev_cam=prev_cam_cat,
363
+ train=False
364
+ )
365
+ uncond, all_cond = torch.chunk(model_output, chunks=2, dim=0)
366
+ return uncond + a_cfg_scale * (all_cond - uncond)
367
+
368
+ else:
369
+ return self.forward(
370
+ t=t,
371
+ x=x,
372
+ a=a,
373
+ prev_x=prev_x,
374
+ prev_a=prev_a,
375
+ ref_x=ref_x,
376
+ gaze=gaze,
377
+ prev_gaze=prev_gaze,
378
+ pose=pose,
379
+ prev_pose=prev_pose,
380
+ cam=cam,
381
+ prev_cam=prev_cam,
382
+ train=False
383
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
generator/generate.py CHANGED
@@ -40,7 +40,7 @@ class DataProcessor:
40
  )
41
 
42
  self.transform = transforms.Compose([
43
- transforms.Resize((256, 256)),
44
  transforms.ToTensor(),
45
  ])
46
 
@@ -50,14 +50,10 @@ class DataProcessor:
50
  h, w = img_arr.shape[:2]
51
 
52
  mult = 360.0 / h
53
- resized_img = cv2.resize(
54
- img_arr, dsize=(0, 0), fx=mult, fy=mult,
55
- interpolation=cv2.INTER_AREA if mult < 1 else cv2.INTER_CUBIC
56
- )
57
 
58
- bboxes = self.fa.face_detector.detect_from_image(resized_img)
59
  valid_bboxes = [
60
- (int(x1 / mult), int(y1 / mult), int(x2 / mult), int(y2 / mult), score)
61
  for (x1, y1, x2, y2, score) in bboxes if score > 0.95
62
  ]
63
 
 
40
  )
41
 
42
  self.transform = transforms.Compose([
43
+ transforms.Resize((512, 512)),
44
  transforms.ToTensor(),
45
  ])
46
 
 
50
  h, w = img_arr.shape[:2]
51
 
52
  mult = 360.0 / h
 
 
 
 
53
 
54
+ bboxes = self.fa.face_detector.detect_from_image(img_arr)
55
  valid_bboxes = [
56
+ (int(x1 ), int(y1), int(x2), int(y2), score)
57
  for (x1, y1, x2, y2, score) in bboxes if score > 0.95
58
  ]
59
 
renderer/inference.py CHANGED
@@ -25,7 +25,7 @@ class DataProcessor:
25
  self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
26
 
27
  self.transform = transforms.Compose([
28
- transforms.Resize((256, 256)),
29
  transforms.ToTensor(),
30
  ])
31
 
@@ -43,17 +43,13 @@ class DataProcessor:
43
 
44
  # Resize for faster detection
45
  h, w = img.shape[:2]
46
- mult = 360. / h
47
- resized_img = cv2.resize(
48
- img, dsize=(0, 0), fx=mult, fy=mult,
49
- interpolation=cv2.INTER_AREA if mult < 1. else cv2.INTER_CUBIC
50
- )
51
 
52
- bboxes = self.fa.face_detector.detect_from_image(resized_img)
 
53
 
54
  # Filter valid faces (score > 0.95)
55
  valid_bboxes = [
56
- (int(x1 / mult), int(y1 / mult), int(x2 / mult), int(y2 / mult), score)
57
  for (x1, y1, x2, y2, score) in bboxes if score > 0.95
58
  ]
59
 
@@ -65,7 +61,7 @@ class DataProcessor:
65
  x1, y1, x2, y2, _ = valid_bboxes[0]
66
  bsy, bsx = int((y2 - y1) / 2), int((x2 - x1) / 2)
67
  my, mx = int((y1 + y2) / 2), int((x1 + x2) / 2)
68
- bs = int(max(bsy, bsx) * 1.3)
69
 
70
  # Pad image to allow cropping outside boundaries
71
  img = cv2.copyMakeBorder(img, bs, bs, bs, bs, cv2.BORDER_CONSTANT, value=0)
@@ -73,11 +69,6 @@ class DataProcessor:
73
  # Adjust coordinates for padding
74
  my, mx = my + bs, mx + bs
75
  crop_img = img[my - bs:my + bs, mx - bs:mx + bs]
76
-
77
- crop_img = cv2.resize(
78
- crop_img, (self.input_size, self.input_size),
79
- interpolation=cv2.INTER_AREA if mult < 1. else cv2.INTER_CUBIC
80
- )
81
  return Image.fromarray(crop_img)
82
 
83
  def load_image(self, path):
@@ -128,7 +119,6 @@ class Demo(nn.Module):
128
  source_img = self.processor.process_img(source_img)
129
 
130
  source_tensor = self.processor.transform(source_img).unsqueeze(0).to(self.device)
131
-
132
  # 2. Encode Source Appearance & Motion
133
  f_r, i_r = self.gen.app_encode(source_tensor)
134
  t_r = self.gen.mot_encode(source_tensor)
@@ -216,8 +206,8 @@ if __name__ == '__main__':
216
  parser.add_argument("--save_path", type=str, default="./results", help="Output directory")
217
 
218
  # Model Params
219
- parser.add_argument("--renderer_path", type=str, required=True, help="Checkpoint path")
220
- parser.add_argument("--input_size", type=int, default=256, help="Resolution")
221
  parser.add_argument('--swin_res_threshold', type=int, default=128)
222
  parser.add_argument('--num_heads', type=int, default=8)
223
  parser.add_argument('--window_size', type=int, default=8)
 
25
  self.fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False)
26
 
27
  self.transform = transforms.Compose([
28
+ transforms.Resize((512, 512)),
29
  transforms.ToTensor(),
30
  ])
31
 
 
43
 
44
  # Resize for faster detection
45
  h, w = img.shape[:2]
 
 
 
 
 
46
 
47
+
48
+ bboxes = self.fa.face_detector.detect_from_image(img)
49
 
50
  # Filter valid faces (score > 0.95)
51
  valid_bboxes = [
52
+ (int(x1), int(y1), int(x2 ), int(y2 ), score)
53
  for (x1, y1, x2, y2, score) in bboxes if score > 0.95
54
  ]
55
 
 
61
  x1, y1, x2, y2, _ = valid_bboxes[0]
62
  bsy, bsx = int((y2 - y1) / 2), int((x2 - x1) / 2)
63
  my, mx = int((y1 + y2) / 2), int((x1 + x2) / 2)
64
+ bs = int(max(bsy, bsx) * 1.6)
65
 
66
  # Pad image to allow cropping outside boundaries
67
  img = cv2.copyMakeBorder(img, bs, bs, bs, bs, cv2.BORDER_CONSTANT, value=0)
 
69
  # Adjust coordinates for padding
70
  my, mx = my + bs, mx + bs
71
  crop_img = img[my - bs:my + bs, mx - bs:mx + bs]
 
 
 
 
 
72
  return Image.fromarray(crop_img)
73
 
74
  def load_image(self, path):
 
119
  source_img = self.processor.process_img(source_img)
120
 
121
  source_tensor = self.processor.transform(source_img).unsqueeze(0).to(self.device)
 
122
  # 2. Encode Source Appearance & Motion
123
  f_r, i_r = self.gen.app_encode(source_tensor)
124
  t_r = self.gen.mot_encode(source_tensor)
 
206
  parser.add_argument("--save_path", type=str, default="./results", help="Output directory")
207
 
208
  # Model Params
209
+ parser.add_argument("--renderer_path", type=str, default="./checkpoints/renderer.ckpt", help="Checkpoint path")
210
+ parser.add_argument("--input_size", type=int, default=512, help="Resolution")
211
  parser.add_argument('--swin_res_threshold', type=int, default=128)
212
  parser.add_argument('--num_heads', type=int, default=8)
213
  parser.add_argument('--window_size', type=int, default=8)
renderer/models.py CHANGED
@@ -13,6 +13,7 @@ class IdentityEncoder(nn.Module):
13
  nn.BatchNorm2d(initial_channels),
14
  nn.ReLU(inplace=True)
15
  )
 
16
  self.down_blocks = nn.ModuleList()
17
  current_channels = initial_channels
18
  for out_channels in output_channels:
@@ -27,6 +28,7 @@ class IdentityEncoder(nn.Module):
27
  def forward(self, x):
28
  features = []
29
  x = self.initial_conv(x)
 
30
  features.append(x)
31
  for block in self.down_blocks:
32
  x = block(x)
 
13
  nn.BatchNorm2d(initial_channels),
14
  nn.ReLU(inplace=True)
15
  )
16
+ self.down_block_0 = DownConvResBlock(initial_channels, initial_channels)
17
  self.down_blocks = nn.ModuleList()
18
  current_channels = initial_channels
19
  for out_channels in output_channels:
 
28
  def forward(self, x):
29
  features = []
30
  x = self.initial_conv(x)
31
+ x = self.down_block_0(x)
32
  features.append(x)
33
  for block in self.down_blocks:
34
  x = block(x)