rdz-falcon commited on
Commit
1ea37cf
·
verified ·
1 Parent(s): 2d54a11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +327 -301
app.py CHANGED
@@ -1,349 +1,375 @@
1
- import gradio as gr
2
- import torch
 
3
  import os
4
  import sys
5
- import warnings
6
  import re
7
  import json
8
  import random
 
 
 
9
  from pathlib import Path
10
 
11
- # Add root to path to allow imports from project root when running from demo-code/
12
- # or when running from root
13
- current_dir = os.path.dirname(os.path.abspath(__file__))
14
- parent_dir = os.path.dirname(current_dir)
15
- sys.path.append(current_dir)
16
- sys.path.append(parent_dir)
17
 
18
- # Import project modules
19
- try:
20
- from visualize import visualize
21
- # Try importing what we can, but we will implement generation logic directly here
22
- # to match test_overfit.py / metrics.py exactly and avoid dependency issues.
23
- # We catch Exception because unsloth in model.py might raise NotImplementedError on CPU
24
- from model import get_motion_token_info
25
- except Exception as e:
26
- print(f"Error importing project modules: {e}")
27
- print("Make sure you are running this from the project root or have the project structure intact.")
28
- # Fallback for explicit relative imports if needed in some environments
29
- try:
30
- from visualize import visualize
31
- except Exception as vis_e:
32
- print(f"Visualize import failed too: {vis_e}")
33
 
34
- # Constants
35
- HF_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
36
- EPOCH_SUBFOLDER = "stage2_v2/epoch-030"
37
- CODEBOOK_SIZE = 512
38
- DATASET_PATH = os.environ.get("DATASET_PATH", "enriched_dataset.json")
39
 
40
- # Hardcoded Config from test_overfit.py / config.py
 
 
 
 
 
 
 
 
 
 
 
 
41
  INFERENCE_TEMPERATURE = 0.7
42
  INFERENCE_TOP_K = 50
43
  INFERENCE_REPETITION_PENALTY = 1.2
44
- M_START = "<M_START>"
45
- M_END = "<M_END>"
46
 
47
- # Global model cache
48
- MODEL = None
49
- TOKENIZER = None
50
- # We use M_START/M_END as in test_overfit.py
51
- M_START_ID = None
52
- M_END_ID = None
53
- VARIANT_MAP = {}
54
 
55
- def load_variant_map():
56
- """Load dataset to map words to valid participant IDs."""
57
- global VARIANT_MAP
58
-
59
- # Try multiple possible paths for the dataset
60
- candidates = [
61
- DATASET_PATH,
62
- os.path.join(os.path.dirname(__file__), DATASET_PATH),
63
- os.path.join(os.path.dirname(__file__), "..", DATASET_PATH),
64
- "data/motion_llm_dataset.json", # Fallback to raw dataset if enriched missing
65
- "motion_llm_dataset.json"
66
- ]
67
-
68
- found_path = None
69
- for p in candidates:
70
- if os.path.exists(p):
71
- found_path = p
72
- break
73
-
74
- if found_path:
75
- print(f"Loading variants from {found_path}...")
76
- try:
77
- with open(found_path, 'r', encoding='utf-8') as f:
78
- data = json.load(f)
79
-
80
- mapping = {}
81
- count = 0
82
- for entry in data:
83
- # Support both formats (enriched or raw)
84
- word = entry.get("word") or entry.get("text_query")
85
- if not word: continue
86
-
87
- # Clean word (sometimes text_query is "Motion for word 'hello'")
88
- if "motion for word" in word.lower():
89
- # extraction heuristic if needed, but 'word' field is preferred
90
- pass
91
-
92
- word = word.lower().strip()
93
- pid = entry.get("participant_id")
94
-
95
- if word and pid:
96
- if word not in mapping:
97
- mapping[word] = []
98
- if pid not in mapping[word]:
99
- mapping[word].append(str(pid))
100
- count += 1
101
-
102
- VARIANT_MAP = mapping
103
- print(f"Loaded {count} variants for {len(VARIANT_MAP)} words.")
104
-
105
- # Debug check for 'push'
106
- if 'push' in VARIANT_MAP:
107
- print(f" 'push' variants: {VARIANT_MAP['push']}")
108
- else:
109
- print(" 'push' NOT found in dataset.")
110
-
111
- except Exception as e:
112
- print(f"Error loading dataset: {e}")
113
- else:
114
- print(f"⚠️ Dataset not found. Tried: {candidates}. Variants will default to 'unknown'.")
115
-
116
- # Hardcoded fallback for demonstration words if missing from dataset
117
- defaults = {
118
- "push": ["P40", "P123", "P1"],
119
- "send": ["P40", "P123"],
120
- "library": ["P40"],
121
- "passport": ["P40"]
122
- }
123
- for w, pids in defaults.items():
124
- if w not in VARIANT_MAP:
125
- VARIANT_MAP[w] = pids
126
- print(f" Added fallback variants for '{w}': {pids}")
127
 
