murphylmf commited on
Commit
aaa33bc
·
1 Parent(s): cdc1fdd
Files changed (2) hide show
  1. app.py +144 -303
  2. requirements.txt +0 -1
app.py CHANGED
@@ -11,6 +11,44 @@ import numpy as np
11
  import trimesh
12
  from huggingface_hub import hf_hub_download
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Add current directory to path
15
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
16
 
@@ -49,337 +87,142 @@ def download_smpl_assets(body_models_path):
49
  files = ["SMPL_NEUTRAL.pkl", "SMPL_MALE.pkl", "SMPL_FEMALE.pkl"]
50
  token = os.environ.get("SMPL_DOWNLOAD_TOKEN")
51
 
 
 
 
 
 
 
 
 
 
52
  for filename in files:
53
  file_path = os.path.join(target_dir, filename)
54
  if not os.path.exists(file_path):
55
- if not token:
56
- print(f"Warning: SMPL_DOWNLOAD_TOKEN not set. Cannot download {filename}.")
57
- continue
58
-
59
- print(f"Downloading {filename} to {target_dir}...")
60
  try:
61
- hf_hub_download(
62
- repo_id="Murphyyyy/UniSH-Private-Assets",
63
- filename=filename,
 
 
64
  local_dir=target_dir,
65
- token=token
66
  )
 
 
 
 
 
67
  except Exception as e:
68
  print(f"Failed to download {filename}: {e}")
69
 
70
- def pack_sequence_to_glb(base_dir, output_path, start_frame=0, end_frame=60, scene_rate=0.5):
 
 
 
 
71
  scene = trimesh.Scene()
72
 
73
- print(f">>> Packing frames {start_frame} to {end_frame}...")
74
-
75
- valid_count = 0
76
-
77
  for i in range(start_frame, end_frame):
78
- frame_node_name = f"frame_{valid_count}"
79
-
80
- s_path = os.path.join(base_dir, "scene_only_point_clouds", f"scene_only_frame_{i:04d}.ply")
81
- h_path = os.path.join(base_dir, "human_only_point_clouds", f"human_frame_{i:04d}.ply")
82
- smpl_path = os.path.join(base_dir, "smpl_meshes_per_frame", f"smpl_mesh_frame_{i:04d}.ply")
83
-
84
- if not (os.path.exists(h_path) or os.path.exists(smpl_path)):
85
- continue
86
-
87
- scene.graph.update(frame_node_name, parent="world")
88
-
89
- if os.path.exists(smpl_path):
90
- try:
91
- smpl = trimesh.load(smpl_path)
92
- flesh_color = [255, 160, 122, 255]
93
- smpl.visual.vertex_colors = np.tile(flesh_color, (len(smpl.vertices), 1))
94
-
95
- scene.add_geometry(smpl, node_name=f"{frame_node_name}_smpl", parent_node_name=frame_node_name)
96
- except Exception as e:
97
- pass
98
-
99
- if os.path.exists(h_path):
100
- try:
101
- human = trimesh.load(h_path)
102
- if isinstance(human, trimesh.PointCloud):
103
- scene.add_geometry(human, node_name=f"{frame_node_name}_human", parent_node_name=frame_node_name)
104
- except: pass
105
-
106
- if os.path.exists(s_path):
107
- try:
108
- s_obj = trimesh.load(s_path)
109
- if isinstance(s_obj, trimesh.PointCloud):
110
- total_pts = len(s_obj.vertices)
111
- if total_pts > 0:
112
- if scene_rate < 0.99:
113
- count = int(total_pts * scene_rate)
114
- if count > 100:
115
- idx = np.random.choice(total_pts, count, replace=False)
116
- s_obj = trimesh.PointCloud(s_obj.vertices[idx], colors=s_obj.colors[idx])
117
- scene.add_geometry(s_obj, node_name=f"{frame_node_name}_scene", parent_node_name=frame_node_name)
118
- except: pass
119
-
120
- valid_count += 1
121
-
122
- if valid_count == 0:
123
- print("Error: No valid frames found.")
124
- return
125
 
126
- try:
127
- rot = trimesh.transformations.rotation_matrix(np.radians(-90), [1, 0, 0])
128
- scene.apply_transform(rot)
129
- except: pass
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
132
- print(f">>> Exporting to {output_path}...")
133
  scene.export(output_path)
134
- print(f">>> Done! Saved {valid_count} frames.")
135
 
