wlyu-adobe commited on
Commit
8f04a1a
·
1 Parent(s): 133857a

Initial commit

Browse files
Files changed (3) hide show
  1. gslrm/model/gslrm.py +3 -917
  2. gslrm/model/utils_losses.py +0 -309
  3. splat_viewer.html +0 -277
gslrm/model/gslrm.py CHANGED
@@ -22,11 +22,8 @@ Classes:
22
  """
23
 
24
  import copy
25
- import os
26
- import time
27
- from typing import Dict, List, Optional, Tuple, Union
28
 
29
- import cv2
30
  import lpips
31
  import numpy as np
32
  import torch
@@ -35,17 +32,13 @@ import torch.nn.functional as F
35
  from easydict import EasyDict as edict
36
  from einops import rearrange
37
  from einops.layers.torch import Rearrange
38
- from PIL import Image
39
 
40
  # Local imports
41
  from .utils_losses import PerceptualLoss, SsimLoss
42
  from .gaussians_renderer import (
43
  GaussianModel,
44
- RGB2SH,
45
  deferred_gaussian_render,
46
- imageseq2video,
47
  render_opencv_cam,
48
- render_turntable,
49
  )
50
  from .transform_data import SplitData, TransformInput, TransformTarget
51
  from .utils_transformer import (
@@ -225,238 +218,6 @@ class GaussiansUpsampler(nn.Module):
225
 
226
  return xyz, features, scaling, rotation, opacity
227
 
228
-
229
- class LossComputer(nn.Module):
230
- """
231
- Computes various loss functions for training the GSLRM model.
232
-
233
- Supports multiple loss types:
234
- - L2 (MSE) loss
235
- - LPIPS perceptual loss
236
- - Custom perceptual loss
237
- - SSIM loss
238
- - Pixel alignment loss
239
- - Point distance regularization loss
240
- """
241
-
242
- def __init__(self, config: edict):
243
- super().__init__()
244
- self.config = config
245
-
246
- # Initialize loss modules based on config
247
- self._init_loss_modules()
248
-
249
- def _init_loss_modules(self):
250
- """Initialize the various loss computation modules."""
251
- # LPIPS loss
252
- if self.config.training.losses.lpips_loss_weight > 0.0:
253
- self.lpips_loss_module = lpips.LPIPS(net="vgg")
254
- self.lpips_loss_module.eval()
255
- # Freeze LPIPS parameters
256
- for param in self.lpips_loss_module.parameters():
257
- param.requires_grad = False
258
-
259
- # Perceptual loss
260
- if self.config.training.losses.perceptual_loss_weight > 0.0:
261
- self.perceptual_loss_module = PerceptualLoss()
262
- self.perceptual_loss_module.eval()
263
- # Freeze perceptual loss parameters
264
- for param in self.perceptual_loss_module.parameters():
265
- param.requires_grad = False
266
-
267
- # SSIM loss
268
- if self.config.training.losses.ssim_loss_weight > 0.0:
269
- self.ssim_loss_module = SsimLoss()
270
- self.ssim_loss_module.eval()
271
- # Freeze SSIM parameters
272
- for param in self.ssim_loss_module.parameters():
273
- param.requires_grad = False
274
-
275
- def forward(
276
- self,
277
- rendering: torch.Tensor, # [b, v, 3, h, w]
278
- target: torch.Tensor, # [b, v, 3, h, w]
279
- img_aligned_xyz: torch.Tensor, # [b, v, 3, h, w]
280
- input: edict,
281
- result_softpa: Optional[edict] = None,
282
- create_visual: bool = False,
283
- ) -> edict:
284
- """
285
- Compute all losses between rendered and target images.
286
-
287
- Args:
288
- rendering: Rendered images in range [0, 1]
289
- target: Target images in range [0, 1]
290
- img_aligned_xyz: Image-aligned 3D positions
291
- input: Input data containing ray information
292
- result_softpa: Additional results (unused)
293
- create_visual: Whether to create visualization images
294
-
295
- Returns:
296
- Dictionary containing all loss values and metrics
297
- """
298
- b, v, _, h, w = rendering.size()
299
- rendering_flat = rendering.reshape(b * v, -1, h, w)
300
- target_flat = target.reshape(b * v, -1, h, w)
301
-
302
- # Handle alpha channel if present
303
- mask = None
304
- if target_flat.size(1) == 4:
305
- target_flat, mask = target_flat.split([3, 1], dim=1)
306
-
307
- # Compute individual losses
308
- losses = self._compute_all_losses(
309
- rendering_flat, target_flat, img_aligned_xyz, input, mask, b, v, h, w
310
- )
311
-
312
- # Compute total weighted loss
313
- total_loss = self._compute_total_loss(losses)
314
-
315
- # Create visualization if requested
316
- visual = self._create_visual(rendering_flat, target_flat, v) if create_visual else None
317
-
318
- # Compile loss metrics
319
- return self._compile_loss_metrics(losses, total_loss, visual)
320
-
321
- def _compute_all_losses(self, rendering, target, img_aligned_xyz, input, mask, b, v, h, w):
322
- """Compute all individual loss components."""
323
- losses = {}
324
-
325
- # L2 (MSE) loss
326
- losses['l2'] = self._compute_l2_loss(rendering, target)
327
- losses['psnr'] = -10.0 * torch.log10(losses['l2'])
328
-
329
- # LPIPS loss
330
- losses['lpips'] = self._compute_lpips_loss(rendering, target)
331
-
332
- # Perceptual loss
333
- losses['perceptual'] = self._compute_perceptual_loss(rendering, target)
334
-
335
- # SSIM loss
336
- losses['ssim'] = self._compute_ssim_loss(rendering, target)
337
-
338
- # Pixel alignment loss
339
- losses['pixelalign'] = self._compute_pixelalign_loss(
340
- img_aligned_xyz, input, mask, b, v, h, w
341
- )
342
-
343
- # Point distance loss
344
- losses['pointsdist'] = self._compute_pointsdist_loss(
345
- img_aligned_xyz, input, b, v, h, w
346
- )
347
-
348
- return losses
349
-
350
- def _compute_l2_loss(self, rendering, target):
351
- """Compute L2 (MSE) loss."""
352
- if self.config.training.losses.l2_loss_weight > 0.0:
353
- return F.mse_loss(rendering, target)
354
- return torch.tensor(1e-8, device=rendering.device)
355
-
356
- def _compute_lpips_loss(self, rendering, target):
357
- """Compute LPIPS perceptual loss."""
358
- if self.config.training.losses.lpips_loss_weight > 0.0:
359
- # LPIPS expects inputs in range [-1, 1]
360
- return self.lpips_loss_module(
361
- rendering * 2.0 - 1.0, target * 2.0 - 1.0
362
- ).mean()
363
- return torch.tensor(0.0, device=rendering.device)
364
-
365
- def _compute_perceptual_loss(self, rendering, target):
366
- """Compute custom perceptual loss."""
367
- if self.config.training.losses.perceptual_loss_weight > 0.0:
368
- return self.perceptual_loss_module(rendering, target)
369
- return torch.tensor(0.0, device=rendering.device)
370
-
371
- def _compute_ssim_loss(self, rendering, target):
372
- """Compute SSIM loss."""
373
- if self.config.training.losses.ssim_loss_weight > 0.0:
374
- return self.ssim_loss_module(rendering, target)
375
- return torch.tensor(0.0, device=rendering.device)
376
-
377
- def _compute_pixelalign_loss(self, img_aligned_xyz, input, mask, b, v, h, w):
378
- """Compute pixel alignment loss."""
379
- if self.config.training.losses.pixelalign_loss_weight > 0.0:
380
- # Compute orthogonal component to ray direction
381
- xyz_vec = img_aligned_xyz - input.ray_o
382
- ortho_vec = (
383
- xyz_vec
384
- - torch.sum(xyz_vec.detach() * input.ray_d, dim=2, keepdim=True)
385
- * input.ray_d
386
- )
387
-
388
- # Apply mask if enabled
389
- if self.config.training.losses.get("masked_pixelalign_loss", False):
390
- assert mask is not None, "mask is None but masked_pixelalign_loss is enabled"
391
- mask_reshaped = mask.view(b, v, 1, h, w)
392
- ortho_vec = ortho_vec * mask_reshaped
393
-
394
- return torch.mean(ortho_vec.norm(dim=2, p=2))
395
-
396
- return torch.tensor(0.0, device=img_aligned_xyz.device)
397
-
398
- def _compute_pointsdist_loss(self, img_aligned_xyz, input, b, v, h, w):
399
- """Compute point distance regularization loss."""
400
- if self.config.training.losses.pointsdist_loss_weight > 0.0:
401
- # Target mean distance (distance from origin to ray origin)
402
- target_mean_dist = torch.norm(input.ray_o, dim=2, p=2, keepdim=True)
403
- target_std_dist = 0.5
404
-
405
- # Predicted distance
406
- pred_dist = (img_aligned_xyz - input.ray_o).norm(dim=2, p=2, keepdim=True)
407
-
408
- # Normalize to target distribution
409
- pred_dist_detach = pred_dist.detach()
410
- pred_mean = pred_dist_detach.mean(dim=(2, 3, 4), keepdim=True)
411
- pred_std = pred_dist_detach.std(dim=(2, 3, 4), keepdim=True)
412
-
413
- target_dist = (pred_dist_detach - pred_mean) / (pred_std + 1e-8) * target_std_dist + target_mean_dist
414
-
415
- return torch.mean((pred_dist - target_dist) ** 2)
416
-
417
- return torch.tensor(0.0, device=img_aligned_xyz.device)
418
-
419
- def _compute_total_loss(self, losses):
420
- """Compute weighted sum of all losses."""
421
- weights = self.config.training.losses
422
- return (
423
- weights.l2_loss_weight * losses['l2']
424
- + weights.lpips_loss_weight * losses['lpips']
425
- + weights.perceptual_loss_weight * losses['perceptual']
426
- + weights.ssim_loss_weight * losses['ssim']
427
- + weights.pixelalign_loss_weight * losses['pixelalign']
428
- + weights.pointsdist_loss_weight * losses['pointsdist']
429
- )
430
-
431
- def _create_visual(self, rendering, target, v):
432
- """Create visualization by concatenating target and rendering."""
433
- visual = torch.cat((target, rendering), dim=3).detach().cpu() # [b*v, c, h, w*2]
434
- visual = rearrange(visual, "(b v) c h (m w) -> (b h) (v m w) c", v=v, m=2)
435
- return (visual.numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
436
-
437
- def _compile_loss_metrics(self, losses, total_loss, visual):
438
- """Compile all loss metrics into a dictionary."""
439
- l2_loss = losses['l2']
440
-
441
- return edict(
442
- loss=total_loss,
443
- l2_loss=l2_loss,
444
- psnr=losses['psnr'],
445
- lpips_loss=losses['lpips'],
446
- perceptual_loss=losses['perceptual'],
447
- ssim_loss=losses['ssim'],
448
- pixelalign_loss=losses['pixelalign'],
449
- pointsdist_loss=losses['pointsdist'],
450
- visual=visual,
451
- # Normalized losses for logging
452
- norm_perceptual_loss=losses['perceptual'] / l2_loss,
453
- norm_lpips_loss=losses['lpips'] / l2_loss,
454
- norm_ssim_loss=losses['ssim'] / l2_loss,
455
- norm_pixelalign_loss=losses['pixelalign'] / l2_loss,
456
- norm_pointsdist_loss=losses['pointsdist'] / l2_loss,
457
- )
458
-
459
-
460
  class GSLRM(nn.Module):
461
  """
