SplatAtlas / scripts /run_absgssgf.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
4.46 kB
import os
import yaml
import subprocess
import time
import shutil
import glob
import re
def run_step(name, cmd, env):
print(f" [Step] {name}...")
start_time = time.time()
try:
result = subprocess.run(cmd, shell=True, env=env)
elapsed = time.time() - start_time
if result.returncode != 0:
print(f" [Failed] {name}. Elapsed: {elapsed:.1f}s")
return False
print(f" [Completed] {name}. Elapsed: {elapsed:.1f}s")
return True
except Exception as e:
print(f" [Exception] {name}: {e}")
return False
def main():
config_path = "/root/autodl-tmp/SplatAtlas/configs/absgssgf_benchmark.yaml"
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
method = cfg["method_name"]
iters = cfg["global_settings"]["iterations"]
curr_env = os.environ.copy()
curr_env["PYTHONPATH"] = f"/root/autodl-tmp/SplatAtlas:/root/autodl-tmp/absgs_official:{curr_env.get('PYTHONPATH', '')}"
for ds in cfg["datasets"]:
print(f"\n>>>> Dataset Domain: {ds['name']} <<<<")
for scene in ds["scenes"]:
source = os.path.join(ds["base_path"], scene)
model_out = f"/root/autodl-tmp/SplatAtlas/outputs/{method}_{scene}"
res = ds.get("resolution", 1)
if not os.path.exists(source):
continue
print(f"\n[{method.upper()}] Processing Scene: {scene} (r={res})")
ply_path = os.path.join(model_out, f"point_cloud/iteration_{iters}/point_cloud.ply")
metrics_json = os.path.join(model_out, f"metrics_test_iter{iters}.json")
render_flag = os.path.join(model_out, f"render_complete_{iters}.flag")
renders_dir = os.path.join(model_out, f"renders_test_{iters}")
depths_dir = os.path.join(model_out, f"depths_test_{iters}")
if os.path.exists(ply_path):
print(f" [Checkpoint] Ply file found.")
if not os.path.exists(render_flag):
for d in ["renders_test", "renders_train", "depths_test", "depths_train", "normals_test", "gt_test"]:
d_path = os.path.join(model_out, f"{d}_{iters}")
if os.path.exists(d_path): shutil.rmtree(d_path)
render_cmd = f"python /root/autodl-tmp/SplatAtlas/scripts/main_render.py --method {method} --source_path {source} --model_path {model_out} --iteration {iters} --resolution {res}"
if not run_step("Offline Rendering", render_cmd, curr_env): continue
else:
ckpt_list = glob.glob(os.path.join(model_out, "chkpnt*.pth"))
if ckpt_list:
latest_ckpt = max(ckpt_list, key=lambda x: int(re.findall(r'chkpnt(\d+)\.pth', os.path.basename(x))[0] if re.findall(r'chkpnt(\d+)\.pth', os.path.basename(x)) else 0))
print(f" [Resume] 发现断点,从 {latest_ckpt} 续训...")
train_cmd = f"python /root/autodl-tmp/SplatAtlas/scripts/main_train.py --method {method} --source_path {source} --model_path {model_out} --iterations {iters} --resolution {res} --track_decoupling --start_checkpoint {latest_ckpt}"
else:
train_cmd = f"python /root/autodl-tmp/SplatAtlas/scripts/main_train.py --method {method} --source_path {source} --model_path {model_out} --iterations {iters} --resolution {res} --track_decoupling"
if not run_step("Training & In-situ Rendering", train_cmd, curr_env): continue
if not os.path.exists(metrics_json):
eval_cmd = f"python /root/autodl-tmp/SplatAtlas/ufd_evalkit/run_eval.py --method {method} --scene {scene} --render_dir {renders_dir} --gt_dir {model_out}/gt_test_{iters} --ply_path {ply_path} --output_json {metrics_json} --colmap_dir {source} --depth_dir {depths_dir}"
run_step("Metrics Calculation", eval_cmd, curr_env)
physics_cmd = f"python /root/autodl-tmp/SplatAtlas/scripts/compute_offline_physics.py --method {method} --source_path {source} --model_path {model_out} --iteration {iters}"
run_step("Offline Physics Diagnosis", physics_cmd, curr_env)
run_step("Updating Grand Table", "python /root/autodl-tmp/SplatAtlas/scripts/build_grand_table.py", curr_env)
if __name__ == "__main__":
main()