rdz-falcon commited on
Commit
1872ff5
·
verified ·
1 Parent(s): 625c5a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -221
app.py CHANGED
@@ -6,7 +6,6 @@ Uses PyRender for high-quality avatar visualization
6
  # IMPORTANT: Set OpenGL platform BEFORE any OpenGL imports (for headless rendering)
7
  import os
8
  os.environ["PYOPENGL_PLATFORM"] = "egl"
9
-
10
  import sys
11
  import re
12
  import json
@@ -15,12 +14,9 @@ import warnings
15
  import tempfile
16
  import uuid
17
  from pathlib import Path
18
-
19
  import torch
20
  import numpy as np
21
-
22
  warnings.filterwarnings("ignore")
23
-
24
  # =====================================================================
25
  # Configuration for HuggingFace Spaces
26
  # =====================================================================
@@ -29,19 +25,15 @@ DATA_DIR = os.path.join(WORK_DIR, "data")
29
  OUTPUT_DIR = os.path.join(WORK_DIR, "outputs")
30
  os.makedirs(DATA_DIR, exist_ok=True)
31
  os.makedirs(OUTPUT_DIR, exist_ok=True)
32
-
33
  # Path definitions
34
  DATASET_PATH = os.path.join(DATA_DIR, "motion_llm_dataset.json")
35
  VQVAE_CHECKPOINT = os.path.join(DATA_DIR, "vqvae_model.pt")
36
  STATS_PATH = os.path.join(DATA_DIR, "vqvae_stats.pt")
37
  SMPLX_MODEL_DIR = os.path.join(DATA_DIR, "smplx_models")
38
-
39
  # HuggingFace model config
40
  HF_REPO_ID = os.environ.get("HF_REPO_ID", "rdz-falcon/SignMotionGPTfit-archive")
41
  HF_SUBFOLDER = os.environ.get("HF_SUBFOLDER", "stage2_v2/epoch-030")
42
-
43
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
-
45
  # Generation parameters
46
  M_START = "<M_START>"
47
  M_END = "<M_END>"
@@ -49,7 +41,6 @@ PAD_TOKEN = "<PAD>"
49
  INFERENCE_TEMPERATURE = 0.7
50
  INFERENCE_TOP_K = 50
51
  INFERENCE_REPETITION_PENALTY = 1.2
52
-
53
  # VQ-VAE parameters
54
  SMPL_DIM = 182
55
  CODEBOOK_SIZE = 512
@@ -58,18 +49,15 @@ VQ_ARGS = dict(
58
  width=512, depth=3, down_t=2, stride_t=2,
59
  dilation_growth_rate=3, activation='relu', norm=None, quantizer="ema_reset"
60
  )
61
-
62
  PARAM_DIMS = [10, 63, 45, 45, 3, 10, 3, 3]
63
- PARAM_NAMES = ["shape", "body_pose", "lhand_pose", "rhand_pose",
64
- "jaw_pose", "expression", "root_pose", "cam_trans"]
65
-
66
  # Visualization defaults
67
  AVATAR_COLOR = (0.36, 0.78, 0.36, 1.0) # Green color as RGBA
68
  VIDEO_FPS = 15
69
  VIDEO_SLOWDOWN = 2
70
  FRAME_WIDTH = 544 # Must be divisible by 16 for video codec compatibility
71
  FRAME_HEIGHT = 720
72
-
73
  # =====================================================================
74
  # Install/Import Dependencies
75
  # =====================================================================
@@ -78,13 +66,11 @@ try:
78
  except ImportError:
79
  os.system("pip install -q gradio>=4.0.0")
80
  import gradio as gr
81
-
82
  try:
83
  import smplx
84
  except ImportError:
85
  os.system("pip install -q smplx==0.1.28")
86
  import smplx
87
-
88
  # PyRender for high-quality rendering
89
  PYRENDER_AVAILABLE = False
90
  try:
@@ -94,16 +80,13 @@ try:
94
  PYRENDER_AVAILABLE = True
95
  except ImportError:
96
  pass
97
-
98
  try:
99
  import imageio
100
  except ImportError:
101
  os.system("pip install -q imageio[ffmpeg]")
102
  import imageio
103
-
104
  from transformers import AutoModelForCausalLM, AutoTokenizer
105
  import torch.nn.functional as F
106
-
107
  # =====================================================================
108
  # Import VQ-VAE architecture
109
  # =====================================================================
@@ -113,13 +96,11 @@ if parent_dir not in sys.path:
113
  sys.path.insert(0, parent_dir)
114
  if current_dir not in sys.path:
115
  sys.path.insert(0, current_dir)
116
-
117
  try:
118
  from mGPT.archs.mgpt_vq import VQVae
119
  except ImportError as e:
120
  print(f"Warning: Could not import VQVae: {e}")
121
  VQVae = None
122
-
123
  # =====================================================================
124
  # Global Cache
125
  # =====================================================================
@@ -131,10 +112,8 @@ _model_cache = {
131
  "stats": (None, None),
132
  "initialized": False
133
  }
134
-
135
  _word_pid_map = {}
136
  _example_cache = {}
137
-
138
  # =====================================================================
139
  # PyRender Setup
140
  # =====================================================================
@@ -143,12 +122,12 @@ def ensure_pyrender():
143
  global PYRENDER_AVAILABLE, trimesh, pyrender, Image, ImageDraw, ImageFont
144
  if PYRENDER_AVAILABLE:
145
  return True
146
-
147
  print("Installing pyrender dependencies...")
148
  if os.path.exists("/etc/debian_version"):
149
  os.system("apt-get update -qq && apt-get install -qq -y libegl1-mesa-dev libgles2-mesa-dev > /dev/null 2>&1")
150
  os.system("pip install -q trimesh pyrender PyOpenGL PyOpenGL_accelerate Pillow")
151
-
152
  try:
153
  import trimesh
154
  import pyrender
@@ -158,23 +137,22 @@ def ensure_pyrender():
158
  except ImportError as e:
159
  print(f"Could not install pyrender: {e}")
160
  return False
161
-
162
  # =====================================================================
163
  # Dataset Loading - Word to PID mapping
164
  # =====================================================================
165
  def load_word_pid_mapping():
166
  """Load the dataset and build word -> PIDs mapping."""
167
  global _word_pid_map
168
-
169
  if not os.path.exists(DATASET_PATH):
170
  print(f"Dataset not found: {DATASET_PATH}")
171
  return
172
-
173
  print(f"Loading dataset from: {DATASET_PATH}")
174
  try:
175
  with open(DATASET_PATH, 'r', encoding='utf-8') as f:
176
  data = json.load(f)
177
-
178
  for entry in data:
179
  word = entry.get('word', '').lower()
180
  pid = entry.get('participant_id', '')
@@ -182,21 +160,17 @@ def load_word_pid_mapping():
182
  if word not in _word_pid_map:
183
  _word_pid_map[word] = set()
184
  _word_pid_map[word].add(pid)
185
-
186
  for word in _word_pid_map:
187
  _word_pid_map[word] = sorted(list(_word_pid_map[word]))
188
-
189
  print(f"Loaded {len(_word_pid_map)} unique words from dataset")
190
  except Exception as e:
191
  print(f"Error loading dataset: {e}")
192
-
193
-
194
  def get_pids_for_word(word: str) -> list:
195
  """Get valid PIDs for a word from the dataset."""
196
  word = word.lower().strip()
197
  return _word_pid_map.get(word, [])
198
-
199
-
200
  def get_random_pids_for_word(word: str, count: int = 2) -> list:
201
  """Get random PIDs for a word. Returns up to 'count' PIDs."""
202
  pids = get_pids_for_word(word)
@@ -205,29 +179,26 @@ def get_random_pids_for_word(word: str, count: int = 2) -> list:
205
  if len(pids) <= count:
206
  return pids
207
  return random.sample(pids, count)
208
-
209
-
210
  def get_example_words_with_pids(count: int = 3) -> list:
211
  """Get example words with valid PIDs from dataset."""
212
  examples = []
213
  preferred = ['push', 'passport', 'library', 'send', 'college', 'help', 'thank', 'hello']
214
-
215
  for word in preferred:
216
  pids = get_pids_for_word(word)
217
  if pids:
218
  examples.append((word, pids[0]))
219
  if len(examples) >= count:
220
  break
221
-
222
  if len(examples) < count:
223
  available = [w for w in _word_pid_map.keys() if w not in [e[0] for e in examples]]
224
  random.shuffle(available)
225
  for word in available[:count - len(examples)]:
226
  pids = _word_pid_map[word]
227
  examples.append((word, pids[0]))
228
-
229
- return examples
230
 
 
231
  # =====================================================================
