File size: 11,737 Bytes
4b07c53 f0d6854 4b07c53 da237e1 b6de2a7 da237e1 b6de2a7 da237e1 4b07c53 da237e1 4b07c53 da237e1 4b07c53 da237e1 4b07c53 da237e1 4b07c53 da237e1 4b07c53 da237e1 f0d6854 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 |
import os
import cv2
import numpy as np
import torch
from typing import Union, List
import torch.nn.functional as F
from einops import rearrange
import shutil
import os.path as osp
from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding
def load_all_model(
unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
vae_type="sd-vae",
unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
device=None,
):
vae = VAE(
model_path = os.path.join("models", vae_type),
)
print(f"load unet model from {unet_model_path}")
unet = UNet(
unet_config=unet_config,
model_path=unet_model_path,
device=device
)
pe = PositionalEncoding(d_model=384)
return vae, unet, pe
def get_file_type(video_path):
_, ext = os.path.splitext(video_path)
if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
return 'image'
elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
return 'video'
else:
return 'unsupported'
def get_video_fps(video_path):
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
video.release()
return fps
def datagen(
whisper_chunks,
vae_encode_latents,
batch_size=8,
delay_frame=0,
device="cuda:0",
):
whisper_batch, latent_batch = [], []
for i, w in enumerate(whisper_chunks):
idx = (i+delay_frame)%len(vae_encode_latents)
latent = vae_encode_latents[idx]
whisper_batch.append(w)
latent_batch.append(latent)
if len(latent_batch) >= batch_size:
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch, latent_batch
whisper_batch, latent_batch = [], []
# the last batch may smaller than batch size
if len(latent_batch) > 0:
whisper_batch = torch.stack(whisper_batch)
latent_batch = torch.cat(latent_batch, dim=0)
yield whisper_batch.to(device), latent_batch.to(device)
def cast_training_params(
model: Union[torch.nn.Module, List[torch.nn.Module]],
dtype=torch.float32,
):
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)
def rand_log_normal(
shape,
loc=0.,
scale=1.,
device='cpu',
dtype=torch.float32,
generator=None
):
"""Draws samples from an lognormal distribution."""
rnd_normal = torch.randn(
shape, device=device, dtype=dtype, generator=generator) # N(0, I)
sigma = (rnd_normal * scale + loc).exp()
return sigma
def get_mouth_region(frames, image_pred, pixel_values_face_mask):
# Initialize lists to store the results for each image in the batch
mouth_real_list = []
mouth_generated_list = []
# Process each image in the batch
for b in range(frames.shape[0]):
# Find the non-zero area in the face mask
non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
# If there are no non-zero indices, skip this image
if non_zero_indices.numel() == 0:
continue
min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
non_zero_indices[:, 1])
min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
non_zero_indices[:, 2])
# Crop the frames and image_pred according to the non-zero area
frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
# Resize the cropped images to 256*256
frames_resized = F.interpolate(frames_cropped.unsqueeze(
0), size=(256, 256), mode='bilinear', align_corners=False)
image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
0), size=(256, 256), mode='bilinear', align_corners=False)
# Append the resized images to the result lists
mouth_real_list.append(frames_resized)
mouth_generated_list.append(image_pred_resized)
# Convert the lists to tensors if they are not empty
mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
mouth_generated = torch.cat(
mouth_generated_list, dim=0) if mouth_generated_list else None
return mouth_real, mouth_generated
def get_image_pred(pixel_values,
ref_pixel_values,
audio_prompts,
vae,
net,
weight_dtype):
with torch.no_grad():
bsz, num_frames, c, h, w = pixel_values.shape
masked_pixel_values = pixel_values.clone()
masked_pixel_values[:, :, :, h//2:, :] = -1
masked_frames = rearrange(
masked_pixel_values, 'b f c h w -> (b f) c h w')
masked_latents = vae.encode(masked_frames).latent_dist.mode()
masked_latents = masked_latents * vae.config.scaling_factor
masked_latents = masked_latents.float()
ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
ref_latents = vae.encode(ref_frames).latent_dist.mode()
ref_latents = ref_latents * vae.config.scaling_factor
ref_latents = ref_latents.float()
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
input_latents = input_latents.to(weight_dtype)
timesteps = torch.tensor([0], device=input_latents.device)
latents_pred = net(
input_latents,
timesteps,
audio_prompts,
)
latents_pred = (1 / vae.config.scaling_factor) * latents_pred
image_pred = vae.decode(latents_pred).sample
image_pred = image_pred.float()
return image_pred
def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
with torch.no_grad():
audio_feature_length_per_frame = 2 * \
(cfg.data.audio_padding_length_left +
cfg.data.audio_padding_length_right + 1)
audio_feats = batch['audio_feature'].to(weight_dtype)
audio_feats = wav2vec.encoder(
audio_feats, output_hidden_states=True).hidden_states
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
start_ts = batch['audio_offset']
step_ts = batch['audio_step']
audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
audio_feats,
torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
audio_prompts = []
for bb in range(bsz):
audio_feats_list = []
for f in range(num_frames):
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
audio_clip = audio_feats[bb:bb+1,
cur_t: cur_t+audio_feature_length_per_frame]
audio_feats_list.append(audio_clip)
audio_feats_list = torch.stack(audio_feats_list, 1)
audio_prompts.append(audio_feats_list)
audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
return audio_prompts
def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
if total_limit is not None:
checkpoints = os.listdir(save_dir)
checkpoints = [d for d in checkpoints if d.endswith(".pth")]
checkpoints = [d for d in checkpoints if name in d]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
)
if len(checkpoints) >= total_limit:
num_to_remove = len(checkpoints) - total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(
f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(
save_dir, removing_checkpoint)
os.remove(removing_checkpoint)
state_dict = model.state_dict()
torch.save(state_dict, save_path)
def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
unwarp_net = accelerator.unwrap_model(net)
save_checkpoint(
unwarp_net.unet,
save_dir,
global_step,
name="unet",
total_limit=cfg.total_limit,
logger=logger
)
def delete_additional_ckpt(base_path, num_keep):
dirs = []
for d in os.listdir(base_path):
if d.startswith("checkpoint-"):
dirs.append(d)
num_tot = len(dirs)
if num_tot <= num_keep:
return
# ensure ckpt is sorted and delete the ealier!
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
for d in del_dirs:
path_to_dir = osp.join(base_path, d)
if osp.exists(path_to_dir):
shutil.rmtree(path_to_dir)
def seed_everything(seed):
import random
import numpy as np
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed % (2**32))
random.seed(seed)
def process_and_save_images(
batch,
image_pred,
image_pred_infer,
save_dir,
global_step,
accelerator,
num_images_to_keep=10,
syncnet_score=1
):
# Rearrange the tensors
print("image_pred.shape: ", image_pred.shape)
pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
# Create masked pixel values
masked_pixel_values = batch["pixel_values_vid"].clone()
_, _, _, h, _ = batch["pixel_values_vid"].shape
masked_pixel_values[:, :, :, h//2:, :] = -1
masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
# Keep only the specified number of images
pixel_values = pixel_values[:num_images_to_keep, :, :, :]
masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
# Concatenate images
concat = torch.cat([
masked_pixel_values * 0.5 + 0.5,
pixel_values_ref_img * 0.5 + 0.5,
image_pred * 0.5 + 0.5,
pixel_values * 0.5 + 0.5,
image_pred_infer * 0.5 + 0.5,
], dim=2)
print("concat.shape: ", concat.shape)
# Create the save directory if it doesn't exist
os.makedirs(f'{save_dir}/samples/', exist_ok=True)
# Try to save the concatenated image
try:
# Concatenate images horizontally and convert to numpy array
final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
# Save the image
cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
except Exception as e:
print(f"Failed to save image: {e}") |