File size: 2,431 Bytes
d3e9181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
093f1a4
 
 
d3e9181
093f1a4
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
import os
import torch

class BaseOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        self.initialized = False

    def initialize(self):
        self.parser.add_argument('--name', type=str, default="t2m_nlayer8_nhead6_ld384_ff1024_cdp0.1_rvq6ns", help='Name of this trial')

        self.parser.add_argument('--vq_name', type=str, default="rvq_nq1_dc512_nc512", help='Name of the rvq model.')

        self.parser.add_argument("--gpu_id", type=int, default=-1, help='GPU id')
        self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name, {t2m} for humanml3d, {kit} for kit-ml')
        self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here.')

        self.parser.add_argument('--latent_dim', type=int, default=384, help='Dimension of transformer latent.')
        self.parser.add_argument('--n_heads', type=int, default=6, help='Number of heads.')
        self.parser.add_argument('--n_layers', type=int, default=8, help='Number of attention layers.')
        self.parser.add_argument('--ff_size', type=int, default=1024, help='FF_Size')
        self.parser.add_argument('--dropout', type=float, default=0.2, help='Dropout ratio in transformer')

        self.parser.add_argument("--max_motion_length", type=int, default=196, help="Max length of motion")
        self.parser.add_argument("--unit_length", type=int, default=4, help="Downscale ratio of VQ")

        self.parser.add_argument('--force_mask', action="store_true", help='True: mask out conditions')

        self.initialized = True

    def parse(self):
        if not self.initialized:
            self.initialize()
        self.opt = self.parser.parse_args()

        # ----- device handling -----
        if self.opt.gpu_id >= 0 and torch.cuda.is_available():
            self.opt.device = torch.device(f"cuda:{self.opt.gpu_id}")
            torch.cuda.set_device(self.opt.gpu_id)
        elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            # for Apple Silicon / MPS, not used on HF but safe
            self.opt.device = torch.device("mps")
        else:
            # CPU fallback (Hugging Face CPU Basic)
            self.opt.device = torch.device("cpu")

        print("Using device:", self.opt.device)
        return self.opt