232
  # VQ-VAE Wrapper
233
  # =====================================================================
@@ -240,14 +211,13 @@ class MotionGPT_VQVAE_Wrapper(torch.nn.Module):
240
  nfeats=smpl_dim, code_num=codebook_size, code_dim=code_dim,
241
  output_emb_width=code_dim, **kwargs
242
  )
243
-
244
  # =====================================================================
245
  # Model Loading Functions
246
  # =====================================================================
247
  def load_llm_model():
248
  print(f"Loading LLM from: {HF_REPO_ID}/{HF_SUBFOLDER}")
249
  token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
250
-
251
  tokenizer = AutoTokenizer.from_pretrained(
252
  HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True, token=token
253
  )
@@ -263,8 +233,6 @@ def load_llm_model():
263
  model.eval()
264
  print(f"LLM loaded (vocab size: {len(tokenizer)})")
265
  return model, tokenizer
266
-
267
-
268
  def load_vqvae_model():
269
  if not os.path.exists(VQVAE_CHECKPOINT):
270
  print(f"VQ-VAE checkpoint not found: {VQVAE_CHECKPOINT}")
@@ -277,8 +245,6 @@ def load_vqvae_model():
277
  model.eval()
278
  print(f"VQ-VAE loaded")
279
  return model
280
-
281
-
282
  def load_stats():
283
  if not os.path.exists(STATS_PATH):
284
  return None, None
@@ -287,8 +253,6 @@ def load_stats():
287
  if torch.is_tensor(mean): mean = mean.cpu().numpy()
288
  if torch.is_tensor(std): std = std.cpu().numpy()
289
  return mean, std
290
-
291
-
292
  def load_smplx_model():
293
  if not os.path.exists(SMPLX_MODEL_DIR):
294
  print(f"SMPL-X directory not found: {SMPLX_MODEL_DIR}")
@@ -302,47 +266,43 @@ def load_smplx_model():
302
  ).to(DEVICE)
303
  print(f"SMPL-X loaded")
304
  return model
305
-
306
-
307
  def initialize_models():
308
  global _model_cache
309
  if _model_cache["initialized"]:
310
  return
311
-
312
  print("\n" + "="*60)
313
  print(" Initializing SignMotionGPT Models")
314
  print("="*60)
315
-
316
  load_word_pid_mapping()
317
-
318
  _model_cache["llm_model"], _model_cache["llm_tokenizer"] = load_llm_model()
319
-
320
  try:
321
  _model_cache["vqvae_model"] = load_vqvae_model()
322
  _model_cache["stats"] = load_stats()
323
  _model_cache["smplx_model"] = load_smplx_model()
324
  except Exception as e:
325
  print(f"Could not load visualization models: {e}")
326
-
327
  # Ensure PyRender is available
328
  ensure_pyrender()
329
-
330
  _model_cache["initialized"] = True
331
  print("All models initialized")
332
  print("="*60)
333
-
334
-
335
  def precompute_examples():
336
  """Pre-compute animations for example words at startup."""
337
  global _example_cache
338
-
339
  if not _model_cache["initialized"]:
340
  return
341
-
342
  examples = get_example_words_with_pids(3)
343
-
344
  print(f"\nPre-computing {len(examples)} example animations...")
345
-
346
  for word, pid in examples:
347
  key = f"{word}_{pid}"
348
  print(f" Computing: {word} ({pid})...")
@@ -353,22 +313,21 @@ def precompute_examples():
353
  except Exception as e:
354
  print(f" Failed: {word} - {e}")
355
  _example_cache[key] = {"video_path": None, "tokens": "", "word": word, "pid": pid}
356
-
357
- print("Example pre-computation complete\n")
358
 
 
359
  # =====================================================================
360
  # Motion Generation Functions
361
  # =====================================================================
362
  def generate_motion_tokens(word: str, variant: str) -> str:
363
  model = _model_cache["llm_model"]
364
  tokenizer = _model_cache["llm_tokenizer"]
365
-
366
  if model is None or tokenizer is None:
367
  raise RuntimeError("LLM model not loaded")
368
-
369
  prompt = f"Instruction: Generate motion for word '{word}' with variant '{variant}'.\nMotion: "
370
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
371
-
372
  with torch.no_grad():
373
  output = model.generate(
374
  **inputs, max_new_tokens=100, do_sample=True,
@@ -378,46 +337,42 @@ def generate_motion_tokens(word: str, variant: str) -> str:
378
  eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
379
  early_stopping=True
380
  )
381
-
382
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
383
  motion_part = decoded.split("Motion: ")[-1] if "Motion: " in decoded else decoded
384
  return motion_part.strip()
385
-
386
-
387
  def parse_motion_tokens(token_str: str) -> list:
388
  if isinstance(token_str, (list, tuple, np.ndarray)):
389
  return [int(x) for x in token_str]
390
  if not isinstance(token_str, str):
391
  return []
392
-
393
  matches = re.findall(r'<M(\d+)>', token_str)
394
  if matches:
395
  return [int(x) for x in matches]
396
-
397
  matches = re.findall(r'<motion_(\d+)>', token_str)
398
  if matches:
399
  return [int(x) for x in matches]
400
-
401
- return []
402
-
403
 
 
404
  def decode_tokens_to_params(tokens: list) -> np.ndarray:
405
  vqvae_model = _model_cache["vqvae_model"]
406
  mean, std = _model_cache["stats"]
407
-
408
  if vqvae_model is None or not tokens:
409
  return np.zeros((0, SMPL_DIM), dtype=np.float32)
410
-
411
  idx = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
412
  T_q = idx.shape[1]
413
  quantizer = vqvae_model.vqvae.quantizer
414
-
415
  if hasattr(quantizer, "codebook"):
416
  codebook = quantizer.codebook.to(DEVICE)
417
  code_dim = codebook.shape[1]
418
  else:
419
  code_dim = CODE_DIM
420
-
421
  x_quantized = None
422
  if hasattr(quantizer, "dequantize"):
423
  try:
@@ -431,47 +386,55 @@ def decode_tokens_to_params(tokens: list) -> np.ndarray:
431
  x_quantized = dq.permute(0, 2, 1).contiguous()
432
  except Exception:
433
  pass
434
-
435
  if x_quantized is None:
436
  if not hasattr(quantizer, "codebook"):
437
  return np.zeros((0, SMPL_DIM), dtype=np.float32)
438
  with torch.no_grad():
439
  emb = codebook[idx]
440
  x_quantized = emb.permute(0, 2, 1).contiguous()
441
-
442
  with torch.no_grad():
443
  x_dec = vqvae_model.vqvae.decoder(x_quantized)
444
  smpl_out = vqvae_model.vqvae.postprocess(x_dec)
445
  params_np = smpl_out.squeeze(0).cpu().numpy()
446
-
447
  if (mean is not None) and (std is not None):
448
  params_np = (params_np * np.array(std).reshape(1, -1)) + np.array(mean).reshape(1, -1)
449
-
450
- return params_np
451
-
452
 
453
- def params_to_vertices(params_seq: np.ndarray) -> tuple:
454
- smplx_model = _model_cache["smplx_model"]
455
- if smplx_model is None or params_seq.shape[0] == 0:
456
- return None, None
457
-
 
 
458
  starts = np.cumsum([0] + PARAM_DIMS[:-1])
459
  ends = starts + np.array(PARAM_DIMS)
 
460
  T = params_seq.shape[0]
461
  all_verts = []
462
- batch_size = 32
 
463
  num_body_joints = getattr(smplx_model, "NUM_BODY_JOINTS", 21)
464
 
465
  with torch.no_grad():
466
  for s in range(0, T, batch_size):
467
- batch = params_seq[s:s+batch_size]
468
  B = batch.shape[0]
469
 
470
- np_parts = {name: batch[:, st:ed].astype(np.float32) for name, st, ed in zip(PARAM_NAMES, starts, ends)}
471
- tensor_parts = {name: torch.from_numpy(arr).to(DEVICE) for name, arr in np_parts.items()}
 
 
 
 
 
 
 
 
472
 
473
- # Handle body pose - in this data format, body_pose is 63 dims (21 joints * 3)
474
- # root_pose is separate as global_orient (3 dims)
475
  body_t = tensor_parts['body_pose']
476
  L_body = body_t.shape[1]
477
  expected_no_go = num_body_joints * 3
@@ -484,24 +447,67 @@ def params_to_vertices(params_seq: np.ndarray) -> tuple:
484
  global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
485
  body_pose_only = body_t
486
  else:
 
487
  if L_body > expected_no_go:
488
  global_orient = body_t[:, :3].contiguous()