128
- def load_model_from_hf(repo_id, subfolder, token=None):
129
- from transformers import AutoModelForCausalLM, AutoTokenizer
130
- print(f"Loading model from HF: {repo_id}/{subfolder}")
 
 
 
 
131
  try:
132
- tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
133
- model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
134
- return model, tokenizer
135
- except Exception as e:
136
- print(f"Error loading model: {e}")
137
- return None, None
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- def init_model():
140
- global MODEL, TOKENIZER, M_START_ID, M_END_ID
141
- if MODEL is not None:
 
 
 
 
 
 
 
 
142
  return
143
 
144
- load_variant_map()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
147
-
148
- # Load model/tokenizer
149
- MODEL, TOKENIZER = load_model_from_hf(HF_REPO_ID, EPOCH_SUBFOLDER, token)
150
-
151
- if MODEL is None:
152
- raise RuntimeError(f"Failed to load model from {HF_REPO_ID}/{EPOCH_SUBFOLDER}")
153
 
154
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155
- MODEL.to(device)
156
- MODEL.eval()
 
 
 
 
 
157
 
158
- # Setup special tokens matching test_overfit.py
159
- # test_overfit.py uses M_START="<M_START>" and M_END="<M_END>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- # Check if tokens exist
162
- if M_START not in TOKENIZER.get_vocab() or M_END not in TOKENIZER.get_vocab():
163
- print(f"⚠️ Warning: {M_START} or {M_END} not found in tokenizer. Adding them now...")
164
- num_added = TOKENIZER.add_special_tokens({"additional_special_tokens": [M_START, M_END]})
165
- if num_added > 0:
166
- MODEL.resize_token_embeddings(len(TOKENIZER))
167
- print(f" Added {num_added} special tokens.")
168
 
169
- M_START_ID = TOKENIZER.convert_tokens_to_ids(M_START)
170
- M_END_ID = TOKENIZER.convert_tokens_to_ids(M_END)
 
 
171
 
