File size: 5,047 Bytes
31112ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
import torchvision
from PIL import Image, ImageDraw
import imageio
import cv2
import torch
import torch.nn.functional as F
import numpy as np 
import zipfile

_gauss_mask_cache = {}


def load_gauss_mask(mask_path):
    if not mask_path:
        return None
    abs_path = os.path.abspath(mask_path)
    mask = _gauss_mask_cache.get(abs_path)
    if mask is None:
        mask = torch.load(abs_path, weights_only=False, map_location="cpu")
        if not torch.is_tensor(mask):
            mask = torch.tensor(mask)
        _gauss_mask_cache[abs_path] = mask
    return mask


def apply_alpha_shift(latents, gauss_mask, shift_mean):
    if gauss_mask is None:
        return latents
    mask = gauss_mask
    if mask.ndim == 3:
        mask = mask.unsqueeze(0).unsqueeze(0)
    elif mask.ndim == 4:
        if mask.shape[0] != 1:
            mask = mask.unsqueeze(0)
        if mask.shape[1] != 1:
            mask = mask.unsqueeze(1)
    elif mask.ndim != 5:
        return latents

    mask = mask.to(device=latents.device, dtype=latents.dtype)
    target_shape = latents.shape[2:]
    if mask.shape[-3:] != target_shape:
        mask = F.interpolate(mask, size=target_shape, mode="trilinear", align_corners=False)
    shift_mean = torch.as_tensor(shift_mean, dtype=latents.dtype, device=latents.device)
    return latents + (1.0 - mask) * shift_mean

def render_video(tensor_fgr,
                tensor_pha,
                nrow=8,
                normalize=True,
                value_range=(-1, 1)):
    def to_tensor(arr_list):
        tensor_list= [torch.from_numpy(arr).float().div_(127.5).sub_(1) for arr in arr_list]
        tensor_list = torch.stack(tensor_list, dim = 0).permute(3,0,1,2).unsqueeze(0)
        return tensor_list
                
    if not torch.is_tensor(tensor_fgr):
        tensor_fgr = to_tensor(tensor_fgr)
    if not torch.is_tensor(tensor_pha):
        tensor_pha = to_tensor(tensor_pha)

    tensor_fgr = tensor_fgr.clamp(min(value_range), max(value_range))
    tensor_fgr = torch.stack([
        torchvision.utils.make_grid(
            u, nrow=nrow, normalize=normalize, value_range=value_range)
        for u in tensor_fgr.unbind(2)
    ],
                            dim=1).permute(1, 2, 3, 0)
    tensor_fgr = (tensor_fgr * 255).type(torch.uint8).cpu()

    tensor_pha = tensor_pha.clamp(min(value_range), max(value_range))
    tensor_pha = torch.stack([
        torchvision.utils.make_grid(
            u, nrow=nrow, normalize=normalize, value_range=value_range)
        for u in tensor_pha.unbind(2)
    ],
                            dim=1).permute(1, 2, 3, 0)
    tensor_pha = (tensor_pha * 255).type(torch.uint8).cpu()

    frames = []
    frames_fgr = []
    frames_pha = []
    for frame_fgr, frame_pha in zip(tensor_fgr.numpy(), tensor_pha.numpy()):
        if frame_pha.shape[-1] == 1:
            frame_pha = frame_pha[:,:,0]
        else:
            frame_pha = (0.0 + frame_pha[:,:,0:1] + frame_pha[:,:,1:2] + frame_pha[:,:,2:3]) / 3.
        frame = np.concatenate([frame_fgr[:,:,::-1], frame_pha.astype(np.uint8)], axis=2)
        frames.append(frame)
        frames_fgr.append(frame_fgr)
        frames_pha.append(frame_pha)

    def create_checkerboard(size=30, pattern_size=(830, 480), color1=(140, 140, 140), color2=(113, 113, 113)):
        img = Image.new('RGB', (pattern_size[0], pattern_size[1]), color1)
        draw = ImageDraw.Draw(img)
        for i in range(0, pattern_size[0], size):
            for j in range(0, pattern_size[1], size):
                if (i + j) // size % 2 == 0:
                    draw.rectangle([i, j, i+size, j+size], fill=color2)
        return img

    def blender_background(frame_rgba, checkerboard):
        alpha_channel = frame_rgba[:, :, 3:] / 255. 
        checkerboard = np.array(checkerboard)
        checkerboard = cv2.resize(checkerboard, (frame_rgba.shape[1], frame_rgba.shape[0]))

        frame_rgb = frame_rgba[:, :, :3] * alpha_channel + checkerboard * (1-alpha_channel)
        return frame_rgb.astype(np.uint8)[:,:,::-1]
    
    checkerboard = create_checkerboard()
    video_checkerboard = [torch.from_numpy(blender_background(f, checkerboard).copy()).float().div_(127.5).sub_(1) for f in frames]
    video_checkerboard = torch.stack(video_checkerboard ).permute(3, 0, 1, 2)
    return video_checkerboard, frames

def from_BRGA_numpy_to_RGBA_torch(video):
    video = [torch.from_numpy(f.copy()).float().div_(127.5).sub_(1) for f in video]
    video = torch.stack(video).permute(3, 0, 1, 2)
    video[[0, 2], ...] = video[[2, 0], ...]
    return video

def write_zip_file(zip_path, frames):
    # frames in BGRA format
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for idx, img in enumerate(frames):
            success, buffer = cv2.imencode(".png", img)
            if not success:
                print(f"Failed to encode image {idx}, skipping...")
                continue
            
            filename = f"img_{idx:03d}.png"
            zipf.writestr(filename, buffer.tobytes())