489
  body_pose_only = body_t[:, 3:].contiguous()
490
  else:
491
- body_pose_only = F.pad(body_t, (0, max(0, expected_no_go - L_body)))
 
492
  global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  out = smplx_model(
495
- betas=tensor_parts['betas'], global_orient=global_orient, body_pose=body_pose_only,
496
- left_hand_pose=tensor_parts['left_hand_pose'], right_hand_pose=tensor_parts['right_hand_pose'],
497
- expression=tensor_parts['expression'], jaw_pose=tensor_parts['jaw_pose'],
498
- leye_pose=tensor_parts['eye_pose'], reye_pose=tensor_parts['eye_pose'],
499
- transl=tensor_parts['trans'], return_verts=True
 
 
 
 
 
 
500
  )
501
- all_verts.append(out.vertices.detach().cpu().numpy())
 
 
502
 
503
- return np.concatenate(all_verts, axis=0), smplx_model.faces.astype(np.int32)
504
-
 
 
505
  # =====================================================================
506
  # PyRender Visualization Functions
507
  # =====================================================================
@@ -520,20 +526,15 @@ def render_single_frame(
520
  """Render a single mesh frame using PyRender."""
521
  if not PYRENDER_AVAILABLE:
522
  raise RuntimeError("PyRender not available")
523
-
524
  # Check for invalid vertices
525
  if not np.isfinite(verts).all():
526
  blank = np.ones((frame_height, frame_width, 3), dtype=np.uint8) * 200
527
  return blank
528
-
529
- # IMPORTANT: Rotate mesh 180 degrees around X-axis (like visualize.py)
530
- # This fixes the coordinate system so we view from the front
531
- rot_matrix = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
532
- verts_rotated = np.dot(verts, rot_matrix[:3, :3].T)
533
-
534
  # Create scene
535
  scene = pyrender.Scene(bg_color=bg_color, ambient_light=[0.4, 0.4, 0.4])
536
-
537
  # Material
538
  material = pyrender.MetallicRoughnessMaterial(
539
  metallicFactor=0.0,
@@ -541,29 +542,31 @@ def render_single_frame(
541
  alphaMode='OPAQUE',
542
  baseColorFactor=color
543
  )
544
-
545
- # Create mesh with rotated vertices
546
- mesh = trimesh.Trimesh(vertices=verts_rotated, faces=faces)
547
  mesh_render = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=True)
548
  scene.add(mesh_render)
549
-
550
- # Compute center for camera positioning (using rotated vertices)
551
- mesh_center = verts_rotated.mean(axis=0)
552
  camera_target = fixed_center if fixed_center is not None else mesh_center
553
-
554
  # Camera setup
555
  camera = pyrender.IntrinsicsCamera(
556
  fx=focal_length, fy=focal_length,
557
  cx=frame_width / 2, cy=frame_height / 2,
558
  znear=0.1, zfar=20.0
559
  )
560
-
561
-
 
 
562
  camera_pose = np.eye(4)
563
  camera_pose[0, 3] = camera_target[0] # Center X
564
  camera_pose[1, 3] = camera_target[1] # Center Y (body center)
565
  camera_pose[2, 3] = camera_target[2] - camera_distance # In front (negative Z)
566
-
567
  # Camera orientation: flip to look at subject (SOKE-style)
568
  # This rotation makes camera look toward +Z (at the subject)
569
  camera_pose[:3, :3] = np.array([
@@ -571,49 +574,47 @@ def render_single_frame(
571
  [0, -1, 0],
572
  [0, 0, -1]
573
  ])
574
-
575
  scene.add(camera, pose=camera_pose)
576
-
577
  # Lighting
578
  key_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
579
  key_pose = np.eye(4)
580
  key_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(-30), np.radians(-20), 0)[:3, :3]
581
  scene.add(key_light, pose=key_pose)
582
-
583
  fill_light = pyrender.DirectionalLight(color=[0.9, 0.9, 1.0], intensity=1.5)
584
  fill_pose = np.eye(4)
585
  fill_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(-20), np.radians(30), 0)[:3, :3]
586
  scene.add(fill_light, pose=fill_pose)
587
-
588
  rim_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=2.0)
589
  rim_pose = np.eye(4)
590
  rim_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(30), np.radians(180), 0)[:3, :3]
591
  scene.add(rim_light, pose=rim_pose)
592
-
593
  # Render
594
  renderer = pyrender.OffscreenRenderer(viewport_width=frame_width, viewport_height=frame_height, point_size=1.0)
595
  color_img, _ = renderer.render(scene)
596
  renderer.delete()
597
-
598
  # Add label
599
  if label:
600
  img = Image.fromarray(color_img)
601
  draw = ImageDraw.Draw(img)
602
-
603
  try:
604
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
605
  except:
606
  font = ImageFont.load_default()
607
-
608
  text_width = len(label) * 10 + 20
609
  draw.rectangle([10, 10, 10 + text_width, 35], fill=(0, 0, 0, 180))
610
  draw.text((15, 12), label, fill=(255, 255, 255), font=font)
