Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,6 @@ Uses PyRender for high-quality avatar visualization
|
|
| 6 |
# IMPORTANT: Set OpenGL platform BEFORE any OpenGL imports (for headless rendering)
|
| 7 |
import os
|
| 8 |
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
| 9 |
-
|
| 10 |
import sys
|
| 11 |
import re
|
| 12 |
import json
|
|
@@ -15,12 +14,9 @@ import warnings
|
|
| 15 |
import tempfile
|
| 16 |
import uuid
|
| 17 |
from pathlib import Path
|
| 18 |
-
|
| 19 |
import torch
|
| 20 |
import numpy as np
|
| 21 |
-
|
| 22 |
warnings.filterwarnings("ignore")
|
| 23 |
-
|
| 24 |
# =====================================================================
|
| 25 |
# Configuration for HuggingFace Spaces
|
| 26 |
# =====================================================================
|
|
@@ -29,19 +25,15 @@ DATA_DIR = os.path.join(WORK_DIR, "data")
|
|
| 29 |
OUTPUT_DIR = os.path.join(WORK_DIR, "outputs")
|
| 30 |
os.makedirs(DATA_DIR, exist_ok=True)
|
| 31 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 32 |
-
|
| 33 |
# Path definitions
|
| 34 |
DATASET_PATH = os.path.join(DATA_DIR, "motion_llm_dataset.json")
|
| 35 |
VQVAE_CHECKPOINT = os.path.join(DATA_DIR, "vqvae_model.pt")
|
| 36 |
STATS_PATH = os.path.join(DATA_DIR, "vqvae_stats.pt")
|
| 37 |
SMPLX_MODEL_DIR = os.path.join(DATA_DIR, "smplx_models")
|
| 38 |
-
|
| 39 |
# HuggingFace model config
|
| 40 |
HF_REPO_ID = os.environ.get("HF_REPO_ID", "rdz-falcon/SignMotionGPTfit-archive")
|
| 41 |
HF_SUBFOLDER = os.environ.get("HF_SUBFOLDER", "stage2_v2/epoch-030")
|
| 42 |
-
|
| 43 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 44 |
-
|
| 45 |
# Generation parameters
|
| 46 |
M_START = "<M_START>"
|
| 47 |
M_END = "<M_END>"
|
|
@@ -49,7 +41,6 @@ PAD_TOKEN = "<PAD>"
|
|
| 49 |
INFERENCE_TEMPERATURE = 0.7
|
| 50 |
INFERENCE_TOP_K = 50
|
| 51 |
INFERENCE_REPETITION_PENALTY = 1.2
|
| 52 |
-
|
| 53 |
# VQ-VAE parameters
|
| 54 |
SMPL_DIM = 182
|
| 55 |
CODEBOOK_SIZE = 512
|
|
@@ -58,18 +49,15 @@ VQ_ARGS = dict(
|
|
| 58 |
width=512, depth=3, down_t=2, stride_t=2,
|
| 59 |
dilation_growth_rate=3, activation='relu', norm=None, quantizer="ema_reset"
|
| 60 |
)
|
| 61 |
-
|
| 62 |
PARAM_DIMS = [10, 63, 45, 45, 3, 10, 3, 3]
|
| 63 |
-
PARAM_NAMES = ["
|
| 64 |
-
"
|
| 65 |
-
|
| 66 |
# Visualization defaults
|
| 67 |
AVATAR_COLOR = (0.36, 0.78, 0.36, 1.0) # Green color as RGBA
|
| 68 |
VIDEO_FPS = 15
|
| 69 |
VIDEO_SLOWDOWN = 2
|
| 70 |
FRAME_WIDTH = 544 # Must be divisible by 16 for video codec compatibility
|
| 71 |
FRAME_HEIGHT = 720
|
| 72 |
-
|
| 73 |
# =====================================================================
|
| 74 |
# Install/Import Dependencies
|
| 75 |
# =====================================================================
|
|
@@ -78,13 +66,11 @@ try:
|
|
| 78 |
except ImportError:
|
| 79 |
os.system("pip install -q gradio>=4.0.0")
|
| 80 |
import gradio as gr
|
| 81 |
-
|
| 82 |
try:
|
| 83 |
import smplx
|
| 84 |
except ImportError:
|
| 85 |
os.system("pip install -q smplx==0.1.28")
|
| 86 |
import smplx
|
| 87 |
-
|
| 88 |
# PyRender for high-quality rendering
|
| 89 |
PYRENDER_AVAILABLE = False
|
| 90 |
try:
|
|
@@ -94,16 +80,13 @@ try:
|
|
| 94 |
PYRENDER_AVAILABLE = True
|
| 95 |
except ImportError:
|
| 96 |
pass
|
| 97 |
-
|
| 98 |
try:
|
| 99 |
import imageio
|
| 100 |
except ImportError:
|
| 101 |
os.system("pip install -q imageio[ffmpeg]")
|
| 102 |
import imageio
|
| 103 |
-
|
| 104 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 105 |
import torch.nn.functional as F
|
| 106 |
-
|
| 107 |
# =====================================================================
|
| 108 |
# Import VQ-VAE architecture
|
| 109 |
# =====================================================================
|
|
@@ -113,13 +96,11 @@ if parent_dir not in sys.path:
|
|
| 113 |
sys.path.insert(0, parent_dir)
|
| 114 |
if current_dir not in sys.path:
|
| 115 |
sys.path.insert(0, current_dir)
|
| 116 |
-
|
| 117 |
try:
|
| 118 |
from mGPT.archs.mgpt_vq import VQVae
|
| 119 |
except ImportError as e:
|
| 120 |
print(f"Warning: Could not import VQVae: {e}")
|
| 121 |
VQVae = None
|
| 122 |
-
|
| 123 |
# =====================================================================
|
| 124 |
# Global Cache
|
| 125 |
# =====================================================================
|
|
@@ -131,10 +112,8 @@ _model_cache = {
|
|
| 131 |
"stats": (None, None),
|
| 132 |
"initialized": False
|
| 133 |
}
|
| 134 |
-
|
| 135 |
_word_pid_map = {}
|
| 136 |
_example_cache = {}
|
| 137 |
-
|
| 138 |
# =====================================================================
|
| 139 |
# PyRender Setup
|
| 140 |
# =====================================================================
|
|
@@ -143,12 +122,12 @@ def ensure_pyrender():
|
|
| 143 |
global PYRENDER_AVAILABLE, trimesh, pyrender, Image, ImageDraw, ImageFont
|
| 144 |
if PYRENDER_AVAILABLE:
|
| 145 |
return True
|
| 146 |
-
|
| 147 |
print("Installing pyrender dependencies...")
|
| 148 |
if os.path.exists("/etc/debian_version"):
|
| 149 |
os.system("apt-get update -qq && apt-get install -qq -y libegl1-mesa-dev libgles2-mesa-dev > /dev/null 2>&1")
|
| 150 |
os.system("pip install -q trimesh pyrender PyOpenGL PyOpenGL_accelerate Pillow")
|
| 151 |
-
|
| 152 |
try:
|
| 153 |
import trimesh
|
| 154 |
import pyrender
|
|
@@ -158,23 +137,22 @@ def ensure_pyrender():
|
|
| 158 |
except ImportError as e:
|
| 159 |
print(f"Could not install pyrender: {e}")
|
| 160 |
return False
|
| 161 |
-
|
| 162 |
# =====================================================================
|
| 163 |
# Dataset Loading - Word to PID mapping
|
| 164 |
# =====================================================================
|
| 165 |
def load_word_pid_mapping():
|
| 166 |
"""Load the dataset and build word -> PIDs mapping."""
|
| 167 |
global _word_pid_map
|
| 168 |
-
|
| 169 |
if not os.path.exists(DATASET_PATH):
|
| 170 |
print(f"Dataset not found: {DATASET_PATH}")
|
| 171 |
return
|
| 172 |
-
|
| 173 |
print(f"Loading dataset from: {DATASET_PATH}")
|
| 174 |
try:
|
| 175 |
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
|
| 176 |
data = json.load(f)
|
| 177 |
-
|
| 178 |
for entry in data:
|
| 179 |
word = entry.get('word', '').lower()
|
| 180 |
pid = entry.get('participant_id', '')
|
|
@@ -182,21 +160,17 @@ def load_word_pid_mapping():
|
|
| 182 |
if word not in _word_pid_map:
|
| 183 |
_word_pid_map[word] = set()
|
| 184 |
_word_pid_map[word].add(pid)
|
| 185 |
-
|
| 186 |
for word in _word_pid_map:
|
| 187 |
_word_pid_map[word] = sorted(list(_word_pid_map[word]))
|
| 188 |
-
|
| 189 |
print(f"Loaded {len(_word_pid_map)} unique words from dataset")
|
| 190 |
except Exception as e:
|
| 191 |
print(f"Error loading dataset: {e}")
|
| 192 |
-
|
| 193 |
-
|
| 194 |
def get_pids_for_word(word: str) -> list:
|
| 195 |
"""Get valid PIDs for a word from the dataset."""
|
| 196 |
word = word.lower().strip()
|
| 197 |
return _word_pid_map.get(word, [])
|
| 198 |
-
|
| 199 |
-
|
| 200 |
def get_random_pids_for_word(word: str, count: int = 2) -> list:
|
| 201 |
"""Get random PIDs for a word. Returns up to 'count' PIDs."""
|
| 202 |
pids = get_pids_for_word(word)
|
|
@@ -205,29 +179,26 @@ def get_random_pids_for_word(word: str, count: int = 2) -> list:
|
|
| 205 |
if len(pids) <= count:
|
| 206 |
return pids
|
| 207 |
return random.sample(pids, count)
|
| 208 |
-
|
| 209 |
-
|
| 210 |
def get_example_words_with_pids(count: int = 3) -> list:
|
| 211 |
"""Get example words with valid PIDs from dataset."""
|
| 212 |
examples = []
|
| 213 |
preferred = ['push', 'passport', 'library', 'send', 'college', 'help', 'thank', 'hello']
|
| 214 |
-
|
| 215 |
for word in preferred:
|
| 216 |
pids = get_pids_for_word(word)
|
| 217 |
if pids:
|
| 218 |
examples.append((word, pids[0]))
|
| 219 |
if len(examples) >= count:
|
| 220 |
break
|
| 221 |
-
|
| 222 |
if len(examples) < count:
|
| 223 |
available = [w for w in _word_pid_map.keys() if w not in [e[0] for e in examples]]
|
| 224 |
random.shuffle(available)
|
| 225 |
for word in available[:count - len(examples)]:
|
| 226 |
pids = _word_pid_map[word]
|
| 227 |
examples.append((word, pids[0]))
|
| 228 |
-
|
| 229 |
-
return examples
|
| 230 |
|
|
|
|
| 231 |
# =====================================================================
|
| 232 |
# VQ-VAE Wrapper
|
| 233 |
# =====================================================================
|
|
@@ -240,14 +211,13 @@ class MotionGPT_VQVAE_Wrapper(torch.nn.Module):
|
|
| 240 |
nfeats=smpl_dim, code_num=codebook_size, code_dim=code_dim,
|
| 241 |
output_emb_width=code_dim, **kwargs
|
| 242 |
)
|
| 243 |
-
|
| 244 |
# =====================================================================
|
| 245 |
# Model Loading Functions
|
| 246 |
# =====================================================================
|
| 247 |
def load_llm_model():
|
| 248 |
print(f"Loading LLM from: {HF_REPO_ID}/{HF_SUBFOLDER}")
|
| 249 |
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
|
| 250 |
-
|
| 251 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 252 |
HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True, token=token
|
| 253 |
)
|
|
@@ -263,8 +233,6 @@ def load_llm_model():
|
|
| 263 |
model.eval()
|
| 264 |
print(f"LLM loaded (vocab size: {len(tokenizer)})")
|
| 265 |
return model, tokenizer
|
| 266 |
-
|
| 267 |
-
|
| 268 |
def load_vqvae_model():
|
| 269 |
if not os.path.exists(VQVAE_CHECKPOINT):
|
| 270 |
print(f"VQ-VAE checkpoint not found: {VQVAE_CHECKPOINT}")
|
|
@@ -277,8 +245,6 @@ def load_vqvae_model():
|
|
| 277 |
model.eval()
|
| 278 |
print(f"VQ-VAE loaded")
|
| 279 |
return model
|
| 280 |
-
|
| 281 |
-
|
| 282 |
def load_stats():
|
| 283 |
if not os.path.exists(STATS_PATH):
|
| 284 |
return None, None
|
|
@@ -287,8 +253,6 @@ def load_stats():
|
|
| 287 |
if torch.is_tensor(mean): mean = mean.cpu().numpy()
|
| 288 |
if torch.is_tensor(std): std = std.cpu().numpy()
|
| 289 |
return mean, std
|
| 290 |
-
|
| 291 |
-
|
| 292 |
def load_smplx_model():
|
| 293 |
if not os.path.exists(SMPLX_MODEL_DIR):
|
| 294 |
print(f"SMPL-X directory not found: {SMPLX_MODEL_DIR}")
|
|
@@ -302,47 +266,43 @@ def load_smplx_model():
|
|
| 302 |
).to(DEVICE)
|
| 303 |
print(f"SMPL-X loaded")
|
| 304 |
return model
|
| 305 |
-
|
| 306 |
-
|
| 307 |
def initialize_models():
|
| 308 |
global _model_cache
|
| 309 |
if _model_cache["initialized"]:
|
| 310 |
return
|
| 311 |
-
|
| 312 |
print("\n" + "="*60)
|
| 313 |
print(" Initializing SignMotionGPT Models")
|
| 314 |
print("="*60)
|
| 315 |
-
|
| 316 |
load_word_pid_mapping()
|
| 317 |
-
|
| 318 |
_model_cache["llm_model"], _model_cache["llm_tokenizer"] = load_llm_model()
|
| 319 |
-
|
| 320 |
try:
|
| 321 |
_model_cache["vqvae_model"] = load_vqvae_model()
|
| 322 |
_model_cache["stats"] = load_stats()
|
| 323 |
_model_cache["smplx_model"] = load_smplx_model()
|
| 324 |
except Exception as e:
|
| 325 |
print(f"Could not load visualization models: {e}")
|
| 326 |
-
|
| 327 |
# Ensure PyRender is available
|
| 328 |
ensure_pyrender()
|
| 329 |
-
|
| 330 |
_model_cache["initialized"] = True
|
| 331 |
print("All models initialized")
|
| 332 |
print("="*60)
|
| 333 |
-
|
| 334 |
-
|
| 335 |
def precompute_examples():
|
| 336 |
"""Pre-compute animations for example words at startup."""
|
| 337 |
global _example_cache
|
| 338 |
-
|
| 339 |
if not _model_cache["initialized"]:
|
| 340 |
return
|
| 341 |
-
|
| 342 |
examples = get_example_words_with_pids(3)
|
| 343 |
-
|
| 344 |
print(f"\nPre-computing {len(examples)} example animations...")
|
| 345 |
-
|
| 346 |
for word, pid in examples:
|
| 347 |
key = f"{word}_{pid}"
|
| 348 |
print(f" Computing: {word} ({pid})...")
|
|
@@ -353,22 +313,21 @@ def precompute_examples():
|
|
| 353 |
except Exception as e:
|
| 354 |
print(f" Failed: {word} - {e}")
|
| 355 |
_example_cache[key] = {"video_path": None, "tokens": "", "word": word, "pid": pid}
|
| 356 |
-
|
| 357 |
-
print("Example pre-computation complete\n")
|
| 358 |
|
|
|
|
| 359 |
# =====================================================================
|
| 360 |
# Motion Generation Functions
|
| 361 |
# =====================================================================
|
| 362 |
def generate_motion_tokens(word: str, variant: str) -> str:
|
| 363 |
model = _model_cache["llm_model"]
|
| 364 |
tokenizer = _model_cache["llm_tokenizer"]
|
| 365 |
-
|
| 366 |
if model is None or tokenizer is None:
|
| 367 |
raise RuntimeError("LLM model not loaded")
|
| 368 |
-
|
| 369 |
prompt = f"Instruction: Generate motion for word '{word}' with variant '{variant}'.\nMotion: "
|
| 370 |
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 371 |
-
|
| 372 |
with torch.no_grad():
|
| 373 |
output = model.generate(
|
| 374 |
**inputs, max_new_tokens=100, do_sample=True,
|
|
@@ -378,46 +337,42 @@ def generate_motion_tokens(word: str, variant: str) -> str:
|
|
| 378 |
eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
|
| 379 |
early_stopping=True
|
| 380 |
)
|
| 381 |
-
|
| 382 |
decoded = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 383 |
motion_part = decoded.split("Motion: ")[-1] if "Motion: " in decoded else decoded
|
| 384 |
return motion_part.strip()
|
| 385 |
-
|
| 386 |
-
|
| 387 |
def parse_motion_tokens(token_str: str) -> list:
|
| 388 |
if isinstance(token_str, (list, tuple, np.ndarray)):
|
| 389 |
return [int(x) for x in token_str]
|
| 390 |
if not isinstance(token_str, str):
|
| 391 |
return []
|
| 392 |
-
|
| 393 |
matches = re.findall(r'<M(\d+)>', token_str)
|
| 394 |
if matches:
|
| 395 |
return [int(x) for x in matches]
|
| 396 |
-
|
| 397 |
matches = re.findall(r'<motion_(\d+)>', token_str)
|
| 398 |
if matches:
|
| 399 |
return [int(x) for x in matches]
|
| 400 |
-
|
| 401 |
-
return []
|
| 402 |
-
|
| 403 |
|
|
|
|
| 404 |
def decode_tokens_to_params(tokens: list) -> np.ndarray:
|
| 405 |
vqvae_model = _model_cache["vqvae_model"]
|
| 406 |
mean, std = _model_cache["stats"]
|
| 407 |
-
|
| 408 |
if vqvae_model is None or not tokens:
|
| 409 |
return np.zeros((0, SMPL_DIM), dtype=np.float32)
|
| 410 |
-
|
| 411 |
idx = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
| 412 |
T_q = idx.shape[1]
|
| 413 |
quantizer = vqvae_model.vqvae.quantizer
|
| 414 |
-
|
| 415 |
if hasattr(quantizer, "codebook"):
|
| 416 |
codebook = quantizer.codebook.to(DEVICE)
|
| 417 |
code_dim = codebook.shape[1]
|
| 418 |
else:
|
| 419 |
code_dim = CODE_DIM
|
| 420 |
-
|
| 421 |
x_quantized = None
|
| 422 |
if hasattr(quantizer, "dequantize"):
|
| 423 |
try:
|
|
@@ -431,47 +386,55 @@ def decode_tokens_to_params(tokens: list) -> np.ndarray:
|
|
| 431 |
x_quantized = dq.permute(0, 2, 1).contiguous()
|
| 432 |
except Exception:
|
| 433 |
pass
|
| 434 |
-
|
| 435 |
if x_quantized is None:
|
| 436 |
if not hasattr(quantizer, "codebook"):
|
| 437 |
return np.zeros((0, SMPL_DIM), dtype=np.float32)
|
| 438 |
with torch.no_grad():
|
| 439 |
emb = codebook[idx]
|
| 440 |
x_quantized = emb.permute(0, 2, 1).contiguous()
|
| 441 |
-
|
| 442 |
with torch.no_grad():
|
| 443 |
x_dec = vqvae_model.vqvae.decoder(x_quantized)
|
| 444 |
smpl_out = vqvae_model.vqvae.postprocess(x_dec)
|
| 445 |
params_np = smpl_out.squeeze(0).cpu().numpy()
|
| 446 |
-
|
| 447 |
if (mean is not None) and (std is not None):
|
| 448 |
params_np = (params_np * np.array(std).reshape(1, -1)) + np.array(mean).reshape(1, -1)
|
| 449 |
-
|
| 450 |
-
return params_np
|
| 451 |
-
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
|
|
|
|
|
|
| 458 |
starts = np.cumsum([0] + PARAM_DIMS[:-1])
|
| 459 |
ends = starts + np.array(PARAM_DIMS)
|
|
|
|
| 460 |
T = params_seq.shape[0]
|
| 461 |
all_verts = []
|
| 462 |
-
|
|
|
|
| 463 |
num_body_joints = getattr(smplx_model, "NUM_BODY_JOINTS", 21)
|
| 464 |
|
| 465 |
with torch.no_grad():
|
| 466 |
for s in range(0, T, batch_size):
|
| 467 |
-
batch = params_seq[s:s+batch_size]
|
| 468 |
B = batch.shape[0]
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
-
# Handle body pose
|
| 474 |
-
# root_pose is separate as global_orient (3 dims)
|
| 475 |
body_t = tensor_parts['body_pose']
|
| 476 |
L_body = body_t.shape[1]
|
| 477 |
expected_no_go = num_body_joints * 3
|
|
@@ -484,24 +447,67 @@ def params_to_vertices(params_seq: np.ndarray) -> tuple:
|
|
| 484 |
global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
|
| 485 |
body_pose_only = body_t
|
| 486 |
else:
|
|
|
|
| 487 |
if L_body > expected_no_go:
|
| 488 |
global_orient = body_t[:, :3].contiguous()
|
| 489 |
body_pose_only = body_t[:, 3:].contiguous()
|
| 490 |
else:
|
| 491 |
-
|
|
|
|
| 492 |
global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
|
| 493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 494 |
out = smplx_model(
|
| 495 |
-
betas=tensor_parts['betas'],
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
)
|
| 501 |
-
|
|
|
|
|
|
|
| 502 |
|
| 503 |
-
|
| 504 |
-
|
|
|
|
|
|
|
| 505 |
# =====================================================================
|
| 506 |
# PyRender Visualization Functions
|
| 507 |
# =====================================================================
|
|
@@ -520,20 +526,15 @@ def render_single_frame(
|
|
| 520 |
"""Render a single mesh frame using PyRender."""
|
| 521 |
if not PYRENDER_AVAILABLE:
|
| 522 |
raise RuntimeError("PyRender not available")
|
| 523 |
-
|
| 524 |
# Check for invalid vertices
|
| 525 |
if not np.isfinite(verts).all():
|
| 526 |
blank = np.ones((frame_height, frame_width, 3), dtype=np.uint8) * 200
|
| 527 |
return blank
|
| 528 |
-
|
| 529 |
-
# IMPORTANT: Rotate mesh 180 degrees around X-axis (like visualize.py)
|
| 530 |
-
# This fixes the coordinate system so we view from the front
|
| 531 |
-
rot_matrix = trimesh.transformations.rotation_matrix(np.radians(180), [1, 0, 0])
|
| 532 |
-
verts_rotated = np.dot(verts, rot_matrix[:3, :3].T)
|
| 533 |
-
|
| 534 |
# Create scene
|
| 535 |
scene = pyrender.Scene(bg_color=bg_color, ambient_light=[0.4, 0.4, 0.4])
|
| 536 |
-
|
| 537 |
# Material
|
| 538 |
material = pyrender.MetallicRoughnessMaterial(
|
| 539 |
metallicFactor=0.0,
|
|
@@ -541,29 +542,31 @@ def render_single_frame(
|
|
| 541 |
alphaMode='OPAQUE',
|
| 542 |
baseColorFactor=color
|
| 543 |
)
|
| 544 |
-
|
| 545 |
-
# Create mesh
|
| 546 |
-
mesh = trimesh.Trimesh(vertices=
|
| 547 |
mesh_render = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=True)
|
| 548 |
scene.add(mesh_render)
|
| 549 |
-
|
| 550 |
-
# Compute center for camera positioning
|
| 551 |
-
mesh_center =
|
| 552 |
camera_target = fixed_center if fixed_center is not None else mesh_center
|
| 553 |
-
|
| 554 |
# Camera setup
|
| 555 |
camera = pyrender.IntrinsicsCamera(
|
| 556 |
fx=focal_length, fy=focal_length,
|
| 557 |
cx=frame_width / 2, cy=frame_height / 2,
|
| 558 |
znear=0.1, zfar=20.0
|
| 559 |
)
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
| 562 |
camera_pose = np.eye(4)
|
| 563 |
camera_pose[0, 3] = camera_target[0] # Center X
|
| 564 |
camera_pose[1, 3] = camera_target[1] # Center Y (body center)
|
| 565 |
camera_pose[2, 3] = camera_target[2] - camera_distance # In front (negative Z)
|
| 566 |
-
|
| 567 |
# Camera orientation: flip to look at subject (SOKE-style)
|
| 568 |
# This rotation makes camera look toward +Z (at the subject)
|
| 569 |
camera_pose[:3, :3] = np.array([
|
|
@@ -571,49 +574,47 @@ def render_single_frame(
|
|
| 571 |
[0, -1, 0],
|
| 572 |
[0, 0, -1]
|
| 573 |
])
|
| 574 |
-
|
| 575 |
scene.add(camera, pose=camera_pose)
|
| 576 |
-
|
| 577 |
# Lighting
|
| 578 |
key_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
|
| 579 |
key_pose = np.eye(4)
|
| 580 |
key_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(-30), np.radians(-20), 0)[:3, :3]
|
| 581 |
scene.add(key_light, pose=key_pose)
|
| 582 |
-
|
| 583 |
fill_light = pyrender.DirectionalLight(color=[0.9, 0.9, 1.0], intensity=1.5)
|
| 584 |
fill_pose = np.eye(4)
|
| 585 |
fill_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(-20), np.radians(30), 0)[:3, :3]
|
| 586 |
scene.add(fill_light, pose=fill_pose)
|
| 587 |
-
|
| 588 |
rim_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=2.0)
|
| 589 |
rim_pose = np.eye(4)
|
| 590 |
rim_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(30), np.radians(180), 0)[:3, :3]
|
| 591 |
scene.add(rim_light, pose=rim_pose)
|
| 592 |
-
|
| 593 |
# Render
|
| 594 |
renderer = pyrender.OffscreenRenderer(viewport_width=frame_width, viewport_height=frame_height, point_size=1.0)
|
| 595 |
color_img, _ = renderer.render(scene)
|
| 596 |
renderer.delete()
|
| 597 |
-
|
| 598 |
# Add label
|
| 599 |
if label:
|
| 600 |
img = Image.fromarray(color_img)
|
| 601 |
draw = ImageDraw.Draw(img)
|
| 602 |
-
|
| 603 |
try:
|
| 604 |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
|
| 605 |
except:
|
| 606 |
font = ImageFont.load_default()
|
| 607 |
-
|
| 608 |
text_width = len(label) * 10 + 20
|
| 609 |
draw.rectangle([10, 10, 10 + text_width, 35], fill=(0, 0, 0, 180))
|
| 610 |
draw.text((15, 12), label, fill=(255, 255, 255), font=font)
|
| 611 |
-
|
| 612 |
-
color_img = np.array(img)
|
| 613 |
-
|
| 614 |
-
return color_img
|
| 615 |
|
|
|
|
| 616 |
|
|
|
|
| 617 |
def render_side_by_side_frame(
|
| 618 |
verts_list: list,
|
| 619 |
faces: np.ndarray,
|
|
@@ -628,20 +629,20 @@ def render_side_by_side_frame(
|
|
| 628 |
"""Render multiple meshes side-by-side for comparison."""
|
| 629 |
if not PYRENDER_AVAILABLE:
|
| 630 |
raise RuntimeError("PyRender not available")
|
| 631 |
-
|
| 632 |
# Colors for each avatar
|
| 633 |
colors = [
|
| 634 |
(0.3, 0.8, 0.4, 1.0), # Green
|
| 635 |
(0.3, 0.6, 0.9, 1.0), # Blue
|
| 636 |
(0.9, 0.5, 0.2, 1.0), # Orange
|
| 637 |
]
|
| 638 |
-
|
| 639 |
frames = []
|
| 640 |
for i, verts in enumerate(verts_list):
|
| 641 |
fixed_center = fixed_centers[i] if fixed_centers else None
|
| 642 |
color = colors[i % len(colors)]
|
| 643 |
label = labels[i] if i < len(labels) else ""
|
| 644 |
-
|
| 645 |
frame = render_single_frame(
|
| 646 |
verts, faces, label=label, color=color,
|
| 647 |
fixed_center=fixed_center, camera_distance=camera_distance,
|
|
@@ -649,10 +650,8 @@ def render_side_by_side_frame(
|
|
| 649 |
frame_height=frame_height, bg_color=bg_color
|
| 650 |
)
|
| 651 |
frames.append(frame)
|
| 652 |
-
|
| 653 |
-
return np.concatenate(frames, axis=1)
|
| 654 |
-
|
| 655 |
|
|
|
|
| 656 |
def render_video(
|
| 657 |
verts: np.ndarray,
|
| 658 |
faces: np.ndarray,
|
|
@@ -668,17 +667,19 @@ def render_video(
|
|
| 668 |
"""Render single avatar animation to video."""
|
| 669 |
if not ensure_pyrender():
|
| 670 |
raise RuntimeError("PyRender not available")
|
| 671 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
# Trim last few frames to remove end-of-sequence artifacts
|
| 673 |
T_total = verts.shape[0]
|
| 674 |
trim_amount = min(8, int(T_total * 0.15))
|
| 675 |
T = max(5, T_total - trim_amount)
|
| 676 |
-
|
| 677 |
-
# Compute fixed camera target from first frame
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
fixed_center = verts_rotated_first.mean(axis=0)
|
| 681 |
-
|
| 682 |
frames = []
|
| 683 |
for t in range(T):
|
| 684 |
frame = render_single_frame(
|
|
@@ -689,16 +690,14 @@ def render_video(
|
|
| 689 |
)
|
| 690 |
for _ in range(slowdown):
|
| 691 |
frames.append(frame)
|
| 692 |
-
|
| 693 |
# Save video
|
| 694 |
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 695 |
-
|
| 696 |
if len(frames) > 0:
|
| 697 |
imageio.mimsave(output_path, frames, fps=fps, codec='libx264', quality=8)
|
| 698 |
-
|
| 699 |
-
return output_path
|
| 700 |
-
|
| 701 |
|
|
|
|
| 702 |
def render_comparison_video(
|
| 703 |
verts1: np.ndarray,
|
| 704 |
faces1: np.ndarray,
|
|
@@ -717,24 +716,27 @@ def render_comparison_video(
|
|
| 717 |
"""Render side-by-side comparison video."""
|
| 718 |
if not ensure_pyrender():
|
| 719 |
raise RuntimeError("PyRender not available")
|
| 720 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 721 |
# Match lengths and trim
|
| 722 |
T_total = min(verts1.shape[0], verts2.shape[0])
|
| 723 |
trim_amount = min(8, int(T_total * 0.15))
|
| 724 |
T = max(5, T_total - trim_amount)
|
| 725 |
-
|
| 726 |
verts1 = verts1[:T]
|
| 727 |
verts2 = verts2[:T]
|
| 728 |
-
|
| 729 |
-
# Compute fixed camera targets
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
fixed_center1 = verts1_rotated_first.mean(axis=0)
|
| 734 |
-
fixed_center2 = verts2_rotated_first.mean(axis=0)
|
| 735 |
-
|
| 736 |
labels = [label1, label2]
|
| 737 |
-
|
| 738 |
frames = []
|
| 739 |
for t in range(T):
|
| 740 |
frame = render_side_by_side_frame(
|
|
@@ -745,15 +747,14 @@ def render_comparison_video(
|
|
| 745 |
)
|
| 746 |
for _ in range(slowdown):
|
| 747 |
frames.append(frame)
|
| 748 |
-
|
| 749 |
# Save video
|
| 750 |
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 751 |
-
|
| 752 |
if len(frames) > 0:
|
| 753 |
imageio.mimsave(output_path, frames, fps=fps, codec='libx264', quality=8)
|
| 754 |
-
|
| 755 |
-
return output_path
|
| 756 |
|
|
|
|
| 757 |
# =====================================================================
|
| 758 |
# Main Processing Functions
|
| 759 |
# =====================================================================
|
|
@@ -761,80 +762,74 @@ def generate_verts_for_word(word: str, pid: str) -> tuple:
|
|
| 761 |
"""Generate vertices and faces for a word-PID pair."""
|
| 762 |
generated_tokens = generate_motion_tokens(word, pid)
|
| 763 |
token_ids = parse_motion_tokens(generated_tokens)
|
| 764 |
-
|
| 765 |
if not token_ids:
|
| 766 |
return None, None, generated_tokens
|
| 767 |
-
|
| 768 |
if _model_cache["vqvae_model"] is None or _model_cache["smplx_model"] is None:
|
| 769 |
return None, None, generated_tokens
|
| 770 |
-
|
| 771 |
params = decode_tokens_to_params(token_ids)
|
| 772 |
if params.shape[0] == 0:
|
| 773 |
return None, None, generated_tokens
|
| 774 |
-
|
| 775 |
verts, faces = params_to_vertices(params)
|
| 776 |
return verts, faces, generated_tokens
|
| 777 |
-
|
| 778 |
-
|
| 779 |
def generate_video_for_word(word: str, pid: str) -> tuple:
|
| 780 |
"""Generate video and tokens for a word. Returns (video_path, tokens)."""
|
| 781 |
verts, faces, tokens = generate_verts_for_word(word, pid)
|
| 782 |
-
|
| 783 |
if verts is None:
|
| 784 |
return None, tokens
|
| 785 |
-
|
| 786 |
# Generate unique filename
|
| 787 |
video_filename = f"motion_{word}_{pid}_{uuid.uuid4().hex[:8]}.mp4"
|
| 788 |
video_path = os.path.join(OUTPUT_DIR, video_filename)
|
| 789 |
-
|
| 790 |
render_video(verts, faces, video_path, label=f"{pid}")
|
| 791 |
return video_path, tokens
|
| 792 |
-
|
| 793 |
-
|
| 794 |
def process_word(word: str):
|
| 795 |
"""Main processing: generate side-by-side comparison video for two random PIDs."""
|
| 796 |
if not word or not word.strip():
|
| 797 |
return None, ""
|
| 798 |
-
|
| 799 |
word = word.strip().lower()
|
| 800 |
-
|
| 801 |
pids = get_random_pids_for_word(word, 2)
|
| 802 |
-
|
| 803 |
if not pids:
|
| 804 |
return None, f"Word '{word}' not found in dataset"
|
| 805 |
-
|
| 806 |
if len(pids) == 1:
|
| 807 |
pids = [pids[0], pids[0]]
|
| 808 |
-
|
| 809 |
try:
|
| 810 |
verts1, faces1, tokens1 = generate_verts_for_word(word, pids[0])
|
| 811 |
verts2, faces2, tokens2 = generate_verts_for_word(word, pids[1])
|
| 812 |
-
|
| 813 |
if verts1 is None and verts2 is None:
|
| 814 |
return None, tokens1 or tokens2 or "Failed to generate motion"
|
| 815 |
-
|
| 816 |
# Generate unique filename
|
| 817 |
video_filename = f"comparison_{word}_{uuid.uuid4().hex[:8]}.mp4"
|
| 818 |
video_path = os.path.join(OUTPUT_DIR, video_filename)
|
| 819 |
-
|
| 820 |
if verts1 is None:
|
| 821 |
render_video(verts2, faces2, video_path, label=pids[1])
|
| 822 |
return video_path, tokens2
|
| 823 |
if verts2 is None:
|
| 824 |
render_video(verts1, faces1, video_path, label=pids[0])
|
| 825 |
return video_path, tokens1
|
| 826 |
-
|
| 827 |
render_comparison_video(
|
| 828 |
verts1, faces1, verts2, faces2, video_path,
|
| 829 |
label1=pids[0], label2=pids[1]
|
| 830 |
)
|
| 831 |
combined_tokens = f"[{pids[0]}] {tokens1}\n\n[{pids[1]}] {tokens2}"
|
| 832 |
return video_path, combined_tokens
|
| 833 |
-
|
| 834 |
except Exception as e:
|
| 835 |
return None, f"Error: {str(e)[:100]}"
|
| 836 |
-
|
| 837 |
-
|
| 838 |
def get_example_video(word: str, pid: str):
|
| 839 |
"""Get pre-computed example video."""
|
| 840 |
key = f"{word}_{pid}"
|
|
@@ -843,65 +838,67 @@ def get_example_video(word: str, pid: str):
|
|
| 843 |
return cached.get("video_path"), cached.get("tokens", "")
|
| 844 |
video_path, tokens = generate_video_for_word(word, pid)
|
| 845 |
return video_path, tokens
|
| 846 |
-
|
| 847 |
# =====================================================================
|
| 848 |
# Gradio Interface
|
| 849 |
# =====================================================================
|
| 850 |
def create_gradio_interface():
|
| 851 |
-
|
| 852 |
custom_css = """
|
| 853 |
.gradio-container { max-width: 1400px !important; }
|
| 854 |
-
.example-row { margin-top: 15px; padding: 12px; background:
|
|
|
|
| 855 |
.example-word-label {
|
| 856 |
text-align: center;
|
| 857 |
font-size: 28px !important;
|
| 858 |
font-weight: bold !important;
|
| 859 |
-
color:
|
|
|
|
| 860 |
margin: 10px 0 !important;
|
| 861 |
padding: 10px !important;
|
| 862 |
}
|
| 863 |
.example-variant-label {
|
| 864 |
text-align: center;
|
| 865 |
font-size: 14px !important;
|
| 866 |
-
color:
|
|
|
|
| 867 |
margin-bottom: 10px !important;
|
| 868 |
}
|
| 869 |
"""
|
| 870 |
-
|
| 871 |
example_list = list(_example_cache.values()) if _example_cache else []
|
| 872 |
-
|
| 873 |
with gr.Blocks(title="SignMotionGPT", css=custom_css, theme=gr.themes.Default()) as demo:
|
| 874 |
-
|
| 875 |
gr.Markdown("# SignMotionGPT Demo")
|
| 876 |
gr.Markdown("Text-to-Sign Language Motion Generation with Variant Comparison")
|
| 877 |
gr.Markdown("*High-quality PyRender visualization with proper hand motion rendering*")
|
| 878 |
-
|
| 879 |
with gr.Row():
|
| 880 |
with gr.Column(scale=1, min_width=280):
|
| 881 |
gr.Markdown("### Input")
|
| 882 |
-
|
| 883 |
word_input = gr.Textbox(
|
| 884 |
label="Word",
|
| 885 |
placeholder="Enter a word from the dataset...",
|
| 886 |
lines=1, max_lines=1
|
| 887 |
)
|
| 888 |
-
|
| 889 |
generate_btn = gr.Button("Generate Motion", variant="primary", size="lg")
|
| 890 |
-
|
| 891 |
gr.Markdown("---")
|
| 892 |
gr.Markdown("### Generated Tokens")
|
| 893 |
-
|
| 894 |
tokens_output = gr.Textbox(
|
| 895 |
label="Motion Tokens (both variants)",
|
| 896 |
lines=8,
|
| 897 |
interactive=False,
|
| 898 |
show_copy_button=True
|
| 899 |
)
|
| 900 |
-
|
| 901 |
if _word_pid_map:
|
| 902 |
sample_words = list(_word_pid_map.keys())[:10]
|
| 903 |
gr.Markdown(f"**Available words:** {', '.join(sample_words)}, ...")
|
| 904 |
-
|
| 905 |
with gr.Column(scale=2, min_width=700):
|
| 906 |
gr.Markdown("### Motion Comparison (Two Signer Variants)")
|
| 907 |
video_output = gr.Video(
|
|
@@ -909,11 +906,11 @@ def create_gradio_interface():
|
|
| 909 |
autoplay=True,
|
| 910 |
show_download_button=True
|
| 911 |
)
|
| 912 |
-
|
| 913 |
if example_list:
|
| 914 |
gr.Markdown("---")
|
| 915 |
gr.Markdown("### Pre-computed Examples")
|
| 916 |
-
|
| 917 |
for item in example_list:
|
| 918 |
word, pid = item['word'], item['pid']
|
| 919 |
with gr.Row(elem_classes="example-row"):
|
|
@@ -921,37 +918,36 @@ def create_gradio_interface():
|
|
| 921 |
gr.HTML(f'<div class="example-word-label">{word.upper()}</div>')
|
| 922 |
gr.HTML(f'<div class="example-variant-label">Variant: {pid}</div>')
|
| 923 |
example_btn = gr.Button("Load Example", size="sm", variant="secondary")
|
| 924 |
-
|
| 925 |
with gr.Column(scale=3, min_width=500):
|
| 926 |
example_video = gr.Video(
|
| 927 |
label=f"Example: {word}",
|
| 928 |
autoplay=False,
|
| 929 |
show_download_button=True
|
| 930 |
)
|
| 931 |
-
|
| 932 |
example_btn.click(
|
| 933 |
fn=lambda w=word, p=pid: get_example_video(w, p),
|
| 934 |
inputs=[],
|
| 935 |
outputs=[example_video, tokens_output]
|
| 936 |
)
|
| 937 |
-
|
| 938 |
gr.Markdown("---")
|
| 939 |
gr.Markdown("*SignMotionGPT: LLM-based sign language motion generation with PyRender visualization*")
|
| 940 |
-
|
| 941 |
generate_btn.click(
|
| 942 |
fn=process_word,
|
| 943 |
inputs=[word_input],
|
| 944 |
outputs=[video_output, tokens_output]
|
| 945 |
)
|
| 946 |
-
|
| 947 |
word_input.submit(
|
| 948 |
fn=process_word,
|
| 949 |
inputs=[word_input],
|
| 950 |
outputs=[video_output, tokens_output]
|
| 951 |
)
|
| 952 |
-
|
| 953 |
-
return demo
|
| 954 |
|
|
|
|
| 955 |
# =====================================================================
|
| 956 |
# Main Entry Point for HuggingFace Spaces
|
| 957 |
# =====================================================================
|
|
@@ -965,20 +961,16 @@ print(f"Output Directory: {OUTPUT_DIR}")
|
|
| 965 |
print(f"Dataset: {DATASET_PATH}")
|
| 966 |
print(f"PyRender Available: {PYRENDER_AVAILABLE}")
|
| 967 |
print("="*60 + "\n")
|
| 968 |
-
|
| 969 |
# Initialize models at startup
|
| 970 |
initialize_models()
|
| 971 |
-
|
| 972 |
# Pre-compute example animations
|
| 973 |
precompute_examples()
|
| 974 |
-
|
| 975 |
# Create and launch interface
|
| 976 |
demo = create_gradio_interface()
|
| 977 |
-
|
| 978 |
if __name__ == "__main__":
|
| 979 |
# Launch with settings for HuggingFace Spaces
|
| 980 |
demo.launch(
|
| 981 |
server_name="0.0.0.0",
|
| 982 |
server_port=7860,
|
| 983 |
share=False
|
| 984 |
-
)
|
|
|
|
| 6 |
# IMPORTANT: Set OpenGL platform BEFORE any OpenGL imports (for headless rendering)
|
| 7 |
import os
|
| 8 |
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
|
|
|
| 9 |
import sys
|
| 10 |
import re
|
| 11 |
import json
|
|
|
|
| 14 |
import tempfile
|
| 15 |
import uuid
|
| 16 |
from pathlib import Path
|
|
|
|
| 17 |
import torch
|
| 18 |
import numpy as np
|
|
|
|
| 19 |
warnings.filterwarnings("ignore")
|
|
|
|
| 20 |
# =====================================================================
|
| 21 |
# Configuration for HuggingFace Spaces
|
| 22 |
# =====================================================================
|
|
|
|
| 25 |
OUTPUT_DIR = os.path.join(WORK_DIR, "outputs")
|
| 26 |
os.makedirs(DATA_DIR, exist_ok=True)
|
| 27 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
| 28 |
# Path definitions
|
| 29 |
DATASET_PATH = os.path.join(DATA_DIR, "motion_llm_dataset.json")
|
| 30 |
VQVAE_CHECKPOINT = os.path.join(DATA_DIR, "vqvae_model.pt")
|
| 31 |
STATS_PATH = os.path.join(DATA_DIR, "vqvae_stats.pt")
|
| 32 |
SMPLX_MODEL_DIR = os.path.join(DATA_DIR, "smplx_models")
|
|
|
|
| 33 |
# HuggingFace model config
|
| 34 |
HF_REPO_ID = os.environ.get("HF_REPO_ID", "rdz-falcon/SignMotionGPTfit-archive")
|
| 35 |
HF_SUBFOLDER = os.environ.get("HF_SUBFOLDER", "stage2_v2/epoch-030")
|
|
|
|
| 36 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 37 |
# Generation parameters
|
| 38 |
M_START = "<M_START>"
|
| 39 |
M_END = "<M_END>"
|
|
|
|
| 41 |
INFERENCE_TEMPERATURE = 0.7
|
| 42 |
INFERENCE_TOP_K = 50
|
| 43 |
INFERENCE_REPETITION_PENALTY = 1.2
|
|
|
|
| 44 |
# VQ-VAE parameters
|
| 45 |
SMPL_DIM = 182
|
| 46 |
CODEBOOK_SIZE = 512
|
|
|
|
| 49 |
width=512, depth=3, down_t=2, stride_t=2,
|
| 50 |
dilation_growth_rate=3, activation='relu', norm=None, quantizer="ema_reset"
|
| 51 |
)
|
|
|
|
| 52 |
PARAM_DIMS = [10, 63, 45, 45, 3, 10, 3, 3]
|
| 53 |
+
PARAM_NAMES = ["betas", "body_pose", "left_hand_pose", "right_hand_pose",
|
| 54 |
+
"trans", "expression", "jaw_pose", "eye_pose"]
|
|
|
|
| 55 |
# Visualization defaults
|
| 56 |
AVATAR_COLOR = (0.36, 0.78, 0.36, 1.0) # Green color as RGBA
|
| 57 |
VIDEO_FPS = 15
|
| 58 |
VIDEO_SLOWDOWN = 2
|
| 59 |
FRAME_WIDTH = 544 # Must be divisible by 16 for video codec compatibility
|
| 60 |
FRAME_HEIGHT = 720
|
|
|
|
| 61 |
# =====================================================================
|
| 62 |
# Install/Import Dependencies
|
| 63 |
# =====================================================================
|
|
|
|
| 66 |
except ImportError:
|
| 67 |
os.system("pip install -q gradio>=4.0.0")
|
| 68 |
import gradio as gr
|
|
|
|
| 69 |
try:
|
| 70 |
import smplx
|
| 71 |
except ImportError:
|
| 72 |
os.system("pip install -q smplx==0.1.28")
|
| 73 |
import smplx
|
|
|
|
| 74 |
# PyRender for high-quality rendering
|
| 75 |
PYRENDER_AVAILABLE = False
|
| 76 |
try:
|
|
|
|
| 80 |
PYRENDER_AVAILABLE = True
|
| 81 |
except ImportError:
|
| 82 |
pass
|
|
|
|
| 83 |
try:
|
| 84 |
import imageio
|
| 85 |
except ImportError:
|
| 86 |
os.system("pip install -q imageio[ffmpeg]")
|
| 87 |
import imageio
|
|
|
|
| 88 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 89 |
import torch.nn.functional as F
|
|
|
|
| 90 |
# =====================================================================
|
| 91 |
# Import VQ-VAE architecture
|
| 92 |
# =====================================================================
|
|
|
|
| 96 |
sys.path.insert(0, parent_dir)
|
| 97 |
if current_dir not in sys.path:
|
| 98 |
sys.path.insert(0, current_dir)
|
|
|
|
| 99 |
try:
|
| 100 |
from mGPT.archs.mgpt_vq import VQVae
|
| 101 |
except ImportError as e:
|
| 102 |
print(f"Warning: Could not import VQVae: {e}")
|
| 103 |
VQVae = None
|
|
|
|
| 104 |
# =====================================================================
|
| 105 |
# Global Cache
|
| 106 |
# =====================================================================
|
|
|
|
| 112 |
"stats": (None, None),
|
| 113 |
"initialized": False
|
| 114 |
}
|
|
|
|
| 115 |
_word_pid_map = {}
|
| 116 |
_example_cache = {}
|
|
|
|
| 117 |
# =====================================================================
|
| 118 |
# PyRender Setup
|
| 119 |
# =====================================================================
|
|
|
|
| 122 |
global PYRENDER_AVAILABLE, trimesh, pyrender, Image, ImageDraw, ImageFont
|
| 123 |
if PYRENDER_AVAILABLE:
|
| 124 |
return True
|
| 125 |
+
|
| 126 |
print("Installing pyrender dependencies...")
|
| 127 |
if os.path.exists("/etc/debian_version"):
|
| 128 |
os.system("apt-get update -qq && apt-get install -qq -y libegl1-mesa-dev libgles2-mesa-dev > /dev/null 2>&1")
|
| 129 |
os.system("pip install -q trimesh pyrender PyOpenGL PyOpenGL_accelerate Pillow")
|
| 130 |
+
|
| 131 |
try:
|
| 132 |
import trimesh
|
| 133 |
import pyrender
|
|
|
|
| 137 |
except ImportError as e:
|
| 138 |
print(f"Could not install pyrender: {e}")
|
| 139 |
return False
|
|
|
|
| 140 |
# =====================================================================
|
| 141 |
# Dataset Loading - Word to PID mapping
|
| 142 |
# =====================================================================
|
| 143 |
def load_word_pid_mapping():
|
| 144 |
"""Load the dataset and build word -> PIDs mapping."""
|
| 145 |
global _word_pid_map
|
| 146 |
+
|
| 147 |
if not os.path.exists(DATASET_PATH):
|
| 148 |
print(f"Dataset not found: {DATASET_PATH}")
|
| 149 |
return
|
| 150 |
+
|
| 151 |
print(f"Loading dataset from: {DATASET_PATH}")
|
| 152 |
try:
|
| 153 |
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
|
| 154 |
data = json.load(f)
|
| 155 |
+
|
| 156 |
for entry in data:
|
| 157 |
word = entry.get('word', '').lower()
|
| 158 |
pid = entry.get('participant_id', '')
|
|
|
|
| 160 |
if word not in _word_pid_map:
|
| 161 |
_word_pid_map[word] = set()
|
| 162 |
_word_pid_map[word].add(pid)
|
| 163 |
+
|
| 164 |
for word in _word_pid_map:
|
| 165 |
_word_pid_map[word] = sorted(list(_word_pid_map[word]))
|
| 166 |
+
|
| 167 |
print(f"Loaded {len(_word_pid_map)} unique words from dataset")
|
| 168 |
except Exception as e:
|
| 169 |
print(f"Error loading dataset: {e}")
|
|
|
|
|
|
|
| 170 |
def get_pids_for_word(word: str) -> list:
|
| 171 |
"""Get valid PIDs for a word from the dataset."""
|
| 172 |
word = word.lower().strip()
|
| 173 |
return _word_pid_map.get(word, [])
|
|
|
|
|
|
|
| 174 |
def get_random_pids_for_word(word: str, count: int = 2) -> list:
|
| 175 |
"""Get random PIDs for a word. Returns up to 'count' PIDs."""
|
| 176 |
pids = get_pids_for_word(word)
|
|
|
|
| 179 |
if len(pids) <= count:
|
| 180 |
return pids
|
| 181 |
return random.sample(pids, count)
|
|
|
|
|
|
|
| 182 |
def get_example_words_with_pids(count: int = 3) -> list:
|
| 183 |
"""Get example words with valid PIDs from dataset."""
|
| 184 |
examples = []
|
| 185 |
preferred = ['push', 'passport', 'library', 'send', 'college', 'help', 'thank', 'hello']
|
| 186 |
+
|
| 187 |
for word in preferred:
|
| 188 |
pids = get_pids_for_word(word)
|
| 189 |
if pids:
|
| 190 |
examples.append((word, pids[0]))
|
| 191 |
if len(examples) >= count:
|
| 192 |
break
|
| 193 |
+
|
| 194 |
if len(examples) < count:
|
| 195 |
available = [w for w in _word_pid_map.keys() if w not in [e[0] for e in examples]]
|
| 196 |
random.shuffle(available)
|
| 197 |
for word in available[:count - len(examples)]:
|
| 198 |
pids = _word_pid_map[word]
|
| 199 |
examples.append((word, pids[0]))
|
|
|
|
|
|
|
| 200 |
|
| 201 |
+
return examples
|
| 202 |
# =====================================================================
|
| 203 |
# VQ-VAE Wrapper
|
| 204 |
# =====================================================================
|
|
|
|
| 211 |
nfeats=smpl_dim, code_num=codebook_size, code_dim=code_dim,
|
| 212 |
output_emb_width=code_dim, **kwargs
|
| 213 |
)
|
|
|
|
| 214 |
# =====================================================================
|
| 215 |
# Model Loading Functions
|
| 216 |
# =====================================================================
|
| 217 |
def load_llm_model():
|
| 218 |
print(f"Loading LLM from: {HF_REPO_ID}/{HF_SUBFOLDER}")
|
| 219 |
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
|
| 220 |
+
|
| 221 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 222 |
HF_REPO_ID, subfolder=HF_SUBFOLDER, trust_remote_code=True, token=token
|
| 223 |
)
|
|
|
|
| 233 |
model.eval()
|
| 234 |
print(f"LLM loaded (vocab size: {len(tokenizer)})")
|
| 235 |
return model, tokenizer
|
|
|
|
|
|
|
| 236 |
def load_vqvae_model():
|
| 237 |
if not os.path.exists(VQVAE_CHECKPOINT):
|
| 238 |
print(f"VQ-VAE checkpoint not found: {VQVAE_CHECKPOINT}")
|
|
|
|
| 245 |
model.eval()
|
| 246 |
print(f"VQ-VAE loaded")
|
| 247 |
return model
|
|
|
|
|
|
|
| 248 |
def load_stats():
|
| 249 |
if not os.path.exists(STATS_PATH):
|
| 250 |
return None, None
|
|
|
|
| 253 |
if torch.is_tensor(mean): mean = mean.cpu().numpy()
|
| 254 |
if torch.is_tensor(std): std = std.cpu().numpy()
|
| 255 |
return mean, std
|
|
|
|
|
|
|
| 256 |
def load_smplx_model():
|
| 257 |
if not os.path.exists(SMPLX_MODEL_DIR):
|
| 258 |
print(f"SMPL-X directory not found: {SMPLX_MODEL_DIR}")
|
|
|
|
| 266 |
).to(DEVICE)
|
| 267 |
print(f"SMPL-X loaded")
|
| 268 |
return model
|
|
|
|
|
|
|
| 269 |
def initialize_models():
|
| 270 |
global _model_cache
|
| 271 |
if _model_cache["initialized"]:
|
| 272 |
return
|
| 273 |
+
|
| 274 |
print("\n" + "="*60)
|
| 275 |
print(" Initializing SignMotionGPT Models")
|
| 276 |
print("="*60)
|
| 277 |
+
|
| 278 |
load_word_pid_mapping()
|
| 279 |
+
|
| 280 |
_model_cache["llm_model"], _model_cache["llm_tokenizer"] = load_llm_model()
|
| 281 |
+
|
| 282 |
try:
|
| 283 |
_model_cache["vqvae_model"] = load_vqvae_model()
|
| 284 |
_model_cache["stats"] = load_stats()
|
| 285 |
_model_cache["smplx_model"] = load_smplx_model()
|
| 286 |
except Exception as e:
|
| 287 |
print(f"Could not load visualization models: {e}")
|
| 288 |
+
|
| 289 |
# Ensure PyRender is available
|
| 290 |
ensure_pyrender()
|
| 291 |
+
|
| 292 |
_model_cache["initialized"] = True
|
| 293 |
print("All models initialized")
|
| 294 |
print("="*60)
|
|
|
|
|
|
|
| 295 |
def precompute_examples():
|
| 296 |
"""Pre-compute animations for example words at startup."""
|
| 297 |
global _example_cache
|
| 298 |
+
|
| 299 |
if not _model_cache["initialized"]:
|
| 300 |
return
|
| 301 |
+
|
| 302 |
examples = get_example_words_with_pids(3)
|
| 303 |
+
|
| 304 |
print(f"\nPre-computing {len(examples)} example animations...")
|
| 305 |
+
|
| 306 |
for word, pid in examples:
|
| 307 |
key = f"{word}_{pid}"
|
| 308 |
print(f" Computing: {word} ({pid})...")
|
|
|
|
| 313 |
except Exception as e:
|
| 314 |
print(f" Failed: {word} - {e}")
|
| 315 |
_example_cache[key] = {"video_path": None, "tokens": "", "word": word, "pid": pid}
|
|
|
|
|
|
|
| 316 |
|
| 317 |
+
print("Example pre-computation complete\n")
|
| 318 |
# =====================================================================
|
| 319 |
# Motion Generation Functions
|
| 320 |
# =====================================================================
|
| 321 |
def generate_motion_tokens(word: str, variant: str) -> str:
|
| 322 |
model = _model_cache["llm_model"]
|
| 323 |
tokenizer = _model_cache["llm_tokenizer"]
|
| 324 |
+
|
| 325 |
if model is None or tokenizer is None:
|
| 326 |
raise RuntimeError("LLM model not loaded")
|
| 327 |
+
|
| 328 |
prompt = f"Instruction: Generate motion for word '{word}' with variant '{variant}'.\nMotion: "
|
| 329 |
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
| 330 |
+
|
| 331 |
with torch.no_grad():
|
| 332 |
output = model.generate(
|
| 333 |
**inputs, max_new_tokens=100, do_sample=True,
|
|
|
|
| 337 |
eos_token_id=tokenizer.convert_tokens_to_ids(M_END),
|
| 338 |
early_stopping=True
|
| 339 |
)
|
| 340 |
+
|
| 341 |
decoded = tokenizer.decode(output[0], skip_special_tokens=False)
|
| 342 |
motion_part = decoded.split("Motion: ")[-1] if "Motion: " in decoded else decoded
|
| 343 |
return motion_part.strip()
|
|
|
|
|
|
|
| 344 |
def parse_motion_tokens(token_str: str) -> list:
|
| 345 |
if isinstance(token_str, (list, tuple, np.ndarray)):
|
| 346 |
return [int(x) for x in token_str]
|
| 347 |
if not isinstance(token_str, str):
|
| 348 |
return []
|
| 349 |
+
|
| 350 |
matches = re.findall(r'<M(\d+)>', token_str)
|
| 351 |
if matches:
|
| 352 |
return [int(x) for x in matches]
|
| 353 |
+
|
| 354 |
matches = re.findall(r'<motion_(\d+)>', token_str)
|
| 355 |
if matches:
|
| 356 |
return [int(x) for x in matches]
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
+
return []
|
| 359 |
def decode_tokens_to_params(tokens: list) -> np.ndarray:
|
| 360 |
vqvae_model = _model_cache["vqvae_model"]
|
| 361 |
mean, std = _model_cache["stats"]
|
| 362 |
+
|
| 363 |
if vqvae_model is None or not tokens:
|
| 364 |
return np.zeros((0, SMPL_DIM), dtype=np.float32)
|
| 365 |
+
|
| 366 |
idx = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
| 367 |
T_q = idx.shape[1]
|
| 368 |
quantizer = vqvae_model.vqvae.quantizer
|
| 369 |
+
|
| 370 |
if hasattr(quantizer, "codebook"):
|
| 371 |
codebook = quantizer.codebook.to(DEVICE)
|
| 372 |
code_dim = codebook.shape[1]
|
| 373 |
else:
|
| 374 |
code_dim = CODE_DIM
|
| 375 |
+
|
| 376 |
x_quantized = None
|
| 377 |
if hasattr(quantizer, "dequantize"):
|
| 378 |
try:
|
|
|
|
| 386 |
x_quantized = dq.permute(0, 2, 1).contiguous()
|
| 387 |
except Exception:
|
| 388 |
pass
|
| 389 |
+
|
| 390 |
if x_quantized is None:
|
| 391 |
if not hasattr(quantizer, "codebook"):
|
| 392 |
return np.zeros((0, SMPL_DIM), dtype=np.float32)
|
| 393 |
with torch.no_grad():
|
| 394 |
emb = codebook[idx]
|
| 395 |
x_quantized = emb.permute(0, 2, 1).contiguous()
|
| 396 |
+
|
| 397 |
with torch.no_grad():
|
| 398 |
x_dec = vqvae_model.vqvae.decoder(x_quantized)
|
| 399 |
smpl_out = vqvae_model.vqvae.postprocess(x_dec)
|
| 400 |
params_np = smpl_out.squeeze(0).cpu().numpy()
|
| 401 |
+
|
| 402 |
if (mean is not None) and (std is not None):
|
| 403 |
params_np = (params_np * np.array(std).reshape(1, -1)) + np.array(mean).reshape(1, -1)
|
|
|
|
|
|
|
|
|
|
| 404 |
|
| 405 |
+
return params_np
|
| 406 |
+
def params_to_vertices(params_seq: np.ndarray, smplx_model, batch_size=32) -> tuple:
|
| 407 |
+
"""
|
| 408 |
+
Convert SMPL-X parameters to 3D vertices.
|
| 409 |
+
FIXED: Properly handles jaw_pose and expression to prevent lip/mouth issues.
|
| 410 |
+
"""
|
| 411 |
+
# Compute parameter slicing indices
|
| 412 |
starts = np.cumsum([0] + PARAM_DIMS[:-1])
|
| 413 |
ends = starts + np.array(PARAM_DIMS)
|
| 414 |
+
|
| 415 |
T = params_seq.shape[0]
|
| 416 |
all_verts = []
|
| 417 |
+
|
| 418 |
+
# Infer number of body joints
|
| 419 |
num_body_joints = getattr(smplx_model, "NUM_BODY_JOINTS", 21)
|
| 420 |
|
| 421 |
with torch.no_grad():
|
| 422 |
for s in range(0, T, batch_size):
|
| 423 |
+
batch = params_seq[s:s+batch_size] # (B, SMPL_DIM)
|
| 424 |
B = batch.shape[0]
|
| 425 |
|
| 426 |
+
# Extract parameters
|
| 427 |
+
np_parts = {}
|
| 428 |
+
for name, st, ed in zip(PARAM_NAMES, starts, ends):
|
| 429 |
+
np_parts[name] = batch[:, st:ed].astype(np.float32)
|
| 430 |
+
|
| 431 |
+
# Convert to tensors
|
| 432 |
+
tensor_parts = {
|
| 433 |
+
name: torch.from_numpy(arr).to(DEVICE)
|
| 434 |
+
for name, arr in np_parts.items()
|
| 435 |
+
}
|
| 436 |
|
| 437 |
+
# Handle body pose (may or may not include global orient)
|
|
|
|
| 438 |
body_t = tensor_parts['body_pose']
|
| 439 |
L_body = body_t.shape[1]
|
| 440 |
expected_no_go = num_body_joints * 3
|
|
|
|
| 447 |
global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
|
| 448 |
body_pose_only = body_t
|
| 449 |
else:
|
| 450 |
+
# Best-effort fallback
|
| 451 |
if L_body > expected_no_go:
|
| 452 |
global_orient = body_t[:, :3].contiguous()
|
| 453 |
body_pose_only = body_t[:, 3:].contiguous()
|
| 454 |
else:
|
| 455 |
+
pad_len = max(0, expected_no_go - L_body)
|
| 456 |
+
body_pose_only = F.pad(body_t, (0, pad_len))
|
| 457 |
global_orient = torch.zeros((B, 3), dtype=torch.float32, device=DEVICE)
|
| 458 |
|
| 459 |
+
# ✅ FIX: Ensure jaw_pose is properly shaped (should be B x 3)
|
| 460 |
+
jaw_pose = tensor_parts['jaw_pose']
|
| 461 |
+
if jaw_pose.shape[1] != 3:
|
| 462 |
+
print(f"Warning: jaw_pose has shape {jaw_pose.shape}, padding/trimming to (B, 3)")
|
| 463 |
+
if jaw_pose.shape[1] < 3:
|
| 464 |
+
jaw_pose = F.pad(jaw_pose, (0, 3 - jaw_pose.shape[1]))
|
| 465 |
+
else:
|
| 466 |
+
jaw_pose = jaw_pose[:, :3]
|
| 467 |
+
jaw_pose = jaw_pose.contiguous()
|
| 468 |
+
|
| 469 |
+
# ✅ FIX: Ensure expression is properly shaped (should be B x 10)
|
| 470 |
+
expression = tensor_parts['expression']
|
| 471 |
+
if expression.shape[1] != 10:
|
| 472 |
+
print(f"Warning: expression has shape {expression.shape}, padding/trimming to (B, 10)")
|
| 473 |
+
if expression.shape[1] < 10:
|
| 474 |
+
expression = F.pad(expression, (0, 10 - expression.shape[1]))
|
| 475 |
+
else:
|
| 476 |
+
expression = expression[:, :10]
|
| 477 |
+
expression = expression.contiguous()
|
| 478 |
+
|
| 479 |
+
# ✅ FIX: Ensure eye_pose is properly shaped (should be B x 3)
|
| 480 |
+
eye_pose = tensor_parts['eye_pose']
|
| 481 |
+
if eye_pose.shape[1] != 3:
|
| 482 |
+
print(f"Warning: eye_pose has shape {eye_pose.shape}, padding/trimming to (B, 3)")
|
| 483 |
+
if eye_pose.shape[1] < 3:
|
| 484 |
+
eye_pose = F.pad(eye_pose, (0, 3 - eye_pose.shape[1]))
|
| 485 |
+
else:
|
| 486 |
+
eye_pose = eye_pose[:, :3]
|
| 487 |
+
eye_pose = eye_pose.contiguous()
|
| 488 |
+
|
| 489 |
+
# Call SMPL-X with validated parameters
|
| 490 |
out = smplx_model(
|
| 491 |
+
betas=tensor_parts['betas'],
|
| 492 |
+
global_orient=global_orient,
|
| 493 |
+
body_pose=body_pose_only,
|
| 494 |
+
left_hand_pose=tensor_parts['left_hand_pose'],
|
| 495 |
+
right_hand_pose=tensor_parts['right_hand_pose'],
|
| 496 |
+
expression=expression, # ✅ Using validated expression
|
| 497 |
+
jaw_pose=jaw_pose, # ✅ Using validated jaw_pose
|
| 498 |
+
leye_pose=eye_pose, # ✅ Using validated eye_pose
|
| 499 |
+
reye_pose=eye_pose, # ✅ Using validated eye_pose
|
| 500 |
+
transl=tensor_parts['trans'],
|
| 501 |
+
return_verts=True
|
| 502 |
)
|
| 503 |
+
|
| 504 |
+
verts = out.vertices.detach().cpu().numpy() # (B, V, 3)
|
| 505 |
+
all_verts.append(verts)
|
| 506 |
|
| 507 |
+
verts_all = np.concatenate(all_verts, axis=0) # (T, V, 3)
|
| 508 |
+
faces = smplx_model.faces.astype(np.int32)
|
| 509 |
+
|
| 510 |
+
return verts_all, faces
|
| 511 |
# =====================================================================
|
| 512 |
# PyRender Visualization Functions
|
| 513 |
# =====================================================================
|
|
|
|
| 526 |
"""Render a single mesh frame using PyRender."""
|
| 527 |
if not PYRENDER_AVAILABLE:
|
| 528 |
raise RuntimeError("PyRender not available")
|
| 529 |
+
|
| 530 |
# Check for invalid vertices
|
| 531 |
if not np.isfinite(verts).all():
|
| 532 |
blank = np.ones((frame_height, frame_width, 3), dtype=np.uint8) * 200
|
| 533 |
return blank
|
| 534 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
# Create scene
|
| 536 |
scene = pyrender.Scene(bg_color=bg_color, ambient_light=[0.4, 0.4, 0.4])
|
| 537 |
+
|
| 538 |
# Material
|
| 539 |
material = pyrender.MetallicRoughnessMaterial(
|
| 540 |
metallicFactor=0.0,
|
|
|
|
| 542 |
alphaMode='OPAQUE',
|
| 543 |
baseColorFactor=color
|
| 544 |
)
|
| 545 |
+
|
| 546 |
+
# Create mesh
|
| 547 |
+
mesh = trimesh.Trimesh(vertices=verts, faces=faces)
|
| 548 |
mesh_render = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=True)
|
| 549 |
scene.add(mesh_render)
|
| 550 |
+
|
| 551 |
+
# Compute center for camera positioning
|
| 552 |
+
mesh_center = verts.mean(axis=0)
|
| 553 |
camera_target = fixed_center if fixed_center is not None else mesh_center
|
| 554 |
+
|
| 555 |
# Camera setup
|
| 556 |
camera = pyrender.IntrinsicsCamera(
|
| 557 |
fx=focal_length, fy=focal_length,
|
| 558 |
cx=frame_width / 2, cy=frame_height / 2,
|
| 559 |
znear=0.1, zfar=20.0
|
| 560 |
)
|
| 561 |
+
|
| 562 |
+
# Camera pose: After 180-degree rotation around X-axis, coordinate system changes
|
| 563 |
+
# Camera should be positioned in front (negative Z) with flipped orientation
|
| 564 |
+
# This matches visualize.py and ensures proper face visibility
|
| 565 |
camera_pose = np.eye(4)
|
| 566 |
camera_pose[0, 3] = camera_target[0] # Center X
|
| 567 |
camera_pose[1, 3] = camera_target[1] # Center Y (body center)
|
| 568 |
camera_pose[2, 3] = camera_target[2] - camera_distance # In front (negative Z)
|
| 569 |
+
|
| 570 |
# Camera orientation: flip to look at subject (SOKE-style)
|
| 571 |
# This rotation makes camera look toward +Z (at the subject)
|
| 572 |
camera_pose[:3, :3] = np.array([
|
|
|
|
| 574 |
[0, -1, 0],
|
| 575 |
[0, 0, -1]
|
| 576 |
])
|
| 577 |
+
|
| 578 |
scene.add(camera, pose=camera_pose)
|
| 579 |
+
|
| 580 |
# Lighting
|
| 581 |
key_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
|
| 582 |
key_pose = np.eye(4)
|
| 583 |
key_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(-30), np.radians(-20), 0)[:3, :3]
|
| 584 |
scene.add(key_light, pose=key_pose)
|
| 585 |
+
|
| 586 |
fill_light = pyrender.DirectionalLight(color=[0.9, 0.9, 1.0], intensity=1.5)
|
| 587 |
fill_pose = np.eye(4)
|
| 588 |
fill_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(-20), np.radians(30), 0)[:3, :3]
|
| 589 |
scene.add(fill_light, pose=fill_pose)
|
| 590 |
+
|
| 591 |
rim_light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=2.0)
|
| 592 |
rim_pose = np.eye(4)
|
| 593 |
rim_pose[:3, :3] = trimesh.transformations.euler_matrix(np.radians(30), np.radians(180), 0)[:3, :3]
|
| 594 |
scene.add(rim_light, pose=rim_pose)
|
| 595 |
+
|
| 596 |
# Render
|
| 597 |
renderer = pyrender.OffscreenRenderer(viewport_width=frame_width, viewport_height=frame_height, point_size=1.0)
|
| 598 |
color_img, _ = renderer.render(scene)
|
| 599 |
renderer.delete()
|
| 600 |
+
|
| 601 |
# Add label
|
| 602 |
if label:
|
| 603 |
img = Image.fromarray(color_img)
|
| 604 |
draw = ImageDraw.Draw(img)
|
| 605 |
+
|
| 606 |
try:
|
| 607 |
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 20)
|
| 608 |
except:
|
| 609 |
font = ImageFont.load_default()
|
| 610 |
+
|
| 611 |
text_width = len(label) * 10 + 20
|
| 612 |
draw.rectangle([10, 10, 10 + text_width, 35], fill=(0, 0, 0, 180))
|
| 613 |
draw.text((15, 12), label, fill=(255, 255, 255), font=font)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
|
| 615 |
+
color_img = np.array(img)
|
| 616 |
|
| 617 |
+
return color_img
|
| 618 |
def render_side_by_side_frame(
|
| 619 |
verts_list: list,
|
| 620 |
faces: np.ndarray,
|
|
|
|
| 629 |
"""Render multiple meshes side-by-side for comparison."""
|
| 630 |
if not PYRENDER_AVAILABLE:
|
| 631 |
raise RuntimeError("PyRender not available")
|
| 632 |
+
|
| 633 |
# Colors for each avatar
|
| 634 |
colors = [
|
| 635 |
(0.3, 0.8, 0.4, 1.0), # Green
|
| 636 |
(0.3, 0.6, 0.9, 1.0), # Blue
|
| 637 |
(0.9, 0.5, 0.2, 1.0), # Orange
|
| 638 |
]
|
| 639 |
+
|
| 640 |
frames = []
|
| 641 |
for i, verts in enumerate(verts_list):
|
| 642 |
fixed_center = fixed_centers[i] if fixed_centers else None
|
| 643 |
color = colors[i % len(colors)]
|
| 644 |
label = labels[i] if i < len(labels) else ""
|
| 645 |
+
|
| 646 |
frame = render_single_frame(
|
| 647 |
verts, faces, label=label, color=color,
|
| 648 |
fixed_center=fixed_center, camera_distance=camera_distance,
|
|
|
|
| 650 |
frame_height=frame_height, bg_color=bg_color
|
| 651 |
)
|
| 652 |
frames.append(frame)
|
|
|
|
|
|
|
|
|
|
| 653 |
|
| 654 |
+
return np.concatenate(frames, axis=1)
|
| 655 |
def render_video(
|
| 656 |
verts: np.ndarray,
|
| 657 |
faces: np.ndarray,
|
|
|
|
| 667 |
"""Render single avatar animation to video."""
|
| 668 |
if not ensure_pyrender():
|
| 669 |
raise RuntimeError("PyRender not available")
|
| 670 |
+
|
| 671 |
+
# Apply orientation fix: rotate 180 degrees around X-axis
|
| 672 |
+
verts = verts.copy()
|
| 673 |
+
verts[..., 1:] *= -1
|
| 674 |
+
|
| 675 |
# Trim last few frames to remove end-of-sequence artifacts
|
| 676 |
T_total = verts.shape[0]
|
| 677 |
trim_amount = min(8, int(T_total * 0.15))
|
| 678 |
T = max(5, T_total - trim_amount)
|
| 679 |
+
|
| 680 |
+
# Compute fixed camera target from first frame
|
| 681 |
+
fixed_center = verts[0].mean(axis=0)
|
| 682 |
+
|
|
|
|
|
|
|
| 683 |
frames = []
|
| 684 |
for t in range(T):
|
| 685 |
frame = render_single_frame(
|
|
|
|
| 690 |
)
|
| 691 |
for _ in range(slowdown):
|
| 692 |
frames.append(frame)
|
| 693 |
+
|
| 694 |
# Save video
|
| 695 |
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 696 |
+
|
| 697 |
if len(frames) > 0:
|
| 698 |
imageio.mimsave(output_path, frames, fps=fps, codec='libx264', quality=8)
|
|
|
|
|
|
|
|
|
|
| 699 |
|
| 700 |
+
return output_path
|
| 701 |
def render_comparison_video(
|
| 702 |
verts1: np.ndarray,
|
| 703 |
faces1: np.ndarray,
|
|
|
|
| 716 |
"""Render side-by-side comparison video."""
|
| 717 |
if not ensure_pyrender():
|
| 718 |
raise RuntimeError("PyRender not available")
|
| 719 |
+
|
| 720 |
+
# Apply orientation fix
|
| 721 |
+
verts1 = verts1.copy()
|
| 722 |
+
verts2 = verts2.copy()
|
| 723 |
+
verts1[..., 1:] *= -1
|
| 724 |
+
verts2[..., 1:] *= -1
|
| 725 |
+
|
| 726 |
# Match lengths and trim
|
| 727 |
T_total = min(verts1.shape[0], verts2.shape[0])
|
| 728 |
trim_amount = min(8, int(T_total * 0.15))
|
| 729 |
T = max(5, T_total - trim_amount)
|
| 730 |
+
|
| 731 |
verts1 = verts1[:T]
|
| 732 |
verts2 = verts2[:T]
|
| 733 |
+
|
| 734 |
+
# Compute fixed camera targets
|
| 735 |
+
fixed_center1 = verts1[0].mean(axis=0)
|
| 736 |
+
fixed_center2 = verts2[0].mean(axis=0)
|
| 737 |
+
|
|
|
|
|
|
|
|
|
|
| 738 |
labels = [label1, label2]
|
| 739 |
+
|
| 740 |
frames = []
|
| 741 |
for t in range(T):
|
| 742 |
frame = render_side_by_side_frame(
|
|
|
|
| 747 |
)
|
| 748 |
for _ in range(slowdown):
|
| 749 |
frames.append(frame)
|
| 750 |
+
|
| 751 |
# Save video
|
| 752 |
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 753 |
+
|
| 754 |
if len(frames) > 0:
|
| 755 |
imageio.mimsave(output_path, frames, fps=fps, codec='libx264', quality=8)
|
|
|
|
|
|
|
| 756 |
|
| 757 |
+
return output_path
|
| 758 |
# =====================================================================
|
| 759 |
# Main Processing Functions
|
| 760 |
# =====================================================================
|
|
|
|
| 762 |
"""Generate vertices and faces for a word-PID pair."""
|
| 763 |
generated_tokens = generate_motion_tokens(word, pid)
|
| 764 |
token_ids = parse_motion_tokens(generated_tokens)
|
| 765 |
+
|
| 766 |
if not token_ids:
|
| 767 |
return None, None, generated_tokens
|
| 768 |
+
|
| 769 |
if _model_cache["vqvae_model"] is None or _model_cache["smplx_model"] is None:
|
| 770 |
return None, None, generated_tokens
|
| 771 |
+
|
| 772 |
params = decode_tokens_to_params(token_ids)
|
| 773 |
if params.shape[0] == 0:
|
| 774 |
return None, None, generated_tokens
|
| 775 |
+
|
| 776 |
verts, faces = params_to_vertices(params)
|
| 777 |
return verts, faces, generated_tokens
|
|
|
|
|
|
|
| 778 |
def generate_video_for_word(word: str, pid: str) -> tuple:
|
| 779 |
"""Generate video and tokens for a word. Returns (video_path, tokens)."""
|
| 780 |
verts, faces, tokens = generate_verts_for_word(word, pid)
|
| 781 |
+
|
| 782 |
if verts is None:
|
| 783 |
return None, tokens
|
| 784 |
+
|
| 785 |
# Generate unique filename
|
| 786 |
video_filename = f"motion_{word}_{pid}_{uuid.uuid4().hex[:8]}.mp4"
|
| 787 |
video_path = os.path.join(OUTPUT_DIR, video_filename)
|
| 788 |
+
|
| 789 |
render_video(verts, faces, video_path, label=f"{pid}")
|
| 790 |
return video_path, tokens
|
|
|
|
|
|
|
| 791 |
def process_word(word: str):
|
| 792 |
"""Main processing: generate side-by-side comparison video for two random PIDs."""
|
| 793 |
if not word or not word.strip():
|
| 794 |
return None, ""
|
| 795 |
+
|
| 796 |
word = word.strip().lower()
|
| 797 |
+
|
| 798 |
pids = get_random_pids_for_word(word, 2)
|
| 799 |
+
|
| 800 |
if not pids:
|
| 801 |
return None, f"Word '{word}' not found in dataset"
|
| 802 |
+
|
| 803 |
if len(pids) == 1:
|
| 804 |
pids = [pids[0], pids[0]]
|
| 805 |
+
|
| 806 |
try:
|
| 807 |
verts1, faces1, tokens1 = generate_verts_for_word(word, pids[0])
|
| 808 |
verts2, faces2, tokens2 = generate_verts_for_word(word, pids[1])
|
| 809 |
+
|
| 810 |
if verts1 is None and verts2 is None:
|
| 811 |
return None, tokens1 or tokens2 or "Failed to generate motion"
|
| 812 |
+
|
| 813 |
# Generate unique filename
|
| 814 |
video_filename = f"comparison_{word}_{uuid.uuid4().hex[:8]}.mp4"
|
| 815 |
video_path = os.path.join(OUTPUT_DIR, video_filename)
|
| 816 |
+
|
| 817 |
if verts1 is None:
|
| 818 |
render_video(verts2, faces2, video_path, label=pids[1])
|
| 819 |
return video_path, tokens2
|
| 820 |
if verts2 is None:
|
| 821 |
render_video(verts1, faces1, video_path, label=pids[0])
|
| 822 |
return video_path, tokens1
|
| 823 |
+
|
| 824 |
render_comparison_video(
|
| 825 |
verts1, faces1, verts2, faces2, video_path,
|
| 826 |
label1=pids[0], label2=pids[1]
|
| 827 |
)
|
| 828 |
combined_tokens = f"[{pids[0]}] {tokens1}\n\n[{pids[1]}] {tokens2}"
|
| 829 |
return video_path, combined_tokens
|
| 830 |
+
|
| 831 |
except Exception as e:
|
| 832 |
return None, f"Error: {str(e)[:100]}"
|
|
|
|
|
|
|
| 833 |
def get_example_video(word: str, pid: str):
|
| 834 |
"""Get pre-computed example video."""
|
| 835 |
key = f"{word}_{pid}"
|
|
|
|
| 838 |
return cached.get("video_path"), cached.get("tokens", "")
|
| 839 |
video_path, tokens = generate_video_for_word(word, pid)
|
| 840 |
return video_path, tokens
|
|
|
|
| 841 |
# =====================================================================
|
| 842 |
# Gradio Interface
|
| 843 |
# =====================================================================
|
| 844 |
def create_gradio_interface():
|
| 845 |
+
|
| 846 |
custom_css = """
|
| 847 |
.gradio-container { max-width: 1400px !important; }
|
| 848 |
+
.example-row { margin-top: 15px; padding: 12px; background:
|
| 849 |
+
#f8f9fa; border-radius: 6px; }
|
| 850 |
.example-word-label {
|
| 851 |
text-align: center;
|
| 852 |
font-size: 28px !important;
|
| 853 |
font-weight: bold !important;
|
| 854 |
+
color:
|
| 855 |
+
#2c3e50 !important;
|
| 856 |
margin: 10px 0 !important;
|
| 857 |
padding: 10px !important;
|
| 858 |
}
|
| 859 |
.example-variant-label {
|
| 860 |
text-align: center;
|
| 861 |
font-size: 14px !important;
|
| 862 |
+
color:
|
| 863 |
+
#7f8c8d !important;
|
| 864 |
margin-bottom: 10px !important;
|
| 865 |
}
|
| 866 |
"""
|
| 867 |
+
|
| 868 |
example_list = list(_example_cache.values()) if _example_cache else []
|
| 869 |
+
|
| 870 |
with gr.Blocks(title="SignMotionGPT", css=custom_css, theme=gr.themes.Default()) as demo:
|
| 871 |
+
|
| 872 |
gr.Markdown("# SignMotionGPT Demo")
|
| 873 |
gr.Markdown("Text-to-Sign Language Motion Generation with Variant Comparison")
|
| 874 |
gr.Markdown("*High-quality PyRender visualization with proper hand motion rendering*")
|
| 875 |
+
|
| 876 |
with gr.Row():
|
| 877 |
with gr.Column(scale=1, min_width=280):
|
| 878 |
gr.Markdown("### Input")
|
| 879 |
+
|
| 880 |
word_input = gr.Textbox(
|
| 881 |
label="Word",
|
| 882 |
placeholder="Enter a word from the dataset...",
|
| 883 |
lines=1, max_lines=1
|
| 884 |
)
|
| 885 |
+
|
| 886 |
generate_btn = gr.Button("Generate Motion", variant="primary", size="lg")
|
| 887 |
+
|
| 888 |
gr.Markdown("---")
|
| 889 |
gr.Markdown("### Generated Tokens")
|
| 890 |
+
|
| 891 |
tokens_output = gr.Textbox(
|
| 892 |
label="Motion Tokens (both variants)",
|
| 893 |
lines=8,
|
| 894 |
interactive=False,
|
| 895 |
show_copy_button=True
|
| 896 |
)
|
| 897 |
+
|
| 898 |
if _word_pid_map:
|
| 899 |
sample_words = list(_word_pid_map.keys())[:10]
|
| 900 |
gr.Markdown(f"**Available words:** {', '.join(sample_words)}, ...")
|
| 901 |
+
|
| 902 |
with gr.Column(scale=2, min_width=700):
|
| 903 |
gr.Markdown("### Motion Comparison (Two Signer Variants)")
|
| 904 |
video_output = gr.Video(
|
|
|
|
| 906 |
autoplay=True,
|
| 907 |
show_download_button=True
|
| 908 |
)
|
| 909 |
+
|
| 910 |
if example_list:
|
| 911 |
gr.Markdown("---")
|
| 912 |
gr.Markdown("### Pre-computed Examples")
|
| 913 |
+
|
| 914 |
for item in example_list:
|
| 915 |
word, pid = item['word'], item['pid']
|
| 916 |
with gr.Row(elem_classes="example-row"):
|
|
|
|
| 918 |
gr.HTML(f'<div class="example-word-label">{word.upper()}</div>')
|
| 919 |
gr.HTML(f'<div class="example-variant-label">Variant: {pid}</div>')
|
| 920 |
example_btn = gr.Button("Load Example", size="sm", variant="secondary")
|
| 921 |
+
|
| 922 |
with gr.Column(scale=3, min_width=500):
|
| 923 |
example_video = gr.Video(
|
| 924 |
label=f"Example: {word}",
|
| 925 |
autoplay=False,
|
| 926 |
show_download_button=True
|
| 927 |
)
|
| 928 |
+
|
| 929 |
example_btn.click(
|
| 930 |
fn=lambda w=word, p=pid: get_example_video(w, p),
|
| 931 |
inputs=[],
|
| 932 |
outputs=[example_video, tokens_output]
|
| 933 |
)
|
| 934 |
+
|
| 935 |
gr.Markdown("---")
|
| 936 |
gr.Markdown("*SignMotionGPT: LLM-based sign language motion generation with PyRender visualization*")
|
| 937 |
+
|
| 938 |
generate_btn.click(
|
| 939 |
fn=process_word,
|
| 940 |
inputs=[word_input],
|
| 941 |
outputs=[video_output, tokens_output]
|
| 942 |
)
|
| 943 |
+
|
| 944 |
word_input.submit(
|
| 945 |
fn=process_word,
|
| 946 |
inputs=[word_input],
|
| 947 |
outputs=[video_output, tokens_output]
|
| 948 |
)
|
|
|
|
|
|
|
| 949 |
|
| 950 |
+
return demo
|
| 951 |
# =====================================================================
|
| 952 |
# Main Entry Point for HuggingFace Spaces
|
| 953 |
# =====================================================================
|
|
|
|
| 961 |
print(f"Dataset: {DATASET_PATH}")
|
| 962 |
print(f"PyRender Available: {PYRENDER_AVAILABLE}")
|
| 963 |
print("="*60 + "\n")
|
|
|
|
| 964 |
# Initialize models at startup
|
| 965 |
initialize_models()
|
|
|
|
| 966 |
# Pre-compute example animations
|
| 967 |
precompute_examples()
|
|
|
|
| 968 |
# Create and launch interface
|
| 969 |
demo = create_gradio_interface()
|
|
|
|
| 970 |
if __name__ == "__main__":
|
| 971 |
# Launch with settings for HuggingFace Spaces
|
| 972 |
demo.launch(
|
| 973 |
server_name="0.0.0.0",
|
| 974 |
server_port=7860,
|
| 975 |
share=False
|
| 976 |
+
)
|