ckc99u commited on
Commit
d038f33
·
verified ·
1 Parent(s): fa1e916

Upload 3 files

Browse files
Files changed (3) hide show
  1. download_models.py +73 -178
  2. inference.py +176 -0
  3. packages.txt +6 -0
download_models.py CHANGED
@@ -1,178 +1,73 @@
1
- #!/usr/bin/env python3
2
- """
3
- Download large model checkpoints at runtime for Hugging Face Spaces
4
- This avoids the 1GB Git LFS repository limit
5
- """
6
-
7
- import os
8
- import sys
9
- from pathlib import Path
10
- from huggingface_hub import hf_hub_download, snapshot_download
11
- import requests
12
- from tqdm import tqdm
13
-
14
- def download_file_with_progress(url, destination):
15
- """Download a file with progress bar"""
16
- response = requests.get(url, stream=True)
17
- total_size = int(response.headers.get('content-length', 0))
18
-
19
- destination = Path(destination)
20
- destination.parent.mkdir(parents=True, exist_ok=True)
21
-
22
- with open(destination, 'wb') as f, tqdm(
23
- desc=destination.name,
24
- total=total_size,
25
- unit='iB',
26
- unit_scale=True,
27
- unit_divisor=1024,
28
- ) as pbar:
29
- for data in response.iter_content(chunk_size=1024):
30
- size = f.write(data)
31
- pbar.update(size)
32
-
33
- def download_from_huggingface_hub():
34
- """Download models from Hugging Face Hub"""
35
-
36
- print("=" * 70)
37
- print("Downloading MagicArticulate checkpoints from Hugging Face Hub...")
38
- print("=" * 70)
39
-
40
- # Create checkpoints directory
41
- checkpoints_dir = Path("checkpoints")
42
- checkpoints_dir.mkdir(exist_ok=True)
43
-
44
- try:
45
- # Download MagicArticulate checkpoints
46
- print("\n📥 Downloading MagicArticulate spatial_order checkpoint...")
47
- spatial_order_path = hf_hub_download(
48
- repo_id="Seed3D/MagicArticulate",
49
- filename="checkpoints/spatial_order.pt",
50
- local_dir=".",
51
- local_dir_use_symlinks=False
52
- )
53
- print(f" Downloaded: {spatial_order_path}")
54
-
55
- print("\n📥 Downloading MagicArticulate hier_order checkpoint...")
56
- hier_order_path = hf_hub_download(
57
- repo_id="Seed3D/MagicArticulate",
58
- filename="checkpoints/hier_order.pt",
59
- local_dir=".",
60
- local_dir_use_symlinks=False
61
- )
62
- print(f"✅ Downloaded: {hier_order_path}")
63
-
64
- except Exception as e:
65
- print(f"❌ Error downloading MagicArticulate checkpoints: {e}")
66
- print("Falling back to direct download...")
67
- download_from_direct_links()
68
- return
69
-
70
- # Download Michelangelo checkpoints
71
- print("\n" + "=" * 70)
72
- print("Downloading Michelangelo checkpoints...")
73
- print("=" * 70)
74
-
75
- try:
76
- # Download aligned_shape_latents
77
- print("\n📥 Downloading Michelangelo aligned_shape_latents...")
78
- michelangelo_dir = Path("checkpoints/michelangelo")
79
- michelangelo_dir.mkdir(parents=True, exist_ok=True)
80
-
81
- # Download the full Michelangelo repo or specific files
82
- snapshot_download(
83
- repo_id="Maikou/Michelangelo",
84
- allow_patterns=["checkpoints/aligned_shape_latents/*"],
85
- local_dir=".",
86
- local_dir_use_symlinks=False
87
- )
88
- print("✅ Michelangelo checkpoints downloaded")
89
-
90
- except Exception as e:
91
- print(f"⚠️ Warning: Could not download Michelangelo checkpoints: {e}")
92
- print("You may need to download them manually or adjust the code.")
93
-
94
- def download_from_direct_links():
95
- """Fallback: Download from direct links"""
96
-
97
- print("\n📥 Using direct download links...")
98
-
99
- # You can add direct download links here as fallback
100
- # For example, from Google Drive, OneDrive, or other cloud storage
101
-
102
- direct_links = {
103
- # Add your direct links here if available
104
- # "checkpoints/spatial_order.pt": "https://your-direct-link/spatial_order.pt",
105
- # "checkpoints/hier_order.pt": "https://your-direct-link/hier_order.pt",
106
- }
107
-
108
- for filepath, url in direct_links.items():
109
- print(f"Downloading {filepath}...")
110
- download_file_with_progress(url, filepath)
111
-
112
- def verify_downloads():
113
- """Verify that all required files are downloaded"""
114
-
115
- required_files = [
116
- "checkpoints/spatial_order.pt",
117
- "checkpoints/hier_order.pt",
118
- ]
119
-
120
- optional_files = [
121
- "checkpoints/aligned_shape_latents/model.safetensors",
122
- ]
123
-
124
- print("\n" + "=" * 70)
125
- print("Verifying downloads...")
126
- print("=" * 70)
127
-
128
- all_good = True
129
- for filepath in required_files:
130
- if Path(filepath).exists():
131
- size = Path(filepath).stat().st_size / (1024 * 1024) # MB
132
- print(f"✅ {filepath} ({size:.2f} MB)")
133
- else:
134
- print(f"❌ {filepath} - MISSING!")
135
- all_good = False
136
-
137
- for filepath in optional_files:
138
- if Path(filepath).exists():
139
- size = Path(filepath).stat().st_size / (1024 * 1024) # MB
140
- print(f"✅ {filepath} ({size:.2f} MB) [optional]")
141
- else:
142
- print(f"⚠️ {filepath} - Not found [optional]")
143
-
144
- if all_good:
145
- print("\n✅ All required checkpoints downloaded successfully!")
146
- else:
147
- print("\n❌ Some required files are missing!")
148
- sys.exit(1)
149
-
150
- # Calculate total size
151
- total_size = 0
152
- for filepath in required_files + optional_files:
153
- if Path(filepath).exists():
154
- total_size += Path(filepath).stat().st_size
155
-
156
- print(f"\n📊 Total downloaded: {total_size / (1024 * 1024 * 1024):.2f} GB")
157
-
158
- def main():
159
- """Main download function"""
160
-
161
- print("\n" + "=" * 70)
162
- print("🚀 MagicArticulate Model Downloader for Hugging Face Spaces")
163
- print("=" * 70)
164
- print("\nThis script downloads large model checkpoints at runtime")
165
- print("to avoid Hugging Face Spaces 1GB Git LFS limit.\n")
166
-
167
- # Download from Hugging Face Hub
168
- download_from_huggingface_hub()
169
-
170
- # Verify all downloads
171
- verify_downloads()
172
-
173
- print("\n" + "=" * 70)
174
- print("✅ Setup complete! Models are ready to use.")
175
- print("=" * 70)
176
-
177
- if __name__ == "__main__":
178
- main()
 
