prithivMLmods commited on
Commit
3ff2b9e
·
verified ·
1 Parent(s): f530214

Delete scripts

Browse files
scripts/gradio_app.py DELETED
@@ -1,218 +0,0 @@
1
- import argparse
2
- import gc
3
- import os
4
- import sys
5
-
6
- import gradio as gr
7
- import torch
8
- from omegaconf import OmegaConf
9
-
10
- # Add project root to path
11
- sys.path.append(os.getcwd())
12
-
13
- from ultrashape.rembg import BackgroundRemover
14
- from ultrashape.utils.misc import instantiate_from_config
15
- from ultrashape.surface_loaders import SharpEdgeSurfaceLoader
16
- from ultrashape.utils import voxelize_from_point
17
- from ultrashape.pipelines import UltraShapePipeline
18
-
19
- # Global variables to cache the model
20
- MODEL_CACHE = {}
21
-
22
-
23
- def get_pipeline_cached(config_path, ckpt_path, device='cuda', low_vram=False):
24
- # Check if we have a valid cached pipeline for this checkpoint
25
- if "pipeline" in MODEL_CACHE and MODEL_CACHE.get("ckpt_path") == ckpt_path:
26
- print("Using cached pipeline...")
27
- return MODEL_CACHE["pipeline"], MODEL_CACHE["config"]
28
-
29
- # Clear old cache if it exists (e.g. different checkpoint)
30
- if MODEL_CACHE:
31
- print("Clearing old model cache...")
32
- MODEL_CACHE.clear()
33
- gc.collect()
34
- torch.cuda.empty_cache()
35
-
36
- print(f"Loading config from {config_path}...")
37
- config = OmegaConf.load(config_path)
38
-
39
- print("Instantiating VAE...")
40
- vae = instantiate_from_config(config.model.params.vae_config)
41
-
42
- print("Instantiating DiT...")
43
- dit = instantiate_from_config(config.model.params.dit_cfg)
44
-
45
- print("Instantiating Conditioner...")
46
- conditioner = instantiate_from_config(config.model.params.conditioner_config)
47
-
48
- print("Instantiating Scheduler & Processor...")
49
- scheduler = instantiate_from_config(config.model.params.scheduler_cfg)
50
- image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
51
-
52
- print(f"Loading weights from {ckpt_path}...")
53
- weights = torch.load(ckpt_path, map_location='cpu')
54
-
55
- vae.load_state_dict(weights['vae'], strict=True)
56
- dit.load_state_dict(weights['dit'], strict=True)
57
- conditioner.load_state_dict(weights['conditioner'], strict=True)
58
-
59
- vae.eval().to(device)
60
- dit.eval().to(device)
61
- conditioner.eval().to(device)
62
-
63
- if hasattr(vae, 'enable_flashvdm_decoder'):
64
- vae.enable_flashvdm_decoder()
65
-
66
- print("Creating Pipeline...")
67
- pipeline = UltraShapePipeline(
68
- vae=vae,
69
- model=dit,
70
- scheduler=scheduler,
71
- conditioner=conditioner,
72
- image_processor=image_processor
73
- )
74
-
75
- if low_vram:
76
- pipeline.enable_model_cpu_offload()
77
-
78
- MODEL_CACHE["pipeline"] = pipeline
79
- MODEL_CACHE["config"] = config
80
- MODEL_CACHE["ckpt_path"] = ckpt_path
81
-
82
- return pipeline, config
83
-
84
-
85
- def predict(
86
- image_input,
87
- mesh_input,
88
- steps,
89
- scale,
90
- octree_res,
91
- num_latents,
92
- chunk_size,
93
- seed,
94
- remove_bg,
95
- ckpt_path,
96
- low_vram
97
- ):
98
- # Aggressive memory cleanup at start
99
- gc.collect()
100
- torch.cuda.empty_cache()
101
-
102
- try:
103
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
- config_path = "configs/infer_dit_refine.yaml"
105
-
106
- if not os.path.exists(config_path):
107
- raise FileNotFoundError(f"Config not found at {config_path}")
108
-
109
- pipeline, config = get_pipeline_cached(config_path, ckpt_path, device, low_vram)
110
-
111
- voxel_res = config.model.params.vae_config.params.voxel_query_res
112
-
113
- print(f"Initializing Surface Loader (Token Num: {num_latents})...")
114
- loader = SharpEdgeSurfaceLoader(
115
- num_sharp_points=204800,
116
- num_uniform_points=204800,
117
- )
118
-
119
- print(f"Processing inputs...")
120
- if image_input is None:
121
- raise ValueError("Image input is required")
122
- if mesh_input is None:
123
- raise ValueError("Mesh input is required")
124
-
125
- # Handle image input
126
- if isinstance(image_input, dict):
127
- # In newer gradio versions Image component might return a dict for mask etc, but usually just PIL/numpy
128
- # if type='pil' it is PIL.Image
129
- pass
130
-
131
- image = image_input.convert("RGBA")
132
-
133
- if remove_bg or image.mode != 'RGBA':
134
- rembg = BackgroundRemover()
135
- image = rembg(image)
136
-
137
- # Handle mesh input - Gradio Model3D returns path to file
138
- surface = loader(mesh_input, normalize_scale=scale).to(device, dtype=torch.float16)
139
- pc = surface[:, :, :3] # [B, N, 3]
140
-
141
- # Voxelize
142
- _, voxel_idx = voxelize_from_point(pc, num_latents, resolution=voxel_res)
143
-
144
- print("Running diffusion process...")
145
- gen_device = "cpu" if low_vram else device
146
- generator = torch.Generator(gen_device).manual_seed(int(seed))
147
-
148
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
149
- mesh_out_list, _ = pipeline(
150
- image=image,
151
- voxel_cond=voxel_idx,
152
- generator=generator,
153
- box_v=1.0,
154
- mc_level=0.0,
155
- octree_resolution=int(octree_res),
156
- num_chunks=int(chunk_size),
157
- num_inference_steps=int(steps)
158
- )
159
-
160
- # Save output
161
- output_dir = "outputs_gradio"
162
- os.makedirs(output_dir, exist_ok=True)
163
- base_name = "output"
164
- save_path = os.path.join(output_dir, f"{base_name}_refined.glb")
165
-
166
- mesh_out = mesh_out_list[0]
167
- mesh_out.export(save_path)
168
- print(f"Successfully saved to {save_path}")
169
-
170
- return save_path
171
-
172
- finally:
173
- # Aggressive memory cleanup at end
174
- gc.collect()
175
- torch.cuda.empty_cache()
176
-
177
-
178
- if __name__ == "__main__":
179
- parser = argparse.ArgumentParser(description="UltraShape Gradio App")
180
- parser.add_argument("--ckpt", type=str, required=True, help="Path to split checkpoint (.pt)")
181
- parser.add_argument("--share", action="store_true", help="Share the gradio app")
182
- parser.add_argument("--low_vram", action="store_true", help="Optimize for low VRAM usage")
183
-
184
- args = parser.parse_args()
185
-
186
- # Define Gradio Interface
187
- with gr.Blocks(title="UltraShape Inference") as demo:
188
- gr.Markdown("# UltraShape Inference: Mesh & Image Refinement")
189
-
190
- with gr.Row():
191
- with gr.Column():
192
- image_input = gr.Image(type="pil", label="Input Image", image_mode="RGBA")
193
- mesh_input = gr.Model3D(label="Input Coarse Mesh (.glb, .obj)")
194
-
195
- with gr.Accordion("Advanced Parameters", open=True):
196
- steps = gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Inference Steps")
197
- scale = gr.Slider(minimum=0.1, maximum=2.0, value=0.99, label="Mesh Normalization Scale")
198
- octree_res = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label="Octree Resolution")
199
- num_latents = gr.Slider(minimum=1024, maximum=32768, value=32768, step=128,
200
- label="Number of Latent Tokens (Use 8192 if OOM)")
201
- chunk_size = gr.Slider(minimum=512, maximum=10000, value=2048, step=512,
202
- label="Chunk Size (Use 2000 if OOM)")
203
- seed = gr.Number(value=42, label="Random Seed")
204
- remove_bg = gr.Checkbox(label="Remove Background", value=False)
205
-
206
- run_btn = gr.Button("Run Inference", variant="primary")
207
-
208
- with gr.Column():
209
- output_model = gr.Model3D(label="Refined Output Mesh")
210
-
211
- run_btn.click(
212
- fn=lambda img, mesh, s, sc, oct, nml, chk, sd, rm: predict(img, mesh, s, sc, oct, nml, chk, sd, rm, args.ckpt,
213
- args.low_vram),
214
- inputs=[image_input, mesh_input, steps, scale, octree_res, num_latents, chunk_size, seed, remove_bg],
215
- outputs=[output_model]
216
- )
217
-
218
- demo.launch(share=args.share, server_name='0.0.0.0', server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/infer_dit_refine.py DELETED
@@ -1,142 +0,0 @@
1
- import os
2
- import sys
3
- import argparse
4
- import torch
5
- import numpy as np
6
- from PIL import Image
7
- from omegaconf import OmegaConf
8
-
9
- # project_root = '[your_project_root_path]' # Replace with your project root path
10
- # sys.path.insert(0, project_root)
11
-
12
- from ultrashape.rembg import BackgroundRemover
13
- from ultrashape.utils.misc import instantiate_from_config
14
- from ultrashape.surface_loaders import SharpEdgeSurfaceLoader
15
- from ultrashape.utils import voxelize_from_point
16
- from ultrashape.pipelines import UltraShapePipeline
17
-
18
- def load_models(config_path, ckpt_path, device='cuda'):
19
-
20
- print(f"Loading config from {config_path}...")
21
- config = OmegaConf.load(config_path)
22
-
23
- print("Instantiating VAE...")
24
- vae = instantiate_from_config(config.model.params.vae_config)
25
-
26
- print("Instantiating DiT...")
27
- dit = instantiate_from_config(config.model.params.dit_cfg)
28
-
29
- print("Instantiating Conditioner...")
30
- conditioner = instantiate_from_config(config.model.params.conditioner_config)
31
-
32
- print("Instantiating Scheduler & Processor...")
33
- scheduler = instantiate_from_config(config.model.params.scheduler_cfg)
34
- image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
35
-
36
- print(f"Loading weights from {ckpt_path}...")
37
- weights = torch.load(ckpt_path, map_location='cpu')
38
-
39
- vae.load_state_dict(weights['vae'], strict=True)
40
- dit.load_state_dict(weights['dit'], strict=True)
41
- conditioner.load_state_dict(weights['conditioner'], strict=True)
42
-
43
- vae.eval().to(device)
44
- dit.eval().to(device)
45
- conditioner.eval().to(device)
46
-
47
- if hasattr(vae, 'enable_flashvdm_decoder'):
48
- vae.enable_flashvdm_decoder()
49
-
50
- components = {
51
- "vae": vae,
52
- "dit": dit,
53
- "conditioner": conditioner,
54
- "scheduler": scheduler,
55
- "image_processor": image_processor,
56
- }
57
-
58
- return components, config
59
-
60
- def run_inference(args):
61
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
-
63
- components, config = load_models(args.config, args.ckpt, device)
64
-
65
- pipeline = UltraShapePipeline(
66
- vae=components['vae'],
67
- model=components['dit'],
68
- scheduler=components['scheduler'],
69
- conditioner=components['conditioner'],
70
- image_processor=components['image_processor']
71
- )
72
-
73
- if args.low_vram:
74
- pipeline.enable_model_cpu_offload()
75
-
76
- token_num = args.num_latents
77
- voxel_res = config.model.params.vae_config.params.voxel_query_res
78
-
79
- print(f"Initializing Surface Loader (Token Num: {token_num})...")
80
- loader = SharpEdgeSurfaceLoader(
81
- num_sharp_points=204800,
82
- num_uniform_points=204800,
83
- )
84
-
85
- print(f"Processing inputs: {args.image} & {args.mesh}")
86
- image = Image.open(args.image)
87
-
88
- if args.remove_bg or image.mode != 'RGBA':
89
- rembg = BackgroundRemover()
90
- image = rembg(image)
91
-
92
- surface = loader(args.mesh, normalize_scale=args.scale).to(device, dtype=torch.float16)
93
- pc = surface[:, :, :3] # [B, N, 3]
94
-
95
- # Voxelize
96
- _, voxel_idx = voxelize_from_point(pc, token_num, resolution=voxel_res)
97
-
98
- print("Running diffusion process...")
99
- generator = torch.Generator(device).manual_seed(args.seed)
100
-
101
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
102
- mesh, _ = pipeline(
103
- image=image,
104
- voxel_cond=voxel_idx,
105
- generator=generator,
106
- box_v=1.0,
107
- mc_level=0.0,
108
- octree_resolution=args.octree_res,
109
- num_inference_steps=args.steps,
110
- num_chunks=args.chunk_size,
111
- )
112
-
113
- os.makedirs(args.output_dir, exist_ok=True)
114
- base_name = os.path.splitext(os.path.basename(args.image))[0]
115
- save_path = os.path.join(args.output_dir, f"{base_name}_refined.glb")
116
-
117
- mesh = mesh[0]
118
- mesh.export(save_path)
119
- print(f"Successfully saved to {save_path}")
120
-
121
- if __name__ == "__main__":
122
- parser = argparse.ArgumentParser(description="UltraShape Inference Script")
123
-
124
- parser.add_argument("--config", type=str, default="configs/infer_dit2.yaml", help="Path to inference config")
125
- parser.add_argument("--ckpt", type=str, required=True, help="Path to split checkpoint (.pt)")
126
- parser.add_argument("--low_vram", action="store_true", help="Optimize for low VRAM usage")
127
-
128
- parser.add_argument("--image", type=str, required=True, help="Input image path")
129
- parser.add_argument("--mesh", type=str, required=True, help="Input coarse mesh (.glb/.obj)")
130
- parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory")
131
-
132
- parser.add_argument("--steps", type=int, default=50, help="Inference steps")
133
- parser.add_argument("--scale", type=float, default=0.99, help="Mesh normalization scale")
134
- parser.add_argument("--num_latents", type=int, default=32768, help="Number of latents")
135
- parser.add_argument("--chunk_size", type=int, default=8000, help="Chunk size for inference")
136
- parser.add_argument("--octree_res", type=int, default=1024, help="Marching Cubes resolution")
137
- parser.add_argument("--seed", type=int, default=42, help="Random seed")
138
- parser.add_argument("--remove_bg", action="store_true", help="Force remove background")
139
-
140
- args = parser.parse_args()
141
-
142
- run_inference(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/install_env.sh DELETED
@@ -1,8 +0,0 @@
1
- conda create -n ultrashape python=3.10
2
- conda activate ultrashape
3
- pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
4
- pip install -r requirements.txt
5
- pip install git+https://github.com/ashawkey/cubvh --no-build-isolation
6
-
7
- pip install --no-build-isolation "git+https://github.com/facebookresearch/pytorch3d.git@stable"
8
- pip install https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_cluster-1.6.3%2Bpt25cu121-cp310-cp310-linux_x86_64.whl
 
 
 
 
 
 
 
 
 
scripts/run.sh DELETED
@@ -1,12 +0,0 @@
1
- # sampling
2
- # python scripts/sampling.py \
3
- # --mesh_json data/mesh_paths.json \
4
- # --output_dir data/sample
5
-
6
- # inference refine_dit
7
- python scripts/infer_dit_refine.py \
8
- --ckpt checkpoints/ultrashape_v1.pt \
9
- --image inputs/image/1.png \
10
- --mesh inputs/coarse_mesh/1.glb \
11
- --config configs/infer_dit_refine.yaml
12
- # --steps 12
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/sampling.py DELETED
@@ -1,586 +0,0 @@
1
- import os
2
- import trimesh
3
- import numpy as np
4
- from typing import List, Optional, Any, Tuple, Union
5
- import pytorch_lightning as pl
6
- from pytorch_lightning.utilities.types import STEP_OUTPUT
7
- import torch
8
- from torch.utils.data import Dataset, DataLoader
9
- import pytorch3d.structures
10
- import pytorch3d.ops
11
- from scipy.stats import truncnorm
12
- import json
13
- import argparse
14
- import cubvh
15
-
16
- # import logging
17
- # from tools.logger import init_log, set_all_log
18
- # sys_logger = init_log("sampler", logging.DEBUG)
19
- # set_all_log(level=logging.DEBUG, path='./debug/logs')
20
-
21
- def load_mesh(mesh_path: str, device: str = "cuda") -> Tuple[torch.Tensor, torch.Tensor]:
22
- if mesh_path.endswith(".npz"):
23
- mesh_np = np.load(mesh_path)
24
- vertices, faces = torch.tensor(mesh_np["vertices"], device=device), torch.tensor(mesh_np["faces"].astype('i8'), device=device)
25
- else:
26
- mesh = trimesh.load(mesh_path, force='mesh')
27
- vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device=device)
28
- faces = torch.tensor(mesh.faces, dtype=torch.long, device=device)
29
- if faces.shape[0] > 2 * 1e8:
30
- raise ValueError(f"too many faces {faces.shape}")
31
- return vertices, faces
32
-
33
- def compute_mesh_features(vertices: torch.Tensor, faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
34
- device = vertices.device
35
-
36
- v0 = vertices[faces[:, 0]]
37
- v1 = vertices[faces[:, 1]]
38
- v2 = vertices[faces[:, 2]]
39
- face_normals = torch.cross(v1 - v0, v2 - v0)
40
- face_areas = torch.norm(face_normals, dim=1) * 0.5
41
- face_normals = face_normals / (face_areas.unsqueeze(1) * 2 + 1e-12)
42
-
43
- vertex_normals = torch.zeros_like(vertices)
44
- face_normals_weighted = face_normals * face_areas.unsqueeze(1)
45
-
46
- vertex_normals.scatter_add_(0, faces[:, 0:1].expand(-1, 3), face_normals_weighted)
47
- vertex_normals.scatter_add_(0, faces[:, 1:2].expand(-1, 3), face_normals_weighted)
48
- vertex_normals.scatter_add_(0, faces[:, 2:3].expand(-1, 3), face_normals_weighted)
49
-
50
- vertex_normals = vertex_normals / (torch.norm(vertex_normals, dim=1, keepdim=True) + 1e-12)
51
-
52
- edges = torch.cat([
53
- faces[:, [0, 1]],
54
- faces[:, [1, 2]],
55
- faces[:, [2, 0]]
56
- ], dim=0)
57
-
58
- edges_unique, edges_inverse = torch.unique(torch.sort(edges, dim=1)[0], dim=0, return_inverse=True)
59
- edge_normals_diff = torch.norm(
60
- vertex_normals[edges[:, 0]] - vertex_normals[edges[:, 1]],
61
- dim=1
62
- )
63
-
64
- vertex_curvatures = torch.zeros(len(vertices), device=device)
65
- vertex_curvatures.scatter_add_(0, edges[:, 0], edge_normals_diff)
66
- vertex_curvatures.scatter_add_(0, edges[:, 1], edge_normals_diff)
67
-
68
- vertex_degrees = torch.zeros(len(vertices), device=device)
69
- vertex_degrees.scatter_add_(0, edges[:, 0], torch.ones_like(edge_normals_diff))
70
- vertex_degrees.scatter_add_(0, edges[:, 1], torch.ones_like(edge_normals_diff))
71
-
72
- vertex_curvatures = vertex_curvatures / (vertex_degrees + 1e-12)
73
- vertex_curvatures = (vertex_curvatures - vertex_curvatures.min()) / (
74
- vertex_curvatures.max() - vertex_curvatures.min() + 1e-12)
75
-
76
- return face_areas, vertex_curvatures
77
-
78
- def sample_uniform_points(
79
- vertices: torch.Tensor,
80
- faces: torch.Tensor,
81
- num_samples: int,
82
- random_seed: Optional[int] = None
83
- ) -> Tuple[torch.Tensor, torch.Tensor]:
84
-
85
- if random_seed is not None:
86
- torch.manual_seed(random_seed)
87
- mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces])
88
-
89
- points, normals = pytorch3d.ops.sample_points_from_meshes(
90
- mesh, num_samples=num_samples, return_normals=True)
91
-
92
- return points[0], normals[0]
93
-
94
- def sample_surface_points(
95
- vertices: torch.Tensor,
96
- faces: torch.Tensor,
97
- num_samples: int,
98
- min_samples_per_face: int = 0,
99
- use_curvature: bool = True,
100
- random_seed: Optional[int] = None
101
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
102
- """Curvature-based surface sampling"""
103
- device = vertices.device
104
- if random_seed is not None:
105
- torch.manual_seed(random_seed)
106
-
107
- # Compute face areas and vertex curvatures
108
- face_areas, vertex_curvatures = compute_mesh_features(vertices, faces)
109
-
110
- # Compute average curvature of faces
111
- face_curvatures = torch.mean(vertex_curvatures[faces], dim=1)
112
- sampling_weights = face_curvatures # Use only curvature as weights
113
- # Calculate number of sample points per face
114
- num_faces = len(faces)
115
-
116
- # Chunk forward
117
- if min_samples_per_face > 0:
118
- base_samples = torch.full((num_faces,), min_samples_per_face, device=device)
119
- remaining_samples = num_samples - torch.sum(base_samples).item()
120
-
121
- if remaining_samples > 0:
122
- # Block sampling to avoid large mesh issues
123
- if num_faces > 2**24:
124
- chunk_size = 1000000 # Process 1 million faces at a time
125
- additional_counts = torch.zeros(num_faces, device=device)
126
-
127
- for start in range(0, num_faces, chunk_size):
128
- end = min(start + chunk_size, num_faces)
129
- chunk_weights = sampling_weights[start:end]
130
- chunk_probs = chunk_weights / chunk_weights.sum()
131
-
132
- # Proportinally allocate remaining samples
133
- chunk_samples = int(remaining_samples * (end - start) / num_faces)
134
- samples = torch.multinomial(chunk_probs, chunk_samples, replacement=True)
135
- chunk_counts = torch.bincount(samples, minlength=chunk_size)
136
- additional_counts[start:end] += chunk_counts[:end-start]
137
-
138
- sample_counts = additional_counts + base_samples
139
- else:
140
- probs = sampling_weights / sampling_weights.sum()
141
- additional_samples = torch.multinomial(probs, remaining_samples, replacement=True)
142
- sample_counts = torch.bincount(additional_samples, minlength=num_faces) + base_samples
143
- else:
144
- sample_counts = base_samples
145
- else:
146
- if num_faces > 2**24:
147
- # Chunk sampling strategy
148
- sample_counts = torch.zeros(num_faces, device=device)
149
- chunk_size = 1000000 # Process 1 million faces at a time
150
- chunk_samples = num_samples // ((num_faces + chunk_size - 1) // chunk_size)
151
-
152
- for start in range(0, num_faces, chunk_size):
153
- end = min(start + chunk_size, num_faces)
154
- chunk_weights = sampling_weights[start:end]
155
- chunk_probs = chunk_weights / chunk_weights.sum()
156
-
157
- samples = torch.multinomial(chunk_probs, chunk_samples, replacement=True)
158
- chunk_counts = torch.bincount(samples, minlength=chunk_size)
159
- sample_counts[start:end] += chunk_counts[:end-start]
160
- else:
161
- probs = sampling_weights / sampling_weights.sum()
162
- samples = torch.multinomial(probs, num_samples, replacement=True)
163
- sample_counts = torch.bincount(samples, minlength=num_faces)
164
-
165
- # Generate barycentric coordinates for sampled points
166
- total_samples = sample_counts.sum().item()
167
- r1 = torch.sqrt(torch.rand(total_samples, device=device))
168
- r2 = torch.rand(total_samples, device=device)
169
-
170
- barycentric_coords = torch.stack([
171
- 1 - r1,
172
- r1 * (1 - r2),
173
- r1 * r2
174
- ], dim=1)
175
-
176
- # Generate face indices
177
- face_indices = torch.repeat_interleave(
178
- torch.arange(num_faces, device=device),
179
- sample_counts
180
- )
181
-
182
- # Get vertices of corresponding faces
183
- face_vertices = vertices[faces[face_indices]]
184
-
185
- # Compute 3D coordinates of sampled points
186
- points = (barycentric_coords.unsqueeze(1) @ face_vertices).squeeze(1)
187
-
188
- # Compute normal vectors of sampled points
189
- v0, v1, v2 = face_vertices[:, 0], face_vertices[:, 1], face_vertices[:, 2]
190
- face_normals = torch.cross(v1 - v0, v2 - v0)
191
- normals = face_normals / (torch.norm(face_normals, dim=1, keepdim=True) + 1e-12)
192
-
193
- return points, face_indices, normals
194
-
195
- def normalize_points_and_mesh(vertices: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
196
- """Normalize mesh and point cloud to unit cube"""
197
- device = vertices.device
198
- vmin = vertices.min(dim=0)[0]
199
- vmax = vertices.max(dim=0)[0]
200
- center = (vmax + vmin) / 2
201
- scale = (vmax - vmin).max()
202
- margin = 0.01
203
- scale = scale * (1 + 2 * margin)
204
-
205
- vertices_normalized = (vertices - center) / scale + 0.5
206
- points_normalized = (points - center) / scale + 0.5
207
-
208
- return vertices_normalized, points_normalized, center, scale
209
-
210
- def add_gaussian_noise(uniform_surface_points: torch.Tensor, curvature_surface_points: torch.Tensor, sigma: float = 0.01) -> torch.Tensor:
211
- """Add Gaussian noise to point cloud"""
212
- # noise = torch.randn_like(points) * sigma
213
- # print("u_num:",uniform_surface_points.shape)
214
- # print("c_num:",curvature_surface_points.shape)
215
-
216
- idx1 = torch.randperm(uniform_surface_points.shape[0])
217
- idx2 = torch.randperm(curvature_surface_points.shape[0])
218
- uniform_surface_points = uniform_surface_points[idx1]
219
- curvature_surface_points = curvature_surface_points[idx2]
220
-
221
- a, b = -0.25, 0.25
222
- mu = 0
223
-
224
- # get near points (add offset on surface points)
225
- offset1 = torch.tensor(truncnorm.rvs((a - mu) / 0.005, (b - mu) / 0.005, loc=mu, scale=0.005, size=(len(uniform_surface_points), 3)),
226
- dtype=uniform_surface_points.dtype, device=uniform_surface_points.device)
227
- offset2 = torch.tensor(truncnorm.rvs((a - mu) / 0.05, (b - mu) / 0.05, loc=mu, scale=0.05, size=(len(uniform_surface_points), 3)),
228
- dtype=uniform_surface_points.dtype, device=uniform_surface_points.device)
229
- uniform_near_points = torch.cat([
230
- uniform_surface_points + offset1,
231
- uniform_surface_points + offset2
232
- ], dim=0)
233
-
234
- # Generate multi-scale noise for curvature sample points
235
- unit_num = curvature_surface_points.shape[0] // 6
236
- scales = [0.001, 0.003, 0.006, 0.01, 0.02, 0.04]
237
-
238
- curvature_near_points = []
239
- for i in range(6):
240
- start = i * unit_num
241
- end = (i + 1) * unit_num if i < 5 else curvature_surface_points.shape[0]
242
- noise = torch.randn((end - start, 3), dtype=curvature_surface_points.dtype,
243
- device=curvature_surface_points.device) * scales[i]
244
- curvature_near_points.append(curvature_surface_points[start:end] + noise)
245
-
246
- curvature_near_points = torch.cat(curvature_near_points, dim=0)
247
-
248
- return uniform_near_points, curvature_near_points
249
-
250
- def compute_points_value_bvh(
251
- vertices: torch.Tensor,
252
- faces: torch.Tensor,
253
- points: torch.Tensor,
254
- use_sdf: bool = True,
255
- batch_size: int = 100_00000
256
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
257
- """Compute SDF or occupancy values for sampled points"""
258
- device = vertices.device
259
-
260
- # Normalize mesh and point cloud
261
- vertices_norm, points_norm, center, scale = normalize_points_and_mesh(vertices, points)
262
-
263
- BVH = cubvh.cuBVH(vertices_norm, faces)
264
- distances, face_id, uvw = BVH.signed_distance(points, return_uvw=True, mode='watertight')
265
- values = distances
266
-
267
- return values, points_norm, center, scale
268
-
269
- def save_point_cloud(
270
- points: torch.Tensor,
271
- output_path: str,
272
- normals: Optional[torch.Tensor] = None,
273
- colors: Optional[torch.Tensor] = None
274
- ) -> None:
275
- """Save point cloud to file"""
276
- points_np = points.cpu().numpy()
277
- normals_np = normals.cpu().numpy() if normals is not None else None
278
- colors_np = None
279
-
280
- if colors is not None:
281
- colors_np = colors.cpu().numpy()
282
- if colors_np.max() <= 1.0:
283
- colors_np = (colors_np * 255).astype(np.uint8)
284
-
285
- ext = os.path.splitext(output_path)[1].lower()
286
-
287
- if ext == '.txt':
288
- data_list = [points_np]
289
- if normals_np is not None:
290
- data_list.append(normals_np)
291
- if colors_np is not None:
292
- data_list.append(colors_np)
293
-
294
- combined_data = np.hstack(data_list)
295
- np.savetxt(output_path, combined_data, fmt='%.6f')
296
-
297
- elif ext == '.ply':
298
- cloud = trimesh.PointCloud(points_np, colors=colors_np)
299
- if normals_np is not None:
300
- cloud.metadata['normals'] = normals_np
301
- cloud.export(output_path)
302
-
303
- else:
304
- raise ValueError(f"Unsupported file format: {ext}. Please use .txt or .ply")
305
-
306
- def sample_points_in_bbox(
307
- bbox_min: torch.Tensor,
308
- bbox_max: torch.Tensor,
309
- num_samples: int,
310
- device: str = "cuda"
311
- ) -> torch.Tensor:
312
- """Uniformly sample points within bounding box"""
313
- points = torch.rand(num_samples, 3, device=device)
314
- points = points * (bbox_max - bbox_min) + bbox_min
315
- return points
316
-
317
- def process_single_mesh(
318
- mesh_name:str,
319
- mesh_path: str,
320
- output_dir: str,
321
- data_type:str = 'mesh',
322
- surface_uniform_samples: int = 100000, # surface上均匀采样点数
323
- surface_curvature_samples: int = 200000, # surface上曲率采样点数
324
- space_samples: int = 300000, # 空间中采样点数
325
- noise_sigma: float = 0.01,
326
- device: str = "cuda"
327
- ) -> None:
328
- """Process a single mesh file
329
- Args:
330
- mesh_path: Input mesh path
331
- output_dir: Output directory
332
- surface_uniform_samples: Number of uniform sample points on surface
333
- surface_curvature_samples: Number of curvature-based sample points on surface
334
- space_samples: Number of sample points in space
335
- noise_sigma: Gaussian noise standard deviation
336
- device: Computation device
337
- """
338
- os.makedirs(output_dir, exist_ok=True)
339
-
340
- if data_type == "mesh":
341
- vertices, faces = load_mesh(mesh_path, device)
342
- elif data_type == "sparse_voxel":
343
- pass
344
- vertices_normalized, _, center, scale = normalize_points_and_mesh(vertices, vertices)
345
-
346
- space_points = torch.rand(space_samples, 3, device=device)
347
-
348
- uniform_surface_points, uniform_surface_normals = sample_uniform_points(
349
- vertices=vertices_normalized,
350
- faces=faces,
351
- num_samples=surface_uniform_samples
352
- )
353
-
354
- curvature_surface_points, _, curvature_surface_normals = sample_surface_points(
355
- vertices=vertices_normalized,
356
- faces=faces,
357
- num_samples=surface_curvature_samples,
358
- use_curvature=True
359
- )
360
-
361
- clean_surface_points = torch.cat([uniform_surface_points, curvature_surface_points], dim=0)
362
- clean_surface_normals = torch.cat([uniform_surface_normals, curvature_surface_normals], dim=0)
363
-
364
- surface_uni_save_path = os.path.join(output_dir, f"{mesh_name}_uni_surface")
365
- save_point_cloud(
366
- points=uniform_surface_points,
367
- output_path=f"{surface_uni_save_path}.ply",
368
- normals=uniform_surface_normals
369
- )
370
-
371
- surface_cur_save_path = os.path.join(output_dir, f"{mesh_name}_cur_surface")
372
- save_point_cloud(
373
- points=curvature_surface_points,
374
- output_path=f"{surface_cur_save_path}.ply",
375
- normals=curvature_surface_normals
376
- )
377
-
378
- uniform_near_points, curvature_near_points = add_gaussian_noise(uniform_surface_points = uniform_surface_points.clone(),
379
- curvature_surface_points = curvature_surface_points.clone(), sigma=noise_sigma)
380
-
381
- space_sdf, _, _, _ = compute_points_value_bvh(
382
- vertices=vertices_normalized,
383
- faces=faces,
384
- points=space_points,
385
- use_sdf=True,
386
- batch_size=1000_00000
387
- )
388
-
389
- # clean_surface_sdf = torch.zeros(len(clean_surface_points), device=device)
390
- uniform_near_sdf, _, _, _ = compute_points_value_bvh(
391
- vertices=vertices_normalized,
392
- faces=faces,
393
- points=uniform_near_points,
394
- use_sdf=True,
395
- batch_size=1000_00000
396
- )
397
-
398
- curvature_near_sdf, _, _, _ = compute_points_value_bvh(
399
- vertices=vertices_normalized,
400
- faces=faces,
401
- points=curvature_near_points,
402
- use_sdf=True,
403
- batch_size=1000_00000
404
- )
405
-
406
- print("sdf:",uniform_near_sdf.shape, curvature_near_sdf.shape)
407
-
408
- base_save_path = os.path.join(output_dir, mesh_name)
409
-
410
- np.savez(f"{base_save_path}.npz",
411
- space_points=space_points.cpu().numpy(),
412
- space_sdf=space_sdf.cpu().numpy(),
413
- clean_surface_points=clean_surface_points.cpu().numpy(),
414
- clean_surface_normals=clean_surface_normals.cpu().numpy(),
415
- uniform_near_points=uniform_near_points.cpu().numpy(),
416
- curvature_near_points=curvature_near_points.cpu().numpy(),
417
- uniform_near_sdf=uniform_near_sdf.cpu().numpy(),
418
- curvature_near_sdf=curvature_near_sdf.cpu().numpy(),
419
- center=center.cpu().numpy(),
420
- scale=scale.cpu().numpy())
421
-
422
- class MeshDataset(Dataset):
423
- def __init__(self, mesh_json: str):
424
- with open(mesh_json, "r") as f:
425
- self.mesh_paths = json.load(f)
426
- # print(len(self.mesh_paths))
427
-
428
- def __len__(self) -> int:
429
- return len(self.mesh_paths)
430
- def __getitem__(self, idx: int) -> dict:
431
- mesh_path = self.mesh_paths[idx]
432
- mesh_name = os.path.basename(mesh_path)[:-4]
433
- mesh = {
434
- "mesh_path": mesh_path,
435
- "mesh_name": mesh_name,
436
- }
437
- return mesh
438
-
439
- class MeshProcessor(pl.LightningModule):
440
- def __init__(
441
- self,
442
- mesh_json: str,
443
- output_dir: str,
444
- data_type:str,
445
- surface_uniform_samples: int = 20000,
446
- surface_curvature_samples: int = 40000,
447
- space_samples: int = 300000,
448
- noise_sigma: float = 0.01,
449
- batch_size: int = 1,
450
- num_workers: int = 4
451
- ):
452
- super().__init__()
453
- self.save_hyperparameters()
454
- os.makedirs(output_dir, exist_ok=True)
455
-
456
- def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
457
- mesh_path = batch["mesh_path"][0]
458
- mesh_name = batch["mesh_name"][0]
459
-
460
- # sys_logger.info(f"Processing {batch_idx}/{len(self.trainer.predict_dataloaders)}: {mesh_name} from {mesh_path}")
461
-
462
- output_subdir = self.hparams.output_dir
463
-
464
- try:
465
- filename = os.path.splitext(os.path.basename(mesh_path))[0]
466
- if os.path.exists(os.path.join(output_subdir, f"{filename}.npz")):
467
- # sys_logger.info(f"Skipping {mesh_name} as it already exists.")
468
- return {
469
- "status": "success",
470
- "mesh_name": mesh_name
471
- }
472
- process_single_mesh(
473
- mesh_name=mesh_name,
474
- mesh_path=mesh_path,
475
- output_dir=output_subdir,
476
- data_type = self.hparams.data_type,
477
- surface_uniform_samples=self.hparams.surface_uniform_samples,
478
- surface_curvature_samples=self.hparams.surface_curvature_samples,
479
- space_samples=self.hparams.space_samples,
480
- noise_sigma=self.hparams.noise_sigma,
481
- device=self.device
482
- )
483
-
484
- return {
485
- "status": "success",
486
- "mesh_name": mesh_name
487
- }
488
-
489
- except Exception as e:
490
- print(f"Error processing {mesh_name}: {str(e)}")
491
- return {
492
- "status": "error",
493
- "mesh_name": mesh_name,
494
- "error": str(e)
495
- }
496
-
497
- def predict_dataloader(self) -> DataLoader:
498
- dataset = MeshDataset(
499
- self.hparams.mesh_json)
500
- return DataLoader(
501
- dataset,
502
- batch_size=self.hparams.batch_size,
503
- num_workers=self.hparams.num_workers,
504
- persistent_workers=True,
505
- shuffle=False
506
- )
507
-
508
- def process_mesh_directory(
509
- mesh_json: str,
510
- output_dir: str,
511
- data_type: str,
512
- surface_uniform_samples: int = 100000,
513
- surface_curvature_samples: int = 200000,
514
- space_samples: int = 300000,
515
- noise_sigma: float = 0.01,
516
- num_gpus: int = -1,
517
- batch_size: int = 1,
518
- num_workers: int = 4
519
- ) -> None:
520
- model = MeshProcessor(
521
- mesh_json=mesh_json,
522
- output_dir=output_dir,
523
- data_type=data_type,
524
- surface_uniform_samples=surface_uniform_samples,
525
- surface_curvature_samples=surface_curvature_samples,
526
- space_samples=space_samples,
527
- noise_sigma=noise_sigma,
528
- batch_size=batch_size,
529
- num_workers=num_workers
530
- )
531
-
532
- trainer = pl.Trainer(
533
- accelerator="gpu",
534
- devices=num_gpus,
535
- strategy="ddp",
536
- precision=32,
537
- logger=False,
538
- enable_progress_bar=True
539
- )
540
-
541
- predictions = trainer.predict(model)
542
-
543
- success_count = sum(1 for p in predictions if p["status"] == "success")
544
- error_count = sum(1 for p in predictions if p["status"] == "error")
545
-
546
- print(f"\nProcessing completed:")
547
- print(f"Successfully processed: {success_count} files")
548
- print(f"Failed to process: {error_count} files")
549
-
550
- if error_count > 0:
551
- print("\nFailed files:")
552
- for p in predictions:
553
- if p["status"] == "error":
554
- print(f"- {p['mesh_name']}: {p['error']}")
555
-
556
- if __name__ == "__main__":
557
-
558
- parser = argparse.ArgumentParser(description="Process Mesh Directory for Sampling")
559
-
560
- parser.add_argument("--mesh_json", type=str, default="test_mesh.json", help="Path to the mesh json file")
561
- parser.add_argument("--output_dir", type=str, default="ultrashape_test1", help="Directory to save outputs")
562
-
563
- parser.add_argument("--surface_uniform_samples", type=int, default=300000, help="Number of uniform samples on surface")
564
- parser.add_argument("--surface_curvature_samples", type=int, default=300000, help="Number of curvature-based samples on surface")
565
- parser.add_argument("--space_samples", type=int, default=400000, help="Number of samples in space")
566
-
567
- parser.add_argument("--noise_sigma", type=float, default=0.01, help="Sigma for Gaussian noise")
568
- parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use")
569
- parser.add_argument("--num_workers", type=int, default=16, help="Number of data loading workers")
570
- parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU")
571
-
572
- args = parser.parse_args()
573
- # print(f"Arguments: {args}")
574
-
575
- process_mesh_directory(
576
- mesh_json=args.mesh_json,
577
- output_dir=args.output_dir,
578
- data_type='mesh',
579
- surface_uniform_samples=args.surface_uniform_samples,
580
- surface_curvature_samples=args.surface_curvature_samples,
581
- space_samples=args.space_samples,
582
- noise_sigma=args.noise_sigma,
583
- num_gpus=args.num_gpus,
584
- num_workers=args.num_workers,
585
- batch_size=args.batch_size
586
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/train_deepspeed.sh DELETED
@@ -1,64 +0,0 @@
1
-
2
- export NCCL_IB_TIMEOUT=24
3
- export NCCL_NVLS_ENABLE=0
4
- NET_TYPE="high"
5
- if [[ "${NET_TYPE}" = "low" ]]; then
6
- export NCCL_SOCKET_IFNAME=eth1
7
- export NCCL_IB_GID_INDEX=3
8
- export NCCL_IB_HCA=mlx5_2:1,mlx5_2:1
9
- export NCCL_IB_SL=3
10
- export NCCL_CHECKS_DISABLE=1
11
- export NCCL_P2P_DISABLE=0
12
- export NCCL_LL_THRESHOLD=16384
13
- export NCCL_IB_CUDA_SUPPORT=1
14
- else
15
- export NCCL_IB_GID_INDEX=3
16
- export NCCL_IB_SL=3
17
- export NCCL_CHECKS_DISABLE=1
18
- export NCCL_P2P_DISABLE=0
19
- export NCCL_IB_DISABLE=0
20
- export NCCL_LL_THRESHOLD=16384
21
- export NCCL_IB_CUDA_SUPPORT=1
22
- export NCCL_SOCKET_IFNAME=bond1
23
- export NCCL_COLLNET_ENABLE=0
24
- export SHARP_COLL_ENABLE_SAT=0
25
- export NCCL_NET_GDR_LEVEL=2
26
- export NCCL_IB_QPS_PER_CONNECTION=4
27
- export NCCL_IB_TC=160
28
- export NCCL_PXN_DISABLE=1
29
- fi
30
- # export NCCL_DEBUG=INFO
31
-
32
- node_num=$1
33
- node_rank=$2
34
- num_gpu_per_node=$3
35
- master_ip=$4
36
- config=$5
37
- output_dir=$6
38
-
39
- echo node_num $node_num
40
- echo node_rank $node_rank
41
- echo master_ip $master_ip
42
- echo config $config
43
- echo output_dir $output_dir
44
-
45
- if test -d "$output_dir"; then
46
- cp $config $output_dir
47
- else
48
- mkdir -p "$output_dir"
49
- cp $config $output_dir
50
- fi
51
-
52
- NODE_RANK=$node_rank \
53
- HF_HUB_OFFLINE=0 \
54
- MASTER_PORT=12348 \
55
- MASTER_ADDR=$master_ip \
56
- NCCL_SOCKET_IFNAME=bond1 \
57
- NCCL_IB_GID_INDEX=3 \
58
- NCCL_NVLS_ENABLE=0 \
59
- python3 main.py \
60
- --num_nodes $node_num \
61
- --num_gpus $num_gpu_per_node \
62
- --config $config \
63
- --output_dir $output_dir \
64
- --deepspeed