Spaces:
Running on Zero
Running on Zero
Delete scripts
Browse files- scripts/gradio_app.py +0 -218
- scripts/infer_dit_refine.py +0 -142
- scripts/install_env.sh +0 -8
- scripts/run.sh +0 -12
- scripts/sampling.py +0 -586
- scripts/train_deepspeed.sh +0 -64
scripts/gradio_app.py
DELETED
|
@@ -1,218 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import gc
|
| 3 |
-
import os
|
| 4 |
-
import sys
|
| 5 |
-
|
| 6 |
-
import gradio as gr
|
| 7 |
-
import torch
|
| 8 |
-
from omegaconf import OmegaConf
|
| 9 |
-
|
| 10 |
-
# Add project root to path
|
| 11 |
-
sys.path.append(os.getcwd())
|
| 12 |
-
|
| 13 |
-
from ultrashape.rembg import BackgroundRemover
|
| 14 |
-
from ultrashape.utils.misc import instantiate_from_config
|
| 15 |
-
from ultrashape.surface_loaders import SharpEdgeSurfaceLoader
|
| 16 |
-
from ultrashape.utils import voxelize_from_point
|
| 17 |
-
from ultrashape.pipelines import UltraShapePipeline
|
| 18 |
-
|
| 19 |
-
# Global variables to cache the model
|
| 20 |
-
MODEL_CACHE = {}
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def get_pipeline_cached(config_path, ckpt_path, device='cuda', low_vram=False):
|
| 24 |
-
# Check if we have a valid cached pipeline for this checkpoint
|
| 25 |
-
if "pipeline" in MODEL_CACHE and MODEL_CACHE.get("ckpt_path") == ckpt_path:
|
| 26 |
-
print("Using cached pipeline...")
|
| 27 |
-
return MODEL_CACHE["pipeline"], MODEL_CACHE["config"]
|
| 28 |
-
|
| 29 |
-
# Clear old cache if it exists (e.g. different checkpoint)
|
| 30 |
-
if MODEL_CACHE:
|
| 31 |
-
print("Clearing old model cache...")
|
| 32 |
-
MODEL_CACHE.clear()
|
| 33 |
-
gc.collect()
|
| 34 |
-
torch.cuda.empty_cache()
|
| 35 |
-
|
| 36 |
-
print(f"Loading config from {config_path}...")
|
| 37 |
-
config = OmegaConf.load(config_path)
|
| 38 |
-
|
| 39 |
-
print("Instantiating VAE...")
|
| 40 |
-
vae = instantiate_from_config(config.model.params.vae_config)
|
| 41 |
-
|
| 42 |
-
print("Instantiating DiT...")
|
| 43 |
-
dit = instantiate_from_config(config.model.params.dit_cfg)
|
| 44 |
-
|
| 45 |
-
print("Instantiating Conditioner...")
|
| 46 |
-
conditioner = instantiate_from_config(config.model.params.conditioner_config)
|
| 47 |
-
|
| 48 |
-
print("Instantiating Scheduler & Processor...")
|
| 49 |
-
scheduler = instantiate_from_config(config.model.params.scheduler_cfg)
|
| 50 |
-
image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
|
| 51 |
-
|
| 52 |
-
print(f"Loading weights from {ckpt_path}...")
|
| 53 |
-
weights = torch.load(ckpt_path, map_location='cpu')
|
| 54 |
-
|
| 55 |
-
vae.load_state_dict(weights['vae'], strict=True)
|
| 56 |
-
dit.load_state_dict(weights['dit'], strict=True)
|
| 57 |
-
conditioner.load_state_dict(weights['conditioner'], strict=True)
|
| 58 |
-
|
| 59 |
-
vae.eval().to(device)
|
| 60 |
-
dit.eval().to(device)
|
| 61 |
-
conditioner.eval().to(device)
|
| 62 |
-
|
| 63 |
-
if hasattr(vae, 'enable_flashvdm_decoder'):
|
| 64 |
-
vae.enable_flashvdm_decoder()
|
| 65 |
-
|
| 66 |
-
print("Creating Pipeline...")
|
| 67 |
-
pipeline = UltraShapePipeline(
|
| 68 |
-
vae=vae,
|
| 69 |
-
model=dit,
|
| 70 |
-
scheduler=scheduler,
|
| 71 |
-
conditioner=conditioner,
|
| 72 |
-
image_processor=image_processor
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
if low_vram:
|
| 76 |
-
pipeline.enable_model_cpu_offload()
|
| 77 |
-
|
| 78 |
-
MODEL_CACHE["pipeline"] = pipeline
|
| 79 |
-
MODEL_CACHE["config"] = config
|
| 80 |
-
MODEL_CACHE["ckpt_path"] = ckpt_path
|
| 81 |
-
|
| 82 |
-
return pipeline, config
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def predict(
|
| 86 |
-
image_input,
|
| 87 |
-
mesh_input,
|
| 88 |
-
steps,
|
| 89 |
-
scale,
|
| 90 |
-
octree_res,
|
| 91 |
-
num_latents,
|
| 92 |
-
chunk_size,
|
| 93 |
-
seed,
|
| 94 |
-
remove_bg,
|
| 95 |
-
ckpt_path,
|
| 96 |
-
low_vram
|
| 97 |
-
):
|
| 98 |
-
# Aggressive memory cleanup at start
|
| 99 |
-
gc.collect()
|
| 100 |
-
torch.cuda.empty_cache()
|
| 101 |
-
|
| 102 |
-
try:
|
| 103 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 104 |
-
config_path = "configs/infer_dit_refine.yaml"
|
| 105 |
-
|
| 106 |
-
if not os.path.exists(config_path):
|
| 107 |
-
raise FileNotFoundError(f"Config not found at {config_path}")
|
| 108 |
-
|
| 109 |
-
pipeline, config = get_pipeline_cached(config_path, ckpt_path, device, low_vram)
|
| 110 |
-
|
| 111 |
-
voxel_res = config.model.params.vae_config.params.voxel_query_res
|
| 112 |
-
|
| 113 |
-
print(f"Initializing Surface Loader (Token Num: {num_latents})...")
|
| 114 |
-
loader = SharpEdgeSurfaceLoader(
|
| 115 |
-
num_sharp_points=204800,
|
| 116 |
-
num_uniform_points=204800,
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
print(f"Processing inputs...")
|
| 120 |
-
if image_input is None:
|
| 121 |
-
raise ValueError("Image input is required")
|
| 122 |
-
if mesh_input is None:
|
| 123 |
-
raise ValueError("Mesh input is required")
|
| 124 |
-
|
| 125 |
-
# Handle image input
|
| 126 |
-
if isinstance(image_input, dict):
|
| 127 |
-
# In newer gradio versions Image component might return a dict for mask etc, but usually just PIL/numpy
|
| 128 |
-
# if type='pil' it is PIL.Image
|
| 129 |
-
pass
|
| 130 |
-
|
| 131 |
-
image = image_input.convert("RGBA")
|
| 132 |
-
|
| 133 |
-
if remove_bg or image.mode != 'RGBA':
|
| 134 |
-
rembg = BackgroundRemover()
|
| 135 |
-
image = rembg(image)
|
| 136 |
-
|
| 137 |
-
# Handle mesh input - Gradio Model3D returns path to file
|
| 138 |
-
surface = loader(mesh_input, normalize_scale=scale).to(device, dtype=torch.float16)
|
| 139 |
-
pc = surface[:, :, :3] # [B, N, 3]
|
| 140 |
-
|
| 141 |
-
# Voxelize
|
| 142 |
-
_, voxel_idx = voxelize_from_point(pc, num_latents, resolution=voxel_res)
|
| 143 |
-
|
| 144 |
-
print("Running diffusion process...")
|
| 145 |
-
gen_device = "cpu" if low_vram else device
|
| 146 |
-
generator = torch.Generator(gen_device).manual_seed(int(seed))
|
| 147 |
-
|
| 148 |
-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 149 |
-
mesh_out_list, _ = pipeline(
|
| 150 |
-
image=image,
|
| 151 |
-
voxel_cond=voxel_idx,
|
| 152 |
-
generator=generator,
|
| 153 |
-
box_v=1.0,
|
| 154 |
-
mc_level=0.0,
|
| 155 |
-
octree_resolution=int(octree_res),
|
| 156 |
-
num_chunks=int(chunk_size),
|
| 157 |
-
num_inference_steps=int(steps)
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
# Save output
|
| 161 |
-
output_dir = "outputs_gradio"
|
| 162 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 163 |
-
base_name = "output"
|
| 164 |
-
save_path = os.path.join(output_dir, f"{base_name}_refined.glb")
|
| 165 |
-
|
| 166 |
-
mesh_out = mesh_out_list[0]
|
| 167 |
-
mesh_out.export(save_path)
|
| 168 |
-
print(f"Successfully saved to {save_path}")
|
| 169 |
-
|
| 170 |
-
return save_path
|
| 171 |
-
|
| 172 |
-
finally:
|
| 173 |
-
# Aggressive memory cleanup at end
|
| 174 |
-
gc.collect()
|
| 175 |
-
torch.cuda.empty_cache()
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
if __name__ == "__main__":
|
| 179 |
-
parser = argparse.ArgumentParser(description="UltraShape Gradio App")
|
| 180 |
-
parser.add_argument("--ckpt", type=str, required=True, help="Path to split checkpoint (.pt)")
|
| 181 |
-
parser.add_argument("--share", action="store_true", help="Share the gradio app")
|
| 182 |
-
parser.add_argument("--low_vram", action="store_true", help="Optimize for low VRAM usage")
|
| 183 |
-
|
| 184 |
-
args = parser.parse_args()
|
| 185 |
-
|
| 186 |
-
# Define Gradio Interface
|
| 187 |
-
with gr.Blocks(title="UltraShape Inference") as demo:
|
| 188 |
-
gr.Markdown("# UltraShape Inference: Mesh & Image Refinement")
|
| 189 |
-
|
| 190 |
-
with gr.Row():
|
| 191 |
-
with gr.Column():
|
| 192 |
-
image_input = gr.Image(type="pil", label="Input Image", image_mode="RGBA")
|
| 193 |
-
mesh_input = gr.Model3D(label="Input Coarse Mesh (.glb, .obj)")
|
| 194 |
-
|
| 195 |
-
with gr.Accordion("Advanced Parameters", open=True):
|
| 196 |
-
steps = gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Inference Steps")
|
| 197 |
-
scale = gr.Slider(minimum=0.1, maximum=2.0, value=0.99, label="Mesh Normalization Scale")
|
| 198 |
-
octree_res = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label="Octree Resolution")
|
| 199 |
-
num_latents = gr.Slider(minimum=1024, maximum=32768, value=32768, step=128,
|
| 200 |
-
label="Number of Latent Tokens (Use 8192 if OOM)")
|
| 201 |
-
chunk_size = gr.Slider(minimum=512, maximum=10000, value=2048, step=512,
|
| 202 |
-
label="Chunk Size (Use 2000 if OOM)")
|
| 203 |
-
seed = gr.Number(value=42, label="Random Seed")
|
| 204 |
-
remove_bg = gr.Checkbox(label="Remove Background", value=False)
|
| 205 |
-
|
| 206 |
-
run_btn = gr.Button("Run Inference", variant="primary")
|
| 207 |
-
|
| 208 |
-
with gr.Column():
|
| 209 |
-
output_model = gr.Model3D(label="Refined Output Mesh")
|
| 210 |
-
|
| 211 |
-
run_btn.click(
|
| 212 |
-
fn=lambda img, mesh, s, sc, oct, nml, chk, sd, rm: predict(img, mesh, s, sc, oct, nml, chk, sd, rm, args.ckpt,
|
| 213 |
-
args.low_vram),
|
| 214 |
-
inputs=[image_input, mesh_input, steps, scale, octree_res, num_latents, chunk_size, seed, remove_bg],
|
| 215 |
-
outputs=[output_model]
|
| 216 |
-
)
|
| 217 |
-
|
| 218 |
-
demo.launch(share=args.share, server_name='0.0.0.0', server_port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/infer_dit_refine.py
DELETED
|
@@ -1,142 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import sys
|
| 3 |
-
import argparse
|
| 4 |
-
import torch
|
| 5 |
-
import numpy as np
|
| 6 |
-
from PIL import Image
|
| 7 |
-
from omegaconf import OmegaConf
|
| 8 |
-
|
| 9 |
-
# project_root = '[your_project_root_path]' # Replace with your project root path
|
| 10 |
-
# sys.path.insert(0, project_root)
|
| 11 |
-
|
| 12 |
-
from ultrashape.rembg import BackgroundRemover
|
| 13 |
-
from ultrashape.utils.misc import instantiate_from_config
|
| 14 |
-
from ultrashape.surface_loaders import SharpEdgeSurfaceLoader
|
| 15 |
-
from ultrashape.utils import voxelize_from_point
|
| 16 |
-
from ultrashape.pipelines import UltraShapePipeline
|
| 17 |
-
|
| 18 |
-
def load_models(config_path, ckpt_path, device='cuda'):
|
| 19 |
-
|
| 20 |
-
print(f"Loading config from {config_path}...")
|
| 21 |
-
config = OmegaConf.load(config_path)
|
| 22 |
-
|
| 23 |
-
print("Instantiating VAE...")
|
| 24 |
-
vae = instantiate_from_config(config.model.params.vae_config)
|
| 25 |
-
|
| 26 |
-
print("Instantiating DiT...")
|
| 27 |
-
dit = instantiate_from_config(config.model.params.dit_cfg)
|
| 28 |
-
|
| 29 |
-
print("Instantiating Conditioner...")
|
| 30 |
-
conditioner = instantiate_from_config(config.model.params.conditioner_config)
|
| 31 |
-
|
| 32 |
-
print("Instantiating Scheduler & Processor...")
|
| 33 |
-
scheduler = instantiate_from_config(config.model.params.scheduler_cfg)
|
| 34 |
-
image_processor = instantiate_from_config(config.model.params.image_processor_cfg)
|
| 35 |
-
|
| 36 |
-
print(f"Loading weights from {ckpt_path}...")
|
| 37 |
-
weights = torch.load(ckpt_path, map_location='cpu')
|
| 38 |
-
|
| 39 |
-
vae.load_state_dict(weights['vae'], strict=True)
|
| 40 |
-
dit.load_state_dict(weights['dit'], strict=True)
|
| 41 |
-
conditioner.load_state_dict(weights['conditioner'], strict=True)
|
| 42 |
-
|
| 43 |
-
vae.eval().to(device)
|
| 44 |
-
dit.eval().to(device)
|
| 45 |
-
conditioner.eval().to(device)
|
| 46 |
-
|
| 47 |
-
if hasattr(vae, 'enable_flashvdm_decoder'):
|
| 48 |
-
vae.enable_flashvdm_decoder()
|
| 49 |
-
|
| 50 |
-
components = {
|
| 51 |
-
"vae": vae,
|
| 52 |
-
"dit": dit,
|
| 53 |
-
"conditioner": conditioner,
|
| 54 |
-
"scheduler": scheduler,
|
| 55 |
-
"image_processor": image_processor,
|
| 56 |
-
}
|
| 57 |
-
|
| 58 |
-
return components, config
|
| 59 |
-
|
| 60 |
-
def run_inference(args):
|
| 61 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 62 |
-
|
| 63 |
-
components, config = load_models(args.config, args.ckpt, device)
|
| 64 |
-
|
| 65 |
-
pipeline = UltraShapePipeline(
|
| 66 |
-
vae=components['vae'],
|
| 67 |
-
model=components['dit'],
|
| 68 |
-
scheduler=components['scheduler'],
|
| 69 |
-
conditioner=components['conditioner'],
|
| 70 |
-
image_processor=components['image_processor']
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
-
if args.low_vram:
|
| 74 |
-
pipeline.enable_model_cpu_offload()
|
| 75 |
-
|
| 76 |
-
token_num = args.num_latents
|
| 77 |
-
voxel_res = config.model.params.vae_config.params.voxel_query_res
|
| 78 |
-
|
| 79 |
-
print(f"Initializing Surface Loader (Token Num: {token_num})...")
|
| 80 |
-
loader = SharpEdgeSurfaceLoader(
|
| 81 |
-
num_sharp_points=204800,
|
| 82 |
-
num_uniform_points=204800,
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
print(f"Processing inputs: {args.image} & {args.mesh}")
|
| 86 |
-
image = Image.open(args.image)
|
| 87 |
-
|
| 88 |
-
if args.remove_bg or image.mode != 'RGBA':
|
| 89 |
-
rembg = BackgroundRemover()
|
| 90 |
-
image = rembg(image)
|
| 91 |
-
|
| 92 |
-
surface = loader(args.mesh, normalize_scale=args.scale).to(device, dtype=torch.float16)
|
| 93 |
-
pc = surface[:, :, :3] # [B, N, 3]
|
| 94 |
-
|
| 95 |
-
# Voxelize
|
| 96 |
-
_, voxel_idx = voxelize_from_point(pc, token_num, resolution=voxel_res)
|
| 97 |
-
|
| 98 |
-
print("Running diffusion process...")
|
| 99 |
-
generator = torch.Generator(device).manual_seed(args.seed)
|
| 100 |
-
|
| 101 |
-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 102 |
-
mesh, _ = pipeline(
|
| 103 |
-
image=image,
|
| 104 |
-
voxel_cond=voxel_idx,
|
| 105 |
-
generator=generator,
|
| 106 |
-
box_v=1.0,
|
| 107 |
-
mc_level=0.0,
|
| 108 |
-
octree_resolution=args.octree_res,
|
| 109 |
-
num_inference_steps=args.steps,
|
| 110 |
-
num_chunks=args.chunk_size,
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
| 114 |
-
base_name = os.path.splitext(os.path.basename(args.image))[0]
|
| 115 |
-
save_path = os.path.join(args.output_dir, f"{base_name}_refined.glb")
|
| 116 |
-
|
| 117 |
-
mesh = mesh[0]
|
| 118 |
-
mesh.export(save_path)
|
| 119 |
-
print(f"Successfully saved to {save_path}")
|
| 120 |
-
|
| 121 |
-
if __name__ == "__main__":
|
| 122 |
-
parser = argparse.ArgumentParser(description="UltraShape Inference Script")
|
| 123 |
-
|
| 124 |
-
parser.add_argument("--config", type=str, default="configs/infer_dit2.yaml", help="Path to inference config")
|
| 125 |
-
parser.add_argument("--ckpt", type=str, required=True, help="Path to split checkpoint (.pt)")
|
| 126 |
-
parser.add_argument("--low_vram", action="store_true", help="Optimize for low VRAM usage")
|
| 127 |
-
|
| 128 |
-
parser.add_argument("--image", type=str, required=True, help="Input image path")
|
| 129 |
-
parser.add_argument("--mesh", type=str, required=True, help="Input coarse mesh (.glb/.obj)")
|
| 130 |
-
parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory")
|
| 131 |
-
|
| 132 |
-
parser.add_argument("--steps", type=int, default=50, help="Inference steps")
|
| 133 |
-
parser.add_argument("--scale", type=float, default=0.99, help="Mesh normalization scale")
|
| 134 |
-
parser.add_argument("--num_latents", type=int, default=32768, help="Number of latents")
|
| 135 |
-
parser.add_argument("--chunk_size", type=int, default=8000, help="Chunk size for inference")
|
| 136 |
-
parser.add_argument("--octree_res", type=int, default=1024, help="Marching Cubes resolution")
|
| 137 |
-
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| 138 |
-
parser.add_argument("--remove_bg", action="store_true", help="Force remove background")
|
| 139 |
-
|
| 140 |
-
args = parser.parse_args()
|
| 141 |
-
|
| 142 |
-
run_inference(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/install_env.sh
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
conda create -n ultrashape python=3.10
|
| 2 |
-
conda activate ultrashape
|
| 3 |
-
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
|
| 4 |
-
pip install -r requirements.txt
|
| 5 |
-
pip install git+https://github.com/ashawkey/cubvh --no-build-isolation
|
| 6 |
-
|
| 7 |
-
pip install --no-build-isolation "git+https://github.com/facebookresearch/pytorch3d.git@stable"
|
| 8 |
-
pip install https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_cluster-1.6.3%2Bpt25cu121-cp310-cp310-linux_x86_64.whl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/run.sh
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
# sampling
|
| 2 |
-
# python scripts/sampling.py \
|
| 3 |
-
# --mesh_json data/mesh_paths.json \
|
| 4 |
-
# --output_dir data/sample
|
| 5 |
-
|
| 6 |
-
# inference refine_dit
|
| 7 |
-
python scripts/infer_dit_refine.py \
|
| 8 |
-
--ckpt checkpoints/ultrashape_v1.pt \
|
| 9 |
-
--image inputs/image/1.png \
|
| 10 |
-
--mesh inputs/coarse_mesh/1.glb \
|
| 11 |
-
--config configs/infer_dit_refine.yaml
|
| 12 |
-
# --steps 12
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/sampling.py
DELETED
|
@@ -1,586 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import trimesh
|
| 3 |
-
import numpy as np
|
| 4 |
-
from typing import List, Optional, Any, Tuple, Union
|
| 5 |
-
import pytorch_lightning as pl
|
| 6 |
-
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 7 |
-
import torch
|
| 8 |
-
from torch.utils.data import Dataset, DataLoader
|
| 9 |
-
import pytorch3d.structures
|
| 10 |
-
import pytorch3d.ops
|
| 11 |
-
from scipy.stats import truncnorm
|
| 12 |
-
import json
|
| 13 |
-
import argparse
|
| 14 |
-
import cubvh
|
| 15 |
-
|
| 16 |
-
# import logging
|
| 17 |
-
# from tools.logger import init_log, set_all_log
|
| 18 |
-
# sys_logger = init_log("sampler", logging.DEBUG)
|
| 19 |
-
# set_all_log(level=logging.DEBUG, path='./debug/logs')
|
| 20 |
-
|
| 21 |
-
def load_mesh(mesh_path: str, device: str = "cuda") -> Tuple[torch.Tensor, torch.Tensor]:
|
| 22 |
-
if mesh_path.endswith(".npz"):
|
| 23 |
-
mesh_np = np.load(mesh_path)
|
| 24 |
-
vertices, faces = torch.tensor(mesh_np["vertices"], device=device), torch.tensor(mesh_np["faces"].astype('i8'), device=device)
|
| 25 |
-
else:
|
| 26 |
-
mesh = trimesh.load(mesh_path, force='mesh')
|
| 27 |
-
vertices = torch.tensor(mesh.vertices, dtype=torch.float32, device=device)
|
| 28 |
-
faces = torch.tensor(mesh.faces, dtype=torch.long, device=device)
|
| 29 |
-
if faces.shape[0] > 2 * 1e8:
|
| 30 |
-
raise ValueError(f"too many faces {faces.shape}")
|
| 31 |
-
return vertices, faces
|
| 32 |
-
|
| 33 |
-
def compute_mesh_features(vertices: torch.Tensor, faces: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 34 |
-
device = vertices.device
|
| 35 |
-
|
| 36 |
-
v0 = vertices[faces[:, 0]]
|
| 37 |
-
v1 = vertices[faces[:, 1]]
|
| 38 |
-
v2 = vertices[faces[:, 2]]
|
| 39 |
-
face_normals = torch.cross(v1 - v0, v2 - v0)
|
| 40 |
-
face_areas = torch.norm(face_normals, dim=1) * 0.5
|
| 41 |
-
face_normals = face_normals / (face_areas.unsqueeze(1) * 2 + 1e-12)
|
| 42 |
-
|
| 43 |
-
vertex_normals = torch.zeros_like(vertices)
|
| 44 |
-
face_normals_weighted = face_normals * face_areas.unsqueeze(1)
|
| 45 |
-
|
| 46 |
-
vertex_normals.scatter_add_(0, faces[:, 0:1].expand(-1, 3), face_normals_weighted)
|
| 47 |
-
vertex_normals.scatter_add_(0, faces[:, 1:2].expand(-1, 3), face_normals_weighted)
|
| 48 |
-
vertex_normals.scatter_add_(0, faces[:, 2:3].expand(-1, 3), face_normals_weighted)
|
| 49 |
-
|
| 50 |
-
vertex_normals = vertex_normals / (torch.norm(vertex_normals, dim=1, keepdim=True) + 1e-12)
|
| 51 |
-
|
| 52 |
-
edges = torch.cat([
|
| 53 |
-
faces[:, [0, 1]],
|
| 54 |
-
faces[:, [1, 2]],
|
| 55 |
-
faces[:, [2, 0]]
|
| 56 |
-
], dim=0)
|
| 57 |
-
|
| 58 |
-
edges_unique, edges_inverse = torch.unique(torch.sort(edges, dim=1)[0], dim=0, return_inverse=True)
|
| 59 |
-
edge_normals_diff = torch.norm(
|
| 60 |
-
vertex_normals[edges[:, 0]] - vertex_normals[edges[:, 1]],
|
| 61 |
-
dim=1
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
vertex_curvatures = torch.zeros(len(vertices), device=device)
|
| 65 |
-
vertex_curvatures.scatter_add_(0, edges[:, 0], edge_normals_diff)
|
| 66 |
-
vertex_curvatures.scatter_add_(0, edges[:, 1], edge_normals_diff)
|
| 67 |
-
|
| 68 |
-
vertex_degrees = torch.zeros(len(vertices), device=device)
|
| 69 |
-
vertex_degrees.scatter_add_(0, edges[:, 0], torch.ones_like(edge_normals_diff))
|
| 70 |
-
vertex_degrees.scatter_add_(0, edges[:, 1], torch.ones_like(edge_normals_diff))
|
| 71 |
-
|
| 72 |
-
vertex_curvatures = vertex_curvatures / (vertex_degrees + 1e-12)
|
| 73 |
-
vertex_curvatures = (vertex_curvatures - vertex_curvatures.min()) / (
|
| 74 |
-
vertex_curvatures.max() - vertex_curvatures.min() + 1e-12)
|
| 75 |
-
|
| 76 |
-
return face_areas, vertex_curvatures
|
| 77 |
-
|
| 78 |
-
def sample_uniform_points(
|
| 79 |
-
vertices: torch.Tensor,
|
| 80 |
-
faces: torch.Tensor,
|
| 81 |
-
num_samples: int,
|
| 82 |
-
random_seed: Optional[int] = None
|
| 83 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 84 |
-
|
| 85 |
-
if random_seed is not None:
|
| 86 |
-
torch.manual_seed(random_seed)
|
| 87 |
-
mesh = pytorch3d.structures.Meshes(verts=[vertices], faces=[faces])
|
| 88 |
-
|
| 89 |
-
points, normals = pytorch3d.ops.sample_points_from_meshes(
|
| 90 |
-
mesh, num_samples=num_samples, return_normals=True)
|
| 91 |
-
|
| 92 |
-
return points[0], normals[0]
|
| 93 |
-
|
| 94 |
-
def sample_surface_points(
|
| 95 |
-
vertices: torch.Tensor,
|
| 96 |
-
faces: torch.Tensor,
|
| 97 |
-
num_samples: int,
|
| 98 |
-
min_samples_per_face: int = 0,
|
| 99 |
-
use_curvature: bool = True,
|
| 100 |
-
random_seed: Optional[int] = None
|
| 101 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 102 |
-
"""Curvature-based surface sampling"""
|
| 103 |
-
device = vertices.device
|
| 104 |
-
if random_seed is not None:
|
| 105 |
-
torch.manual_seed(random_seed)
|
| 106 |
-
|
| 107 |
-
# Compute face areas and vertex curvatures
|
| 108 |
-
face_areas, vertex_curvatures = compute_mesh_features(vertices, faces)
|
| 109 |
-
|
| 110 |
-
# Compute average curvature of faces
|
| 111 |
-
face_curvatures = torch.mean(vertex_curvatures[faces], dim=1)
|
| 112 |
-
sampling_weights = face_curvatures # Use only curvature as weights
|
| 113 |
-
# Calculate number of sample points per face
|
| 114 |
-
num_faces = len(faces)
|
| 115 |
-
|
| 116 |
-
# Chunk forward
|
| 117 |
-
if min_samples_per_face > 0:
|
| 118 |
-
base_samples = torch.full((num_faces,), min_samples_per_face, device=device)
|
| 119 |
-
remaining_samples = num_samples - torch.sum(base_samples).item()
|
| 120 |
-
|
| 121 |
-
if remaining_samples > 0:
|
| 122 |
-
# Block sampling to avoid large mesh issues
|
| 123 |
-
if num_faces > 2**24:
|
| 124 |
-
chunk_size = 1000000 # Process 1 million faces at a time
|
| 125 |
-
additional_counts = torch.zeros(num_faces, device=device)
|
| 126 |
-
|
| 127 |
-
for start in range(0, num_faces, chunk_size):
|
| 128 |
-
end = min(start + chunk_size, num_faces)
|
| 129 |
-
chunk_weights = sampling_weights[start:end]
|
| 130 |
-
chunk_probs = chunk_weights / chunk_weights.sum()
|
| 131 |
-
|
| 132 |
-
# Proportinally allocate remaining samples
|
| 133 |
-
chunk_samples = int(remaining_samples * (end - start) / num_faces)
|
| 134 |
-
samples = torch.multinomial(chunk_probs, chunk_samples, replacement=True)
|
| 135 |
-
chunk_counts = torch.bincount(samples, minlength=chunk_size)
|
| 136 |
-
additional_counts[start:end] += chunk_counts[:end-start]
|
| 137 |
-
|
| 138 |
-
sample_counts = additional_counts + base_samples
|
| 139 |
-
else:
|
| 140 |
-
probs = sampling_weights / sampling_weights.sum()
|
| 141 |
-
additional_samples = torch.multinomial(probs, remaining_samples, replacement=True)
|
| 142 |
-
sample_counts = torch.bincount(additional_samples, minlength=num_faces) + base_samples
|
| 143 |
-
else:
|
| 144 |
-
sample_counts = base_samples
|
| 145 |
-
else:
|
| 146 |
-
if num_faces > 2**24:
|
| 147 |
-
# Chunk sampling strategy
|
| 148 |
-
sample_counts = torch.zeros(num_faces, device=device)
|
| 149 |
-
chunk_size = 1000000 # Process 1 million faces at a time
|
| 150 |
-
chunk_samples = num_samples // ((num_faces + chunk_size - 1) // chunk_size)
|
| 151 |
-
|
| 152 |
-
for start in range(0, num_faces, chunk_size):
|
| 153 |
-
end = min(start + chunk_size, num_faces)
|
| 154 |
-
chunk_weights = sampling_weights[start:end]
|
| 155 |
-
chunk_probs = chunk_weights / chunk_weights.sum()
|
| 156 |
-
|
| 157 |
-
samples = torch.multinomial(chunk_probs, chunk_samples, replacement=True)
|
| 158 |
-
chunk_counts = torch.bincount(samples, minlength=chunk_size)
|
| 159 |
-
sample_counts[start:end] += chunk_counts[:end-start]
|
| 160 |
-
else:
|
| 161 |
-
probs = sampling_weights / sampling_weights.sum()
|
| 162 |
-
samples = torch.multinomial(probs, num_samples, replacement=True)
|
| 163 |
-
sample_counts = torch.bincount(samples, minlength=num_faces)
|
| 164 |
-
|
| 165 |
-
# Generate barycentric coordinates for sampled points
|
| 166 |
-
total_samples = sample_counts.sum().item()
|
| 167 |
-
r1 = torch.sqrt(torch.rand(total_samples, device=device))
|
| 168 |
-
r2 = torch.rand(total_samples, device=device)
|
| 169 |
-
|
| 170 |
-
barycentric_coords = torch.stack([
|
| 171 |
-
1 - r1,
|
| 172 |
-
r1 * (1 - r2),
|
| 173 |
-
r1 * r2
|
| 174 |
-
], dim=1)
|
| 175 |
-
|
| 176 |
-
# Generate face indices
|
| 177 |
-
face_indices = torch.repeat_interleave(
|
| 178 |
-
torch.arange(num_faces, device=device),
|
| 179 |
-
sample_counts
|
| 180 |
-
)
|
| 181 |
-
|
| 182 |
-
# Get vertices of corresponding faces
|
| 183 |
-
face_vertices = vertices[faces[face_indices]]
|
| 184 |
-
|
| 185 |
-
# Compute 3D coordinates of sampled points
|
| 186 |
-
points = (barycentric_coords.unsqueeze(1) @ face_vertices).squeeze(1)
|
| 187 |
-
|
| 188 |
-
# Compute normal vectors of sampled points
|
| 189 |
-
v0, v1, v2 = face_vertices[:, 0], face_vertices[:, 1], face_vertices[:, 2]
|
| 190 |
-
face_normals = torch.cross(v1 - v0, v2 - v0)
|
| 191 |
-
normals = face_normals / (torch.norm(face_normals, dim=1, keepdim=True) + 1e-12)
|
| 192 |
-
|
| 193 |
-
return points, face_indices, normals
|
| 194 |
-
|
| 195 |
-
def normalize_points_and_mesh(vertices: torch.Tensor, points: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 196 |
-
"""Normalize mesh and point cloud to unit cube"""
|
| 197 |
-
device = vertices.device
|
| 198 |
-
vmin = vertices.min(dim=0)[0]
|
| 199 |
-
vmax = vertices.max(dim=0)[0]
|
| 200 |
-
center = (vmax + vmin) / 2
|
| 201 |
-
scale = (vmax - vmin).max()
|
| 202 |
-
margin = 0.01
|
| 203 |
-
scale = scale * (1 + 2 * margin)
|
| 204 |
-
|
| 205 |
-
vertices_normalized = (vertices - center) / scale + 0.5
|
| 206 |
-
points_normalized = (points - center) / scale + 0.5
|
| 207 |
-
|
| 208 |
-
return vertices_normalized, points_normalized, center, scale
|
| 209 |
-
|
| 210 |
-
def add_gaussian_noise(uniform_surface_points: torch.Tensor, curvature_surface_points: torch.Tensor, sigma: float = 0.01) -> torch.Tensor:
|
| 211 |
-
"""Add Gaussian noise to point cloud"""
|
| 212 |
-
# noise = torch.randn_like(points) * sigma
|
| 213 |
-
# print("u_num:",uniform_surface_points.shape)
|
| 214 |
-
# print("c_num:",curvature_surface_points.shape)
|
| 215 |
-
|
| 216 |
-
idx1 = torch.randperm(uniform_surface_points.shape[0])
|
| 217 |
-
idx2 = torch.randperm(curvature_surface_points.shape[0])
|
| 218 |
-
uniform_surface_points = uniform_surface_points[idx1]
|
| 219 |
-
curvature_surface_points = curvature_surface_points[idx2]
|
| 220 |
-
|
| 221 |
-
a, b = -0.25, 0.25
|
| 222 |
-
mu = 0
|
| 223 |
-
|
| 224 |
-
# get near points (add offset on surface points)
|
| 225 |
-
offset1 = torch.tensor(truncnorm.rvs((a - mu) / 0.005, (b - mu) / 0.005, loc=mu, scale=0.005, size=(len(uniform_surface_points), 3)),
|
| 226 |
-
dtype=uniform_surface_points.dtype, device=uniform_surface_points.device)
|
| 227 |
-
offset2 = torch.tensor(truncnorm.rvs((a - mu) / 0.05, (b - mu) / 0.05, loc=mu, scale=0.05, size=(len(uniform_surface_points), 3)),
|
| 228 |
-
dtype=uniform_surface_points.dtype, device=uniform_surface_points.device)
|
| 229 |
-
uniform_near_points = torch.cat([
|
| 230 |
-
uniform_surface_points + offset1,
|
| 231 |
-
uniform_surface_points + offset2
|
| 232 |
-
], dim=0)
|
| 233 |
-
|
| 234 |
-
# Generate multi-scale noise for curvature sample points
|
| 235 |
-
unit_num = curvature_surface_points.shape[0] // 6
|
| 236 |
-
scales = [0.001, 0.003, 0.006, 0.01, 0.02, 0.04]
|
| 237 |
-
|
| 238 |
-
curvature_near_points = []
|
| 239 |
-
for i in range(6):
|
| 240 |
-
start = i * unit_num
|
| 241 |
-
end = (i + 1) * unit_num if i < 5 else curvature_surface_points.shape[0]
|
| 242 |
-
noise = torch.randn((end - start, 3), dtype=curvature_surface_points.dtype,
|
| 243 |
-
device=curvature_surface_points.device) * scales[i]
|
| 244 |
-
curvature_near_points.append(curvature_surface_points[start:end] + noise)
|
| 245 |
-
|
| 246 |
-
curvature_near_points = torch.cat(curvature_near_points, dim=0)
|
| 247 |
-
|
| 248 |
-
return uniform_near_points, curvature_near_points
|
| 249 |
-
|
| 250 |
-
def compute_points_value_bvh(
|
| 251 |
-
vertices: torch.Tensor,
|
| 252 |
-
faces: torch.Tensor,
|
| 253 |
-
points: torch.Tensor,
|
| 254 |
-
use_sdf: bool = True,
|
| 255 |
-
batch_size: int = 100_00000
|
| 256 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 257 |
-
"""Compute SDF or occupancy values for sampled points"""
|
| 258 |
-
device = vertices.device
|
| 259 |
-
|
| 260 |
-
# Normalize mesh and point cloud
|
| 261 |
-
vertices_norm, points_norm, center, scale = normalize_points_and_mesh(vertices, points)
|
| 262 |
-
|
| 263 |
-
BVH = cubvh.cuBVH(vertices_norm, faces)
|
| 264 |
-
distances, face_id, uvw = BVH.signed_distance(points, return_uvw=True, mode='watertight')
|
| 265 |
-
values = distances
|
| 266 |
-
|
| 267 |
-
return values, points_norm, center, scale
|
| 268 |
-
|
| 269 |
-
def save_point_cloud(
|
| 270 |
-
points: torch.Tensor,
|
| 271 |
-
output_path: str,
|
| 272 |
-
normals: Optional[torch.Tensor] = None,
|
| 273 |
-
colors: Optional[torch.Tensor] = None
|
| 274 |
-
) -> None:
|
| 275 |
-
"""Save point cloud to file"""
|
| 276 |
-
points_np = points.cpu().numpy()
|
| 277 |
-
normals_np = normals.cpu().numpy() if normals is not None else None
|
| 278 |
-
colors_np = None
|
| 279 |
-
|
| 280 |
-
if colors is not None:
|
| 281 |
-
colors_np = colors.cpu().numpy()
|
| 282 |
-
if colors_np.max() <= 1.0:
|
| 283 |
-
colors_np = (colors_np * 255).astype(np.uint8)
|
| 284 |
-
|
| 285 |
-
ext = os.path.splitext(output_path)[1].lower()
|
| 286 |
-
|
| 287 |
-
if ext == '.txt':
|
| 288 |
-
data_list = [points_np]
|
| 289 |
-
if normals_np is not None:
|
| 290 |
-
data_list.append(normals_np)
|
| 291 |
-
if colors_np is not None:
|
| 292 |
-
data_list.append(colors_np)
|
| 293 |
-
|
| 294 |
-
combined_data = np.hstack(data_list)
|
| 295 |
-
np.savetxt(output_path, combined_data, fmt='%.6f')
|
| 296 |
-
|
| 297 |
-
elif ext == '.ply':
|
| 298 |
-
cloud = trimesh.PointCloud(points_np, colors=colors_np)
|
| 299 |
-
if normals_np is not None:
|
| 300 |
-
cloud.metadata['normals'] = normals_np
|
| 301 |
-
cloud.export(output_path)
|
| 302 |
-
|
| 303 |
-
else:
|
| 304 |
-
raise ValueError(f"Unsupported file format: {ext}. Please use .txt or .ply")
|
| 305 |
-
|
| 306 |
-
def sample_points_in_bbox(
|
| 307 |
-
bbox_min: torch.Tensor,
|
| 308 |
-
bbox_max: torch.Tensor,
|
| 309 |
-
num_samples: int,
|
| 310 |
-
device: str = "cuda"
|
| 311 |
-
) -> torch.Tensor:
|
| 312 |
-
"""Uniformly sample points within bounding box"""
|
| 313 |
-
points = torch.rand(num_samples, 3, device=device)
|
| 314 |
-
points = points * (bbox_max - bbox_min) + bbox_min
|
| 315 |
-
return points
|
| 316 |
-
|
| 317 |
-
def process_single_mesh(
|
| 318 |
-
mesh_name:str,
|
| 319 |
-
mesh_path: str,
|
| 320 |
-
output_dir: str,
|
| 321 |
-
data_type:str = 'mesh',
|
| 322 |
-
surface_uniform_samples: int = 100000, # surface上均匀采样点数
|
| 323 |
-
surface_curvature_samples: int = 200000, # surface上曲率采样点数
|
| 324 |
-
space_samples: int = 300000, # 空间中采样点数
|
| 325 |
-
noise_sigma: float = 0.01,
|
| 326 |
-
device: str = "cuda"
|
| 327 |
-
) -> None:
|
| 328 |
-
"""Process a single mesh file
|
| 329 |
-
Args:
|
| 330 |
-
mesh_path: Input mesh path
|
| 331 |
-
output_dir: Output directory
|
| 332 |
-
surface_uniform_samples: Number of uniform sample points on surface
|
| 333 |
-
surface_curvature_samples: Number of curvature-based sample points on surface
|
| 334 |
-
space_samples: Number of sample points in space
|
| 335 |
-
noise_sigma: Gaussian noise standard deviation
|
| 336 |
-
device: Computation device
|
| 337 |
-
"""
|
| 338 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 339 |
-
|
| 340 |
-
if data_type == "mesh":
|
| 341 |
-
vertices, faces = load_mesh(mesh_path, device)
|
| 342 |
-
elif data_type == "sparse_voxel":
|
| 343 |
-
pass
|
| 344 |
-
vertices_normalized, _, center, scale = normalize_points_and_mesh(vertices, vertices)
|
| 345 |
-
|
| 346 |
-
space_points = torch.rand(space_samples, 3, device=device)
|
| 347 |
-
|
| 348 |
-
uniform_surface_points, uniform_surface_normals = sample_uniform_points(
|
| 349 |
-
vertices=vertices_normalized,
|
| 350 |
-
faces=faces,
|
| 351 |
-
num_samples=surface_uniform_samples
|
| 352 |
-
)
|
| 353 |
-
|
| 354 |
-
curvature_surface_points, _, curvature_surface_normals = sample_surface_points(
|
| 355 |
-
vertices=vertices_normalized,
|
| 356 |
-
faces=faces,
|
| 357 |
-
num_samples=surface_curvature_samples,
|
| 358 |
-
use_curvature=True
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
clean_surface_points = torch.cat([uniform_surface_points, curvature_surface_points], dim=0)
|
| 362 |
-
clean_surface_normals = torch.cat([uniform_surface_normals, curvature_surface_normals], dim=0)
|
| 363 |
-
|
| 364 |
-
surface_uni_save_path = os.path.join(output_dir, f"{mesh_name}_uni_surface")
|
| 365 |
-
save_point_cloud(
|
| 366 |
-
points=uniform_surface_points,
|
| 367 |
-
output_path=f"{surface_uni_save_path}.ply",
|
| 368 |
-
normals=uniform_surface_normals
|
| 369 |
-
)
|
| 370 |
-
|
| 371 |
-
surface_cur_save_path = os.path.join(output_dir, f"{mesh_name}_cur_surface")
|
| 372 |
-
save_point_cloud(
|
| 373 |
-
points=curvature_surface_points,
|
| 374 |
-
output_path=f"{surface_cur_save_path}.ply",
|
| 375 |
-
normals=curvature_surface_normals
|
| 376 |
-
)
|
| 377 |
-
|
| 378 |
-
uniform_near_points, curvature_near_points = add_gaussian_noise(uniform_surface_points = uniform_surface_points.clone(),
|
| 379 |
-
curvature_surface_points = curvature_surface_points.clone(), sigma=noise_sigma)
|
| 380 |
-
|
| 381 |
-
space_sdf, _, _, _ = compute_points_value_bvh(
|
| 382 |
-
vertices=vertices_normalized,
|
| 383 |
-
faces=faces,
|
| 384 |
-
points=space_points,
|
| 385 |
-
use_sdf=True,
|
| 386 |
-
batch_size=1000_00000
|
| 387 |
-
)
|
| 388 |
-
|
| 389 |
-
# clean_surface_sdf = torch.zeros(len(clean_surface_points), device=device)
|
| 390 |
-
uniform_near_sdf, _, _, _ = compute_points_value_bvh(
|
| 391 |
-
vertices=vertices_normalized,
|
| 392 |
-
faces=faces,
|
| 393 |
-
points=uniform_near_points,
|
| 394 |
-
use_sdf=True,
|
| 395 |
-
batch_size=1000_00000
|
| 396 |
-
)
|
| 397 |
-
|
| 398 |
-
curvature_near_sdf, _, _, _ = compute_points_value_bvh(
|
| 399 |
-
vertices=vertices_normalized,
|
| 400 |
-
faces=faces,
|
| 401 |
-
points=curvature_near_points,
|
| 402 |
-
use_sdf=True,
|
| 403 |
-
batch_size=1000_00000
|
| 404 |
-
)
|
| 405 |
-
|
| 406 |
-
print("sdf:",uniform_near_sdf.shape, curvature_near_sdf.shape)
|
| 407 |
-
|
| 408 |
-
base_save_path = os.path.join(output_dir, mesh_name)
|
| 409 |
-
|
| 410 |
-
np.savez(f"{base_save_path}.npz",
|
| 411 |
-
space_points=space_points.cpu().numpy(),
|
| 412 |
-
space_sdf=space_sdf.cpu().numpy(),
|
| 413 |
-
clean_surface_points=clean_surface_points.cpu().numpy(),
|
| 414 |
-
clean_surface_normals=clean_surface_normals.cpu().numpy(),
|
| 415 |
-
uniform_near_points=uniform_near_points.cpu().numpy(),
|
| 416 |
-
curvature_near_points=curvature_near_points.cpu().numpy(),
|
| 417 |
-
uniform_near_sdf=uniform_near_sdf.cpu().numpy(),
|
| 418 |
-
curvature_near_sdf=curvature_near_sdf.cpu().numpy(),
|
| 419 |
-
center=center.cpu().numpy(),
|
| 420 |
-
scale=scale.cpu().numpy())
|
| 421 |
-
|
| 422 |
-
class MeshDataset(Dataset):
|
| 423 |
-
def __init__(self, mesh_json: str):
|
| 424 |
-
with open(mesh_json, "r") as f:
|
| 425 |
-
self.mesh_paths = json.load(f)
|
| 426 |
-
# print(len(self.mesh_paths))
|
| 427 |
-
|
| 428 |
-
def __len__(self) -> int:
|
| 429 |
-
return len(self.mesh_paths)
|
| 430 |
-
def __getitem__(self, idx: int) -> dict:
|
| 431 |
-
mesh_path = self.mesh_paths[idx]
|
| 432 |
-
mesh_name = os.path.basename(mesh_path)[:-4]
|
| 433 |
-
mesh = {
|
| 434 |
-
"mesh_path": mesh_path,
|
| 435 |
-
"mesh_name": mesh_name,
|
| 436 |
-
}
|
| 437 |
-
return mesh
|
| 438 |
-
|
| 439 |
-
class MeshProcessor(pl.LightningModule):
|
| 440 |
-
def __init__(
|
| 441 |
-
self,
|
| 442 |
-
mesh_json: str,
|
| 443 |
-
output_dir: str,
|
| 444 |
-
data_type:str,
|
| 445 |
-
surface_uniform_samples: int = 20000,
|
| 446 |
-
surface_curvature_samples: int = 40000,
|
| 447 |
-
space_samples: int = 300000,
|
| 448 |
-
noise_sigma: float = 0.01,
|
| 449 |
-
batch_size: int = 1,
|
| 450 |
-
num_workers: int = 4
|
| 451 |
-
):
|
| 452 |
-
super().__init__()
|
| 453 |
-
self.save_hyperparameters()
|
| 454 |
-
os.makedirs(output_dir, exist_ok=True)
|
| 455 |
-
|
| 456 |
-
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> STEP_OUTPUT:
|
| 457 |
-
mesh_path = batch["mesh_path"][0]
|
| 458 |
-
mesh_name = batch["mesh_name"][0]
|
| 459 |
-
|
| 460 |
-
# sys_logger.info(f"Processing {batch_idx}/{len(self.trainer.predict_dataloaders)}: {mesh_name} from {mesh_path}")
|
| 461 |
-
|
| 462 |
-
output_subdir = self.hparams.output_dir
|
| 463 |
-
|
| 464 |
-
try:
|
| 465 |
-
filename = os.path.splitext(os.path.basename(mesh_path))[0]
|
| 466 |
-
if os.path.exists(os.path.join(output_subdir, f"{filename}.npz")):
|
| 467 |
-
# sys_logger.info(f"Skipping {mesh_name} as it already exists.")
|
| 468 |
-
return {
|
| 469 |
-
"status": "success",
|
| 470 |
-
"mesh_name": mesh_name
|
| 471 |
-
}
|
| 472 |
-
process_single_mesh(
|
| 473 |
-
mesh_name=mesh_name,
|
| 474 |
-
mesh_path=mesh_path,
|
| 475 |
-
output_dir=output_subdir,
|
| 476 |
-
data_type = self.hparams.data_type,
|
| 477 |
-
surface_uniform_samples=self.hparams.surface_uniform_samples,
|
| 478 |
-
surface_curvature_samples=self.hparams.surface_curvature_samples,
|
| 479 |
-
space_samples=self.hparams.space_samples,
|
| 480 |
-
noise_sigma=self.hparams.noise_sigma,
|
| 481 |
-
device=self.device
|
| 482 |
-
)
|
| 483 |
-
|
| 484 |
-
return {
|
| 485 |
-
"status": "success",
|
| 486 |
-
"mesh_name": mesh_name
|
| 487 |
-
}
|
| 488 |
-
|
| 489 |
-
except Exception as e:
|
| 490 |
-
print(f"Error processing {mesh_name}: {str(e)}")
|
| 491 |
-
return {
|
| 492 |
-
"status": "error",
|
| 493 |
-
"mesh_name": mesh_name,
|
| 494 |
-
"error": str(e)
|
| 495 |
-
}
|
| 496 |
-
|
| 497 |
-
def predict_dataloader(self) -> DataLoader:
|
| 498 |
-
dataset = MeshDataset(
|
| 499 |
-
self.hparams.mesh_json)
|
| 500 |
-
return DataLoader(
|
| 501 |
-
dataset,
|
| 502 |
-
batch_size=self.hparams.batch_size,
|
| 503 |
-
num_workers=self.hparams.num_workers,
|
| 504 |
-
persistent_workers=True,
|
| 505 |
-
shuffle=False
|
| 506 |
-
)
|
| 507 |
-
|
| 508 |
-
def process_mesh_directory(
|
| 509 |
-
mesh_json: str,
|
| 510 |
-
output_dir: str,
|
| 511 |
-
data_type: str,
|
| 512 |
-
surface_uniform_samples: int = 100000,
|
| 513 |
-
surface_curvature_samples: int = 200000,
|
| 514 |
-
space_samples: int = 300000,
|
| 515 |
-
noise_sigma: float = 0.01,
|
| 516 |
-
num_gpus: int = -1,
|
| 517 |
-
batch_size: int = 1,
|
| 518 |
-
num_workers: int = 4
|
| 519 |
-
) -> None:
|
| 520 |
-
model = MeshProcessor(
|
| 521 |
-
mesh_json=mesh_json,
|
| 522 |
-
output_dir=output_dir,
|
| 523 |
-
data_type=data_type,
|
| 524 |
-
surface_uniform_samples=surface_uniform_samples,
|
| 525 |
-
surface_curvature_samples=surface_curvature_samples,
|
| 526 |
-
space_samples=space_samples,
|
| 527 |
-
noise_sigma=noise_sigma,
|
| 528 |
-
batch_size=batch_size,
|
| 529 |
-
num_workers=num_workers
|
| 530 |
-
)
|
| 531 |
-
|
| 532 |
-
trainer = pl.Trainer(
|
| 533 |
-
accelerator="gpu",
|
| 534 |
-
devices=num_gpus,
|
| 535 |
-
strategy="ddp",
|
| 536 |
-
precision=32,
|
| 537 |
-
logger=False,
|
| 538 |
-
enable_progress_bar=True
|
| 539 |
-
)
|
| 540 |
-
|
| 541 |
-
predictions = trainer.predict(model)
|
| 542 |
-
|
| 543 |
-
success_count = sum(1 for p in predictions if p["status"] == "success")
|
| 544 |
-
error_count = sum(1 for p in predictions if p["status"] == "error")
|
| 545 |
-
|
| 546 |
-
print(f"\nProcessing completed:")
|
| 547 |
-
print(f"Successfully processed: {success_count} files")
|
| 548 |
-
print(f"Failed to process: {error_count} files")
|
| 549 |
-
|
| 550 |
-
if error_count > 0:
|
| 551 |
-
print("\nFailed files:")
|
| 552 |
-
for p in predictions:
|
| 553 |
-
if p["status"] == "error":
|
| 554 |
-
print(f"- {p['mesh_name']}: {p['error']}")
|
| 555 |
-
|
| 556 |
-
if __name__ == "__main__":
|
| 557 |
-
|
| 558 |
-
parser = argparse.ArgumentParser(description="Process Mesh Directory for Sampling")
|
| 559 |
-
|
| 560 |
-
parser.add_argument("--mesh_json", type=str, default="test_mesh.json", help="Path to the mesh json file")
|
| 561 |
-
parser.add_argument("--output_dir", type=str, default="ultrashape_test1", help="Directory to save outputs")
|
| 562 |
-
|
| 563 |
-
parser.add_argument("--surface_uniform_samples", type=int, default=300000, help="Number of uniform samples on surface")
|
| 564 |
-
parser.add_argument("--surface_curvature_samples", type=int, default=300000, help="Number of curvature-based samples on surface")
|
| 565 |
-
parser.add_argument("--space_samples", type=int, default=400000, help="Number of samples in space")
|
| 566 |
-
|
| 567 |
-
parser.add_argument("--noise_sigma", type=float, default=0.01, help="Sigma for Gaussian noise")
|
| 568 |
-
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use")
|
| 569 |
-
parser.add_argument("--num_workers", type=int, default=16, help="Number of data loading workers")
|
| 570 |
-
parser.add_argument("--batch_size", type=int, default=1, help="Batch size per GPU")
|
| 571 |
-
|
| 572 |
-
args = parser.parse_args()
|
| 573 |
-
# print(f"Arguments: {args}")
|
| 574 |
-
|
| 575 |
-
process_mesh_directory(
|
| 576 |
-
mesh_json=args.mesh_json,
|
| 577 |
-
output_dir=args.output_dir,
|
| 578 |
-
data_type='mesh',
|
| 579 |
-
surface_uniform_samples=args.surface_uniform_samples,
|
| 580 |
-
surface_curvature_samples=args.surface_curvature_samples,
|
| 581 |
-
space_samples=args.space_samples,
|
| 582 |
-
noise_sigma=args.noise_sigma,
|
| 583 |
-
num_gpus=args.num_gpus,
|
| 584 |
-
num_workers=args.num_workers,
|
| 585 |
-
batch_size=args.batch_size
|
| 586 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/train_deepspeed.sh
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
export NCCL_IB_TIMEOUT=24
|
| 3 |
-
export NCCL_NVLS_ENABLE=0
|
| 4 |
-
NET_TYPE="high"
|
| 5 |
-
if [[ "${NET_TYPE}" = "low" ]]; then
|
| 6 |
-
export NCCL_SOCKET_IFNAME=eth1
|
| 7 |
-
export NCCL_IB_GID_INDEX=3
|
| 8 |
-
export NCCL_IB_HCA=mlx5_2:1,mlx5_2:1
|
| 9 |
-
export NCCL_IB_SL=3
|
| 10 |
-
export NCCL_CHECKS_DISABLE=1
|
| 11 |
-
export NCCL_P2P_DISABLE=0
|
| 12 |
-
export NCCL_LL_THRESHOLD=16384
|
| 13 |
-
export NCCL_IB_CUDA_SUPPORT=1
|
| 14 |
-
else
|
| 15 |
-
export NCCL_IB_GID_INDEX=3
|
| 16 |
-
export NCCL_IB_SL=3
|
| 17 |
-
export NCCL_CHECKS_DISABLE=1
|
| 18 |
-
export NCCL_P2P_DISABLE=0
|
| 19 |
-
export NCCL_IB_DISABLE=0
|
| 20 |
-
export NCCL_LL_THRESHOLD=16384
|
| 21 |
-
export NCCL_IB_CUDA_SUPPORT=1
|
| 22 |
-
export NCCL_SOCKET_IFNAME=bond1
|
| 23 |
-
export NCCL_COLLNET_ENABLE=0
|
| 24 |
-
export SHARP_COLL_ENABLE_SAT=0
|
| 25 |
-
export NCCL_NET_GDR_LEVEL=2
|
| 26 |
-
export NCCL_IB_QPS_PER_CONNECTION=4
|
| 27 |
-
export NCCL_IB_TC=160
|
| 28 |
-
export NCCL_PXN_DISABLE=1
|
| 29 |
-
fi
|
| 30 |
-
# export NCCL_DEBUG=INFO
|
| 31 |
-
|
| 32 |
-
node_num=$1
|
| 33 |
-
node_rank=$2
|
| 34 |
-
num_gpu_per_node=$3
|
| 35 |
-
master_ip=$4
|
| 36 |
-
config=$5
|
| 37 |
-
output_dir=$6
|
| 38 |
-
|
| 39 |
-
echo node_num $node_num
|
| 40 |
-
echo node_rank $node_rank
|
| 41 |
-
echo master_ip $master_ip
|
| 42 |
-
echo config $config
|
| 43 |
-
echo output_dir $output_dir
|
| 44 |
-
|
| 45 |
-
if test -d "$output_dir"; then
|
| 46 |
-
cp $config $output_dir
|
| 47 |
-
else
|
| 48 |
-
mkdir -p "$output_dir"
|
| 49 |
-
cp $config $output_dir
|
| 50 |
-
fi
|
| 51 |
-
|
| 52 |
-
NODE_RANK=$node_rank \
|
| 53 |
-
HF_HUB_OFFLINE=0 \
|
| 54 |
-
MASTER_PORT=12348 \
|
| 55 |
-
MASTER_ADDR=$master_ip \
|
| 56 |
-
NCCL_SOCKET_IFNAME=bond1 \
|
| 57 |
-
NCCL_IB_GID_INDEX=3 \
|
| 58 |
-
NCCL_NVLS_ENABLE=0 \
|
| 59 |
-
python3 main.py \
|
| 60 |
-
--num_nodes $node_num \
|
| 61 |
-
--num_gpus $num_gpu_per_node \
|
| 62 |
-
--config $config \
|
| 63 |
-
--output_dir $output_dir \
|
| 64 |
-
--deepspeed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|