rdz-falcon commited on
Commit
cfc32bc
·
verified ·
1 Parent(s): 0ea30b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +617 -317
app.py CHANGED
@@ -1,37 +1,24 @@
 
 
 
 
1
  import os
2
  import sys
3
  import re
4
  import json
5
  import random
6
- import argparse
7
  import warnings
8
  import html as html_module
9
- import shutil
10
- from pathlib import Path
11
 
12
  import torch
13
  import numpy as np
14
- from huggingface_hub import hf_hub_download, snapshot_download
15
-
16
- # Clean imports for Spaces (relies on requirements.txt)
17
- import gradio as gr
18
- import plotly.graph_objects as go
19
- from plotly.subplots import make_subplots
20
- import smplx
21
-
22
- from transformers import AutoModelForCausalLM, AutoTokenizer
23
 
24
  warnings.filterwarnings("ignore")
25
 
26
  # =====================================================================
27
- # Configuration
28
  # =====================================================================
29
- # The Repo ID where your LLM and auxiliary files (vqvae, dataset) are stored
30
- HF_REPO_ID = os.environ.get("HF_REPO_ID", "rdz-falcon/SignMotionGPTfit-archive")
31
- HF_SUBFOLDER = os.environ.get("HF_SUBFOLDER", "stage2_v2/epoch-030")
32
-
33
- # Spaces run in /home/user/app. We set up paths relative to that.
34
- WORK_DIR = os.getcwd()
35
  DATA_DIR = os.path.join(WORK_DIR, "data")
36
  os.makedirs(DATA_DIR, exist_ok=True)
37
 
@@ -41,19 +28,21 @@ VQVAE_CHECKPOINT = os.path.join(DATA_DIR, "vqvae_model.pt")
41
  STATS_PATH = os.path.join(DATA_DIR, "vqvae_stats.pt")
42
  SMPLX_MODEL_DIR = os.path.join(DATA_DIR, "smplx_models")
43
 
 
 
 
 
44
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
 
46
- # Token definitions
47
  M_START = "<M_START>"
48
  M_END = "<M_END>"
49
  PAD_TOKEN = "<PAD>"
50
-
51
- # Inference settings
52
  INFERENCE_TEMPERATURE = 0.7
53
  INFERENCE_TOP_K = 50
54
  INFERENCE_REPETITION_PENALTY = 1.2
55
 
56
- # Architecture settings
57
  SMPL_DIM = 182
58
  CODEBOOK_SIZE = 512
59
  CODE_DIM = 512
@@ -67,77 +56,45 @@ PARAM_NAMES = ["betas", "body_pose", "left_hand_pose", "right_hand_pose",
67
  "trans", "expression", "jaw_pose", "eye_pose"]
68
 
69
  # =====================================================================
70
- # Helper: Download Assets from HF Hub
71
  # =====================================================================
72
- def download_artifacts():
73
- """
74
- Attempts to download missing auxiliary files (VQVAE, Stats, Dataset, SMPLX)
75
- from the Hugging Face Hub Repository if they don't exist locally.
76
- """
77
- print(f"Checking for artifacts in {HF_REPO_ID}...")
78
- token = os.environ.get("HF_TOKEN") # Ensure this is set in Space Settings if repo is private
79
 
80
- # 1. Download Dataset
81
- if not os.path.exists(DATASET_PATH):
82
- try:
83
- print("Downloading dataset...")
84
- hf_hub_download(repo_id=HF_REPO_ID, filename="enriched_dataset.json",
85
- local_dir=WORK_DIR, token=token)
86
- except Exception as e:
87
- print(f"Warning: Could not download dataset: {e}")
88
 
89
- # 2. Download VQVAE Model
90
- if not os.path.exists(VQVAE_CHECKPOINT):
91
- try:
92
- print("Downloading VQVAE model...")
93
- # Assuming these are in a 'data' folder in your repo, or root. Adjust filename path as needed.
94
- hf_hub_download(repo_id=HF_REPO_ID, filename="data/vqvae_model.pt",
95
- local_dir=WORK_DIR, token=token)
96
- except Exception as e:
97
- # Fallback try root
98
- try:
99
- hf_hub_download(repo_id=HF_REPO_ID, filename="vqvae_model.pt",
100
- local_dir=DATA_DIR, token=token)
101
- except:
102
- print(f"Warning: Could not download VQVAE model: {e}")
103
-
104
- # 3. Download Stats
105
- if not os.path.exists(STATS_PATH):
106
- try:
107
- print("Downloading VQVAE stats...")
108
- hf_hub_download(repo_id=HF_REPO_ID, filename="data/vqvae_stats.pt",
109
- local_dir=WORK_DIR, token=token)
110
- except Exception as e:
111
- try:
112
- hf_hub_download(repo_id=HF_REPO_ID, filename="vqvae_stats.pt",
113
- local_dir=DATA_DIR, token=token)
114
- except:
115
- print(f"Warning: Could not download VQVAE stats: {e}")
116
-
117
- # 4. SMPLX Models
118
- # Note: SMPLX models are licensed. If you can't host them, users must upload them.
119
- # If they are in your repo (e.g. inside a zip or folder), download them here.
120
- if not os.path.exists(SMPLX_MODEL_DIR):
121
- print("Looking for SMPL-X models...")
122
- try:
123
- # Attempt to download a folder if it exists in the repo
124
- snapshot_download(repo_id=HF_REPO_ID, allow_patterns="smplx_models/*",
125
- local_dir=DATA_DIR, token=token)
126
- except Exception as e:
127
- print(f"Warning: Could not download SMPL-X models. Ensure 'smplx_models' folder exists in {DATA_DIR} or repo.")
128
 
 
129
 
130
  # =====================================================================
131
  # Import VQ-VAE architecture
132
  # =====================================================================
133
- # Ensure current directory is in path so mGPT import works
134
- sys.path.append(os.getcwd())
 
 
 
 
 
135
 
136
  try:
137
- # This requires the mGPT folder to be uploaded to the Space
138
  from mGPT.archs.mgpt_vq import VQVae
139
  except ImportError as e:
140
- print(f"Error: Could not import VQVae. Ensure the 'mGPT' folder is uploaded to the Space files. Details: {e}")
141
  VQVae = None
142
 
143
  # =====================================================================
@@ -152,14 +109,16 @@ _model_cache = {
152
  "initialized": False
153
  }
154
 
155
- _word_pid_map = {}
156
- _example_cache = {}
157
 
158
  # =====================================================================
159
- # Dataset Loading
160
  # =====================================================================
161
  def load_word_pid_mapping():
 
162
  global _word_pid_map
 
163
  if not os.path.exists(DATASET_PATH):
164
  print(f"Dataset not found: {DATASET_PATH}")
165
  return
@@ -177,37 +136,50 @@ def load_word_pid_mapping():
177
  _word_pid_map[word] = set()
178
  _word_pid_map[word].add(pid)
179
 
 
180
  for word in _word_pid_map:
181
  _word_pid_map[word] = sorted(list(_word_pid_map[word]))
 
182
  print(f"Loaded {len(_word_pid_map)} unique words from dataset")
183
  except Exception as e:
184
  print(f"Error loading dataset: {e}")
185
 
 
186
  def get_pids_for_word(word: str) -> list:
187
- return _word_pid_map.get(word.lower().strip(), [])
 
 
 
188
 
189
  def get_random_pids_for_word(word: str, count: int = 2) -> list:
 
190
  pids = get_pids_for_word(word)
191
- if not pids: return []
192
- if len(pids) <= count: return pids
 
 
193
  return random.sample(pids, count)
194
 
 
195
  def get_example_words_with_pids(count: int = 3) -> list:
 
196
  examples = []
197
  preferred = ['push', 'passport', 'library', 'send', 'college', 'help', 'thank', 'hello']
 
198
  for word in preferred:
199
  pids = get_pids_for_word(word)
200
  if pids:
201
  examples.append((word, pids[0]))
202
- if len(examples) >= count: break
 
203
 
204
  if len(examples) < count:
205
  available = [w for w in _word_pid_map.keys() if w not in [e[0] for e in examples]]
206
- if available:
207
- random.shuffle(available)
208
- for word in available[:count - len(examples)]:
209
- pids = _word_pid_map[word]
210
- examples.append((word, pids[0]))
211
  return examples
212
 
213
  # =====================================================================
