| import os |
|
|
| target_file = "/root/autodl-tmp/SplatAtlas/methods/wrapper_2dgs.py" |
| with open(target_file, "r") as f: |
| code = f.read() |
|
|
| old_block = """ chunk_size = 2000 |
| for i in range(0, V, chunk_size): |
| end = min(i + chunk_size, V) |
| pts_chunk = query_points[i:end] |
| dist_sq = torch.cdist(pts_chunk, xyz, p=2).pow(2)""" |
|
|
| new_block = """ N_gaussians = xyz.shape[0] |
| chunk_size = max(1, 30_000_000 // N_gaussians) |
| for i in range(0, V, chunk_size): |
| end = min(i + chunk_size, V) |
| pts_chunk = query_points[i:end] |
| dist_sq = torch.cdist(pts_chunk, xyz, p=2).pow(2)""" |
|
|
| if old_block in code: |
| code = code.replace(old_block, new_block) |
| with open(target_file, "w") as f: |
| f.write(code) |
| print("Patch applied successfully.") |
| else: |
| print("Block not found or already patched.") |
|
|