462
  Gaussian Splatting Large Reconstruction Model.
@@ -575,7 +336,6 @@ class GSLRM(nn.Module):
575
  def _init_rendering_modules(self, config: edict) -> None:
576
  """Initialize rendering and loss computation modules."""
577
  self.gaussian_renderer = Renderer(config)
578
- self.loss_calculator = LossComputer(config)
579
 
580
  def _init_training_state(self, config: edict) -> None:
581
  """Initialize training state management variables."""
@@ -584,101 +344,6 @@ class GSLRM(nn.Module):
584
  self.training_max_step = None
585
  self.original_config = copy.deepcopy(config)
586
 
587
- def set_training_step(self, current_step: int, start_step: int, max_step: int) -> None:
588
- """
589
- Update training step and dynamically adjust configuration based on training phase.
590
-
591
- Args:
592
- current_step: Current training step
593
- start_step: Starting step of training
594
- max_step: Maximum training steps
595
- """
596
- self.training_step = current_step
597
- self.training_start_step = start_step
598
- self.training_max_step = max_step
599
-
600
- # Determine if config modification is needed based on warmup settings
601
- needs_config_modification = self._should_modify_config_for_warmup(current_step)
602
-
603
- if needs_config_modification:
604
- # Always use original config as base for modifications
605
- self.config = copy.deepcopy(self.original_config)
606
- self._apply_warmup_modifications(current_step)
607
- else:
608
- # Restore original configuration
609
- self.config = self.original_config
610
-
611
- # Update loss calculator with current config
612
- self.loss_calculator.config = self.config
613
-
614
- def _should_modify_config_for_warmup(self, current_step: int) -> bool:
615
- """Check if configuration should be modified for warmup phases."""
616
- pointsdist_warmup = (
617
- self.config.training.losses.get("warmup_pointsdist", False)
618
- and current_step < 1000
619
- )
620
- l2_warmup = (
621
- self.config.training.schedule.get("l2_warmup_steps", 0) > 0
622
- and current_step < self.config.training.schedule.l2_warmup_steps
623
- )
624
- return pointsdist_warmup or l2_warmup
625
-
626
- def _apply_warmup_modifications(self, current_step: int) -> None:
627
- """Apply configuration modifications for warmup phases."""
628
- # Point distance warmup phase
629
- if (self.config.training.losses.get("warmup_pointsdist", False)
630
- and current_step < 1000):
631
- self.config.training.losses.l2_loss_weight = 0.0
632
- self.config.training.losses.perceptual_loss_weight = 0.0
633
- self.config.training.losses.pointsdist_loss_weight = 0.1
634
- self.config.model.clip_xyz = False # Disable xyz clipping during warmup
635
-
636
- # L2 loss warmup phase
637
- if (self.config.training.schedule.get("l2_warmup_steps", 0) > 0
638
- and current_step < self.config.training.schedule.l2_warmup_steps):
639
- self.config.training.losses.perceptual_loss_weight = 0.0
640
- self.config.training.losses.lpips_loss_weight = 0.0
641
-
642
- def set_current_step(self, current_step: int, start_step: int, max_step: int) -> None:
643
- """Backward compatibility wrapper for set_training_step."""
644
- self.set_training_step(current_step, start_step, max_step)
645
-
646
- def train(self, mode: bool = True) -> None:
647
- """
648
- Override train method to keep frozen modules in eval mode.
649
-
650
- Args:
651
- mode: Whether to set training mode (True) or evaluation mode (False)
652
- """
653
- super().train(mode)
654
- # Keep loss calculator in eval mode to prevent training of frozen components
655
- if self.loss_calculator is not None:
656
- self.loss_calculator.eval()
657
-
658
- def get_parameter_overview(self) -> edict:
659
- """
660
- Get overview of trainable parameters in each module.
661
-
662
- Returns:
663
- Dictionary containing parameter counts for each major component
664
- """
665
- def count_trainable_params(module: nn.Module) -> int:
666
- return sum(p.numel() for p in module.parameters() if p.requires_grad)
667
-
668
- return edict(
669
- patch_embedder=count_trainable_params(self.patch_embedder),
670
- gaussian_position_embeddings=self.gaussian_position_embeddings.data.numel(),
671
- transformer_total=(
672
- count_trainable_params(self.transformer_layers) +
673
- count_trainable_params(self.input_layer_norm)
674
- ),
675
- gaussian_upsampler=count_trainable_params(self.gaussian_upsampler),
676
- pixel_gaussian_decoder=count_trainable_params(self.pixel_gaussian_decoder),
677
- )
678
-
679
- def get_overview(self) -> edict:
680
- """Backward compatibility wrapper for get_parameter_overview."""
681
- return self.get_parameter_overview()
682
 
683
  def _create_transformer_layer_runner(self, start_layer: int, end_layer: int):