@@ -224,117 +196,127 @@ class MotionGPT_VQVAE_Wrapper(torch.nn.Module):
224
  )
225
 
226
  # =====================================================================
227
- # Model Loading
228
  # =====================================================================
229
  def load_llm_model():
230
  print(f"Loading LLM from: {HF_REPO_ID}/{HF_SUBFOLDER}")
231
- # Use environment token if available for private repos
232
- token = os.environ.get("HF_TOKEN")
233
- try:
234
- tokenizer = AutoTokenizer.from_pretrained(HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True, token=token)
235
- model = AutoModelForCausalLM.from_pretrained(
236
- HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True,
237
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
238
- token=token
239
- )
240
- if tokenizer.pad_token is None:
241
- tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
242
- model.resize_token_embeddings(len(tokenizer))
243
- model.config.pad_token_id = tokenizer.pad_token_id
244
- model.to(DEVICE)
245
- model.eval()
246
- print(f"LLM loaded (vocab size: {len(tokenizer)})")
247
- return model, tokenizer
248
- except Exception as e:
249
- print(f"Error loading LLM: {e}")
250
- return None, None
251
 
252
  def load_vqvae_model():
253
  if not os.path.exists(VQVAE_CHECKPOINT):
254
- print(f"VQ-VAE checkpoint not found at {VQVAE_CHECKPOINT}")
255
  return None
256
  print(f"Loading VQ-VAE from: {VQVAE_CHECKPOINT}")
257
- try:
258
- model = MotionGPT_VQVAE_Wrapper(smpl_dim=SMPL_DIM, codebook_size=CODEBOOK_SIZE, code_dim=CODE_DIM, **VQ_ARGS).to(DEVICE)
259
- ckpt = torch.load(VQVAE_CHECKPOINT, map_location=DEVICE, weights_only=False) # Removed weights_only=False for compatibility, add back if torch version requires
260
- state_dict = ckpt.get('model_state_dict', ckpt)
261
- model.load_state_dict(state_dict, strict=False)
262
- model.eval()
263
- return model
264
- except Exception as e:
265
- print(f"Error loading VQVAE: {e}")
266
- return None
267
 
268
  def load_stats():
269
  if not os.path.exists(STATS_PATH):
270
  return None, None
271
- try:
272
- st = torch.load(STATS_PATH, map_location='cpu')
273
- mean, std = st.get('mean', 0), st.get('std', 1)
274
- if torch.is_tensor(mean): mean = mean.cpu().numpy()
275
- if torch.is_tensor(std): std = std.cpu().numpy()
276
- return mean, std
277
- except Exception as e:
278
- print(f"Error loading stats: {e}")
279
- return None, None
280
 
281
  def load_smplx_model():
282
  if not os.path.exists(SMPLX_MODEL_DIR):
283
  print(f"SMPL-X directory not found: {SMPLX_MODEL_DIR}")
284
  return None
285
  print(f"Loading SMPL-X from: {SMPLX_MODEL_DIR}")
286
- try:
287
- model = smplx.SMPLX(
288
- model_path=SMPLX_MODEL_DIR, model_type='smplx', gender='neutral', use_pca=False,
289
- create_global_orient=True, create_body_pose=True, create_betas=True,
290
- create_expression=True, create_jaw_pose=True, create_left_hand_pose=True,
291
- create_right_hand_pose=True, create_transl=True
292
- ).to(DEVICE)
293
- return model
294
- except Exception as e:
295
- print(f"Error loading SMPL-X: {e}")
296
- return None
297
 
298
  def initialize_models():
299
  global _model_cache
300
- if _model_cache["initialized"]: return
 
301
 
302
- print("Initializing Models...")
303
- # Download assets first
304
- download_artifacts()
305
 
 
306
  load_word_pid_mapping()
 
307
  _model_cache["llm_model"], _model_cache["llm_tokenizer"] = load_llm_model()
308
- _model_cache["vqvae_model"] = load_vqvae_model()
309
- _model_cache["stats"] = load_stats()
310
- _model_cache["smplx_model"] = load_smplx_model()
 
 
 
 
311
 
312
  _model_cache["initialized"] = True
313
- print("Initialization complete.")
 
 
314
 
315
  def precompute_examples():
 
316
  global _example_cache
317
- if not _model_cache["initialized"]: return
 
 
318
 
319
  examples = get_example_words_with_pids(3)
320
- if not examples: return
321
-
322
- print(f"Pre-computing {len(examples)} examples...")
323
  for word, pid in examples:
324
  key = f"{word}_{pid}"
 
325
  try:
326
  html, tokens = generate_animation_for_word(word, pid, upper_body_only=True)
327
  _example_cache[key] = {"html": html, "tokens": tokens, "word": word, "pid": pid}
 
328
  except Exception as e:
329
- print(f"Failed pre-compute {word}: {e}")
 
 
 
330
 
331
  # =====================================================================
332
- # Motion Generation & Visualization Logic (Kept largely the same)
333
  # =====================================================================
334
  def generate_motion_tokens(word: str, variant: str) -> str:
335
  model = _model_cache["llm_model"]
336
  tokenizer = _model_cache["llm_tokenizer"]
337
- if model is None: return "Error: LLM not loaded."
 
 
338
 
339
  prompt = f"Instruction: Generate motion for word '{word}' with variant '{variant}'.\nMotion: "
340
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
@@ -348,269 +330,587 @@ def generate_motion_tokens(word: str, variant: str) -> str:
348
  eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
349
  early_stopping=True
350
  )
 
351
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
352
  motion_part = decoded.split("Motion: ")[-1] if "Motion: " in decoded else decoded
353
  return motion_part.strip()
354
 
 
355
  def parse_motion_tokens(token_str: str) -> list:
356
- if isinstance(token_str, str):
357
- matches = re.findall(r'<M(\d+)>', token_str)
358
- if not matches: matches = re.findall(r'<motion_(\d+)>', token_str)
359
- if matches: return [int(x) for x in matches]
 
 
 
 
 
 
 
 
 
360
  return []
361
 
 
362
  def decode_tokens_to_params(tokens: list) -> np.ndarray:
363
  vqvae_model = _model_cache["vqvae_model"]
364
  mean, std = _model_cache["stats"]
365
- if vqvae_model is None or not tokens: return np.zeros((0, SMPL_DIM), dtype=np.float32)
 
 
366
 
367
  idx = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
368
- with torch.no_grad():
369
- quantizer = vqvae_model.vqvae.quantizer
370
- if hasattr(quantizer, "codebook"):
371
- codebook = quantizer.codebook.to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  emb = codebook[idx]
373
  x_quantized = emb.permute(0, 2, 1).contiguous()
374
- else:
375
- # Fallback if specific quantizer logic fails
376
- return np.zeros((0, SMPL_DIM), dtype=np.float32)
377
-
378
  x_dec = vqvae_model.vqvae.decoder(x_quantized)
379
  smpl_out = vqvae_model.vqvae.postprocess(x_dec)
380
  params_np = smpl_out.squeeze(0).cpu().numpy()
381
-
382
- if mean is not None and std is not None:
383
  params_np = (params_np * np.array(std).reshape(1, -1)) + np.array(mean).reshape(1, -1)
 
384
  return params_np
385
 
 
386
  def params_to_vertices(params_seq: np.ndarray) -> tuple:
387
  smplx_model = _model_cache["smplx_model"]
388
- if smplx_model is None: return None, None
 
389
 
390
  starts = np.cumsum([0] + PARAM_DIMS[:-1])
391
  ends = starts + np.array(PARAM_DIMS)
392
  T = params_seq.shape[0]
393
  all_verts = []
394
-
395
- # Process in chunks to avoid memory issues on CPU spaces
396
- batch_size = 10
397
 
398
  with torch.no_grad():
399
  for s in range(0, T, batch_size):
400
  batch = params_seq[s:s+batch_size]
 
 
401
  np_parts = {name: batch[:, st:ed].astype(np.float32) for name, st, ed in zip(PARAM_NAMES, starts, ends)}
402
  tensor_parts = {name: torch.from_numpy(arr).to(DEVICE) for name, arr in np_parts.items()}
403
 
404
- # Simple handling for body pose/orient split
405
  body_t = tensor_parts['body_pose']
