File size: 3,195 Bytes
fe64bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from src.models.get_model import LOSSES
from src.models.tools.jointstypes import JOINTSTYPES


def add_model_options(parser):
    group = parser.add_argument_group('Model options')
    group.add_argument("--modelname", default='motionclip_transformer_rc_rcxyz_vel', help="Choice of the model, should be like motionclip_transformer_rc_rcxyz_vel")
    group.add_argument("--latent_dim", default=256, type=int, help="dimensionality of the latent space")
    group.add_argument("--lambda_rc", default=1.0, type=float, help="weight of the rc divergence loss")
    group.add_argument("--lambda_rcxyz", default=1.0, type=float, help="weight of the rc divergence loss")
    group.add_argument("--lambda_vel", default=1.0, type=float, help="weight of the vel divergence loss")
    group.add_argument("--lambda_velxyz", default=1.0, type=float, help="weight of the vel divergence loss")

    group.add_argument("--jointstype", default="vertices", choices=JOINTSTYPES, help="Jointstype for training with xyz")

    group.add_argument('--vertstrans', dest='vertstrans', action='store_true', help="Training with vertex translations in the SMPL mesh")
    group.add_argument('--no-vertstrans', dest='vertstrans', action='store_false', help="Training without vertex translations in the SMPL mesh")
    group.set_defaults(vertstrans=False)

    group.add_argument("--num_layers", default=8, type=int, help="Number of layers for GRU and transformer")
    group.add_argument("--activation", default="gelu", help="Activation for function for the transformer layers")

    # Ablations
    group.add_argument("--ablation", choices=[None, "average_encoder", "zandtime", "time_encoding", "concat_bias", "extra_token"],
                       help="Ablations for the transformer architechture")

    # CLIP related losses
    group.add_argument("--clip_image_losses", default='', help="supports multiple, underscore separated, valid options are [mse, ce]. if empty, will not train on images.")
    group.add_argument("--clip_text_losses", default='', help="supports multiple, underscore separated, valid options are [mse, ce]. if empty, will not train on text.")
    group.add_argument("--clip_lambda_mse", default=1.0, type=float, help="weight of the MSE loss, for both texts and images, if in use.")
    group.add_argument("--clip_lambda_ce", default=1.0, type=float, help="weight of the CROSS-ENTROPY loss, for both texts and images, if in use.")
    group.add_argument("--clip_lambda_cosine", default=1.0, type=float, help="weight of the Cosine-dist loss, for both texts and images, if in use.")

    group.add_argument("--clip_training", default='', type=str, help="pass something here to train clip")
    group.add_argument("--clip_layers", default=12, type=int, help="# of CLIP layers to use")


def parse_modelname(modelname):
    modeltype, archiname, *losses = modelname.split("_")

    if len(losses) == 0:
        raise NotImplementedError("You have to specify at least one loss function.")

    for loss in losses:
        if loss not in LOSSES:
            raise NotImplementedError("This loss is not implemented.")

    return modeltype, archiname, losses