rdz-falcon commited on
Commit
2d54a11
·
verified ·
1 Parent(s): c586336

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -122
app.py CHANGED
@@ -6,10 +6,10 @@ import warnings
6
  import re
7
  import json
8
  import random
9
- import base64
10
  from pathlib import Path
11
 
12
- # Add root to path to allow imports from project root
 
13
  current_dir = os.path.dirname(os.path.abspath(__file__))
14
  parent_dir = os.path.dirname(current_dir)
15
  sys.path.append(current_dir)
@@ -18,12 +18,18 @@ sys.path.append(parent_dir)
18
  # Import project modules
19
  try:
20
  from visualize import visualize
 
 
 
 
21
  except Exception as e:
22
  print(f"Error importing project modules: {e}")
23
- # Fallback/Dummy visualize for testing if module is missing
24
- def visualize(**kwargs):
25
- print("Visualizer called (Dummy)")
26
- return None
 
 
27
 
28
  # Constants
29
  HF_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
@@ -31,7 +37,7 @@ EPOCH_SUBFOLDER = "stage2_v2/epoch-030"
31
  CODEBOOK_SIZE = 512
32
  DATASET_PATH = os.environ.get("DATASET_PATH", "enriched_dataset.json")
33
 
34
- # Inference Params
35
  INFERENCE_TEMPERATURE = 0.7
36
  INFERENCE_TOP_K = 50
37
  INFERENCE_REPETITION_PENALTY = 1.2
@@ -41,18 +47,24 @@ M_END = "<M_END>"
41
  # Global model cache
42
  MODEL = None
43
  TOKENIZER = None
 
44
  M_START_ID = None
45
  M_END_ID = None
46
  VARIANT_MAP = {}
47
 
48
  def load_variant_map():
 
49
  global VARIANT_MAP
 
 
50
  candidates = [
51
  DATASET_PATH,
52
  os.path.join(os.path.dirname(__file__), DATASET_PATH),
53
- "data/motion_llm_dataset.json",
 
54
  "motion_llm_dataset.json"
55
  ]
 
56
  found_path = None
57
  for p in candidates:
58
  if os.path.exists(p):
@@ -64,63 +76,136 @@ def load_variant_map():
64
  try:
65
  with open(found_path, 'r', encoding='utf-8') as f:
66
  data = json.load(f)
 
67
  mapping = {}
 
68
  for entry in data:
 
69
  word = entry.get("word") or entry.get("text_query")
70
  if not word: continue
 
 
 
 
 
 
71
  word = word.lower().strip()
72
  pid = entry.get("participant_id")
 
73
  if word and pid:
74
- mapping.setdefault(word, []).append(str(pid))
 
 
 
 
 
75
  VARIANT_MAP = mapping
76
- print(f"Loaded variants for {len(VARIANT_MAP)} words.")
 
 
 
 
 
 
 
77
  except Exception as e:
78
  print(f"Error loading dataset: {e}")
79
  else:
