File size: 7,639 Bytes
55f3ab3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import torch
import os
import json
import cv2
import numpy as np
from einops import rearrange
from PIL import Image
import imageio
import cv2
import einops
from matplotlib import pyplot as plt
import re

import torch
import torchvision

from safetensors.torch import load_model, save_model,safe_open
import einops
def cat_video(amd_model,z_video:torch.Tensor,ref_img:torch.Tensor,motion_seq_len:int=15):
    '''
    Args:
        z_video (torch.Tensor): shape = (B,F,C,H,W)
        motion_seq_len (torch.Tensor): motion transformer output 
        ref_img : B,C,H,W
    '''
    n,f,_,_,_ = z_video.shape
    assert (f - 1) % motion_seq_len == 0, f"no. frames miss match"
    motion_list = []
    for i in range(1,f,motion_seq_len):
        motion_list.append(amd_model.extract_motion(z_video[:,i-1:i+motion_seq_len],None))

    # ref_motion
    ref_frame = ref_img.unsqueeze(1)
    mix_frame = ref_frame.repeat(1,2,1,1,1)
    ref_motion = amd_model.extract_motion(mix_frame,None) # 4,1,256,4,4
    ref_motion = ref_motion.squeeze(1)
    return torch.concat(motion_list,dim=1),ref_motion


def save_cfg(path, args):
    os.makedirs(path, exist_ok=True)
    # if not os.path.exists(f'{path}/args.txt'):
    with open(f'{path}/args.txt', 'w') as f:
        json.dump(args.__dict__, f, indent=2)
    # else:
    print(f'Experiment of the same name already exists. Are you trying to resume training?')
        # assert args.resume > 0, f'Experiment of the same name already exists. Are you trying to resume training?'
        
def _freeze_parameters(model):
    for param in model.parameters():
        param.requires_grad = False
    model._requires_grad = False
    return model

def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=8, fps=8):
    """
    Args:
        videos: videos of shape (b, c ,t, h, w) # Default videos in [0,1]
        rescale: rescale the videos to [0, 1] # (True if videos are in [-1, 1]) 
    """
    videos = rearrange(videos, "b c t h w -> t b c h w")
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
        x = (x * 255).clamp(0, 255).numpy().astype(np.uint8)
        outputs.append(x)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    # export_to_video(outputs, output_video_path=path, fps=fps)

    imageio.mimsave(path, outputs, fps=fps)


def save_images_grid(images, grid_size, save_path):
    """
    将多个 PIL.Image.Image 对象组成一个网格,并保存为 .png 文件

    :param images: List of PIL.Image.Image 对象
    :param grid_size: (rows, cols) 格式的元组,表示网格的行数和列数
    :param save_path: 保存图片的路径
    """
    rows, cols = grid_size
    assert len(images) <= rows * cols, "图像数量多于网格容量"

    # 获取每个图像的尺寸(假设所有图像尺寸相同)
    img_width, img_height = images[0].size

    # 创建一个新的图像,大小为网格的总尺寸
    grid_img = Image.new('RGB', (cols * img_width, rows * img_height))

    # 将每个图像粘贴到网格中
    for idx, img in enumerate(images):
        row = idx // cols
        col = idx % cols
        grid_img.paste(img, (col * img_width, row * img_height))

    # 保存结果图像
    grid_img.save(save_path)


def print_param_num(model):
    """
    打印模型的参数数量
    """
    total_params = sum(p.numel() for p in model.parameters())

    train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    freeze_params = sum(p.numel() for p in model.parameters() if p.requires_grad is False)

    print(f'#### #### 模型总参数数量:{total_params / 1_000_000:.2f}M')
    print(f'####  模型训练数量:{train_params / 1_000_000:.2f}M')
    print(f'####  模型冻结参数数量:{freeze_params / 1_000_000:.2f}M')


def vae_encode(vae,latents):
    # video : N,T,C,H,W
    latents_type = None
    
    if len(latents.shape) == 5:
        N,T,C,H,W = latents.shape
        latents_type = 'video'
        latents = einops.rearrange(latents,'n t c h w -> (n t) c h w')
    else:
        N,C,H,W = latents.shape
        latents_type = 'image'
        
    with torch.no_grad():
        latents = vae.encode(latents).latent_dist
        latents = latents.sample()
        latents = latents * 0.18215
    
    if latents_type == 'video':
        latents = einops.rearrange(latents,'(n t) c h w -> n t c h w',n=N,t=T)
    return latents

def vae_decode(vae,latents):
    latents_type = None
    
    if len(latents.shape) == 5:
        N,T,C,H,W = latents.shape
        latents_type = 'video'
        latents = einops.rearrange(latents,'n t c h w -> (n t) c h w')
    else:
        N,C,H,W = latents.shape
        latents_type = 'image'

    latents = 1 / 0.18215 * latents
    with torch.no_grad():
        latents = vae.decode(latents).sample # (nt)chw
        
    if latents_type == 'video':
        latents = einops.rearrange(latents,'(n t) c h w -> n t c h w',n=N,t=T)

    return latents

def latents_to_videos(latents,batch_size):
    if len(latents.shape) == 4:
        M,C,H,W = latents.shape
        T = M // batch_size
        latents = einops.rearrange(latents,'(bt) c h w -> b t c h w',b=batch_size,t=T)

    videos = ((latents / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
    return videos


def freeze(model):
    for param in model.parameters():
        param.requires_grad = False


def model_load_pretrain(model, path, not_load_keyword='decoder',strict=False):
    tensors = {}
    with safe_open(path, framework="pt") as f:
        for k in f.keys():
            if not_load_keyword not in k:
                tensors[k] = f.get_tensor(k)

    model.load_state_dict(tensors,strict=strict)

def display_images(images, save_dir, prefix="image", vae=None, need_decode=False):
    """
    将图像保存到指定目录
    
    Args:
        images: 输入图像张量
        save_dir: 保存图像的目录路径
        prefix: 图像文件名前缀
        vae: VAE模型
        need_decode: 是否需要解码(可选)
    """
    # 确保保存目录存在
    os.makedirs(save_dir, exist_ok=True)
    
    if len(images.shape) == 5:
        images = einops.rearrange(images, 'b t c h w -> (b t) c h w')
        
    if need_decode:
        images = vae_decode(vae, images)

    t, c, h, w = images.shape
    
    # 遍历每张图像并保存
    for i in range(t):
        image = images[i]
        image = image.permute(1, 2, 0)
        
        # 转换为numpy数组并规范化到0-255范围
        image_np = ((image / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().numpy()
        
        # 使用PIL保存图像
        image_pil = Image.fromarray(image_np)
        save_path = os.path.join(save_dir, f"{prefix}_{i+1}.png")
        image_pil.save(save_path)

def find_latest_checkpoint(checkpoint_dir):
    max_step = -1
    latest_path = None
    
    # 遍历目标目录
    for name in os.listdir(checkpoint_dir):
        # 用正则匹配数字部分
        match = re.match(r"checkpoint-(\d+)$", name)
        if match:
            current_step = int(match.group(1))
            # 更新最大值
            if current_step > max_step:
                max_step = current_step
                latest_path = os.path.join(checkpoint_dir, name)
    
    if latest_path is None:
        raise ValueError(f"No valid checkpoint found in {checkpoint_dir}")
    
    result = os.path.join(latest_path,'model.safetensors')
    
    return result