import argparse import os import torch import torchvision import numpy as np from tqdm import tqdm import sys import concurrent.futures sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from core.registry import METHOD_REGISTRY import methods from ufd_evalkit.geometric import depth_to_normal def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--method", type=str, required=True) parser.add_argument("--source_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--iteration", type=int, default=30000) parser.add_argument("--resolution", type=int, default=-1) return parser.parse_args() def save_np(data, path): np.save(path, data) def main(): args = parse_args() # 动态从注册表获取模型,拒绝写死的 import model_class = methods.load_method(args.method) dataset_config = {"source_path": args.source_path, "model_path": args.model_path, "resolution": args.resolution} model = model_class(dataset_config, hyperparams={}) print(f"📥 [SplatAtlas] 正在请求变体自行加载资产... (Iteration: {args.iteration})") try: # 严格遵守契约:由 Wrapper 决定如何加载 .ply 或 MLP 权重 model.load(args.model_path, args.iteration) except Exception as e: print(f"❌ 资产加载失败,请检查 [{args.method}] 的 load() 方法实现: {e}") return print(f"📷 [SplatAtlas] 启动离线资产渲染 (异步 I/O 开启)") executor = concurrent.futures.ThreadPoolExecutor(max_workers=16) futures = [] test_render_dir = os.path.join(args.model_path, f"renders_test_{args.iteration}") test_depth_dir = os.path.join(args.model_path, f"depths_test_{args.iteration}") test_normal_dir = os.path.join(args.model_path, f"normals_test_{args.iteration}") test_gt_dir = os.path.join(args.model_path, f"gt_test_{args.iteration}") for d in [test_render_dir, test_depth_dir, test_normal_dir, test_gt_dir]: os.makedirs(d, exist_ok=True) train_render_dir = os.path.join(args.model_path, f"renders_train_{args.iteration}") for d in [train_render_dir]: os.makedirs(d, exist_ok=True) with torch.no_grad(): test_cams = model.scene.getTestCameras() if test_cams and isinstance(test_cams[0], list): flat_cams = [c for sub in test_cams for c in sub] else: flat_cams = test_cams for idx, view in enumerate(tqdm(flat_cams, desc="Rendering Test")): render_pkg = model.render(view) rendering = render_pkg.get("render", render_pkg.get("image")).detach().cpu() depth = render_pkg.get("depth", None) futures.append(executor.submit(torchvision.utils.save_image, rendering, os.path.join(test_render_dir, f'{idx:05d}.png'))) futures.append(executor.submit(torchvision.utils.save_image, view.original_image.cpu(), os.path.join(test_gt_dir, f'{idx:05d}.png'))) if depth is not None: depth_cpu = depth.detach().cpu() normal = render_pkg.get("normal", depth_to_normal(depth)).detach().cpu() futures.append(executor.submit(save_np, depth_cpu.numpy(), os.path.join(test_depth_dir, f'{idx:05d}.npy'))) futures.append(executor.submit(torchvision.utils.save_image, normal, os.path.join(test_normal_dir, f'{idx:05d}.png'))) for idx, view in enumerate(tqdm(model.scene.getTrainCameras(), desc="Rendering Train")): render_pkg = model.render(view) rendering = render_pkg.get("render", render_pkg.get("image")).detach().cpu() futures.append(executor.submit(torchvision.utils.save_image, rendering, os.path.join(train_render_dir, f'{idx:05d}.png'))) print("⏳ [Async I/O] GPU 渲染完毕!正在等待 CPU 将所有数据落盘...") concurrent.futures.wait(futures) executor.shutdown() flag_path = os.path.join(args.model_path, f"render_complete_{args.iteration}.flag") with open(flag_path, "w") as f: f.write("Completed by Async Engine") print("✅ 完美落盘,标记成功!") if __name__ == "__main__": main()