406
- # Assumption: Model output matches SMPL-X expectations.
407
- # Simplified logic for demo stability:
408
- global_orient = body_t[:, :3].contiguous()
409
- body_pose_only = body_t[:, 3:66].contiguous() # Trim to standard 63 if needed, or keep dynamic
410
-
411
- try:
412
- out = smplx_model(
413
- betas=tensor_parts['betas'], global_orient=global_orient, body_pose=body_pose_only,
414
- left_hand_pose=tensor_parts['left_hand_pose'], right_hand_pose=tensor_parts['right_hand_pose'],
415
- expression=tensor_parts['expression'], jaw_pose=tensor_parts['jaw_pose'],
416
- leye_pose=tensor_parts['eye_pose'], reye_pose=tensor_parts['eye_pose'],
417
- transl=tensor_parts['trans'], return_verts=True
418
- )
419
- all_verts.append(out.vertices.detach().cpu().numpy())
420
- except Exception as e:
421
- print(f"SMPL-X Forward pass error: {e}")
422
- return None, None
423
-
424
- if not all_verts: return None, None
 
 
 
 
 
 
 
 
425
  return np.concatenate(all_verts, axis=0), smplx_model.faces.astype(np.int32)
426
 
427
- def compute_upper_body_bounds(verts):
428
- if verts is None: return None
 
 
 
 
429
  v = verts[0]
430
  y_min, y_max = v[:, 1].min(), v[:, 1].max()
431
  x_min, x_max = v[:, 0].min(), v[:, 0].max()
432
  z_min, z_max = v[:, 2].min(), v[:, 2].max()
 
433
  body_height = y_max - y_min
434
  waist_y = y_min + body_height * 0.45
 
 
 
 
435
 
436
- # Add margins
437
  return {
438
- 'y_range': [waist_y, y_max + 0.1],
439
- 'x_range': [x_min - 0.2, x_max + 0.2],
440
- 'z_range': [z_min - 0.2, z_max + 0.2],
441
- 'center': [(x_min + x_max)/2, (waist_y + y_max)/2, (z_min + z_max)/2]
 
 
442
  }
443
 
444
  # =====================================================================
445
- # HTML Generation
446
  # =====================================================================
447
- def create_animation_html(verts, faces, upper_body_only=True, title=""):
448
- if verts is None: return create_error_html("Model generation failed.")
 
 
 
449
 
450
- T = verts.shape[0]
451
  i, j, k = faces.T.tolist()
 
 
452
  bounds = compute_upper_body_bounds(verts) if upper_body_only else None
453
 
454
- mesh = go.Mesh3d(x=verts[0,:,0], y=verts[0,:,1], z=verts[0,:,2], i=i, j=j, k=k,
455
- color='#6FA8DC', opacity=0.8, flatshading=True)
 
 
 
 
456
 
457
- frames = [go.Frame(data=[go.Mesh3d(x=verts[t,:,0], y=verts[t,:,1], z=verts[t,:,2], i=i, j=j, k=k)], name=str(t)) for t in range(T)]
 
 
 
 
 
 
 
 
 
 
 
458
 
459
- scene_cfg = dict(aspectmode='data', xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False))
460
- if bounds:
461
- scene_cfg.update(dict(
462
- xaxis=dict(range=bounds['x_range'], visible=False),
463
- yaxis=dict(range=bounds['y_range'], visible=False),
464
- zaxis=dict(range=bounds['z_range'], visible=False),
465
- aspectmode='manual', aspectratio=dict(x=1, y=1, z=1),
466
- camera=dict(eye=dict(x=0, y=0.5, z=2.0))
467
- ))
468
-
469
  fig = go.Figure(data=[mesh], frames=frames)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  fig.update_layout(
471
- title=title, scene=scene_cfg, height=500, margin=dict(l=0, r=0, t=30, b=0),
472
- updatemenus=[dict(type="buttons", buttons=[dict(label="Play", method="animate", args=[None, {"frame": {"duration": 50}}])])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  )
474
- return fig.to_html(include_plotlyjs='cdn', full_html=True)
475
 
476
- def create_side_by_side_html(verts1, faces1, verts2, faces2, title1="", title2=""):
477
- if verts1 is None or verts2 is None: return create_error_html("One or both models failed.")
 
 
 
 
478
  T = min(verts1.shape[0], verts2.shape[0])
479
  verts1, verts2 = verts1[:T], verts2[:T]
 
480
  i1, j1, k1 = faces1.T.tolist()
481
  i2, j2, k2 = faces2.T.tolist()
 
482
 
483
- fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'scene'}, {'type': 'scene'}]], subplot_titles=[title1, title2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
- fig.add_trace(go.Mesh3d(x=verts1[0,:,0], y=verts1[0,:,1], z=verts1[0,:,2], i=i1, j=j1, k1=k1, color='#6FA8DC'), row=1, col=1)
486
- fig.add_trace(go.Mesh3d(x=verts2[0,:,0], y=verts2[0,:,1], z=verts2[0,:,2], i=i2, j=j2, k2=k2, color='#93C47D'), row=1, col=2)
487
 
488
  frames = []
489
  for t in range(T):
490
- frames.append(go.Frame(data=[
491
- go.Mesh3d(x=verts1[t,:,0], y=verts1[t,:,1], z=verts1[t,:,2], i=i1, j=j1, k=k1),
492
- go.Mesh3d(x=verts2[t,:,0], y=verts2[t,:,1], z=verts2[t,:,2], i=i2, j=j2, k=k2)
493
- ], name=str(t)))
494
-
 
 
 
 
 
 
495
  fig.frames = frames
496
 
497
- # Generic simple camera
498
- cam = dict(eye=dict(x=0, y=0, z=2.2), up=dict(x=0, y=1, z=0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  fig.update_layout(
500
- scene=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), camera=cam, aspectmode='data'),
501
- scene2=dict(xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), camera=cam, aspectmode='data'),
502
- height=500, margin=dict(l=0, r=0, t=30, b=0),
503
- updatemenus=[dict(type="buttons", buttons=[dict(label="Play", method="animate", args=[None, {"frame": {"duration": 50}}])])]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  )
505
- return fig.to_html(include_plotlyjs='cdn', full_html=True)
506
 
507
- def create_iframe_html(html_content):
508
- escaped = html_module.escape(html_content)
509
- return f'<iframe srcdoc="{escaped}" style="width: 100%; height: 520px; border: none;"></iframe>'
510
 
511
- def create_error_html(msg):
512
- return f'<div style="text-align:center; padding:50px;">{msg}</div>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
- def create_placeholder_html():
515
- return '<div style="text-align:center; padding:50px; color:#666;">Enter a word to generate animation</div>'
 
 
 
 
 
516
 
517
  # =====================================================================
518
- # Main Generators
519
  # =====================================================================
520
- def generate_verts_for_word(word, pid):
521
- gen_tokens = generate_motion_tokens(word, pid)
522
- ids = parse_motion_tokens(gen_tokens)
523
- if not ids: return None, None, gen_tokens
524
- params = decode_tokens_to_params(ids)
 
 
 
 
 
 
 
 
 
 
525
  verts, faces = params_to_vertices(params)
526
- return verts, faces, gen_tokens
 
527
 
528
- def generate_animation_for_word(word, pid, upper_body_only=True):
 
529
  verts, faces, tokens = generate_verts_for_word(word, pid)
530
- html = create_animation_html(verts, faces, upper_body_only, title=pid)
531
- return html, tokens
 
 
 
 
 
532
 
533
- def process_word(word):
534
- if not _model_cache["initialized"]: initialize_models()
 
 
535
 
536
  word = word.strip().lower()
 
537
  pids = get_random_pids_for_word(word, 2)
538
 
539
  if not pids:
540
- return create_iframe_html(create_error_html(f"Word '{word}' not found in dataset.")), ""
541
 
542
- if len(pids) == 1: pids = [pids[0], pids[0]]
 
543
 
544
  try:
545
- verts1, faces1, tok1 = generate_verts_for_word(word, pids[0])
546
- verts2, faces2, tok2 = generate_verts_for_word(word, pids[1])
547
 
548
  if verts1 is None and verts2 is None:
549
- return create_iframe_html(create_error_html("Motion generation failed.")), f"{tok1}\n{tok2}"
550
 