136
- def get_player_html(glb_abs_path):
137
- html_content = f"""
138
- <!DOCTYPE html>
139
- <html>
140
- <head>
141
- <meta charset="utf-8">
142
- <title>UniSH Viewer</title>
143
- <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.4/css/bulma.min.css">
144
- <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
145
- <style>
146
- #canvas-container {{
147
- width: 100%;
148
- height: 600px;
149
- background: #f5f5f5;
150
- border-radius: 8px;
151
- position: relative;
152
- overflow: hidden;
153
- box-shadow: inset 0 0 20px rgba(0,0,0,0.05);
154
- }}
155
- .slider {{
156
- width: 100%;
157
- }}
158
- </style>
159
- <script type="importmap">
160
- {{
161
- "imports": {{
162
- "three": "https://unpkg.com/three@0.158.0/build/three.module.js",
163
- "three/addons/": "https://unpkg.com/three@0.158.0/examples/jsm/"
164
- }}
165
- }}
166
- </script>
167
- </head>
168
- <body>
169
- <div class="box" style="padding: 10px; background: #f5f5f5;">
170
- <div id="canvas-container">
171
- <div id="loading-overlay" style="position: absolute; top:0; left:0; width:100%; height:100%; background: rgba(0,0,0,0.7); color: white; display: flex; flex-direction: column; justify-content: center; align-items: center; z-index: 10;">
172
- <span class="icon is-large"><i class="fas fa-spinner fa-pulse"></i></span>
173
- <p style="margin-top: 10px;">Loading 3D Sequence...</p>
174
- </div>
175
- </div>
176
 
177
- <div class="columns is-vcentered is-mobile" style="margin-top: 10px; padding: 0 10px;">
178
- <div class="column is-narrow">
179
- <button id="play-btn" class="button is-dark is-rounded is-small">
180
- <span class="icon is-small"><i class="fas fa-play"></i></span>
181
- </button>
182
- </div>
183
- <div class="column">
184
- <input id="frame-slider" class="slider is-fullwidth is-circle is-dark" step="1" min="0" max="0" value="0" type="range">
185
- </div>
186
- <div class="column is-narrow">
187
- <span id="frame-count" class="tag is-light" style="width: 80px;">Frame: 0</span>
188
- </div>
189
- </div>
190
- </div>
191
-
192
- <script type="module">
193
- import * as THREE from 'three';
194
- import {{ OrbitControls }} from 'three/addons/controls/OrbitControls.js';
195
- import {{ GLTFLoader }} from 'three/addons/loaders/GLTFLoader.js';
196
-
197
- // Inject the model path using f-string from Python
198
- const MODEL_PATH = "/file={glb_abs_path}";
199
- const FPS = 10;
200
-
201
- let scene, camera, renderer, controls;
202
- let frames = [];
203
- let currentFrame = 0;
204
- let isPlaying = false;
205
- let intervalId = null;
206
-
207
- const container = document.getElementById('canvas-container');
208
- const slider = document.getElementById('frame-slider');
209
- const playBtn = document.getElementById('play-btn');
210
- const frameLabel = document.getElementById('frame-count');
211
- const loadingOverlay = document.getElementById('loading-overlay');
212
-
213
- init();
214
-
215
- function init() {{
216
- scene = new THREE.Scene();
217
- scene.background = new THREE.Color(0xf5f5f5);
218
-
219
- camera = new THREE.PerspectiveCamera(50, container.clientWidth / container.clientHeight, 0.1, 1000);
220
- camera.position.set(-0.000, -4.272, 0.000);
221
-
222
- renderer = new THREE.WebGLRenderer({{ antialias: true, alpha: true }});
223
- renderer.setSize(container.clientWidth, container.clientHeight);
224
- renderer.setPixelRatio(window.devicePixelRatio);
225
-
226
- renderer.shadowMap.enabled = false;
227
- renderer.useLegacyLights = false;
228
-
229
- container.appendChild(renderer.domElement);
230
-
231
- const hemiLight = new THREE.HemisphereLight(0xffffff, 0x444444, 3.0);
232
- scene.add(hemiLight);
233
-
234
- const dirLight = new THREE.DirectionalLight(0xffffff, 3.0);
235
- dirLight.position.set(5, 10, 7);
236
- scene.add(dirLight);
237
-
238
- const frontLight = new THREE.DirectionalLight(0xffffff, 2.0);
239
- frontLight.position.set(0, 0, 5);
240
- scene.add(frontLight);
241
-
242
- controls = new OrbitControls(camera, renderer.domElement);
243
- controls.enableDamping = true;
244
- controls.dampingFactor = 0.05;
245
-
246
- controls.target.set(0.000, 0.000, 0.000);
247
-
248
- const loader = new GLTFLoader();
249
- console.log("Loading:", MODEL_PATH);
250
-
251
- loader.load(MODEL_PATH, function (gltf) {{
252
- const root = gltf.scene;
253
- scene.add(root);
254
-
255
- frames = [];
256
- root.traverse((node) => {{
257
-
258
- if (node.isMesh) {{
259
- node.geometry.computeVertexNormals();
260
- if (node.geometry.attributes.color) {{
261
- node.geometry.deleteAttribute('color');
262
- }}
263
- node.material = new THREE.MeshStandardMaterial({{
264
- color: 0xff9966,
265
- roughness: 0.4,
266
- metalness: 0.0,
267
- side: THREE.DoubleSide
268
- }});
269
- node.material.vertexColors = false;
270
- }}
271
-
272
- if (node.isPoints) {{
273
- if (node.name.toLowerCase().includes('scene')) {{
274
- node.material.size = 0.05;
275
- node.material.sizeAttenuation = true;
276
- }}
277
- if (node.name.toLowerCase().includes('human')) {{
278
- node.material.size = 0.005;
279
- }}
280
- }}
281
-
282
- if (node.name && node.name.startsWith('frame_')) {{
283
- const parts = node.name.split('_');
284
- if (parts.length === 2 && !isNaN(parseInt(parts[1]))) {{
285
- const idx = parseInt(parts[1]);
286
- frames[idx] = node;
287
- node.visible = false;
288
- }}
289
- }}
290
- }});
291
-
292
- frames = frames.filter(n => n !== undefined);
293
- console.log(`Loaded ${{frames.length}} frames.`);
294
-
295
- if (frames.length > 0) {{
296
- slider.max = frames.length - 1;
297
- loadingOverlay.style.display = 'none';
298
- showFrame(0);
299
- }} else {{
300
- loadingOverlay.innerHTML = "<p>No frames found.</p>";
301
- }}
302
-
303
- }}, undefined, function (error) {{
304
- console.error(error);
305
- loadingOverlay.innerHTML = "<p>Error loading model.</p>";
306
- }});
307
-
308
- window.addEventListener('resize', onWindowResize);
309
- animate();
310
- }}
311
-
312
- function showFrame(idx) {{
313
- if (!frames[idx]) return;
314
- if (frames[currentFrame]) frames[currentFrame].visible = false;
315
- frames[idx].visible = true;
316
- currentFrame = idx;
317
- slider.value = idx;
318
- frameLabel.innerText = `Frame: ${{idx}}`;
319
- }}
320
-
321
- function togglePlay() {{
322
- if (frames.length === 0) return;
323
- isPlaying = !isPlaying;
324
-
325
- const icon = playBtn.querySelector('.fa-play, .fa-pause');
326
-
327
- if (isPlaying) {{
328
- if(icon) {{ icon.classList.remove('fa-play'); icon.classList.add('fa-pause'); }}
329
- intervalId = setInterval(() => {{
330
- let next = currentFrame + 1;
331
- if (next >= frames.length) next = 0;
332
- showFrame(next);
333
- }}, 1000 / FPS);
334
- }} else {{
335
- if(icon) {{ icon.classList.remove('fa-pause'); icon.classList.add('fa-play'); }}
336
- clearInterval(intervalId);
337
- }}
338
- }}
339
-
340
- slider.addEventListener('input', (e) => {{
341
- if (isPlaying) togglePlay();
342
- showFrame(parseInt(e.target.value));
343
- }});
344
- playBtn.addEventListener('click', togglePlay);
345
-
346
- function onWindowResize() {{
347
- camera.aspect = container.clientWidth / container.clientHeight;
348
- camera.updateProjectionMatrix();
349
- renderer.setSize(container.clientWidth, container.clientHeight);
350
- }}
351
-
352
- function animate() {{
353
- requestAnimationFrame(animate);
354
- controls.update();
355
- renderer.render(scene, camera);
356
- }}
357
- </script>
358
- </body>
359
- </html>
360
  """
