SplatAtlas / scripts /main_train.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
11 kB
import argparse
import os
import json
import torch
import torchvision
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import sys
import time
import concurrent.futures
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from ufd_evalkit.geometric import depth_to_normal
import methods
from core.registry import METHOD_REGISTRY
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--method", type=str, required=True)
parser.add_argument("--source_path", type=str, required=True)
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--iterations", type=int, default=30000)
parser.add_argument("--save_iterations", nargs="+", type=int, default=[5000, 10000, 20000, 30000])
parser.add_argument("--resolution", type=int, default=-1)
parser.add_argument("--track_decoupling", action="store_true")
parser.add_argument("--cap_gaussians", type=int, default=None)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--skip_render", action="store_true")
return parser.parse_args()
def save_np(data, path): np.save(path, data)
def in_memory_render(model, model_path, iteration):
print("\n[Memory Fusion] Flushing VRAM fragments...")
torch.cuda.empty_cache()
executor = concurrent.futures.ThreadPoolExecutor(max_workers=16)
futures = []
with torch.no_grad():
test_render_dir = os.path.join(model_path, f"renders_test_{iteration}")
test_depth_dir = os.path.join(model_path, f"depths_test_{iteration}")
test_normal_dir = os.path.join(model_path, f"normals_test_{iteration}")
test_gt_dir = os.path.join(model_path, f"gt_test_{iteration}")
for d in [test_render_dir, test_depth_dir, test_normal_dir, test_gt_dir]: os.makedirs(d, exist_ok=True)
print(f"Launching async full-dimensional rendering | Test Set")
for idx, view in enumerate(tqdm(model.scene.getTestCameras(), desc="Rendering Test")):
render_pkg = model.render(view)
rendering = render_pkg.get("render", render_pkg.get("image")).detach().cpu()
depth = render_pkg.get("depth", None)
futures.append(executor.submit(torchvision.utils.save_image, rendering, os.path.join(test_render_dir, f'{idx:05d}.png')))
futures.append(executor.submit(torchvision.utils.save_image, view.original_image.cpu(), os.path.join(test_gt_dir, f'{idx:05d}.png')))
if depth is not None:
depth_cpu = depth.detach().cpu()
normal = render_pkg.get("normal", depth_to_normal(view, depth)).detach().cpu()
futures.append(executor.submit(save_np, depth_cpu.numpy(), os.path.join(test_depth_dir, f'{idx:05d}.npy')))
futures.append(executor.submit(torchvision.utils.save_image, normal, os.path.join(test_normal_dir, f'{idx:05d}.png')))
train_render_dir = os.path.join(model_path, f"renders_train_{iteration}")
train_depth_dir = os.path.join(model_path, f"depths_train_{iteration}")
for d in [train_render_dir, train_depth_dir]: os.makedirs(d, exist_ok=True)
print(f"Launching async full-dimensional rendering | Train Set")
for idx, view in enumerate(tqdm(model.scene.getTrainCameras(), desc="Rendering Train")):
render_pkg = model.render(view)
rendering = render_pkg.get("render", render_pkg.get("image")).detach().cpu()
futures.append(executor.submit(torchvision.utils.save_image, rendering, os.path.join(train_render_dir, f'{idx:05d}.png')))
depth = render_pkg.get("depth", None)
if depth is not None:
futures.append(executor.submit(save_np, depth.detach().cpu().numpy(), os.path.join(train_depth_dir, f'{idx:05d}.npy')))
print("[Async I/O] GPU rendering complete. Waiting for CPU flush...")
concurrent.futures.wait(futures)
error_log_path = os.path.join(model_path, "render_error.log")
has_error = False
for future in futures:
try:
future.result()
except Exception as e:
if not has_error:
print(f"\n[CRITICAL ERROR] Async render thread crashed. See {error_log_path}")
has_error = True
with open(error_log_path, "a") as err_f:
err_f.write(f"Exception: {str(e)}\n")
executor.shutdown()
if not has_error:
flag_path = os.path.join(model_path, f"render_complete_{iteration}.flag")
with open(flag_path, "w") as f: f.write("Completed by Async Engine")
print("Flush complete. Flag written.")
else:
print("[WARNING] Render flag not written due to errors.")
def main():
import random
args = parse_args()
random.seed(args.seed); np.random.seed(args.seed)
torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed)
os.makedirs(args.model_path, exist_ok=True)
tb_writer = SummaryWriter(args.model_path)
dataset_config = {"source_path": args.source_path, "model_path": args.model_path, "resolution": args.resolution}
model = methods.load_method(args.method)(dataset_config, {"track_decoupling": args.track_decoupling, "cap_gaussians": args.cap_gaussians})
print(f"[SplatAtlas] Starting training | Method: {args.method} | Resolution: {args.resolution}")
num_train = len(model.scene.getTrainCameras())
num_test = len(model.scene.getTestCameras()) if hasattr(model.scene, "getTestCameras") else 0
split_info = (
f"========================================\n"
f"[Dataset Topology Checker]\n"
f"Method: {args.method}\n"
f"Train Cameras: {num_train}\n"
f"Test Cameras: {num_test}\n"
f"Leakage Status: {'SAFE' if num_test > 0 else 'DANGER (NO TEST SET DETECTED)'}\n"
f"========================================"
)
print(split_info)
with open(os.path.join(args.model_path, "dataset_split.log"), "w") as f:
f.write(split_info + "\n")
cam_json_path = os.path.join(args.model_path, "cameras.json")
with open(cam_json_path, "w") as f:
json.dump([{"id": i, "center": c.camera_center.detach().cpu().numpy().tolist(), "view_dir": c.world_view_transform[:3, 2].detach().cpu().numpy().tolist()} for i, c in enumerate([c for item in (model.scene.getTrainCameras().values() if isinstance(model.scene.getTrainCameras(), dict) else model.scene.getTrainCameras()) for c in (item if isinstance(item, list) else [item])])], f, indent=4)
t_train_start = time.time()
timing_path = os.path.join(args.model_path, "timing.json")
with open(timing_path, "w") as f:
json.dump({"train_start_unix": t_train_start, "status": "running"}, f)
cumulative_gpu_time = 0.0
for step in range(1, args.iterations + 1):
t_step_start = time.time()
stats_or_tuple = model.train_iteration(step)
if isinstance(stats_or_tuple, tuple):
stats, histograms = stats_or_tuple
else:
stats, histograms = stats_or_tuple, {}
step_elapsed = time.time() - t_step_start
cumulative_gpu_time += step_elapsed
stats['cumulative_gpu_time_sec'] = cumulative_gpu_time
stats['step_time_ms'] = step_elapsed * 1000.0
if step % 100 == 0:
print(f"[TIMING] step={step} step_time_ms={stats.get('step_time_ms',0):.2f} "
f"iter_time_ms={stats.get('iter_time_ms',0):.2f} "
f"peak_vram_GB={stats.get('peak_vram_GB',0):.3f}", flush=True)
if step % 100 == 0:
print(f"\n{'='*40}")
print(f"Step {step:05d} Runtime Monitor")
print(f"{'='*40}")
print("[Scalars]")
for key, val in stats.items():
if val is not None:
tb_writer.add_scalar(f"train/{key}", val, step)
if isinstance(val, float):
print(f" |- {key}: {val:.6f}")
else:
print(f" |- {key}: {val}")
if histograms:
print("\n[Histograms]")
for dist_name, tensor_data in histograms.items():
if not torch.isnan(tensor_data).any():
try:
arr = tensor_data.float().cpu().numpy().flatten()
arr = arr[np.isfinite(arr)]
if len(arr) > 0:
counts, bin_edges = np.histogram(arr.astype(np.float64), bins=100)
tb_writer.add_histogram_raw(
f'Distributions/{dist_name}',
min=float(arr.min()), max=float(arr.max()),
num=len(arr),
sum=float(arr.sum()),
sum_squares=float((arr.astype(np.float64)**2).sum()),
bucket_limits=bin_edges[1:].tolist(),
bucket_counts=counts.tolist(),
global_step=step
)
except Exception as e:
print(f" |- histogram write failed: {e}")
t_min = tensor_data.min().item()
t_max = tensor_data.max().item()
t_mean = tensor_data.mean().item()
shape_str = "x".join(map(str, tensor_data.shape))
print(f" |- {dist_name}:")
print(f" | - Shape: [{shape_str}], Mean: {t_mean:.4f}, Min: {t_min:.4f}, Max: {t_max:.4f}")
else:
print(f" |- {dist_name}: contains NaN, skipped.")
if step in args.save_iterations:
model.save(args.model_path, step)
print(f"\n[CheckPoint] Assets saved to: {args.model_path} (Step {step})")
t_train_end = time.time()
total_training_seconds = t_train_end - t_train_start
with open(timing_path, "w") as f:
json.dump({
"train_start_unix": t_train_start,
"train_end_unix": t_train_end,
"total_training_seconds": total_training_seconds,
"total_training_minutes": total_training_seconds / 60.0,
"cumulative_gpu_time_seconds": cumulative_gpu_time,
"iterations": args.iterations,
"status": "complete"
}, f, indent=4)
print(f"\n[Timing] Training complete. Total: {total_training_seconds:.1f}s / {total_training_seconds/60.0:.2f}min")
print(f"[Timing] timing.json written to {timing_path}")
tb_writer.close()
if not args.skip_render:
in_memory_render(model, args.model_path, args.iterations)
if __name__ == "__main__":
main()