File size: 19,942 Bytes
3d1c0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
# Copyright (c) 2025 FoundationVision
# SPDX-License-Identifier: MIT
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from infinity.models.videovae.utils.misc import is_torch_optim_sch


def inflate_gen(state_dict, temporal_patch_size, spatial_patch_size, strategy="average", inflation_pe=False):
    new_state_dict = state_dict.copy()

    pe_image0_w = state_dict["encoder.to_patch_emb_first_frame.1.weight"] # image_channel * patch_width * patch_height
    pe_image0_b = state_dict["encoder.to_patch_emb_first_frame.1.bias"] # image_channel * patch_width * patch_height
    pe_image1_w = state_dict["encoder.to_patch_emb_first_frame.2.weight"] # image_channel * patch_width * patch_height, dim
    pe_image1_b = state_dict["encoder.to_patch_emb_first_frame.2.bias"] # image_channel * patch_width * patch_height
    pe_image2_w = state_dict["encoder.to_patch_emb_first_frame.3.weight"] # image_channel * patch_width * patch_height
    pe_image2_b = state_dict["encoder.to_patch_emb_first_frame.3.bias"] # image_channel * patch_width * patch_height

    pd_image0_w = state_dict["decoder.to_pixels_first_frame.0.weight"] # dim, image_channel * patch_width * patch_height
    pd_image0_b = state_dict["decoder.to_pixels_first_frame.0.bias"] # image_channel * patch_width * patch_height

    pe_video0_w = state_dict["encoder.to_patch_emb.1.weight"]

    old_patch_size = int(math.sqrt(pe_image0_w.shape[0] // 3))
    old_patch_size_temporal = pe_video0_w.shape[0] // (3 * old_patch_size * old_patch_size)

    if old_patch_size != spatial_patch_size or old_patch_size_temporal != temporal_patch_size:
        if not inflation_pe:
            del new_state_dict["encoder.to_patch_emb_first_frame.1.weight"]
            del new_state_dict["encoder.to_patch_emb_first_frame.1.bias"]
            del new_state_dict["encoder.to_patch_emb_first_frame.2.weight"]

            del new_state_dict["decoder.to_pixels_first_frame.0.weight"]
            del new_state_dict["decoder.to_pixels_first_frame.0.bias"]

            del new_state_dict["encoder.to_patch_emb.1.weight"]
            del new_state_dict["encoder.to_patch_emb.1.bias"]
            del new_state_dict["encoder.to_patch_emb.2.weight"]

            del new_state_dict["decoder.to_pixels.0.weight"]
            del new_state_dict["decoder.to_pixels.0.bias"]

            return new_state_dict

        
        print(f"Inflate the patch embedding size from {old_patch_size_temporal}x{old_patch_size}x{old_patch_size} to {temporal_patch_size}x{spatial_patch_size}x{spatial_patch_size}.")
        pe_image0_w = F.interpolate(pe_image0_w.unsqueeze(0).unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
        pe_image0_b = F.interpolate(pe_image0_b.unsqueeze(0).unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
        pe_image1_w = F.interpolate(pe_image1_w.unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0)

        new_state_dict["encoder.to_patch_emb_first_frame.1.weight"] = pe_image0_w
        new_state_dict["encoder.to_patch_emb_first_frame.1.bias"] = pe_image0_b
        new_state_dict["encoder.to_patch_emb_first_frame.2.weight"] = pe_image1_w

        pd_image0_w = F.interpolate(pd_image0_w.permute(1, 0).unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0).permute(1, 0)
        pd_image0_b = F.interpolate(pd_image0_b.unsqueeze(0).unsqueeze(0), size=(3 * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)

        new_state_dict["decoder.to_pixels_first_frame.0.weight"] = pd_image0_w
        new_state_dict["decoder.to_pixels_first_frame.0.bias"] = pd_image0_b

        pe_video0_w = state_dict["encoder.to_patch_emb.1.weight"]
        pe_video0_b = state_dict["encoder.to_patch_emb.1.bias"]
        pe_video1_w = state_dict["encoder.to_patch_emb.2.weight"]

        pe_video0_w = F.interpolate(pe_video0_w.unsqueeze(0).unsqueeze(0), size=(3 * temporal_patch_size * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
        pe_video0_b = F.interpolate(pe_video0_b.unsqueeze(0).unsqueeze(0), size=(3 * temporal_patch_size* spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)
        pe_video1_w = F.interpolate(pe_video1_w.unsqueeze(0), size=(3 * temporal_patch_size * spatial_patch_size * spatial_patch_size)).squeeze(0)

        pd_video0_w = state_dict["decoder.to_pixels.0.weight"]
        pd_video0_b = state_dict["decoder.to_pixels.0.bias"]

        pd_video0_w = F.interpolate(pd_image0_w.permute(1, 0).unsqueeze(0), size=(3 * temporal_patch_size * spatial_patch_size * spatial_patch_size)).squeeze(0).permute(1, 0)
        pd_video0_b = F.interpolate(pd_image0_b.unsqueeze(0).unsqueeze(0), size=(3 * temporal_patch_size * spatial_patch_size * spatial_patch_size)).squeeze(0).squeeze(0)

        new_state_dict["encoder.to_patch_emb.1.weight"] = pe_video0_w
        new_state_dict["encoder.to_patch_emb.1.bias"] = pe_video0_b
        new_state_dict["encoder.to_patch_emb.2.weight"] = pe_video1_w

        new_state_dict["decoder.to_pixels.0.weight"] = pd_video0_w
        new_state_dict["decoder.to_pixels.0.bias"] = pd_video0_b

        return new_state_dict
    

    if strategy == "average":
        pe_video0_w = torch.cat([pe_image0_w/temporal_patch_size] * temporal_patch_size)
        pe_video0_b = torch.cat([pe_image0_b/temporal_patch_size] * temporal_patch_size)

        pe_video1_w = torch.cat([pe_image1_w/temporal_patch_size] * temporal_patch_size, dim=-1)
        pe_video1_b = pe_image1_b # torch.cat([pe_image1_b/temporal_patch_size] * temporal_patch_size)

        pe_video2_w = pe_image2_w # torch.cat([pe_image2_w/temporal_patch_size] * temporal_patch_size)
        pe_video2_b = pe_image2_b # torch.cat([pe_image2_b/temporal_patch_size] * temporal_patch_size)

    elif strategy == "first":
        pe_video0_w = torch.cat([pe_image0_w] + [torch.zeros_like(pe_image0_w, dtype=pe_image0_w.dtype)] * (temporal_patch_size - 1))
        pe_video0_b = torch.cat([pe_image0_b] + [torch.zeros_like(pe_image0_b, dtype=pe_image0_b.dtype)] * (temporal_patch_size - 1))

        pe_video1_w = torch.cat([pe_image1_w] + [torch.zeros_like(pe_image1_w, dtype=pe_image1_w.dtype)] * (temporal_patch_size - 1), dim=-1)
        pe_video1_b = pe_image1_b # torch.cat([pe_image1_b] + [torch.zeros_like(pe_image1_b, dtype=pe_image1_b.dtype)] * (temporal_patch_size - 1))

        pe_video2_w = pe_image2_w # torch.cat([pe_image2_w] + [torch.zeros_like(pe_image2_w, dtype=pe_image2_w.dtype)] * (temporal_patch_size - 1))
        pe_video2_b = pe_image2_b # torch.cat([pe_image2_b] + [torch.zeros_like(pe_image2_b, dtype=pe_image2_b.dtype)] * (temporal_patch_size - 1))
    

    else:
        raise NotImplementedError

    
    new_state_dict["encoder.to_patch_emb.1.weight"] = pe_video0_w
    new_state_dict["encoder.to_patch_emb.1.bias"] = pe_video0_b

    new_state_dict["encoder.to_patch_emb.2.weight"] = pe_video1_w
    new_state_dict["encoder.to_patch_emb.2.bias"] = pe_video1_b

    new_state_dict["encoder.to_patch_emb.3.weight"] = pe_video2_w
    new_state_dict["encoder.to_patch_emb.3.bias"] = pe_video2_b
    

    if strategy == "average":
        pd_video0_w = torch.cat([pd_image0_w/temporal_patch_size] * temporal_patch_size)
        pd_video0_b = torch.cat([pd_image0_b/temporal_patch_size] * temporal_patch_size)
    
    elif strategy == "first":
        pd_video0_w = torch.cat([pd_image0_w] + [torch.zeros_like(pd_image0_w, dtype=pd_image0_w.dtype)] * (temporal_patch_size - 1))
        pd_video0_b = torch.cat([pd_image0_b] + [torch.zeros_like(pd_image0_b, dtype=pd_image0_b.dtype)] * (temporal_patch_size - 1))

    else:
        raise NotImplementedError

    
    new_state_dict["decoder.to_pixels.0.weight"] = pd_video0_w
    new_state_dict["decoder.to_pixels.0.bias"] = pd_video0_b

    return new_state_dict


def inflate_dis(state_dict, strategy="center"):
    print("#" * 50)
    print(f"Initialize the video discriminator with {strategy}.")
    print("#" * 50)
    idis_weights = {k: v for k, v in state_dict.items() if "image_discriminator" in k}
    vids_weights = {k: v for k, v in state_dict.items() if "video_discriminator" in k}

    new_state_dict = state_dict.copy()
    for k in vids_weights.keys():
        del new_state_dict[k]
    

    for k in idis_weights.keys():
        new_k = "video_discriminator" + k[len("image_discriminator"):]
        if "weight" in k and new_state_dict[k].ndim == 4:
            old_weight = state_dict[k]
            if strategy == "average":
                new_weight = old_weight.unsqueeze(2).repeat(1, 1, 4, 1, 1) / 4
            elif strategy == "center":
                new_weight_ = old_weight# .unsqueeze(2) # O I 1 K K
                new_weight = torch.zeros((new_weight_.size(0), new_weight_.size(1), 4, new_weight_.size(2), new_weight_.size(3)), dtype=new_weight_.dtype)
                new_weight[:, :, 1] = new_weight_
                
            elif strategy == "first":
                new_weight_ = old_weight# .unsqueeze(2)
                new_weight = torch.zeros((new_weight_.size(0), new_weight_.size(1), 4, new_weight_.size(2), new_weight_.size(3)), dtype=new_weight_.dtype)
                new_weight[:, :, 0] = new_weight_

            elif strategy == "last":
                new_weight_ = old_weight# .unsqueeze(2)
                new_weight = torch.zeros((new_weight_.size(0), new_weight_.size(1), 4, new_weight_.size(2), new_weight_.size(3)), dtype=new_weight_.dtype)
                new_weight[:, :, -1] = new_weight_
            else:
                raise NotImplementedError
            
            new_state_dict[new_k] = new_weight
        
        elif "bias" in k:
            new_state_dict[new_k] = state_dict[k]
        else:
            new_state_dict[new_k] = state_dict[k]


    return new_state_dict

def load_unstrictly(state_dict, model, loaded_keys=[]):
    missing_keys = []
    for name, param in model.named_parameters():
        if name in state_dict:
            try:
                param.data.copy_(state_dict[name])
            except:
                # print(f"{name} mismatch: param {name}, shape {param.data.shape}, state_dict shape {state_dict[name].shape}")
                missing_keys.append(name)
        elif name not in loaded_keys:
            missing_keys.append(name)
    return model, missing_keys

def init_vae_only(state_dict, vae):
    vae, missing_keys = load_unstrictly(state_dict, vae)
    print(f"missing keys in loading vae: {[key for key in missing_keys if not key.startswith('flux')]}")
    return vae

def init_image_disc(state_dict, image_disc, args):
    if args.no_init_idis or args.init_idis == "no":
        state_dict = {}
    else:
        state_dict = state_dict["image_disc"]
    # load nn.GroupNorm to Normalize class
    delete_keys = []
    loaded_keys = []
    model = image_disc
    for key in state_dict:
        if key.endswith(".weight"):
            norm_key = key.replace(".weight", ".norm.weight")
            if norm_key and norm_key in model.state_dict():
                model.state_dict()[norm_key].copy_(state_dict[key])
                delete_keys.append(key)
                loaded_keys.append(norm_key)
        if key.endswith(".bias"):
            norm_key = key.replace(".bias", ".norm.bias")
            if norm_key and norm_key in model.state_dict():
                model.state_dict()[norm_key].copy_(state_dict[key])
                delete_keys.append(key)
                loaded_keys.append(norm_key)
    for key in delete_keys:
        del state_dict[key]
    msg = image_disc.load_state_dict(state_dict, strict=False)
    print(f"image disc missing: {[key for key in msg.missing_keys if key not in loaded_keys]}")
    print(f"image disc unexpected: {msg.unexpected_keys}")
    return image_disc

def init_video_disc(state_dict, video_disc, args):
    # init video disc
    if args.init_vdis == "no":
        video_disc_state_dict = {}
    elif args.init_vdis == "keep":
        video_disc_state_dict = state_dict["video_disc"]
    else:
        video_disc_state_dict = inflate_dis(state_dict["video_disc"], strategy=args.init_vdis)
    msg = video_disc.load_state_dict(video_disc_state_dict, strict=False)
    print(f"video disc missing: {msg.missing_keys}")
    print(f"video disc unexpected: {msg.unexpected_keys}")
    return video_disc

def init_vit_from_image(state_dict, vae, image_disc, video_disc, args):
    if args.init_vgen == "no":
        vae_state_dict = state_dict["vae"]
        del vae_state_dict["encoder.to_patch_emb.1.weight"]
        del vae_state_dict["encoder.to_patch_emb.1.bias"]
        del vae_state_dict["encoder.to_patch_emb.2.weight"]
        del vae_state_dict["encoder.to_patch_emb.2.bias"]
        del vae_state_dict["encoder.to_patch_emb.3.weight"]
        del vae_state_dict["encoder.to_patch_emb.3.bias"]

        del vae_state_dict["decoder.to_pixels.0.weight"]
        del vae_state_dict["decoder.to_pixels.0.bias"]
        vae_state_dict = state_dict["vae"]
    
    elif args.init_vgen == "keep":
        vae_state_dict = state_dict["vae"]
    else:
        vae_state_dict = inflate_gen(state_dict["vae"], temporal_patch_size=args.temporal_patch_size, spatial_patch_size=args.patch_size, strategy=args.init_vgen, inflation_pe=args.inflation_pe)
    
    if args.vq_to_vae:
        del vae_state_dict["pre_vq_conv.1.weight"]
        del vae_state_dict["pre_vq_conv.1.bias"]
    
    msg = vae.load_state_dict(vae_state_dict, strict=False)
    print(f"vae missing: {msg.missing_keys}")
    print(f"vae unexpected: {msg.unexpected_keys}")
    
    image_disc = init_image_disc(state_dict, image_disc, args)
    # video_disc = init_video_disc(state_dict, image_disc, args) # random init video discriminator
    
    return vae, image_disc, video_disc

def load_cnn(model, state_dict, prefix, expand=False, use_linear=False):
    delete_keys = []
    loaded_keys = []
    for key in state_dict:
        if key.startswith(prefix):
            _key = key[len(prefix):]
            if _key in model.state_dict():
                # load nn.Conv2d or nn.Linear to nn.Linear
                if use_linear and (".q.weight" in key or ".k.weight" in key or ".v.weight" in key or ".proj_out.weight" in key):
                    load_weights = state_dict[key].squeeze()
                elif _key.endswith(".conv.weight") and expand:
                    if model.state_dict()[_key].shape == state_dict[key].shape:
                        # 2D cnn to 2D cnn
                        load_weights = state_dict[key]
                    else:
                        # 2D cnn to 3D cnn
                        _expand_dim = model.state_dict()[_key].shape[2]
                        load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
                        load_weights = load_weights / _expand_dim # normalize across expand dim
                else:
                    load_weights = state_dict[key]
                model.state_dict()[_key].copy_(load_weights)
                delete_keys.append(key)
                loaded_keys.append(prefix+_key)
            # load nn.Conv2d to Conv class
            conv_list = ["conv"] if use_linear else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."]
            if any(k in _key for k in conv_list):
                if _key.endswith(".weight"):
                    conv_key = _key.replace(".weight", ".conv.weight")
                    if conv_key and conv_key in model.state_dict():
                        if model.state_dict()[conv_key].shape == state_dict[key].shape:
                            # 2D cnn to 2D cnn
                            load_weights = state_dict[key]
                        else:
                            # 2D cnn to 3D cnn
                            _expand_dim = model.state_dict()[conv_key].shape[2]
                            load_weights = state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1)
                            load_weights = load_weights / _expand_dim # normalize across expand dim
                        model.state_dict()[conv_key].copy_(load_weights)
                        delete_keys.append(key)
                        loaded_keys.append(prefix+conv_key)
                if _key.endswith(".bias"):
                    conv_key = _key.replace(".bias", ".conv.bias")
                    if conv_key and conv_key in model.state_dict():
                        model.state_dict()[conv_key].copy_(state_dict[key])
                        delete_keys.append(key)
                        loaded_keys.append(prefix+conv_key)
            # load nn.GroupNorm to Normalize class
            if "norm" in _key:
                if _key.endswith(".weight"):
                    norm_key = _key.replace(".weight", ".norm.weight")
                    if norm_key and norm_key in model.state_dict():
                        model.state_dict()[norm_key].copy_(state_dict[key])
                        delete_keys.append(key)
                        loaded_keys.append(prefix+norm_key)
                if _key.endswith(".bias"):
                    norm_key = _key.replace(".bias", ".norm.bias")
                    if norm_key and norm_key in model.state_dict():
                        model.state_dict()[norm_key].copy_(state_dict[key])
                        delete_keys.append(key)
                        loaded_keys.append(prefix+norm_key)
            
    for key in delete_keys:
        del state_dict[key]

    return model, state_dict, loaded_keys

def init_cnn_from_image(state_dict, vae, image_disc, video_disc, args, expand=False):
    vae.encoder, state_dict["vae"], loaded_keys1 = load_cnn(vae.encoder, state_dict["vae"], prefix="encoder.", expand=expand)
    vae.decoder, state_dict["vae"], loaded_keys2 = load_cnn(vae.decoder, state_dict["vae"], prefix="decoder.", expand=expand)
    loaded_keys = loaded_keys1 + loaded_keys2
    # msg = vae.load_state_dict(state_dict["vae"], strict=False)
    # print(f"vae missing: {[key for key in msg.missing_keys if key not in loaded_keys]}")
    # print(f"vae unexpected: {msg.unexpected_keys}")
    vae, missing_keys = load_unstrictly(state_dict["vae"], vae, loaded_keys)

    if image_disc:
        image_disc = init_image_disc(state_dict, image_disc, args)
    ### random init video discriminator
    # if video_disc:
    #     video_disc = init_video_disc(state_dict, image_disc, args)
    return vae, image_disc, video_disc

def resume_from_ckpt(state_dict, model_optims, load_optims=True):
    all_missing_keys = []
    # load weights first
    for k in model_optims:
        if model_optims[k] and state_dict[k] and (not is_torch_optim_sch(model_optims[k])) and k in state_dict:
            model_optims[k], missing_keys = load_unstrictly(state_dict[k], model_optims[k])
            all_missing_keys += missing_keys
        
    if len(all_missing_keys) == 0 and load_optims:
        print("Loading optimizer states")
        for k in model_optims: 
            if model_optims[k] and state_dict[k] and is_torch_optim_sch(model_optims[k]) and k in state_dict:
                model_optims[k].load_state_dict(state_dict[k])
    else:
        print(f"missing weights: {all_missing_keys}, load_optims={load_optims}, do not load optimzer states")
    return model_optims, state_dict["step"]

### old version
# def get_last_ckpt(root_dir):
#     if not os.path.exists(root_dir): return None, None
#     ckpt_files = {}
#     for dirpath, dirnames, filenames in os.walk(root_dir):
#         for filename in filenames:
#             if filename.endswith('.ckpt'):
#                 num_iter = int(filename.split('-')[1].split('=')[1])
#                 ckpt_files[num_iter]=os.path.join(dirpath, filename)
#     iter_list = list(ckpt_files.keys())
#     if len(iter_list) == 0: return None, None
#     max_iter = max(iter_list)
#     return ckpt_files[max_iter], max_iter