import os target_file = "/root/autodl-tmp/SplatAtlas/methods/wrapper_2dgs.py" with open(target_file, "r") as f: code = f.read() old_block = """ chunk_size = 2000 for i in range(0, V, chunk_size): end = min(i + chunk_size, V) pts_chunk = query_points[i:end] dist_sq = torch.cdist(pts_chunk, xyz, p=2).pow(2)""" new_block = """ N_gaussians = xyz.shape[0] chunk_size = max(1, 30_000_000 // N_gaussians) for i in range(0, V, chunk_size): end = min(i + chunk_size, V) pts_chunk = query_points[i:end] dist_sq = torch.cdist(pts_chunk, xyz, p=2).pow(2)""" if old_block in code: code = code.replace(old_block, new_block) with open(target_file, "w") as f: f.write(code) print("Patch applied successfully.") else: print("Block not found or already patched.")