NeoPy's picture
EXP
0a0615c verified
raw
history blame
2.2 kB
import os
import torch
from torch import nn
from io import BytesIO
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
def decrypt_model(configs, input_path):
with open(input_path, "rb") as f:
data = f.read()
with open(
os.path.join(configs["binary_path"], "decrypt.bin"),
"rb"
) as f:
key = f.read()
return BytesIO(
unpad(
AES.new(
key,
AES.MODE_CBC,
data[:16]
).decrypt(data[16:]),
AES.block_size
)
).read()
def calc_same_padding(kernel_size):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
def torch_interp(x, xp, fp):
sort_idx = xp.argsort()
xp = xp[sort_idx]
fp = fp[sort_idx]
right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1)
left_idxs = (right_idxs - 1).clamp(min=0)
x_left = xp[left_idxs]
y_left = fp[left_idxs]
interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left))
interp_vals[x < xp[0]] = fp[0]
interp_vals[x > xp[-1]] = fp[-1]
return interp_vals
def batch_interp_with_replacement_detach(uv, f0):
result = f0.clone()
for i in range(uv.shape[0]):
interp_vals = torch_interp(
torch.where(uv[i])[-1],
torch.where(~uv[i])[-1],
f0[i][~uv[i]]
).detach()
result[i][uv[i]] = interp_vals
return result
class DotDict(dict):
def __getattr__(*args):
val = dict.get(*args)
return DotDict(val) if type(val) is dict else val
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()
class Transpose(nn.Module):
def __init__(self, dims):
super().__init__()
assert len(dims) == 2, "dims == 2"
self.dims = dims
def forward(self, x):
return x.transpose(*self.dims)
class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()