physctrl / libs /sv3d /infer.py
chenwang's picture
update
4724018
import os
import argparse
import rembg
import numpy as np
import math
import torch
import json
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from diffusers.utils import export_to_gif
from diffusers_sv3d import SV3DUNetSpatioTemporalConditionModel, StableVideo3DDiffusionPipeline
from kiui.cam import orbit_camera
SV3D_DIFFUSERS = "chenguolin/sv3d-diffusers"
# os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["HF_HOME"] = "~/.cache/huggingface"
def construct_camera(azimuths_rad, elevation_rad, output_dir, res=576, radius=2, fov=33.8):
transforms = {}
transforms["camera_angle_x"] = math.radians(fov)
transforms["frames"] = []
for i in range(21):
frame = {}
frame['file_path'] = f"data/{i:03d}"
frame['transform_matrix'] = orbit_camera(elevation_rad[i], azimuths_rad[i], radius, is_degree=False).tolist()
transforms['frames'].append(frame)
with open(f"{output_dir}/../transforms_train.json", "w") as f:
json.dump(transforms, f, indent=4)
with open(f"{output_dir}/../transforms_val.json", "w") as f:
json.dump(transforms, f, indent=4)
with open(f"{output_dir}/../transforms_test.json", "w") as f:
json.dump(transforms, f, indent=4)
def recenter(image, h_begin=100, w_begin=220, res=256):
image = np.array(image)
h_image, w_image = image.shape[:2]
new_image = np.zeros((res, res, 4), dtype=np.uint8)
h_begin_new = -min(0, h_begin)
w_begin_new = -min(0, w_begin)
if h_begin > 0 and w_begin > 0:
new_image = image[h_begin:h_begin+res, w_begin:w_begin+res]
else:
new_image[h_begin_new:h_begin_new+h_image, w_begin_new:w_image] = image
return Image.fromarray(new_image)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--base-dir", default="../../data", type=str, help="Base dir")
parser.add_argument("--output-dir", default="../../data", type=str, help="Output filepath")
parser.add_argument("--data-name", default="chair", type=str, help="Data Name")
parser.add_argument("--elevation", default=0, type=float, help="Camera elevation of the input image")
parser.add_argument("--half-precision", action="store_true", help="Use fp16 half precision")
parser.add_argument("--seed", default=-1, type=int, help="Random seed")
args = parser.parse_args()
image_path = f'{args.base_dir}/{args.data_name}/{args.data_name}.png'
output_dir = f'{args.output_dir}/{args.data_name}/data'
os.makedirs(output_dir, exist_ok=True)
num_frames, sv3d_res = 20, 576
elevations_deg = [args.elevation] * num_frames
elevations_rad = [np.deg2rad(e) for e in elevations_deg]
polars_rad = [np.deg2rad(90 - e) for e in elevations_deg]
azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360
azimuths_rad = [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
azimuths_rad[:-1].sort()
# print(f"Elevation: {elevations_rad}")
print(f"Azimuth: {np.rad2deg(azimuths_rad)}")
# construct_camera(azimuths_rad, elevations_rad, output_dir=output_dir)
bg_remover = rembg.new_session()
unet = SV3DUNetSpatioTemporalConditionModel.from_pretrained(SV3D_DIFFUSERS, subfolder="unet")
vae = AutoencoderKL.from_pretrained(SV3D_DIFFUSERS, subfolder="vae")
scheduler = EulerDiscreteScheduler.from_pretrained(SV3D_DIFFUSERS, subfolder="scheduler")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(SV3D_DIFFUSERS, subfolder="image_encoder")
feature_extractor = CLIPImageProcessor.from_pretrained(SV3D_DIFFUSERS, subfolder="feature_extractor")
pipeline = StableVideo3DDiffusionPipeline(
image_encoder=image_encoder, feature_extractor=feature_extractor,
unet=unet, vae=vae,
scheduler=scheduler,
)
pipeline = pipeline.to("cuda")
with torch.no_grad():
with torch.autocast("cuda", dtype=torch.float16 if args.half_precision else torch.float32, enabled=True):
h_begin, w_begin, res = 180, 190, 280
image = Image.open(image_path)
image = recenter(image, h_begin, w_begin, res)
image = rembg.remove(image, session=bg_remover) # [H, W, 4]
image.save(f"{output_dir}/../{args.data_name}_alpha.png")
if len(image.split()) == 4: # RGBA
input_image = Image.new("RGB", image.size, (255, 255, 255)) # pure white bg
input_image.paste(image, mask=image.split()[3]) # 3rd is the alpha channel
else:
input_image = image
video_frames = pipeline(
input_image.resize((sv3d_res, sv3d_res)),
height=sv3d_res,
width=sv3d_res,
num_frames=num_frames,
decode_chunk_size=8, # smaller to save memory
polars_rad=polars_rad,
azimuths_rad=azimuths_rad,
generator=torch.manual_seed(args.seed) if args.seed >= 0 else None,
).frames[0]
os.makedirs(output_dir, exist_ok=True)
export_to_gif(video_frames, f"{output_dir}/animation.gif", fps=7)
for i, frame in enumerate(video_frames):
# frame = frame.resize((res, res))
frame.save(f"{output_dir}/{i:03d}.png")
video_frames[19].save(f"../LGM/workspace_test/{args.data_name}_0.png")
video_frames[4].save(f"../LGM/workspace_test/{args.data_name}_1.png")
video_frames[9].save(f"../LGM/workspace_test/{args.data_name}_2.png")
video_frames[14].save(f"../LGM/workspace_test/{args.data_name}_3.png")
if __name__ == "__main__":
main()