motionReFit / src /app /setup_models.py
Yzy00518's picture
Upload src/app/setup_models.py with huggingface_hub
1ce0df3
import torch
from omegaconf import OmegaConf
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
print(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
src_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
project_root = os.path.abspath(os.path.join(src_root, '..'))
from utils.inference_utils import set_all_seeds, fix_state_dict
from model.gaussian_diffusion import GaussianDiffusion
from model.unet import Unet
from utils.normalize import set_up_normalization
from utils.constants import TO_24
set_all_seeds(135)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import clip
text_embedder, _ = clip.load("ViT-B/32", device=device)
text_embedder.eval()
def print_config(config):
print(OmegaConf.to_yaml(config))
def getmodel(model_used, device, model_root, use_step=False, is_disc=False, config=None):
model = Unet(
dim_model=config.dim_model,
num_heads=config.num_heads,
num_layers=config.num_layers,
dropout_p=config.dropout_p,
dim_input=config.dim_input,
dim_output=config.dim_output,
text_emb=config.text_emb,
device=device,
Disc = is_disc,
).to(device)
model_path = os.path.join(model_root, f'model_h3d_epoch{model_used}.pth')
if use_step:
model_path = os.path.join(model_root, f'model_h3d_step{model_used}.pth')
print("==>", model_path)
if torch.cuda.is_available():
state_dict = torch.load(model_path)
else:
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
fixed_state_dict = fix_state_dict(state_dict)['model_state_dict']
fixed_state_dict = fix_state_dict(fixed_state_dict)
model.load_state_dict(fixed_state_dict)
model.eval()
return model
base_config = OmegaConf.load(os.path.join(src_root, "configs/base.yaml"))
regen_config = OmegaConf.load(os.path.join(src_root, "configs/inference/regen.yaml"))
regen_config = OmegaConf.merge(base_config, regen_config)
style_transfer_config = OmegaConf.load(os.path.join(src_root, "configs/inference/style_transfer.yaml"))
style_transfer_config = OmegaConf.merge(base_config, style_transfer_config)
adjustment_config = OmegaConf.load(os.path.join(src_root, "configs/inference/adjustment.yaml"))
adjustment_config = OmegaConf.merge(base_config, adjustment_config)
models = {
'regen': getmodel(regen_config.model_used,
device=device,
model_root=os.path.join(project_root, regen_config.model_path, regen_config.task),
use_step=False,
is_disc=False,
config = regen_config.unet,
),
'regen_disc': getmodel(regen_config.disc_model_used,
device=device,
model_root=os.path.join(project_root, regen_config.disc_model_path, regen_config.task),
use_step=True,
is_disc=True,
config = regen_config.unet,
),
'style_transfer': getmodel(style_transfer_config.model_used,
device=device,
model_root=os.path.join(project_root, style_transfer_config.model_path, style_transfer_config.task),
use_step=False,
is_disc=False,
config = style_transfer_config.unet,
),
'style_transfer_disc': getmodel(style_transfer_config.disc_model_used,
device=device,
model_root=os.path.join(project_root, style_transfer_config.disc_model_path, style_transfer_config.task),
use_step=True,
is_disc=True,
config = style_transfer_config.unet,
),
'adjustment': getmodel(adjustment_config.model_used,
device=device,
model_root=os.path.join(project_root, adjustment_config.model_path, adjustment_config.task),
use_step=False,
is_disc=False,
config = adjustment_config.unet,
),
'adjustment_disc': getmodel(adjustment_config.disc_model_used,
device=device,
model_root=os.path.join(project_root, adjustment_config.disc_model_path, adjustment_config.task),
use_step=True,
is_disc=True,
config = adjustment_config.unet,
),
}
diffuser = GaussianDiffusion(device=device,
fix_mode=base_config.diffusion.fix_mode,
text_emb=base_config.diffusion.text_emb,
fixed_frames=base_config.diffusion.fixed_frames,
seq_len=base_config.diffusion.seq_len,
timesteps=base_config.diffusion.timesteps,
beta_schedule=base_config.diffusion.beta_schedule)
normalize, denormalize = set_up_normalization(device=device, seq_len=base_config.seq_len, scale=3,
norm_path=os.path.abspath(os.path.join(os.path.dirname(__file__), '../../data/norm_scaled.npy')))
test_configs = {
'batch_size': 1,
'seq_len': base_config.seq_len,
'channels': base_config.channels,
'fixed_frame': base_config.fixed_frame,
'use_cfg': base_config.use_cfg,
'cfg_alpha': regen_config.cfg_alpha,
'cg_alpha': regen_config.cg_alpha,
'cg_diffusion_steps': regen_config.cg_diffusion_steps,
}