551
- # If one fails, show single
552
- if verts1 is None: return create_iframe_html(create_animation_html(verts2, faces2, title=pids[1])), tok2
553
- if verts2 is None: return create_iframe_html(create_animation_html(verts1, faces1, title=pids[0])), tok1
554
-
555
- html = create_side_by_side_html(verts1, faces1, verts2, faces2, title1=pids[0], title2=pids[1])
556
- return create_iframe_html(html), f"[{pids[0]}] {tok1}\n\n[{pids[1]}] {tok2}"
 
 
 
 
 
557
 
558
  except Exception as e:
559
- return create_iframe_html(create_error_html(f"Error: {str(e)}")), ""
560
 
561
- def get_example(word, pid):
562
- if not _model_cache["initialized"]: initialize_models()
 
563
  key = f"{word}_{pid}"
564
  if key in _example_cache:
565
- return create_iframe_html(_example_cache[key]["html"]), _example_cache[key]["tokens"]
566
- # Generate on fly if cache miss
567
- html, tok = generate_animation_for_word(word, pid)
568
- return create_iframe_html(html), tok
569
 
570
  # =====================================================================
571
- # App Launch
572
  # =====================================================================
573
- def create_ui():
574
- initialize_models()
575
- precompute_examples()
 
 
 
 
 
 
 
576
 
577
- with gr.Blocks(title="SignMotionGPT", theme=gr.themes.Default()) as demo:
 
578
  gr.Markdown("# SignMotionGPT Demo")
579
- gr.Markdown("Input a word to generate sign language motion.")
580
-
581
  with gr.Row():
582
- with gr.Column(scale=1):
583
- txt_input = gr.Textbox(label="Word", placeholder="e.g. hello, help, computer")
584
- btn = gr.Button("Generate", variant="primary")
585
- txt_out = gr.Textbox(label="Generated Tokens", lines=5)
 
 
 
 
 
 
586
 
587
- # Examples
588
- if _example_cache:
589
- gr.Markdown("### Examples")
590
- for k, v in _example_cache.items():
591
- gr.Button(f"{v['word']} ({v['pid']})").click(
592
- fn=lambda w=v['word'], p=v['pid']: get_example(w, p),
593
- outputs=[gr.HTML(), txt_out] # Hack: we need to target the main output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  )
595
- # To keep UI simple, I'll just skip complex example buttons in this condensed version
596
- # and rely on the user typing.
597
-
598
- with gr.Column(scale=2):
599
- html_out = gr.HTML(label="Visual", value=create_iframe_html(create_placeholder_html()))
600
-
601
- # Wire up
602
- btn.click(process_word, inputs=[txt_input], outputs=[html_out, txt_out])
603
- txt_input.submit(process_word, inputs=[txt_input], outputs=[html_out, txt_out])
 
 
 
 
 
 
604
 
 
 
 
 
 
 
605
  return demo
606
 
607
- if __name__ == "__main__":
608
- # Initialize immediately on startup to fail fast if files missing
609
- try:
610
- initialize_models()
611
- except Exception as e:
612
- print(f"Startup initialization warning: {e}")
 
 
 
 
 
 
 
 
 
 
 
613
 
614
- demo = create_ui()
615
- # In Spaces, simply use .launch() without arguments
616
- demo.launch()
 
 
 
1
+ """
2
+ SignMotionGPT - HuggingFace Spaces Demo
3
+ Text-to-Sign Language Motion Generation
4
+ """
5
  import os
6
  import sys
7
  import re
8
  import json
9
  import random
 
10
  import warnings
11
  import html as html_module
 
 
12
 
13
  import torch
14
  import numpy as np
 
 
 
 
 
 
 
 
 
15
 
16
  warnings.filterwarnings("ignore")
17
 
18
  # =====================================================================
19
+ # Configuration for HuggingFace Spaces
20
  # =====================================================================
21
+ WORK_DIR = os.getcwd()
 
 
 
 
 
22
  DATA_DIR = os.path.join(WORK_DIR, "data")
23
  os.makedirs(DATA_DIR, exist_ok=True)
24
 
 
28
  STATS_PATH = os.path.join(DATA_DIR, "vqvae_stats.pt")
29
  SMPLX_MODEL_DIR = os.path.join(DATA_DIR, "smplx_models")
30
 
31
+ # HuggingFace model config
32
+ HF_REPO_ID = os.environ.get("HF_REPO_ID", "rdz-falcon/SignMotionGPTfit-archive")
33
+ HF_SUBFOLDER = os.environ.get("HF_SUBFOLDER", "stage2_v2/epoch-030")
34
+
35
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
 
37
+ # Generation parameters
38
  M_START = "<M_START>"
39
  M_END = "<M_END>"
40
  PAD_TOKEN = "<PAD>"
 
 
41
  INFERENCE_TEMPERATURE = 0.7
42
  INFERENCE_TOP_K = 50
43
  INFERENCE_REPETITION_PENALTY = 1.2
44
 
45
+ # VQ-VAE parameters
46
  SMPL_DIM = 182
47
  CODEBOOK_SIZE = 512
48
  CODE_DIM = 512
 
56
  "trans", "expression", "jaw_pose", "eye_pose"]
57
 
58
  # =====================================================================
59
+ # Install/Import Dependencies
60
  # =====================================================================
61
+ try:
62
+ import gradio as gr
63
+ except ImportError:
64
+ os.system("pip install -q gradio>=4.0.0")
65
+ import gradio as gr
 
 
66
 
67
+ try:
68
+ import plotly.graph_objects as go
69
+ from plotly.subplots import make_subplots
70
+ except ImportError:
71
+ os.system("pip install -q plotly>=5.18.0")
72
+ import plotly.graph_objects as go
73
+ from plotly.subplots import make_subplots
 
74
 
75
+ try:
76
+ import smplx
77
+ except ImportError:
78
+ os.system("pip install -q smplx==0.1.28")
79
+ import smplx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ from transformers import AutoModelForCausalLM, AutoTokenizer
82
 
83
  # =====================================================================
84
  # Import VQ-VAE architecture
85
  # =====================================================================
86
+ # Add parent directory to path for mGPT imports
87
+ current_dir = os.path.dirname(os.path.abspath(__file__))
88
+ parent_dir = os.path.dirname(current_dir)
89
+ if parent_dir not in sys.path:
90
+ sys.path.insert(0, parent_dir)
91
+ if current_dir not in sys.path:
92
+ sys.path.insert(0, current_dir)
93
 
94
  try:
 
95
  from mGPT.archs.mgpt_vq import VQVae
96
  except ImportError as e:
97
+ print(f"Warning: Could not import VQVae: {e}")
98
  VQVae = None
99
 
100
  # =====================================================================
 
109
  "initialized": False
110
  }
111
 
112
+ _word_pid_map = {} # word -> list of valid PIDs
113
+ _example_cache = {} # Pre-computed example animations
114
 
115
  # =====================================================================
116
+ # Dataset Loading - Word to PID mapping
117
  # =====================================================================
118
  def load_word_pid_mapping():
119
+ """Load the dataset and build word -> PIDs mapping."""
120
  global _word_pid_map
121
+
122
  if not os.path.exists(DATASET_PATH):
123
  print(f"Dataset not found: {DATASET_PATH}")
124
  return
 
136
  _word_pid_map[word] = set()
137
  _word_pid_map[word].add(pid)
138
 
139
+ # Convert sets to sorted lists
140
  for word in _word_pid_map:
141
  _word_pid_map[word] = sorted(list(_word_pid_map[word]))
142
+
143
  print(f"Loaded {len(_word_pid_map)} unique words from dataset")
144
  except Exception as e:
145
  print(f"Error loading dataset: {e}")
146
 
147
+
148
  def get_pids_for_word(word: str) -> list:
149
+ """Get valid PIDs for a word from the dataset."""
150
+ word = word.lower().strip()
151
+ return _word_pid_map.get(word, [])
152
+
153
 
154
  def get_random_pids_for_word(word: str, count: int = 2) -> list:
155
+ """Get random PIDs for a word. Returns up to 'count' PIDs."""
156
  pids = get_pids_for_word(word)
157
+ if not pids:
158
+ return []
159
+ if len(pids) <= count:
160
+ return pids
161
  return random.sample(pids, count)
162
 
163
+
164
  def get_example_words_with_pids(count: int = 3) -> list:
165
+ """Get example words with valid PIDs from dataset."""
166
  examples = []