1
+ """
2
+ Download model checkpoints from Hugging Face Hub
3
+ Optimized for Spaces environment with error handling
4
+ """
5
+ import os
6
+ from huggingface_hub import hf_hub_download
7
+ from pathlib import Path
8
+
9
+ def download_with_retry(repo_id, filename, local_dir, max_retries=3):
10
+ """Download file with retry logic"""
11
+ for attempt in range(max_retries):
12
+ try:
13
+ print(f"Downloading {filename} (attempt {attempt + 1}/{max_retries})...")
14
+ file_path = hf_hub_download(
15
+ repo_id=repo_id,
16
+ filename=filename,
17
+ local_dir=local_dir,
18
+ local_dir_use_symlinks=False # Important for Spaces
19
+ )
20
+ print(f"✓ Successfully downloaded: {filename}")
21
+ return file_path
22
+ except Exception as e:
23
+ if attempt == max_retries - 1:
24
+ print(f"✗ Failed to download {filename}: {e}")
25
+ raise
26
+ print(f"Retry {attempt + 1} failed, trying again...")
27
+ return None
28
+
29
+ def main():
30
+ print("=" * 50)
31
+ print("Downloading MagicArticulate Model Checkpoints")
32
+ print("=" * 50)
33
+
34
+ # Create directories
35
+ Path("skeleton_ckpt").mkdir(exist_ok=True)
36
+ Path("third_partys/Michelangelo/checkpoints/aligned_shape_latents").mkdir(
37
+ parents=True, exist_ok=True
38
+ )
39
+
40
+ # Download Michelangelo checkpoint (required dependency)
41
+ print("\n[1/3] Downloading Michelangelo checkpoint...")
42
+ try:
43
+ download_with_retry(
44
+ repo_id="Maikou/Michelangelo",
45
+ filename="checkpoints/aligned_shape_latents/shapevae-256.ckpt",
46
+ local_dir="third_partys/Michelangelo"
47
+ )
48
+ except Exception as e:
49
+ print(f"Warning: Michelangelo download failed: {e}")
50
+ print("This may affect some features.")
51
+
52
+ # Download MagicArticulate spatial checkpoint
53
+ print("\n[2/3] Downloading MagicArticulate spatial checkpoint...")
54
+ download_with_retry(
55
+ repo_id="Seed3D/MagicArticulate",
56
+ filename="skeleton_ckpt/checkpoint_trainonv2_spatial.pth",
57
+ local_dir="."
58
+ )
59
+
60
+ # Download MagicArticulate hierarchical checkpoint
61
+ print("\n[3/3] Downloading MagicArticulate hierarchical checkpoint...")
62
+ download_with_retry(
63
+ repo_id="Seed3D/MagicArticulate",
64
+ filename="skeleton_ckpt/checkpoint_trainonv2_hier.pth",
65
+ local_dir="."
66
+ )
67
+
68
+ print("\n" + "=" * 50)
69
+ print("✓ All downloads completed successfully!")
70
+ print("=" * 50)
71
+
72
+ if __name__ == "__main__":
73
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import trimesh
4
+ import numpy as np
5
+ from pathlib import Path
6
+ import time
7
+
8
+ # Import from MagicArticulate
9
+ from skeleton_models.skeletongen import SkeletonGPT
10
+ from data_utils.save_npz import normalize_to_unit_cube
11
+ from utils.mesh_to_pc import MeshProcessor
12
+ from utils.save_utils import (
13
+ pred_joints_and_bones,
14
+ save_skeleton_to_txt,
15
+ merge_duplicate_joints_and_fix_bones,
16
+ save_skeleton_obj,
17
+ save_mesh
18
+ )
19
+
20
+ class SkeletonInferencer:
21
+ """Wrapper class for skeleton generation inference"""
22
+
23
+ def __init__(self, pretrained_weights, device="cuda", precision="fp16"):
24
+ self.device = device
25
+ self.precision = precision
26
+
27
+ # Create args object
28
+ class Args:
29
+ def __init__(self):
30
+ self.llm = "facebook/opt-350m"
31
+ self.pad_id = -1
32
+ self.n_discrete_size = 128
33
+ self.n_max_bones = 100
34
+ self.num_beams = 1
35
+ self.seed = 0
36
+
37
+ self.args = Args()
38
+
39
+ # Load model
40
+ print(f"Loading model from {pretrained_weights}...")
41
+ self.model = SkeletonGPT(self.args).to(device)
42
+
43
+ pkg = torch.load(pretrained_weights, map_location=torch.device("cpu"))
44
+ self.model.load_state_dict(pkg["model"])
45
+ self.model.eval()
46
+
47
+ # Set precision
48
+ if precision == "fp16" and device == "cuda":
49
+ self.model = self.model.half()
50
+
51
+ print("Model loaded successfully!")
52
+
53
+ @torch.no_grad()
54
+ def infer(
55
+ self,
56
+ input_path,
57
+ output_dir,
58
+ input_pc_num=8192,
59
+ apply_marching_cubes=False,
60
+ octree_depth=7,
61
+ sequence_type="spatial"
62
+ ):
63
+ """
64
+ Run inference on a single mesh file
65
+
66
+ Returns:
67
+ dict: Results including paths and statistics
68
+ """
69
+ start_time = time.time()
70
+
71
+ output_dir = Path(output_dir)
72
+ output_dir.mkdir(parents=True, exist_ok=True)
73
+
74
+ # Load mesh
75
+ mesh = trimesh.load(input_path, force='mesh')
76
+
77
+ # Convert to point cloud
78
+ if apply_marching_cubes:
79
+ pc_list = MeshProcessor.convert_meshes_to_point_clouds(
80
+ [mesh], input_pc_num,
81
+ apply_marching_cubes=True,
82
+ octree_depth=octree_depth
83
+ )
84
+ pc_normal = pc_list[0]
85
+ else:
86
+ # Simple sampling
87
+ points, face_indices = trimesh.sample.sample_surface(mesh, input_pc_num)
88
+ normals = mesh.face_normals[face_indices]
89
+ pc_normal = np.concatenate([points, normals], axis=-1)
90
+
91
+ # Normalize point cloud
92
+ pc_coor = pc_normal[:, :3]
93
+ normals = pc_normal[:, 3:]
94
+ pc_coor, center, scale = normalize_to_unit_cube(pc_coor, scale_factor=0.9995)
95
+
96
+ # Prepare transform parameters
97
+ bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
98
+ pc_center = (bounds[0] + bounds[1]) / 2
99
+ pc_scale = (bounds[1] - bounds[0]).max() + 1e-5
100
+
101
+ transform_params = torch.tensor([
102
+ center[0], center[1], center[2], scale,
103
+ pc_center[0], pc_center[1], pc_center[2], pc_scale
104
+ ], dtype=torch.float32)
105
+
106
+ # Prepare batch data
107
+ pc_normal_normalized = np.concatenate([pc_coor, normals], axis=-1)
108
+ batch_data = {
109
+ 'pc_normal': torch.from_numpy(pc_normal_normalized).half().unsqueeze(0).to(self.device),
110
+ 'transform_params': transform_params.unsqueeze(0),
111
+ 'vertices': torch.from_numpy(mesh.vertices).unsqueeze(0),
112
+ 'faces': torch.from_numpy(mesh.faces).unsqueeze(0),
113
+ 'file_name': [Path(input_path).stem]
114
+ }
115
+
116
+ # Generate skeleton
117
+ pred_bone_coords = self.model.generate(batch_data)
118
+
119
+ # Process results
120
+ file_name = Path(input_path).stem
121
+ skeleton = pred_bone_coords[0].cpu().numpy()
122
+ pred_joints, pred_bones = pred_joints_and_bones(skeleton.squeeze())
123
+
124
+ # Post-process
125
+ hier_order = (sequence_type == "hierarchical")
126
+ if hier_order and len(pred_bones) > 0:
127
+ pred_root_index = pred_bones[0][0]
128
+ pred_joints, pred_bones, pred_root_index = merge_duplicate_joints_and_fix_bones(
129
+ pred_joints, pred_bones, root_index=pred_root_index
130
+ )
131
+ else:
132
+ pred_joints, pred_bones = merge_duplicate_joints_and_fix_bones(
133
+ pred_joints, pred_bones
134
+ )
135
+ pred_root_index = None
136
+
137
+ # Denormalize for saving
138
+ trans = transform_params[:3].numpy()
139
+ scale_val = transform_params[3].item()
140
+ pc_trans = transform_params[4:7].numpy()
141
+ pc_scale_val = transform_params[7].item()
142
+
143
+ pred_joints_denorm = pred_joints * pc_scale_val + pc_trans
144
+ pred_joints_denorm = pred_joints_denorm / scale_val + trans
145
+
146
+ # Save files
147
+ pred_rig_filename = output_dir / f"{file_name}_pred.txt"
148
+ pred_skel_filename = output_dir / f"{file_name}_skel.obj"
149
+ mesh_filename = output_dir / f"{file_name}_mesh.obj"
150
+
151
+ save_skeleton_to_txt(
152
+ pred_joints_denorm, pred_bones, pred_root_index,
153
+ hier_order, mesh.vertices, str(pred_rig_filename)
154
+ )
155
+
156
+ save_skeleton_obj(
157
+ pred_joints, pred_bones, str(pred_skel_filename),
158
+ pred_root_index if hier_order else None,
159
+ use_cone=hier_order
160
+ )
161
+
162
+ # Save normalized mesh
163
+ vertices_norm = (mesh.vertices - trans) * scale_val
164
+ vertices_norm = (vertices_norm - pc_trans) / pc_scale_val
165
+ save_mesh(vertices_norm, mesh.faces, str(mesh_filename))
166
+
167
+ elapsed_time = time.time() - start_time
168
+
169
+ return {
170
+ 'skeleton_file': str(pred_skel_filename),
171
+ 'rig_file': str(pred_rig_filename),
172
+ 'mesh_file': str(mesh_filename),
173
+ 'num_joints': len(pred_joints),
174
+ 'num_bones': len(pred_bones),
175
+ 'time': elapsed_time
176
+ }
packages.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ libgl1-mesa-glx
2
+ libglib2.0-0
3
+ libsm6
4
+ libxext6
5
+ libxrender-dev
6
+ libgomp1