File size: 5,047 Bytes
31112ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
import torchvision
from PIL import Image, ImageDraw
import imageio
import cv2
import torch
import torch.nn.functional as F
import numpy as np
import zipfile
_gauss_mask_cache = {}
def load_gauss_mask(mask_path):
if not mask_path:
return None
abs_path = os.path.abspath(mask_path)
mask = _gauss_mask_cache.get(abs_path)
if mask is None:
mask = torch.load(abs_path, weights_only=False, map_location="cpu")
if not torch.is_tensor(mask):
mask = torch.tensor(mask)
_gauss_mask_cache[abs_path] = mask
return mask
def apply_alpha_shift(latents, gauss_mask, shift_mean):
if gauss_mask is None:
return latents
mask = gauss_mask
if mask.ndim == 3:
mask = mask.unsqueeze(0).unsqueeze(0)
elif mask.ndim == 4:
if mask.shape[0] != 1:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.unsqueeze(1)
elif mask.ndim != 5:
return latents
mask = mask.to(device=latents.device, dtype=latents.dtype)
target_shape = latents.shape[2:]
if mask.shape[-3:] != target_shape:
mask = F.interpolate(mask, size=target_shape, mode="trilinear", align_corners=False)
shift_mean = torch.as_tensor(shift_mean, dtype=latents.dtype, device=latents.device)
return latents + (1.0 - mask) * shift_mean
def render_video(tensor_fgr,
tensor_pha,
nrow=8,
normalize=True,
value_range=(-1, 1)):
def to_tensor(arr_list):
tensor_list= [torch.from_numpy(arr).float().div_(127.5).sub_(1) for arr in arr_list]
tensor_list = torch.stack(tensor_list, dim = 0).permute(3,0,1,2).unsqueeze(0)
return tensor_list
if not torch.is_tensor(tensor_fgr):
tensor_fgr = to_tensor(tensor_fgr)
if not torch.is_tensor(tensor_pha):
tensor_pha = to_tensor(tensor_pha)
tensor_fgr = tensor_fgr.clamp(min(value_range), max(value_range))
tensor_fgr = torch.stack([
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range)
for u in tensor_fgr.unbind(2)
],
dim=1).permute(1, 2, 3, 0)
tensor_fgr = (tensor_fgr * 255).type(torch.uint8).cpu()
tensor_pha = tensor_pha.clamp(min(value_range), max(value_range))
tensor_pha = torch.stack([
torchvision.utils.make_grid(
u, nrow=nrow, normalize=normalize, value_range=value_range)
for u in tensor_pha.unbind(2)
],
dim=1).permute(1, 2, 3, 0)
tensor_pha = (tensor_pha * 255).type(torch.uint8).cpu()
frames = []
frames_fgr = []
frames_pha = []
for frame_fgr, frame_pha in zip(tensor_fgr.numpy(), tensor_pha.numpy()):
if frame_pha.shape[-1] == 1:
frame_pha = frame_pha[:,:,0]
else:
frame_pha = (0.0 + frame_pha[:,:,0:1] + frame_pha[:,:,1:2] + frame_pha[:,:,2:3]) / 3.
frame = np.concatenate([frame_fgr[:,:,::-1], frame_pha.astype(np.uint8)], axis=2)
frames.append(frame)
frames_fgr.append(frame_fgr)
frames_pha.append(frame_pha)
def create_checkerboard(size=30, pattern_size=(830, 480), color1=(140, 140, 140), color2=(113, 113, 113)):
img = Image.new('RGB', (pattern_size[0], pattern_size[1]), color1)
draw = ImageDraw.Draw(img)
for i in range(0, pattern_size[0], size):
for j in range(0, pattern_size[1], size):
if (i + j) // size % 2 == 0:
draw.rectangle([i, j, i+size, j+size], fill=color2)
return img
def blender_background(frame_rgba, checkerboard):
alpha_channel = frame_rgba[:, :, 3:] / 255.
checkerboard = np.array(checkerboard)
checkerboard = cv2.resize(checkerboard, (frame_rgba.shape[1], frame_rgba.shape[0]))
frame_rgb = frame_rgba[:, :, :3] * alpha_channel + checkerboard * (1-alpha_channel)
return frame_rgb.astype(np.uint8)[:,:,::-1]
checkerboard = create_checkerboard()
video_checkerboard = [torch.from_numpy(blender_background(f, checkerboard).copy()).float().div_(127.5).sub_(1) for f in frames]
video_checkerboard = torch.stack(video_checkerboard ).permute(3, 0, 1, 2)
return video_checkerboard, frames
def from_BRGA_numpy_to_RGBA_torch(video):
video = [torch.from_numpy(f.copy()).float().div_(127.5).sub_(1) for f in video]
video = torch.stack(video).permute(3, 0, 1, 2)
video[[0, 2], ...] = video[[2, 0], ...]
return video
def write_zip_file(zip_path, frames):
# frames in BGRA format
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for idx, img in enumerate(frames):
success, buffer = cv2.imencode(".png", img)
if not success:
print(f"Failed to encode image {idx}, skipping...")
continue
filename = f"img_{idx:03d}.png"
zipf.writestr(filename, buffer.tobytes())
|