Spaces:
Runtime error
Runtime error
| import math | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Literal, Optional | |
| import numpy as np | |
| import torch | |
| import tyro | |
| from PIL import Image | |
| from torch import Tensor, optim | |
| from gsplat import rasterization, rasterization_2dgs | |
| class SimpleTrainer: | |
| """Trains random gaussians to fit an image.""" | |
| def __init__( | |
| self, | |
| gt_image: Tensor, | |
| num_points: int = 2000, | |
| ): | |
| self.device = torch.device("cuda:0") | |
| self.gt_image = gt_image.to(device=self.device) | |
| self.num_points = num_points | |
| fov_x = math.pi / 2.0 | |
| self.H, self.W = gt_image.shape[0], gt_image.shape[1] | |
| self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x) | |
| self.img_size = torch.tensor([self.W, self.H, 1], device=self.device) | |
| self._init_gaussians() | |
| def _init_gaussians(self): | |
| """Random gaussians""" | |
| bd = 2 | |
| self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5) | |
| self.scales = torch.rand(self.num_points, 3, device=self.device) | |
| d = 3 | |
| self.rgbs = torch.rand(self.num_points, d, device=self.device) | |
| u = torch.rand(self.num_points, 1, device=self.device) | |
| v = torch.rand(self.num_points, 1, device=self.device) | |
| w = torch.rand(self.num_points, 1, device=self.device) | |
| self.quats = torch.cat( | |
| [ | |
| torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v), | |
| torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v), | |
| torch.sqrt(u) * torch.sin(2.0 * math.pi * w), | |
| torch.sqrt(u) * torch.cos(2.0 * math.pi * w), | |
| ], | |
| -1, | |
| ) | |
| self.opacities = torch.ones((self.num_points), device=self.device) | |
| self.viewmat = torch.tensor( | |
| [ | |
| [1.0, 0.0, 0.0, 0.0], | |
| [0.0, 1.0, 0.0, 0.0], | |
| [0.0, 0.0, 1.0, 8.0], | |
| [0.0, 0.0, 0.0, 1.0], | |
| ], | |
| device=self.device, | |
| ) | |
| self.background = torch.zeros(d, device=self.device) | |
| self.means.requires_grad = True | |
| self.scales.requires_grad = True | |
| self.quats.requires_grad = True | |
| self.rgbs.requires_grad = True | |
| self.opacities.requires_grad = True | |
| self.viewmat.requires_grad = False | |
| def train( | |
| self, | |
| iterations: int = 1000, | |
| lr: float = 0.01, | |
| save_imgs: bool = False, | |
| model_type: Literal["3dgs", "2dgs"] = "3dgs", | |
| ): | |
| optimizer = optim.Adam( | |
| [self.rgbs, self.means, self.scales, self.opacities, self.quats], lr | |
| ) | |
| mse_loss = torch.nn.MSELoss() | |
| frames = [] | |
| times = [0] * 2 # rasterization, backward | |
| K = torch.tensor( | |
| [ | |
| [self.focal, 0, self.W / 2], | |
| [0, self.focal, self.H / 2], | |
| [0, 0, 1], | |
| ], | |
| device=self.device, | |
| ) | |
| if model_type == "3dgs": | |
| rasterize_fnc = rasterization | |
| elif model_type == "2dgs": | |
| rasterize_fnc = rasterization_2dgs | |
| for iter in range(iterations): | |
| start = time.time() | |
| renders = rasterize_fnc( | |
| self.means, | |
| self.quats / self.quats.norm(dim=-1, keepdim=True), | |
| self.scales, | |
| torch.sigmoid(self.opacities), | |
| torch.sigmoid(self.rgbs), | |
| self.viewmat[None], | |
| K[None], | |
| self.W, | |
| self.H, | |
| packed=False, | |
| )[0] | |
| out_img = renders[0] | |
| torch.cuda.synchronize() | |
| times[0] += time.time() - start | |
| loss = mse_loss(out_img, self.gt_image) | |
| optimizer.zero_grad() | |
| start = time.time() | |
| loss.backward() | |
| torch.cuda.synchronize() | |
| times[1] += time.time() - start | |
| optimizer.step() | |
| print(f"Iteration {iter + 1}/{iterations}, Loss: {loss.item()}") | |
| if save_imgs and iter % 5 == 0: | |
| frames.append((out_img.detach().cpu().numpy() * 255).astype(np.uint8)) | |
| if save_imgs: | |
| # save them as a gif with PIL | |
| frames = [Image.fromarray(frame) for frame in frames] | |
| out_dir = os.path.join(os.getcwd(), "results") | |
| os.makedirs(out_dir, exist_ok=True) | |
| frames[0].save( | |
| f"{out_dir}/training.gif", | |
| save_all=True, | |
| append_images=frames[1:], | |
| optimize=False, | |
| duration=5, | |
| loop=0, | |
| ) | |
| print(f"Total(s):\nRasterization: {times[0]:.3f}, Backward: {times[1]:.3f}") | |
| print( | |
| f"Per step(s):\nRasterization: {times[0]/iterations:.5f}, Backward: {times[1]/iterations:.5f}" | |
| ) | |
| def image_path_to_tensor(image_path: Path): | |
| import torchvision.transforms as transforms | |
| img = Image.open(image_path) | |
| transform = transforms.ToTensor() | |
| img_tensor = transform(img).permute(1, 2, 0)[..., :3] | |
| return img_tensor | |
| def main( | |
| height: int = 256, | |
| width: int = 256, | |
| num_points: int = 100000, | |
| save_imgs: bool = True, | |
| img_path: Optional[Path] = None, | |
| iterations: int = 1000, | |
| lr: float = 0.01, | |
| model_type: Literal["3dgs", "2dgs"] = "3dgs", | |
| ) -> None: | |
| if img_path: | |
| gt_image = image_path_to_tensor(img_path) | |
| else: | |
| gt_image = torch.ones((height, width, 3)) * 1.0 | |
| # make top left and bottom right red, blue | |
| gt_image[: height // 2, : width // 2, :] = torch.tensor([1.0, 0.0, 0.0]) | |
| gt_image[height // 2 :, width // 2 :, :] = torch.tensor([0.0, 0.0, 1.0]) | |
| trainer = SimpleTrainer(gt_image=gt_image, num_points=num_points) | |
| trainer.train( | |
| iterations=iterations, | |
| lr=lr, | |
| save_imgs=save_imgs, | |
| model_type=model_type, | |
| ) | |
| if __name__ == "__main__": | |
| tyro.cli(main) | |