80
- print("⚠️ Dataset not found. Variants will default to 'unknown'.")
81
- # Fallbacks
82
- VARIANT_MAP["push"] = ["P40", "P123"]
83
- VARIANT_MAP["send"] = ["P40"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def init_model():
86
  global MODEL, TOKENIZER, M_START_ID, M_END_ID
87
  if MODEL is not None:
88
  return
89
-
90
  load_variant_map()
 
 
91
 
92
- from transformers import AutoModelForCausalLM, AutoTokenizer
93
- token = os.environ.get("HF_TOKEN")
94
-
95
- print(f"Loading model from HF: {HF_REPO_ID}/{EPOCH_SUBFOLDER}")
96
- TOKENIZER = AutoTokenizer.from_pretrained(HF_REPO_ID, subfolder=EPOCH_SUBFOLDER, token=token, trust_remote_code=True)
97
- MODEL = AutoModelForCausalLM.from_pretrained(HF_REPO_ID, subfolder=EPOCH_SUBFOLDER, token=token, trust_remote_code=True)
98
 
 
 
 
99
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
  MODEL.to(device)
101
  MODEL.eval()
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- # Add special tokens if missing
104
- if M_START not in TOKENIZER.get_vocab():
105
- TOKENIZER.add_special_tokens({"additional_special_tokens": [M_START, M_END]})
106
- MODEL.resize_token_embeddings(len(TOKENIZER))
107
-
108
  M_START_ID = TOKENIZER.convert_tokens_to_ids(M_START)
109
  M_END_ID = TOKENIZER.convert_tokens_to_ids(M_END)
110
 
111
- if "<motion_0>" not in TOKENIZER.get_vocab():
112
- print("Adding motion tokens...")
 
 
 
 
113
  motion_tokens = [f"<motion_{i}>" for i in range(CODEBOOK_SIZE)]
114
- TOKENIZER.add_tokens(motion_tokens, special_tokens=True)
115
- MODEL.resize_token_embeddings(len(TOKENIZER))
 
 
 
 
 
116
 
117
  def generate_motion_simple(model, tokenizer, prompt_text, device):
 
 
 
 
 
 
 
118
  word_lower = prompt_text.lower().strip()
119
- variants = VARIANT_MAP.get(word_lower, ["unknown"])
120
- pid = random.choice(variants)
121
 
 
 
 
 
 
 
 
 
122
  prompt = f"Instruction: Generate motion for word '{prompt_text}' with variant '{pid}'.\nMotion: "
123
- print(f"Input Prompt: {prompt}")
 
124
 
125
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
126
 
@@ -133,51 +218,94 @@ def generate_motion_simple(model, tokenizer, prompt_text, device):
133
  top_k=INFERENCE_TOP_K,
134
  repetition_penalty=INFERENCE_REPETITION_PENALTY,
135
  pad_token_id=tokenizer.pad_token_id,
136
- eos_token_id=M_END_ID,
137
  early_stopping=True
138
  )
139
 
140
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
141
- return decoded.split("Motion: ")[-1].strip() if "Motion: " in decoded else decoded.strip()
 
 
 
 
 
 
 
 
142
 
143
  def generate_motion_app(text_prompt):
144
- # Returns: (iframe_html, file_path, status_text)
145
  if not text_prompt:
146
- return None, None, "Please enter a prompt."
147
 
148
  if MODEL is None:
149
  try:
150
  init_model()
151
  except Exception as e:
152
- return None, None, f"Model Init Error: {e}"
153
 
 
154
  print(f"Generating for: {text_prompt}")
155
 
156
  try:
157
- # 1. Generate Tokens
158
- generated_sequence = generate_motion_simple(MODEL, TOKENIZER, text_prompt, MODEL.device)
 
 
 
 
 
159
 
160
- # Clean tokens
 
 
 
 
 
 
 
 
 
 
 
 
161
  m_tokens = re.findall(r'<M(\d+)>', generated_sequence)
162
  if not m_tokens:
 
163
  m_tokens = re.findall(r'<motion_(\d+)>', generated_sequence)
 
 
 
 
 
 
 
164
 
165
- tokens_for_vis = " ".join(m_tokens) if m_tokens else generated_sequence
166
 
167
- # 2. Visualization Paths
 
 
 
 
 
168
  data_dir = os.environ.get("DATA_DIR", "data")
169
  vqvae_ckpt = os.path.join(data_dir, "vqvae_model.pt")
170
  stats_path = os.path.join(data_dir, "vqvae_stats.pt")
171
  smplx_dir = os.path.join(data_dir, "smplx_models")
172
 
173
- # Check files
174
- if not os.path.exists(vqvae_ckpt):
175
- return None, None, f"Missing VQ-VAE model at {vqvae_ckpt}"
 
 
176
 
177
- output_html = f"motion_{text_prompt.replace(' ', '_')}.html"
 
 
 
 
178
 
179
- # 3. Create Visualization (Saves HTML to disk)
180
- visualize(
181
  tokens=tokens_for_vis,
182
  vqvae_ckpt=vqvae_ckpt,
183
  stats_path=stats_path,
@@ -187,83 +315,35 @@ def generate_motion_app(text_prompt):
187
  fps=20
188
  )
189
 