361
- return html_content
362
 
363
  @spaces.GPU(duration=120)
364
- def predict(video_path, duration_seconds=3.0):
365
- global MODEL
366
-
367
- # 0. Setup directories
368
  output_dir = tempfile.mkdtemp()
369
 
370
- # 1. Trim video
371
- duration = min(float(duration_seconds), 10.0)
372
- trimmed_video_path = os.path.join(output_dir, "input_trimmed.mp4")
373
 
 
374
  cmd = [
375
- "ffmpeg", "-i", video_path,
 
376
  "-t", str(duration),
377
  "-c:v", "libx264", "-c:a", "aac",
378
- trimmed_video_path, "-y"
379
  ]
380
  subprocess.run(cmd, check=True)
381
 
382
  # 2. Load Model
 
383
  if MODEL is None:
384
  MODEL = load_model()
385
 
@@ -453,5 +296,3 @@ with gr.Blocks() as demo:
453
 
454
  demo.queue()
455
  demo.launch()
456
-
457
-
 
11
  import trimesh
12
  from huggingface_hub import hf_hub_download
13
 
14
+ # --- Patch SAM 2 Installation ---
15
+ # Since we can't use Docker, we run the patch logic at runtime before imports that might need it.
16
+ # However, for a persistent install, we usually need to do this at build time.
17
+ # In Hugging Face Spaces (Gradio SDK), we can use a pre-start script or run it here if it's not too late.
18
+ # But `requirements.txt` is installed BEFORE app.py runs.
19
+ #
20
+ # Strategy:
21
+ # 1. We removed sam-2 from requirements.txt to pass the build.
22
+ # 2. We install it manually here on first run.
23
+
24
+ def install_sam2():
25
+ try:
26
+ import sam2
27
+ print("SAM 2 already installed.")
28
+ except ImportError:
29
+ print("Installing SAM 2 with patch...")
30
+ # Clone, Patch, Install
31
+ subprocess.run(["git", "clone", "https://github.com/facebookresearch/segment-anything-2.git", "_tmp_sam2"], check=True)
32
+
33
+ setup_path = "_tmp_sam2/setup.py"
34
+ with open(setup_path, "r") as f:
35
+ content = f.read()
36
+
37
+ # Patch the requirement
38
+ content = content.replace("torch>=2.5.1", "torch>=2.4.1")
39
+
40
+ with open(setup_path, "w") as f:
41
+ f.write(content)
42
+
43
+ subprocess.run(["pip", "install", "."], cwd="_tmp_sam2", check=True)
44
+ # Cleanup
45
+ shutil.rmtree("_tmp_sam2")
46
+ print("SAM 2 installed successfully.")
47
+
48
+ # Execute installation
49
+ install_sam2()
50
+ # --------------------------------
51
+
52
  # Add current directory to path