611
-
612
- color_img = np.array(img)
613
-
614
- return color_img
615
 
 
616
 
 
617
  def render_side_by_side_frame(
618
  verts_list: list,
619
  faces: np.ndarray,
@@ -628,20 +629,20 @@ def render_side_by_side_frame(
628
  """Render multiple meshes side-by-side for comparison."""
629
  if not PYRENDER_AVAILABLE:
630
  raise RuntimeError("PyRender not available")
631
-
632
  # Colors for each avatar
633
  colors = [
634
  (0.3, 0.8, 0.4, 1.0), # Green
635
  (0.3, 0.6, 0.9, 1.0), # Blue
636
  (0.9, 0.5, 0.2, 1.0), # Orange
637
  ]
638
-
639
  frames = []
640
  for i, verts in enumerate(verts_list):
641
  fixed_center = fixed_centers[i] if fixed_centers else None
642
  color = colors[i % len(colors)]
643
  label = labels[i] if i < len(labels) else ""
644
-
645
  frame = render_single_frame(
646
  verts, faces, label=label, color=color,
647
  fixed_center=fixed_center, camera_distance=camera_distance,
@@ -649,10 +650,8 @@ def render_side_by_side_frame(
649
  frame_height=frame_height, bg_color=bg_color
650
  )
651
  frames.append(frame)
652
-
653
- return np.concatenate(frames, axis=1)
654
-
655
 
 
656
  def render_video(
657
  verts: np.ndarray,
658
  faces: np.ndarray,
@@ -668,17 +667,19 @@ def render_video(
668
  """Render single avatar animation to video."""
669
  if not ensure_pyrender():
670
  raise RuntimeError("PyRender not available")
671
-
 
 
 
 
672
  # Trim last few frames to remove end-of-sequence artifacts
673
  T_total = verts.shape[0]
674
  trim_amount = min(8, int(T_total * 0.15))
675
  T = max(5, T_total - trim_amount)
676
-
677
- # Compute fixed camera target from first frame (after rotation)
678
- rot_matrix = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
679
- verts_rotated_first = np.dot(verts[0], rot_matrix[:3, :3].T)
680
- fixed_center = verts_rotated_first.mean(axis=0)
681
-
682
  frames = []
683
  for t in range(T):
684
  frame = render_single_frame(
@@ -689,16 +690,14 @@ def render_video(
689
  )
690
  for _ in range(slowdown):
691
  frames.append(frame)
692
-
693
  # Save video
694
  Path(output_path).parent.mkdir(parents=True, exist_ok=True)
695
-
696
  if len(frames) > 0:
697
  imageio.mimsave(output_path, frames, fps=fps, codec='libx264', quality=8)
698
-
699
- return output_path
700
-
701
 
 
702
  def render_comparison_video(
703
  verts1: np.ndarray,
704
  faces1: np.ndarray,
@@ -717,24 +716,27 @@ def render_comparison_video(
717
  """Render side-by-side comparison video."""
718
  if not ensure_pyrender():
719
  raise RuntimeError("PyRender not available")
720
-
 
 
 
 
 
 
721
  # Match lengths and trim
722
  T_total = min(verts1.shape[0], verts2.shape[0])
723
  trim_amount = min(8, int(T_total * 0.15))
724
  T = max(5, T_total - trim_amount)
725
-
726
  verts1 = verts1[:T]
727
  verts2 = verts2[:T]
728
-
729
- # Compute fixed camera targets (after rotation)
730
- rot_matrix = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
731
- verts1_rotated_first = np.dot(verts1[0], rot_matrix[:3, :3].T)
732
- verts2_rotated_first = np.dot(verts2[0], rot_matrix[:3, :3].T)
733
- fixed_center1 = verts1_rotated_first.mean(axis=0)
734
- fixed_center2 = verts2_rotated_first.mean(axis=0)
735
-
736
  labels = [label1, label2]
737
-
738
  frames = []
739
  for t in range(T):
740
  frame = render_side_by_side_frame(
@@ -745,15 +747,14 @@ def render_comparison_video(
745
  )
746
  for _ in range(slowdown):
747
  frames.append(frame)
748
-
749
  # Save video
750
  Path(output_path).parent.mkdir(parents=True, exist_ok=True)
751
-
752
  if len(frames) > 0:
753
  imageio.mimsave(output_path, frames, fps=fps, codec='libx264', quality=8)
754
-
755
- return output_path
756
 
 
757
  # =====================================================================
758
  # Main Processing Functions
759
  # =====================================================================
@@ -761,80 +762,74 @@ def generate_verts_for_word(word: str, pid: str) -> tuple:
761
  """Generate vertices and faces for a word-PID pair."""
762
  generated_tokens = generate_motion_tokens(word, pid)
763
  token_ids = parse_motion_tokens(generated_tokens)
764
-
765
  if not token_ids:
766
  return None, None, generated_tokens
767
-
768
  if _model_cache["vqvae_model"] is None or _model_cache["smplx_model"] is None:
769
  return None, None, generated_tokens
770
-
771
  params = decode_tokens_to_params(token_ids)
772
  if params.shape[0] == 0:
773
  return None, None, generated_tokens
774
-
775
  verts, faces = params_to_vertices(params)
776
  return verts, faces, generated_tokens
777
-
778
-
779
  def generate_video_for_word(word: str, pid: str) -> tuple:
780
  """Generate video and tokens for a word. Returns (video_path, tokens)."""
781
  verts, faces, tokens = generate_verts_for_word(word, pid)
782
-
783
  if verts is None:
784
  return None, tokens
785
-
786
  # Generate unique filename
787
  video_filename = f"motion_{word}_{pid}_{uuid.uuid4().hex[:8]}.mp4"
788
  video_path = os.path.join(OUTPUT_DIR, video_filename)
789
-
790
  render_video(verts, faces, video_path, label=f"{pid}")
791
  return video_path, tokens
792
-
793
-
794
  def process_word(word: str):
795
  """Main processing: generate side-by-side comparison video for two random PIDs."""
796
  if not word or not word.strip():
797
  return None, ""
798
-
799
  word = word.strip().lower()
800
-
801
  pids = get_random_pids_for_word(word, 2)
802
-
803
  if not pids:
804
  return None, f"Word '{word}' not found in dataset"
805
-
806
  if len(pids) == 1:
807
  pids = [pids[0], pids[0]]
808
-
809
  try:
810
  verts1, faces1, tokens1 = generate_verts_for_word(word, pids[0])
811
  verts2, faces2, tokens2 = generate_verts_for_word(word, pids[1])
812
-
813
  if verts1 is None and verts2 is None:
814
  return None, tokens1 or tokens2 or "Failed to generate motion"
815
-
816
  # Generate unique filename
817
  video_filename = f"comparison_{word}_{uuid.uuid4().hex[:8]}.mp4"
818
  video_path = os.path.join(OUTPUT_DIR, video_filename)
819
-
820
  if verts1 is None:
821
  render_video(verts2, faces2, video_path, label=pids[1])
822
  return video_path, tokens2
823
  if verts2 is None:
824
  render_video(verts1, faces1, video_path, label=pids[0])
825
  return video_path, tokens1
826
-
827
  render_comparison_video(
828
  verts1, faces1, verts2, faces2, video_path,
829
  label1=pids[0], label2=pids[1]
830
  )
831
  combined_tokens = f"[{pids[0]}] {tokens1}\n\n[{pids[1]}] {tokens2}"
832
  return video_path, combined_tokens
833
-
834
  except Exception as e:
835
  return None, f"Error: {str(e)[:100]}"
836
-
837
-
838
  def get_example_video(word: str, pid: str):
839
  """Get pre-computed example video."""
840
  key = f"{word}_{pid}"
@@ -843,65 +838,67 @@ def get_example_video(word: str, pid: str):
843
  return cached.get("video_path"), cached.get("tokens", "")
844
  video_path, tokens = generate_video_for_word(word, pid)
845
  return video_path, tokens
846
-
847
  # =====================================================================
848
  # Gradio Interface
849
  # =====================================================================
850
  def create_gradio_interface():
851
-
852
  custom_css = """
853
  .gradio-container { max-width: 1400px !important; }
854
- .example-row { margin-top: 15px; padding: 12px; background: #f8f9fa; border-radius: 6px; }
 
855
  .example-word-label {
856
  text-align: center;
857
  font-size: 28px !important;
858
  font-weight: bold !important;
859
- color: #2c3e50 !important;
 
860
  margin: 10px 0 !important;
861
  padding: 10px !important;
862
  }
863
  .example-variant-label {
864
  text-align: center;
865
  font-size: 14px !important;
866
- color: #7f8c8d !important;
 
867
  margin-bottom: 10px !important;
868
  }
869
  """
870
-
871
  example_list = list(_example_cache.values()) if _example_cache else []
872
-
873
  with gr.Blocks(title="SignMotionGPT", css=custom_css, theme=gr.themes.Default()) as demo:
874
-
875
  gr.Markdown("# SignMotionGPT Demo")
876
  gr.Markdown("Text-to-Sign Language Motion Generation with Variant Comparison")
877
  gr.Markdown("*High-quality PyRender visualization with proper hand motion rendering*")
878
-
879
  with gr.Row():
880
  with gr.Column(scale=1, min_width=280):
881
  gr.Markdown("### Input")
882
-
883
  word_input = gr.Textbox(
884
  label="Word",
885
  placeholder="Enter a word from the dataset...",
886
  lines=1, max_lines=1
887
  )
888
-
889
  generate_btn = gr.Button("Generate Motion", variant="primary", size="lg")
890
-
891
  gr.Markdown("---")
892
  gr.Markdown("### Generated Tokens")
893
-
894
  tokens_output = gr.Textbox(
895
  label="Motion Tokens (both variants)",
896
  lines=8,
897
  interactive=False,
898
  show_copy_button=True
899
  )
900
-
901
  if _word_pid_map:
902
  sample_words = list(_word_pid_map.keys())[:10]
903
  gr.Markdown(f"**Available words:** {', '.join(sample_words)}, ...")
904
-
905
  with gr.Column(scale=2, min_width=700):
906
  gr.Markdown("### Motion Comparison (Two Signer Variants)")
907
  video_output = gr.Video(
@@ -909,11 +906,11 @@ def create_gradio_interface():
909
  autoplay=True,
910
  show_download_button=True
911
  )
912
-
913
  if example_list:
914
  gr.Markdown("---")
915
  gr.Markdown("### Pre-computed Examples")
916
-
917
  for item in example_list:
918
  word, pid = item['word'], item['pid']
919
  with gr.Row(elem_classes="example-row"):
@@ -921,37 +918,36 @@ def create_gradio_interface():
921
  gr.HTML(f'<div class="example-word-label">{word.upper()}</div>')
922
  gr.HTML(f'<div class="example-variant-label">Variant: {pid}</div>')
923
  example_btn = gr.Button("Load Example", size="sm", variant="secondary")
924
-
925
  with gr.Column(scale=3, min_width=500):
926
  example_video = gr.Video(
927
  label=f"Example: {word}",
928
  autoplay=False,
929
  show_download_button=True
930
  )
931
-
932
  example_btn.click(
933
  fn=lambda w=word, p=pid: get_example_video(w, p),
934
  inputs=[],
935
  outputs=[example_video, tokens_output]
936
  )
937
-
938
  gr.Markdown("---")
939
  gr.Markdown("*SignMotionGPT: LLM-based sign language motion generation with PyRender visualization*")
940
-
941
  generate_btn.click(
942
  fn=process_word,
943
  inputs=[word_input],
944
  outputs=[video_output, tokens_output]
945
  )
946
-
947
  word_input.submit(
948
  fn=process_word,
949
  inputs=[word_input],
950
  outputs=[video_output, tokens_output]
951
  )
952
-
953
- return demo
954
 
 
955
  # =====================================================================
956
  # Main Entry Point for HuggingFace Spaces
957
  # =====================================================================
@@ -965,20 +961,16 @@ print(f"Output Directory: {OUTPUT_DIR}")
965
  print(f"Dataset: {DATASET_PATH}")
966
  print(f"PyRender Available: {PYRENDER_AVAILABLE}")
967
  print("="*60 + "\n")
968
-
969
  # Initialize models at startup
970
  initialize_models()
971
-
972
  # Pre-compute example animations
973
  precompute_examples()
974
-
975
  # Create and launch interface
976
  demo = create_gradio_interface()
977
-
978
  if __name__ == "__main__":
979
  # Launch with settings for HuggingFace Spaces
980
  demo.launch(
981
  server_name="0.0.0.0",
982
  server_port=7860,
983
  share=False
984
- )
 
6
  # IMPORTANT: Set OpenGL platform BEFORE any OpenGL imports (for headless rendering)
7
  import os
8
  os.environ["PYOPENGL_PLATFORM"] = "egl"
 
9
  import sys
10
  import re
11
  import json
 
14
  import tempfile
15
  import uuid
16
  from pathlib import Path
 
17
  import torch
18
  import numpy as np
 
19
  warnings.filterwarnings("ignore")
 
20
  # =====================================================================
21
  # Configuration for HuggingFace Spaces
22
  # =====================================================================
 
25
  OUTPUT_DIR = os.path.join(WORK_DIR, "outputs")
26
  os.makedirs(DATA_DIR, exist_ok=True)
27
  os.makedirs(OUTPUT_DIR, exist_ok=True)
 
28
  # Path definitions
29
  DATASET_PATH = os.path.join(DATA_DIR, "motion_llm_dataset.json")
30
  VQVAE_CHECKPOINT = os.path.join(DATA_DIR, "vqvae_model.pt")
31
  STATS_PATH = os.path.join(DATA_DIR, "vqvae_stats.pt")
32
  SMPLX_MODEL_DIR = os.path.join(DATA_DIR, "smplx_models")
 
33
  # HuggingFace model config
34
  HF_REPO_ID = os.environ.get("HF_REPO_ID", "rdz-falcon/SignMotionGPTfit-archive")
35
  HF_SUBFOLDER = os.environ.get("HF_SUBFOLDER", "stage2_v2/epoch-030")
 
36
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
37
  # Generation parameters
38
  M_START = "<M_START>"
39
  M_END = "<M_END>"
 
41
  INFERENCE_TEMPERATURE = 0.7
42
  INFERENCE_TOP_K = 50
43
  INFERENCE_REPETITION_PENALTY = 1.2
 
44
  # VQ-VAE parameters
45
  SMPL_DIM = 182
46
  CODEBOOK_SIZE = 512
 
49
  width=512, depth=3, down_t=2, stride_t=2,
50
  dilation_growth_rate=3, activation='relu', norm=None, quantizer="ema_reset"
51
  )
 
52
  PARAM_DIMS = [10, 63, 45, 45, 3, 10, 3, 3]
53
+ PARAM_NAMES = ["betas", "body_pose", "left_hand_pose", "right_hand_pose",
54
+ "trans", "expression", "jaw_pose", "eye_pose"]
 
55
  # Visualization defaults
56
  AVATAR_COLOR = (0.36, 0.78, 0.36, 1.0) # Green color as RGBA
57
  VIDEO_FPS = 15
58
  VIDEO_SLOWDOWN = 2
59
  FRAME_WIDTH = 544 # Must be divisible by 16 for video codec compatibility
60
  FRAME_HEIGHT = 720
 
61
  # =====================================================================
62
  # Install/Import Dependencies
63
  # =====================================================================
 
66
  except ImportError:
67
  os.system("pip install -q gradio>=4.0.0")
68
  import gradio as gr
 
69
  try:
70
  import smplx
71
  except ImportError:
72
  os.system("pip install -q smplx==0.1.28")
73
  import smplx
 
74
  # PyRender for high-quality rendering
75
  PYRENDER_AVAILABLE = False
76
  try:
 
80
  PYRENDER_AVAILABLE = True
81
  except ImportError:
82
  pass
 
83
  try:
84
  import imageio
85
  except ImportError:
86
  os.system("pip install -q imageio[ffmpeg]")
87
  import imageio
 
88
  from transformers import AutoModelForCausalLM, AutoTokenizer
89
  import torch.nn.functional as F
 
90
  # =====================================================================
91
  # Import VQ-VAE architecture
92
  # =====================================================================
 
96
  sys.path.insert(0, parent_dir)
97
  if current_dir not in sys.path:
98
  sys.path.insert(0, current_dir)
 
99
  try:
100
  from mGPT.archs.mgpt_vq import VQVae
101
  except ImportError as e:
102
  print(f"Warning: Could not import VQVae: {e}")
103
  VQVae = None
 
104
  # =====================================================================
105
  # Global Cache
106
  # =====================================================================
 
112
  "stats": (None, None),
113
  "initialized": False
114
  }
 
115
  _word_pid_map = {}
116
  _example_cache = {}
 
117
  # =====================================================================
118
  # PyRender Setup
119
  # =====================================================================
 
122
  global PYRENDER_AVAILABLE, trimesh, pyrender, Image, ImageDraw, ImageFont
123
  if PYRENDER_AVAILABLE:
124
  return True
125
+
126
  print("Installing pyrender dependencies...")
127
  if os.path.exists("/etc/debian_version"):
128
  os.system("apt-get update -qq && apt-get install -qq -y libegl1-mesa-dev libgles2-mesa-dev > /dev/null 2>&1")
129
  os.system("pip install -q trimesh pyrender PyOpenGL PyOpenGL_accelerate Pillow")
130
+
131
  try:
132
  import trimesh
133
  import pyrender
 
137
  except ImportError as e:
138
  print(f"Could not install pyrender: {e}")
139
  return False
 
140
  # =====================================================================
141
  # Dataset Loading - Word to PID mapping
142
  # =====================================================================
143
  def load_word_pid_mapping():
144
  """Load the dataset and build word -> PIDs mapping."""
145
  global _word_pid_map
146
+
147
  if not os.path.exists(DATASET_PATH):
148
  print(f"Dataset not found: {DATASET_PATH}")
149
  return
150
+
151
  print(f"Loading dataset from: {DATASET_PATH}")
152
  try:
153
  with open(DATASET_PATH, 'r', encoding='utf-8') as f:
154
  data = json.load(f)
155
+
156
  for entry in data:
157
  word = entry.get('word', '').lower()
158
  pid = entry.get('participant_id', '')
 
160
  if word not in _word_pid_map:
161
  _word_pid_map[word] = set()
162
  _word_pid_map[word].add(pid)
163
+
164
  for word in _word_pid_map:
165
  _word_pid_map[word] = sorted(list(_word_pid_map[word]))
166
+
167
  print(f"Loaded {len(_word_pid_map)} unique words from dataset")
168
  except Exception as e:
169
  print(f"Error loading dataset: {e}")
 
 
170
  def get_pids_for_word(word: str) -> list:
171
  """Get valid PIDs for a word from the dataset."""
172
  word = word.lower().strip()
173
  return _word_pid_map.get(word, [])
 
 
174
  def get_random_pids_for_word(word: str, count: int = 2) -> list:
175
  """Get random PIDs for a word. Returns up to 'count' PIDs."""
176
  pids = get_pids_for_word(word)
 
179
  if len(pids) <= count:
180
  return pids
181
  return random.sample(pids, count)
 
 
182
  def get_example_words_with_pids(count: int = 3) -> list:
183
  """Get example words with valid PIDs from dataset."""
184
  examples = []
185
  preferred = ['push', 'passport', 'library', 'send', 'college', 'help', 'thank', 'hello']
186
+
187
  for word in preferred:
188
  pids = get_pids_for_word(word)
189
  if pids:
190
  examples.append((word, pids[0]))
191
  if len(examples) >= count:
192
  break
193
+
194
  if len(examples) < count:
195
  available = [w for w in _word_pid_map.keys() if w not in [e[0] for e in examples]]
196
  random.shuffle(available)
197
  for word in available[:count - len(examples)]:
198
  pids = _word_pid_map[word]
199
  examples.append((word, pids[0]))
 
 
200
 
201
+ return examples
202
  # =====================================================================
203
  # VQ-VAE Wrapper
204
  # =====================================================================
 
211
  nfeats=smpl_dim, code_num=codebook_size, code_dim=code_dim,
212
  output_emb_width=code_dim, **kwargs
213
  )
 
214
  # =====================================================================
215
  # Model Loading Functions
216
  # =====================================================================
217
  def load_llm_model():
218
  print(f"Loading LLM from: {HF_REPO_ID}/{HF_SUBFOLDER}")
219
  token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
220
+
221
  tokenizer = AutoTokenizer.from_pretrained(
222
  HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True, token=token
223
  )
 
233
  model.eval()
234
  print(f"LLM loaded (vocab size: {len(tokenizer)})")
235
  return model, tokenizer
 
 
236
  def load_vqvae_model():
237
  if not os.path.exists(VQVAE_CHECKPOINT):
238
  print(f"VQ-VAE checkpoint not found: {VQVAE_CHECKPOINT}")
 
245
  model.eval()
246
  print(f"VQ-VAE loaded")
247
  return model
 
 
248
  def load_stats():
249
  if not os.path.exists(STATS_PATH):
250
  return None, None
 
253
  if torch.is_tensor(mean): mean = mean.cpu().numpy()
254
  if torch.is_tensor(std): std = std.cpu().numpy()
255
  return mean, std
 
 
256
  def load_smplx_model():
257
  if not os.path.exists(SMPLX_MODEL_DIR):
258
  print(f"SMPL-X directory not found: {SMPLX_MODEL_DIR}")
 
266
  ).to(DEVICE)
267
  print(f"SMPL-X loaded")
268
  return model
 
 
269
  def initialize_models():
270
  global _model_cache
271
  if _model_cache["initialized"]:
272
  return
273
+
274
  print("\n" + "="*60)
275
  print(" Initializing SignMotionGPT Models")
276
  print("="*60)
277
+
278
  load_word_pid_mapping()
279
+
280
  _model_cache["llm_model"], _model_cache["llm_tokenizer"] = load_llm_model()
281
+
282
  try:
283
  _model_cache["vqvae_model"] = load_vqvae_model()
284
  _model_cache["stats"] = load_stats()
285
  _model_cache["smplx_model"] = load_smplx_model()
286
  except Exception as e:
287
  print(f"Could not load visualization models: {e}")
288
+
289
  # Ensure PyRender is available
290
  ensure_pyrender()
291
+
292
  _model_cache["initialized"] = True
293
  print("All models initialized")
294
  print("="*60)
 
 
295
  def precompute_examples():
296
  """Pre-compute animations for example words at startup."""
297
  global _example_cache
298
+
299
  if not _model_cache["initialized"]:
300
  return
301
+
302
  examples = get_example_words_with_pids(3)
303
+
304
  print(f"\nPre-computing {len(examples)} example animations...")
305
+
306
  for word, pid in examples:
307
  key = f"{word}_{pid}"
308
  print(f" Computing: {word} ({pid})...")
 
313
  except Exception as e:
314
  print(f" Failed: {word} - {e}")
315
  _example_cache[key] = {"video_path": None, "tokens": "", "word": word, "pid": pid}
 
 
316
 
317
+ print("Example pre-computation complete\n")
318
  # =====================================================================
319
  # Motion Generation Functions
320
  # =====================================================================
321
  def generate_motion_tokens(word: str, variant: str) -> str:
322
  model = _model_cache["llm_model"]
323
  tokenizer = _model_cache["llm_tokenizer"]
324
+
325
  if model is None or tokenizer is None:
326
  raise RuntimeError("LLM model not loaded")
327
+
328
  prompt = f"Instruction: Generate motion for word '{word}' with variant '{variant}'.\nMotion: "
329
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
330
+
331
  with torch.no_grad():
332
  output = model.generate(
333
  **inputs, max_new_tokens=100, do_sample=True,
 
337
  eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
338
  early_stopping=True
339
  )
340
+
341
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
342
  motion_part = decoded.split("Motion: ")[-1] if "Motion: " in decoded else decoded
343
  return motion_part.strip()
 
 
344
  def parse_motion_tokens(token_str: str) -> list:
345
  if isinstance(token_str, (list, tuple, np.ndarray)):
346
  return [int(x) for x in token_str]
347
  if not isinstance(token_str, str):
348
  return []
349
+
350
  matches = re.findall(r'<M(\d+)>', token_str)
351
  if matches:
352
  return [int(x) for x in matches]
353
+
354
  matches = re.findall(r'<motion_(\d+)>', token_str)
355
  if matches:
356
  return [int(x) for x in matches]
 
 
 
357
 
358
+ return []
359
  def decode_tokens_to_params(tokens: list) -> np.ndarray:
360
  vqvae_model = _model_cache["vqvae_model"]
361
  mean, std = _model_cache["stats"]
362
+
363
  if vqvae_model is None or not tokens:
364
  return np.zeros((0, SMPL_DIM), dtype=np.float32)
365
+
366
  idx = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
367
  T_q = idx.shape[1]
368
  quantizer = vqvae_model.vqvae.quantizer
369
+
370
  if hasattr(quantizer, "codebook"):
371
  codebook = quantizer.codebook.to(DEVICE)
372
  code_dim = codebook.shape[1]
373
  else:
374
  code_dim = CODE_DIM
375
+
376
  x_quantized = None
377
  if hasattr(quantizer, "dequantize"):
378
  try:
 
386
  x_quantized = dq.permute(0, 2, 1).contiguous()
387
  except Exception:
388
  pass
389
+
390
  if x_quantized is None:
391
  if not hasattr(quantizer, "codebook"):
392
  return np.zeros((0, SMPL_DIM), dtype=np.float32)
393
  with torch.no_grad():
394
  emb = codebook[idx]
395
  x_quantized = emb.permute(0, 2, 1).contiguous()
396
+
397
  with torch.no_grad():
398
  x_dec = vqvae_model.vqvae.decoder(x_quantized)
399
  smpl_out = vqvae_model.vqvae.postprocess(x_dec)
400
  params_np = smpl_out.squeeze(0).cpu().numpy()
401
+
402
  if (mean is not None) and (std is not None):
403
  params_np = (params_np * np.array(std).reshape(1, -1)) + np.array(mean).reshape(1, -1)
 
 
 
404
 
405
+ return params_np
406
+ def params_to_vertices(params_seq: np.ndarray, smplx_model, batch_size=32) -> tuple:
407
+ """
408
+ Convert SMPL-X parameters to 3D vertices.
409
+ FIXED: Properly handles jaw_pose and expression to prevent lip/mouth issues.
410
+ """
411
+ # Compute parameter slicing indices
412
  starts = np.cumsum([0] + PARAM_DIMS[:-1])
413
  ends = starts + np.array(PARAM_DIMS)
414
+
415
  T = params_seq.shape[0]
416
  all_verts = []
417
+
418
+ # Infer number of body joints
419
  num_body_joints = getattr(smplx_model, "NUM_BODY_JOINTS", 21)
420
 
421
  with torch.no_grad():
422
  for s in range(0, T, batch_size):
423
+ batch = params_seq[s:s+batch_size] # (B, SMPL_DIM)
424
  B = batch.shape[0]
425
 
426
+ # Extract parameters
427
+ np_parts = {}
428
+ for name, st, ed in zip(PARAM_NAMES, starts, ends):
429
+ np_parts[name] = batch[:, st:ed].astype(np.float32)
430
+
431
+ # Convert to tensors
432
+ tensor_parts = {
433
+ name: torch.from_numpy(arr).to(DEVICE)
434
+ for name, arr in np_parts.items()
435
+ }
436
 
437
+ # Handle body pose (may or may not include global orient)
 
438
  body_t = tensor_parts['body_pose']
439
  L_body = body_t.shape[1]
440
  expected_no_go = num_body_joints * 3
 
447
  global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
448
  body_pose_only = body_t
449
  else:
450
+ # Best-effort fallback
451
  if L_body > expected_no_go:
452
  global_orient = body_t[:, :3].contiguous()
453
  body_pose_only = body_t[:, 3:].contiguous()
454
  else:
455
+ pad_len = max(0, expected_no_go - L_body)
456
+ body_pose_only = F.pad(body_t, (0, pad_len))
457
  global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
458
 
459
+ # ✅ FIX: Ensure jaw_pose is properly shaped (should be B x 3)
460
+ jaw_pose = tensor_parts['jaw_pose']
461
+ if jaw_pose.shape[1] != 3:
462
+ print(f"Warning: jaw_pose has shape {jaw_pose.shape}, padding/trimming to (B, 3)")
463
+ if jaw_pose.shape[1] < 3:
464
+ jaw_pose = F.pad(jaw_pose, (0, 3 - jaw_pose.shape[1]))
465
+ else:
466
+ jaw_pose = jaw_pose[:, :3]
467
+ jaw_pose = jaw_pose.contiguous()
468
+
469
+ # ✅ FIX: Ensure expression is properly shaped (should be B x 10)
470
+ expression = tensor_parts['expression']
471
+ if expression.shape[1] != 10:
472
+ print(f"Warning: expression has shape {expression.shape}, padding/trimming to (B, 10)")
473
+ if expression.shape[1] < 10:
474
+ expression = F.pad(expression, (0, 10 - expression.shape[1]))
475
+ else:
476
+ expression = expression[:, :10]
477
+ expression = expression.contiguous()
478
+
479
+ # ✅ FIX: Ensure eye_pose is properly shaped (should be B x 3)
480
+ eye_pose = tensor_parts['eye_pose']
481
+ if eye_pose.shape[1] != 3:
482
+ print(f"Warning: eye_pose has shape {eye_pose.shape}, padding/trimming to (B, 3)")
483
+ if eye_pose.shape[1] < 3:
484
+ eye_pose = F.pad(eye_pose, (0, 3 - eye_pose.shape[1]))
485
+ else:
486
+ eye_pose = eye_pose[:, :3]
487
+ eye_pose = eye_pose.contiguous()
488
+
489
+ # Call SMPL-X with validated parameters
490
  out = smplx_model(
491
+ betas=tensor_parts['betas'],
492
+ global_orient=global_orient,
493
+ body_pose=body_pose_only,
494
+ left_hand_pose=tensor_parts['left_hand_pose'],
495
+ right_hand_pose=tensor_parts['right_hand_pose'],
496
+ expression=expression, # ✅ Using validated expression
497
+ jaw_pose=jaw_pose, # ✅ Using validated jaw_pose
498
+ leye_pose=eye_pose, # ✅ Using validated eye_pose
499
+ reye_pose=eye_pose, # ✅ Using validated eye_pose
500
+ transl=tensor_parts['trans'],
501
+ return_verts=True
502
  )
503
+
504
+ verts = out.vertices.detach().cpu().numpy() # (B, V, 3)
505
+ all_verts.append(verts)
506
 
507
+ verts_all = np.concatenate(all_verts, axis=0) # (T, V, 3)
508
+ faces = smplx_model.faces.astype(np.int32)
509
+
510
+ return verts_all, faces
511
  # =====================================================================
512
  # PyRender Visualization Functions
513
  # =====================================================================
 
526
  """Render a single mesh frame using PyRender."""
527
  if not PYRENDER_AVAILABLE:
528
  raise RuntimeError("PyRender not available")
529
+
530
  # Check for invalid vertices
531
  if not np.isfinite(verts).all():
532
  blank = np.ones((frame_height, frame_width, 3), dtype=np.uint8) * 200
533
  return blank
534
+
 
 
 
 
 
535
  # Create scene
536
  scene = pyrender.Scene(bg_color=bg_color, ambient_light=[0.4, 0.4, 0.4])
537
+
538
  # Material
539
  material = pyrender.MetallicRoughnessMaterial(
540
  metallicFactor=0.0,
 
542
  alphaMode='OPAQUE',
543
  baseColorFactor=color
544
  )
545
+
546
+ # Create mesh
547
+ mesh = trimesh.Trimesh(vertices=verts, faces=faces)
548
  mesh_render = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=True)
549
  scene.add(mesh_render)
550
+
551
+ # Compute center for camera positioning
552
+ mesh_center = verts.mean(axis=0)
553
  camera_target = fixed_center if fixed_center is not None else mesh_center
554
+
555
  # Camera setup
556
  camera = pyrender.IntrinsicsCamera(
557
  fx=focal_length, fy=focal_length,
558
  cx=frame_width / 2, cy=frame_height / 2,
559
  znear=0.1, zfar=20.0
560
  )
561
+
562
+ # Camera pose: After 180-degree rotation around X-axis, coordinate system changes
563
+ # Camera should be positioned in front (negative Z) with flipped orientation
564
+ # This matches visualize.py and ensures proper face visibility
565
  camera_pose = np.eye(4)
566
  camera_pose[0, 3] = camera_target[0] # Center X
567
  camera_pose[1, 3] = camera_target[1] # Center Y (body center)
568
  camera_pose[2, 3] = camera_target[2] - camera_distance # In front (negative Z)
569
+
570
  # Camera orientation: flip to look at subject (SOKE-style)
571
  # This rotation makes camera look toward +Z (at the subject)
572
  camera_pose[:3, :3] = np.array([
 
574
  [0, -1, 0],
575
  [0, 0, -1]
576
  ])
577
+
578
  scene.add(camera, pose=camera_pose)
579
+
580
  # Lighting
581
  key_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
582
  key_pose = np.eye(4)
583
  key_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(-30), np.radians(-20), 0)[:3, :3]
584
  scene.add(key_light, pose=key_pose)
585
+
586
  fill_light = pyrender.DirectionalLight(color=[0.9, 0.9, 1.0], intensity=1.5)
587
  fill_pose = np.eye(4)
588
  fill_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(-20), np.radians(30), 0)[:3, :3]
589
  scene.add(fill_light, pose=fill_pose)
590
+
591
  rim_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=2.0)
592
  rim_pose = np.eye(4)
593
  rim_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(30), np.radians(180), 0)[:3, :3]
