Spaces:
Build error
Build error
| import argparse | |
| import os | |
| import torch | |
| class BaseOptions(): | |
| def __init__(self): | |
| self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
| self.initialized = False | |
| def initialize(self): | |
| self.parser.add_argument('--name', type=str, default="t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns", help='Name of this trial') | |
| self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.') | |
| self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id') | |
| self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml') | |
| self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.') | |
| self.parser.add_argument('--latent_dim', type=int, default=384, help='Dimension of transformer latent.') | |
| self.parser.add_argument('--n_heads', type=int, default=6, help='Number of heads.') | |
| self.parser.add_argument('--n_layers', type=int, default=8, help='Number of attention layers.') | |
| self.parser.add_argument('--ff_size', type=int, default=1024, help='FF_Size') | |
| self.parser.add_argument('--dropout', type=float, default=0.2, help='Dropout ratio in transformer') | |
| self.parser.add_argument("--max_motion_length", type=int, default=196, help="Max length of motion") | |
| self.parser.add_argument("--unit_length", type=int, default=4, help="Downscale ratio of VQ") | |
| self.parser.add_argument('--force_mask', action="store_true", help='True: mask out conditions') | |
| self.initialized = True | |
| def parse(self): | |
| if not self.initialized: | |
| self.initialize() | |
| self.opt = self.parser.parse_args() | |
| # ----- device handling ----- | |
| if self.opt.gpu_id >= 0 and torch.cuda.is_available(): | |
| self.opt.device = torch.device(f"cuda:{self.opt.gpu_id}") | |
| torch.cuda.set_device(self.opt.gpu_id) | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| # for Apple Silicon / MPS, not used on HF but safe | |
| self.opt.device = torch.device("mps") | |
| else: | |
| # CPU fallback (Hugging Face CPU Basic) | |
| self.opt.device = torch.device("cpu") | |
| print("Using device:", self.opt.device) | |
| return self.opt | |