import math import os import time import spaces from pathlib import Path import numpy as np import torch import torch.optim as optim import torchvision.transforms as transforms from PIL import Image import gradio as gr # Import the rasterization functions from gsplat. from gsplat import rasterization, rasterization_2dgs class SimpleTrainer: """Trains random Gaussians to fit an image and saves a GIF of the training process.""" def __init__(self, gt_image: torch.Tensor, num_points: int = 2000): self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.gt_image = gt_image.to(device=self.device) self.num_points = num_points fov_x = math.pi / 2.0 self.H, self.W = gt_image.shape[0], gt_image.shape[1] self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x) self.img_size = torch.tensor([self.W, self.H, 1], device=self.device) self._init_gaussians() def _init_gaussians(self): """Initialize random Gaussians.""" bd = 2 self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5) self.scales = torch.rand(self.num_points, 3, device=self.device) d = 3 self.rgbs = torch.rand(self.num_points, d, device=self.device) u = torch.rand(self.num_points, 1, device=self.device) v = torch.rand(self.num_points, 1, device=self.device) w = torch.rand(self.num_points, 1, device=self.device) self.quats = torch.cat([ torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v), torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v), torch.sqrt(u) * torch.sin(2.0 * math.pi * w), torch.sqrt(u) * torch.cos(2.0 * math.pi * w), ], -1) self.opacities = torch.ones((self.num_points), device=self.device) self.viewmat = torch.tensor([ [1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 8.0], [0.0, 0.0, 0.0, 1.0], ], device=self.device) self.background = torch.zeros(d, device=self.device) self.means.requires_grad = True self.scales.requires_grad = True self.quats.requires_grad = True self.rgbs.requires_grad = True self.opacities.requires_grad = True self.viewmat.requires_grad = False def train(self, iterations: int = 100, lr: float = 0.01, model_type: str = "3dgs", save_imgs: bool = False): optimizer = optim.Adam( [self.rgbs, self.means, self.scales, self.opacities, self.quats], lr ) mse_loss = torch.nn.MSELoss() frames = [] # to store frames for the GIF # Define camera intrinsics K = torch.tensor([ [self.focal, 0, self.W / 2], [0, self.focal, self.H / 2], [0, 0, 1], ], device=self.device) if model_type == "3dgs": rasterize_fnc = rasterization elif model_type == "2dgs": rasterize_fnc = rasterization_2dgs else: raise ValueError("Invalid model type. Choose '3dgs' or '2dgs'.") final_out = None for iter in range(iterations): # Render current Gaussians renders = rasterize_fnc( self.means, self.quats / self.quats.norm(dim=-1, keepdim=True), self.scales, torch.sigmoid(self.opacities), torch.sigmoid(self.rgbs), self.viewmat[None], K[None], self.W, self.H, packed=False, )[0] out_img = renders[0] loss = mse_loss(out_img, self.gt_image) optimizer.zero_grad() loss.backward() optimizer.step() # Log progress less frequently to reduce verbosity. if (iter + 1) % 10 == 0 or iter == 0 or (iter + 1) == iterations: print(f"Iteration {iter+1}/{iterations}, Loss: {loss.item()}") # Collect frames for the GIF every 5 iterations. if save_imgs and (iter % 5 == 0): frame = (out_img.detach().cpu().numpy() * 255).astype(np.uint8) frames.append(frame) final_out = out_img.detach().cpu().numpy() gif_path = None # Create the GIF synchronously (waits for completion) if save_imgs and frames: frames_pil = [Image.fromarray(frame) for frame in frames] out_dir = os.path.join(os.getcwd(), "results") os.makedirs(out_dir, exist_ok=True) gif_path = os.path.join(out_dir, "training.gif") frames_pil[0].save( gif_path, save_all=True, append_images=frames_pil[1:], optimize=False, duration=50, # duration in ms between frames loop=0, ) print(f"GIF saved at {gif_path}") return final_out, gif_path @spaces.GPU def image_to_tensor(pil_image: Image.Image) -> torch.Tensor: """Convert a PIL image to a torch.Tensor with shape (H, W, 3) normalized to [0, 1].""" pil_image = pil_image.convert("RGB") transform = transforms.ToTensor() img_tensor = transform(pil_image).permute(1, 2, 0)[..., :3] return img_tensor @spaces.GPU def fit_image(pil_image: Image.Image, num_points: int, iterations: int, lr: float, model_type: str): try: gt_image = image_to_tensor(pil_image) trainer = SimpleTrainer(gt_image=gt_image, num_points=num_points) out_img, gif_path = trainer.train(iterations=iterations, lr=lr, model_type=model_type, save_imgs=True) out_img = np.clip(out_img, 0, 1) out_img = (out_img * 255).astype(np.uint8) final_image = Image.fromarray(out_img) return final_image, gif_path except Exception as e: print(f"Error during image fitting: {e}") return None, None @spaces.GPU def gradio_fit(pil_image, num_points, iterations, lr, model_type): """Gradio wrapper: runs image fitting and returns the final image and training GIF.""" final_image, gif_path = fit_image(pil_image, num_points, iterations, lr, model_type) if final_image is None: return "An error occurred during image fitting.", None return final_image, gif_path # Set up the Gradio interface. # Two outputs: one for the final fitted image, and one for the training GIF. iface = gr.Interface( fn=gradio_fit, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Slider(minimum=1_000, maximum=1_000_000, step=1_000, value=5_000, label="Number of Points"), gr.Slider(minimum=10, maximum=500, step=10, value=100, label="Iterations"), gr.Slider(minimum=0.001, maximum=0.1, step=0.001, value=0.01, label="Learning Rate"), gr.Dropdown(choices=["3dgs", "2dgs"], value="3dgs", label="Model Type"), ], outputs=[ gr.Image(type="pil", label="Fitted Image"), gr.Image(type="filepath", label="Training GIF"), ], title="Image Fitting with gsplat", description="Fit an image using random Gaussians with gsplat. The training process is recorded as a GIF.", ) if __name__ == "__main__": iface.launch()