Spaces:
Runtime error
Runtime error
| import torch | |
| import argparse | |
| from Project.models.psp import pSp | |
| import torch.onnx | |
| import onnx | |
| import onnxruntime | |
| def setup_model(checkpoint_path, device='cpu'): | |
| model = onnx.load("e4e.onnx") | |
| onnx.checker.check_model(model) | |
| session = onnxruntime.InferenceSession("e4e.onnx") | |
| ckpt = torch.load(checkpoint_path) | |
| opts = ckpt['opts'] | |
| opts['checkpoint_path'] = checkpoint_path | |
| opts['device'] = device | |
| latent_avg=ckpt['latent_avg'] | |
| opts = argparse.Namespace(**opts) | |
| return session, opts,latent_avg | |
| ''' | |
| ckpt = torch.load(checkpoint_path) | |
| opts = ckpt['opts'] | |
| opts['checkpoint_path'] = checkpoint_path | |
| opts['device'] = device | |
| opts = argparse.Namespace(**opts) | |
| net = pSp(opts).encoder | |
| net.eval() | |
| #net = net.to(device) | |
| x=torch.randn(1,3,256,256,requires_grad=False) | |
| torch.onnx.export( | |
| net, | |
| x, | |
| "e4e.onnx", | |
| export_params = True, | |
| opset_version=11, # the ONNX version to export the model to | |
| do_constant_folding=True, # whether to execute constant folding for optimization | |
| input_names=['input'], # the model's input names | |
| output_names=['output'] | |
| ) | |
| return net, opts | |
| ''' | |
| ''' | |
| def load_e4e_standalone(checkpoint_path, device='cuda'): | |
| ckpt = torch.load(checkpoint_path, map_location='cpu') | |
| opts = argparse.Namespace(**ckpt['opts']) | |
| e4e = Encoder4Editing(50, 'ir_se', opts) | |
| e4e_dict = {k.replace('encoder.', ''): v for k, v in ckpt['state_dict'].items() if k.startswith('encoder.')} | |
| e4e.load_state_dict(e4e_dict) | |
| e4e.eval() | |
| e4e = e4e.to(device) | |
| latent_avg = ckpt['latent_avg'].to(device) | |
| def add_latent_avg(model, inputs, outputs): | |
| return outputs + latent_avg.repeat(outputs.shape[0], 1, 1) | |
| e4e.register_forward_hook(add_latent_avg) | |
| return e4e | |
| ''' |