Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import argparse | |
| import logging | |
| from pathlib import Path | |
| import torch | |
| logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) | |
| log = logging.getLogger(__name__) | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description='ESPER') | |
| parser.add_argument( | |
| '--init-model', type=str, default='gpt2', help='language model used for policy.') | |
| parser.add_argument( | |
| '--label_path', type=str, default='./data/esper_demo/labels_all.json', help='style label info file path') | |
| parser.add_argument( | |
| '--checkpoint', type=str, default='./data/esper_demo/ckpt', help='checkpoint file path') | |
| parser.add_argument( | |
| '--prefix_length', type=int, default=10, help='prefix length for the visual mapper') | |
| parser.add_argument( | |
| '--clipcap_num_layers', type=int, default=1, help='num_layers for the visual mapper') | |
| parser.add_argument( | |
| '--use_transformer_mapper', action='store_true', default=False, help='use transformer mapper instead of mlp') | |
| parser.add_argument( | |
| '--use_label_prefix', action='store_true', default=False, help='label as prefixes') | |
| parser.add_argument( | |
| '--clip_model_type', type=str, default='ViT-B/32', help='clip backbone type') | |
| parser.add_argument( | |
| '--infer_no_repeat_size', type=int, default=2, help="no repeat ngram size for inference") | |
| parser.add_argument( | |
| '--response-length', type=int, default=20, help='number of tokens to generate for each prompt.') | |
| parser.add_argument( | |
| '--port', type=int, default=None, help="port for the demo server") | |
| args = parser.parse_args() | |
| args.cuda = torch.cuda.is_available() | |
| if args.use_label_prefix: | |
| log.info(f'using label prefix') | |
| if args.checkpoint is not None: | |
| args.checkpoint = str(Path(args.checkpoint).resolve()) | |
| return args | |