167
  preferred = ['push', 'passport', 'library', 'send', 'college', 'help', 'thank', 'hello']
168
+
169
  for word in preferred:
170
  pids = get_pids_for_word(word)
171
  if pids:
172
  examples.append((word, pids[0]))
173
+ if len(examples) >= count:
174
+ break
175
 
176
  if len(examples) < count:
177
  available = [w for w in _word_pid_map.keys() if w not in [e[0] for e in examples]]
178
+ random.shuffle(available)
179
+ for word in available[:count - len(examples)]:
180
+ pids = _word_pid_map[word]
181
+ examples.append((word, pids[0]))
182
+
183
  return examples
184
 
185
  # =====================================================================
 
196
  )
197
 
198
  # =====================================================================
199
+ # Model Loading Functions
200
  # =====================================================================
201
  def load_llm_model():
202
  print(f"Loading LLM from: {HF_REPO_ID}/{HF_SUBFOLDER}")
203
+ token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
204
+
205
+ tokenizer = AutoTokenizer.from_pretrained(
206
+ HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True, token=token
207
+ )
208
+ model = AutoModelForCausalLM.from_pretrained(
209
+ HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True, token=token,
210
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
211
+ )
212
+ if tokenizer.pad_token is None:
213
+ tokenizer.add_special_tokens({"pad_token": PAD_TOKEN})
214
+ model.resize_token_embeddings(len(tokenizer))
215
+ model.config.pad_token_id = tokenizer.pad_token_id
216
+ model.to(DEVICE)
217
+ model.eval()
218
+ print(f"LLM loaded (vocab size: {len(tokenizer)})")
219
+ return model, tokenizer
220
+
 
 
221
 
222
  def load_vqvae_model():
223
  if not os.path.exists(VQVAE_CHECKPOINT):
224
+ print(f"VQ-VAE checkpoint not found: {VQVAE_CHECKPOINT}")
225
  return None
226
  print(f"Loading VQ-VAE from: {VQVAE_CHECKPOINT}")
227
+ model = MotionGPT_VQVAE_Wrapper(smpl_dim=SMPL_DIM, codebook_size=CODEBOOK_SIZE, code_dim=CODE_DIM, **VQ_ARGS).to(DEVICE)
228
+ ckpt = torch.load(VQVAE_CHECKPOINT, map_location=DEVICE, weights_only=False)
229
+ state_dict = ckpt.get('model_state_dict', ckpt)
230
+ model.load_state_dict(state_dict, strict=False)
231
+ model.eval()
232
+ print(f"VQ-VAE loaded")
233
+ return model
234
+
 
 
235
 
236
  def load_stats():
237
  if not os.path.exists(STATS_PATH):
238
  return None, None
239
+ st = torch.load(STATS_PATH, map_location='cpu', weights_only=False)
240
+ mean, std = st.get('mean', 0), st.get('std', 1)
241
+ if torch.is_tensor(mean): mean = mean.cpu().numpy()
242
+ if torch.is_tensor(std): std = std.cpu().numpy()
243
+ return mean, std
244
+
 
 
 
245
 
246
  def load_smplx_model():
247
  if not os.path.exists(SMPLX_MODEL_DIR):
248
  print(f"SMPL-X directory not found: {SMPLX_MODEL_DIR}")
249
  return None
250
  print(f"Loading SMPL-X from: {SMPLX_MODEL_DIR}")
251
+ model = smplx.SMPLX(
252
+ model_path=SMPLX_MODEL_DIR, model_type='smplx', gender='neutral', use_pca=False,
253
+ create_global_orient=True, create_body_pose=True, create_betas=True,
254
+ create_expression=True, create_jaw_pose=True, create_left_hand_pose=True,
255
+ create_right_hand_pose=True, create_transl=True
256
+ ).to(DEVICE)
257
+ print(f"SMPL-X loaded")
258
+ return model
259
+
 
 
260
 
261
  def initialize_models():
262
  global _model_cache
263
+ if _model_cache["initialized"]:
264
+ return
265
 
266
+ print("\n" + "="*60)
267
+ print(" Initializing SignMotionGPT Models")
268
+ print("="*60)
269
 
270
+ # Load word-PID mapping from dataset
271
  load_word_pid_mapping()
272
+
273
  _model_cache["llm_model"], _model_cache["llm_tokenizer"] = load_llm_model()
274
+
275
+ try:
276
+ _model_cache["vqvae_model"] = load_vqvae_model()
277
+ _model_cache["stats"] = load_stats()
278
+ _model_cache["smplx_model"] = load_smplx_model()
279
+ except Exception as e:
280
+ print(f"Could not load visualization models: {e}")
281
 
282
  _model_cache["initialized"] = True
283
+ print("All models initialized")
284
+ print("="*60)
285
+
286
 
287
  def precompute_examples():
288
+ """Pre-compute animations for example words at startup."""
289
  global _example_cache
290
+
291
+ if not _model_cache["initialized"]:
292
+ return
293
 
294
  examples = get_example_words_with_pids(3)
295
+
296
+ print(f"\nPre-computing {len(examples)} example animations...")
297
+
298
  for word, pid in examples:
299
  key = f"{word}_{pid}"
300
+ print(f" Computing: {word} ({pid})...")
301
  try:
302
  html, tokens = generate_animation_for_word(word, pid, upper_body_only=True)
303
  _example_cache[key] = {"html": html, "tokens": tokens, "word": word, "pid": pid}
304
+ print(f" Done: {word}")
305
  except Exception as e:
306
+ print(f" Failed: {word} - {e}")
307
+ _example_cache[key] = {"html": create_error_html(), "tokens": "", "word": word, "pid": pid}
308
+
309
+ print("Example pre-computation complete\n")
310
 
311
  # =====================================================================
312
+ # Motion Generation Functions
313
  # =====================================================================
314
  def generate_motion_tokens(word: str, variant: str) -> str:
315
  model = _model_cache["llm_model"]
316
  tokenizer = _model_cache["llm_tokenizer"]
317
+
318
+ if model is None or tokenizer is None:
319
+ raise RuntimeError("LLM model not loaded")
320
 
321
  prompt = f"Instruction: Generate motion for word '{word}' with variant '{variant}'.\nMotion: "
322
  inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
 
330
  eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
331
  early_stopping=True
332
  )
333
+
334
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
335
  motion_part = decoded.split("Motion: ")[-1] if "Motion: " in decoded else decoded
336
  return motion_part.strip()
337
 
338
+
339
  def parse_motion_tokens(token_str: str) -> list:
340
+ if isinstance(token_str, (list, tuple, np.ndarray)):
341
+ return [int(x) for x in token_str]
342
+ if not isinstance(token_str, str):
343
+ return []
344
+
345
+ matches = re.findall(r'<M(\d+)>', token_str)
346
+ if matches:
347
+ return [int(x) for x in matches]
348
+
349
+ matches = re.findall(r'<motion_(\d+)>', token_str)
350
+ if matches:
351
+ return [int(x) for x in matches]
352
+
353
  return []
354
 
355
+
356
  def decode_tokens_to_params(tokens: list) -> np.ndarray:
357
  vqvae_model = _model_cache["vqvae_model"]
358
  mean, std = _model_cache["stats"]
359
+
360
+ if vqvae_model is None or not tokens:
361
+ return np.zeros((0, SMPL_DIM), dtype=np.float32)
362
 
363
  idx = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
364
+ T_q = idx.shape[1]
365
+ quantizer = vqvae_model.vqvae.quantizer
366
+
367
+ if hasattr(quantizer, "codebook"):
368
+ codebook = quantizer.codebook.to(DEVICE)
369
+ code_dim = codebook.shape[1]
370
+ else:
371
+ code_dim = CODE_DIM
372
+
373
+ x_quantized = None
374
+ if hasattr(quantizer, "dequantize"):
375
+ try:
376
+ with torch.no_grad():
377
+ dq = quantizer.dequantize(idx)
378
+ if dq is not None:
379
+ dq = dq.contiguous()
380
+ if dq.ndim == 3 and dq.shape[1] == code_dim:
381
+ x_quantized = dq
382
+ elif dq.ndim == 3 and dq.shape[1] == T_q:
383
+ x_quantized = dq.permute(0, 2, 1).contiguous()
384
+ except Exception:
385
+ pass
386
+
387
+ if x_quantized is None:
388
+ if not hasattr(quantizer, "codebook"):
389
+ return np.zeros((0, SMPL_DIM), dtype=np.float32)
390
+ with torch.no_grad():
391
  emb = codebook[idx]
