import argparse import ast def get_default_params(model_name): # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) model_name = model_name.lower() if "vit" in model_name: return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} else: return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} class ParseKwargs(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): kw = {} for value in values: key, value = value.split('=') try: kw[key] = ast.literal_eval(value) except ValueError: kw[key] = str(value) # fallback to string (avoid need to escape on command line) setattr(namespace, self.dest, kw) def parse_args(args): parser = argparse.ArgumentParser() # newly added flags parser.add_argument( "--datasets-for-testing", type=str, nargs='*', default=None, help='A list of names of datasets for zero-shot classification testing' ) parser.add_argument( "--root-train-img-dir", type=str, default=None, help="Root directory to training images", ) parser.add_argument( "--root-val-img-dir", type=str, default=None, help="Root directory to validation images", ) parser.add_argument( "--classification-mode", type=str, default="multiclass", help="Choose either binary or multiclass", ) parser.add_argument( "--csv-class-key", type=str, default="label", help="For csv-like datasets, the name of the key for image labels (for classification)." ) parser.add_argument( "--debugging", default=False, action="store_true", help="" ) parser.add_argument( "--test-data-dir", type=str, default=None, help="Root directory to test datasets" ) # parser.add_argument( # "--random-rotation", # action="store_true", # default=False, # help="If True, add random rotation into image transform for data augmentation (only for training)." # ) parser.add_argument( "--test-data", type=str, default=None, help="Path to file(s) with test data (e.g., for testing zero-shot classification)", ) parser.add_argument( "--classnames", type=str, default=None, help="Path to txt file containing class names", ) parser.add_argument( "--test-data-name", type=str, default=None, help="The name of the test data (e.g., RSICD, EuroSat)", ) parser.add_argument( "--test-result-save-path", type=str, default='None', help="The path to save test results as a pickle file." ) # parser.add_argument( # "--csv-actual-label-key", # type=str, # default="binary", # help="If classification_model=binary, then specify the name of the key for actual binary labels (i.e., 0/1)." # ) # parser.add_argument( # "--alpha", # type=float, # default=None, # help="The regularization multiplier of logistic regression to try for linear probing. If None, do a search." # ) parser.add_argument( "--method", type=str, default='vanilla', choices=['vanilla', 'farslip1', 'farslip2'], help="alignment method" ) parser.add_argument( "--local-method", type=str, default='objects', choices=['randomcrops', 'objects', 'grids'], ) parser.add_argument( "--loss-type", type=str, nargs='+', default= ["global_itc", "local_itc", "distill"], help='A list of loss types' ) parser.add_argument( "--max-boxes", type=int, default=20, ) parser.add_argument( "--max-size", type=int, default=None, ) parser.add_argument( "--min-size", type=int, default=64, ) parser.add_argument( "--wandb-tags", type=str, nargs='*', default=[], ) parser.add_argument( "--EMA-momentum", type=float, default=0.99, ) parser.add_argument( "--distill-type", type=str, default='ema', choices=["ema", "active", "frozen"] ) parser.add_argument( "--long-clip", type=str, default='disable', choices=["disable", "load_from_clip", "load_from_scratch"] ) parser.add_argument( "--train-dataset-name", type=str, ) parser.add_argument( "--local-itc-align", type=str, choices=["cls", "pooled", "roi", "combined"] ) parser.add_argument( "--distill-align", type=str, choices=["roi2cls", "roi2pooled", "combined"] ) parser.add_argument( "--last-attn-type", type=str, choices=["MaskCLIP", "SegEarth"] ) parser.add_argument( "--find-unused-parameters", action="store_true", ) parser.add_argument( "--w-d", type=float, default=0.1 ) parser.add_argument( "--w-l", type=float, default=1.0 ) parser.add_argument( "--w-g", type=float, default=1.0 ) parser.add_argument( "--step", type=int, default=1 ) parser.add_argument( "--use-imagecrop-aug", action="store_true", ) parser.add_argument( "--mpcl-loss", action="store_true" ) parser.add_argument( "--global-type", type=str, default='ori', choices = ['ori', 'pooled_tokens', 'merged', 'merged2'] ) parser.add_argument( "--frozen-text", action="store_true" ) # Original open_clip flags parser.add_argument( "--train-data", type=str, default=None, help="Path to file(s) with training data. When using webdataset, multiple datasources can be combined using the `::` separator.", ) parser.add_argument( "--train-data-upsampling-factors", type=str, default=None, help=( "When using multiple data sources with webdataset and sampling with replacement, this can be used to upsample specific data sources. " "Similar to --train-data, this should be a string with as many numbers as there are data sources, separated by `::` (e.g. 1::2::0.5) " "By default, datapoints are sampled uniformly regardless of the dataset sizes." ) ) parser.add_argument( "--val-data", type=str, default=None, help="Path to file(s) with validation data", ) parser.add_argument( "--train-num-samples", type=int, default=None, help="Number of samples in dataset. Required for webdataset if not available in info file.", ) parser.add_argument( "--val-num-samples", type=int, default=None, help="Number of samples in dataset. Useful for webdataset if not available in info file.", ) parser.add_argument( "--train-dataset-type", choices=["webdataset", "csv", "synthetic", "auto", "json"], default="auto", help="Which type of dataset to process." ) parser.add_argument( "--val-dataset-type", choices=["webdataset", "csv", "synthetic", "auto", "json"], default="auto", help="Which type of dataset to process." ) parser.add_argument( "--dataset-resampled", default=False, action="store_true", help="Whether to use sampling with replacement for webdataset shard selection." ) parser.add_argument( "--csv-separator", type=str, default="\t", help="For csv-like datasets, which separator to use." ) parser.add_argument( "--csv-img-key", type=str, default="filepath", help="For csv-like datasets, the name of the key for the image paths." ) parser.add_argument( "--csv-caption-key", type=str, default="title", help="For csv-like datasets, the name of the key for the captions." ) parser.add_argument( "--imagenet-val", type=str, default=None, help="Path to imagenet val set for conducting zero shot evaluation.", ) parser.add_argument( "--imagenet-v2", type=str, default=None, help="Path to imagenet v2 for conducting zero shot evaluation.", ) parser.add_argument( "--cache-dir", type=str, default=None, help="Override system default cache path for model & tokenizer file downloads.", ) parser.add_argument( "--logs", type=str, default="./logs/", help="Where to store tensorboard logs. Use None to avoid storing logs.", ) parser.add_argument( "--log-local", action="store_true", default=False, help="log files on local master, otherwise global master only.", ) parser.add_argument( "--name", type=str, default=None, help="Optional identifier for the experiment when storing logs. Otherwise use current time.", ) parser.add_argument( "--workers", type=int, default=4, help="Number of dataloader workers per GPU." ) parser.add_argument( "--batch-size", type=int, default=64, help="Batch size per GPU." ) parser.add_argument( "--epochs", type=int, default=32, help="Number of epochs to train for." ) parser.add_argument( "--epochs-cooldown", type=int, default=None, help="When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards." ) parser.add_argument("--lr", type=float, default=None, help="Learning rate.") parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") parser.add_argument("--momentum", type=float, default=None, help="Momentum (for timm optimizers).") parser.add_argument( "--warmup", type=int, default=10000, help="Number of steps to warmup for." ) parser.add_argument( "--opt", type=str, default='adamw', help="Which optimizer to use. Choices are ['adamw', or any timm optimizer 'timm/{opt_name}']." ) parser.add_argument( "--use-bn-sync", default=False, action="store_true", help="Whether to use batch norm sync.") parser.add_argument( "--skip-scheduler", action="store_true", default=False, help="Use this flag to skip the learning rate decay.", ) parser.add_argument( "--lr-scheduler", type=str, default='cosine', help="LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine", ) parser.add_argument( "--lr-cooldown-end", type=float, default=0.0, help="End learning rate for cooldown schedule. Default: 0" ) parser.add_argument( "--lr-cooldown-power", type=float, default=1.0, help="Power for polynomial cooldown schedule. Default: 1.0 (linear decay)" ) parser.add_argument( "--save-frequency", type=int, default=1, help="How often to save checkpoints." ) parser.add_argument( "--save-most-recent", action="store_true", default=False, help="Always save the most recent model trained to epoch_latest.pt.", ) parser.add_argument( "--val-frequency", type=int, default=1, help="How often to run evaluation with val data." ) parser.add_argument( "--resume", default=None, type=str, help="path to latest checkpoint (default: none)", ) parser.add_argument( "--precision", choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "pure_bf16", "pure_fp16", "fp32"], default="amp", help="Floating point precision." ) parser.add_argument( "--model", type=str, default="RN50", help="Name of the vision backbone to use.", ) parser.add_argument( "--pretrained", default='', type=str, help="Use a pretrained CLIP model weights with the specified tag or file path.", ) parser.add_argument( "--pretrained-image", default=False, action='store_true', help="Load imagenet pretrained weights for image tower backbone if available.", ) parser.add_argument( "--lock-image", default=False, action='store_true', help="Lock full image tower by disabling gradients.", ) parser.add_argument( "--lock-image-unlocked-groups", type=int, default=0, help="Leave last n image tower layer groups unlocked.", ) parser.add_argument( "--lock-image-freeze-bn-stats", default=False, action='store_true', help="Freeze BatchNorm running stats in image tower for any locked layers.", ) parser.add_argument( '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override default image mean value of dataset') parser.add_argument( '--image-std', type=float, nargs='+', default=None, metavar='STD', help='Override default image std deviation of of dataset') parser.add_argument( '--image-interpolation', default=None, type=str, choices=['bicubic', 'bilinear', 'random'], help="Override default image resize interpolation" ) parser.add_argument( '--image-resize-mode', default=None, type=str, choices=['shortest', 'longest', 'squash'], help="Override default image resize (& crop) mode during inference" ) parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs) parser.add_argument( "--grad-checkpointing", default=False, action='store_true', help="Enable gradient checkpointing.", ) parser.add_argument( "--local-loss", default=False, action="store_true", help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)" ) parser.add_argument( "--gather-with-grad", default=False, action="store_true", help="enable full distributed gradient for feature gather" ) parser.add_argument( '--force-image-size', type=int, nargs='+', default=None, help='Override default image size' ) parser.add_argument( "--force-quick-gelu", default=False, action='store_true', help="Force use of QuickGELU activation for non-OpenAI transformer models.", ) parser.add_argument( "--force-patch-dropout", default=None, type=float, help="Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper", ) parser.add_argument( "--force-custom-text", default=False, action='store_true', help="Force use of CustomTextCLIP model (separate text-tower).", ) parser.add_argument( "--torchscript", default=False, action='store_true', help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", ) parser.add_argument( "--torchcompile", default=False, action='store_true', help="torch.compile() the model, requires pytorch 2.0 or later.", ) parser.add_argument( "--trace", default=False, action='store_true', help="torch.jit.trace the model for inference / eval only", ) parser.add_argument( "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps." ) parser.add_argument( "--device", default="cuda", type=str, help="Accelerator to use." ) # arguments for distributed training parser.add_argument( "--dist-url", default=None, type=str, help="url used to set up distributed training", ) parser.add_argument( "--dist-backend", default=None, type=str, help="distributed backend. \"nccl\" for GPU, \"hccl\" for Ascend NPU" ) parser.add_argument( "--report-to", default='', type=str, help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']" ) parser.add_argument( "--wandb-notes", default='', type=str, help="Notes if logging with wandb" ) parser.add_argument( "--wandb-project-name", type=str, default='open-clip', help="Name of the project if logging with wandb.", ) parser.add_argument( "--debug", default=False, action="store_true", help="If true, more information is logged." ) parser.add_argument( "--copy-codebase", default=False, action="store_true", help="If true, we copy the entire base on the log directory, and execute from there." ) parser.add_argument( "--horovod", default=False, action="store_true", help="Use horovod for distributed training." ) parser.add_argument( "--ddp-static-graph", default=False, action='store_true', help="Enable static graph optimization for DDP in PyTorch >= 1.11.", ) parser.add_argument( "--no-set-device-rank", default=False, action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)." ) parser.add_argument( "--seed", type=int, default=0, help="Default random seed." ) parser.add_argument( "--grad-clip-norm", type=float, default=None, help="Gradient clip." ) parser.add_argument( "--lock-text", default=False, action='store_true', help="Lock full text tower by disabling gradients.", ) parser.add_argument( "--lock-text-unlocked-layers", type=int, default=0, help="Leave last n text tower layer groups unlocked.", ) parser.add_argument( "--lock-text-freeze-layer-norm", default=False, action='store_true', help="Freeze LayerNorm running stats in text tower for any locked layers.", ) parser.add_argument( "--log-every-n-steps", type=int, default=100, help="Log every n steps to tensorboard/console/wandb.", ) parser.add_argument( "--coca-caption-loss-weight", type=float, default=2.0, help="Weight assigned to caption loss in CoCa." ) parser.add_argument( "--coca-contrastive-loss-weight", type=float, default=1.0, help="Weight assigned to contrastive loss when training CoCa." ) parser.add_argument( "--remote-sync", type=str, default=None, help="Optinoally sync with a remote path specified by this arg", ) parser.add_argument( "--remote-sync-frequency", type=int, default=300, help="How frequently to sync to a remote directly if --remote-sync is not None.", ) parser.add_argument( "--remote-sync-protocol", choices=["s3", "fsspec"], default="s3", help="How to do the remote sync backup if --remote-sync is not None.", ) parser.add_argument( "--delete-previous-checkpoint", default=False, action="store_true", help="If true, delete previous checkpoint after storing a new one." ) parser.add_argument( "--distill-model", default=None, help='Which model arch to distill from, if any.' ) parser.add_argument( "--distill-pretrained", default=None, help='Which pre-trained weights to distill from, if any.' ) parser.add_argument( "--use-bnb-linear", default=None, help='Replace the network linear layers from the bitsandbytes library. ' 'Allows int8 training/inference, etc.' ) parser.add_argument( "--siglip", default=False, action="store_true", help='Use SigLip (sigmoid) loss.' ) parser.add_argument( "--loss-dist-impl", default=None, type=str, help='A string to specify a specific distributed loss implementation.' ) args = parser.parse_args(args) if 'timm' not in args.opt: # set default opt params based on model name (only if timm optimizer not used) default_params = get_default_params(args.model) for name, val in default_params.items(): if getattr(args, name) is None: setattr(args, name, val) return args