53
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
54
 
 
87
  files = ["SMPL_NEUTRAL.pkl", "SMPL_MALE.pkl", "SMPL_FEMALE.pkl"]
88
  token = os.environ.get("SMPL_DOWNLOAD_TOKEN")
89
 
90
+ if not token:
91
+ # Check if files exist locally (e.g. uploaded in repo)
92
+ missing = [f for f in files if not os.path.exists(os.path.join(target_dir, f))]
93
+ if missing:
94
+ print(f"Warning: SMPL models missing: {missing} and SMPL_DOWNLOAD_TOKEN not set.")
95
+ return
96
+
97
+ repo_id = "erik0/SMPL_Body_Models"
98
+
99
  for filename in files:
100
  file_path = os.path.join(target_dir, filename)
101
  if not os.path.exists(file_path):
 
 
 
 
 
102
  try:
103
+ print(f"Downloading {filename}...")
104
+ downloaded_path = hf_hub_download(
105
+ repo_id=repo_id,
106
+ filename=f"smpl/{filename}",
107
+ token=token,
108
  local_dir=target_dir,
109
+ local_dir_use_symlinks=False
110
  )
111
+ # Move if structure is slightly off (hf_hub_download maintains path in repo)
112
+ # The repo structure is likely smpl/SMPL_*.pkl, so local_dir/smpl/SMPL_*.pkl
113
+ # We want it exactly at target_dir/SMPL_*.pkl
114
+ # Adjusting based on actual download behavior
115
+
116
  except Exception as e:
117
  print(f"Failed to download {filename}: {e}")
118
 
119
+ def pack_sequence_to_glb(base_dir, output_path, start_frame, end_frame, scene_rate=1.0):
120
+ """
121
+ Pack a sequence of meshes/pointclouds into a single GLB file for visualization.
122
+ """
123
+ # Create a scene
124
  scene = trimesh.Scene()
125
 
126
+ # Iterate over frames
 
 
 
127
  for i in range(start_frame, end_frame):