172
- # Check motion tokens
173
- # We expect <motion_0> ... <motion_511>
174
- # If missing, add them
175
- first_motion = "<motion_0>"
176
- if first_motion not in TOKENIZER.get_vocab():
177
- print("⚠️ Warning: Motion tokens not found. Adding them now...")
178
- motion_tokens = [f"<motion_{i}>" for i in range(CODEBOOK_SIZE)]
179
- num_added = TOKENIZER.add_tokens(motion_tokens, special_tokens=True)
180
- if num_added > 0:
181
- MODEL.resize_token_embeddings(len(TOKENIZER))
182
- print(f" Added {num_added} motion tokens.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- print(f"Model initialized. Vocab size: {len(TOKENIZER)}")
185
- print(f"M_START_ID: {M_START_ID}, M_END_ID: {M_END_ID}")
 
 
 
 
 
 
 
 
 
 
186
 
187
- def generate_motion_simple(model, tokenizer, prompt_text, device):
188
- """
189
- Replicates the simple generation logic from metrics.py / test_overfit.py
190
- """
191
- # Construct prompt exactly as in test_overfit.py:
192
- # prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
193
 
194
- # Get a valid participant ID if possible
195
- word_lower = prompt_text.lower().strip()
196
- variants = VARIANT_MAP.get(word_lower, [])
 
197
 
198
- if variants:
199
- pid = random.choice(variants)
200
- print(f"Selected variant '{pid}' for word '{prompt_text}'")
201
- else:
202
- # Fallback to 'unknown' or a common PID if known (e.g., P1)
203
- pid = "unknown"
204
- print(f"No variants found for '{prompt_text}', using '{pid}'")
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
- prompt = f"Instruction: Generate motion for word '{prompt_text}' with variant '{pid}'.\nMotion: "
 
 
 
 
 
 
207
 
208
- print(f"Input Prompt:\n{prompt}")
 
209
 
210
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- with torch.no_grad():
213
- output = model.generate(
214
- **inputs,
215
- max_new_tokens=100,
216
- do_sample=True,
217
- temperature=INFERENCE_TEMPERATURE,
218
- top_k=INFERENCE_TOP_K,
219
- repetition_penalty=INFERENCE_REPETITION_PENALTY,
220
- pad_token_id=tokenizer.pad_token_id,
221
- eos_token_id=M_END_ID, # Stop at <M_END>
222
- early_stopping=True
223
- )
224
 
225
- decoded = tokenizer.decode(output[0], skip_special_tokens=False)
 
 
 
 
 
 
226
 
227
- # Parse output to extract just the motion part
228
- # We expect: ... \nMotion: <M_START> <motion_...> ... <M_END>
229
- if "Motion: " in decoded:
230
- motion_part = decoded.split("Motion: ")[-1]
231
- else:
232
- motion_part = decoded
233
-
234
- return motion_part.strip()
235
 
236
- def generate_motion_app(text_prompt):
237
- if not text_prompt:
238
- return None, "Please enter a prompt."
 
 
 
239
 
240
- if MODEL is None:
241
- try:
242
- init_model()
243
- except Exception as e:
244
- return None, f"Model Initialization Failed: {e}"
245
-
246
- device = MODEL.device
247
- print(f"Generating for: {text_prompt}")
248
 
249
- try:
250
- generated_sequence = generate_motion_simple(MODEL, TOKENIZER, text_prompt, device)
251
- print("Generated sequence (raw):", generated_sequence)
252
-
253
- # Extract tokens for visualization
254
- # Logic from metrics.py: _extract_motion_tokens_from_sequence
255
- # Expect tokens like <M123> or <motion_123>
256
- # The generation might include M_START/M_END.
257
-
258
- # Clean up for visualization input
259
- # We need a string of tokens.
260
- # If the output is like "<M_START> <motion_1> <motion_2> <M_END>", we pass that.
261
- # visualize.py's parse_motion_tokens handles <motion_ID> regex.
262
- # BUT visualize.py expects either "123 456" OR "<motion_123> <motion_456>"
263
- # It does NOT explicitly handle <M123> which is what we might have here if M_START was used.
264
- # Let's convert <M123> to space-separated integers for safety.
265
-
266
- # Extract integers from <M123> or <motion_123>
267
- # generated_sequence is raw string from tokenizer decode
268
-
269
- import re
270
- # Try <M123> format (test_overfit style)
271
- m_tokens = re.findall(r'<M(\d+)>', generated_sequence)
272
- if not m_tokens:
273
- # Try <motion_123> format
274
- m_tokens = re.findall(r'<motion_(\d+)>', generated_sequence)
275
-
276
- if m_tokens:
277
- # Reconstruct as space-separated string for visualize.py
278
- tokens_for_vis = " ".join(m_tokens)
279
- else:
280
- # Fallback to raw string if regex failed (visualize.py might handle other formats)
281
- tokens_for_vis = generated_sequence
282
-
283
- print(f"Tokens for visualization: {tokens_for_vis[:50]}...")
284
-
285
- except Exception as e:
286
- return None, f"Generation Error: {e}"
287
 
288
- # Visualization
289
- try:
290
- # Ensure paths for VQ-VAE and SMPL-X
291
- data_dir = os.environ.get("DATA_DIR", "data")
292
- vqvae_ckpt = os.path.join(data_dir, "vqvae_model.pt")
293
- stats_path = os.path.join(data_dir, "vqvae_stats.pt")
294
- smplx_dir = os.path.join(data_dir, "smplx_models")
295
-
296
- # Check existence
297
- missing = []
298
- if not os.path.exists(vqvae_ckpt): missing.append(vqvae_ckpt)
299
- if not os.path.exists(stats_path): missing.append(stats_path)
300
- if not os.path.exists(smplx_dir): missing.append(smplx_dir)
301
 
302
- if missing:
303
- return None, f"Missing visualization files in {data_dir}: {missing}. Please ensure they are uploaded to the Space."
 
 
 
 
 
 
 
 
304
 
305
- # Output to a temporary file
306
- output_html = "temp_viz.html"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
- fig = visualize(
309
- tokens=tokens_for_vis,
310
- vqvae_ckpt=vqvae_ckpt,
311
- stats_path=stats_path,
312
- smplx_dir=smplx_dir,
313
- output_html=output_html,
314
- title=f"Motion: {text_prompt}",
315
- fps=20
316
- )
 
 
317
 
318
- if fig is None:
319
- return None, "Visualization failed (no frames produced)."
320
-
321
- # Count tokens for display
322
- matches = re.findall(r'<motion_(\d+)>', tokens_for_vis)
323
- # Also check for <M...> format just in case
324
- if not matches:
325
- matches = re.findall(r'<M(\d+)>', tokens_for_vis)
326
-
327
- num_tokens = len(matches)
328
 
329
- return fig, f"Success! Generated tokens length: {num_tokens}. Sequence: {tokens_for_vis[:100]}..."
 
330
 
331
- except Exception as e:
332
- return None, f"Visualization Error: {e}"
333
-
334
-
335
- # Gradio UI
336
- with gr.Interface(
337
- fn=generate_motion_app,
338
- inputs=gr.Textbox(label="Enter Motion Prompt", placeholder="e.g. walking forward"),
339
- outputs=[
340
- gr.Plot(label="Motion Visualization"),
341
- gr.Textbox(label="Status/Output")
342
- ],
343
- title="SignMotionGPT Demo",
344
- description="Generate Sign Language/Motion Avatars from Text. Using model checkpoint: epoch 30."
345
- ) as demo:
346
- pass
347
 
348
  if __name__ == "__main__":
349
- demo.launch()
 
 
 
 
1
+ """
2
+ Gradio Interface for SignMotionGPT (HF Spaces Compatible)
3
+ """
4
  import os
5
  import sys
 
6
  import re
7
  import json
8
  import random
9
+ import argparse
10
+ import time
11
+ import warnings
12
  from pathlib import Path
13
 
14
+ import torch
15
+ import numpy as np
16
+ import gradio as gr
17
+ import plotly.graph_objects as go
18
+ from plotly.subplots import make_subplots
19
+ import smplx
20
 
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+
23
+ warnings.filterwarnings("ignore")
24
+
25
+ # =====================================================================
26
+ # Configuration & Paths
27
+ # =====================================================================
28
+ # Setup directories for HF Spaces
29
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
30
+ OUTPUT_DIR = os.path.join(BASE_DIR, "generated_outputs")
31
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
32
+
33
+ # Add project root to path
34
+ sys.path.append(BASE_DIR)
 
35
 
36
+ HF_REPO_ID = os.environ.get("HF_REPO_ID", "rdz-falcon/SignMotionGPTfit-archive")
37
+ HF_SUBFOLDER = os.environ.get("HF_SUBFOLDER", "stage2_v2/epoch-030")
 
 
 
38
 
39
+ DATA_DIR = os.environ.get("DATA_DIR", os.path.join(BASE_DIR, "data"))
40
+ DATASET_PATH = os.environ.get("DATASET_PATH", os.path.join(BASE_DIR, "enriched_dataset.json"))
41
+
42
+ VQVAE_CHECKPOINT = os.environ.get("VQVAE_CHECKPOINT", os.path.join(DATA_DIR, "vqvae_model.pt"))
43
+ STATS_PATH = os.environ.get("VQVAE_STATS_PATH", os.path.join(DATA_DIR, "vqvae_stats.pt"))
44
+ SMPLX_MODEL_DIR = os.environ.get("SMPLX_MODEL_DIR", os.path.join(DATA_DIR, "smplx_models"))
45
+
46
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
+ # Model Config
49
+ M_START = "<M_START>"
50
+ M_END = "<M_END>"
51
+ PAD_TOKEN = "<PAD>"
52
  INFERENCE_TEMPERATURE = 0.7
53
  INFERENCE_TOP_K = 50
54
  INFERENCE_REPETITION_PENALTY = 1.2
 
 
55
 
56
+ SMPL_DIM = 182
57
+ CODEBOOK_SIZE = 512
58
+ CODE_DIM = 512
59
+ VQ_ARGS = dict(
60
+ width=512, depth=3, down_t=2, stride_t=2,
61
+ dilation_growth_rate=3, activation='relu', norm=None, quantizer="ema_reset"
62
+ )
63
 
64
+ PARAM_DIMS = [10, 63, 45, 45, 3, 10, 3, 3]
65
+ PARAM_NAMES = ["betas", "body_pose", "left_hand_pose", "right_hand_pose",
66
+ "trans", "expression", "jaw_pose", "eye_pose"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # =====================================================================
69
+ # Import VQ-VAE architecture
70
+ # =====================================================================
71
+ try:
72
+ # Try importing from local project structure
73
+ from mGPT.archs.mgpt_vq import VQVae
74
+ except ImportError:
75
  try:
76
+ # Fallback for flat structure
77
+ from archs.mgpt_vq import VQVae
78
+ except ImportError:
79
+ print("⚠️ Warning: Could not import VQVae architecture.")
80
+ VQVae = None
81
+
82
+ # =====================================================================
83
+ # Global Cache
84
+ # =====================================================================
85
+ _model_cache = {
86
+ "llm_model": None,
87
+ "llm_tokenizer": None,
88
+ "vqvae_model": None,
89
+ "smplx_model": None,
90
+ "stats": (None, None),
91
+ "initialized": False
92
+ }
93
 
94
+ _word_pid_map = {}
95
+ _example_cache = {}
96
+
97
+ # =====================================================================
98
+ # Dataset Loading
99
+ # =====================================================================
100
+ def load_word_pid_mapping():
101
+ global _word_pid_map
102
+ if not os.path.exists(DATASET_PATH):
103
+ # Fallback defaults if dataset missing
104
+ _word_pid_map = {"push": ["P40"], "send": ["P40"]}
105
  return
106
 
107
+ try:
108
+ with open(DATASET_PATH, 'r', encoding='utf-8') as f:
109
+ data = json.load(f)
110
+
111
+ mapping = {}
112
+ for entry in data:
113
+ word = (entry.get('word') or entry.get('text_query', '')).lower().strip()
114
+ pid = entry.get('participant_id')
115
+ if word and pid:
116
+ mapping.setdefault(word, set()).add(str(pid))
117
+
118
+ _word_pid_map = {k: sorted(list(v)) for k, v in mapping.items()}
119
+ print(f"Loaded {len(_word_pid_map)} words from dataset")
120
+ except Exception as e:
121
+ print(f"Error loading dataset: {e}")
122
 
123
+ def get_random_pids_for_word(word: str, count: int = 2) -> list:
124
+ pids = _word_pid_map.get(word.lower().strip(), [])
125
+ if not pids: return []
126
+ if len(pids) <= count: return pids
127
+ return random.sample(pids, count)
 
 
128
 
129
+ # =====================================================================
130
+ # Models
131
+ # =====================================================================
132
+ class MotionGPT_VQVAE_Wrapper(torch.nn.Module):
133
+ def __init__(self, smpl_dim=SMPL_DIM, codebook_size=CODEBOOK_SIZE, code_dim=CODE_DIM, **kwargs):
134
+ super().__init__()
135
+ if VQVae is None: raise RuntimeError("VQVae architecture missing")
136
+ self.vqvae = VQVae(nfeats=smpl_dim, code_num=codebook_size, code_dim=code_dim, output_emb_width=code_dim, **kwargs)
137
 
138
+ def initialize_models():
139
+ global _model_cache
140
+ if _model_cache["initialized"]: return
141
+
142
+ print("Initializing Models...")
143
+ load_word_pid_mapping()
144
+
145
+ # LLM
146
+ print(f"Loading LLM: {HF_REPO_ID}")
147
+ tok = AutoTokenizer.from_pretrained(HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True)
148
+ model = AutoModelForCausalLM.from_pretrained(HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True)
149
+ if tok.pad_token is None: tok.add_special_tokens({"pad_token": PAD_TOKEN})
150
+ model.resize_token_embeddings(len(tok))
151
+ model.to(DEVICE).eval()
152
+ _model_cache["llm_model"] = model
153
+ _model_cache["llm_tokenizer"] = tok
154
 
155
+ # VQ-VAE
156
+ if os.path.exists(VQVAE_CHECKPOINT):
157
+ vq = MotionGPT_VQVAE_Wrapper(**VQ_ARGS).to(DEVICE)
158
+ ckpt = torch.load(VQVAE_CHECKPOINT, map_location=DEVICE)
159
+ vq.load_state_dict(ckpt.get('model_state_dict', ckpt), strict=False)
160
+ vq.eval()
161
+ _model_cache["vqvae_model"] = vq
162
 
163
+ # Stats
164
+ if os.path.exists(STATS_PATH):
165
+ st = torch.load(STATS_PATH, map_location='cpu')
166
+ _model_cache["stats"] = (st.get('mean', 0), st.get('std', 1))
167
 
168
+ # SMPL-X
169
+ if os.path.exists(SMPLX_MODEL_DIR):
170
+ _model_cache["smplx_model"] = smplx.SMPLX(
171
+ model_path=SMPLX_MODEL_DIR, model_type='smplx', gender='neutral', use_pca=False,
172
+ create_global_orient=True, create_body_pose=True, create_betas=True,
173
+ create_expression=True, create_jaw_pose=True, create_left_hand_pose=True,
174
+ create_right_hand_pose=True, create_transl=True
175
+ ).to(DEVICE)
176
+
177
+ _model_cache["initialized"] = True
178
+ print("Models Initialized.")
179
+
180
+ # =====================================================================
181
+ # Generation Logic
182
+ # =====================================================================
183
+ def generate_motion_tokens(word: str, variant: str) -> str:
184
+ model, tok = _model_cache["llm_model"], _model_cache["llm_tokenizer"]
185
+ prompt = f"Instruction: Generate motion for word '{word}' with variant '{variant}'.\nMotion: "
186
+ inputs = tok(prompt, return_tensors="pt").to(DEVICE)
187
+ with torch.no_grad():
188
+ out = model.generate(
189
+ **inputs, max_new_tokens=100, do_sample=True,
190
+ temperature=INFERENCE_TEMPERATURE, top_k=INFERENCE_TOP_K,
191
+ repetition_penalty=INFERENCE_REPETITION_PENALTY,
192
+ eos_token_id=tok.convert_tokens_to_ids(M_END)
193
+ )
194
+ decoded = tok.decode(out[0], skip_special_tokens=False)
195
+ return decoded.split("Motion: ")[-1].strip() if "Motion: " in decoded else decoded.strip()
196
 
197
+ def decode_tokens_to_params(tokens: list) -> np.ndarray:
198
+ vq, (mean, std) = _model_cache["vqvae_model"], _model_cache["stats"]
199
+ if not vq or not tokens: return np.zeros((0, SMPL_DIM))
200
+
201
+ idx = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
202
+ with torch.no_grad():
203
+ emb = vq.vqvae.quantizer.codebook[idx].permute(0, 2, 1)
204
+ decoded = vq.vqvae.decoder(emb)
205
+ params = vq.vqvae.postprocess(decoded).squeeze(0).cpu().numpy()
206
+
207
+ if mean is not None: params = (params * std) + mean
208
+ return params
209
 
210
+ def params_to_vertices(params: np.ndarray):
211
+ smpl = _model_cache["smplx_model"]
212
+ if not smpl or params.shape[0] == 0: return None, None
 
 
 
213
 
214
+ # Split params (simplified logic for brevity)
215
+ dims = [10, 63, 45, 45, 3, 10, 3, 3]
216
+ split_params = np.split(params, np.cumsum(dims)[:-1], axis=1)
217
+ tensor_parts = [torch.from_numpy(p).to(DEVICE).float() for p in split_params]
218
 
219
+ # Batch processing to avoid OOM
220
+ verts_list = []
221
+ for i in range(0, params.shape[0], 32):
222
+ batch = [t[i:i+32] for t in tensor_parts]
223
+ with torch.no_grad():
224
+ # Handle global_orient vs body_pose split
225
+ bp_full = batch[1]
226
+ go = bp_full[:, :3]
227
+ bp = bp_full[:, 3:]
228
+
229
+ out = smpl(
230
+ betas=batch[0], global_orient=go, body_pose=bp,
231
+ left_hand_pose=batch[2], right_hand_pose=batch[3],
232
+ transl=batch[4], expression=batch[5],
233
+ jaw_pose=batch[6], leye_pose=batch[7], reye_pose=batch[7]
234
+ )
235
+ verts_list.append(out.vertices.cpu().numpy())
236
+
237
+ return np.concatenate(verts_list, axis=0), smpl.faces
238
 
239
+ # =====================================================================
240
+ # Visualization (Plotly -> HTML)
241
+ # =====================================================================
242
+ def create_side_by_side_html(verts1, faces1, verts2, faces2, title1="", title2="", fps=20):
243
+ # Truncate to matching length
244
+ min_len = min(len(verts1), len(verts2))
245
+ v1, v2 = verts1[:min_len], verts2[:min_len]
246
 
247
+ fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'scene'}, {'type': 'scene'}]],
248
+ subplot_titles=[title1, title2])
249
 