392
  x_quantized = emb.permute(0, 2, 1).contiguous()
393
+
394
+ with torch.no_grad():
 
 
395
  x_dec = vqvae_model.vqvae.decoder(x_quantized)
396
  smpl_out = vqvae_model.vqvae.postprocess(x_dec)
397
  params_np = smpl_out.squeeze(0).cpu().numpy()
398
+
399
+ if (mean is not None) and (std is not None):
400
  params_np = (params_np * np.array(std).reshape(1, -1)) + np.array(mean).reshape(1, -1)
401
+
402
  return params_np
403
 
404
+
405
  def params_to_vertices(params_seq: np.ndarray) -> tuple:
406
  smplx_model = _model_cache["smplx_model"]
407
+ if smplx_model is None or params_seq.shape[0] == 0:
408
+ return None, None
409
 
410
  starts = np.cumsum([0] + PARAM_DIMS[:-1])
411
  ends = starts + np.array(PARAM_DIMS)
412
  T = params_seq.shape[0]
413
  all_verts = []
414
+ batch_size = 32
415
+ num_body_joints = getattr(smplx_model, "NUM_BODY_JOINTS", 21)
 
416
 
417
  with torch.no_grad():
418
  for s in range(0, T, batch_size):
419
  batch = params_seq[s:s+batch_size]
420
+ B = batch.shape[0]
421
+
422
  np_parts = {name: batch[:, st:ed].astype(np.float32) for name, st, ed in zip(PARAM_NAMES, starts, ends)}
423
  tensor_parts = {name: torch.from_numpy(arr).to(DEVICE) for name, arr in np_parts.items()}
424
 
 
425
  body_t = tensor_parts['body_pose']
426
+ L_body = body_t.shape[1]
427
+ expected_no_go = num_body_joints * 3
428
+ expected_with_go = (num_body_joints + 1) * 3
429
+
430
+ if L_body == expected_with_go:
431
+ global_orient = body_t[:, :3].contiguous()
432
+ body_pose_only = body_t[:, 3:].contiguous()
433
+ elif L_body == expected_no_go:
434
+ global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
435
+ body_pose_only = body_t
436
+ else:
437
+ if L_body > expected_no_go:
438
+ global_orient = body_t[:, :3].contiguous()
439
+ body_pose_only = body_t[:, 3:].contiguous()
440
+ else:
441
+ body_pose_only = torch.nn.functional.pad(body_t, (0, max(0, expected_no_go - L_body)))
442
+ global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
443
+
444
+ out = smplx_model(
445
+ betas=tensor_parts['betas'], global_orient=global_orient, body_pose=body_pose_only,
446
+ left_hand_pose=tensor_parts['left_hand_pose'], right_hand_pose=tensor_parts['right_hand_pose'],
447
+ expression=tensor_parts['expression'], jaw_pose=tensor_parts['jaw_pose'],
448
+ leye_pose=tensor_parts['eye_pose'], reye_pose=tensor_parts['eye_pose'],
449
+ transl=tensor_parts['trans'], return_verts=True
450
+ )
451
+ all_verts.append(out.vertices.detach().cpu().numpy())
452
+
453
  return np.concatenate(all_verts, axis=0), smplx_model.faces.astype(np.int32)
454
 
455
+
456
+ def compute_upper_body_bounds(verts: np.ndarray) -> dict:
457
+ """Compute bounds for upper body view. SMPL-X: Y is up, Z is forward."""
458
+ if verts is None or verts.shape[0] == 0:
459
+ return None
460
+
461
  v = verts[0]
462
  y_min, y_max = v[:, 1].min(), v[:, 1].max()
463
  x_min, x_max = v[:, 0].min(), v[:, 0].max()
464
  z_min, z_max = v[:, 2].min(), v[:, 2].max()
465
+
466
  body_height = y_max - y_min
467
  waist_y = y_min + body_height * 0.45
468
+ upper_center_y = (waist_y + y_max) / 2
469
+
470
+ x_padding = (x_max - x_min) * 0.15
471
+ z_padding = (z_max - z_min) * 0.15
472
 
 
473
  return {
474
+ 'waist_y': waist_y,
475
+ 'upper_center_y': upper_center_y,
476
+ 'y_range': [waist_y - body_height * 0.05, y_max + body_height * 0.05],
477
+ 'x_range': [x_min - x_padding, x_max + x_padding],
478
+ 'z_range': [z_min - z_padding, z_max + z_padding],
479
+ 'center': [(x_min + x_max) / 2, upper_center_y, (z_min + z_max) / 2]
480
  }
481
 
482
  # =====================================================================
483
+ # Visualization Functions
484
  # =====================================================================
485
+ def create_animation_html(verts: np.ndarray, faces: np.ndarray, fps: int = 20,
486
+ upper_body_only: bool = True, title: str = "") -> str:
487
+ """Create Plotly animation HTML."""
488
+ if verts is None or faces is None or verts.shape[0] == 0:
489
+ return create_placeholder_html()
490
 
491
+ T, V, _ = verts.shape
492
  i, j, k = faces.T.tolist()
493
+ frame_duration = 1000 // fps
494
+
495
  bounds = compute_upper_body_bounds(verts) if upper_body_only else None
496
 
497
+ mesh = go.Mesh3d(
498
+ x=verts[0, :, 0], y=verts[0, :, 1], z=verts[0, :, 2],
499
+ i=i, j=j, k=k, flatshading=True, opacity=0.6,
500
+ color='#6FA8DC',
501
+ lighting=dict(ambient=0.6, diffuse=0.7, specular=0.2)
502
+ )
503
 
504
+ frames = [
505
+ go.Frame(
506
+ data=[go.Mesh3d(
507
+ x=verts[t, :, 0], y=verts[t, :, 1], z=verts[t, :, 2],
508
+ i=i, j=j, k=k, flatshading=True, opacity=0.6,
509
+ color='#6FA8DC',
510
+ lighting=dict(ambient=0.6, diffuse=0.7, specular=0.2)
511
+ )],
512
+ name=str(t)
513
+ )
514
+ for t in range(T)
515
+ ]
516
 
 
 
 
 
 
 
 
 
 
 
517
  fig = go.Figure(data=[mesh], frames=frames)
518
+
519
+ sliders = [dict(
520
+ active=0, yanchor="top", xanchor="left",
521
+ currentvalue=dict(font=dict(size=12), prefix="Frame: ", visible=True, xanchor="right"),
522
+ pad=dict(b=5, t=30), len=0.75, x=0.2, y=0.02,
523
+ steps=[
524
+ dict(args=[[str(t)], dict(frame=dict(duration=frame_duration, redraw=True), mode="immediate", transition=dict(duration=0))],
525
+ label=str(t) if t % 10 == 0 else "", method="animate")
526
+ for t in range(T)
527
+ ]
528
+ )]
529
+
530
+ if bounds and upper_body_only:
531
+ scene_config = dict(
532
+ aspectmode='manual', aspectratio=dict(x=1, y=1.2, z=1),
533
+ xaxis=dict(visible=False, showbackground=False, range=bounds['x_range']),
534
+ yaxis=dict(visible=False, showbackground=False, range=bounds['y_range']),
535
+ zaxis=dict(visible=False, showbackground=False, range=bounds['z_range']),
536
+ camera=dict(
537
+ eye=dict(x=0, y=bounds['center'][1] * 0.1, z=2.5),
538
+ center=dict(x=0, y=bounds['center'][1], z=0),
539
+ up=dict(x=0, y=1, z=0)
540
+ ),
541
+ bgcolor='rgba(250,250,250,1)'
542
+ )
543
+ else:
544
+ scene_config = dict(
545
+ aspectmode='data',
546
+ xaxis=dict(visible=False, showbackground=False),
547
+ yaxis=dict(visible=False, showbackground=False),
548
+ zaxis=dict(visible=False, showbackground=False),
549
+ camera=dict(eye=dict(x=0, y=0, z=2.5), up=dict(x=0, y=1, z=0)),
550
+ bgcolor='rgba(250,250,250,1)'
551
+ )
552
+
553
+ annotations = []
554
+ if title:
555
+ annotations.append(dict(
556
+ text=f"<b>{title}</b>",
557
+ x=0.5, y=1.0, xref="paper", yref="paper",
558
+ showarrow=False, font=dict(size=14),
559
+ xanchor="center", yanchor="bottom"
560
+ ))
561
+
562
  fig.update_layout(
563
+ scene=scene_config,
564
+ annotations=annotations,
565
+ updatemenus=[dict(
566
+ type="buttons", showactive=True,
567
+ x=0.02, y=0.02, xanchor="left", yanchor="bottom",
568
+ pad=dict(t=0, r=10), direction="right",
569
+ buttons=[
570
+ dict(label="Play", method="animate",
571
+ args=[None, {"frame": {"duration": frame_duration, "redraw": True}, "fromcurrent": True, "transition": {"duration": 0}}]),
572
+ dict(label="Pause", method="animate",
573
+ args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}]),
574
+ dict(label="Reset", method="animate",
575
+ args=[["0"], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}])
576
+ ]
577
+ )],
578
+ sliders=sliders,
579
+ height=500,
580
+ margin=dict(l=0, r=0, t=30 if title else 10, b=60),
581
+ paper_bgcolor='rgba(250,250,250,1)',
582
+ plot_bgcolor='rgba(250,250,250,1)'
583
+ )
584
+
585
+ return fig.to_html(
586
+ include_plotlyjs='cdn', full_html=True,
587
+ config={'displayModeBar': True, 'displaylogo': False, 'scrollZoom': True,
588
+ 'modeBarButtonsToRemove': ['lasso2d', 'select2d', 'toImage']}
589
  )
 
