| import numpy as np |
| import torch |
| import imageio |
|
|
| from my.utils.tqdm import tqdm |
| from my.utils.event import EventStorage, read_stats, get_event_storage |
| from my.utils.heartbeat import HeartBeat, get_heartbeat |
| from my.utils.debug import EarlyLoopBreak |
|
|
| from .utils import PSNR, Scrambler, every, at |
| from .data import load_blender |
| from .render import ( |
| as_torch_tsrs, scene_box_filter, render_ray_bundle, render_one_view, rays_from_img |
| ) |
| from .vis import vis, stitch_vis |
|
|
|
|
| device_glb = torch.device("cuda") |
|
|
|
|
| def all_train_rays(scene): |
| imgs, K, poses = load_blender("train", scene) |
| num_imgs = len(imgs) |
| ro, rd, rgbs = [], [], [] |
| for i in tqdm(range(num_imgs)): |
| img, pose = imgs[i], poses[i] |
| H, W = img.shape[:2] |
| _ro, _rd = rays_from_img(H, W, K, pose) |
| ro.append(_ro) |
| rd.append(_rd) |
| rgbs.append(img.reshape(-1, 3)) |
|
|
| ro, rd, rgbs = [ |
| np.concatenate(xs, axis=0) for xs in (ro, rd, rgbs) |
| ] |
| return ro, rd, rgbs |
|
|
|
|
| class OneTestView(): |
| def __init__(self, scene): |
| imgs, K, poses = load_blender("test", scene) |
| self.imgs, self.K, self.poses = imgs, K, poses |
| self.i = 0 |
|
|
| def render(self, model): |
| i = self.i |
| img, K, pose = self.imgs[i], self.K, self.poses[i] |
| with torch.no_grad(): |
| aabb = model.aabb.T.cpu().numpy() |
| H, W = img.shape[:2] |
| rgbs, depth = render_one_view(model, aabb, H, W, K, pose) |
| psnr = PSNR.psnr(img, rgbs) |
|
|
| self.i = (self.i + 1) % len(self.imgs) |
|
|
| return img, rgbs, depth, psnr |
|
|
|
|
| def train( |
| model, n_epoch=2, bs=4096, lr=0.02, scene="lego" |
| ): |
| fuse = EarlyLoopBreak(500) |
|
|
| aabb = model.aabb.T.numpy() |
| model = model.to(device_glb) |
| optim = torch.optim.Adam(model.parameters(), lr=lr) |
|
|
| test_view = OneTestView(scene) |
| all_ro, all_rd, all_rgbs = all_train_rays(scene) |
| print(n_epoch, len(all_ro), bs) |
| with tqdm(total=(n_epoch * len(all_ro) // bs)) as pbar, \ |
| HeartBeat(pbar) as hbeat, EventStorage() as metric: |
|
|
| ro, rd, t_min, t_max, intsct_inds = scene_box_filter(all_ro, all_rd, aabb) |
| rgbs = all_rgbs[intsct_inds] |
| print(len(ro)) |
| for epc in range(n_epoch): |
| n = len(ro) |
| scrambler = Scrambler(n) |
| ro, rd, t_min, t_max, rgbs = scrambler.apply(ro, rd, t_min, t_max, rgbs) |
|
|
| num_batch = int(np.ceil(n / bs)) |
| for i in range(num_batch): |
| if fuse.on_break(): |
| break |
| s = i * bs |
| e = min(n, s + bs) |
|
|
| optim.zero_grad() |
| _ro, _rd, _t_min, _t_max, _rgbs = as_torch_tsrs( |
| model.device, ro[s:e], rd[s:e], t_min[s:e], t_max[s:e], rgbs[s:e] |
| ) |
| pred, _, _ = render_ray_bundle(model, _ro, _rd, _t_min, _t_max) |
| loss = ((pred - _rgbs) ** 2).mean() |
| loss.backward() |
| optim.step() |
|
|
| pbar.update() |
|
|
| psnr = PSNR.psnr_from_mse(loss.item()) |
| metric.put_scalars(psnr=psnr, d_scale=model.d_scale.item()) |
|
|
| if every(pbar, step=50): |
| pbar.set_description(f"TRAIN: psnr {psnr:.2f}") |
|
|
| if every(pbar, percent=1): |
| gimg, rimg, depth, psnr = test_view.render(model) |
| pane = vis( |
| gimg, rimg, depth, |
| msg=f"psnr: {psnr:.2f}", return_buffer=True |
| ) |
| metric.put_artifact( |
| "vis", ".png", lambda fn: imageio.imwrite(fn, pane) |
| ) |
|
|
| if at(pbar, percent=30): |
| model.make_alpha_mask() |
|
|
| if every(pbar, percent=35): |
| target_xyz = (model.grid_size * 1.328).int().tolist() |
| model.resample(target_xyz) |
| optim = torch.optim.Adam(model.parameters(), lr=lr) |
| print(f"resamp the voxel to {model.grid_size}") |
|
|
| curr_lr = update_lr(pbar, optim, lr) |
| metric.put_scalars(lr=curr_lr) |
|
|
| metric.step() |
| hbeat.beat() |
|
|
| metric.put_artifact( |
| "ckpt", ".pt", lambda fn: torch.save(model.state_dict(), fn) |
| ) |
| |
|
|
| metric.put_artifact( |
| "train_seq", ".mp4", |
| lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "vis")[1]) |
| ) |
|
|
| with EventStorage("test"): |
| final_psnr = test(model, scene) |
| metric.put("test_psnr", final_psnr) |
|
|
| metric.step() |
|
|
| hbeat.done() |
|
|
|
|
| def update_lr(pbar, optimizer, init_lr): |
| i, N = pbar.n, pbar.total |
| factor = 0.1 ** (1 / N) |
| lr = init_lr * (factor ** i) |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = lr |
| return lr |
|
|
|
|
| def last_ckpt(): |
| ts, ckpts = read_stats("./", "ckpt") |
| if len(ckpts) > 0: |
| fname = ckpts[-1] |
| last = torch.load(fname, map_location="cpu") |
| print(f"loaded ckpt from iter {ts[-1]}") |
| return last |
|
|
|
|
| def __evaluate_ckpt(model, scene): |
| |
| |
| metric = get_event_storage() |
|
|
| state = last_ckpt() |
| if state is not None: |
| model.load_state_dict(state) |
| model.to(device_glb) |
|
|
| with EventStorage("test"): |
| final_psnr = test(model, scene) |
| metric.put("test_psnr", final_psnr) |
|
|
|
|
| def test(model, scene): |
| fuse = EarlyLoopBreak(5) |
| metric = get_event_storage() |
| hbeat = get_heartbeat() |
|
|
| aabb = model.aabb.T.cpu().numpy() |
| model = model.to(device_glb) |
|
|
| imgs, K, poses = load_blender("test", scene) |
| num_imgs = len(imgs) |
|
|
| stats = [] |
|
|
| for i in (pbar := tqdm(range(num_imgs))): |
| if fuse.on_break(): |
| break |
|
|
| img, pose = imgs[i], poses[i] |
| H, W = img.shape[:2] |
| rgbs, depth = render_one_view(model, aabb, H, W, K, pose) |
| psnr = PSNR.psnr(img, rgbs) |
|
|
| stats.append(psnr) |
| metric.put_scalars(psnr=psnr) |
| pbar.set_description(f"TEST: mean psnr {np.mean(stats):.2f}") |
|
|
| plot = vis(img, rgbs, depth, msg=f"PSNR: {psnr:.2f}", return_buffer=True) |
| metric.put_artifact("test_vis", ".png", lambda fn: imageio.imwrite(fn, plot)) |
| metric.step() |
| hbeat.beat() |
|
|
| metric.put_artifact( |
| "test_seq", ".mp4", |
| lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "test_vis")[1]) |
| ) |
|
|
| final_psnr = np.mean(stats) |
| metric.put("final_psnr", final_psnr) |
| metric.step() |
|
|
| return final_psnr |
|
|