Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from tqdm import tqdm | |
| from lib.models import build_body_model | |
| from .losses import SMPLifyLoss | |
| class TemporalSMPLify(): | |
| def __init__(self, | |
| smpl=None, | |
| lr=1e-2, | |
| num_iters=5, | |
| num_steps=10, | |
| img_w=None, | |
| img_h=None, | |
| device=None | |
| ): | |
| self.smpl = smpl | |
| self.lr = lr | |
| self.num_iters = num_iters | |
| self.num_steps = num_steps | |
| self.img_w = img_w | |
| self.img_h = img_h | |
| self.device = device | |
| def fit(self, init_pred, keypoints, bbox, **kwargs): | |
| def to_params(param): | |
| return torch.from_numpy(param).float().to(self.device).requires_grad_(True) | |
| pose = init_pred['pose'].detach().cpu().numpy() | |
| betas = init_pred['betas'].detach().cpu().numpy() | |
| cam = init_pred['cam'].detach().cpu().numpy() | |
| keypoints = torch.from_numpy(keypoints).float().unsqueeze(0).to(self.device) | |
| BN = pose.shape[1] | |
| lr = self.lr | |
| # Stage 1. Optimize translation | |
| params = [to_params(pose), to_params(betas), to_params(cam)] | |
| optim_params = [params[2]] | |
| optimizer = torch.optim.LBFGS( | |
| optim_params, | |
| lr=lr, | |
| max_iter=self.num_iters, | |
| line_search_fn='strong_wolfe') | |
| loss_fn = SMPLifyLoss(init_pose=pose, device=self.device, **kwargs) | |
| closure = loss_fn.create_closure(optimizer, | |
| self.smpl, | |
| params, | |
| bbox, | |
| keypoints) | |
| for j in (j_bar := tqdm(range(self.num_steps), leave=False)): | |
| optimizer.zero_grad() | |
| loss = optimizer.step(closure) | |
| msg = f'Loss: {loss.item():.1f}' | |
| j_bar.set_postfix_str(msg) | |
| # Stage 2. Optimize all params | |
| optimizer = torch.optim.LBFGS( | |
| params, | |
| lr=lr * BN, | |
| max_iter=self.num_iters, | |
| line_search_fn='strong_wolfe') | |
| for j in (j_bar := tqdm(range(self.num_steps), leave=False)): | |
| optimizer.zero_grad() | |
| loss = optimizer.step(closure) | |
| msg = f'Loss: {loss.item():.1f}' | |
| j_bar.set_postfix_str(msg) | |
| init_pred['pose'] = params[0].detach() | |
| init_pred['betas'] = params[1].detach() | |
| init_pred['cam'] = params[2].detach() | |
| return init_pred |