590
 
591
+
592
+ def create_side_by_side_html(verts1, faces1, verts2, faces2, title1="", title2="", fps=20) -> str:
593
+ """Create side-by-side animation HTML for two avatars."""
594
+ if verts1 is None or verts2 is None:
595
+ return create_placeholder_html()
596
+
597
  T = min(verts1.shape[0], verts2.shape[0])
598
  verts1, verts2 = verts1[:T], verts2[:T]
599
+
600
  i1, j1, k1 = faces1.T.tolist()
601
  i2, j2, k2 = faces2.T.tolist()
602
+ frame_duration = 1000 // fps
603
 
604
+ bounds1 = compute_upper_body_bounds(verts1)
605
+ bounds2 = compute_upper_body_bounds(verts2)
606
+
607
+ fig = make_subplots(
608
+ rows=1, cols=2,
609
+ specs=[[{'type': 'scene'}, {'type': 'scene'}]],
610
+ horizontal_spacing=0.02,
611
+ subplot_titles=[title1, title2]
612
+ )
613
+
614
+ mesh1 = go.Mesh3d(
615
+ x=verts1[0, :, 0], y=verts1[0, :, 1], z=verts1[0, :, 2],
616
+ i=i1, j=j1, k=k1, flatshading=True, opacity=0.6, color='#6FA8DC',
617
+ lighting=dict(ambient=0.6, diffuse=0.7, specular=0.2), scene='scene'
618
+ )
619
+ mesh2 = go.Mesh3d(
620
+ x=verts2[0, :, 0], y=verts2[0, :, 1], z=verts2[0, :, 2],
621
+ i=i2, j=j2, k=k2, flatshading=True, opacity=0.6, color='#93C47D',
622
+ lighting=dict(ambient=0.6, diffuse=0.7, specular=0.2), scene='scene2'
623
+ )
624
 
625
+ fig.add_trace(mesh1, row=1, col=1)
626
+ fig.add_trace(mesh2, row=1, col=2)
627
 
628
  frames = []
629
  for t in range(T):
630
+ frames.append(go.Frame(
631
+ name=str(t),
632
+ data=[
633
+ go.Mesh3d(x=verts1[t, :, 0], y=verts1[t, :, 1], z=verts1[t, :, 2],
634
+ i=i1, j=j1, k=k1, flatshading=True, opacity=0.6, color='#6FA8DC',
635
+ lighting=dict(ambient=0.6, diffuse=0.7, specular=0.2), scene='scene'),
636
+ go.Mesh3d(x=verts2[t, :, 0], y=verts2[t, :, 1], z=verts2[t, :, 2],
637
+ i=i2, j=j2, k=k2, flatshading=True, opacity=0.6, color='#93C47D',
638
+ lighting=dict(ambient=0.6, diffuse=0.7, specular=0.2), scene='scene2')
639
+ ]
640
+ ))
641
  fig.frames = frames
642
 
