| | import math |
| | import os |
| | import time |
| | import spaces |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.optim as optim |
| | import torchvision.transforms as transforms |
| | from PIL import Image |
| | import gradio as gr |
| |
|
| | |
| | from gsplat import rasterization, rasterization_2dgs |
| |
|
| | class SimpleTrainer: |
| | """Trains random Gaussians to fit an image and saves a GIF of the training process.""" |
| | def __init__(self, gt_image: torch.Tensor, num_points: int = 2000): |
| | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | self.gt_image = gt_image.to(device=self.device) |
| | self.num_points = num_points |
| | fov_x = math.pi / 2.0 |
| | self.H, self.W = gt_image.shape[0], gt_image.shape[1] |
| | self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x) |
| | self.img_size = torch.tensor([self.W, self.H, 1], device=self.device) |
| | self._init_gaussians() |
| |
|
| | def _init_gaussians(self): |
| | """Initialize random Gaussians.""" |
| | bd = 2 |
| | self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5) |
| | self.scales = torch.rand(self.num_points, 3, device=self.device) |
| | d = 3 |
| | self.rgbs = torch.rand(self.num_points, d, device=self.device) |
| | u = torch.rand(self.num_points, 1, device=self.device) |
| | v = torch.rand(self.num_points, 1, device=self.device) |
| | w = torch.rand(self.num_points, 1, device=self.device) |
| | self.quats = torch.cat([ |
| | torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v), |
| | torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v), |
| | torch.sqrt(u) * torch.sin(2.0 * math.pi * w), |
| | torch.sqrt(u) * torch.cos(2.0 * math.pi * w), |
| | ], -1) |
| | self.opacities = torch.ones((self.num_points), device=self.device) |
| | self.viewmat = torch.tensor([ |
| | [1.0, 0.0, 0.0, 0.0], |
| | [0.0, 1.0, 0.0, 0.0], |
| | [0.0, 0.0, 1.0, 8.0], |
| | [0.0, 0.0, 0.0, 1.0], |
| | ], device=self.device) |
| | self.background = torch.zeros(d, device=self.device) |
| | self.means.requires_grad = True |
| | self.scales.requires_grad = True |
| | self.quats.requires_grad = True |
| | self.rgbs.requires_grad = True |
| | self.opacities.requires_grad = True |
| | self.viewmat.requires_grad = False |
| |
|
| | def train(self, iterations: int = 100, lr: float = 0.01, model_type: str = "3dgs", save_imgs: bool = False): |
| | optimizer = optim.Adam( |
| | [self.rgbs, self.means, self.scales, self.opacities, self.quats], lr |
| | ) |
| | mse_loss = torch.nn.MSELoss() |
| | frames = [] |
| |
|
| | |
| | K = torch.tensor([ |
| | [self.focal, 0, self.W / 2], |
| | [0, self.focal, self.H / 2], |
| | [0, 0, 1], |
| | ], device=self.device) |
| |
|
| | if model_type == "3dgs": |
| | rasterize_fnc = rasterization |
| | elif model_type == "2dgs": |
| | rasterize_fnc = rasterization_2dgs |
| | else: |
| | raise ValueError("Invalid model type. Choose '3dgs' or '2dgs'.") |
| |
|
| | final_out = None |
| | for iter in range(iterations): |
| | |
| | renders = rasterize_fnc( |
| | self.means, |
| | self.quats / self.quats.norm(dim=-1, keepdim=True), |
| | self.scales, |
| | torch.sigmoid(self.opacities), |
| | torch.sigmoid(self.rgbs), |
| | self.viewmat[None], |
| | K[None], |
| | self.W, |
| | self.H, |
| | packed=False, |
| | )[0] |
| | out_img = renders[0] |
| | loss = mse_loss(out_img, self.gt_image) |
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | |
| | if (iter + 1) % 10 == 0 or iter == 0 or (iter + 1) == iterations: |
| | print(f"Iteration {iter+1}/{iterations}, Loss: {loss.item()}") |
| |
|
| | |
| | if save_imgs and (iter % 5 == 0): |
| | frame = (out_img.detach().cpu().numpy() * 255).astype(np.uint8) |
| | frames.append(frame) |
| | final_out = out_img.detach().cpu().numpy() |
| |
|
| | gif_path = None |
| | |
| | if save_imgs and frames: |
| | frames_pil = [Image.fromarray(frame) for frame in frames] |
| | out_dir = os.path.join(os.getcwd(), "results") |
| | os.makedirs(out_dir, exist_ok=True) |
| | gif_path = os.path.join(out_dir, "training.gif") |
| | frames_pil[0].save( |
| | gif_path, |
| | save_all=True, |
| | append_images=frames_pil[1:], |
| | optimize=False, |
| | duration=50, |
| | loop=0, |
| | ) |
| | print(f"GIF saved at {gif_path}") |
| | return final_out, gif_path |
| |
|
| | @spaces.GPU |
| | def image_to_tensor(pil_image: Image.Image) -> torch.Tensor: |
| | """Convert a PIL image to a torch.Tensor with shape (H, W, 3) normalized to [0, 1].""" |
| | pil_image = pil_image.convert("RGB") |
| | transform = transforms.ToTensor() |
| | img_tensor = transform(pil_image).permute(1, 2, 0)[..., :3] |
| | return img_tensor |
| |
|
| | @spaces.GPU |
| | def fit_image(pil_image: Image.Image, num_points: int, iterations: int, lr: float, model_type: str): |
| | try: |
| | gt_image = image_to_tensor(pil_image) |
| | trainer = SimpleTrainer(gt_image=gt_image, num_points=num_points) |
| | out_img, gif_path = trainer.train(iterations=iterations, lr=lr, model_type=model_type, save_imgs=True) |
| | out_img = np.clip(out_img, 0, 1) |
| | out_img = (out_img * 255).astype(np.uint8) |
| | final_image = Image.fromarray(out_img) |
| | return final_image, gif_path |
| | except Exception as e: |
| | print(f"Error during image fitting: {e}") |
| | return None, None |
| |
|
| | @spaces.GPU |
| | def gradio_fit(pil_image, num_points, iterations, lr, model_type): |
| | """Gradio wrapper: runs image fitting and returns the final image and training GIF.""" |
| | final_image, gif_path = fit_image(pil_image, num_points, iterations, lr, model_type) |
| | if final_image is None: |
| | return "An error occurred during image fitting.", None |
| | return final_image, gif_path |
| |
|
| | |
| | |
| | iface = gr.Interface( |
| | fn=gradio_fit, |
| | inputs=[ |
| | gr.Image(type="pil", label="Input Image"), |
| | gr.Slider(minimum=1_000, maximum=1_000_000, step=1_000, value=5_000, label="Number of Points"), |
| | gr.Slider(minimum=10, maximum=500, step=10, value=100, label="Iterations"), |
| | gr.Slider(minimum=0.001, maximum=0.1, step=0.001, value=0.01, label="Learning Rate"), |
| | gr.Dropdown(choices=["3dgs", "2dgs"], value="3dgs", label="Model Type"), |
| | ], |
| | outputs=[ |
| | gr.Image(type="pil", label="Fitted Image"), |
| | gr.Image(type="filepath", label="Training GIF"), |
| | ], |
| | title="Image Fitting with gsplat", |
| | description="Fit an image using random Gaussians with gsplat. The training process is recorded as a GIF.", |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | iface.launch() |