| import math |
| import os |
| import time |
| from pathlib import Path |
| from typing import Literal, Optional |
|
|
| import numpy as np |
| import torch |
| import tyro |
| from PIL import Image |
| from torch import Tensor, optim |
|
|
| from gsplat import rasterization, rasterization_2dgs |
|
|
|
|
| class SimpleTrainer: |
| """Trains random gaussians to fit an image.""" |
|
|
| def __init__( |
| self, |
| gt_image: Tensor, |
| num_points: int = 2000, |
| ): |
| self.device = torch.device("cuda:0") |
| self.gt_image = gt_image.to(device=self.device) |
| self.num_points = num_points |
|
|
| fov_x = math.pi / 2.0 |
| self.H, self.W = gt_image.shape[0], gt_image.shape[1] |
| self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x) |
| self.img_size = torch.tensor([self.W, self.H, 1], device=self.device) |
|
|
| self._init_gaussians() |
|
|
| def _init_gaussians(self): |
| """Random gaussians""" |
| bd = 2 |
|
|
| self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5) |
| self.scales = torch.rand(self.num_points, 3, device=self.device) |
| d = 3 |
| self.rgbs = torch.rand(self.num_points, d, device=self.device) |
|
|
| u = torch.rand(self.num_points, 1, device=self.device) |
| v = torch.rand(self.num_points, 1, device=self.device) |
| w = torch.rand(self.num_points, 1, device=self.device) |
|
|
| self.quats = torch.cat( |
| [ |
| torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v), |
| torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v), |
| torch.sqrt(u) * torch.sin(2.0 * math.pi * w), |
| torch.sqrt(u) * torch.cos(2.0 * math.pi * w), |
| ], |
| -1, |
| ) |
| self.opacities = torch.ones((self.num_points), device=self.device) |
|
|
| self.viewmat = torch.tensor( |
| [ |
| [1.0, 0.0, 0.0, 0.0], |
| [0.0, 1.0, 0.0, 0.0], |
| [0.0, 0.0, 1.0, 8.0], |
| [0.0, 0.0, 0.0, 1.0], |
| ], |
| device=self.device, |
| ) |
| self.background = torch.zeros(d, device=self.device) |
|
|
| self.means.requires_grad = True |
| self.scales.requires_grad = True |
| self.quats.requires_grad = True |
| self.rgbs.requires_grad = True |
| self.opacities.requires_grad = True |
| self.viewmat.requires_grad = False |
|
|
| def train( |
| self, |
| iterations: int = 1000, |
| lr: float = 0.01, |
| save_imgs: bool = False, |
| model_type: Literal["3dgs", "2dgs"] = "3dgs", |
| ): |
| optimizer = optim.Adam( |
| [self.rgbs, self.means, self.scales, self.opacities, self.quats], lr |
| ) |
| mse_loss = torch.nn.MSELoss() |
| frames = [] |
| times = [0] * 2 |
| K = torch.tensor( |
| [ |
| [self.focal, 0, self.W / 2], |
| [0, self.focal, self.H / 2], |
| [0, 0, 1], |
| ], |
| device=self.device, |
| ) |
|
|
| if model_type == "3dgs": |
| rasterize_fnc = rasterization |
| elif model_type == "2dgs": |
| rasterize_fnc = rasterization_2dgs |
|
|
| for iter in range(iterations): |
| start = time.time() |
|
|
| renders = rasterize_fnc( |
| self.means, |
| self.quats / self.quats.norm(dim=-1, keepdim=True), |
| self.scales, |
| torch.sigmoid(self.opacities), |
| torch.sigmoid(self.rgbs), |
| self.viewmat[None], |
| K[None], |
| self.W, |
| self.H, |
| packed=False, |
| )[0] |
| out_img = renders[0] |
| torch.cuda.synchronize() |
| times[0] += time.time() - start |
| loss = mse_loss(out_img, self.gt_image) |
| optimizer.zero_grad() |
| start = time.time() |
| loss.backward() |
| torch.cuda.synchronize() |
| times[1] += time.time() - start |
| optimizer.step() |
| print(f"Iteration {iter + 1}/{iterations}, Loss: {loss.item()}") |
|
|
| if save_imgs and iter % 5 == 0: |
| frames.append((out_img.detach().cpu().numpy() * 255).astype(np.uint8)) |
| if save_imgs: |
| |
| frames = [Image.fromarray(frame) for frame in frames] |
| out_dir = os.path.join(os.getcwd(), "results") |
| os.makedirs(out_dir, exist_ok=True) |
| frames[0].save( |
| f"{out_dir}/training.gif", |
| save_all=True, |
| append_images=frames[1:], |
| optimize=False, |
| duration=5, |
| loop=0, |
| ) |
| print(f"Total(s):\nRasterization: {times[0]:.3f}, Backward: {times[1]:.3f}") |
| print( |
| f"Per step(s):\nRasterization: {times[0]/iterations:.5f}, Backward: {times[1]/iterations:.5f}" |
| ) |
|
|
|
|
| def image_path_to_tensor(image_path: Path): |
| import torchvision.transforms as transforms |
|
|
| img = Image.open(image_path) |
| transform = transforms.ToTensor() |
| img_tensor = transform(img).permute(1, 2, 0)[..., :3] |
| return img_tensor |
|
|
|
|
| def main( |
| height: int = 256, |
| width: int = 256, |
| num_points: int = 100000, |
| save_imgs: bool = True, |
| img_path: Optional[Path] = None, |
| iterations: int = 1000, |
| lr: float = 0.01, |
| model_type: Literal["3dgs", "2dgs"] = "3dgs", |
| ) -> None: |
| if img_path: |
| gt_image = image_path_to_tensor(img_path) |
| else: |
| gt_image = torch.ones((height, width, 3)) * 1.0 |
| |
| gt_image[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0]) |
| gt_image[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0]) |
|
|
| trainer = SimpleTrainer(gt_image=gt_image, num_points=num_points) |
| trainer.train( |
| iterations=iterations, |
| lr=lr, |
| save_imgs=save_imgs, |
| model_type=model_type, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| tyro.cli(main) |
|
|