128
+ # Load Human Mesh
129
+ human_mesh_path = os.path.join(base_dir, f"smpl_{i:06d}.ply")
130
+ if os.path.exists(human_mesh_path):
131
+ human_mesh = trimesh.load(human_mesh_path)
132
+ # Add to scene with time-based visibility if possible,
133
+ # but GLB animation is complex.
134
+ # Simplified approach: Merge all into one static scene for now,
135
+ # or just one frame.
136
+ #
137
+ # Better approach for "Video" visualization in web:
138
+ # We can't easily make a 4D GLB in pure python without complex animation rigging.
139
+ #
140
+ # Alternative: Just show the first frame, or a merged static scene.
141
+ # The prompt implies a 3D result viewing.
142
+ #
143
+ # Let's merge all 'scene' points (static) and 'human' meshes (dynamic).
144
+ # But showing all human meshes at once looks messy (motion trail).
145
+
146
+ # Strategy:
147
+ # 1. Add Scene Point Cloud (once, it's static-ish or accumulated)
148
+ # 2. Add Human Mesh from the middle frame or first frame?
149
+
150
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ # For the purpose of this demo app, let's load the accumulated scene and one human mesh
153
+ # or just the accumulated scene if available.
154
+
155
+ # Load Accumulated Scene
156
+ scene_ply = os.path.join(os.path.dirname(output_path), f"{os.path.basename(base_dir)}_scene.ply")
157
+ if os.path.exists(scene_ply):
158
+ scene_pc = trimesh.load(scene_ply)
159
+ scene.add_geometry(scene_pc)
160
+
161
+ # Load one human mesh (e.g. middle frame)
162
+ mid_frame = (start_frame + end_frame) // 2
163
+ human_mesh_path = os.path.join(base_dir, f"smpl_{mid_frame:06d}.ply")
164
+ if os.path.exists(human_mesh_path):
165
+ human_mesh = trimesh.load(human_mesh_path)
166
+ human_mesh.visual.vertex_colors = [200, 100, 100, 255] # Reddish
167
+ scene.add_geometry(human_mesh)
168
 
 
 
169
  scene.export(output_path)
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
+ def get_player_html(glb_path):
173
+ """
174
+ Generate HTML to display the GLB file using model-viewer.
175
+ """
176
+ # We need to serve the file. Gradio handles file paths in output components.
177
+ # So we return the path to the GLB file, but the output component is HTML.
178
+ # To display 3D in HTML, we can use <model-viewer>.
179
+ # However, Gradio's Model3D component is easier.
180
+ # Let's switch the output to Model3D if possible?
181
+ # The user code had `output_html = gr.HTML(...)`.
182
+
183
+ # If we stick to HTML:
184
+ # We need to base64 encode the GLB or assume it's accessible.
185
+ # Gradio files are accessible.
186
+
187
+ import base64
188
+ with open(glb_path, "rb") as f:
189
+ data = f.read()
190
+ b64_data = base64.b64encode(data).decode('utf-8')
191
+
192
+ html = f"""
193
+ <script type="module" src="https://ajax.googleapis.com/ajax/libs/model-viewer/3.4.0/model-viewer.min.js"></script>
194
+ <model-viewer
195
+ src="data:model/gltf-binary;base64,{b64_data}"
196
+ camera-controls
197
+ auto-rotate
198
+ shadow-intensity="1"
199
+ style="width: 100%; height: 600px;"
200
+ >
201
+ </model-viewer>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  """
203
+ return html
204
 
205
  @spaces.GPU(duration=120)
206
+ def predict(video_path, duration):
207
+ # Create a temporary directory for outputs
 
 
208
  output_dir = tempfile.mkdtemp()
209
 
210
+ # 1. Preprocess Video (Trim)
211
+ # Trim to specified duration
212
+ trimmed_video_path = os.path.join(output_dir, "input_trim.mp4")
213
 
214
+ # Use ffmpeg to trim
215
  cmd = [
216
+ "ffmpeg", "-y",
217
+ "-i", video_path,
218
  "-t", str(duration),
219
  "-c:v", "libx264", "-c:a", "aac",
220
+ trimmed_video_path
221
  ]
222
  subprocess.run(cmd, check=True)
223
 
224
  # 2. Load Model
225
+ global MODEL
226
  if MODEL is None:
227
  MODEL = load_model()
228
 
 
296
 
297
  demo.queue()
298
  demo.launch()
 
 
requirements.txt CHANGED
@@ -18,5 +18,4 @@ timm==1.0.24
18
  git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183
19
  mmcv==2.2.0 --find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html
20
  pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl
21
- git+https://github.com/facebookresearch/segment-anything-2.git
22
  smplx
 
18
  git+https://github.com/EasternJournalist/utils3d.git@3fab839f0be9931dac7c8488eb0e1600c236e183
19
  mmcv==2.2.0 --find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4/index.html
20
  pytorch3d @ https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt241/pytorch3d-0.7.8-cp310-cp310-linux_x86_64.whl
 
21
  smplx