250
+ # Add first frame
251
+ for col, v, c in [(1, v1, '#6FA8DC'), (2, v2, '#93C47D')]:
252
+ fig.add_trace(go.Mesh3d(
253
+ x=v[0,:,0], y=v[0,:,1], z=v[0,:,2],
254
+ i=faces1[:,0], j=faces1[:,1], k=faces1[:,2],
255
+ color=c, opacity=0.8, flatshading=True
256
+ ), row=1, col=col)
257
+
258
+ # Frames
259
+ frames = []
260
+ for t in range(min_len):
261
+ frames.append(go.Frame(data=[
262
+ go.Mesh3d(x=v1[t,:,0], y=v1[t,:,1], z=v1[t,:,2]),
263
+ go.Mesh3d(x=v2[t,:,0], y=v2[t,:,1], z=v2[t,:,2])
264
+ ], name=str(t)))
265
 
266
+ fig.frames = frames
 
 
 
 
 
 
 
 
 
 
 
267
 
268
+ # Animation settings
269
+ fig.update_layout(
270
+ updatemenus=[dict(type="buttons", buttons=[dict(label="Play", method="animate", args=[None])])],
271
+ scene=dict(aspectmode='data', xaxis_visible=False, yaxis_visible=False, zaxis_visible=False),
272
+ scene2=dict(aspectmode='data', xaxis_visible=False, yaxis_visible=False, zaxis_visible=False),
273
+ height=500, margin=dict(l=0, r=0, t=30, b=0)
274
+ )
275
 