190
- # 4. Prepare Outputs
191
- if not os.path.exists(output_html):
192
- return None, None, "Error: HTML file was not generated."
193
-
194
- # A) Prepare Iframe for Preview (Base64 encoding)
195
- with open(output_html, "rb") as f:
196
- encoded_html = base64.b64encode(f.read()).decode('utf-8')
197
-
198
- iframe = f"""<iframe
199
- src="data:text/html;base64,{encoded_html}"
200
- width="100%"
201
- height="600px"
202
- style="border:none;">
203
- </iframe>"""
204
-
205
- # B) Prepare Status Message
206
- status_msg = f"✅ Success! Generated {len(m_tokens)} tokens.\nSequence: {tokens_for_vis[:50]}..."
207
 
208
- # Return: (HTML Preview, File Path for Download, Status)
209
- return iframe, output_html, status_msg
210
 
211
  except Exception as e:
212
- import traceback
213
- traceback.print_exc()
214
- return None, None, f"Error: {str(e)}"
215
 
216
- # --- Gradio UI ---
217
- custom_css = """
218
- .gradio-container { max-width: 1400px !important; }
219
- .viz-section { min-height: 700px; }
220
- """
221
 
222
- with gr.Blocks(css=custom_css, title="SignMotionGPT Demo") as demo:
223
- gr.Markdown("# 🤟 SignMotionGPT Demo")
224
-
225
- with gr.Row():
226
- # INPUT COLUMN
227
- with gr.Column(scale=1):
228
- text_input = gr.Textbox(label="Enter Word", placeholder="e.g., push")
229
- with gr.Row():
230
- clear_btn = gr.Button("Clear")
231
- submit_btn = gr.Button("Generate Motion", variant="primary")
232
-
233
- status_output = gr.Textbox(label="Status", lines=5, interactive=False)
234
-
235
- # DOWNLOAD BUTTON (New!)
236
- gr.Markdown("### 📥 Download Result")
237
- file_output = gr.File(label="Download HTML Animation")
238
-
239
- # PREVIEW COLUMN
240
- with gr.Column(scale=3, elem_classes="viz-section"):
241
- gr.Markdown("### 🎭 Preview")
242
- plot_output = gr.HTML(label="Avatar Motion")
243
-
244
- # Examples
245
- gr.Markdown("### Examples")
246
- with gr.Row():
247
- for word in ["push", "send", "library", "passport"]:
248
- gr.Button(word).click(
249
- fn=lambda w=word: w, outputs=text_input
250
- ).then(
251
- fn=generate_motion_app,
252
- inputs=text_input,
253
- outputs=[plot_output, file_output, status_output]
254
- )
255
-
256
- # Main Events
257
- submit_btn.click(
258
- fn=generate_motion_app,
259
- inputs=[text_input],
260
- outputs=[plot_output, file_output, status_output] # 3 Outputs
261
- )
262
-
263
- clear_btn.click(
264
- fn=lambda: ("", None, None, ""),
265
- outputs=[text_input, plot_output, file_output, status_output]
266
- )
267
 
268
  if __name__ == "__main__":
269
- demo.launch()
 
6
  import re
7
  import json
8
  import random
 
9
  from pathlib import Path
10
 
11
+ # Add root to path to allow imports from project root when running from demo-code/
12
+ # or when running from root
13
  current_dir = os.path.dirname(os.path.abspath(__file__))
14
  parent_dir = os.path.dirname(current_dir)
15
  sys.path.append(current_dir)
 
18
  # Import project modules
19
  try:
20
  from visualize import visualize
21
+ # Try importing what we can, but we will implement generation logic directly here
22
+ # to match test_overfit.py / metrics.py exactly and avoid dependency issues.
23
+ # We catch Exception because unsloth in model.py might raise NotImplementedError on CPU
24
+ from model import get_motion_token_info
25
  except Exception as e:
26
  print(f"Error importing project modules: {e}")
27
+ print("Make sure you are running this from the project root or have the project structure intact.")
28
+ # Fallback for explicit relative imports if needed in some environments
29
+ try:
30
+ from visualize import visualize
31
+ except Exception as vis_e:
32
+ print(f"Visualize import failed too: {vis_e}")
33
 
34
  # Constants
35
  HF_REPO_ID = "rdz-falcon/SignMotionGPTfit-archive"
 
37
  CODEBOOK_SIZE = 512
38
  DATASET_PATH = os.environ.get("DATASET_PATH", "enriched_dataset.json")
39
 
40
+ # Hardcoded Config from test_overfit.py / config.py
41
  INFERENCE_TEMPERATURE = 0.7
