Spaces:
Runtime error
Runtime error
| import torch | |
| from omegaconf import OmegaConf | |
| import os | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| print(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | |
| project_root = os.path.abspath(os.path.join(src_root, '..')) | |
| from utils.inference_utils import set_all_seeds, fix_state_dict | |
| from model.gaussian_diffusion import GaussianDiffusion | |
| from model.unet import Unet | |
| from utils.normalize import set_up_normalization | |
| from utils.constants import TO_24 | |
| set_all_seeds(135) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| import clip | |
| text_embedder, _ = clip.load("ViT-B/32", device=device) | |
| text_embedder.eval() | |
| def print_config(config): | |
| print(OmegaConf.to_yaml(config)) | |
| def getmodel(model_used, device, model_root, use_step=False, is_disc=False, config=None): | |
| model = Unet( | |
| dim_model=config.dim_model, | |
| num_heads=config.num_heads, | |
| num_layers=config.num_layers, | |
| dropout_p=config.dropout_p, | |
| dim_input=config.dim_input, | |
| dim_output=config.dim_output, | |
| text_emb=config.text_emb, | |
| device=device, | |
| Disc = is_disc, | |
| ).to(device) | |
| model_path = os.path.join(model_root, f'model_h3d_epoch{model_used}.pth') | |
| if use_step: | |
| model_path = os.path.join(model_root, f'model_h3d_step{model_used}.pth') | |
| print("==>", model_path) | |
| if torch.cuda.is_available(): | |
| state_dict = torch.load(model_path) | |
| else: | |
| state_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
| fixed_state_dict = fix_state_dict(state_dict)['model_state_dict'] | |
| fixed_state_dict = fix_state_dict(fixed_state_dict) | |
| model.load_state_dict(fixed_state_dict) | |
| model.eval() | |
| return model | |
| base_config = OmegaConf.load(os.path.join(src_root, "configs/base.yaml")) | |
| regen_config = OmegaConf.load(os.path.join(src_root, "configs/inference/regen.yaml")) | |
| regen_config = OmegaConf.merge(base_config, regen_config) | |
| style_transfer_config = OmegaConf.load(os.path.join(src_root, "configs/inference/style_transfer.yaml")) | |
| style_transfer_config = OmegaConf.merge(base_config, style_transfer_config) | |
| adjustment_config = OmegaConf.load(os.path.join(src_root, "configs/inference/adjustment.yaml")) | |
| adjustment_config = OmegaConf.merge(base_config, adjustment_config) | |
| models = { | |
| 'regen': getmodel(regen_config.model_used, | |
| device=device, | |
| model_root=os.path.join(project_root, regen_config.model_path, regen_config.task), | |
| use_step=False, | |
| is_disc=False, | |
| config = regen_config.unet, | |
| ), | |
| 'regen_disc': getmodel(regen_config.disc_model_used, | |
| device=device, | |
| model_root=os.path.join(project_root, regen_config.disc_model_path, regen_config.task), | |
| use_step=True, | |
| is_disc=True, | |
| config = regen_config.unet, | |
| ), | |
| 'style_transfer': getmodel(style_transfer_config.model_used, | |
| device=device, | |
| model_root=os.path.join(project_root, style_transfer_config.model_path, style_transfer_config.task), | |
| use_step=False, | |
| is_disc=False, | |
| config = style_transfer_config.unet, | |
| ), | |
| 'style_transfer_disc': getmodel(style_transfer_config.disc_model_used, | |
| device=device, | |
| model_root=os.path.join(project_root, style_transfer_config.disc_model_path, style_transfer_config.task), | |
| use_step=True, | |
| is_disc=True, | |
| config = style_transfer_config.unet, | |
| ), | |
| 'adjustment': getmodel(adjustment_config.model_used, | |
| device=device, | |
| model_root=os.path.join(project_root, adjustment_config.model_path, adjustment_config.task), | |
| use_step=False, | |
| is_disc=False, | |
| config = adjustment_config.unet, | |
| ), | |
| 'adjustment_disc': getmodel(adjustment_config.disc_model_used, | |
| device=device, | |
| model_root=os.path.join(project_root, adjustment_config.disc_model_path, adjustment_config.task), | |
| use_step=True, | |
| is_disc=True, | |
| config = adjustment_config.unet, | |
| ), | |
| } | |
| diffuser = GaussianDiffusion(device=device, | |
| fix_mode=base_config.diffusion.fix_mode, | |
| text_emb=base_config.diffusion.text_emb, | |
| fixed_frames=base_config.diffusion.fixed_frames, | |
| seq_len=base_config.diffusion.seq_len, | |
| timesteps=base_config.diffusion.timesteps, | |
| beta_schedule=base_config.diffusion.beta_schedule) | |
| normalize, denormalize = set_up_normalization(device=device, seq_len=base_config.seq_len, scale=3, | |
| norm_path=os.path.abspath(os.path.join(os.path.dirname(__file__), '../../data/norm_scaled.npy'))) | |
| test_configs = { | |
| 'batch_size': 1, | |
| 'seq_len': base_config.seq_len, | |
| 'channels': base_config.channels, | |
| 'fixed_frame': base_config.fixed_frame, | |
| 'use_cfg': base_config.use_cfg, | |
| 'cfg_alpha': regen_config.cfg_alpha, | |
| 'cg_alpha': regen_config.cg_alpha, | |
| 'cg_diffusion_steps': regen_config.cg_diffusion_steps, | |
| } |