Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| def faceshifter_batch(source_emb: torch.tensor, | |
| target: torch.tensor, | |
| G: torch.nn.Module) -> np.ndarray: | |
| """ | |
| Apply faceshifter model for batch of target images | |
| """ | |
| bs = target.shape[0] | |
| assert target.ndim == 4, "target should have 4 dimentions -- B x C x H x W" | |
| if bs > 1: | |
| source_emb = torch.cat([source_emb]*bs) | |
| with torch.no_grad(): | |
| Y_st, _ = G(target, source_emb) | |
| Y_st = (Y_st.permute(0, 2, 3, 1)*0.5 + 0.5)*255 | |
| Y_st = Y_st[:, :, :, [2,1,0]].type(torch.uint8) | |
| Y_st = Y_st.cpu().detach().numpy() | |
| return Y_st |