|
|
| import argparse |
| import os |
| import json |
| import torch |
| import torchvision |
| import numpy as np |
| from tqdm import tqdm |
| from torch.utils.tensorboard import SummaryWriter |
| import sys |
| import time |
| import concurrent.futures |
|
|
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
| from ufd_evalkit.geometric import depth_to_normal |
| import methods |
| from core.registry import METHOD_REGISTRY |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--method", type=str, required=True) |
| parser.add_argument("--source_path", type=str, required=True) |
| parser.add_argument("--model_path", type=str, required=True) |
| parser.add_argument("--iterations", type=int, default=30000) |
| parser.add_argument("--save_iterations", nargs="+", type=int, default=[5000, 10000, 20000, 30000]) |
| parser.add_argument("--resolution", type=int, default=-1) |
| parser.add_argument("--track_decoupling", action="store_true") |
| parser.add_argument("--cap_gaussians", type=int, default=None) |
| parser.add_argument("--seed", type=int, default=0) |
| parser.add_argument("--skip_render", action="store_true") |
| return parser.parse_args() |
|
|
| def save_np(data, path): np.save(path, data) |
|
|
| def in_memory_render(model, model_path, iteration): |
| print("\n[Memory Fusion] Flushing VRAM fragments...") |
| torch.cuda.empty_cache() |
| |
| executor = concurrent.futures.ThreadPoolExecutor(max_workers=16) |
| futures = [] |
| |
| with torch.no_grad(): |
| test_render_dir = os.path.join(model_path, f"renders_test_{iteration}") |
| test_depth_dir = os.path.join(model_path, f"depths_test_{iteration}") |
| test_normal_dir = os.path.join(model_path, f"normals_test_{iteration}") |
| test_gt_dir = os.path.join(model_path, f"gt_test_{iteration}") |
| for d in [test_render_dir, test_depth_dir, test_normal_dir, test_gt_dir]: os.makedirs(d, exist_ok=True) |
| |
| print(f"Launching async full-dimensional rendering | Test Set") |
| for idx, view in enumerate(tqdm(model.scene.getTestCameras(), desc="Rendering Test")): |
| render_pkg = model.render(view) |
| rendering = render_pkg.get("render", render_pkg.get("image")).detach().cpu() |
| depth = render_pkg.get("depth", None) |
| |
| futures.append(executor.submit(torchvision.utils.save_image, rendering, os.path.join(test_render_dir, f'{idx:05d}.png'))) |
| futures.append(executor.submit(torchvision.utils.save_image, view.original_image.cpu(), os.path.join(test_gt_dir, f'{idx:05d}.png'))) |
| |
| if depth is not None: |
| depth_cpu = depth.detach().cpu() |
| normal = render_pkg.get("normal", depth_to_normal(view, depth)).detach().cpu() |
| futures.append(executor.submit(save_np, depth_cpu.numpy(), os.path.join(test_depth_dir, f'{idx:05d}.npy'))) |
| futures.append(executor.submit(torchvision.utils.save_image, normal, os.path.join(test_normal_dir, f'{idx:05d}.png'))) |
|
|
| train_render_dir = os.path.join(model_path, f"renders_train_{iteration}") |
| train_depth_dir = os.path.join(model_path, f"depths_train_{iteration}") |
| for d in [train_render_dir, train_depth_dir]: os.makedirs(d, exist_ok=True) |
| |
| print(f"Launching async full-dimensional rendering | Train Set") |
| for idx, view in enumerate(tqdm(model.scene.getTrainCameras(), desc="Rendering Train")): |
| render_pkg = model.render(view) |
| rendering = render_pkg.get("render", render_pkg.get("image")).detach().cpu() |
| futures.append(executor.submit(torchvision.utils.save_image, rendering, os.path.join(train_render_dir, f'{idx:05d}.png'))) |
| |
| depth = render_pkg.get("depth", None) |
| if depth is not None: |
| futures.append(executor.submit(save_np, depth.detach().cpu().numpy(), os.path.join(train_depth_dir, f'{idx:05d}.npy'))) |
|
|
| print("[Async I/O] GPU rendering complete. Waiting for CPU flush...") |
| concurrent.futures.wait(futures) |
| |
| error_log_path = os.path.join(model_path, "render_error.log") |
| has_error = False |
| for future in futures: |
| try: |
| future.result() |
| except Exception as e: |
| if not has_error: |
| print(f"\n[CRITICAL ERROR] Async render thread crashed. See {error_log_path}") |
| has_error = True |
| with open(error_log_path, "a") as err_f: |
| err_f.write(f"Exception: {str(e)}\n") |
| |
| executor.shutdown() |
| |
| if not has_error: |
| flag_path = os.path.join(model_path, f"render_complete_{iteration}.flag") |
| with open(flag_path, "w") as f: f.write("Completed by Async Engine") |
| print("Flush complete. Flag written.") |
| else: |
| print("[WARNING] Render flag not written due to errors.") |
|
|
| def main(): |
| import random |
| args = parse_args() |
| random.seed(args.seed); np.random.seed(args.seed) |
| torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) |
|
|
| os.makedirs(args.model_path, exist_ok=True) |
| tb_writer = SummaryWriter(args.model_path) |
| |
| dataset_config = {"source_path": args.source_path, "model_path": args.model_path, "resolution": args.resolution} |
| model = methods.load_method(args.method)(dataset_config, {"track_decoupling": args.track_decoupling, "cap_gaussians": args.cap_gaussians}) |
| |
| print(f"[SplatAtlas] Starting training | Method: {args.method} | Resolution: {args.resolution}") |
| |
| num_train = len(model.scene.getTrainCameras()) |
| num_test = len(model.scene.getTestCameras()) if hasattr(model.scene, "getTestCameras") else 0 |
| split_info = ( |
| f"========================================\n" |
| f"[Dataset Topology Checker]\n" |
| f"Method: {args.method}\n" |
| f"Train Cameras: {num_train}\n" |
| f"Test Cameras: {num_test}\n" |
| f"Leakage Status: {'SAFE' if num_test > 0 else 'DANGER (NO TEST SET DETECTED)'}\n" |
| f"========================================" |
| ) |
| print(split_info) |
| |
| with open(os.path.join(args.model_path, "dataset_split.log"), "w") as f: |
| f.write(split_info + "\n") |
| |
| cam_json_path = os.path.join(args.model_path, "cameras.json") |
| with open(cam_json_path, "w") as f: |
| json.dump([{"id": i, "center": c.camera_center.detach().cpu().numpy().tolist(), "view_dir": c.world_view_transform[:3, 2].detach().cpu().numpy().tolist()} for i, c in enumerate([c for item in (model.scene.getTrainCameras().values() if isinstance(model.scene.getTrainCameras(), dict) else model.scene.getTrainCameras()) for c in (item if isinstance(item, list) else [item])])], f, indent=4) |
| |
| t_train_start = time.time() |
| timing_path = os.path.join(args.model_path, "timing.json") |
| with open(timing_path, "w") as f: |
| json.dump({"train_start_unix": t_train_start, "status": "running"}, f) |
|
|
| cumulative_gpu_time = 0.0 |
| for step in range(1, args.iterations + 1): |
| t_step_start = time.time() |
| |
| stats_or_tuple = model.train_iteration(step) |
| |
| if isinstance(stats_or_tuple, tuple): |
| stats, histograms = stats_or_tuple |
| else: |
| stats, histograms = stats_or_tuple, {} |
| |
| step_elapsed = time.time() - t_step_start |
| cumulative_gpu_time += step_elapsed |
| stats['cumulative_gpu_time_sec'] = cumulative_gpu_time |
| stats['step_time_ms'] = step_elapsed * 1000.0 |
| if step % 100 == 0: |
| print(f"[TIMING] step={step} step_time_ms={stats.get('step_time_ms',0):.2f} " |
| f"iter_time_ms={stats.get('iter_time_ms',0):.2f} " |
| f"peak_vram_GB={stats.get('peak_vram_GB',0):.3f}", flush=True) |
| |
| if step % 100 == 0: |
| print(f"\n{'='*40}") |
| print(f"Step {step:05d} Runtime Monitor") |
| print(f"{'='*40}") |
| |
| print("[Scalars]") |
| for key, val in stats.items(): |
| if val is not None: |
| tb_writer.add_scalar(f"train/{key}", val, step) |
| if isinstance(val, float): |
| print(f" |- {key}: {val:.6f}") |
| else: |
| print(f" |- {key}: {val}") |
| |
| if histograms: |
| print("\n[Histograms]") |
| for dist_name, tensor_data in histograms.items(): |
| if not torch.isnan(tensor_data).any(): |
| try: |
| arr = tensor_data.float().cpu().numpy().flatten() |
| arr = arr[np.isfinite(arr)] |
| if len(arr) > 0: |
| counts, bin_edges = np.histogram(arr.astype(np.float64), bins=100) |
| tb_writer.add_histogram_raw( |
| f'Distributions/{dist_name}', |
| min=float(arr.min()), max=float(arr.max()), |
| num=len(arr), |
| sum=float(arr.sum()), |
| sum_squares=float((arr.astype(np.float64)**2).sum()), |
| bucket_limits=bin_edges[1:].tolist(), |
| bucket_counts=counts.tolist(), |
| global_step=step |
| ) |
| except Exception as e: |
| print(f" |- histogram write failed: {e}") |
| t_min = tensor_data.min().item() |
| t_max = tensor_data.max().item() |
| t_mean = tensor_data.mean().item() |
| shape_str = "x".join(map(str, tensor_data.shape)) |
| print(f" |- {dist_name}:") |
| print(f" | - Shape: [{shape_str}], Mean: {t_mean:.4f}, Min: {t_min:.4f}, Max: {t_max:.4f}") |
| else: |
| print(f" |- {dist_name}: contains NaN, skipped.") |
| |
| if step in args.save_iterations: |
| model.save(args.model_path, step) |
| print(f"\n[CheckPoint] Assets saved to: {args.model_path} (Step {step})") |
|
|
| t_train_end = time.time() |
| total_training_seconds = t_train_end - t_train_start |
|
|
| with open(timing_path, "w") as f: |
| json.dump({ |
| "train_start_unix": t_train_start, |
| "train_end_unix": t_train_end, |
| "total_training_seconds": total_training_seconds, |
| "total_training_minutes": total_training_seconds / 60.0, |
| "cumulative_gpu_time_seconds": cumulative_gpu_time, |
| "iterations": args.iterations, |
| "status": "complete" |
| }, f, indent=4) |
|
|
| print(f"\n[Timing] Training complete. Total: {total_training_seconds:.1f}s / {total_training_seconds/60.0:.2f}min") |
| print(f"[Timing] timing.json written to {timing_path}") |
|
|
| tb_writer.close() |
| if not args.skip_render: |
| in_memory_render(model, args.model_path, args.iterations) |
|
|
| if __name__ == "__main__": |
| main() |
|
|