684
  """
@@ -843,149 +508,6 @@ class GSLRM(nn.Module):
843
 
844
  return aligned_positions
845
 
846
- @staticmethod
847
- def translate_legacy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
848
- """
849
- Translate legacy model parameter names to new parameter names.
850
-
851
- This function allows loading models saved with the old variable names
852
- by mapping them to the new, cleaner variable names.
853
-
854
- Args:
855
- state_dict: Dictionary containing model parameters with old names
856
-
857
- Returns:
858
- Dictionary with parameters mapped to new names
859
- """
860
- # Define the mapping from old names to new names
861
- name_mapping = {
862
- # Data processors
863
- 'split_data.': 'data_splitter.',
864
- 'transform_input.': 'input_transformer.',
865
- 'transform_target.': 'target_transformer.',
866
-
867
- # Tokenizer
868
- 'image_tokenizer.': 'patch_embedder.',
869
-
870
- # Positional embeddings
871
- 'refsrc_marker': 'view_type_embeddings',
872
- 'gaussians_pos_embedding': 'gaussian_position_embeddings',
873
-
874
- # Transformer
875
- 'transformer_input_layernorm.': 'input_layer_norm.',
876
- 'transformer.': 'transformer_layers.',
877
-
878
- # Gaussian modules
879
- 'upsampler.': 'gaussian_upsampler.',
880
- 'image_token_decoder.': 'pixel_gaussian_decoder.',
881
-
882
- # Rendering modules
883
- 'renderer.': 'gaussian_renderer.',
884
- 'loss_computer.': 'loss_calculator.',
885
- }
886
-
887
- # Create new state dict with translated names
888
- new_state_dict = {}
889
-
890
- for old_key, value in state_dict.items():
891
- new_key = old_key
892
-
893
- # Apply name mappings
894
- for old_pattern, new_pattern in name_mapping.items():
895
- if old_key.startswith(old_pattern):
896
- new_key = old_key.replace(old_pattern, new_pattern, 1)
897
- break
898
-
899
- # Fix specific key naming issues
900
- # Change loss_computer.perceptual_loss_module.Net to loss_computer.perceptual_loss_module.net
901
- if "loss_computer.perceptual_loss_module.Net" in new_key:
902
- old_net_key = new_key
903
- new_key = new_key.replace("loss_computer.perceptual_loss_module.Net", "loss_computer.perceptual_loss_module.net")
904
- print(f"Renamed checkpoint key: {old_net_key} -> {new_key}")
905
- # Also handle the new naming convention
906
- elif "loss_calculator.perceptual_loss_module.Net" in new_key:
907
- old_net_key = new_key
908
- new_key = new_key.replace("loss_calculator.perceptual_loss_module.Net", "loss_calculator.perceptual_loss_module.net")
909
- print(f"Renamed checkpoint key: {old_net_key} -> {new_key}")
910
-
911
- new_state_dict[new_key] = value
912
-
913
- return new_state_dict
914
-
915
- def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True):
916
- """
917
- Load model state dict with automatic legacy name translation.
918
-
919
- Args:
920
- state_dict: Model state dictionary (potentially with old parameter names)
921
- strict: Whether to strictly enforce parameter name matching
922
- """
923
- # Check if this is a legacy state dict by looking for old parameter names
924
- legacy_indicators = [
925
- 'image_tokenizer.',
926
- 'refsrc_marker',
927
- 'gaussians_pos_embedding',
928
- 'transformer_input_layernorm.',
929
- 'upsampler.',
930
- 'image_token_decoder.',
931
- 'renderer.',
932
- 'loss_computer.'
933
- ]
934
-
935
- is_legacy = any(
936
- any(key.startswith(indicator) for key in state_dict.keys())
937
- for indicator in legacy_indicators
938
- )
939
-
940
- if is_legacy:
941
- print("Detected legacy model format. Translating parameter names...")
942
- state_dict = self.translate_legacy_state_dict(state_dict)
943
- print("Parameter name translation completed.")
944
-
945
- # Load the (potentially translated) state dict
946
- return super().load_state_dict(state_dict, strict=strict)
947
-
948
- @classmethod
949
- def load_from_checkpoint(
950
- cls,
951
- checkpoint_path: str,
952
- config: edict,
953
- map_location: Optional[str] = None
954
- ) -> 'GSLRM':
955
- """
956
- Load model from checkpoint with automatic legacy name translation.
957
-
958
- Args:
959
- checkpoint_path: Path to the checkpoint file
960
- config: Model configuration
961
- map_location: Device to map tensors to (e.g., 'cpu', 'cuda:0')
962
-
963
- Returns:
964
- Loaded GSLRM model
965
- """
966
- # Create model instance
967
- model = cls(config)
968
-
969
- # Load checkpoint
970
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
971
-
972
- # Extract state dict (handle different checkpoint formats)
973
- if isinstance(checkpoint, dict):
974
- if 'model_state_dict' in checkpoint:
975
- state_dict = checkpoint['model_state_dict']
976
- elif 'state_dict' in checkpoint:
977
- state_dict = checkpoint['state_dict']
978
- else:
979
- state_dict = checkpoint
980
- else:
981
- state_dict = checkpoint
982
-
983
- # Load state dict with automatic translation
984
- model.load_state_dict(state_dict)
985
-
986
- print(f"Successfully loaded model from {checkpoint_path}")
987
- return model
988
-
989
  def _create_gaussian_models_and_stats(
990
  self,
991
  xyz: torch.Tensor,
@@ -1180,7 +702,6 @@ class GSLRM(nn.Module):
1180
  )
1181
 
1182
  # Perform rendering and loss computation if target data is available
1183
- loss_metrics = None
1184
  rendered_images = None
1185
 
1186
  if target_data is not None:
@@ -1193,17 +714,6 @@ class GSLRM(nn.Module):
1193
  C2W=target_data.c2w,
1194
  fxfycxcy=target_data.fxfycxcy,
1195
  )
1196
-
1197
- # Compute losses if rendered and target have matching dimensions
1198
- if rendered_images.shape[1] == target_data.image.shape[1]:
1199
- loss_metrics = self.loss_calculator(
1200
- rendered_images,
1201
- target_data.image,
1202
- pixel_aligned_xyz,
1203
- input_data,
1204
- create_visual=create_visual,
1205
- result_softpa=gaussian_splat_result,
1206
- )
1207
 
1208
  # Create Gaussian models for each batch item and compute usage statistics
1209
  gaussian_models, pixel_aligned_positions, usage_statistics = self._create_gaussian_models_and_stats(
@@ -1211,12 +721,6 @@ class GSLRM(nn.Module):
1211
  num_pixel_aligned_gaussians, num_views, height, width, patch_size
1212
  )
1213
 
1214
- # Add usage statistics to loss metrics for logging
1215
- if loss_metrics is not None:
1216
- loss_metrics.gaussians_usage = torch.tensor(
1217
- np.mean(np.array(usage_statistics))
1218
- ).float()
1219
-
1220
  # Compile final results
1221
  return edict(
1222
  input=input_data,
@@ -1224,424 +728,6 @@ class GSLRM(nn.Module):
1224
  gaussians=gaussian_models,
1225
  pixelalign_xyz=pixel_aligned_positions,
1226
  img_tokens=image_patch_tokens,
1227
- loss_metrics=loss_metrics,
1228
  render=rendered_images,
1229
- )
1230
-
1231
- @torch.no_grad()
1232
- def save_visualization_outputs(
1233
- self,
1234
- output_directory: str,
1235
- model_results: edict,
1236
- batch_data: edict,
1237
- save_all_items: bool = False
1238
- ) -> None:
1239
- """
1240
- Save visualization outputs including rendered images and Gaussian models.
1241
-
1242
- Args:
1243
- output_directory: Directory to save outputs
1244
- model_results: Results from model forward pass
1245
- batch_data: Original batch data
1246
- save_all_items: Whether to save all batch items or just the first
1247
- """
1248
- os.makedirs(output_directory, exist_ok=True)
1249
-
1250
- input_data, target_data = model_results.input, model_results.target
1251
-
1252
- # Save supervision visualization if available
1253
- if (model_results.loss_metrics is not None and
1254
- model_results.loss_metrics.visual is not None):
1255
-
1256
- batch_uids = [
1257
- target_data.index[b, 0, -1].item()
1258
- for b in range(target_data.index.size(0))
1259
- ]
1260
-
1261
- uid_range = f"{batch_uids[0]:08}_{batch_uids[-1]:08}"
1262
-
1263
- # Save supervision comparison image
1264
- Image.fromarray(model_results.loss_metrics.visual).save(
1265
- os.path.join(output_directory, f"supervision_{uid_range}.jpg")
1266
- )
1267
-
1268
- # Save UIDs for reference
1269
- with open(os.path.join(output_directory, "uids.txt"), "w") as f:
1270
- uid_string = "_".join([f"{uid:08}" for uid in batch_uids])
1271
- f.write(uid_string)
1272
-
1273
- # Save input images
1274
- input_visualization = rearrange(
1275
- input_data.image, "batch views channels height width -> (batch height) (views width) channels"
1276
- )
1277
- input_visualization = (
1278
- (input_visualization.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1279
- )
1280
- Image.fromarray(input_visualization[..., :3]).save(
1281
- os.path.join(output_directory, f"input_{uid_range}.jpg")
1282
- )
1283
-
1284
- # Process each batch item individually
1285
- batch_size = input_data.image.size(0)
1286
- for batch_idx in range(batch_size):
1287
- item_uid = input_data.index[batch_idx, 0, -1].item()
1288
-
1289
- # Render turntable visualization
1290
- turntable_image = render_turntable(model_results.gaussians[batch_idx])
1291
- Image.fromarray(turntable_image).save(
1292
- os.path.join(output_directory, f"turntable_{item_uid}.jpg")
1293
- )
1294
-
1295
- # Save individual input images during inference
1296
- if self.config.inference:
1297
- individual_input = rearrange(
1298
- input_data.image[batch_idx], "views channels height width -> height (views width) channels"
1299
- )
1300
- individual_input = (
1301
- (individual_input.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1302
- )
1303
- Image.fromarray(individual_input[..., :3]).save(
1304
- os.path.join(output_directory, f"input_{item_uid}.jpg")
1305
- )
1306
-
1307
- # Extract image dimensions and create opacity/depth visualizations
1308
- _, num_views, _, img_height, img_width = input_data.image.size()
1309
- patch_size = self.config.model.image_tokenizer.patch_size
1310
-
1311
- # Get opacity values for pixel-aligned Gaussians
1312
- gaussian_opacity = model_results.gaussians[batch_idx].get_opacity
1313
- pixel_opacity = gaussian_opacity[-num_views * img_height * img_width:]
1314
-
1315
- # Reshape opacity to image format
1316
- opacity_visualization = rearrange(
1317
- pixel_opacity,
1318
- "(views height width patch_h patch_w) channels -> (height patch_h) (views width patch_w) channels",
1319
- views=num_views,
1320
- height=img_height // patch_size,
1321
- width=img_width // patch_size,
1322
- patch_h=patch_size,
1323
- patch_w=patch_size,
1324
- ).squeeze(-1).cpu().numpy()
1325
- opacity_visualization = (opacity_visualization * 255.0).clip(0.0, 255.0).astype(np.uint8)
1326
-
1327
- # Get 3D positions and compute depth visualization
1328
- gaussian_positions = model_results.gaussians[batch_idx].get_xyz
1329
- pixel_positions = gaussian_positions[-num_views * img_height * img_width:]
1330
-
1331
- # Reshape positions to image format
1332
- pixel_positions_reshaped = rearrange(
1333
- pixel_positions,
1334
- "(views height width patch_h patch_w) coords -> views coords (height patch_h) (width patch_w)",
1335
- views=num_views,
1336
- height=img_height // patch_size,
1337
- width=img_width // patch_size,
1338
- patch_h=patch_size,
1339
- patch_w=patch_size,
1340
- )
1341
-
1342
- # Compute distances from ray origins
1343
- ray_distances = (pixel_positions_reshaped - input_data.ray_o[batch_idx]).norm(dim=1, p=2)
1344
- distance_visualization = rearrange(ray_distances, "views height width -> height (views width)")
1345
- distance_visualization = distance_visualization.cpu().numpy()
1346
-
1347
- # Normalize distances for visualization
1348
- dist_min, dist_max = distance_visualization.min(), distance_visualization.max()
1349
- distance_visualization = (distance_visualization - dist_min) / (dist_max - dist_min)
1350
- distance_visualization = (distance_visualization * 255.0).clip(0.0, 255.0).astype(np.uint8)
1351
-
1352
- # Combine opacity and depth visualizations
1353
- combined_visualization = np.concatenate([opacity_visualization, distance_visualization], axis=0)
1354
- Image.fromarray(combined_visualization).save(
1355
- os.path.join(output_directory, f"aligned_gs_opacity_depth_{item_uid}.jpg")
1356
- )
1357
-
1358
- # Save unfiltered Gaussian model for small images during early training
1359
- if (self.config.model.image_tokenizer.image_size <= 256 and
1360
- self.training_step is not None and self.training_step <= 5000):
1361
- model_results.gaussians[batch_idx].save_ply(
1362
- os.path.join(output_directory, f"gaussians_{item_uid}_unfiltered.ply")
1363
- )
1364
-
1365
- # Save filtered Gaussian model
1366
- camera_origins = None # Could use input_data.ray_o[batch_idx, :, :, 0, 0] if needed
1367
- default_crop_box = [-0.91, 0.91, -0.91, 0.91, -0.91, 0.91]
1368
-
1369
- model_results.gaussians[batch_idx].apply_all_filters(
1370
- opacity_thres=0.02,
1371
- crop_bbx=default_crop_box,
1372
- cam_origins=camera_origins,
1373
- nearfar_percent=(0.0001, 1.0),
1374
- ).save_ply(os.path.join(output_directory, f"gaussians_{item_uid}.ply"))
1375
-
1376
- print(f"Saved visualization for UID: {item_uid}")
1377
-
1378
- # Break after first item unless saving all
1379
- if not save_all_items:
1380
- break
1381
-
1382
- @torch.no_grad()
1383
- def save_visuals(self, out_dir: str, result: edict, batch: edict, save_all: bool = False) -> None:
1384
- """Backward compatibility wrapper for save_visualization_outputs."""
1385
- self.save_visualization_outputs(out_dir, result, batch, save_all)
1386
-
1387
- @torch.no_grad()
1388
- def save_evaluation_results(
1389
- self,
1390
- output_directory: str,
1391
- model_results: edict,
1392
- batch_data: edict,
1393
- dataset
1394
- ) -> None:
1395
- """Save comprehensive evaluation results including metrics, visualizations, and 3D models."""
1396
- from .utils_metrics import compute_psnr, compute_lpips, compute_ssim
1397
-
1398
- os.makedirs(output_directory, exist_ok=True)
1399
- input_data, target_data = model_results.input, model_results.target
1400
-
1401
- for batch_idx in range(input_data.image.size(0)):
1402
- item_uid = input_data.index[batch_idx, 0, -1].item()
1403
- item_output_dir = os.path.join(output_directory, f"{item_uid:08d}")
1404
- os.makedirs(item_output_dir, exist_ok=True)
1405
-
1406
- # Save input image
1407
- input_image = rearrange(
1408
- input_data.image[batch_idx], "views channels height width -> height (views width) channels"
1409
- )
1410
- input_image = (input_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1411
- Image.fromarray(input_image[..., :3]).save(os.path.join(item_output_dir, "input.png"))
1412
-
1413
- # Save ground truth vs prediction comparison
1414
- comparison_image = torch.stack((target_data.image[batch_idx], model_results.render[batch_idx]), dim=0)
1415
- num_views = comparison_image.size(1)
1416
- if num_views > 10:
1417
- comparison_image = comparison_image[:, ::num_views // 10, :, :, :]
1418
- comparison_image = rearrange(
1419
- comparison_image, "comparison_type views channels height width -> (comparison_type height) (views width) channels"
1420
- )
1421
- comparison_image = (comparison_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1422
- Image.fromarray(comparison_image).save(os.path.join(item_output_dir, "gt_vs_pred.png"))
1423
-
1424
- # Compute and save metrics
1425
- per_view_psnr = compute_psnr(target_data.image[batch_idx], model_results.render[batch_idx])
1426
- per_view_lpips = compute_lpips(target_data.image[batch_idx], model_results.render[batch_idx])
1427
- per_view_ssim = compute_ssim(target_data.image[batch_idx], model_results.render[batch_idx])
1428
-
1429
- # Save per-view metrics
1430
- view_ids = target_data.index[batch_idx, :, 0].cpu().numpy()
1431
- with open(os.path.join(item_output_dir, "perview_metrics.txt"), "w") as f:
1432
- for i in range(per_view_psnr.size(0)):
1433
- f.write(
1434
- f"view {view_ids[i]:0>6}, psnr: {per_view_psnr[i].item():.4f}, "
1435
- f"lpips: {per_view_lpips[i].item():.4f}, ssim: {per_view_ssim[i].item():.4f}\n"
1436
- )
1437
-
1438
- # Save average metrics
1439
- avg_psnr = per_view_psnr.mean().item()
1440
- avg_lpips = per_view_lpips.mean().item()
1441
- avg_ssim = per_view_ssim.mean().item()
1442
-
1443
- with open(os.path.join(item_output_dir, "metrics.txt"), "w") as f:
1444
- f.write(f"psnr: {avg_psnr:.4f}\nlpips: {avg_lpips:.4f}\nssim: {avg_ssim:.4f}\n")
1445
-
1446
- print(f"UID {item_uid}: PSNR={avg_psnr:.4f}, LPIPS={avg_lpips:.4f}, SSIM={avg_ssim:.4f}")
1447
-
1448
- # Save Gaussian model
1449
- crop_box = None
1450
- if self.config.model.get("clip_xyz", False):
1451
- if self.config.model.get("half_bbx_size", None) is not None:
1452
- half_size = self.config.model.half_bbx_size
1453
- crop_box = [-half_size, half_size, -half_size, half_size, -half_size, half_size]
1454
- else:
1455
- crop_box = [-0.91, 0.91, -0.91, 0.91, -0.91, 0.91]
1456
-
1457
- model_results.gaussians[batch_idx].apply_all_filters(
1458
- opacity_thres=0.02, crop_bbx=crop_box, cam_origins=None, nearfar_percent=(0.0001, 1.0)
1459
- ).save_ply(os.path.join(item_output_dir, "gaussians.ply"))
1460
-
1461
- # Create turntable visualization
1462
- num_turntable_views = 150
1463
- render_resolution = input_image.shape[0]
1464
-
1465
- turntable_frames = render_turntable(
1466
- model_results.gaussians[batch_idx], rendering_resolution=render_resolution, num_views=num_turntable_views
1467
- )
1468
- turntable_frames = rearrange(
1469
- turntable_frames, "height (views width) channels -> views height width channels", views=num_turntable_views
1470
- )
1471
- turntable_frames = np.ascontiguousarray(turntable_frames)
1472
-
1473
- # Save basic turntable video
1474
- imageseq2video(turntable_frames, os.path.join(item_output_dir, "turntable.mp4"), fps=30)
1475
-
1476
- # Save description and preview if available
1477
- try:
1478
- description = dataset.get_description(item_uid)["prompt"]
1479
- if len(description) > 0:
1480
- with open(os.path.join(item_output_dir, "description.txt"), "w") as f:
1481
- f.write(description)
1482
-
1483
- # Create preview image (subsample to 10 views)
1484
- preview_frames = turntable_frames[::num_turntable_views // 10]
1485
- preview_image = rearrange(preview_frames, "views height width channels -> height (views width) channels")
1486
- Image.fromarray(preview_image).save(os.path.join(item_output_dir, "turntable_preview.png"))
1487
- except (AttributeError, KeyError):
1488
- pass
1489
-
1490
- # Create turntable with input overlay
1491
- border_width = 2
1492
- target_width = render_resolution
1493
- target_height = int(input_image.shape[0] / input_image.shape[1] * target_width)
1494
-
1495
- resized_input = cv2.resize(
1496
- input_image, (target_width - border_width * 2, target_height - border_width * 2), interpolation=cv2.INTER_AREA
1497
- )
1498
- bordered_input = np.pad(
1499
- resized_input, ((border_width, border_width), (border_width, border_width), (0, 0)),
1500
- mode="constant", constant_values=200
1501
- )
1502
-
1503
- input_sequence = np.tile(bordered_input[None], (turntable_frames.shape[0], 1, 1, 1))
1504
- combined_frames = np.concatenate((turntable_frames, input_sequence), axis=1)
1505
-
1506
- imageseq2video(combined_frames, os.path.join(item_output_dir, "turntable_with_input.mp4"), fps=30)
1507
-
1508
- @torch.no_grad()
1509
- def save_evaluations(self, out_dir: str, result: edict, batch: edict, dataset) -> None:
1510
- """Backward compatibility wrapper for save_evaluation_results."""
1511
- self.save_evaluation_results(out_dir, result, batch, dataset)
1512
-
1513
- @torch.no_grad()
1514
- def save_validation_results(
1515
- self,
1516
- output_directory: str,
1517
- model_results: edict,
1518
- batch_data: edict,
1519
- dataset,
1520
- save_visualizations: bool = False
1521
- ) -> Dict[str, float]:
1522
- """Save validation results and compute aggregated metrics."""
1523
- from .utils_metrics import compute_psnr, compute_lpips, compute_ssim
1524
-
1525
- os.makedirs(output_directory, exist_ok=True)
1526
- input_data, target_data = model_results.input, model_results.target
1527
- validation_metrics = {"psnr": [], "lpips": [], "ssim": []}
1528
-
1529
- for batch_idx in range(input_data.image.size(0)):
1530
- item_uid = input_data.index[batch_idx, 0, -1].item()
1531
- should_save_visuals = (batch_idx == 0) and save_visualizations
1532
-
1533
- # Compute metrics (RGB only)
1534
- target_image = target_data.image[batch_idx][:, :3, ...]
1535
- per_view_psnr = compute_psnr(target_image, model_results.render[batch_idx])
1536
- per_view_lpips = compute_lpips(target_image, model_results.render[batch_idx])
1537
- per_view_ssim = compute_ssim(target_image, model_results.render[batch_idx])
1538
-
1539
- avg_psnr = per_view_psnr.mean().item()
1540
- avg_lpips = per_view_lpips.mean().item()
1541
- avg_ssim = per_view_ssim.mean().item()
1542
-
1543
- validation_metrics["psnr"].append(avg_psnr)
1544
- validation_metrics["lpips"].append(avg_lpips)
1545
- validation_metrics["ssim"].append(avg_ssim)
1546
-
1547
- # Save visualizations only for first item if requested
1548
- if should_save_visuals:
1549
- item_output_dir = os.path.join(output_directory, f"{item_uid:08d}")
1550
- os.makedirs(item_output_dir, exist_ok=True)
1551
-
1552
- # Save input image
1553
- input_image = rearrange(
1554
- input_data.image[batch_idx][:, :3, ...], "views channels height width -> height (views width) channels"
1555
- )
1556
- input_image = (input_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1557
- Image.fromarray(input_image).save(os.path.join(item_output_dir, "input.png"))
1558
-
1559
- # Save ground truth vs prediction comparison
1560
- comparison_image = torch.stack((target_image, model_results.render[batch_idx]), dim=0)
1561
- num_views = comparison_image.size(1)
1562
- if num_views > 10:
1563
- comparison_image = comparison_image[:, ::num_views // 10, :, :, :]
1564
- comparison_image = rearrange(
1565
- comparison_image, "comparison_type views channels height width -> (comparison_type height) (views width) channels"
1566
- )
1567
- comparison_image = (comparison_image.cpu().numpy() * 255.0).clip(0.0, 255.0).astype(np.uint8)
1568
- Image.fromarray(comparison_image).save(os.path.join(item_output_dir, "gt_vs_pred.png"))
1569
-
1570
- # Save per-view metrics
1571
- view_ids = target_data.index[batch_idx, :, 0].cpu().numpy()
1572
- with open(os.path.join(item_output_dir, "perview_metrics.txt"), "w") as f:
1573
- for i in range(per_view_psnr.size(0)):
1574
- f.write(
1575
- f"view {view_ids[i]:0>6}, psnr: {per_view_psnr[i].item():.4f}, "
1576
- f"lpips: {per_view_lpips[i].item():.4f}, ssim: {per_view_ssim[i].item():.4f}\n"
1577
- )
1578
-
1579
- # Save averaged metrics
1580
- with open(os.path.join(item_output_dir, "metrics.txt"), "w") as f:
1581
- f.write(f"psnr: {avg_psnr:.4f}\nlpips: {avg_lpips:.4f}\nssim: {avg_ssim:.4f}\n")
1582
-
1583
- print(f"Validation UID {item_uid}: PSNR={avg_psnr:.4f}, LPIPS={avg_lpips:.4f}, SSIM={avg_ssim:.4f}")
1584
-
1585
- # Save Gaussian model
1586
- crop_box = None
1587
- if self.config.model.get("clip_xyz", False):
1588
- if self.config.model.get("half_bbx_size", None) is not None:
1589
- half_size = self.config.model.half_bbx_size
1590
- crop_box = [-half_size, half_size, -half_size, half_size, -half_size, half_size]
1591
- else:
1592
- crop_box = [-0.91, 0.91, -0.91, 0.91, -0.91, 0.91]
1593
-
1594
- model_results.gaussians[batch_idx].apply_all_filters(
1595
- opacity_thres=0.02, crop_bbx=crop_box, cam_origins=None, nearfar_percent=(0.0001, 1.0)
1596
- ).save_ply(os.path.join(item_output_dir, "gaussians.ply"))
1597
-
1598
- # Create turntable visualization
1599
- num_turntable_views = 150
1600
- render_resolution = input_image.shape[0]
1601
-
1602
- turntable_frames = render_turntable(
1603
- model_results.gaussians[batch_idx], rendering_resolution=render_resolution, num_views=num_turntable_views
1604
- )
1605
- turntable_frames = rearrange(
1606
- turntable_frames, "height (views width) channels -> views height width channels", views=num_turntable_views
1607
- )
1608
- turntable_frames = np.ascontiguousarray(turntable_frames)
1609
-
1610
- imageseq2video(turntable_frames, os.path.join(item_output_dir, "turntable.mp4"), fps=30)
1611
-
1612
- # Create turntable with input overlay
1613
- border_width = 2
1614
- target_width = render_resolution
1615
- target_height = int(input_image.shape[0] / input_image.shape[1] * target_width)
1616
-
1617
- resized_input = cv2.resize(
1618
- input_image, (target_width - border_width * 2, target_height - border_width * 2), interpolation=cv2.INTER_AREA
1619
- )
1620
- bordered_input = np.pad(
1621
- resized_input, ((border_width, border_width), (border_width, border_width), (0, 0)),
1622
- mode="constant", constant_values=200
1623
- )
1624
-
1625
- input_sequence = np.tile(bordered_input[None], (turntable_frames.shape[0], 1, 1, 1))
1626
- combined_frames = np.concatenate((turntable_frames, input_sequence), axis=1)
1627
-
1628
- imageseq2video(combined_frames, os.path.join(item_output_dir, "turntable_with_input.mp4"), fps=30)
1629
-
1630
- # Return averaged metrics
1631
- return {
1632
- "psnr": torch.tensor(validation_metrics["psnr"]).mean().item(),
1633
- "lpips": torch.tensor(validation_metrics["lpips"]).mean().item(),
1634
- "ssim": torch.tensor(validation_metrics["ssim"]).mean().item(),
1635
- }
1636
-
1637
- @torch.no_grad()
1638
- def save_validations(
1639
- self,
1640
- out_dir: str,
1641
- result: edict,
1642
- batch: edict,
1643
- dataset,
1644
- save_img: bool = False
1645
- ) -> Dict[str, float]:
1646
- """Backward compatibility wrapper for save_validation_results."""
1647
- return self.save_validation_results(out_dir, result, batch, dataset, save_img)
 
22
  """
