File size: 6,299 Bytes
199f9c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from torchmetrics import MetricCollection
from svd_pipeline import StableVideoDiffusionPipeline
from accelerate.logging import get_logger
import os
from utils import load_image
import torch
import numpy as np
import videoio
import torchmetrics.image
import matplotlib.image
from PIL import Image

logger = get_logger(__name__, log_level="INFO")


def valid_net(args, val_dataset, val_dataloader, unet, image_encoder, vae, zero, accelerator, global_step, weight_dtype):
    logger.info(
        f"Running validation... \n Generating {args.num_validation_images} videos."
    )

    # The models need unwrapping because for compatibility in distributed training mode.

    pipeline = StableVideoDiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        unet=unet,
        image_encoder=image_encoder,
        vae=vae,
        revision=args.revision,
        torch_dtype=weight_dtype,
    )

    pipeline.set_progress_bar_config(disable=True)

    # run inference
    val_save_dir = os.path.join(
        args.output_dir, "validation_images")

    print("Validation images will be saved to ", val_save_dir)

    os.makedirs(val_save_dir, exist_ok=True)


    num_frames = args.num_frames
    unet.eval()
    with torch.autocast(
        str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
    ):
        for batch in val_dataloader:
            #clear gradients (the torch no grad is the magic that makes this work)
            with torch.no_grad():
                torch.cuda.empty_cache()

            pixel_values = batch["pixel_values"].to(accelerator.device)
            original_pixel_values = batch['original_pixel_values'].to(accelerator.device)
            idx = batch["idx"].to(accelerator.device)
            if "focal_stack_num" in batch:
                focal_stack_num = batch["focal_stack_num"][0].item()
            else:
                focal_stack_num = None

            svd_output, gt_frames = pipeline(
                pixel_values,
                height=pixel_values.shape[3],
                width=pixel_values.shape[4],
                num_frames=args.num_frames,
                decode_chunk_size=8,
                motion_bucket_id=0 if args.conditioning != "ablate_time" else focal_stack_num,
                min_guidance_scale=1.5,
                max_guidance_scale=1.5,
                reconstruction_guidance_scale=args.reconstruction_guidance,
                fps=7,
                noise_aug_strength=0,
                accelerator=accelerator,
                weight_dtype=weight_dtype,
                conditioning = args.conditioning,
                focal_stack_num = focal_stack_num,
                zero=zero
                # generator=generator,
            )
            video_frames = svd_output.frames[0]
            gt_frames = gt_frames[0]


            with torch.no_grad():

                if args.num_frames == 10:
                    #remove a frame at end from video_frames and gt_frames
                    video_frames = video_frames[:, :-1]
                    gt_frames = gt_frames[:, :-1]
                    original_pixel_values = original_pixel_values[:, :-1]
    
                if len(original_pixel_values.shape) == 5:
                    pixel_values = original_pixel_values[0] #assuming batch size is 1
                else:
                    pixel_values = original_pixel_values.repeat(num_frames, 1, 1, 1)
                pixel_values_normalized = pixel_values*0.5 + 0.5
                pixel_values_normalized = torch.clamp(pixel_values_normalized,0,1)




                video_frames_normalized = video_frames*0.5 + 0.5
                video_frames_normalized = torch.clamp(video_frames_normalized,0,1)
                video_frames_normalized = video_frames_normalized.permute(1,0,2,3)


                gt_frames = torch.clamp(gt_frames,0,1)
                gt_frames = gt_frames.permute(1,0,2,3)

                #RESIZE images 
                video_frames_normalized = torch.nn.functional.interpolate(video_frames_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
                gt_frames = torch.nn.functional.interpolate(gt_frames, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')
                pixel_values_normalized = torch.nn.functional.interpolate(pixel_values_normalized, ((pixel_values.shape[2]//2)*2, (pixel_values.shape[3]//2)*2), mode='bilinear')

                os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/videos"), exist_ok=True)
                videoio.videosave(os.path.join(
                    val_save_dir,
                    f"position_{focal_stack_num}/videos/step_{global_step}_val_img_{idx[0].item()}.mp4",
                ), video_frames_normalized.permute(0,2,3,1).cpu().numpy(), fps=5)

                if args.test:
                    #save images
                    os.makedirs(os.path.join(val_save_dir, f"position_{focal_stack_num}/images"), exist_ok=True)
                    if not args.photos:
                        for i in range(num_frames):
                            matplotlib.image.imsave(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/img_{idx[0].item()}_frame_{i}.png"), video_frames_normalized[i].permute(1,2,0).cpu().numpy())
                    else:
                        for i in range(num_frames):
                            #use Pillow to save images
                            img = Image.fromarray((video_frames_normalized[i].permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
                            #use index to assign icc profile to img
                            if batch['icc_profile'][0] != "none":
                                img.info['icc_profile'] = batch['icc_profile'][0]
                            img.save(os.path.join(val_save_dir, f"position_{focal_stack_num}/images/img_{idx[0].item()}_frame_{i}.png"))
            del video_frames

    accelerator.wait_for_everyone()

    #clear gradients (the torch no grad is the magic that makes this work)
    with torch.no_grad():
        torch.cuda.empty_cache()
        
    del pipeline

    accelerator.wait_for_everyone() #this is really important and we need to make sure everyone is leaving at the same time