Spaces:
Running
Running
| """ | |
| Inference script for generating motion tokens from text prompts. | |
| Run after training to generate motion sequences from any text description. | |
| Usage: | |
| python inference.py --prompt "walking forward" --stage 3 | |
| python inference.py --prompt "dancing" --stage 2 --output motion_output.txt | |
| """ | |
| import os | |
| import argparse | |
| import torch | |
| from pathlib import Path | |
| from config import ( | |
| OUT_S1, OUT_S2, OUT_S3, MAX_SEQ_LEN, DATA_JSON_PATH, | |
| WORK_DIR | |
| ) | |
| from data import ( | |
| load_dataset, compute_length_stats, build_prompt_vocab, | |
| check_has_participant_id | |
| ) | |
| from model import setup_model_and_tokenizer, get_motion_token_info | |
| from generate import generate_t2m | |
| def load_trained_model(stage: int, device: torch.device): | |
| """ | |
| Load a trained model from a specific stage checkpoint. | |
| Args: | |
| stage: Stage number (1, 2, or 3) | |
| device: Device to load model on | |
| Returns: | |
| model, tokenizer, motion_token_ids, mot_begin_id, mot_end_id | |
| """ | |
| stage_dirs = {1: OUT_S1, 2: OUT_S2, 3: OUT_S3} | |
| stage_dir = stage_dirs.get(stage) | |
| if not stage_dir or not os.path.exists(stage_dir): | |
| raise FileNotFoundError( | |
| f"Stage {stage} checkpoint not found at {stage_dir}. " | |
| f"Train stage {stage} first." | |
| ) | |
| print(f"\nLoading Stage {stage} model from: {stage_dir}") | |
| # Load dataset to build vocab (needed for model setup) | |
| if not os.path.exists(DATA_JSON_PATH): | |
| raise FileNotFoundError(f"Dataset not found: {DATA_JSON_PATH}") | |
| raw_ds = load_dataset(DATA_JSON_PATH) | |
| # Build motion vocab | |
| def max_token_in_example(ex): | |
| return max(int(x) for x in ex["motion_tokens"].split()) | |
| global_max_id = max(max_token_in_example(ex) for ex in raw_ds) | |
| codebook_size = global_max_id + 1 | |
| # Check for participant IDs | |
| has_pid = check_has_participant_id(raw_ds) | |
| unique_pids = None | |
| if has_pid: | |
| unique_pids = sorted({str(ex["participant_id"]) for ex in raw_ds}) | |
| # Setup model and tokenizer with same config as training | |
| model, tokenizer, _ = setup_model_and_tokenizer(codebook_size, unique_pids) | |
| # Load trained weights from checkpoint | |
| # Try different checkpoint naming patterns | |
| possible_ckpts = [ | |
| os.path.join(stage_dir, "pytorch_model.bin"), | |
| os.path.join(stage_dir, "model.safetensors"), | |
| os.path.join(stage_dir, "adapter_model.bin"), | |
| ] | |
| loaded = False | |
| for ckpt_path in possible_ckpts: | |
| if os.path.exists(ckpt_path): | |
| print(f"Loading checkpoint: {ckpt_path}") | |
| # Unsloth/PEFT models save adapters separately | |
| # The model will auto-load from the directory | |
| loaded = True | |
| break | |
| if not loaded: | |
| print(f"⚠️ No explicit checkpoint file found, using model directory: {stage_dir}") | |
| # Move model to device | |
| model.to(device) | |
| model.eval() | |
| # Get motion token info | |
| motion_token_ids, mot_begin_id, mot_end_id = get_motion_token_info( | |
| tokenizer, codebook_size | |
| ) | |
| print(f"✅ Stage {stage} model loaded successfully") | |
| print(f" Vocabulary size: {len(tokenizer)}") | |
| print(f" Motion tokens: {len(motion_token_ids)}") | |
| return model, tokenizer, motion_token_ids, mot_begin_id, mot_end_id, raw_ds | |
| def inference( | |
| prompt: str, | |
| stage: int = 3, | |
| pid: str = None, | |
| output_file: str = None, | |
| per_prompt_vocab: bool = True, | |
| device: torch.device = None | |
| ): | |
| """ | |
| Generate motion tokens from a text prompt. | |
| Args: | |
| prompt: Text description of desired motion | |
| stage: Which training stage model to use (1, 2, or 3) | |
| pid: Optional participant ID for personalization | |
| output_file: Optional file to save output tokens | |
| per_prompt_vocab: Whether to use per-prompt vocabulary constraints | |
| device: Device to run inference on | |
| Returns: | |
| Generated motion token string | |
| """ | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("="*60) | |
| print(f"Motion Generation Inference - Stage {stage}") | |
| print("="*60) | |
| print(f"Prompt: '{prompt}'") | |
| print(f"Device: {device}") | |
| # Load model and dataset | |
| model, tokenizer, motion_token_ids, mot_begin_id, mot_end_id, raw_ds = load_trained_model(stage, device) | |
| # Compute length stats and prompt vocab | |
| print("\nComputing dataset statistics...") | |
| length_stats_by_text, global_median_len = compute_length_stats(raw_ds) | |
| prompt_vocab = build_prompt_vocab(raw_ds) | |
| has_pid = check_has_participant_id(raw_ds) | |
| # Generate motion tokens | |
| print(f"\nGenerating motion for: '{prompt}'") | |
| print(f"Per-prompt vocabulary: {per_prompt_vocab}") | |
| generated = generate_t2m( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt_text=prompt, | |
| mot_begin_id=mot_begin_id, | |
| mot_end_id=mot_end_id, | |
| motion_token_ids=motion_token_ids, | |
| length_stats_by_text=length_stats_by_text, | |
| global_median_len=global_median_len, | |
| prompt_vocab=prompt_vocab, | |
| has_pid=has_pid, | |
| per_prompt_vocab=per_prompt_vocab, | |
| pid=pid | |
| ) | |
| print("\n" + "="*60) | |
| print("Generated Motion:") | |
| print("="*60) | |
| print(generated) | |
| print("="*60) | |
| # Optionally save to file | |
| if output_file: | |
| output_path = Path(output_file) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(output_path, 'w') as f: | |
| f.write(generated) | |
| print(f"\n✅ Output saved to: {output_file}") | |
| return generated | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Generate motion tokens from text prompts using trained SignMotionGPT model" | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| required=True, | |
| help="Text description of the desired motion (e.g., 'walking forward', 'dancing')" | |
| ) | |
| parser.add_argument( | |
| "--stage", | |
| type=int, | |
| default=3, | |
| choices=[1, 2, 3], | |
| help="Which training stage model to use (1=motion-only, 2=multi-task, 3=T2M SFT, default=3)" | |
| ) | |
| parser.add_argument( | |
| "--pid", | |
| type=str, | |
| default=None, | |
| help="Optional participant ID for personalized generation (e.g., 'P40')" | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| default=None, | |
| help="Optional output file to save generated tokens" | |
| ) | |
| parser.add_argument( | |
| "--no-per-prompt-vocab", | |
| action="store_true", | |
| help="Disable per-prompt vocabulary constraints (allows all motion tokens)" | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default=None, | |
| choices=["cpu", "cuda", "cuda:0", "cuda:1"], | |
| help="Device to run inference on (default: auto-detect)" | |
| ) | |
| args = parser.parse_args() | |
| # Setup device | |
| if args.device: | |
| device = torch.device(args.device) | |
| else: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Run inference | |
| inference( | |
| prompt=args.prompt, | |
| stage=args.stage, | |
| pid=args.pid, | |
| output_file=args.output, | |
| per_prompt_vocab=not args.no_per_prompt_vocab, | |
| device=device | |
| ) | |
| if __name__ == "__main__": | |
| main() | |