276
+ return fig.to_html(include_plotlyjs='cdn', full_html=True)
 
 
 
 
 
 
 
277
 
278
+ def create_single_html(verts, faces, title="", fps=20):
279
+ fig = go.Figure(go.Mesh3d(
280
+ x=verts[0,:,0], y=verts[0,:,1], z=verts[0,:,2],
281
+ i=faces[:,0], j=faces[:,1], k=faces[:,2],
282
+ color='#6FA8DC', opacity=0.8, flatshading=True
283
+ ))
284
 
285
+ frames = [go.Frame(data=[go.Mesh3d(x=verts[t,:,0], y=verts[t,:,1], z=verts[t,:,2])], name=str(t))
286
+ for t in range(len(verts))]
287
+ fig.frames = frames
 
 
 
 
 
288
 
289
+ fig.update_layout(
290
+ title=title,
291
+ updatemenus=[dict(type="buttons", buttons=[dict(label="Play", method="animate", args=[None])])],
292
+ scene=dict(aspectmode='data', xaxis_visible=False, yaxis_visible=False, zaxis_visible=False),
293
+ height=500
294
+ )
295
+ return fig.to_html(include_plotlyjs='cdn', full_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
+ # =====================================================================
298
+ # Main Logic with File Saving
299
+ # =====================================================================
300
+ def save_and_get_iframe(html_content, filename_suffix=""):
301
+ """Saves HTML to disk and returns an Iframe pointing to it."""
302
+ filename = f"vis_{int(time.time())}_{filename_suffix}.html"
303
+ filepath = os.path.join(OUTPUT_DIR, filename)
304
+
305
+ with open(filepath, "w", encoding="utf-8") as f:
306
+ f.write(html_content)
 
 
 
307
 
308
+ # Use the /file= route to serve the absolute path
309
+ # allow-same-origin and allow-scripts are crucial for Plotly
310
+ iframe = f"""
311
+ <iframe src="/file={filepath}"
312
+ width="100%" height="550px"
313
+ style="border:none; background:#fafafa;"
314
+ sandbox="allow-scripts allow-same-origin">
315
+ </iframe>
316
+ """
317
+ return iframe
318
 
319
+ def process_word(word: str):
320
+ if not word.strip(): return None, ""
321
+
322
+ pids = get_random_pids_for_word(word, 2)
323
+ if not pids: pids = ["Unknown", "Unknown"]
324
+ elif len(pids) == 1: pids = [pids[0], pids[0]]
325
+
326
+ # Generate 1
327
+ raw1 = generate_motion_tokens(word, pids[0])
328
+ toks1 = [int(x) for x in re.findall(r'<M(\d+)>', raw1)]
329
+ verts1, faces = params_to_vertices(decode_tokens_to_params(toks1))
330
+
331
+ # Generate 2
332
+ raw2 = generate_motion_tokens(word, pids[1])
333
+ toks2 = [int(x) for x in re.findall(r'<M(\d+)>', raw2)]
334
+ verts2, _ = params_to_vertices(decode_tokens_to_params(toks2))
335
+
336
+ if verts1 is not None and verts2 is not None:
337
+ html = create_side_by_side_html(verts1, faces, verts2, faces, title1=f"{pids[0]}", title2=f"{pids[1]}")
338
+ elif verts1 is not None:
339
+ html = create_single_html(verts1, faces, title=f"{pids[0]}")
340
+ else:
341
+ return "<div>Error generating motion</div>", ""
342
 
343
+ iframe = save_and_get_iframe(html, f"{word}")
344
+ return iframe, f"[{pids[0]}] {len(toks1)} toks\n[{pids[1]}] {len(toks2)} toks"
345
+
346
+ # =====================================================================
347
+ # UI
348
+ # =====================================================================
349
+ def create_ui():
350
+ custom_css = ".gradio-container { max-width: 1400px !important; }"
351
+
352
+ with gr.Blocks(css=custom_css, title="SignMotionGPT") as demo:
353
+ gr.Markdown("# SignMotionGPT Comparison Demo")
354
 
355
+ with gr.Row():
356
+ with gr.Column(scale=1):
357
+ txt_input = gr.Textbox(label="Word", placeholder="push")
358
+ btn = gr.Button("Generate Comparison", variant="primary")
359
+ out_toks = gr.Textbox(label="Details", lines=4)
360
+
361
+ with gr.Column(scale=2):
362
+ out_html = gr.HTML(label="Visualization")
363
+
364
+ btn.click(process_word, inputs=txt_input, outputs=[out_html, out_toks])
365
 
366
+ # Initialize
367
+ initialize_models()
368
 
369
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
  if __name__ == "__main__":
372
+ demo = create_ui()
373
+ print(f"🚀 Launching. Output dir: {OUTPUT_DIR}")
374
+ # allowed_paths=[OUTPUT_DIR] is the magic key for HF Spaces
375
+ demo.launch(server_name="0.0.0.0", server_port=7860, allowed_paths=[OUTPUT_DIR])