SplatAtlas / scripts /main_render.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
4.3 kB
import argparse
import os
import torch
import torchvision
import numpy as np
from tqdm import tqdm
import sys
import concurrent.futures
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from core.registry import METHOD_REGISTRY
import methods
from ufd_evalkit.geometric import depth_to_normal
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--method", type=str, required=True)
parser.add_argument("--source_path", type=str, required=True)
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--iteration", type=int, default=30000)
parser.add_argument("--resolution", type=int, default=-1)
return parser.parse_args()
def save_np(data, path): np.save(path, data)
def main():
args = parse_args()
# 动态从注册表获取模型,拒绝写死的 import
model_class = methods.load_method(args.method)
dataset_config = {"source_path": args.source_path, "model_path": args.model_path, "resolution": args.resolution}
model = model_class(dataset_config, hyperparams={})
print(f"📥 [SplatAtlas] 正在请求变体自行加载资产... (Iteration: {args.iteration})")
try:
# 严格遵守契约:由 Wrapper 决定如何加载 .ply 或 MLP 权重
model.load(args.model_path, args.iteration)
except Exception as e:
print(f"❌ 资产加载失败,请检查 [{args.method}] 的 load() 方法实现: {e}")
return
print(f"📷 [SplatAtlas] 启动离线资产渲染 (异步 I/O 开启)")
executor = concurrent.futures.ThreadPoolExecutor(max_workers=16)
futures = []
test_render_dir = os.path.join(args.model_path, f"renders_test_{args.iteration}")
test_depth_dir = os.path.join(args.model_path, f"depths_test_{args.iteration}")
test_normal_dir = os.path.join(args.model_path, f"normals_test_{args.iteration}")
test_gt_dir = os.path.join(args.model_path, f"gt_test_{args.iteration}")
for d in [test_render_dir, test_depth_dir, test_normal_dir, test_gt_dir]: os.makedirs(d, exist_ok=True)
train_render_dir = os.path.join(args.model_path, f"renders_train_{args.iteration}")
for d in [train_render_dir]: os.makedirs(d, exist_ok=True)
with torch.no_grad():
test_cams = model.scene.getTestCameras()
if test_cams and isinstance(test_cams[0], list):
flat_cams = [c for sub in test_cams for c in sub]
else:
flat_cams = test_cams
for idx, view in enumerate(tqdm(flat_cams, desc="Rendering Test")):
render_pkg = model.render(view)
rendering = render_pkg.get("render", render_pkg.get("image")).detach().cpu()
depth = render_pkg.get("depth", None)
futures.append(executor.submit(torchvision.utils.save_image, rendering, os.path.join(test_render_dir, f'{idx:05d}.png')))
futures.append(executor.submit(torchvision.utils.save_image, view.original_image.cpu(), os.path.join(test_gt_dir, f'{idx:05d}.png')))
if depth is not None:
depth_cpu = depth.detach().cpu()
normal = render_pkg.get("normal", depth_to_normal(depth)).detach().cpu()
futures.append(executor.submit(save_np, depth_cpu.numpy(), os.path.join(test_depth_dir, f'{idx:05d}.npy')))
futures.append(executor.submit(torchvision.utils.save_image, normal, os.path.join(test_normal_dir, f'{idx:05d}.png')))
for idx, view in enumerate(tqdm(model.scene.getTrainCameras(), desc="Rendering Train")):
render_pkg = model.render(view)
rendering = render_pkg.get("render", render_pkg.get("image")).detach().cpu()
futures.append(executor.submit(torchvision.utils.save_image, rendering, os.path.join(train_render_dir, f'{idx:05d}.png')))
print("⏳ [Async I/O] GPU 渲染完毕!正在等待 CPU 将所有数据落盘...")
concurrent.futures.wait(futures)
executor.shutdown()
flag_path = os.path.join(args.model_path, f"render_complete_{args.iteration}.flag")
with open(flag_path, "w") as f: f.write("Completed by Async Engine")
print("✅ 完美落盘,标记成功!")
if __name__ == "__main__":
main()