643
+ sliders = [dict(
644
+ active=0, yanchor="top", xanchor="left",
645
+ currentvalue=dict(font=dict(size=12), prefix="Frame: ", visible=True, xanchor="right"),
646
+ pad=dict(b=5, t=30), len=0.75, x=0.15, y=0.02,
647
+ steps=[
648
+ dict(args=[[str(t)], dict(frame=dict(duration=frame_duration, redraw=True), mode="immediate", transition=dict(duration=0))],
649
+ label=str(t) if t % 10 == 0 else "", method="animate")
650
+ for t in range(T)
651
+ ]
652
+ )]
653
+
654
+ def make_scene_config(bounds):
655
+ if bounds:
656
+ return dict(
657
+ aspectmode='manual', aspectratio=dict(x=1, y=1.2, z=1),
658
+ xaxis=dict(visible=False, showbackground=False, range=bounds['x_range']),
659
+ yaxis=dict(visible=False, showbackground=False, range=bounds['y_range']),
660
+ zaxis=dict(visible=False, showbackground=False, range=bounds['z_range']),
661
+ camera=dict(eye=dict(x=0, y=bounds['center'][1]*0.1, z=2.5),
662
+ center=dict(x=0, y=bounds['center'][1], z=0), up=dict(x=0, y=1, z=0)),
663
+ bgcolor='rgba(250,250,250,1)'
664
+ )
665
+ return dict(aspectmode='data', xaxis=dict(visible=False), yaxis=dict(visible=False),
666
+ zaxis=dict(visible=False), camera=dict(eye=dict(x=0, y=0, z=2.5), up=dict(x=0, y=1, z=0)),
667
+ bgcolor='rgba(250,250,250,1)')
668
+
669
  fig.update_layout(
670
+ scene=make_scene_config(bounds1),
671
+ scene2=make_scene_config(bounds2),
672
+ updatemenus=[dict(
673
+ type="buttons", showactive=True,
674
+ x=0.02, y=0.02, xanchor="left", yanchor="bottom",
675
+ pad=dict(t=0, r=10), direction="right",
676
+ buttons=[
677
+ dict(label="Play", method="animate",
678
+ args=[None, {"frame": {"duration": frame_duration, "redraw": True}, "fromcurrent": True, "transition": {"duration": 0}}]),
679
+ dict(label="Pause", method="animate",
680
+ args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}]),
681
+ dict(label="Reset", method="animate",
682
+ args=[["0"], {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"}])
683
+ ]
684
+ )],
685
+ sliders=sliders,
686
+ height=500,
687
+ margin=dict(l=0, r=0, t=40, b=60),
688
+ paper_bgcolor='rgba(250,250,250,1)'
689
+ )
690
+
691
+ return fig.to_html(
692
+ include_plotlyjs='cdn', full_html=True,
693
+ config={'displayModeBar': True, 'displaylogo': False, 'scrollZoom': True,
694
+ 'modeBarButtonsToRemove': ['lasso2d', 'select2d', 'toImage']}
695
  )
 
696
 
 
 
 
697
 
698
+ def create_placeholder_html() -> str:
699
+ return """
700
+ <div style="display: flex; justify-content: center; align-items: center;
701
+ height: 500px; background: #fafafa; border-radius: 4px; border: 1px solid #e0e0e0;">
702
+ <p style="font-size: 14px; color: #888;">Enter a word to generate motion</p>
703
+ </div>
704
+ """
705
+
706
+
707
+ def create_error_html(msg: str = "Error generating animation") -> str:
708
+ return f"""
709
+ <div style="display: flex; justify-content: center; align-items: center;
710
+ height: 500px; background: #fafafa; border-radius: 4px; border: 1px solid #e0e0e0;">
711
+ <p style="font-size: 14px; color: #c00;">{msg}</p>
712
+ </div>
713
+ """
714
+
715
 
716
+ def create_iframe_html(html_content: str, height: int = 530) -> str:
717
+ escaped_html = html_module.escape(html_content)
718
+ return f'''
719
+ <div style="width: 100%; height: {height}px; border: 1px solid #ddd; border-radius: 4px; overflow: hidden; background: #fafafa;">
720
+ <iframe srcdoc="{escaped_html}" style="width: 100%; height: 100%; border: none;" sandbox="allow-scripts allow-same-origin"></iframe>
721
+ </div>
722
+ '''
723
 
724
  # =====================================================================
725
+ # Main Processing Functions
726
  # =====================================================================
727
+ def generate_verts_for_word(word: str, pid: str) -> tuple:
728
+ """Generate vertices and faces for a word-PID pair."""
729
+ generated_tokens = generate_motion_tokens(word, pid)
730
+ token_ids = parse_motion_tokens(generated_tokens)
731
+
732
+ if not token_ids:
733
+ return None, None, generated_tokens
734
+
735
+ if _model_cache["vqvae_model"] is None or _model_cache["smplx_model"] is None:
736
+ return None, None, generated_tokens
737
+
738
+ params = decode_tokens_to_params(token_ids)
739
+ if params.shape[0] == 0:
740
+ return None, None, generated_tokens
741
+
742
  verts, faces = params_to_vertices(params)
743
+ return verts, faces, generated_tokens
744
+
745
 
746
+ def generate_animation_for_word(word: str, pid: str, upper_body_only: bool = True) -> tuple:
747
+ """Generate animation HTML and tokens for a word. Returns (html, tokens)."""
748
  verts, faces, tokens = generate_verts_for_word(word, pid)
749
+
750
+ if verts is None:
751
+ return create_placeholder_html(), tokens
752
+
753
+ animation_html = create_animation_html(verts, faces, upper_body_only=upper_body_only, title=f"{pid}")
754
+ return animation_html, tokens
755
+
756
 
757
+ def process_word(word: str):
758
+ """Main processing: generate side-by-side comparison for two random PIDs."""
759
+ if not word or not word.strip():
760
+ return create_iframe_html(create_placeholder_html()), ""
761
 
762
  word = word.strip().lower()
763
+
764
  pids = get_random_pids_for_word(word, 2)
765
 
766
  if not pids:
767
+ return create_iframe_html(create_error_html(f"Word '{word}' not found in dataset")), ""
768
 
769
+ if len(pids) == 1:
770
+ pids = [pids[0], pids[0]]
771
 
772
  try:
773
+ verts1, faces1, tokens1 = generate_verts_for_word(word, pids[0])
774
+ verts2, faces2, tokens2 = generate_verts_for_word(word, pids[1])
775
 
776
  if verts1 is None and verts2 is None:
777
+ return create_iframe_html(create_error_html("Failed to generate motion")), tokens1 or tokens2
778
 
779
+ if verts1 is None:
780
+ html = create_animation_html(verts2, faces2, upper_body_only=True, title=f"{pids[1]}")
781
+ return create_iframe_html(html), tokens2
782
+ if verts2 is None:
783
+ html = create_animation_html(verts1, faces1, upper_body_only=True, title=f"{pids[0]}")
784
+ return create_iframe_html(html), tokens1
785
+
786
+ html = create_side_by_side_html(verts1, faces1, verts2, faces2,
787
+ title1=f"{pids[0]}", title2=f"{pids[1]}")
788
+ combined_tokens = f"[{pids[0]}] {tokens1}\n\n[{pids[1]}] {tokens2}"
789
+ return create_iframe_html(html), combined_tokens
790
 
791
  except Exception as e:
792
+ return create_iframe_html(create_error_html(f"Error: {str(e)[:100]}")), ""
793
 
794
+
795
+ def get_example_animation(word: str, pid: str):
796
+ """Get pre-computed example animation."""
797
  key = f"{word}_{pid}"
798
  if key in _example_cache:
799
+ cached = _example_cache[key]
800
+ return create_iframe_html(cached["html"]), cached["tokens"]
801
+ html, tokens = generate_animation_for_word(word, pid, upper_body_only=True)
802
+ return create_iframe_html(html), tokens
803
 
804
  # =====================================================================
805
+ # Gradio Interface
806
  # =====================================================================
807
+ def create_gradio_interface():
808
+
809
+ default_html = create_iframe_html(create_placeholder_html())
810
+
811
+ custom_css = """
812
+ .gradio-container { max-width: 1400px !important; }
813
+ .example-row { margin-top: 15px; padding: 12px; background: #f8f9fa; border-radius: 6px; }
814
+ """
815
+
816
+ example_list = list(_example_cache.values()) if _example_cache else []
817
 
818
+ with gr.Blocks(title="SignMotionGPT", css=custom_css, theme=gr.themes.Default()) as demo:
819
+
820
  gr.Markdown("# SignMotionGPT Demo")
821
+ gr.Markdown("Text-to-Sign Language Motion Generation with Variant Comparison")
822
+
823
  with gr.Row():
824
+ with gr.Column(scale=1, min_width=280):
825
+ gr.Markdown("### Input")
826
+
827
+ word_input = gr.Textbox(
828
+ label="Word",
829
+ placeholder="Enter a word from the dataset...",
830
+ lines=1, max_lines=1
831
+ )
832
+
833
+ generate_btn = gr.Button("Generate Motion", variant="primary", size="lg")
834
 
835
+ gr.Markdown("---")
836
+ gr.Markdown("### Generated Tokens")
837
+
838
+ tokens_output = gr.Textbox(
839
+ label="Motion Tokens (both variants)",
840
+ lines=8,
841
+ interactive=False,
842
+ show_copy_button=True
843
+ )
844
+
845
+ if _word_pid_map:
846
+ sample_words = list(_word_pid_map.keys())[:10]
847
+ gr.Markdown(f"**Available words:** {', '.join(sample_words)}, ...")
848
+
849
+ with gr.Column(scale=2, min_width=700):
850
+ gr.Markdown("### Motion Comparison (Two Signer Variants)")
851
+ animation_output = gr.HTML(value=default_html, elem_id="animation-container")
852
+
853
+ if example_list:
854
+ gr.Markdown("---")
855
+ gr.Markdown("### Pre-computed Examples")
856
+
857
+ for item in example_list:
858
+ word, pid = item['word'], item['pid']
859
+ with gr.Row(elem_classes="example-row"):
860
+ with gr.Column(scale=1, min_width=120):
861
+ gr.Markdown(f"**{word.capitalize()}**")
862
+ gr.Markdown(f"Variant: {pid}")
863
+ example_btn = gr.Button(f"Load", size="sm")
864
+
865
+ with gr.Column(scale=3, min_width=500):
866
+ example_html = gr.HTML(
867
+ value=create_iframe_html(create_placeholder_html(), height=450),
868
+ elem_id=f"example-{word}"
869
  )
870
+
871
+ example_btn.click(
872
+ fn=lambda w=word, p=pid: get_example_animation(w, p),
873
+ inputs=[],
874
+ outputs=[example_html, tokens_output]
875
+ )
876
+
877
+ gr.Markdown("---")
878
+ gr.Markdown("*SignMotionGPT: LLM-based sign language motion generation*")
879
+
880
+ generate_btn.click(
881
+ fn=process_word,
882
+ inputs=[word_input],
883
+ outputs=[animation_output, tokens_output]
884
+ )
885
 
886
+ word_input.submit(
887
+ fn=process_word,
888
+ inputs=[word_input],
889
+ outputs=[animation_output, tokens_output]
890
+ )
891
+
892
  return demo
893
 
894
+ # =====================================================================
895
+ # Main Entry Point for HuggingFace Spaces
896
+ # =====================================================================
897
+ print("\n" + "="*60)
898
+ print(" SignMotionGPT - HuggingFace Spaces")
899
+ print("="*60)
900
+ print(f"Device: {DEVICE}")
901
+ print(f"Model: {HF_REPO_ID}/{HF_SUBFOLDER}")
902
+ print(f"Data Directory: {DATA_DIR}")
903
+ print(f"Dataset: {DATASET_PATH}")
904
+ print("="*60 + "\n")
905
+
906
+ # Initialize models at startup
907
+ initialize_models()
908
+
909
+ # Pre-compute example animations
910
+ precompute_examples()
911
 
912
+ # Create and launch interface
913
+ demo = create_gradio_interface()
914
+
915
+ if __name__ == "__main__":
916
+ demo.launch()