23
 
24
  import copy
25
+ from typing import List, Optional, Tuple
 
 
26
 
 
27
  import lpips
28
  import numpy as np
29
  import torch
 
32
  from easydict import EasyDict as edict
33
  from einops import rearrange
34
  from einops.layers.torch import Rearrange
 
35
 
36
  # Local imports
37
  from .utils_losses import PerceptualLoss, SsimLoss
38
  from .gaussians_renderer import (
39
  GaussianModel,
 
40
  deferred_gaussian_render,
 
41
  render_opencv_cam,
 
42
  )
43
  from .transform_data import SplitData, TransformInput, TransformTarget
44
  from .utils_transformer import (
 
218
 
219
  return xyz, features, scaling, rotation, opacity
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  class GSLRM(nn.Module):
222
  """
223
  Gaussian Splatting Large Reconstruction Model.
 
336
  def _init_rendering_modules(self, config: edict) -> None:
337
  """Initialize rendering and loss computation modules."""
338
  self.gaussian_renderer = Renderer(config)
 
339
 
340
  def _init_training_state(self, config: edict) -> None:
341
  """Initialize training state management variables."""
 
344
  self.training_max_step = None
345
  self.original_config = copy.deepcopy(config)
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  def _create_transformer_layer_runner(self, start_layer: int, end_layer: int):
349
  """
 
508
 
509
  return aligned_positions
510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  def _create_gaussian_models_and_stats(
512
  self,
513
  xyz: torch.Tensor,
 
702
  )
703
 
704
  # Perform rendering and loss computation if target data is available
 
705
  rendered_images = None
706
 
707
  if target_data is not None:
 
714
  C2W=target_data.c2w,
715
  fxfycxcy=target_data.fxfycxcy,
716
  )
 
 
 
 
 
 
 
 
 
 
 