594
  scene.add(rim_light, pose=rim_pose)
595
+
596
  # Render
597
  renderer = pyrender.OffscreenRenderer(viewport_width=frame_width, viewport_height=frame_height, point_size=1.0)
598
  color_img, _ = renderer.render(scene)
599
  renderer.delete()
600
+
601
  # Add label
602
  if label:
603
  img = Image.fromarray(color_img)
604
  draw = ImageDraw.Draw(img)
605
+
606
  try:
607
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
608
  except:
609
  font = ImageFont.load_default()
610
+
611
  text_width = len(label) * 10 + 20
612
  draw.rectangle([10, 10, 10 + text_width, 35], fill=(0, 0, 0, 180))
613
  draw.text((15, 12), label, fill=(255, 255, 255), font=font)
 
 
 
 
614
 
615
+ color_img = np.array(img)
616
 
617
+ return color_img
618
  def render_side_by_side_frame(
619
  verts_list: list,
620
  faces: np.ndarray,
 
629
  """Render multiple meshes side-by-side for comparison."""
630
  if not PYRENDER_AVAILABLE:
631
  raise RuntimeError("PyRender not available")
632
+
633
  # Colors for each avatar
634
  colors = [
635
  (0.3, 0.8, 0.4, 1.0), # Green
636
  (0.3, 0.6, 0.9, 1.0), # Blue
637
  (0.9, 0.5, 0.2, 1.0), # Orange
638
  ]
639
+
640
  frames = []
641
  for i, verts in enumerate(verts_list):
642
  fixed_center = fixed_centers[i] if fixed_centers else None
643
  color = colors[i % len(colors)]
644
  label = labels[i] if i < len(labels) else ""
645
+
646
  frame = render_single_frame(
647
  verts, faces, label=label, color=color,
648
  fixed_center=fixed_center, camera_distance=camera_distance,
 
650
  frame_height=frame_height, bg_color=bg_color
651
  )
652
  frames.append(frame)
 
 
 
653
 
654
+ return np.concatenate(frames, axis=1)
655
  def render_video(
656
  verts: np.ndarray,
657
  faces: np.ndarray,
 
667
  """Render single avatar animation to video."""
668
  if not ensure_pyrender():
669
  raise RuntimeError("PyRender not available")
670
+
671
+ # Apply orientation fix: rotate 180 degrees around X-axis
672
+ verts = verts.copy()
673
+ verts[..., 1:] *= -1
674
+
675
  # Trim last few frames to remove end-of-sequence artifacts
676
  T_total = verts.shape[0]
677
  trim_amount = min(8, int(T_total * 0.15))
678
  T = max(5, T_total - trim_amount)
679
+
680
+ # Compute fixed camera target from first frame
681
+ fixed_center = verts[0].mean(axis=0)
682
+
 
 
683
  frames = []
684
  for t in range(T):
685
  frame = render_single_frame(
 
690
  )
691
  for _ in range(slowdown):
692
  frames.append(frame)
693
+
694
  # Save video
695
  Path(output_path).parent.mkdir(parents=True, exist_ok=True)
696
+
697
  if len(frames) > 0:
698
  imageio.mimsave(output_path, frames, fps=fps, codec='libx264', quality=8)
 
 
 
699
 
700
+ return output_path
701
  def render_comparison_video(
702
  verts1: np.ndarray,
703
  faces1: np.ndarray,
 
716
  """Render side-by-side comparison video."""
717
  if not ensure_pyrender():
718
  raise RuntimeError("PyRender not available")
719
+
720
+ # Apply orientation fix
721
+ verts1 = verts1.copy()
722
+ verts2 = verts2.copy()
723
+ verts1[..., 1:] *= -1
724
+ verts2[..., 1:] *= -1
725
+
726
  # Match lengths and trim
727
  T_total = min(verts1.shape[0], verts2.shape[0])
728
  trim_amount = min(8, int(T_total * 0.15))
729
  T = max(5, T_total - trim_amount)
730
+
731
  verts1 = verts1[:T]
732
  verts2 = verts2[:T]
733
+
734
+ # Compute fixed camera targets
735
+ fixed_center1 = verts1[0].mean(axis=0)
736
+ fixed_center2 = verts2[0].mean(axis=0)
737
+
 
 
 
738
  labels = [label1, label2]
739
+
740
  frames = []
741
  for t in range(T):
742
  frame = render_side_by_side_frame(
 
747
  )
748
  for _ in range(slowdown):
749
  frames.append(frame)
750
+
751
  # Save video
752
  Path(output_path).parent.mkdir(parents=True, exist_ok=True)
753
+
754
  if len(frames) > 0:
755
  imageio.mimsave(output_path, frames, fps=fps, codec='libx264', quality=8)
 
 
756
 
757
+ return output_path
758
  # =====================================================================
759
  # Main Processing Functions
760
  # =====================================================================
 
762
  """Generate vertices and faces for a word-PID pair."""
763
  generated_tokens = generate_motion_tokens(word, pid)
764
  token_ids = parse_motion_tokens(generated_tokens)
765
+
766
  if not token_ids:
767
  return None, None, generated_tokens
768
+
769
  if _model_cache["vqvae_model"] is None or _model_cache["smplx_model"] is None:
770
  return None, None, generated_tokens
771
+
772
  params = decode_tokens_to_params(token_ids)
773
  if params.shape[0] == 0:
774
  return None, None, generated_tokens
775
+
776
  verts, faces = params_to_vertices(params)
777
  return verts, faces, generated_tokens
 
 
778
  def generate_video_for_word(word: str, pid: str) -> tuple:
779
  """Generate video and tokens for a word. Returns (video_path, tokens)."""
780
  verts, faces, tokens = generate_verts_for_word(word, pid)
781
+
782
  if verts is None:
783
  return None, tokens
784
+
785
  # Generate unique filename
786
  video_filename = f"motion_{word}_{pid}_{uuid.uuid4().hex[:8]}.mp4"
787
  video_path = os.path.join(OUTPUT_DIR, video_filename)
788
+
789
  render_video(verts, faces, video_path, label=f"{pid}")
790
  return video_path, tokens
 
 
791
  def process_word(word: str):
792
  """Main processing: generate side-by-side comparison video for two random PIDs."""
793
  if not word or not word.strip():
794
  return None, ""
795
+
796
  word = word.strip().lower()
797
+
798
  pids = get_random_pids_for_word(word, 2)
799
+
800
  if not pids:
801
  return None, f"Word '{word}' not found in dataset"
802
+
803
  if len(pids) == 1:
804
  pids = [pids[0], pids[0]]
805
+
806
  try:
807
  verts1, faces1, tokens1 = generate_verts_for_word(word, pids[0])
808
  verts2, faces2, tokens2 = generate_verts_for_word(word, pids[1])
809
+
810
  if verts1 is None and verts2 is None:
811
  return None, tokens1 or tokens2 or "Failed to generate motion"
812
+
813
  # Generate unique filename
814
  video_filename = f"comparison_{word}_{uuid.uuid4().hex[:8]}.mp4"
815
  video_path = os.path.join(OUTPUT_DIR, video_filename)
816
+
817
  if verts1 is None:
818
  render_video(verts2, faces2, video_path, label=pids[1])
819
  return video_path, tokens2
820
  if verts2 is None:
821
  render_video(verts1, faces1, video_path, label=pids[0])
822
  return video_path, tokens1
823
+
824
  render_comparison_video(
825
  verts1, faces1, verts2, faces2, video_path,
826
  label1=pids[0], label2=pids[1]
827
  )
828
  combined_tokens = f"[{pids[0]}] {tokens1}\n\n[{pids[1]}] {tokens2}"
829
  return video_path, combined_tokens
830
+
831
  except Exception as e:
832
  return None, f"Error: {str(e)[:100]}"
 
 
833
  def get_example_video(word: str, pid: str):
834
  """Get pre-computed example video."""
835
  key = f"{word}_{pid}"
 
838
  return cached.get("video_path"), cached.get("tokens", "")
839
  video_path, tokens = generate_video_for_word(word, pid)
840
  return video_path, tokens
 
841
  # =====================================================================
842
  # Gradio Interface
843
  # =====================================================================
844
  def create_gradio_interface():
845
+
846
  custom_css = """
847
  .gradio-container { max-width: 1400px !important; }
848
+ .example-row { margin-top: 15px; padding: 12px; background:
849
+ #f8f9fa; border-radius: 6px; }
850
  .example-word-label {
851
  text-align: center;
852
  font-size: 28px !important;
853
  font-weight: bold !important;
854
+ color:
855
+ #2c3e50 !important;
856
  margin: 10px 0 !important;
857
  padding: 10px !important;
858
  }
859
  .example-variant-label {
860
  text-align: center;
861
  font-size: 14px !important;
862
+ color:
863
+ #7f8c8d !important;
864
  margin-bottom: 10px !important;
865
  }
866
  """
867
+
868
  example_list = list(_example_cache.values()) if _example_cache else []
869
+
870
  with gr.Blocks(title="SignMotionGPT", css=custom_css, theme=gr.themes.Default()) as demo:
871
+
872
  gr.Markdown("# SignMotionGPT Demo")
873
  gr.Markdown("Text-to-Sign Language Motion Generation with Variant Comparison")
874
  gr.Markdown("*High-quality PyRender visualization with proper hand motion rendering*")
875
+
876
  with gr.Row():
877
  with gr.Column(scale=1, min_width=280):
878
  gr.Markdown("### Input")
879
+
880
  word_input = gr.Textbox(
881
  label="Word",
882
  placeholder="Enter a word from the dataset...",
883
  lines=1, max_lines=1
884
  )
885
+
886
  generate_btn = gr.Button("Generate Motion", variant="primary", size="lg")
887
+
888
  gr.Markdown("---")
889
  gr.Markdown("### Generated Tokens")
890
+
891
  tokens_output = gr.Textbox(
892
  label="Motion Tokens (both variants)",
893
  lines=8,
894
  interactive=False,
895
  show_copy_button=True
896
  )
897
+
898
  if _word_pid_map:
899
  sample_words = list(_word_pid_map.keys())[:10]
900
  gr.Markdown(f"**Available words:** {', '.join(sample_words)}, ...")
901
+
902
  with gr.Column(scale=2, min_width=700):
903
  gr.Markdown("### Motion Comparison (Two Signer Variants)")
904
  video_output = gr.Video(
 
906
  autoplay=True,
907
  show_download_button=True
908
  )
909
+
910
  if example_list:
911
  gr.Markdown("---")
912
  gr.Markdown("### Pre-computed Examples")
913
+
914
  for item in example_list:
915
  word, pid = item['word'], item['pid']
916
  with gr.Row(elem_classes="example-row"):
 
918
  gr.HTML(f'<div class="example-word-label">{word.upper()}</div>')
919
  gr.HTML(f'<div class="example-variant-label">Variant: {pid}</div>')
920
  example_btn = gr.Button("Load Example", size="sm", variant="secondary")
921
+
922
  with gr.Column(scale=3, min_width=500):
923
  example_video = gr.Video(
924
  label=f"Example: {word}",
925
  autoplay=False,
926
  show_download_button=True
927
  )
928
+
929
  example_btn.click(
930
  fn=lambda w=word, p=pid: get_example_video(w, p),
931
  inputs=[],
932
  outputs=[example_video, tokens_output]
933
  )
934
+
935
  gr.Markdown("---")
936
  gr.Markdown("*SignMotionGPT: LLM-based sign language motion generation with PyRender visualization*")
937
+
938
  generate_btn.click(
939
  fn=process_word,
940
  inputs=[word_input],
941
  outputs=[video_output, tokens_output]
942
  )
943
+
944
  word_input.submit(
945
  fn=process_word,
946
  inputs=[word_input],
947
  outputs=[video_output, tokens_output]
948
  )
 
 
949
 
950
+ return demo
951
  # =====================================================================
952
  # Main Entry Point for HuggingFace Spaces
953
  # =====================================================================
 
961
  print(f"Dataset: {DATASET_PATH}")
962
  print(f"PyRender Available: {PYRENDER_AVAILABLE}")
963
  print("="*60 + "\n")
 
964
  # Initialize models at startup
965
  initialize_models()
 
966
  # Pre-compute example animations
967
  precompute_examples()
 
968
  # Create and launch interface
969
  demo = create_gradio_interface()
 
970
  if __name__ == "__main__":
971
  # Launch with settings for HuggingFace Spaces
972
  demo.launch(
973
  server_name="0.0.0.0",
974
  server_port=7860,
975
  share=False
976
+ )