SplatAtlas / scripts /patch_oom.py
KCBtheone's picture
Upload SplatAtlas benchmark pipeline code
23e73f9 verified
Raw
History Blame Contribute Delete
935 Bytes
import os
target_file = "/root/autodl-tmp/SplatAtlas/methods/wrapper_2dgs.py"
with open(target_file, "r") as f:
code = f.read()
old_block = """ chunk_size = 2000
for i in range(0, V, chunk_size):
end = min(i + chunk_size, V)
pts_chunk = query_points[i:end]
dist_sq = torch.cdist(pts_chunk, xyz, p=2).pow(2)"""
new_block = """ N_gaussians = xyz.shape[0]
chunk_size = max(1, 30_000_000 // N_gaussians)
for i in range(0, V, chunk_size):
end = min(i + chunk_size, V)
pts_chunk = query_points[i:end]
dist_sq = torch.cdist(pts_chunk, xyz, p=2).pow(2)"""
if old_block in code:
code = code.replace(old_block, new_block)
with open(target_file, "w") as f:
f.write(code)
print("Patch applied successfully.")
else:
print("Block not found or already patched.")