Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import json | |
| from pathlib import Path | |
| import yaml | |
| import torch | |
| from policy import Policy | |
| logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) | |
| log = logging.getLogger(__name__) | |
| def load_model_args(args): | |
| checkpoint = Path(args.checkpoint + '.ckpt') | |
| assert checkpoint.is_file(), f"no checkpoint file: {checkpoint}" | |
| args_path = Path(args.checkpoint + '.json') | |
| if args_path.is_file(): | |
| with open(args_path) as f: | |
| hparams = json.load(f) | |
| else: | |
| args_path = Path(args.checkpoint + '.yaml') | |
| with open(args_path) as f: | |
| hparams = yaml.safe_load(f) | |
| for key in ['init_model', 'clip_model_type', 'use_caption', 'use_style_reward', 'use_transformer_mapper', | |
| 'prefix_length', 'clipcap_num_layers', 'use_ptuning_v2']: | |
| if key in hparams: | |
| setattr(args, key, hparams[key]) | |
| args.loaded_init_model = True | |
| return args | |
| def load_model(args, device, finetune=False): | |
| log.info('loading model') | |
| policy = Policy(model_name=args.init_model, temperature=1.0, device=device, | |
| clipcap_path='None', fix_gpt=True, | |
| label_path=args.label_path, | |
| prefix_length=args.prefix_length, | |
| clipcap_num_layers=args.clipcap_num_layers, | |
| use_transformer_mapper=args.use_transformer_mapper, | |
| model_weight='None', use_label_prefix=args.use_label_prefix) | |
| ckpt = args.checkpoint + '.ckpt' | |
| state = torch.load(ckpt, map_location=torch.device('cpu')) | |
| policy_key = 'policy_model' | |
| if policy_key in state: | |
| policy.model.load_state_dict(state[policy_key]) | |
| else: | |
| weights = state['state_dict'] | |
| key = 'policy.model.' | |
| if not any(k for k in weights.keys() if k.startswith(key)): | |
| key = 'model.model.' | |
| weights = {k[len(key):]: v for k, v in weights.items() if k.startswith(key)} | |
| # weights = {k: v for k, v in weights.items() if k.startswith('clip_project.')} | |
| policy.model.load_state_dict(weights, strict=False) | |
| model = policy | |
| model = model.to(device) | |
| return model | |