42
  INFERENCE_TOP_K = 50
43
  INFERENCE_REPETITION_PENALTY = 1.2
 
47
  # Global model cache
48
  MODEL = None
49
  TOKENIZER = None
50
+ # We use M_START/M_END as in test_overfit.py
51
  M_START_ID = None
52
  M_END_ID = None
53
  VARIANT_MAP = {}
54
 
55
  def load_variant_map():
56
+ """Load dataset to map words to valid participant IDs."""
57
  global VARIANT_MAP
58
+
59
+ # Try multiple possible paths for the dataset
60
  candidates = [
61
  DATASET_PATH,
62
  os.path.join(os.path.dirname(__file__), DATASET_PATH),
63
+ os.path.join(os.path.dirname(__file__), "..", DATASET_PATH),
64
+ "data/motion_llm_dataset.json", # Fallback to raw dataset if enriched missing
65
  "motion_llm_dataset.json"
66
  ]
67
+
68
  found_path = None
69
  for p in candidates:
70
  if os.path.exists(p):
 
76
  try:
77
  with open(found_path, 'r', encoding='utf-8') as f:
78
  data = json.load(f)
79
+
80
  mapping = {}
81
+ count = 0
82
  for entry in data:
83
+ # Support both formats (enriched or raw)
84
  word = entry.get("word") or entry.get("text_query")
85
  if not word: continue
86
+
87
+ # Clean word (sometimes text_query is "Motion for word 'hello'")
88
+ if "motion for word" in word.lower():
89
+ # extraction heuristic if needed, but 'word' field is preferred
90
+ pass
91
+
92
  word = word.lower().strip()
93
  pid = entry.get("participant_id")
94
+
95
  if word and pid:
96
+ if word not in mapping:
97
+ mapping[word] = []
98
+ if pid not in mapping[word]:
99
+ mapping[word].append(str(pid))
100
+ count += 1
101
+
102
  VARIANT_MAP = mapping
103
+ print(f"Loaded {count} variants for {len(VARIANT_MAP)} words.")
104
+
105
+ # Debug check for 'push'
106
+ if 'push' in VARIANT_MAP:
107
+ print(f" 'push' variants: {VARIANT_MAP['push']}")
108
+ else:
109
+ print(" 'push' NOT found in dataset.")
110
+
111
  except Exception as e:
112
  print(f"Error loading dataset: {e}")
113
  else:
114
+ print(f"⚠️ Dataset not found. Tried: {candidates}. Variants will default to 'unknown'.")
115
+
116
+ # Hardcoded fallback for demonstration words if missing from dataset
117
+ defaults = {
118
+ "push": ["P40", "P123", "P1"],
119
+ "send": ["P40", "P123"],
120
+ "library": ["P40"],
121
+ "passport": ["P40"]
122
+ }
123
+ for w, pids in defaults.items():
124
+ if w not in VARIANT_MAP:
125
+ VARIANT_MAP[w] = pids
126
+ print(f" Added fallback variants for '{w}': {pids}")
127
+
128
+ def load_model_from_hf(repo_id, subfolder, token=None):
129
+ from transformers import AutoModelForCausalLM, AutoTokenizer
130
+ print(f"Loading model from HF: {repo_id}/{subfolder}")
131
+ try:
132
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
133
+ model = AutoModelForCausalLM.from_pretrained(repo_id, subfolder=subfolder, token=token, trust_remote_code=True)
134
+ return model, tokenizer
135
+ except Exception as e:
136
+ print(f"Error loading model: {e}")
137
+ return None, None
138
 
139
  def init_model():
140
  global MODEL, TOKENIZER, M_START_ID, M_END_ID
141
  if MODEL is not None:
142
  return
143
+
144
  load_variant_map()
145
+
146
+ token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
147
 
148
+ # Load model/tokenizer
149
+ MODEL, TOKENIZER = load_model_from_hf(HF_REPO_ID, EPOCH_SUBFOLDER, token)
 
 
 
 
150
 
151
+ if MODEL is None:
152
+ raise RuntimeError(f"Failed to load model from {HF_REPO_ID}/{EPOCH_SUBFOLDER}")
153
+
154
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155
  MODEL.to(device)
156
  MODEL.eval()