717
 
718
  # Create Gaussian models for each batch item and compute usage statistics
719
  gaussian_models, pixel_aligned_positions, usage_statistics = self._create_gaussian_models_and_stats(
 
721
  num_pixel_aligned_gaussians, num_views, height, width, patch_size
722
  )
723
 
 
 
 
 
 
 
724
  # Compile final results
725
  return edict(
726
  input=input_data,
 
728
  gaussians=gaussian_models,
729
  pixelalign_xyz=pixel_aligned_positions,
730
  img_tokens=image_patch_tokens,
731
+ loss_metrics=None,
732
  render=rendered_images,
733
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gslrm/model/utils_losses.py DELETED
@@ -1,309 +0,0 @@
1
- # Copyright (C) 2025, FaceLift Research Group
2
- # https://github.com/weijielyu/FaceLift
3
- #
4
- # This software is free for non-commercial, research and evaluation use
5
- # under the terms of the LICENSE.md file.
6
- #
7
- # For inquiries contact: wlyu3@ucmerced.edu
8
-
9
- """
10
- Perceptual Loss Implementation using VGG19 and SSIM Loss Implementation.
11
-
12
- Adapted from https://github.com/zhengqili/Crowdsampling-the-Plenoptic-Function/blob/f5216f312cf82d77f8d20454b5eeb3930324630a/models/networks.py#L1478
13
- """
14
- import os
15
- from typing import List, Tuple, Union, Optional
16
-
17
- import scipy.io
18
- import torch
19
- import torch.nn as nn
20
- from pytorch_msssim import SSIM
21
-
22
- # VGG19 ImageNet normalization constants
23
- IMAGENET_MEAN = [123.6800, 116.7790, 103.9390]
24
-
25
- # VGG19 layer configuration
26
- VGG19_LAYER_INDICES = [0, 2, 5, 7, 10, 12, 14, 16, 19, 21, 23, 25, 28, 30, 32, 34]
27
- VGG19_LAYER_NAMES = [
28
- "conv1", "conv2", "conv3", "conv4", "conv5", "conv6", "conv7", "conv8",
29
- "conv9", "conv10", "conv11", "conv12", "conv13", "conv14", "conv15", "conv16"
30
- ]
31
- VGG19_CHANNEL_SIZES = [64, 64, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 512]
32
-
33
- # Perceptual loss weighting factors
34
- LAYER_WEIGHTS = [1.0, 1/2.6, 1/4.8, 1/3.7, 1/5.6, 10/1.5]
35
-
36
- class VGG19(nn.Module):
37
- """
38
- VGG19 network implementation for perceptual loss computation.
39
-
40
- This class implements the VGG19 architecture with specific layer outputs
41
- used for computing perceptual losses at different scales.
42
- """
43
-
44
- def __init__(self) -> None:
45
- """Initialize VGG19 network layers."""
46
- super(VGG19, self).__init__()
47
- self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=True)
48
- self.relu1 = nn.ReLU(inplace=True)
49
-
50
- self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True)
51
- self.relu2 = nn.ReLU(inplace=True)
52
- self.max1 = nn.AvgPool2d(kernel_size=2, stride=2)
53
-
54
- self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True)
55
- self.relu3 = nn.ReLU(inplace=True)
56
-
57
- self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=True)
58
- self.relu4 = nn.ReLU(inplace=True)
59
- self.max2 = nn.AvgPool2d(kernel_size=2, stride=2)
60
-
61
- self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True)
62
- self.relu5 = nn.ReLU(inplace=True)
63
-
64
- self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True)
65
- self.relu6 = nn.ReLU(inplace=True)
66
-
67
- self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True)
68
- self.relu7 = nn.ReLU(inplace=True)
69
-
70
- self.conv8 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True)
71
- self.relu8 = nn.ReLU(inplace=True)
72
- self.max3 = nn.AvgPool2d(kernel_size=2, stride=2)
73
-
74
- self.conv9 = nn.Conv2d(256, 512, kernel_size=3, padding=1, bias=True)
75
- self.relu9 = nn.ReLU(inplace=True)
76
-
77
- self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
78
- self.relu10 = nn.ReLU(inplace=True)
79
-
80
- self.conv11 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
81
- self.relu11 = nn.ReLU(inplace=True)
82
-
83
- self.conv12 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
84
- self.relu12 = nn.ReLU(inplace=True)
85
- self.max4 = nn.AvgPool2d(kernel_size=2, stride=2)
86
-
87
- self.conv13 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
88
- self.relu13 = nn.ReLU(inplace=True)
89
-
90
- self.conv14 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
91
- self.relu14 = nn.ReLU(inplace=True)
92
-
93
- self.conv15 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
94
- self.relu15 = nn.ReLU(inplace=True)
95
-
96
- self.conv16 = nn.Conv2d(512, 512, kernel_size=3, padding=1, bias=True)
97
- self.relu16 = nn.ReLU(inplace=True)
98
- self.max5 = nn.AvgPool2d(kernel_size=2, stride=2)
99
-
100
- def forward(self, x: torch.Tensor, return_style: int) -> Union[List[torch.Tensor], Tuple[torch.Tensor, ...]]:
101
- """
102
- Forward pass through VGG19 network.
103
-
104
- Args:
105
- x: Input tensor of shape [B, 3, H, W]
106
- return_style: If > 0, return style features as list; otherwise return content features as tuple
107
-
108
- Returns:
109
- Either a list of style features or tuple of content features from different layers
110
- """
111
- out1 = self.conv1(x)
112
- out2 = self.relu1(out1)
113
-
114
- out3 = self.conv2(out2)
115
- out4 = self.relu2(out3)
116
- out5 = self.max1(out4)
117
-
118
- out6 = self.conv3(out5)
119
- out7 = self.relu3(out6)
120
- out8 = self.conv4(out7)
121
- out9 = self.relu4(out8)
122
- out10 = self.max2(out9)
123
- out11 = self.conv5(out10)
124
- out12 = self.relu5(out11)
125
- out13 = self.conv6(out12)
126
- out14 = self.relu6(out13)
127
- out15 = self.conv7(out14)
128
- out16 = self.relu7(out15)
129
- out17 = self.conv8(out16)
130
- out18 = self.relu8(out17)
131
- out19 = self.max3(out18)
132
- out20 = self.conv9(out19)
133
- out21 = self.relu9(out20)
134
- out22 = self.conv10(out21)
135
- out23 = self.relu10(out22)
136
- out24 = self.conv11(out23)
137
- out25 = self.relu11(out24)
138
- out26 = self.conv12(out25)
139
- out27 = self.relu12(out26)
140
- out28 = self.max4(out27)
141
- out29 = self.conv13(out28)
142
- out30 = self.relu13(out29)
143
- out31 = self.conv14(out30)
144
- out32 = self.relu14(out31)
145
-
146
- if return_style > 0:
147
- return [out2, out7, out12, out21, out30]
148
- else:
149
- return out4, out9, out14, out23, out32
150
-
151
-
152
- class PerceptualLoss(nn.Module):
153
- """
154
- Perceptual Loss module using pre-trained VGG19.
155
-
156
- This class implements perceptual loss by comparing features extracted from
157
- different layers of a pre-trained VGG19 network. It computes weighted
158
- differences across multiple scales to capture both low-level and high-level
159
- visual differences between images.
160
- """
161
-
162
- def __init__(self, device: str = "cpu", weight_file: Optional[str] = None) -> None:
163
- """
164
- Initialize PerceptualLoss module.
165
-
166
- Args:
167
- device: Device to run computations on ('cpu' or 'cuda')
168
- weight_file: Path to VGG19 weight file. If None, uses default path or environment variable.
169
-
170
- Raises:
171
- FileNotFoundError: If weight file is not found
172
- RuntimeError: If weight file cannot be loaded
173
- """
174
- super().__init__()
175
- self.device = device
176
- self.net = VGG19()
177
-
178
- # Determine weight file path
179
- if weight_file is None:
180
- # Check environment variable first
181
- weight_file = os.environ.get('VGG19_WEIGHTS_PATH')
182
- if weight_file is None:
183
- # Fallback to default path
184
- weight_file = "/sensei-fs/users/kaiz/repos/weight-collections/imagenet-vgg-verydeep-19.mat"
185
-
186
- # Load VGG19 weights
187
- if not os.path.isfile(weight_file):
188
- raise FileNotFoundError(
189
- f"VGG19 weight file not found: {weight_file}\n"
190
- f"Download it from: https://www.vlfeat.org/matconvnet/models/imagenet-vgg-verydeep-19.mat\n"
191
- f"Expected MD5: 106118b7cf60435e6d8e04f6a6dc3657"
192
- )
193
-
194
- try:
195
- vgg_rawnet = scipy.io.loadmat(weight_file)
196
- vgg_layers = vgg_rawnet["layers"][0]
197
- except Exception as e:
198
- raise RuntimeError(f"Failed to load VGG19 weights from {weight_file}: {e}")
199
-
200
- # Load pre-trained weights into the network
201
- self._load_pretrained_weights(vgg_layers)
202
-
203
- # Set network to evaluation mode and freeze parameters
204
- self.net = self.net.eval().to(device)
205
- for param in self.net.parameters():
206
- param.requires_grad = False
207
-
208
- def _load_pretrained_weights(self, vgg_layers) -> None:
209
- """Load pre-trained VGG19 weights into the network."""
210
- for layer_idx in range(len(VGG19_LAYER_NAMES)):
211
- layer_name = VGG19_LAYER_NAMES[layer_idx]
212
- mat_layer_idx = VGG19_LAYER_INDICES[layer_idx]
213
- channel_size = VGG19_CHANNEL_SIZES[layer_idx]
214
-
215
- # Extract weights and biases from MATLAB format
216
- layer_weights = torch.from_numpy(
217
- vgg_layers[mat_layer_idx][0][0][2][0][0]
218
- ).permute(3, 2, 0, 1)
219
- layer_biases = torch.from_numpy(
220
- vgg_layers[mat_layer_idx][0][0][2][0][1]
221
- ).view(channel_size)
222
-
223
- # Assign to network
224
- getattr(self.net, layer_name).weight = nn.Parameter(layer_weights)
225
- getattr(self.net, layer_name).bias = nn.Parameter(layer_biases)
226
-
227
- def _compute_l1_error(self, truth: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
228
- """
229
- Compute L1 (Mean Absolute Error) between two tensors.
230
-
231
- Args:
232
- truth: Ground truth tensor
233
- pred: Predicted tensor
234
-
235
- Returns:
236
- L1 error as a scalar tensor
237
- """
238
- return torch.mean(torch.abs(truth - pred))
239
-
240
- def forward(self, pred_img: torch.Tensor, real_img: torch.Tensor) -> torch.Tensor:
241
- """
242
- Compute perceptual loss between predicted and real images.
243
-
244
- Args:
245
- pred_img: Predicted image tensor of shape [B, 3, H, W] in range [0, 1]
246
- real_img: Real image tensor of shape [B, 3, H, W] in range [0, 1]
247
-
248
- Returns:
249
- Perceptual loss as a scalar tensor
250
- """
251
- # Convert to ImageNet normalization (RGB -> BGR and subtract mean)
252
- imagenet_mean = torch.tensor(IMAGENET_MEAN, dtype=torch.float32, device=pred_img.device)
253
- imagenet_mean = imagenet_mean.view(1, 3, 1, 1)
254
-
255
- # Scale to [0, 255] and apply ImageNet normalization
256
- real_img_normalized = real_img * 255.0 - imagenet_mean
257
- pred_img_normalized = pred_img * 255.0 - imagenet_mean
258
-
259
- # Extract features from both images
260
- real_features = self.net(real_img_normalized, return_style=0)
261
- pred_features = self.net(pred_img_normalized, return_style=0)
262
-
263
- # Compute weighted L1 losses at different scales
264
- losses = []
265
-
266
- # Raw image loss
267
- raw_loss = self._compute_l1_error(real_img_normalized, pred_img_normalized)
268
- losses.append(raw_loss * LAYER_WEIGHTS[0])
269
-
270
- # Feature losses at different VGG layers
271
- for i, (real_feat, pred_feat) in enumerate(zip(real_features, pred_features)):
272
- feature_loss = self._compute_l1_error(real_feat, pred_feat)
273
- losses.append(feature_loss * LAYER_WEIGHTS[i + 1])
274
-
275
- # Combine all losses and normalize
276
- total_loss = sum(losses) / 255.0
277
- return total_loss
278
-
279
- class SsimLoss(nn.Module):
280
- """
281
- SSIM Loss module that computes 1 - SSIM for image similarity.
282
-
283
- Args:
284
- data_range: Range of input data (default: 1.0 for [0,1] range)
285
- """
286
-
287
- def __init__(self, data_range: float = 1.0) -> None:
288
- super().__init__()
289
- self.data_range = data_range
290
- self.ssim_module = SSIM(
291
- win_size=11,
292
- win_sigma=1.5,
293
- data_range=self.data_range,
294
- size_average=True,
295
- channel=3,
296
- )
297
-
298
- def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
299
- """
300
- Compute SSIM loss between two image tensors.
301
-
302
- Args:
303
- x: Image tensor of shape (N, C, H, W)
304
- y: Image tensor of shape (N, C, H, W)
305
-
306
- Returns:
307
- SSIM loss (1 - SSIM similarity)
308
- """
309
- return 1.0 - self.ssim_module(x, y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
splat_viewer.html DELETED
@@ -1,277 +0,0 @@
1
- <!DOCTYPE html>
2
- <html lang="en" dir="ltr">
3
- <head>
4
- <title>WebGL Gaussian Splat Viewer</title>
5
- <meta charset="utf-8" />
6
- <meta
7
- name="viewport"
8
- content="width=device-width, initial-scale=1, minimum-scale=1, maximum-scale=1, user-scalable=no"
9
- />
10
- <meta name="apple-mobile-web-app-capable" content="yes" />
11
- <meta
12
- name="apple-mobile-web-app-status-bar-style"
13
- content="black-translucent"
14
- />
15
- <style>
16
- body {
17
- overflow: hidden;
18
- margin: 0;
19
- height: 100vh;
20
- width: 100vw;
21
- font-family: sans-serif;
22
- background: black;
23
- text-shadow: 0 0 3px black;
24
- }
25
- a, body {
26
- color: white;
27
- }
28
- #info {
29
- z-index: 100;
30
- position: absolute;
31
- top: 10px;
32
- left: 15px;
33
- }
34
- h3 {
35
- margin: 5px 0;
36
- }
37
- p {
38
- margin: 5px 0;
39
- font-size: small;
40
- }
41
-
42
- .cube-wrapper {
43
- transform-style: preserve-3d;
44
- }
45
-
46
- .cube {
47
- transform-style: preserve-3d;
48
- transform: rotateX(45deg) rotateZ(45deg);
49
- animation: rotation 2s infinite;
50
- }
51
-
52
- .cube-faces {
53
- transform-style: preserve-3d;
54
- height: 80px;
55
- width: 80px;
56
- position: relative;
57
- transform-origin: 0 0;
58
- transform: translateX(0) translateY(0) translateZ(-40px);
59
- }
60
-
61
- .cube-face {
62
- position: absolute;
63
- inset: 0;
64
- background: #0017ff;
65
- border: solid 1px #ffffff;
66
- }
67
- .cube-face.top {
68
- transform: translateZ(80px);
69
- }
70
- .cube-face.front {
71
- transform-origin: 0 50%;
72
- transform: rotateY(-90deg);
73
- }
74
- .cube-face.back {
75
- transform-origin: 0 50%;
76
- transform: rotateY(-90deg) translateZ(-80px);
77
- }
78
- .cube-face.right {
79
- transform-origin: 50% 0;
80
- transform: rotateX(-90deg) translateY(-80px);
81
- }
82
- .cube-face.left {
83
- transform-origin: 50% 0;
84
- transform: rotateX(-90deg) translateY(-80px) translateZ(80px);
85
- }
86
-
87
- @keyframes rotation {
88
- 0% {
89
- transform: rotateX(45deg) rotateY(0) rotateZ(45deg);
90
- animation-timing-function: cubic-bezier(
91
- 0.17,
92
- 0.84,
93
- 0.44,
94
- 1
95
- );
96
- }
97
- 50% {
98
- transform: rotateX(45deg) rotateY(0) rotateZ(225deg);
99
- animation-timing-function: cubic-bezier(
100
- 0.76,
101
- 0.05,
102
- 0.86,
103
- 0.06
104
- );
105
- }
106
- 100% {
107
- transform: rotateX(45deg) rotateY(0) rotateZ(405deg);
108
- animation-timing-function: cubic-bezier(
109
- 0.17,
110
- 0.84,
111
- 0.44,
112
- 1
113
- );
114
- }
115
- }
116
-
117
- .scene,
118
- #message {
119
- position: absolute;
120
- display: flex;
121
- top: 0;
122
- right: 0;
123
- left: 0;
124
- bottom: 0;
125
- z-index: 2;
126
- height: 100%;
127
- width: 100%;
128
- align-items: center;
129
- justify-content: center;
130
- }
131
- #message {
132
- font-weight: bold;
133
- font-size: large;
134
- color: red;
135
- pointer-events: none;
136
- }
137
-
138
- details {
139
- font-size: small;
140
-
141
- }
142
-
143
- #progress {
144
- position: absolute;
145
- top: 0;
146
- height: 5px;
147
- background: blue;
148
- z-index: 99;
149
- transition: width 0.1s ease-in-out;
150
- }
151
-
152
- #quality {
153
- position: absolute;
154
- bottom: 10px;
155
- z-index: 999;
156
- right: 10px;
157
- }
158
-
159
- #caminfo {
160
- position: absolute;
161
- top: 10px;
162
- z-index: 999;
163
- right: 10px;
164
- }
165
- #canvas {
166
- display: block;
167
- position: absolute;
168
- top: 0;
169
- left: 0;
170
- width: 100%;
171
- height: 100%;
172
- touch-action: none;
173
- }
174
-
175
- #instructions {
176
- background: rgba(0,0,0,0.6);
177
- white-space: pre-wrap;
178
- padding: 10px;
179
- border-radius: 10px;
180
- font-size: x-small;
181
- }
182
- body.nohf .nohf {
183
- display: none;
184
- }
185
- body.nohf #progress, body.nohf .cube-face {
186
- background: #ff9d0d;
187
- }
188
- </style>
189
- </head>
190
- <body>
191
- <script>
192
- if(location.host.includes('hf.space')) document.body.classList.add('nohf');
193
- </script>
194
- <div id="info">
195
- <h3 class="nohf">WebGL 3D Gaussian Splat Viewer</h3>
196
- <p>
197
- <small class="nohf">
198
- By <a href="https://twitter.com/antimatter15">Kevin Kwok</a>.
199
- Code on
200
- <a href="https://github.com/antimatter15/splat">Github</a
201
- >.
202
- </small>
203
- </p>
204
-
205
- <details>
206
- <summary>Use mouse or arrow keys to navigate.</summary>
207
-
208
- <div id="instructions">movement (arrow keys)
209
- - left/right arrow keys to strafe side to side
210
- - up/down arrow keys to move forward/back
211
- - space to jump
212
-
213
- camera angle (wasd)
214
- - a/d to turn camera left/right
215
- - w/s to tilt camera up/down
216
- - q/e to roll camera counterclockwise/clockwise
217
- - i/k and j/l to orbit
218
-
219
- trackpad
220
- - scroll up/down/left/right to orbit
221
- - pinch to move forward/back
222
- - ctrl key + scroll to move forward/back
223
- - shift + scroll to move up/down or strafe
224
-
225
- mouse
226
- - click and drag to orbit
227
- - right click (or ctrl/cmd key) and drag up/down to move
228
-
229
- touch (mobile)
230
- - one finger to orbit
231
- - two finger pinch to move forward/back
232
- - two finger rotate to rotate camera clockwise/counterclockwise
233
- - two finger pan to move side-to-side and up-down
234
-
235
- gamepad
236
- - if you have a game controller connected it should work
237
-
238
- other
239
- - press 0-9 to switch to one of the pre-loaded camera views
240
- - press '-' or '+'key to cycle loaded cameras
241
- - press p to resume default animation
242
- - drag and drop .ply file to convert to .splat
243
- - drag and drop cameras.json to load cameras
244
- </div>
245
-
246
- </details>
247
-
248
- </div>
249
-
250
- <div id="progress"></div>
251
-
252
- <div id="message"></div>
253
- <div class="scene" id="spinner">
254
- <div class="cube-wrapper">
255
- <div class="cube">
256
- <div class="cube-faces">
257
- <div class="cube-face bottom"></div>
258
- <div class="cube-face top"></div>
259
- <div class="cube-face left"></div>
260
- <div class="cube-face right"></div>
261
- <div class="cube-face back"></div>
262
- <div class="cube-face front"></div>
263
- </div>
264
- </div>
265
- </div>
266
- </div>
267
- <canvas id="canvas"></canvas>
268
-
269
- <div id="quality">
270
- <span id="fps"></span>
271
- </div>
272
- <div id="caminfo">
273
- <span id="camid"></span>
274
- </div>
275
- <script src="main.js"></script>
276
- </body>
277
- </html>