Spaces:
Sleeping
Sleeping
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
| import os | |
| import os.path as osp | |
| import sys | |
| import cv2 | |
| import glob | |
| import math | |
| import torch | |
| import gzip | |
| import copy | |
| import time | |
| import json | |
| import pickle | |
| import base64 | |
| import imageio | |
| import hashlib | |
| import requests | |
| import binascii | |
| import zipfile | |
| # import skvideo.io | |
| import numpy as np | |
| from io import BytesIO | |
| import urllib.request | |
| import torch.nn.functional as F | |
| import torchvision.utils as tvutils | |
| from multiprocessing.pool import ThreadPool as Pool | |
| from einops import rearrange | |
| from PIL import Image, ImageDraw, ImageFont | |
| def gen_text_image(captions, text_size): | |
| num_char = int(38 * (text_size / text_size)) | |
| font_size = int(text_size / 20) | |
| font = ImageFont.truetype('data/font/DejaVuSans.ttf', size=font_size) | |
| text_image_list = [] | |
| for text in captions: | |
| txt_img = Image.new("RGB", (text_size, text_size), color="white") | |
| draw = ImageDraw.Draw(txt_img) | |
| lines = "\n".join(text[start:start + num_char] for start in range(0, len(text), num_char)) | |
| draw.text((0, 0), lines, fill="black", font=font) | |
| txt_img = np.array(txt_img) | |
| text_image_list.append(txt_img) | |
| text_images = np.stack(text_image_list, axis=0) | |
| text_images = torch.from_numpy(text_images) | |
| return text_images | |
| def save_video_refimg_and_text( | |
| local_path, | |
| ref_frame, | |
| gen_video, | |
| captions, | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5], | |
| text_size=256, | |
| nrow=4, | |
| save_fps=8, | |
| retry=5): | |
| ''' | |
| gen_video: BxCxFxHxW | |
| ''' | |
| nrow = max(int(gen_video.size(0) / 2), 1) | |
| vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
| vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
| text_images = gen_text_image(captions, text_size) # Tensor 8x256x256x3 | |
| text_images = text_images.unsqueeze(1) # Tensor 8x1x256x256x3 | |
| text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 8x16x256x256x3 | |
| ref_frame = ref_frame.unsqueeze(2) | |
| ref_frame = ref_frame.mul_(vid_std).add_(vid_mean) | |
| ref_frame = ref_frame.repeat_interleave(repeats=gen_video.size(2), dim=2) # 8x16x256x256x3 | |
| ref_frame.clamp_(0, 1) | |
| ref_frame = ref_frame * 255.0 | |
| ref_frame = rearrange(ref_frame, 'b c f h w -> b f h w c') | |
| gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 | |
| gen_video.clamp_(0, 1) | |
| gen_video = gen_video * 255.0 | |
| images = rearrange(gen_video, 'b c f h w -> b f h w c') | |
| images = torch.cat([ref_frame, images, text_images], dim=3) | |
| images = rearrange(images, '(r j) f h w c -> f (r h) (j w) c', r=nrow) | |
| images = [(img.numpy()).astype('uint8') for img in images] | |
| for _ in [None] * retry: | |
| try: | |
| if len(images) == 1: | |
| local_path = local_path + '.png' | |
| cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
| else: | |
| local_path = local_path + '.mp4' | |
| frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) | |
| os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) | |
| for fid, frame in enumerate(images): | |
| tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) | |
| cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
| cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' | |
| os.system(cmd); os.system(f'rm -rf {frame_dir}') | |
| # os.system(f'rm -rf {local_path}') | |
| exception = None | |
| break | |
| except Exception as e: | |
| exception = e | |
| continue | |
| def save_i2vgen_video( | |
| local_path, | |
| image_id, | |
| gen_video, | |
| captions, | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5], | |
| text_size=256, | |
| retry=5, | |
| save_fps = 8 | |
| ): | |
| ''' | |
| Save both the generated video and the input conditions. | |
| ''' | |
| vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
| vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
| text_images = gen_text_image(captions, text_size) # Tensor 1x256x256x3 | |
| text_images = text_images.unsqueeze(1) # Tensor 1x1x256x256x3 | |
| text_images = text_images.repeat_interleave(repeats=gen_video.size(2), dim=1) # 1x16x256x256x3 | |
| image_id = image_id.unsqueeze(2) # B, C, F, H, W | |
| image_id = image_id.repeat_interleave(repeats=gen_video.size(2), dim=2) # 1x3x32x256x448 | |
| image_id = image_id.mul_(vid_std).add_(vid_mean) # 32x3x256x448 | |
| image_id.clamp_(0, 1) | |
| image_id = image_id * 255.0 | |
| image_id = rearrange(image_id, 'b c f h w -> b f h w c') | |
| gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 | |
| gen_video.clamp_(0, 1) | |
| gen_video = gen_video * 255.0 | |
| images = rearrange(gen_video, 'b c f h w -> b f h w c') | |
| images = torch.cat([image_id, images, text_images], dim=3) | |
| images = images[0] | |
| images = [(img.numpy()).astype('uint8') for img in images] | |
| exception = None | |
| for _ in [None] * retry: | |
| try: | |
| frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) | |
| os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) | |
| for fid, frame in enumerate(images): | |
| tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) | |
| cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
| cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' | |
| os.system(cmd); os.system(f'rm -rf {frame_dir}') | |
| break | |
| except Exception as e: | |
| exception = e | |
| continue | |
| if exception is not None: | |
| raise exception | |
| def save_i2vgen_video_safe( | |
| local_path, | |
| gen_video, | |
| captions, | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5], | |
| text_size=256, | |
| retry=5, | |
| save_fps = 8 | |
| ): | |
| ''' | |
| Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. | |
| ''' | |
| vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
| vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
| gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 | |
| gen_video.clamp_(0, 1) | |
| gen_video = gen_video * 255.0 | |
| images = rearrange(gen_video, 'b c f h w -> b f h w c') | |
| images = images[0] | |
| images = [(img.numpy()).astype('uint8') for img in images] | |
| num_image = len(images) | |
| exception = None | |
| for _ in [None] * retry: | |
| try: | |
| if num_image == 1: | |
| local_path = local_path + '.png' | |
| cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
| else: | |
| writer = imageio.get_writer(local_path, fps=save_fps, codec='libx264', quality=8) | |
| for fid, frame in enumerate(images): | |
| if fid == num_image-1: # Fix known bugs. | |
| ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) | |
| if ratio > 0.4: continue | |
| writer.append_data(frame) | |
| writer.close() | |
| break | |
| except Exception as e: | |
| exception = e | |
| continue | |
| if exception is not None: | |
| raise exception | |
| def save_t2vhigen_video_safe( | |
| local_path, | |
| gen_video, | |
| captions, | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5], | |
| text_size=256, | |
| retry=5, | |
| save_fps = 8 | |
| ): | |
| ''' | |
| Save only the generated video, do not save the related reference conditions, and at the same time perform anomaly detection on the last frame. | |
| ''' | |
| vid_mean = torch.tensor(mean, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
| vid_std = torch.tensor(std, device=gen_video.device).view(1, -1, 1, 1, 1) #ncfhw | |
| gen_video = gen_video.mul_(vid_std).add_(vid_mean) # 8x3x16x256x384 | |
| gen_video.clamp_(0, 1) | |
| gen_video = gen_video * 255.0 | |
| images = rearrange(gen_video, 'b c f h w -> b f h w c') | |
| images = images[0] | |
| images = [(img.numpy()).astype('uint8') for img in images] | |
| num_image = len(images) | |
| exception = None | |
| for _ in [None] * retry: | |
| try: | |
| if num_image == 1: | |
| local_path = local_path + '.png' | |
| cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
| else: | |
| frame_dir = os.path.join(os.path.dirname(local_path), '%s_frames' % (os.path.basename(local_path))) | |
| os.system(f'rm -rf {frame_dir}'); os.makedirs(frame_dir, exist_ok=True) | |
| for fid, frame in enumerate(images): | |
| if fid == num_image-1: # Fix known bugs. | |
| ratio = (np.sum((frame >= 117) & (frame <= 137)))/(frame.size) | |
| if ratio > 0.4: continue | |
| tpth = os.path.join(frame_dir, '%04d.png' % (fid+1)) | |
| cv2.imwrite(tpth, frame[:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
| cmd = f'ffmpeg -y -f image2 -loglevel quiet -framerate {save_fps} -i {frame_dir}/%04d.png -vcodec libx264 -crf 17 -pix_fmt yuv420p {local_path}' | |
| os.system(cmd) | |
| os.system(f'rm -rf {frame_dir}') | |
| break | |
| except Exception as e: | |
| exception = e | |
| continue | |
| if exception is not None: | |
| raise exception | |
| def save_video_multiple_conditions_not_gif_horizontal_3col(local_path, video_tensor, model_kwargs, source_imgs, | |
| mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5], nrow=8, retry=5, save_fps=8): | |
| mean=torch.tensor(mean,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw | |
| std=torch.tensor(std,device=video_tensor.device).view(1,-1,1,1,1)#ncfhw | |
| video_tensor = video_tensor.mul_(std).add_(mean) #### unnormalize back to [0,1] | |
| video_tensor.clamp_(0, 1) | |
| b, c, n, h, w = video_tensor.shape | |
| source_imgs = F.adaptive_avg_pool3d(source_imgs, (n, h, w)) | |
| source_imgs = source_imgs.cpu() | |
| model_kwargs_channel3 = {} | |
| for key, conditions in model_kwargs[0].items(): | |
| if conditions.size(1) == 1: | |
| conditions = torch.cat([conditions, conditions, conditions], dim=1) | |
| conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) | |
| if conditions.size(1) == 2: | |
| conditions = torch.cat([conditions, conditions[:,:1,]], dim=1) | |
| conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) | |
| elif conditions.size(1) == 3: | |
| conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) | |
| elif conditions.size(1) == 4: # means it is a mask. | |
| color = ((conditions[:, 0:3] + 1.)/2.) # .astype(np.float32) | |
| alpha = conditions[:, 3:4] # .astype(np.float32) | |
| conditions = color * alpha + 1.0 * (1.0 - alpha) | |
| conditions = F.adaptive_avg_pool3d(conditions, (n, h, w)) | |
| model_kwargs_channel3[key] = conditions.cpu() if conditions.is_cuda else conditions | |
| # filename = rand_name(suffix='.gif') | |
| for _ in [None] * retry: | |
| try: | |
| vid_gif = rearrange(video_tensor, '(i j) c f h w -> c f (i h) (j w)', i = nrow) | |
| cons_list = [rearrange(con, '(i j) c f h w -> c f (i h) (j w)', i = nrow) for _, con in model_kwargs_channel3.items()] | |
| vid_gif = torch.cat(cons_list + [vid_gif,], dim=3) | |
| vid_gif = vid_gif.permute(1,2,3,0) | |
| images = vid_gif * 255.0 | |
| images = [(img.numpy()).astype('uint8') for img in images] | |
| if len(images) == 1: | |
| local_path = local_path.replace('.mp4', '.png') | |
| cv2.imwrite(local_path, images[0][:,:,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100]) | |
| # bucket.put_object_from_file(oss_key, local_path) | |
| else: | |
| outputs = [] | |
| for image_name in images: | |
| x = Image.fromarray(image_name) | |
| outputs.append(x) | |
| from pathlib import Path | |
| save_fmt = Path(local_path).suffix | |
| if save_fmt == ".mp4": | |
| with imageio.get_writer(local_path, fps=save_fps) as writer: | |
| for img in outputs: | |
| img_array = np.array(img) # Convert PIL Image to numpy array | |
| writer.append_data(img_array) | |
| elif save_fmt == ".gif": | |
| outputs[0].save( | |
| fp=local_path, | |
| format="GIF", | |
| append_images=outputs[1:], | |
| save_all=True, | |
| duration=(1 / save_fps * 1000), | |
| loop=0, | |
| ) | |
| else: | |
| raise ValueError("Unsupported file type. Use .mp4 or .gif.") | |
| # fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| # fps = save_fps | |
| # image = images[0] | |
| # media_writer = cv2.VideoWriter(local_path, fourcc, fps, (image.shape[1],image.shape[0])) | |
| # for image_name in images: | |
| # im = image_name[:,:,::-1] | |
| # media_writer.write(im) | |
| # media_writer.release() | |
| exception = None | |
| break | |
| except Exception as e: | |
| exception = e | |
| continue | |
| if exception is not None: | |
| print('save video to {} failed, error: {}'.format(local_path, exception), flush=True) | |