InteriorFusion / src /interiorfusion /utils /gaussian_utils.py
stevee00's picture
Upload src/interiorfusion/utils/gaussian_utils.py
4db20e4 verified
"""Gaussian Splatting utilities."""
import os
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
def export_gaussian_splatting(
gaussian_cloud: Union[np.ndarray, torch.Tensor],
output_path: Union[str, Path],
) -> str:
"""
Export Gaussian Splatting representation to PLY format.
Args:
gaussian_cloud: [N, 14] array/tensor with columns:
[x, y, z, scale_x, scale_y, scale_z,
rot_qx, rot_qy, rot_qz, rot_qw,
r, g, b, opacity]
"""
if isinstance(gaussian_cloud, torch.Tensor):
gaussian_cloud = gaussian_cloud.cpu().numpy()
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
try:
from plyfile import PlyData, PlyElement
num_points = len(gaussian_cloud)
dtype = [
('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('scale_0', 'f4'), ('scale_1', 'f4'), ('scale_2', 'f4'),
('rot_0', 'f4'), ('rot_1', 'f4'), ('rot_2', 'f4'), ('rot_3', 'f4'),
('f_dc_0', 'f4'), ('f_dc_1', 'f4'), ('f_dc_2', 'f4'),
('opacity', 'f4'),
]
vertices = np.zeros(num_points, dtype=dtype)
vertices['x'] = gaussian_cloud[:, 0]
vertices['y'] = gaussian_cloud[:, 1]
vertices['z'] = gaussian_cloud[:, 2]
vertices['scale_0'] = gaussian_cloud[:, 3]
vertices['scale_1'] = gaussian_cloud[:, 4]
vertices['scale_2'] = gaussian_cloud[:, 5]
vertices['rot_0'] = gaussian_cloud[:, 6]
vertices['rot_1'] = gaussian_cloud[:, 7]
vertices['rot_2'] = gaussian_cloud[:, 8]
vertices['rot_3'] = gaussian_cloud[:, 9]
vertices['f_dc_0'] = gaussian_cloud[:, 10]
vertices['f_dc_1'] = gaussian_cloud[:, 11]
vertices['f_dc_2'] = gaussian_cloud[:, 12]
vertices['opacity'] = gaussian_cloud[:, 13]
el = PlyElement.describe(vertices, 'vertex')
PlyData([el], text=True).write(str(output_path))
except ImportError:
# Fallback: write simple ASCII PLY
_write_ascii_ply(gaussian_cloud, output_path)
return str(output_path)
def _write_ascii_ply(gaussian_cloud: np.ndarray, output_path: Path):
"""Write simple ASCII PLY fallback."""
num_points = len(gaussian_cloud)
with open(output_path, 'w') as f:
f.write("ply\n")
f.write("format ascii 1.0\n")
f.write(f"element vertex {num_points}\n")
f.write("property float x\n")
f.write("property float y\n")
f.write("property float z\n")
f.write("property float scale_0\n")
f.write("property float scale_1\n")
f.write("property float scale_2\n")
f.write("property float rot_0\n")
f.write("property float rot_1\n")
f.write("property float rot_2\n")
f.write("property float rot_3\n")
f.write("property float f_dc_0\n")
f.write("property float f_dc_1\n")
f.write("property float f_dc_2\n")
f.write("property float opacity\n")
f.write("end_header\n")
for point in gaussian_cloud:
f.write(" ".join(f"{v:.6f}" for v in point) + "\n")
def render_gaussian_splatting(
gaussian_cloud: torch.Tensor,
camera_pose: torch.Tensor,
image_size: Tuple[int, int] = (512, 512),
fov: float = 50.0,
) -> torch.Tensor:
"""
Render Gaussian Splatting from given camera pose.
Args:
gaussian_cloud: [N, 14] tensor
camera_pose: [4, 4] camera-to-world transform
image_size: (H, W)
fov: Field of view in degrees
Returns:
Rendered image [3, H, W]
"""
# Placeholder: in production, use gsplat or differentiable rasterizer
# For now, return empty image
H, W = image_size
return torch.zeros(3, H, W, device=gaussian_cloud.device)
def merge_gaussian_clouds(
clouds: List[torch.Tensor],
) -> torch.Tensor:
"""Merge multiple Gaussian clouds into one."""
if not clouds:
return torch.zeros(0, 14)
valid_clouds = [c for c in clouds if len(c) > 0]
if not valid_clouds:
return torch.zeros(0, 14)
return torch.cat(valid_clouds, dim=0)