Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8f04a1a
1
Parent(s):
133857a
Initial commit
Browse files- gslrm/model/gslrm.py +3 -917
- gslrm/model/utils_losses.py +0 -309
- splat_viewer.html +0 -277
gslrm/model/gslrm.py
CHANGED
|
@@ -22,11 +22,8 @@ Classes:
|
|
| 22 |
"""
|
| 23 |
|
| 24 |
import copy
|
| 25 |
-
import
|
| 26 |
-
import time
|
| 27 |
-
from typing import Dict, List, Optional, Tuple, Union
|
| 28 |
|
| 29 |
-
import cv2
|
| 30 |
import lpips
|
| 31 |
import numpy as np
|
| 32 |
import torch
|
|
@@ -35,17 +32,13 @@ import torch.nn.functional as F
|
|
| 35 |
from easydict import EasyDict as edict
|
| 36 |
from einops import rearrange
|
| 37 |
from einops.layers.torch import Rearrange
|
| 38 |
-
from PIL import Image
|
| 39 |
|
| 40 |
# Local imports
|
| 41 |
from .utils_losses import PerceptualLoss, SsimLoss
|
| 42 |
from .gaussians_renderer import (
|
| 43 |
GaussianModel,
|
| 44 |
-
RGB2SH,
|
| 45 |
deferred_gaussian_render,
|
| 46 |
-
imageseq2video,
|
| 47 |
render_opencv_cam,
|
| 48 |
-
render_turntable,
|
| 49 |
)
|
| 50 |
from .transform_data import SplitData, TransformInput, TransformTarget
|
| 51 |
from .utils_transformer import (
|
|
@@ -225,238 +218,6 @@ class GaussiansUpsampler(nn.Module):
|
|
| 225 |
|
| 226 |
return xyz, features, scaling, rotation, opacity
|
| 227 |
|
| 228 |
-
|
| 229 |
-
class LossComputer(nn.Module):
|
| 230 |
-
"""
|
| 231 |
-
Computes various loss functions for training the GSLRM model.
|
| 232 |
-
|
| 233 |
-
Supports multiple loss types:
|
| 234 |
-
- L2 (MSE) loss
|
| 235 |
-
- LPIPS perceptual loss
|
| 236 |
-
- Custom perceptual loss
|
| 237 |
-
- SSIM loss
|
| 238 |
-
- Pixel alignment loss
|
| 239 |
-
- Point distance regularization loss
|
| 240 |
-
"""
|
| 241 |
-
|
| 242 |
-
def __init__(self, config: edict):
|
| 243 |
-
super().__init__()
|
| 244 |
-
self.config = config
|
| 245 |
-
|
| 246 |
-
# Initialize loss modules based on config
|
| 247 |
-
self._init_loss_modules()
|
| 248 |
-
|
| 249 |
-
def _init_loss_modules(self):
|
| 250 |
-
"""Initialize the various loss computation modules."""
|
| 251 |
-
# LPIPS loss
|
| 252 |
-
if self.config.training.losses.lpips_loss_weight > 0.0:
|
| 253 |
-
self.lpips_loss_module = lpips.LPIPS(net="vgg")
|
| 254 |
-
self.lpips_loss_module.eval()
|
| 255 |
-
# Freeze LPIPS parameters
|
| 256 |
-
for param in self.lpips_loss_module.parameters():
|
| 257 |
-
param.requires_grad = False
|
| 258 |
-
|
| 259 |
-
# Perceptual loss
|
| 260 |
-
if self.config.training.losses.perceptual_loss_weight > 0.0:
|
| 261 |
-
self.perceptual_loss_module = PerceptualLoss()
|
| 262 |
-
self.perceptual_loss_module.eval()
|
| 263 |
-
# Freeze perceptual loss parameters
|
| 264 |
-
for param in self.perceptual_loss_module.parameters():
|
| 265 |
-
param.requires_grad = False
|
| 266 |
-
|
| 267 |
-
# SSIM loss
|
| 268 |
-
if self.config.training.losses.ssim_loss_weight > 0.0:
|
| 269 |
-
self.ssim_loss_module = SsimLoss()
|
| 270 |
-
self.ssim_loss_module.eval()
|
| 271 |
-
# Freeze SSIM parameters
|
| 272 |
-
for param in self.ssim_loss_module.parameters():
|
| 273 |
-
param.requires_grad = False
|
| 274 |
-
|
| 275 |
-
def forward(
|
| 276 |
-
self,
|
| 277 |
-
rendering: torch.Tensor, # [b, v, 3, h, w]
|
| 278 |
-
target: torch.Tensor, # [b, v, 3, h, w]
|
| 279 |
-
img_aligned_xyz: torch.Tensor, # [b, v, 3, h, w]
|
| 280 |
-
input: edict,
|
| 281 |
-
result_softpa: Optional[edict] = None,
|
| 282 |
-
create_visual: bool = False,
|
| 283 |
-
) -> edict:
|
| 284 |
-
"""
|
| 285 |
-
Compute all losses between rendered and target images.
|
| 286 |
-
|
| 287 |
-
Args:
|
| 288 |
-
rendering: Rendered images in range [0, 1]
|
| 289 |
-
target: Target images in range [0, 1]
|
| 290 |
-
img_aligned_xyz: Image-aligned 3D positions
|
| 291 |
-
input: Input data containing ray information
|
| 292 |
-
result_softpa: Additional results (unused)
|
| 293 |
-
create_visual: Whether to create visualization images
|
| 294 |
-
|
| 295 |
-
Returns:
|
| 296 |
-
Dictionary containing all loss values and metrics
|
| 297 |
-
"""
|
| 298 |
-
b, v, _, h, w = rendering.size()
|
| 299 |
-
rendering_flat = rendering.reshape(b * v, -1, h, w)
|
| 300 |
-
target_flat = target.reshape(b * v, -1, h, w)
|
| 301 |
-
|
| 302 |
-
# Handle alpha channel if present
|
| 303 |
-
mask = None
|
| 304 |
-
if target_flat.size(1) == 4:
|
| 305 |
-
target_flat, mask = target_flat.split([3, 1], dim=1)
|
| 306 |
-
|
| 307 |
-
# Compute individual losses
|
| 308 |
-
losses = self._compute_all_losses(
|
| 309 |
-
rendering_flat, target_flat, img_aligned_xyz, input, mask, b, v, h, w
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
# Compute total weighted loss
|
| 313 |
-
total_loss = self._compute_total_loss(losses)
|
| 314 |
-
|
| 315 |
-
# Create visualization if requested
|
| 316 |
-
visual = self._create_visual(rendering_flat, target_flat, v) if create_visual else None
|
| 317 |
-
|
| 318 |
-
# Compile loss metrics
|
| 319 |
-
return self._compile_loss_metrics(losses, total_loss, visual)
|
| 320 |
-
|
| 321 |
-
def _compute_all_losses(self, rendering, target, img_aligned_xyz, input, mask, b, v, h, w):
|
| 322 |
-
"""Compute all individual loss components."""
|
| 323 |
-
losses = {}
|
| 324 |
-
|
| 325 |
-
# L2 (MSE) loss
|
| 326 |
-
losses['l2'] = self._compute_l2_loss(rendering, target)
|
| 327 |
-
losses['psnr'] = -10.0 * torch.log10(losses['l2'])
|
| 328 |
-
|
| 329 |
-
# LPIPS loss
|
| 330 |
-
losses['lpips'] = self._compute_lpips_loss(rendering, target)
|
| 331 |
-
|
| 332 |
-
# Perceptual loss
|
| 333 |
-
losses['perceptual'] = self._compute_perceptual_loss(rendering, target)
|
| 334 |
-
|
| 335 |
-
# SSIM loss
|
| 336 |
-
losses['ssim'] = self._compute_ssim_loss(rendering, target)
|
| 337 |
-
|
| 338 |
-
# Pixel alignment loss
|
| 339 |
-
losses['pixelalign'] = self._compute_pixelalign_loss(
|
| 340 |
-
img_aligned_xyz, input, mask, b, v, h, w
|
| 341 |
-
)
|
| 342 |
-
|
| 343 |
-
# Point distance loss
|
| 344 |
-
losses['pointsdist'] = self._compute_pointsdist_loss(
|
| 345 |
-
img_aligned_xyz, input, b, v, h, w
|
| 346 |
-
)
|
| 347 |
-
|
| 348 |
-
return losses
|
| 349 |
-
|
| 350 |
-
def _compute_l2_loss(self, rendering, target):
|
| 351 |
-
"""Compute L2 (MSE) loss."""
|
| 352 |
-
if self.config.training.losses.l2_loss_weight > 0.0:
|
| 353 |
-
return F.mse_loss(rendering, target)
|
| 354 |
-
return torch.tensor(1e-8, device=rendering.device)
|
| 355 |
-
|
| 356 |
-
def _compute_lpips_loss(self, rendering, target):
|
| 357 |
-
"""Compute LPIPS perceptual loss."""
|
| 358 |
-
if self.config.training.losses.lpips_loss_weight > 0.0:
|
| 359 |
-
# LPIPS expects inputs in range [-1, 1]
|
| 360 |
-
return self.lpips_loss_module(
|
| 361 |
-
rendering * 2.0 - 1.0, target * 2.0 - 1.0
|
| 362 |
-
).mean()
|
| 363 |
-
return torch.tensor(0.0, device=rendering.device)
|
| 364 |
-
|
| 365 |
-
def _compute_perceptual_loss(self, rendering, target):
|
| 366 |
-
"""Compute custom perceptual loss."""
|
| 367 |
-
if self.config.training.losses.perceptual_loss_weight > 0.0:
|
| 368 |
-
return self.perceptual_loss_module(rendering, target)
|
| 369 |
-
return torch.tensor(0.0, device=rendering.device)
|
| 370 |
-
|
| 371 |
-
def _compute_ssim_loss(self, rendering, target):
|
| 372 |
-
"""Compute SSIM loss."""
|
| 373 |
-
if self.config.training.losses.ssim_loss_weight > 0.0:
|
| 374 |
-
return self.ssim_loss_module(rendering, target)
|
| 375 |
-
return torch.tensor(0.0, device=rendering.device)
|
| 376 |
-
|
| 377 |
-
def _compute_pixelalign_loss(self, img_aligned_xyz, input, mask, b, v, h, w):
|
| 378 |
-
"""Compute pixel alignment loss."""
|
| 379 |
-
if self.config.training.losses.pixelalign_loss_weight > 0.0:
|
| 380 |
-
# Compute orthogonal component to ray direction
|
| 381 |
-
xyz_vec = img_aligned_xyz - input.ray_o
|
| 382 |
-
ortho_vec = (
|
| 383 |
-
xyz_vec
|
| 384 |
-
- torch.sum(xyz_vec.detach() * input.ray_d, dim=2, keepdim=True)
|
| 385 |
-
* input.ray_d
|
| 386 |
-
)
|
| 387 |
-
|
| 388 |
-
# Apply mask if enabled
|
| 389 |
-
if self.config.training.losses.get("masked_pixelalign_loss", False):
|
| 390 |
-
assert mask is not None, "mask is None but masked_pixelalign_loss is enabled"
|
| 391 |
-
mask_reshaped = mask.view(b, v, 1, h, w)
|
| 392 |
-
ortho_vec = ortho_vec * mask_reshaped
|
| 393 |
-
|
| 394 |
-
return torch.mean(ortho_vec.norm(dim=2, p=2))
|
| 395 |
-
|
| 396 |
-
return torch.tensor(0.0, device=img_aligned_xyz.device)
|
| 397 |
-
|
| 398 |
-
def _compute_pointsdist_loss(self, img_aligned_xyz, input, b, v, h, w):
|
| 399 |
-
"""Compute point distance regularization loss."""
|
| 400 |
-
if self.config.training.losses.pointsdist_loss_weight > 0.0:
|
| 401 |
-
# Target mean distance (distance from origin to ray origin)
|
| 402 |
-
target_mean_dist = torch.norm(input.ray_o, dim=2, p=2, keepdim=True)
|
| 403 |
-
target_std_dist = 0.5
|
| 404 |
-
|
| 405 |
-
# Predicted distance
|
| 406 |
-
pred_dist = (img_aligned_xyz - input.ray_o).norm(dim=2, p=2, keepdim=True)
|
| 407 |
-
|
| 408 |
-
# Normalize to target distribution
|
| 409 |
-
pred_dist_detach = pred_dist.detach()
|
| 410 |
-
pred_mean = pred_dist_detach.mean(dim=(2, 3, 4), keepdim=True)
|
| 411 |
-
pred_std = pred_dist_detach.std(dim=(2, 3, 4), keepdim=True)
|
| 412 |
-
|
| 413 |
-
target_dist = (pred_dist_detach - pred_mean) / (pred_std + 1e-8) * target_std_dist + target_mean_dist
|
| 414 |
-
|
| 415 |
-
return torch.mean((pred_dist - target_dist) ** 2)
|
| 416 |
-
|
| 417 |
-
return torch.tensor(0.0, device=img_aligned_xyz.device)
|
| 418 |
-
|
| 419 |
-
def _compute_total_loss(self, losses):
|
| 420 |
-
"""Compute weighted sum of all losses."""
|
| 421 |
-
weights = self.config.training.losses
|
| 422 |
-
return (
|
| 423 |
-
weights.l2_loss_weight * losses['l2']
|
| 424 |
-
+ weights.lpips_loss_weight * losses['lpips']
|
| 425 |
-
+ weights.perceptual_loss_weight * losses['perceptual']
|
| 426 |
-
+ weights.ssim_loss_weight * losses['ssim']
|
| 427 |
-
+ weights.pixelalign_loss_weight * losses['pixelalign']
|
| 428 |
-
+ weights.pointsdist_loss_weight * losses['pointsdist']
|
| 429 |
-
)
|
| 430 |
-
|
| 431 |
-
def _create_visual(self, rendering, target, v):
|
| 432 |
-
"""Create visualization by concatenating target and rendering."""
|
| 433 |
-
visual = torch.cat((target, rendering), dim=3).detach().cpu() # [b*v, c, h, w*2]
|
| 434 |
-
visual = rearrange(visual, "(b v) c h (m w) -> (b h) (v m w) c", v=v, m=2)
|
| 435 |
-
return (visual.numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
|
| 436 |
-
|
| 437 |
-
def _compile_loss_metrics(self, losses, total_loss, visual):
|
| 438 |
-
"""Compile all loss metrics into a dictionary."""
|
| 439 |
-
l2_loss = losses['l2']
|
| 440 |
-
|
| 441 |
-
return edict(
|
| 442 |
-
loss=total_loss,
|
| 443 |
-
l2_loss=l2_loss,
|
| 444 |
-
psnr=losses['psnr'],
|
| 445 |
-
lpips_loss=losses['lpips'],
|
| 446 |
-
perceptual_loss=losses['perceptual'],
|
| 447 |
-
ssim_loss=losses['ssim'],
|
| 448 |
-
pixelalign_loss=losses['pixelalign'],
|
| 449 |
-
pointsdist_loss=losses['pointsdist'],
|
| 450 |
-
visual=visual,
|
| 451 |
-
# Normalized losses for logging
|
| 452 |
-
norm_perceptual_loss=losses['perceptual'] / l2_loss,
|
| 453 |
-
norm_lpips_loss=losses['lpips'] / l2_loss,
|
| 454 |
-
norm_ssim_loss=losses['ssim'] / l2_loss,
|
| 455 |
-
norm_pixelalign_loss=losses['pixelalign'] / l2_loss,
|
| 456 |
-
norm_pointsdist_loss=losses['pointsdist'] / l2_loss,
|
| 457 |
-
)
|
| 458 |
-
|
| 459 |
-
|
| 460 |
class GSLRM(nn.Module):
|
| 461 |
"""
|
| 462 |
Gaussian Splatting Large Reconstruction Model.
|
|
@@ -575,7 +336,6 @@ class GSLRM(nn.Module):
|
|
| 575 |
def _init_rendering_modules(self, config: edict) -> None:
|
| 576 |
"""Initialize rendering and loss computation modules."""
|
| 577 |
self.gaussian_renderer = Renderer(config)
|
| 578 |
-
self.loss_calculator = LossComputer(config)
|
| 579 |
|
| 580 |
def _init_training_state(self, config: edict) -> None:
|
| 581 |
"""Initialize training state management variables."""
|
|
@@ -584,101 +344,6 @@ class GSLRM(nn.Module):
|
|
| 584 |
self.training_max_step = None
|
| 585 |
self.original_config = copy.deepcopy(config)
|
| 586 |
|
| 587 |
-
def set_training_step(self, current_step: int, start_step: int, max_step: int) -> None:
|
| 588 |
-
"""
|
| 589 |
-
Update training step and dynamically adjust configuration based on training phase.
|
| 590 |
-
|
| 591 |
-
Args:
|
| 592 |
-
current_step: Current training step
|
| 593 |
-
start_step: Starting step of training
|
| 594 |
-
max_step: Maximum training steps
|
| 595 |
-
"""
|
| 596 |
-
self.training_step = current_step
|
| 597 |
-
self.training_start_step = start_step
|
| 598 |
-
self.training_max_step = max_step
|
| 599 |
-
|
| 600 |
-
# Determine if config modification is needed based on warmup settings
|
| 601 |
-
needs_config_modification = self._should_modify_config_for_warmup(current_step)
|
| 602 |
-
|
| 603 |
-
if needs_config_modification:
|
| 604 |
-
# Always use original config as base for modifications
|
| 605 |
-
self.config = copy.deepcopy(self.original_config)
|
| 606 |
-
self._apply_warmup_modifications(current_step)
|
| 607 |
-
else:
|
| 608 |
-
# Restore original configuration
|
| 609 |
-
self.config = self.original_config
|
| 610 |
-
|
| 611 |
-
# Update loss calculator with current config
|
| 612 |
-
self.loss_calculator.config = self.config
|
| 613 |
-
|
| 614 |
-
def _should_modify_config_for_warmup(self, current_step: int) -> bool:
|
| 615 |
-
"""Check if configuration should be modified for warmup phases."""
|
| 616 |
-
pointsdist_warmup = (
|
| 617 |
-
self.config.training.losses.get("warmup_pointsdist", False)
|
| 618 |
-
and current_step < 1000
|
| 619 |
-
)
|
| 620 |
-
l2_warmup = (
|
| 621 |
-
self.config.training.schedule.get("l2_warmup_steps", 0) > 0
|
| 622 |
-
and current_step < self.config.training.schedule.l2_warmup_steps
|
| 623 |
-
)
|
| 624 |
-
return pointsdist_warmup or l2_warmup
|
| 625 |
-
|
| 626 |
-
def _apply_warmup_modifications(self, current_step: int) -> None:
|
| 627 |
-
"""Apply configuration modifications for warmup phases."""
|
| 628 |
-
# Point distance warmup phase
|
| 629 |
-
if (self.config.training.losses.get("warmup_pointsdist", False)
|
| 630 |
-
and current_step < 1000):
|
| 631 |
-
self.config.training.losses.l2_loss_weight = 0.0
|
| 632 |
-
self.config.training.losses.perceptual_loss_weight = 0.0
|
| 633 |
-
self.config.training.losses.pointsdist_loss_weight = 0.1
|
| 634 |
-
self.config.model.clip_xyz = False # Disable xyz clipping during warmup
|
| 635 |
-
|
| 636 |
-
# L2 loss warmup phase
|
| 637 |
-
if (self.config.training.schedule.get("l2_warmup_steps", 0) > 0
|
| 638 |
-
and current_step < self.config.training.schedule.l2_warmup_steps):
|
| 639 |
-
self.config.training.losses.perceptual_loss_weight = 0.0
|
| 640 |
-
self.config.training.losses.lpips_loss_weight = 0.0
|
| 641 |
-
|
| 642 |
-
def set_current_step(self, current_step: int, start_step: int, max_step: int) -> None:
|
| 643 |
-
"""Backward compatibility wrapper for set_training_step."""
|
| 644 |
-
self.set_training_step(current_step, start_step, max_step)
|
| 645 |
-
|
| 646 |
-
def train(self, mode: bool = True) -> None:
|
| 647 |
-
"""
|
| 648 |
-
Override train method to keep frozen modules in eval mode.
|
| 649 |
-
|
| 650 |
-
Args:
|
| 651 |
-
mode: Whether to set training mode (True) or evaluation mode (False)
|
| 652 |
-
"""
|
| 653 |
-
super().train(mode)
|
| 654 |
-
# Keep loss calculator in eval mode to prevent training of frozen components
|
| 655 |
-
if self.loss_calculator is not None:
|
| 656 |
-
self.loss_calculator.eval()
|
| 657 |
-
|
| 658 |
-
def get_parameter_overview(self) -> edict:
|
| 659 |
-
"""
|
| 660 |
-
Get overview of trainable parameters in each module.
|
| 661 |
-
|
| 662 |
-
Returns:
|
| 663 |
-
Dictionary containing parameter counts for each major component
|
| 664 |
-
"""
|
| 665 |
-
def count_trainable_params(module: nn.Module) -> int:
|
| 666 |
-
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 667 |
-
|
| 668 |
-
return edict(
|
| 669 |
-
patch_embedder=count_trainable_params(self.patch_embedder),
|
| 670 |
-
gaussian_position_embeddings=self.gaussian_position_embeddings.data.numel(),
|
| 671 |
-
transformer_total=(
|
| 672 |
-
count_trainable_params(self.transformer_layers) +
|
| 673 |
-
count_trainable_params(self.input_layer_norm)
|
| 674 |
-
),
|
| 675 |
-
gaussian_upsampler=count_trainable_params(self.gaussian_upsampler),
|
| 676 |
-
pixel_gaussian_decoder=count_trainable_params(self.pixel_gaussian_decoder),
|
| 677 |
-
)
|
| 678 |
-
|
| 679 |
-
def get_overview(self) -> edict:
|
| 680 |
-
"""Backward compatibility wrapper for get_parameter_overview."""
|
| 681 |
-
return self.get_parameter_overview()
|
| 682 |
|
| 683 |
def _create_transformer_layer_runner(self, start_layer: int, end_layer: int):
|
| 684 |
"""
|
|
@@ -843,149 +508,6 @@ class GSLRM(nn.Module):
|
|
| 843 |
|
| 844 |
return aligned_positions
|
| 845 |
|
| 846 |
-
@staticmethod
|
| 847 |
-
def translate_legacy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 848 |
-
"""
|
| 849 |
-
Translate legacy model parameter names to new parameter names.
|
| 850 |
-
|
| 851 |
-
This function allows loading models saved with the old variable names
|
| 852 |
-
by mapping them to the new, cleaner variable names.
|
| 853 |
-
|
| 854 |
-
Args:
|
| 855 |
-
state_dict: Dictionary containing model parameters with old names
|
| 856 |
-
|
| 857 |
-
Returns:
|
| 858 |
-
Dictionary with parameters mapped to new names
|
| 859 |
-
"""
|
| 860 |
-
# Define the mapping from old names to new names
|
| 861 |
-
name_mapping = {
|
| 862 |
-
# Data processors
|
| 863 |
-
'split_data.': 'data_splitter.',
|
| 864 |
-
'transform_input.': 'input_transformer.',
|
| 865 |
-
'transform_target.': 'target_transformer.',
|
| 866 |
-
|
| 867 |
-
# Tokenizer
|
| 868 |
-
'image_tokenizer.': 'patch_embedder.',
|
| 869 |
-
|
| 870 |
-
# Positional embeddings
|
| 871 |
-
'refsrc_marker': 'view_type_embeddings',
|
| 872 |
-
'gaussians_pos_embedding': 'gaussian_position_embeddings',
|
| 873 |
-
|
| 874 |
-
# Transformer
|
| 875 |
-
'transformer_input_layernorm.': 'input_layer_norm.',
|
| 876 |
-
'transformer.': 'transformer_layers.',
|
| 877 |
-
|
| 878 |
-
# Gaussian modules
|
| 879 |
-
'upsampler.': 'gaussian_upsampler.',
|
| 880 |
-
'image_token_decoder.': 'pixel_gaussian_decoder.',
|
| 881 |
-
|
| 882 |
-
# Rendering modules
|
| 883 |
-
'renderer.': 'gaussian_renderer.',
|
| 884 |
-
'loss_computer.': 'loss_calculator.',
|
| 885 |
-
}
|
| 886 |
-
|
| 887 |
-
# Create new state dict with translated names
|
| 888 |
-
new_state_dict = {}
|
| 889 |
-
|
| 890 |
-
for old_key, value in state_dict.items():
|
| 891 |
-
new_key = old_key
|
| 892 |
-
|
| 893 |
-
# Apply name mappings
|
| 894 |
-
for old_pattern, new_pattern in name_mapping.items():
|
| 895 |
-
if old_key.startswith(old_pattern):
|
| 896 |
-
new_key = old_key.replace(old_pattern, new_pattern, 1)
|
| 897 |
-
break
|
| 898 |
-
|
| 899 |
-
# Fix specific key naming issues
|
| 900 |
-
# Change loss_computer.perceptual_loss_module.Net to loss_computer.perceptual_loss_module.net
|
| 901 |
-
if "loss_computer.perceptual_loss_module.Net" in new_key:
|
| 902 |
-
old_net_key = new_key
|
| 903 |
-
new_key = new_key.replace("loss_computer.perceptual_loss_module.Net", "loss_computer.perceptual_loss_module.net")
|
| 904 |
-
print(f"Renamed checkpoint key: {old_net_key} -> {new_key}")
|
| 905 |
-
# Also handle the new naming convention
|
| 906 |
-
elif "loss_calculator.perceptual_loss_module.Net" in new_key:
|
| 907 |
-
old_net_key = new_key
|
| 908 |
-
new_key = new_key.replace("loss_calculator.perceptual_loss_module.Net", "loss_calculator.perceptual_loss_module.net")
|
| 909 |
-
print(f"Renamed checkpoint key: {old_net_key} -> {new_key}")
|
| 910 |
-
|
| 911 |
-
new_state_dict[new_key] = value
|
| 912 |
-
|
| 913 |
-
return new_state_dict
|
| 914 |
-
|
| 915 |
-
def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
|
| 916 |
-
"""
|
| 917 |
-
Load model state dict with automatic legacy name translation.
|
| 918 |
-
|
| 919 |
-
Args:
|
| 920 |
-
state_dict: Model state dictionary (potentially with old parameter names)
|
| 921 |
-
strict: Whether to strictly enforce parameter name matching
|
| 922 |
-
"""
|
| 923 |
-
# Check if this is a legacy state dict by looking for old parameter names
|
| 924 |
-
legacy_indicators = [
|
| 925 |
-
'image_tokenizer.',
|
| 926 |
-
'refsrc_marker',
|
| 927 |
-
'gaussians_pos_embedding',
|
| 928 |
-
'transformer_input_layernorm.',
|
| 929 |
-
'upsampler.',
|
| 930 |
-
'image_token_decoder.',
|
| 931 |
-
'renderer.',
|
| 932 |
-
'loss_computer.'
|
| 933 |
-
]
|
| 934 |
-
|
| 935 |
-
is_legacy = any(
|
| 936 |
-
any(key.startswith(indicator) for key in state_dict.keys())
|
| 937 |
-
for indicator in legacy_indicators
|
| 938 |
-
)
|
| 939 |
-
|
| 940 |
-
if is_legacy:
|
| 941 |
-
print("Detected legacy model format. Translating parameter names...")
|
| 942 |
-
state_dict = self.translate_legacy_state_dict(state_dict)
|
| 943 |
-
print("Parameter name translation completed.")
|
| 944 |
-
|
| 945 |
-
# Load the (potentially translated) state dict
|
| 946 |
-
return super().load_state_dict(state_dict, strict=strict)
|
| 947 |
-
|
| 948 |
-
@classmethod
|
| 949 |
-
def load_from_checkpoint(
|
| 950 |
-
cls,
|
| 951 |
-
checkpoint_path: str,
|
| 952 |
-
config: edict,
|
| 953 |
-
map_location: Optional[str] = None
|
| 954 |
-
) -> 'GSLRM':
|
| 955 |
-
"""
|
| 956 |
-
Load model from checkpoint with automatic legacy name translation.
|
| 957 |
-
|
| 958 |
-
Args:
|
| 959 |
-
checkpoint_path: Path to the checkpoint file
|
| 960 |
-
config: Model configuration
|
| 961 |
-
map_location: Device to map tensors to (e.g., 'cpu', 'cuda:0')
|
| 962 |
-
|
| 963 |
-
Returns:
|
| 964 |
-
Loaded GSLRM model
|
| 965 |
-
"""
|
| 966 |
-
# Create model instance
|
| 967 |
-
model = cls(config)
|
| 968 |
-
|
| 969 |
-
# Load checkpoint
|
| 970 |
-
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
| 971 |
-
|
| 972 |
-
# Extract state dict (handle different checkpoint formats)
|
| 973 |
-
if isinstance(checkpoint, dict):
|
| 974 |
-
if 'model_state_dict' in checkpoint:
|
| 975 |
-
state_dict = checkpoint['model_state_dict']
|
| 976 |
-
elif 'state_dict' in checkpoint:
|
| 977 |
-
state_dict = checkpoint['state_dict']
|
| 978 |
-
else:
|
| 979 |
-
state_dict = checkpoint
|
| 980 |
-
else:
|
| 981 |
-
state_dict = checkpoint
|
| 982 |
-
|
| 983 |
-
# Load state dict with automatic translation
|
| 984 |
-
model.load_state_dict(state_dict)
|
| 985 |
-
|
| 986 |
-
print(f"Successfully loaded model from {checkpoint_path}")
|
| 987 |
-
return model
|
| 988 |
-
|
| 989 |
def _create_gaussian_models_and_stats(
|
| 990 |
self,
|
| 991 |
xyz: torch.Tensor,
|
|
@@ -1180,7 +702,6 @@ class GSLRM(nn.Module):
|
|
| 1180 |
)
|
| 1181 |
|
| 1182 |
# Perform rendering and loss computation if target data is available
|
| 1183 |
-
loss_metrics = None
|
| 1184 |
rendered_images = None
|
| 1185 |
|
| 1186 |
if target_data is not None:
|
|
@@ -1193,17 +714,6 @@ class GSLRM(nn.Module):
|
|
| 1193 |
C2W=target_data.c2w,
|
| 1194 |
fxfycxcy=target_data.fxfycxcy,
|
| 1195 |
)
|
| 1196 |
-
|
| 1197 |
-
# Compute losses if rendered and target have matching dimensions
|
| 1198 |
-
if rendered_images.shape[1] == target_data.image.shape[1]:
|
| 1199 |
-
loss_metrics = self.loss_calculator(
|
| 1200 |
-
rendered_images,
|
| 1201 |
-
target_data.image,
|
| 1202 |
-
pixel_aligned_xyz,
|
| 1203 |
-
input_data,
|
| 1204 |
-
create_visual=create_visual,
|
| 1205 |
-
result_softpa=gaussian_splat_result,
|
| 1206 |
-
)
|
| 1207 |
|
| 1208 |
# Create Gaussian models for each batch item and compute usage statistics
|
| 1209 |
gaussian_models, pixel_aligned_positions, usage_statistics = self._create_gaussian_models_and_stats(
|
|
@@ -1211,12 +721,6 @@ class GSLRM(nn.Module):
|
|
| 1211 |
num_pixel_aligned_gaussians, num_views, height, width, patch_size
|
| 1212 |
)
|
| 1213 |
|
| 1214 |
-
# Add usage statistics to loss metrics for logging
|
| 1215 |
-
if loss_metrics is not None:
|
| 1216 |
-
loss_metrics.gaussians_usage = torch.tensor(
|
| 1217 |
-
np.mean(np.array(usage_statistics))
|
| 1218 |
-
).float()
|
| 1219 |
-
|
| 1220 |
# Compile final results
|
| 1221 |
return edict(
|
| 1222 |
input=input_data,
|
|
@@ -1224,424 +728,6 @@ class GSLRM(nn.Module):
|
|
| 1224 |
gaussians=gaussian_models,
|
| 1225 |
pixelalign_xyz=pixel_aligned_positions,
|
| 1226 |
img_tokens=image_patch_tokens,
|
| 1227 |
-
loss_metrics=
|
| 1228 |
render=rendered_images,
|
| 1229 |
-
)
|
| 1230 |
-
|
| 1231 |
-
@torch.no_grad()
|
| 1232 |
-
def save_visualization_outputs(
|
| 1233 |
-
self,
|
| 1234 |
-
output_directory: str,
|
| 1235 |
-
model_results: edict,
|
| 1236 |
-
batch_data: edict,
|
| 1237 |
-
save_all_items: bool = False
|
| 1238 |
-
) -> None:
|
| 1239 |
-
"""
|
| 1240 |
-
Save visualization outputs including rendered images and Gaussian models.
|
| 1241 |
-
|
| 1242 |
-
Args:
|
| 1243 |
-
output_directory: Directory to save outputs
|
| 1244 |
-
model_results: Results from model forward pass
|
| 1245 |
-
batch_data: Original batch data
|
| 1246 |
-
save_all_items: Whether to save all batch items or just the first
|
| 1247 |
-
"""
|
| 1248 |
-
os.makedirs(output_directory, exist_ok=True)
|
| 1249 |
-
|
| 1250 |
-
input_data, target_data = model_results.input, model_results.target
|
| 1251 |
-
|
| 1252 |
-
# Save supervision visualization if available
|
| 1253 |
-
if (model_results.loss_metrics is not None and
|
| 1254 |
-
model_results.loss_metrics.visual is not None):
|
| 1255 |
-
|
| 1256 |
-
batch_uids = [
|
| 1257 |
-
target_data.index[b, 0, -1].item()
|
| 1258 |
-
for b in range(target_data.index.size(0))
|
| 1259 |
-
]
|
| 1260 |
-
|
| 1261 |
-
uid_range = f"{batch_uids[0]:08}_{batch_uids[-1]:08}"
|
| 1262 |
-
|
| 1263 |
-
# Save supervision comparison image
|
| 1264 |
-
Image.fromarray(model_results.loss_metrics.visual).save(
|
| 1265 |
-
os.path.join(output_directory, f"supervision_{uid_range}.jpg")
|
| 1266 |
-
)
|
| 1267 |
-
|
| 1268 |
-
# Save UIDs for reference
|
| 1269 |
-
with open(os.path.join(output_directory, "uids.txt"), "w") as f:
|
| 1270 |
-
uid_string = "_".join([f"{uid:08}" for uid in batch_uids])
|
| 1271 |
-
f.write(uid_string)
|
| 1272 |
-
|
| 1273 |
-
# Save input images
|
| 1274 |
-
input_visualization = rearrange(
|
| 1275 |
-
input_data.image, "batch views channels height width -> (batch height) (views width) channels"
|
| 1276 |
-
)
|
| 1277 |
-
input_visualization = (
|
| 1278 |
-
(input_visualization.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
|
| 1279 |
-
)
|
| 1280 |
-
Image.fromarray(input_visualization[..., :3]).save(
|
| 1281 |
-
os.path.join(output_directory, f"input_{uid_range}.jpg")
|
| 1282 |
-
)
|
| 1283 |
-
|
| 1284 |
-
# Process each batch item individually
|
| 1285 |
-
batch_size = input_data.image.size(0)
|
| 1286 |
-
for batch_idx in range(batch_size):
|
| 1287 |
-
item_uid = input_data.index[batch_idx, 0, -1].item()
|
| 1288 |
-
|
| 1289 |
-
# Render turntable visualization
|
| 1290 |
-
turntable_image = render_turntable(model_results.gaussians[batch_idx])
|
| 1291 |
-
Image.fromarray(turntable_image).save(
|
| 1292 |
-
os.path.join(output_directory, f"turntable_{item_uid}.jpg")
|
| 1293 |
-
)
|
| 1294 |
-
|
| 1295 |
-
# Save individual input images during inference
|
| 1296 |
-
if self.config.inference:
|
| 1297 |
-
individual_input = rearrange(
|
| 1298 |
-
input_data.image[batch_idx], "views channels height width -> height (views width) channels"
|
| 1299 |
-
)
|
| 1300 |
-
individual_input = (
|
| 1301 |
-
(individual_input.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
|
| 1302 |
-
)
|
| 1303 |
-
Image.fromarray(individual_input[..., :3]).save(
|
| 1304 |
-
os.path.join(output_directory, f"input_{item_uid}.jpg")
|
| 1305 |
-
)
|
| 1306 |
-
|
| 1307 |
-
# Extract image dimensions and create opacity/depth visualizations
|
| 1308 |
-
_, num_views, _, img_height, img_width = input_data.image.size()
|
| 1309 |
-
patch_size = self.config.model.image_tokenizer.patch_size
|
| 1310 |
-
|
| 1311 |
-
# Get opacity values for pixel-aligned Gaussians
|
| 1312 |
-
gaussian_opacity = model_results.gaussians[batch_idx].get_opacity
|
| 1313 |
-
pixel_opacity = gaussian_opacity[-num_views * img_height * img_width:]
|
| 1314 |
-
|
| 1315 |
-
# Reshape opacity to image format
|
| 1316 |
-
opacity_visualization = rearrange(
|
| 1317 |
-
pixel_opacity,
|
| 1318 |
-
"(views height width patch_h patch_w) channels -> (height patch_h) (views width patch_w) channels",
|
| 1319 |
-
views=num_views,
|
| 1320 |
-
height=img_height // patch_size,
|
| 1321 |
-
width=img_width // patch_size,
|
| 1322 |
-
patch_h=patch_size,
|
| 1323 |
-
patch_w=patch_size,
|
| 1324 |
-
).squeeze(-1).cpu().numpy()
|
| 1325 |
-
opacity_visualization = (opacity_visualization * 255.0).clip(0.0, 255.0).astype(np.uint8)
|
| 1326 |
-
|
| 1327 |
-
# Get 3D positions and compute depth visualization
|
| 1328 |
-
gaussian_positions = model_results.gaussians[batch_idx].get_xyz
|
| 1329 |
-
pixel_positions = gaussian_positions[-num_views * img_height * img_width:]
|
| 1330 |
-
|
| 1331 |
-
# Reshape positions to image format
|
| 1332 |
-
pixel_positions_reshaped = rearrange(
|
| 1333 |
-
pixel_positions,
|
| 1334 |
-
"(views height width patch_h patch_w) coords -> views coords (height patch_h) (width patch_w)",
|
| 1335 |
-
views=num_views,
|
| 1336 |
-
height=img_height // patch_size,
|
| 1337 |
-
width=img_width // patch_size,
|
| 1338 |
-
patch_h=patch_size,
|
| 1339 |
-
patch_w=patch_size,
|
| 1340 |
-
)
|
| 1341 |
-
|
| 1342 |
-
# Compute distances from ray origins
|
| 1343 |
-
ray_distances = (pixel_positions_reshaped - input_data.ray_o[batch_idx]).norm(dim=1, p=2)
|
| 1344 |
-
distance_visualization = rearrange(ray_distances, "views height width -> height (views width)")
|
| 1345 |
-
distance_visualization = distance_visualization.cpu().numpy()
|
| 1346 |
-
|
| 1347 |
-
# Normalize distances for visualization
|
| 1348 |
-
dist_min, dist_max = distance_visualization.min(), distance_visualization.max()
|
| 1349 |
-
distance_visualization = (distance_visualization - dist_min) / (dist_max - dist_min)
|
| 1350 |
-
distance_visualization = (distance_visualization * 255.0).clip(0.0, 255.0).astype(np.uint8)
|
| 1351 |
-
|
| 1352 |
-
# Combine opacity and depth visualizations
|
| 1353 |
-
combined_visualization = np.concatenate([opacity_visualization, distance_visualization], axis=0)
|
| 1354 |
-
Image.fromarray(combined_visualization).save(
|
| 1355 |
-
os.path.join(output_directory, f"aligned_gs_opacity_depth_{item_uid}.jpg")
|
| 1356 |
-
)
|
| 1357 |
-
|
| 1358 |
-
# Save unfiltered Gaussian model for small images during early training
|
| 1359 |
-
if (self.config.model.image_tokenizer.image_size <= 256 and
|
| 1360 |
-
self.training_step is not None and self.training_step <= 5000):
|
| 1361 |
-
model_results.gaussians[batch_idx].save_ply(
|
| 1362 |
-
os.path.join(output_directory, f"gaussians_{item_uid}_unfiltered.ply")
|
| 1363 |
-
)
|
| 1364 |
-
|
| 1365 |
-
# Save filtered Gaussian model
|
| 1366 |
-
camera_origins = None # Could use input_data.ray_o[batch_idx, :, :, 0, 0] if needed
|
| 1367 |
-
default_crop_box = [-0.91, 0.91, -0.91, 0.91, -0.91, 0.91]
|
| 1368 |
-
|
| 1369 |
-
model_results.gaussians[batch_idx].apply_all_filters(
|
| 1370 |
-
opacity_thres=0.02,
|
| 1371 |
-
crop_bbx=default_crop_box,
|
| 1372 |
-
cam_origins=camera_origins,
|
| 1373 |
-
nearfar_percent=(0.0001, 1.0),
|
| 1374 |
-
).save_ply(os.path.join(output_directory, f"gaussians_{item_uid}.ply"))
|
| 1375 |
-
|
| 1376 |
-
print(f"Saved visualization for UID: {item_uid}")
|
| 1377 |
-
|
| 1378 |
-
# Break after first item unless saving all
|
| 1379 |
-
if not save_all_items:
|
| 1380 |
-
break
|
| 1381 |
-
|
| 1382 |
-
@torch.no_grad()
|
| 1383 |
-
def save_visuals(self, out_dir: str, result: edict, batch: edict, save_all: bool = False) -> None:
|
| 1384 |
-
"""Backward compatibility wrapper for save_visualization_outputs."""
|
| 1385 |
-
self.save_visualization_outputs(out_dir, result, batch, save_all)
|
| 1386 |
-
|
| 1387 |
-
@torch.no_grad()
|
| 1388 |
-
def save_evaluation_results(
|
| 1389 |
-
self,
|
| 1390 |
-
output_directory: str,
|
| 1391 |
-
model_results: edict,
|
| 1392 |
-
batch_data: edict,
|
| 1393 |
-
dataset
|
| 1394 |
-
) -> None:
|
| 1395 |
-
"""Save comprehensive evaluation results including metrics, visualizations, and 3D models."""
|
| 1396 |
-
from .utils_metrics import compute_psnr, compute_lpips, compute_ssim
|
| 1397 |
-
|
| 1398 |
-
os.makedirs(output_directory, exist_ok=True)
|
| 1399 |
-
input_data, target_data = model_results.input, model_results.target
|
| 1400 |
-
|
| 1401 |
-
for batch_idx in range(input_data.image.size(0)):
|
| 1402 |
-
item_uid = input_data.index[batch_idx, 0, -1].item()
|
| 1403 |
-
item_output_dir = os.path.join(output_directory, f"{item_uid:08d}")
|
| 1404 |
-
os.makedirs(item_output_dir, exist_ok=True)
|
| 1405 |
-
|
| 1406 |
-
# Save input image
|
| 1407 |
-
input_image = rearrange(
|
| 1408 |
-
input_data.image[batch_idx], "views channels height width -> height (views width) channels"
|
| 1409 |
-
)
|
| 1410 |
-
input_image = (input_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
|
| 1411 |
-
Image.fromarray(input_image[..., :3]).save(os.path.join(item_output_dir, "input.png"))
|
| 1412 |
-
|
| 1413 |
-
# Save ground truth vs prediction comparison
|
| 1414 |
-
comparison_image = torch.stack((target_data.image[batch_idx], model_results.render[batch_idx]), dim=0)
|
| 1415 |
-
num_views = comparison_image.size(1)
|
| 1416 |
-
if num_views > 10:
|
| 1417 |
-
comparison_image = comparison_image[:, ::num_views // 10, :, :, :]
|
| 1418 |
-
comparison_image = rearrange(
|
| 1419 |
-
comparison_image, "comparison_type views channels height width -> (comparison_type height) (views width) channels"
|
| 1420 |
-
)
|
| 1421 |
-
comparison_image = (comparison_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
|
| 1422 |
-
Image.fromarray(comparison_image).save(os.path.join(item_output_dir, "gt_vs_pred.png"))
|
| 1423 |
-
|
| 1424 |
-
# Compute and save metrics
|
| 1425 |
-
per_view_psnr = compute_psnr(target_data.image[batch_idx], model_results.render[batch_idx])
|
| 1426 |
-
per_view_lpips = compute_lpips(target_data.image[batch_idx], model_results.render[batch_idx])
|
| 1427 |
-
per_view_ssim = compute_ssim(target_data.image[batch_idx], model_results.render[batch_idx])
|
| 1428 |
-
|
| 1429 |
-
# Save per-view metrics
|
| 1430 |
-
view_ids = target_data.index[batch_idx, :, 0].cpu().numpy()
|
| 1431 |
-
with open(os.path.join(item_output_dir, "perview_metrics.txt"), "w") as f:
|
| 1432 |
-
for i in range(per_view_psnr.size(0)):
|
| 1433 |
-
f.write(
|
| 1434 |
-
f"view {view_ids[i]:0>6}, psnr: {per_view_psnr[i].item():.4f}, "
|
| 1435 |
-
f"lpips: {per_view_lpips[i].item():.4f}, ssim: {per_view_ssim[i].item():.4f}\n"
|
| 1436 |
-
)
|
| 1437 |
-
|
| 1438 |
-
# Save average metrics
|
| 1439 |
-
avg_psnr = per_view_psnr.mean().item()
|
| 1440 |
-
avg_lpips = per_view_lpips.mean().item()
|
| 1441 |
-
avg_ssim = per_view_ssim.mean().item()
|
| 1442 |
-
|
| 1443 |
-
with open(os.path.join(item_output_dir, "metrics.txt"), "w") as f:
|
| 1444 |
-
f.write(f"psnr: {avg_psnr:.4f}\nlpips: {avg_lpips:.4f}\nssim: {avg_ssim:.4f}\n")
|
| 1445 |
-
|
| 1446 |
-
print(f"UID {item_uid}: PSNR={avg_psnr:.4f}, LPIPS={avg_lpips:.4f}, SSIM={avg_ssim:.4f}")
|
| 1447 |
-
|
| 1448 |
-
# Save Gaussian model
|
| 1449 |
-
crop_box = None
|
| 1450 |
-
if self.config.model.get("clip_xyz", False):
|
| 1451 |
-
if self.config.model.get("half_bbx_size", None) is not None:
|
| 1452 |
-
half_size = self.config.model.half_bbx_size
|
| 1453 |
-
crop_box = [-half_size, half_size, -half_size, half_size, -half_size, half_size]
|
| 1454 |
-
else:
|
| 1455 |
-
crop_box = [-0.91, 0.91, -0.91, 0.91, -0.91, 0.91]
|
| 1456 |
-
|
| 1457 |
-
model_results.gaussians[batch_idx].apply_all_filters(
|
| 1458 |
-
opacity_thres=0.02, crop_bbx=crop_box, cam_origins=None, nearfar_percent=(0.0001, 1.0)
|
| 1459 |
-
).save_ply(os.path.join(item_output_dir, "gaussians.ply"))
|
| 1460 |
-
|
| 1461 |
-
# Create turntable visualization
|
| 1462 |
-
num_turntable_views = 150
|
| 1463 |
-
render_resolution = input_image.shape[0]
|
| 1464 |
-
|
| 1465 |
-
turntable_frames = render_turntable(
|
| 1466 |
-
model_results.gaussians[batch_idx], rendering_resolution=render_resolution, num_views=num_turntable_views
|
| 1467 |
-
)
|
| 1468 |
-
turntable_frames = rearrange(
|
| 1469 |
-
turntable_frames, "height (views width) channels -> views height width channels", views=num_turntable_views
|
| 1470 |
-
)
|
| 1471 |
-
turntable_frames = np.ascontiguousarray(turntable_frames)
|
| 1472 |
-
|
| 1473 |
-
# Save basic turntable video
|
| 1474 |
-
imageseq2video(turntable_frames, os.path.join(item_output_dir, "turntable.mp4"), fps=30)
|
| 1475 |
-
|
| 1476 |
-
# Save description and preview if available
|
| 1477 |
-
try:
|
| 1478 |
-
description = dataset.get_description(item_uid)["prompt"]
|
| 1479 |
-
if len(description) > 0:
|
| 1480 |
-
with open(os.path.join(item_output_dir, "description.txt"), "w") as f:
|
| 1481 |
-
f.write(description)
|
| 1482 |
-
|
| 1483 |
-
# Create preview image (subsample to 10 views)
|
| 1484 |
-
preview_frames = turntable_frames[::num_turntable_views // 10]
|
| 1485 |
-
preview_image = rearrange(preview_frames, "views height width channels -> height (views width) channels")
|
| 1486 |
-
Image.fromarray(preview_image).save(os.path.join(item_output_dir, "turntable_preview.png"))
|
| 1487 |
-
except (AttributeError, KeyError):
|
| 1488 |
-
pass
|
| 1489 |
-
|
| 1490 |
-
# Create turntable with input overlay
|
| 1491 |
-
border_width = 2
|
| 1492 |
-
target_width = render_resolution
|
| 1493 |
-
target_height = int(input_image.shape[0] / input_image.shape[1] * target_width)
|
| 1494 |
-
|
| 1495 |
-
resized_input = cv2.resize(
|
| 1496 |
-
input_image, (target_width - border_width * 2, target_height - border_width * 2), interpolation=cv2.INTER_AREA
|
| 1497 |
-
)
|
| 1498 |
-
bordered_input = np.pad(
|
| 1499 |
-
resized_input, ((border_width, border_width), (border_width, border_width), (0, 0)),
|
| 1500 |
-
mode="constant", constant_values=200
|
| 1501 |
-
)
|
| 1502 |
-
|
| 1503 |
-
input_sequence = np.tile(bordered_input[None], (turntable_frames.shape[0], 1, 1, 1))
|
| 1504 |
-
combined_frames = np.concatenate((turntable_frames, input_sequence), axis=1)
|
| 1505 |
-
|
| 1506 |
-
imageseq2video(combined_frames, os.path.join(item_output_dir, "turntable_with_input.mp4"), fps=30)
|
| 1507 |
-
|
| 1508 |
-
@torch.no_grad()
|
| 1509 |
-
def save_evaluations(self, out_dir: str, result: edict, batch: edict, dataset) -> None:
|
| 1510 |
-
"""Backward compatibility wrapper for save_evaluation_results."""
|
| 1511 |
-
self.save_evaluation_results(out_dir, result, batch, dataset)
|
| 1512 |
-
|
| 1513 |
-
@torch.no_grad()
|
| 1514 |
-
def save_validation_results(
|
| 1515 |
-
self,
|
| 1516 |
-
output_directory: str,
|
| 1517 |
-
model_results: edict,
|
| 1518 |
-
batch_data: edict,
|
| 1519 |
-
dataset,
|
| 1520 |
-
save_visualizations: bool = False
|
| 1521 |
-
) -> Dict[str, float]:
|
| 1522 |
-
"""Save validation results and compute aggregated metrics."""
|
| 1523 |
-
from .utils_metrics import compute_psnr, compute_lpips, compute_ssim
|
| 1524 |
-
|
| 1525 |
-
os.makedirs(output_directory, exist_ok=True)
|
| 1526 |
-
input_data, target_data = model_results.input, model_results.target
|
| 1527 |
-
validation_metrics = {"psnr": [], "lpips": [], "ssim": []}
|
| 1528 |
-
|
| 1529 |
-
for batch_idx in range(input_data.image.size(0)):
|
| 1530 |
-
item_uid = input_data.index[batch_idx, 0, -1].item()
|
| 1531 |
-
should_save_visuals = (batch_idx == 0) and save_visualizations
|
| 1532 |
-
|
| 1533 |
-
# Compute metrics (RGB only)
|
| 1534 |
-
target_image = target_data.image[batch_idx][:, :3, ...]
|
| 1535 |
-
per_view_psnr = compute_psnr(target_image, model_results.render[batch_idx])
|
| 1536 |
-
per_view_lpips = compute_lpips(target_image, model_results.render[batch_idx])
|
| 1537 |
-
per_view_ssim = compute_ssim(target_image, model_results.render[batch_idx])
|
| 1538 |
-
|
| 1539 |
-
avg_psnr = per_view_psnr.mean().item()
|
| 1540 |
-
avg_lpips = per_view_lpips.mean().item()
|
| 1541 |
-
avg_ssim = per_view_ssim.mean().item()
|
| 1542 |
-
|
| 1543 |
-
validation_metrics["psnr"].append(avg_psnr)
|
| 1544 |
-
validation_metrics["lpips"].append(avg_lpips)
|
| 1545 |
-
validation_metrics["ssim"].append(avg_ssim)
|
| 1546 |
-
|
| 1547 |
-
# Save visualizations only for first item if requested
|
| 1548 |
-
if should_save_visuals:
|
| 1549 |
-
item_output_dir = os.path.join(output_directory, f"{item_uid:08d}")
|
| 1550 |
-
os.makedirs(item_output_dir, exist_ok=True)
|
| 1551 |
-
|
| 1552 |
-
# Save input image
|
| 1553 |
-
input_image = rearrange(
|
| 1554 |
-
input_data.image[batch_idx][:, :3, ...], "views channels height width -> height (views width) channels"
|
| 1555 |
-
)
|
| 1556 |
-
input_image = (input_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
|
| 1557 |
-
Image.fromarray(input_image).save(os.path.join(item_output_dir, "input.png"))
|
| 1558 |
-
|
| 1559 |
-
# Save ground truth vs prediction comparison
|
| 1560 |
-
comparison_image = torch.stack((target_image, model_results.render[batch_idx]), dim=0)
|
| 1561 |
-
num_views = comparison_image.size(1)
|
| 1562 |
-
if num_views > 10:
|
| 1563 |
-
comparison_image = comparison_image[:, ::num_views // 10, :, :, :]
|
| 1564 |
-
comparison_image = rearrange(
|
| 1565 |
-
comparison_image, "comparison_type views channels height width -> (comparison_type height) (views width) channels"
|
| 1566 |
-
)
|
| 1567 |
-
comparison_image = (comparison_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
|
| 1568 |
-
Image.fromarray(comparison_image).save(os.path.join(item_output_dir, "gt_vs_pred.png"))
|
| 1569 |
-
|
| 1570 |
-
# Save per-view metrics
|
| 1571 |
-
view_ids = target_data.index[batch_idx, :, 0].cpu().numpy()
|
| 1572 |
-
with open(os.path.join(item_output_dir, "perview_metrics.txt"), "w") as f:
|
| 1573 |
-
for i in range(per_view_psnr.size(0)):
|
| 1574 |
-
f.write(
|
| 1575 |
-
f"view {view_ids[i]:0>6}, psnr: {per_view_psnr[i].item():.4f}, "
|
| 1576 |
-
f"lpips: {per_view_lpips[i].item():.4f}, ssim: {per_view_ssim[i].item():.4f}\n"
|
| 1577 |
-
)
|
| 1578 |
-
|
| 1579 |
-
# Save averaged metrics
|
| 1580 |
-
with open(os.path.join(item_output_dir, "metrics.txt"), "w") as f:
|
| 1581 |
-
f.write(f"psnr: {avg_psnr:.4f}\nlpips: {avg_lpips:.4f}\nssim: {avg_ssim:.4f}\n")
|
| 1582 |
-
|
| 1583 |
-
print(f"Validation UID {item_uid}: PSNR={avg_psnr:.4f}, LPIPS={avg_lpips:.4f}, SSIM={avg_ssim:.4f}")
|
| 1584 |
-
|
| 1585 |
-
# Save Gaussian model
|
| 1586 |
-
crop_box = None
|
| 1587 |
-
if self.config.model.get("clip_xyz", False):
|
| 1588 |
-
if self.config.model.get("half_bbx_size", None) is not None:
|
| 1589 |
-
half_size = self.config.model.half_bbx_size
|
| 1590 |
-
crop_box = [-half_size, half_size, -half_size, half_size, -half_size, half_size]
|
| 1591 |
-
else:
|
| 1592 |
-
crop_box = [-0.91, 0.91, -0.91, 0.91, -0.91, 0.91]
|
| 1593 |
-
|
| 1594 |
-
model_results.gaussians[batch_idx].apply_all_filters(
|
| 1595 |
-
opacity_thres=0.02, crop_bbx=crop_box, cam_origins=None, nearfar_percent=(0.0001, 1.0)
|
| 1596 |
-
).save_ply(os.path.join(item_output_dir, "gaussians.ply"))
|
| 1597 |
-
|
| 1598 |
-
# Create turntable visualization
|
| 1599 |
-
num_turntable_views = 150
|
| 1600 |
-
render_resolution = input_image.shape[0]
|
| 1601 |
-
|
| 1602 |
-
turntable_frames = render_turntable(
|
| 1603 |
-
model_results.gaussians[batch_idx], rendering_resolution=render_resolution, num_views=num_turntable_views
|
| 1604 |
-
)
|
| 1605 |
-
turntable_frames = rearrange(
|
| 1606 |
-
turntable_frames, "height (views width) channels -> views height width channels", views=num_turntable_views
|
| 1607 |
-
)
|
| 1608 |
-
turntable_frames = np.ascontiguousarray(turntable_frames)
|
| 1609 |
-
|
| 1610 |
-
imageseq2video(turntable_frames, os.path.join(item_output_dir, "turntable.mp4"), fps=30)
|
| 1611 |
-
|
| 1612 |
-
# Create turntable with input overlay
|
| 1613 |
-
border_width = 2
|
| 1614 |
-
target_width = render_resolution
|
| 1615 |
-
target_height = int(input_image.shape[0] / input_image.shape[1] * target_width)
|
| 1616 |
-
|
| 1617 |
-
resized_input = cv2.resize(
|
| 1618 |
-
input_image, (target_width - border_width * 2, target_height - border_width * 2), interpolation=cv2.INTER_AREA
|
| 1619 |
-
)
|
| 1620 |
-
bordered_input = np.pad(
|
| 1621 |
-
resized_input, ((border_width, border_width), (border_width, border_width), (0, 0)),
|
| 1622 |
-
mode="constant", constant_values=200
|
| 1623 |
-
)
|
| 1624 |
-
|
| 1625 |
-
input_sequence = np.tile(bordered_input[None], (turntable_frames.shape[0], 1, 1, 1))
|
| 1626 |
-
combined_frames = np.concatenate((turntable_frames, input_sequence), axis=1)
|
| 1627 |
-
|
| 1628 |
-
imageseq2video(combined_frames, os.path.join(item_output_dir, "turntable_with_input.mp4"), fps=30)
|
| 1629 |
-
|
| 1630 |
-
# Return averaged metrics
|
| 1631 |
-
return {
|
| 1632 |
-
"psnr": torch.tensor(validation_metrics["psnr"]).mean().item(),
|
| 1633 |
-
"lpips": torch.tensor(validation_metrics["lpips"]).mean().item(),
|
| 1634 |
-
"ssim": torch.tensor(validation_metrics["ssim"]).mean().item(),
|
| 1635 |
-
}
|
| 1636 |
-
|
| 1637 |
-
@torch.no_grad()
|
| 1638 |
-
def save_validations(
|
| 1639 |
-
self,
|
| 1640 |
-
out_dir: str,
|
| 1641 |
-
result: edict,
|
| 1642 |
-
batch: edict,
|
| 1643 |
-
dataset,
|
| 1644 |
-
save_img: bool = False
|
| 1645 |
-
) -> Dict[str, float]:
|
| 1646 |
-
"""Backward compatibility wrapper for save_validation_results."""
|
| 1647 |
-
return self.save_validation_results(out_dir, result, batch, dataset, save_img)
|
|
|
|
| 22 |
"""
|
| 23 |
|
| 24 |
import copy
|
| 25 |
+
from typing import List, Optional, Tuple
|
|
|
|
|
|
|
| 26 |
|
|
|
|
| 27 |
import lpips
|
| 28 |
import numpy as np
|
| 29 |
import torch
|
|
|
|
| 32 |
from easydict import EasyDict as edict
|
| 33 |
from einops import rearrange
|
| 34 |
from einops.layers.torch import Rearrange
|
|
|
|
| 35 |
|
| 36 |
# Local imports
|
| 37 |
from .utils_losses import PerceptualLoss, SsimLoss
|
| 38 |
from .gaussians_renderer import (
|
| 39 |
GaussianModel,
|
|
|
|
| 40 |
deferred_gaussian_render,
|
|
|
|
| 41 |
render_opencv_cam,
|
|
|
|
| 42 |
)
|
| 43 |
from .transform_data import SplitData, TransformInput, TransformTarget
|
| 44 |
from .utils_transformer import (
|
|
|
|
| 218 |
|
| 219 |
return xyz, features, scaling, rotation, opacity
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
class GSLRM(nn.Module):
|
| 222 |
"""
|
| 223 |
Gaussian Splatting Large Reconstruction Model.
|
|
|
|
| 336 |
def _init_rendering_modules(self, config: edict) -> None:
|
| 337 |
"""Initialize rendering and loss computation modules."""
|
| 338 |
self.gaussian_renderer = Renderer(config)
|
|
|
|
| 339 |
|
| 340 |
def _init_training_state(self, config: edict) -> None:
|
| 341 |
"""Initialize training state management variables."""
|
|
|
|
| 344 |
self.training_max_step = None
|
| 345 |
self.original_config = copy.deepcopy(config)
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
def _create_transformer_layer_runner(self, start_layer: int, end_layer: int):
|
| 349 |
"""
|
|
|
|
| 508 |
|
| 509 |
return aligned_positions
|
| 510 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
def _create_gaussian_models_and_stats(
|
| 512 |
self,
|
| 513 |
xyz: torch.Tensor,
|
|
|
|
| 702 |
)
|
| 703 |
|
| 704 |
# Perform rendering and loss computation if target data is available
|
|
|
|
| 705 |
rendered_images = None
|
| 706 |
|
| 707 |
if target_data is not None:
|
|
|
|
| 714 |
C2W=target_data.c2w,
|
| 715 |
fxfycxcy=target_data.fxfycxcy,
|
| 716 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
|
| 718 |
# Create Gaussian models for each batch item and compute usage statistics
|
| 719 |
gaussian_models, pixel_aligned_positions, usage_statistics = self._create_gaussian_models_and_stats(
|
|
|
|
| 721 |
num_pixel_aligned_gaussians, num_views, height, width, patch_size
|
| 722 |
)
|
| 723 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
# Compile final results
|
| 725 |
return edict(
|
| 726 |
input=input_data,
|
|
|
|
| 728 |
gaussians=gaussian_models,
|
| 729 |
pixelalign_xyz=pixel_aligned_positions,
|
| 730 |
img_tokens=image_patch_tokens,
|
| 731 |
+
loss_metrics=None,
|
| 732 |
render=rendered_images,
|
| 733 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gslrm/model/utils_losses.py
DELETED
|
@@ -1,309 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2025, FaceLift Research Group
|
| 2 |
-
# https://github.com/weijielyu/FaceLift
|
| 3 |
-
#
|
| 4 |
-
# This software is free for non-commercial, research and evaluation use
|
| 5 |
-
# under the terms of the LICENSE.md file.
|
| 6 |
-
#
|
| 7 |
-
# For inquiries contact: wlyu3@ucmerced.edu
|
| 8 |
-
|
| 9 |
-
"""
|
| 10 |
-
Perceptual Loss Implementation using VGG19 and SSIM Loss Implementation.
|
| 11 |
-
|
| 12 |
-
Adapted from https://github.com/zhengqili/Crowdsampling-the-Plenoptic-Function/blob/f5216f312cf82d77f8d20454b5eeb3930324630a/models/networks.py#L1478
|
| 13 |
-
"""
|
| 14 |
-
import os
|
| 15 |
-
from typing import List, Tuple, Union, Optional
|
| 16 |
-
|
| 17 |
-
import scipy.io
|
| 18 |
-
import torch
|
| 19 |
-
import torch.nn as nn
|
| 20 |
-
from pytorch_msssim import SSIM
|
| 21 |
-
|
| 22 |
-
# VGG19 ImageNet normalization constants
|
| 23 |
-
IMAGENET_MEAN = [123.6800, 116.7790, 103.9390]
|
| 24 |
-
|
| 25 |
-
# VGG19 layer configuration
|
| 26 |
-
VGG19_LAYER_INDICES = [0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34]
|
| 27 |
-
VGG19_LAYER_NAMES = [
|
| 28 |
-
"conv1", "conv2", "conv3", "conv4", "conv5", "conv6", "conv7", "conv8",
|
| 29 |
-
"conv9", "conv10", "conv11", "conv12", "conv13", "conv14", "conv15", "conv16"
|
| 30 |
-
]
|
| 31 |
-
VGG19_CHANNEL_SIZES = [64, 64, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 512]
|
| 32 |
-
|
| 33 |
-
# Perceptual loss weighting factors
|
| 34 |
-
LAYER_WEIGHTS = [1.0, 1/2.6, 1/4.8, 1/3.7, 1/5.6, 10/1.5]
|
| 35 |
-
|
| 36 |
-
class VGG19(nn.Module):
|
| 37 |
-
"""
|
| 38 |
-
VGG19 network implementation for perceptual loss computation.
|
| 39 |
-
|
| 40 |
-
This class implements the VGG19 architecture with specific layer outputs
|
| 41 |
-
used for computing perceptual losses at different scales.
|
| 42 |
-
"""
|
| 43 |
-
|
| 44 |
-
def __init__(self) -> None:
|
| 45 |
-
"""Initialize VGG19 network layers."""
|
| 46 |
-
super(VGG19, self).__init__()
|
| 47 |
-
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True)
|
| 48 |
-
self.relu1 = nn.ReLU(inplace=True)
|
| 49 |
-
|
| 50 |
-
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
|
| 51 |
-
self.relu2 = nn.ReLU(inplace=True)
|
| 52 |
-
self.max1 = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 53 |
-
|
| 54 |
-
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True)
|
| 55 |
-
self.relu3 = nn.ReLU(inplace=True)
|
| 56 |
-
|
| 57 |
-
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True)
|
| 58 |
-
self.relu4 = nn.ReLU(inplace=True)
|
| 59 |
-
self.max2 = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 60 |
-
|
| 61 |
-
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True)
|
| 62 |
-
self.relu5 = nn.ReLU(inplace=True)
|
| 63 |
-
|
| 64 |
-
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True)
|
| 65 |
-
self.relu6 = nn.ReLU(inplace=True)
|
| 66 |
-
|
| 67 |
-
self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True)
|
| 68 |
-
self.relu7 = nn.ReLU(inplace=True)
|
| 69 |
-
|
| 70 |
-
self.conv8 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True)
|
| 71 |
-
self.relu8 = nn.ReLU(inplace=True)
|
| 72 |
-
self.max3 = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 73 |
-
|
| 74 |
-
self.conv9 = nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=True)
|
| 75 |
-
self.relu9 = nn.ReLU(inplace=True)
|
| 76 |
-
|
| 77 |
-
self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
|
| 78 |
-
self.relu10 = nn.ReLU(inplace=True)
|
| 79 |
-
|
| 80 |
-
self.conv11 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
|
| 81 |
-
self.relu11 = nn.ReLU(inplace=True)
|
| 82 |
-
|
| 83 |
-
self.conv12 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
|
| 84 |
-
self.relu12 = nn.ReLU(inplace=True)
|
| 85 |
-
self.max4 = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 86 |
-
|
| 87 |
-
self.conv13 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
|
| 88 |
-
self.relu13 = nn.ReLU(inplace=True)
|
| 89 |
-
|
| 90 |
-
self.conv14 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
|
| 91 |
-
self.relu14 = nn.ReLU(inplace=True)
|
| 92 |
-
|
| 93 |
-
self.conv15 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
|
| 94 |
-
self.relu15 = nn.ReLU(inplace=True)
|
| 95 |
-
|
| 96 |
-
self.conv16 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
|
| 97 |
-
self.relu16 = nn.ReLU(inplace=True)
|
| 98 |
-
self.max5 = nn.AvgPool2d(kernel_size=2, stride=2)
|
| 99 |
-
|
| 100 |
-
def forward(self, x: torch.Tensor, return_style: int) -> Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]:
|
| 101 |
-
"""
|
| 102 |
-
Forward pass through VGG19 network.
|
| 103 |
-
|
| 104 |
-
Args:
|
| 105 |
-
x: Input tensor of shape [B, 3, H, W]
|
| 106 |
-
return_style: If > 0, return style features as list; otherwise return content features as tuple
|
| 107 |
-
|
| 108 |
-
Returns:
|
| 109 |
-
Either a list of style features or tuple of content features from different layers
|
| 110 |
-
"""
|
| 111 |
-
out1 = self.conv1(x)
|
| 112 |
-
out2 = self.relu1(out1)
|
| 113 |
-
|
| 114 |
-
out3 = self.conv2(out2)
|
| 115 |
-
out4 = self.relu2(out3)
|
| 116 |
-
out5 = self.max1(out4)
|
| 117 |
-
|
| 118 |
-
out6 = self.conv3(out5)
|
| 119 |
-
out7 = self.relu3(out6)
|
| 120 |
-
out8 = self.conv4(out7)
|
| 121 |
-
out9 = self.relu4(out8)
|
| 122 |
-
out10 = self.max2(out9)
|
| 123 |
-
out11 = self.conv5(out10)
|
| 124 |
-
out12 = self.relu5(out11)
|
| 125 |
-
out13 = self.conv6(out12)
|
| 126 |
-
out14 = self.relu6(out13)
|
| 127 |
-
out15 = self.conv7(out14)
|
| 128 |
-
out16 = self.relu7(out15)
|
| 129 |
-
out17 = self.conv8(out16)
|
| 130 |
-
out18 = self.relu8(out17)
|
| 131 |
-
out19 = self.max3(out18)
|
| 132 |
-
out20 = self.conv9(out19)
|
| 133 |
-
out21 = self.relu9(out20)
|
| 134 |
-
out22 = self.conv10(out21)
|
| 135 |
-
out23 = self.relu10(out22)
|
| 136 |
-
out24 = self.conv11(out23)
|
| 137 |
-
out25 = self.relu11(out24)
|
| 138 |
-
out26 = self.conv12(out25)
|
| 139 |
-
out27 = self.relu12(out26)
|
| 140 |
-
out28 = self.max4(out27)
|
| 141 |
-
out29 = self.conv13(out28)
|
| 142 |
-
out30 = self.relu13(out29)
|
| 143 |
-
out31 = self.conv14(out30)
|
| 144 |
-
out32 = self.relu14(out31)
|
| 145 |
-
|
| 146 |
-
if return_style > 0:
|
| 147 |
-
return [out2, out7, out12, out21, out30]
|
| 148 |
-
else:
|
| 149 |
-
return out4, out9, out14, out23, out32
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
class PerceptualLoss(nn.Module):
|
| 153 |
-
"""
|
| 154 |
-
Perceptual Loss module using pre-trained VGG19.
|
| 155 |
-
|
| 156 |
-
This class implements perceptual loss by comparing features extracted from
|
| 157 |
-
different layers of a pre-trained VGG19 network. It computes weighted
|
| 158 |
-
differences across multiple scales to capture both low-level and high-level
|
| 159 |
-
visual differences between images.
|
| 160 |
-
"""
|
| 161 |
-
|
| 162 |
-
def __init__(self, device: str = "cpu", weight_file: Optional[str] = None) -> None:
|
| 163 |
-
"""
|
| 164 |
-
Initialize PerceptualLoss module.
|
| 165 |
-
|
| 166 |
-
Args:
|
| 167 |
-
device: Device to run computations on ('cpu' or 'cuda')
|
| 168 |
-
weight_file: Path to VGG19 weight file. If None, uses default path or environment variable.
|
| 169 |
-
|
| 170 |
-
Raises:
|
| 171 |
-
FileNotFoundError: If weight file is not found
|
| 172 |
-
RuntimeError: If weight file cannot be loaded
|
| 173 |
-
"""
|
| 174 |
-
super().__init__()
|
| 175 |
-
self.device = device
|
| 176 |
-
self.net = VGG19()
|
| 177 |
-
|
| 178 |
-
# Determine weight file path
|
| 179 |
-
if weight_file is None:
|
| 180 |
-
# Check environment variable first
|
| 181 |
-
weight_file = os.environ.get('VGG19_WEIGHTS_PATH')
|
| 182 |
-
if weight_file is None:
|
| 183 |
-
# Fallback to default path
|
| 184 |
-
weight_file = "/sensei-fs/users/kaiz/repos/weight-collections/imagenet-vgg-verydeep-19.mat"
|
| 185 |
-
|
| 186 |
-
# Load VGG19 weights
|
| 187 |
-
if not os.path.isfile(weight_file):
|
| 188 |
-
raise FileNotFoundError(
|
| 189 |
-
f"VGG19 weight file not found: {weight_file}\n"
|
| 190 |
-
f"Download it from: https://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat\n"
|
| 191 |
-
f"Expected MD5: 106118b7cf60435e6d8e04f6a6dc3657"
|
| 192 |
-
)
|
| 193 |
-
|
| 194 |
-
try:
|
| 195 |
-
vgg_rawnet = scipy.io.loadmat(weight_file)
|
| 196 |
-
vgg_layers = vgg_rawnet["layers"][0]
|
| 197 |
-
except Exception as e:
|
| 198 |
-
raise RuntimeError(f"Failed to load VGG19 weights from {weight_file}: {e}")
|
| 199 |
-
|
| 200 |
-
# Load pre-trained weights into the network
|
| 201 |
-
self._load_pretrained_weights(vgg_layers)
|
| 202 |
-
|
| 203 |
-
# Set network to evaluation mode and freeze parameters
|
| 204 |
-
self.net = self.net.eval().to(device)
|
| 205 |
-
for param in self.net.parameters():
|
| 206 |
-
param.requires_grad = False
|
| 207 |
-
|
| 208 |
-
def _load_pretrained_weights(self, vgg_layers) -> None:
|
| 209 |
-
"""Load pre-trained VGG19 weights into the network."""
|
| 210 |
-
for layer_idx in range(len(VGG19_LAYER_NAMES)):
|
| 211 |
-
layer_name = VGG19_LAYER_NAMES[layer_idx]
|
| 212 |
-
mat_layer_idx = VGG19_LAYER_INDICES[layer_idx]
|
| 213 |
-
channel_size = VGG19_CHANNEL_SIZES[layer_idx]
|
| 214 |
-
|
| 215 |
-
# Extract weights and biases from MATLAB format
|
| 216 |
-
layer_weights = torch.from_numpy(
|
| 217 |
-
vgg_layers[mat_layer_idx][0][0][2][0][0]
|
| 218 |
-
).permute(3, 2, 0, 1)
|
| 219 |
-
layer_biases = torch.from_numpy(
|
| 220 |
-
vgg_layers[mat_layer_idx][0][0][2][0][1]
|
| 221 |
-
).view(channel_size)
|
| 222 |
-
|
| 223 |
-
# Assign to network
|
| 224 |
-
getattr(self.net, layer_name).weight = nn.Parameter(layer_weights)
|
| 225 |
-
getattr(self.net, layer_name).bias = nn.Parameter(layer_biases)
|
| 226 |
-
|
| 227 |
-
def _compute_l1_error(self, truth: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
|
| 228 |
-
"""
|
| 229 |
-
Compute L1 (Mean Absolute Error) between two tensors.
|
| 230 |
-
|
| 231 |
-
Args:
|
| 232 |
-
truth: Ground truth tensor
|
| 233 |
-
pred: Predicted tensor
|
| 234 |
-
|
| 235 |
-
Returns:
|
| 236 |
-
L1 error as a scalar tensor
|
| 237 |
-
"""
|
| 238 |
-
return torch.mean(torch.abs(truth - pred))
|
| 239 |
-
|
| 240 |
-
def forward(self, pred_img: torch.Tensor, real_img: torch.Tensor) -> torch.Tensor:
|
| 241 |
-
"""
|
| 242 |
-
Compute perceptual loss between predicted and real images.
|
| 243 |
-
|
| 244 |
-
Args:
|
| 245 |
-
pred_img: Predicted image tensor of shape [B, 3, H, W] in range [0, 1]
|
| 246 |
-
real_img: Real image tensor of shape [B, 3, H, W] in range [0, 1]
|
| 247 |
-
|
| 248 |
-
Returns:
|
| 249 |
-
Perceptual loss as a scalar tensor
|
| 250 |
-
"""
|
| 251 |
-
# Convert to ImageNet normalization (RGB -> BGR and subtract mean)
|
| 252 |
-
imagenet_mean = torch.tensor(IMAGENET_MEAN, dtype=torch.float32, device=pred_img.device)
|
| 253 |
-
imagenet_mean = imagenet_mean.view(1, 3, 1, 1)
|
| 254 |
-
|
| 255 |
-
# Scale to [0, 255] and apply ImageNet normalization
|
| 256 |
-
real_img_normalized = real_img * 255.0 - imagenet_mean
|
| 257 |
-
pred_img_normalized = pred_img * 255.0 - imagenet_mean
|
| 258 |
-
|
| 259 |
-
# Extract features from both images
|
| 260 |
-
real_features = self.net(real_img_normalized, return_style=0)
|
| 261 |
-
pred_features = self.net(pred_img_normalized, return_style=0)
|
| 262 |
-
|
| 263 |
-
# Compute weighted L1 losses at different scales
|
| 264 |
-
losses = []
|
| 265 |
-
|
| 266 |
-
# Raw image loss
|
| 267 |
-
raw_loss = self._compute_l1_error(real_img_normalized, pred_img_normalized)
|
| 268 |
-
losses.append(raw_loss * LAYER_WEIGHTS[0])
|
| 269 |
-
|
| 270 |
-
# Feature losses at different VGG layers
|
| 271 |
-
for i, (real_feat, pred_feat) in enumerate(zip(real_features, pred_features)):
|
| 272 |
-
feature_loss = self._compute_l1_error(real_feat, pred_feat)
|
| 273 |
-
losses.append(feature_loss * LAYER_WEIGHTS[i + 1])
|
| 274 |
-
|
| 275 |
-
# Combine all losses and normalize
|
| 276 |
-
total_loss = sum(losses) / 255.0
|
| 277 |
-
return total_loss
|
| 278 |
-
|
| 279 |
-
class SsimLoss(nn.Module):
|
| 280 |
-
"""
|
| 281 |
-
SSIM Loss module that computes 1 - SSIM for image similarity.
|
| 282 |
-
|
| 283 |
-
Args:
|
| 284 |
-
data_range: Range of input data (default: 1.0 for [0,1] range)
|
| 285 |
-
"""
|
| 286 |
-
|
| 287 |
-
def __init__(self, data_range: float = 1.0) -> None:
|
| 288 |
-
super().__init__()
|
| 289 |
-
self.data_range = data_range
|
| 290 |
-
self.ssim_module = SSIM(
|
| 291 |
-
win_size=11,
|
| 292 |
-
win_sigma=1.5,
|
| 293 |
-
data_range=self.data_range,
|
| 294 |
-
size_average=True,
|
| 295 |
-
channel=3,
|
| 296 |
-
)
|
| 297 |
-
|
| 298 |
-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
| 299 |
-
"""
|
| 300 |
-
Compute SSIM loss between two image tensors.
|
| 301 |
-
|
| 302 |
-
Args:
|
| 303 |
-
x: Image tensor of shape (N, C, H, W)
|
| 304 |
-
y: Image tensor of shape (N, C, H, W)
|
| 305 |
-
|
| 306 |
-
Returns:
|
| 307 |
-
SSIM loss (1 - SSIM similarity)
|
| 308 |
-
"""
|
| 309 |
-
return 1.0 - self.ssim_module(x, y)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
splat_viewer.html
DELETED
|
@@ -1,277 +0,0 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html lang="en" dir="ltr">
|
| 3 |
-
<head>
|
| 4 |
-
<title>WebGL Gaussian Splat Viewer</title>
|
| 5 |
-
<meta charset="utf-8" />
|
| 6 |
-
<meta
|
| 7 |
-
name="viewport"
|
| 8 |
-
content="width=device-width, initial-scale=1, minimum-scale=1, maximum-scale=1, user-scalable=no"
|
| 9 |
-
/>
|
| 10 |
-
<meta name="apple-mobile-web-app-capable" content="yes" />
|
| 11 |
-
<meta
|
| 12 |
-
name="apple-mobile-web-app-status-bar-style"
|
| 13 |
-
content="black-translucent"
|
| 14 |
-
/>
|
| 15 |
-
<style>
|
| 16 |
-
body {
|
| 17 |
-
overflow: hidden;
|
| 18 |
-
margin: 0;
|
| 19 |
-
height: 100vh;
|
| 20 |
-
width: 100vw;
|
| 21 |
-
font-family: sans-serif;
|
| 22 |
-
background: black;
|
| 23 |
-
text-shadow: 0 0 3px black;
|
| 24 |
-
}
|
| 25 |
-
a, body {
|
| 26 |
-
color: white;
|
| 27 |
-
}
|
| 28 |
-
#info {
|
| 29 |
-
z-index: 100;
|
| 30 |
-
position: absolute;
|
| 31 |
-
top: 10px;
|
| 32 |
-
left: 15px;
|
| 33 |
-
}
|
| 34 |
-
h3 {
|
| 35 |
-
margin: 5px 0;
|
| 36 |
-
}
|
| 37 |
-
p {
|
| 38 |
-
margin: 5px 0;
|
| 39 |
-
font-size: small;
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
.cube-wrapper {
|
| 43 |
-
transform-style: preserve-3d;
|
| 44 |
-
}
|
| 45 |
-
|
| 46 |
-
.cube {
|
| 47 |
-
transform-style: preserve-3d;
|
| 48 |
-
transform: rotateX(45deg) rotateZ(45deg);
|
| 49 |
-
animation: rotation 2s infinite;
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
.cube-faces {
|
| 53 |
-
transform-style: preserve-3d;
|
| 54 |
-
height: 80px;
|
| 55 |
-
width: 80px;
|
| 56 |
-
position: relative;
|
| 57 |
-
transform-origin: 0 0;
|
| 58 |
-
transform: translateX(0) translateY(0) translateZ(-40px);
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
.cube-face {
|
| 62 |
-
position: absolute;
|
| 63 |
-
inset: 0;
|
| 64 |
-
background: #0017ff;
|
| 65 |
-
border: solid 1px #ffffff;
|
| 66 |
-
}
|
| 67 |
-
.cube-face.top {
|
| 68 |
-
transform: translateZ(80px);
|
| 69 |
-
}
|
| 70 |
-
.cube-face.front {
|
| 71 |
-
transform-origin: 0 50%;
|
| 72 |
-
transform: rotateY(-90deg);
|
| 73 |
-
}
|
| 74 |
-
.cube-face.back {
|
| 75 |
-
transform-origin: 0 50%;
|
| 76 |
-
transform: rotateY(-90deg) translateZ(-80px);
|
| 77 |
-
}
|
| 78 |
-
.cube-face.right {
|
| 79 |
-
transform-origin: 50% 0;
|
| 80 |
-
transform: rotateX(-90deg) translateY(-80px);
|
| 81 |
-
}
|
| 82 |
-
.cube-face.left {
|
| 83 |
-
transform-origin: 50% 0;
|
| 84 |
-
transform: rotateX(-90deg) translateY(-80px) translateZ(80px);
|
| 85 |
-
}
|
| 86 |
-
|
| 87 |
-
@keyframes rotation {
|
| 88 |
-
0% {
|
| 89 |
-
transform: rotateX(45deg) rotateY(0) rotateZ(45deg);
|
| 90 |
-
animation-timing-function: cubic-bezier(
|
| 91 |
-
0.17,
|
| 92 |
-
0.84,
|
| 93 |
-
0.44,
|
| 94 |
-
1
|
| 95 |
-
);
|
| 96 |
-
}
|
| 97 |
-
50% {
|
| 98 |
-
transform: rotateX(45deg) rotateY(0) rotateZ(225deg);
|
| 99 |
-
animation-timing-function: cubic-bezier(
|
| 100 |
-
0.76,
|
| 101 |
-
0.05,
|
| 102 |
-
0.86,
|
| 103 |
-
0.06
|
| 104 |
-
);
|
| 105 |
-
}
|
| 106 |
-
100% {
|
| 107 |
-
transform: rotateX(45deg) rotateY(0) rotateZ(405deg);
|
| 108 |
-
animation-timing-function: cubic-bezier(
|
| 109 |
-
0.17,
|
| 110 |
-
0.84,
|
| 111 |
-
0.44,
|
| 112 |
-
1
|
| 113 |
-
);
|
| 114 |
-
}
|
| 115 |
-
}
|
| 116 |
-
|
| 117 |
-
.scene,
|
| 118 |
-
#message {
|
| 119 |
-
position: absolute;
|
| 120 |
-
display: flex;
|
| 121 |
-
top: 0;
|
| 122 |
-
right: 0;
|
| 123 |
-
left: 0;
|
| 124 |
-
bottom: 0;
|
| 125 |
-
z-index: 2;
|
| 126 |
-
height: 100%;
|
| 127 |
-
width: 100%;
|
| 128 |
-
align-items: center;
|
| 129 |
-
justify-content: center;
|
| 130 |
-
}
|
| 131 |
-
#message {
|
| 132 |
-
font-weight: bold;
|
| 133 |
-
font-size: large;
|
| 134 |
-
color: red;
|
| 135 |
-
pointer-events: none;
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
details {
|
| 139 |
-
font-size: small;
|
| 140 |
-
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
#progress {
|
| 144 |
-
position: absolute;
|
| 145 |
-
top: 0;
|
| 146 |
-
height: 5px;
|
| 147 |
-
background: blue;
|
| 148 |
-
z-index: 99;
|
| 149 |
-
transition: width 0.1s ease-in-out;
|
| 150 |
-
}
|
| 151 |
-
|
| 152 |
-
#quality {
|
| 153 |
-
position: absolute;
|
| 154 |
-
bottom: 10px;
|
| 155 |
-
z-index: 999;
|
| 156 |
-
right: 10px;
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
#caminfo {
|
| 160 |
-
position: absolute;
|
| 161 |
-
top: 10px;
|
| 162 |
-
z-index: 999;
|
| 163 |
-
right: 10px;
|
| 164 |
-
}
|
| 165 |
-
#canvas {
|
| 166 |
-
display: block;
|
| 167 |
-
position: absolute;
|
| 168 |
-
top: 0;
|
| 169 |
-
left: 0;
|
| 170 |
-
width: 100%;
|
| 171 |
-
height: 100%;
|
| 172 |
-
touch-action: none;
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
#instructions {
|
| 176 |
-
background: rgba(0,0,0,0.6);
|
| 177 |
-
white-space: pre-wrap;
|
| 178 |
-
padding: 10px;
|
| 179 |
-
border-radius: 10px;
|
| 180 |
-
font-size: x-small;
|
| 181 |
-
}
|
| 182 |
-
body.nohf .nohf {
|
| 183 |
-
display: none;
|
| 184 |
-
}
|
| 185 |
-
body.nohf #progress, body.nohf .cube-face {
|
| 186 |
-
background: #ff9d0d;
|
| 187 |
-
}
|
| 188 |
-
</style>
|
| 189 |
-
</head>
|
| 190 |
-
<body>
|
| 191 |
-
<script>
|
| 192 |
-
if(location.host.includes('hf.space')) document.body.classList.add('nohf');
|
| 193 |
-
</script>
|
| 194 |
-
<div id="info">
|
| 195 |
-
<h3 class="nohf">WebGL 3D Gaussian Splat Viewer</h3>
|
| 196 |
-
<p>
|
| 197 |
-
<small class="nohf">
|
| 198 |
-
By <a href="https://twitter.com/antimatter15">Kevin Kwok</a>.
|
| 199 |
-
Code on
|
| 200 |
-
<a href="https://github.com/antimatter15/splat">Github</a
|
| 201 |
-
>.
|
| 202 |
-
</small>
|
| 203 |
-
</p>
|
| 204 |
-
|
| 205 |
-
<details>
|
| 206 |
-
<summary>Use mouse or arrow keys to navigate.</summary>
|
| 207 |
-
|
| 208 |
-
<div id="instructions">movement (arrow keys)
|
| 209 |
-
- left/right arrow keys to strafe side to side
|
| 210 |
-
- up/down arrow keys to move forward/back
|
| 211 |
-
- space to jump
|
| 212 |
-
|
| 213 |
-
camera angle (wasd)
|
| 214 |
-
- a/d to turn camera left/right
|
| 215 |
-
- w/s to tilt camera up/down
|
| 216 |
-
- q/e to roll camera counterclockwise/clockwise
|
| 217 |
-
- i/k and j/l to orbit
|
| 218 |
-
|
| 219 |
-
trackpad
|
| 220 |
-
- scroll up/down/left/right to orbit
|
| 221 |
-
- pinch to move forward/back
|
| 222 |
-
- ctrl key + scroll to move forward/back
|
| 223 |
-
- shift + scroll to move up/down or strafe
|
| 224 |
-
|
| 225 |
-
mouse
|
| 226 |
-
- click and drag to orbit
|
| 227 |
-
- right click (or ctrl/cmd key) and drag up/down to move
|
| 228 |
-
|
| 229 |
-
touch (mobile)
|
| 230 |
-
- one finger to orbit
|
| 231 |
-
- two finger pinch to move forward/back
|
| 232 |
-
- two finger rotate to rotate camera clockwise/counterclockwise
|
| 233 |
-
- two finger pan to move side-to-side and up-down
|
| 234 |
-
|
| 235 |
-
gamepad
|
| 236 |
-
- if you have a game controller connected it should work
|
| 237 |
-
|
| 238 |
-
other
|
| 239 |
-
- press 0-9 to switch to one of the pre-loaded camera views
|
| 240 |
-
- press '-' or '+'key to cycle loaded cameras
|
| 241 |
-
- press p to resume default animation
|
| 242 |
-
- drag and drop .ply file to convert to .splat
|
| 243 |
-
- drag and drop cameras.json to load cameras
|
| 244 |
-
</div>
|
| 245 |
-
|
| 246 |
-
</details>
|
| 247 |
-
|
| 248 |
-
</div>
|
| 249 |
-
|
| 250 |
-
<div id="progress"></div>
|
| 251 |
-
|
| 252 |
-
<div id="message"></div>
|
| 253 |
-
<div class="scene" id="spinner">
|
| 254 |
-
<div class="cube-wrapper">
|
| 255 |
-
<div class="cube">
|
| 256 |
-
<div class="cube-faces">
|
| 257 |
-
<div class="cube-face bottom"></div>
|
| 258 |
-
<div class="cube-face top"></div>
|
| 259 |
-
<div class="cube-face left"></div>
|
| 260 |
-
<div class="cube-face right"></div>
|
| 261 |
-
<div class="cube-face back"></div>
|
| 262 |
-
<div class="cube-face front"></div>
|
| 263 |
-
</div>
|
| 264 |
-
</div>
|
| 265 |
-
</div>
|
| 266 |
-
</div>
|
| 267 |
-
<canvas id="canvas"></canvas>
|
| 268 |
-
|
| 269 |
-
<div id="quality">
|
| 270 |
-
<span id="fps"></span>
|
| 271 |
-
</div>
|
| 272 |
-
<div id="caminfo">
|
| 273 |
-
<span id="camid"></span>
|
| 274 |
-
</div>
|
| 275 |
-
<script src="main.js"></script>
|
| 276 |
-
</body>
|
| 277 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|