| import os |
| import numpy as np |
| import torch |
| from einops import rearrange |
| from imageio import imwrite |
| from pydantic import validator |
| import imageio |
| import tempfile |
| import gradio as gr |
|
|
| from PIL import Image |
|
|
| from my.utils import ( |
| tqdm, EventStorage, HeartBeat, EarlyLoopBreak, |
| get_event_storage, get_heartbeat, read_stats |
| ) |
| from my.config import BaseConf, dispatch, optional_load_config |
| from my.utils.seed import seed_everything |
|
|
| from adapt import ScoreAdapter |
| from run_img_sampling import SD |
| from misc import torch_samps_to_imgs |
| from pose import PoseConfig |
|
|
| from run_nerf import VoxConfig |
| from voxnerf.utils import every |
| from voxnerf.render import ( |
| as_torch_tsrs, rays_from_img, ray_box_intersect, render_ray_bundle |
| ) |
| from voxnerf.vis import stitch_vis, bad_vis as nerf_vis |
|
|
| from pytorch3d.renderer import PointsRasterizationSettings |
|
|
| from semantic_coding import semantic_coding, semantic_karlo, semantic_sd |
| from pc_project import point_e, render_depth_from_cloud |
| device_glb = torch.device("cuda") |
|
|
| def tsr_stats(tsr): |
| return { |
| "mean": tsr.mean().item(), |
| "std": tsr.std().item(), |
| "max": tsr.max().item(), |
| } |
|
|
| class SJC_3DFuse(BaseConf): |
| family: str = "sd" |
| sd: SD = SD( |
| variant="v1", |
| prompt="a comfortable bed", |
| scale=100.0, |
| dir="./results", |
| alpha=0.3 |
| ) |
| lr: float = 0.05 |
| n_steps: int = 10000 |
| vox: VoxConfig = VoxConfig( |
| model_type="V_SD", grid_size=100, density_shift=-1.0, c=3, |
| blend_bg_texture=False , bg_texture_hw=4, |
| bbox_len=1.0 |
| ) |
| pose: PoseConfig = PoseConfig(rend_hw=64, FoV=60.0, R=1.5) |
|
|
| emptiness_scale: int = 10 |
| emptiness_weight: int = 1e4 |
| emptiness_step: float = 0.5 |
| emptiness_multiplier: float = 20.0 |
|
|
| depth_weight: int = 0 |
|
|
| var_red: bool = True |
| exp_dir: str = "./results" |
| ti_step: int = 800 |
| pt_step: int = 800 |
| initial: str = "" |
| random_seed: int = 0 |
| semantic_model: str = "Karlo" |
| bg_preprocess: bool = True |
| num_initial_image: int = 4 |
| @validator("vox") |
| def check_vox(cls, vox_cfg, values): |
| family = values['family'] |
| if family == "sd": |
| vox_cfg.c = 4 |
| return vox_cfg |
|
|
| def run(self): |
| raise Exception("This version is for huggingface demo, which doesn't support CLI. Please visit https://github.com/KU-CVLAB/3DFuse") |
| |
| def run_gradio(self, points, images): |
| cfgs = self.dict() |
| initial = cfgs.pop('initial') |
| exp_dir=os.path.join(cfgs.pop('exp_dir'),initial) |
| |
| |
| yield gr.update(value=None), "Tuning for the LoRA layer is starting now. It will take approximately ~10 mins.", gr.update(value=None) |
| state=semantic_coding(images, cfgs,self.sd,initial) |
| self.sd.dir=state |
| |
| |
| family = cfgs.pop("family") |
| model = getattr(self, family).make() |
| print(model.prompt) |
| cfgs.pop("vox") |
| vox = self.vox.make() |
| |
| cfgs.pop("pose") |
| poser = self.pose.make() |
| |
| |
| yield from fuse_3d(**cfgs, poser=poser,model=model,vox=vox,exp_dir=exp_dir, points=points, is_gradio=True) |
|
|
|
|
| def fuse_3d( |
| poser, vox, model: ScoreAdapter, |
| lr, n_steps, emptiness_scale, emptiness_weight, emptiness_step, emptiness_multiplier, |
| depth_weight, var_red, exp_dir, points, is_gradio, **kwargs |
| ): |
| del kwargs |
|
|
| if is_gradio: |
| yield gr.update(visible=True), "LoRA layers tuning has just finished. \nScore distillation has started.", gr.update(visible=True) |
| assert model.samps_centered() |
| _, target_H, target_W = model.data_shape() |
| bs = 1 |
| aabb = vox.aabb.T.cpu().numpy() |
| vox = vox.to(device_glb) |
| opt = torch.optim.Adamax(vox.opt_params(), lr=lr) |
|
|
| H, W = poser.H, poser.W |
| Ks_, poses_, prompt_prefixes_, angles_list = poser.sample_train(n_steps,device_glb) |
|
|
| ts = model.us[30:-10] |
|
|
| fuse = EarlyLoopBreak(5) |
| |
| raster_settings = PointsRasterizationSettings( |
| image_size= 800, |
| radius = 0.02, |
| points_per_pixel = 10 |
| ) |
|
|
| ts = model.us[30:-10] |
| calibration_value=0.0 |
| |
|
|
| |
| with tqdm(total=n_steps) as pbar: |
| |
| |
|
|
| for i in range(len(poses_)): |
| if fuse.on_break(): |
| break |
| |
| depth_map = render_depth_from_cloud(points, angles_list[i], raster_settings, device_glb,calibration_value) |
| |
| y, depth, ws = render_one_view(vox, aabb, H, W, Ks_[i], poses_[i], return_w=True) |
|
|
|
|
| p = f"{prompt_prefixes_[i]} {model.prompt}" |
| score_conds = model.prompts_emb([p]) |
|
|
| score_conds['c']=score_conds['c'].repeat(bs,1,1) |
| score_conds['uc']=score_conds['uc'].repeat(bs,1,1) |
|
|
| opt.zero_grad() |
| |
| with torch.no_grad(): |
| chosen_σs = np.random.choice(ts, bs, replace=False) |
| chosen_σs = chosen_σs.reshape(-1, 1, 1, 1) |
| chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32) |
|
|
|
|
| noise = torch.randn(bs, *y.shape[1:], device=model.device) |
|
|
| zs = y + chosen_σs * noise |
|
|
| Ds = model.denoise(zs, chosen_σs,depth_map.unsqueeze(dim=0),**score_conds) |
|
|
| if var_red: |
| grad = (Ds - y) / chosen_σs |
| else: |
| grad = (Ds - zs) / chosen_σs |
|
|
| grad = grad.mean(0, keepdim=True) |
| |
| y.backward(-grad, retain_graph=True) |
|
|
| if depth_weight > 0: |
| center_depth = depth[7:-7, 7:-7] |
| border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50) |
| center_depth_mean = center_depth.mean() |
| depth_diff = center_depth_mean - border_depth_mean |
| depth_loss = - torch.log(depth_diff + 1e-12) |
| depth_loss = depth_weight * depth_loss |
| depth_loss.backward(retain_graph=True) |
|
|
| emptiness_loss = torch.log(1 + emptiness_scale * ws).mean() |
| emptiness_loss = emptiness_weight * emptiness_loss |
| if emptiness_step * n_steps <= i: |
| emptiness_loss *= emptiness_multiplier |
| emptiness_loss.backward() |
| |
| opt.step() |
|
|
| |
|
|
| if every(pbar, percent=2): |
| with torch.no_grad(): |
| y = model.decode(y) |
| |
| |
| if is_gradio : |
| yield torch_samps_to_imgs(y)[0], f"Progress: {pbar.n}/{pbar.total} \nAfter the generation is complete, the video results will be displayed below.", gr.update(value=None) |
| |
| |
| |
|
|
| |
| pbar.update() |
|
|
| pbar.set_description(p) |
| |
|
|
| |
| |
| |
|
|
| |
| out=evaluate(model, vox, poser) |
| |
| if is_gradio: |
| yield gr.update(visible=False), f"Generation complete. Please check the video below.", gr.update(value=out) |
| else : |
| yield None |
| |
| |
|
|
| |
|
|
| @torch.no_grad() |
| def evaluate(score_model, vox, poser): |
| H, W = poser.H, poser.W |
| vox.eval() |
| K, poses = poser.sample_test(100) |
|
|
| fuse = EarlyLoopBreak(5) |
| |
| |
|
|
| aabb = vox.aabb.T.cpu().numpy() |
| vox = vox.to(device_glb) |
|
|
| num_imgs = len(poses) |
| frames=[] |
| for i in (pbar := tqdm(range(num_imgs))): |
| if fuse.on_break(): |
| break |
|
|
| pose = poses[i] |
| y, depth = render_one_view(vox, aabb, H, W, K, pose) |
| y = score_model.decode(y) |
| |
| y=torch_samps_to_imgs(y)[0] |
| frames.append(y) |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) |
| writer = imageio.get_writer(out_file.name, fps=10) |
| for img in frames: |
| writer.append_data(img) |
| writer.close() |
| |
| return out_file.name |
|
|
| def render_one_view(vox, aabb, H, W, K, pose, return_w=False): |
| N = H * W |
| ro, rd = rays_from_img(H, W, K, pose) |
| |
| ro, rd, t_min, t_max = scene_box_filter_(ro, rd, aabb) |
|
|
| assert len(ro) == N, "for now all pixels must be in" |
| ro, rd, t_min, t_max = as_torch_tsrs(vox.device, ro, rd, t_min, t_max) |
| rgbs, depth, weights = render_ray_bundle(vox, ro, rd, t_min, t_max) |
|
|
| rgbs = rearrange(rgbs, "(h w) c -> 1 c h w", h=H, w=W) |
| depth = rearrange(depth, "(h w) 1 -> h w", h=H, w=W) |
| if return_w: |
| return rgbs, depth, weights |
| else: |
| return rgbs, depth |
|
|
|
|
| def scene_box_filter_(ro, rd, aabb): |
| _, t_min, t_max = ray_box_intersect(ro, rd, aabb) |
| |
| t_min, t_max = np.maximum(t_min, 0), np.maximum(t_max, 0) |
| return ro, rd, t_min, t_max |
|
|
|
|
| def vis_routine(metric, y, depth,prompt,depth_map): |
| pane = nerf_vis(y, depth, final_H=256) |
| im = torch_samps_to_imgs(y)[0] |
| |
| depth = depth.cpu().numpy() |
| metric.put_artifact("view", ".png","",lambda fn: imwrite(fn, pane)) |
| metric.put_artifact("img", ".png",prompt, lambda fn: imwrite(fn, im)) |
| if depth_map != None: |
| metric.put_artifact("PC_depth", ".png",prompt, lambda fn: imwrite(fn, depth_map.cpu().squeeze())) |
| metric.put_artifact("depth", ".npy","",lambda fn: np.save(fn, depth)) |
|
|
|
|
| if __name__ == "__main__": |
| dispatch(SJC_3DFuse) |