157
+
158
+ # Setup special tokens matching test_overfit.py
159
+ # test_overfit.py uses M_START="<M_START>" and M_END="<M_END>"
160
+
161
+ # Check if tokens exist
162
+ if M_START not in TOKENIZER.get_vocab() or M_END not in TOKENIZER.get_vocab():
163
+ print(f"⚠️ Warning: {M_START} or {M_END} not found in tokenizer. Adding them now...")
164
+ num_added = TOKENIZER.add_special_tokens({"additional_special_tokens": [M_START, M_END]})
165
+ if num_added > 0:
166
+ MODEL.resize_token_embeddings(len(TOKENIZER))
167
+ print(f" Added {num_added} special tokens.")
168
 
 
 
 
 
 
169
  M_START_ID = TOKENIZER.convert_tokens_to_ids(M_START)
170
  M_END_ID = TOKENIZER.convert_tokens_to_ids(M_END)
171
 
172
+ # Check motion tokens
173
+ # We expect <motion_0> ... <motion_511>
174
+ # If missing, add them
175
+ first_motion = "<motion_0>"
176
+ if first_motion not in TOKENIZER.get_vocab():
177
+ print("⚠️ Warning: Motion tokens not found. Adding them now...")
178
  motion_tokens = [f"<motion_{i}>" for i in range(CODEBOOK_SIZE)]
179
+ num_added = TOKENIZER.add_tokens(motion_tokens, special_tokens=True)
180
+ if num_added > 0:
181
+ MODEL.resize_token_embeddings(len(TOKENIZER))
182
+ print(f" Added {num_added} motion tokens.")
183
+
184
+ print(f"Model initialized. Vocab size: {len(TOKENIZER)}")
185
+ print(f"M_START_ID: {M_START_ID}, M_END_ID: {M_END_ID}")
186
 
187
  def generate_motion_simple(model, tokenizer, prompt_text, device):
188
+ """
189
+ Replicates the simple generation logic from metrics.py / test_overfit.py
190
+ """
191
+ # Construct prompt exactly as in test_overfit.py:
192
+ # prompt = f"Instruction: Generate motion for word '{sample['word']}' with variant '{sample['participant_id']}'.\nMotion: "
193
+
194
+ # Get a valid participant ID if possible
195
  word_lower = prompt_text.lower().strip()
196
+ variants = VARIANT_MAP.get(word_lower, [])
 
197
 
198
+ if variants:
199
+ pid = random.choice(variants)
200
+ print(f"Selected variant '{pid}' for word '{prompt_text}'")
201
+ else:
202
+ # Fallback to 'unknown' or a common PID if known (e.g., P1)
203
+ pid = "unknown"
204
+ print(f"No variants found for '{prompt_text}', using '{pid}'")
205
+
206
  prompt = f"Instruction: Generate motion for word '{prompt_text}' with variant '{pid}'.\nMotion: "
207
+
208
+ print(f"Input Prompt:\n{prompt}")
209
 
210
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
211
 
 
218
  top_k=INFERENCE_TOP_K,
219
  repetition_penalty=INFERENCE_REPETITION_PENALTY,
220
  pad_token_id=tokenizer.pad_token_id,
221
+ eos_token_id=M_END_ID, # Stop at <M_END>
222
  early_stopping=True
223
  )
224
 
225
  decoded = tokenizer.decode(output[0], skip_special_tokens=False)
226
+
227
+ # Parse output to extract just the motion part
228
+ # We expect: ... \nMotion: <M_START> <motion_...> ... <M_END>
229
+ if "Motion: " in decoded:
230
+ motion_part = decoded.split("Motion: ")[-1]
231
+ else:
232
+ motion_part = decoded
233
+
234
+ return motion_part.strip()
235
 
236
  def generate_motion_app(text_prompt):
 
237
  if not text_prompt:
238
+ return None, "Please enter a prompt."
239
 
240
  if MODEL is None:
241
  try:
242
  init_model()
243
  except Exception as e:
244
+ return None, f"Model Initialization Failed: {e}"
245
 
246
+ device = MODEL.device
247
  print(f"Generating for: {text_prompt}")
248
 
249
  try:
