gsplat2d / app.py
fr1ll's picture
modify for zero gpu by adding spaces decorator
efef35e
import math
import os
import time
import spaces
from pathlib import Path
import numpy as np
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
# Import the rasterization functions from gsplat.
from gsplat import rasterization, rasterization_2dgs
class SimpleTrainer:
"""Trains random Gaussians to fit an image and saves a GIF of the training process."""
def __init__(self, gt_image: torch.Tensor, num_points: int = 2000):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.gt_image = gt_image.to(device=self.device)
self.num_points = num_points
fov_x = math.pi / 2.0
self.H, self.W = gt_image.shape[0], gt_image.shape[1]
self.focal = 0.5 * float(self.W) / math.tan(0.5 * fov_x)
self.img_size = torch.tensor([self.W, self.H, 1], device=self.device)
self._init_gaussians()
def _init_gaussians(self):
"""Initialize random Gaussians."""
bd = 2
self.means = bd * (torch.rand(self.num_points, 3, device=self.device) - 0.5)
self.scales = torch.rand(self.num_points, 3, device=self.device)
d = 3
self.rgbs = torch.rand(self.num_points, d, device=self.device)
u = torch.rand(self.num_points, 1, device=self.device)
v = torch.rand(self.num_points, 1, device=self.device)
w = torch.rand(self.num_points, 1, device=self.device)
self.quats = torch.cat([
torch.sqrt(1.0 - u) * torch.sin(2.0 * math.pi * v),
torch.sqrt(1.0 - u) * torch.cos(2.0 * math.pi * v),
torch.sqrt(u) * torch.sin(2.0 * math.pi * w),
torch.sqrt(u) * torch.cos(2.0 * math.pi * w),
], -1)
self.opacities = torch.ones((self.num_points), device=self.device)
self.viewmat = torch.tensor([
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 8.0],
[0.0, 0.0, 0.0, 1.0],
], device=self.device)
self.background = torch.zeros(d, device=self.device)
self.means.requires_grad = True
self.scales.requires_grad = True
self.quats.requires_grad = True
self.rgbs.requires_grad = True
self.opacities.requires_grad = True
self.viewmat.requires_grad = False
def train(self, iterations: int = 100, lr: float = 0.01, model_type: str = "3dgs", save_imgs: bool = False):
optimizer = optim.Adam(
[self.rgbs, self.means, self.scales, self.opacities, self.quats], lr
)
mse_loss = torch.nn.MSELoss()
frames = [] # to store frames for the GIF
# Define camera intrinsics
K = torch.tensor([
[self.focal, 0, self.W / 2],
[0, self.focal, self.H / 2],
[0, 0, 1],
], device=self.device)
if model_type == "3dgs":
rasterize_fnc = rasterization
elif model_type == "2dgs":
rasterize_fnc = rasterization_2dgs
else:
raise ValueError("Invalid model type. Choose '3dgs' or '2dgs'.")
final_out = None
for iter in range(iterations):
# Render current Gaussians
renders = rasterize_fnc(
self.means,
self.quats / self.quats.norm(dim=-1, keepdim=True),
self.scales,
torch.sigmoid(self.opacities),
torch.sigmoid(self.rgbs),
self.viewmat[None],
K[None],
self.W,
self.H,
packed=False,
)[0]
out_img = renders[0]
loss = mse_loss(out_img, self.gt_image)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log progress less frequently to reduce verbosity.
if (iter + 1) % 10 == 0 or iter == 0 or (iter + 1) == iterations:
print(f"Iteration {iter+1}/{iterations}, Loss: {loss.item()}")
# Collect frames for the GIF every 5 iterations.
if save_imgs and (iter % 5 == 0):
frame = (out_img.detach().cpu().numpy() * 255).astype(np.uint8)
frames.append(frame)
final_out = out_img.detach().cpu().numpy()
gif_path = None
# Create the GIF synchronously (waits for completion)
if save_imgs and frames:
frames_pil = [Image.fromarray(frame) for frame in frames]
out_dir = os.path.join(os.getcwd(), "results")
os.makedirs(out_dir, exist_ok=True)
gif_path = os.path.join(out_dir, "training.gif")
frames_pil[0].save(
gif_path,
save_all=True,
append_images=frames_pil[1:],
optimize=False,
duration=50, # duration in ms between frames
loop=0,
)
print(f"GIF saved at {gif_path}")
return final_out, gif_path
@spaces.GPU
def image_to_tensor(pil_image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a torch.Tensor with shape (H, W, 3) normalized to [0, 1]."""
pil_image = pil_image.convert("RGB")
transform = transforms.ToTensor()
img_tensor = transform(pil_image).permute(1, 2, 0)[..., :3]
return img_tensor
@spaces.GPU
def fit_image(pil_image: Image.Image, num_points: int, iterations: int, lr: float, model_type: str):
try:
gt_image = image_to_tensor(pil_image)
trainer = SimpleTrainer(gt_image=gt_image, num_points=num_points)
out_img, gif_path = trainer.train(iterations=iterations, lr=lr, model_type=model_type, save_imgs=True)
out_img = np.clip(out_img, 0, 1)
out_img = (out_img * 255).astype(np.uint8)
final_image = Image.fromarray(out_img)
return final_image, gif_path
except Exception as e:
print(f"Error during image fitting: {e}")
return None, None
@spaces.GPU
def gradio_fit(pil_image, num_points, iterations, lr, model_type):
"""Gradio wrapper: runs image fitting and returns the final image and training GIF."""
final_image, gif_path = fit_image(pil_image, num_points, iterations, lr, model_type)
if final_image is None:
return "An error occurred during image fitting.", None
return final_image, gif_path
# Set up the Gradio interface.
# Two outputs: one for the final fitted image, and one for the training GIF.
iface = gr.Interface(
fn=gradio_fit,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Slider(minimum=1_000, maximum=1_000_000, step=1_000, value=5_000, label="Number of Points"),
gr.Slider(minimum=10, maximum=500, step=10, value=100, label="Iterations"),
gr.Slider(minimum=0.001, maximum=0.1, step=0.001, value=0.01, label="Learning Rate"),
gr.Dropdown(choices=["3dgs", "2dgs"], value="3dgs", label="Model Type"),
],
outputs=[
gr.Image(type="pil", label="Fitted Image"),
gr.Image(type="filepath", label="Training GIF"),
],
title="Image Fitting with gsplat",
description="Fit an image using random Gaussians with gsplat. The training process is recorded as a GIF.",
)
if __name__ == "__main__":
iface.launch()