250
+ generated_sequence = generate_motion_simple(MODEL, TOKENIZER, text_prompt, device)
251
+ print("Generated sequence (raw):", generated_sequence)
252
+
253
+ # Extract tokens for visualization
254
+ # Logic from metrics.py: _extract_motion_tokens_from_sequence
255
+ # Expect tokens like <M123> or <motion_123>
256
+ # The generation might include M_START/M_END.
257
 
258
+ # Clean up for visualization input
259
+ # We need a string of tokens.
260
+ # If the output is like "<M_START> <motion_1> <motion_2> <M_END>", we pass that.
261
+ # visualize.py's parse_motion_tokens handles <motion_ID> regex.
262
+ # BUT visualize.py expects either "123 456" OR "<motion_123> <motion_456>"
263
+ # It does NOT explicitly handle <M123> which is what we might have here if M_START was used.
264
+ # Let's convert <M123> to space-separated integers for safety.
265
+
266
+ # Extract integers from <M123> or <motion_123>
267
+ # generated_sequence is raw string from tokenizer decode
268
+
269
+ import re
270
+ # Try <M123> format (test_overfit style)
271
  m_tokens = re.findall(r'<M(\d+)>', generated_sequence)
272
  if not m_tokens:
273
+ # Try <motion_123> format
274
  m_tokens = re.findall(r'<motion_(\d+)>', generated_sequence)
275
+
276
+ if m_tokens:
277
+ # Reconstruct as space-separated string for visualize.py
278
+ tokens_for_vis = " ".join(m_tokens)
279
+ else:
280
+ # Fallback to raw string if regex failed (visualize.py might handle other formats)
281
+ tokens_for_vis = generated_sequence
282
 
283
+ print(f"Tokens for visualization: {tokens_for_vis[:50]}...")
284
 
285
+ except Exception as e:
286
+ return None, f"Generation Error: {e}"
287
+
288
+ # Visualization
289
+ try:
290
+ # Ensure paths for VQ-VAE and SMPL-X
291
  data_dir = os.environ.get("DATA_DIR", "data")
292
  vqvae_ckpt = os.path.join(data_dir, "vqvae_model.pt")
293
  stats_path = os.path.join(data_dir, "vqvae_stats.pt")
294
  smplx_dir = os.path.join(data_dir, "smplx_models")
295
 
296
+ # Check existence
297
+ missing = []
298
+ if not os.path.exists(vqvae_ckpt): missing.append(vqvae_ckpt)
299
+ if not os.path.exists(stats_path): missing.append(stats_path)
300
+ if not os.path.exists(smplx_dir): missing.append(smplx_dir)
301
 
302
+ if missing:
303
+ return None, f"Missing visualization files in {data_dir}: {missing}. Please ensure they are uploaded to the Space."
304
+
305
+ # Output to a temporary file
306
+ output_html = "temp_viz.html"
307
 
308
+ fig = visualize(
 
309
  tokens=tokens_for_vis,
310
  vqvae_ckpt=vqvae_ckpt,
311
  stats_path=stats_path,
 
315
  fps=20
316
  )
317
 
318
+ if fig is None:
319
+ return None, "Visualization failed (no frames produced)."
320
+
321
+ # Count tokens for display
322
+ matches = re.findall(r'<motion_(\d+)>', tokens_for_vis)
323
+ # Also check for <M...> format just in case
324
+ if not matches:
325
+ matches = re.findall(r'<M(\d+)>', tokens_for_vis)
326
+
327
+ num_tokens = len(matches)
 
 
 
 
 
 
 
328
 
329
+ return fig, f"Success! Generated tokens length: {num_tokens}. Sequence: {tokens_for_vis[:100]}..."
 
330
 
331
  except Exception as e:
332
+ return None, f"Visualization Error: {e}"
 
 
333
 
 
 
 
 
 
334
 
335
+ # Gradio UI
336
+ with gr.Interface(
337
+ fn=generate_motion_app,
338
+ inputs=gr.Textbox(label="Enter Motion Prompt", placeholder="e.g. walking forward"),
339
+ outputs=[
340
+ gr.Plot(label="Motion Visualization"),
341
+ gr.Textbox(label="Status/Output")
342
+ ],
343
+ title="SignMotionGPT Demo",
344
+ description="Generate Sign Language/Motion Avatars from Text. Using model checkpoint: epoch 30."
345
+ ) as demo:
346
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  if __name__ == "__main__":
349
+ demo.launch()