|
|
return p |
|
|
|
|
|
import torchvision.models as models |
|
|
|
|
|
class NormType(Enum): |
|
|
Batch = 1 |
|
|
BatchZero = 2 |
|
|
Weight = 3 |
|
|
Spectral = 4 |
|
|
Instance = 5 |
|
|
InstanceZero = 6 |
|
|
Pixel = 7 |
|
|
|
|
|
SplitFuncOrIdxList = Optional[Union[Callable, List[int]]] |
|
|
|
|
|
def to_device(m, device): |
|
|
return m.to(device) |
|
|
|
|
|
def apply_init(m, init_func): |
|
|
if hasattr(m, 'apply'): |
|
|
m.apply(lambda x: init_func(x.weight) if hasattr(x, 'weight') and hasattr(x.weight, 'data') else None) |
|
|
|
|
|
ImageDataBunch = DataBunch |
|
|
|
|
|
|
|
|
def camel2snake(name): |
|
|
import re |
|
|
_camel_re1 = re.compile('(.)([A-Z][a-z]+)') |
|
|
_camel_re2 = re.compile('([a-z0-9])([A-Z])') |
|
|
return _camel_re2.sub(r'\1_\2', _camel_re1.sub(r'\1_\2', name)).lower() |
|
|
|
|
|
class SequentialEx(nn.Module): |
|
|
"Like `nn.Sequential`, but with ModuleList semantics, and can handle `MergeLayer`." |
|
|
def __init__(self, *layers): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList(layers) |
|
|
|
|
|
def forward(self, x): |
|
|
res = x |
|
|
for l in self.layers: |
|
|
res.orig = x |
|
|
nres = l(res) |
|
|
|
|
|
res.orig = None |
|
|
res = nres |
|
|
return res |
|
|
|
|
|
def __getitem__(self, i): return self.layers[i] |
|
|
def append(self, l): return self.layers.append(l) |
|
|
def extend(self, l): return self.layers.extend(l) |
|
|
def insert(self, i, l): return self.layers.insert(i, l) |
|
|
|
|
|
class MergeLayer(nn.Module): |
|
|
"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`." |
|
|
def __init__(self, dense:bool=False): |
|
|
super().__init__() |
|
|
self.dense = dense |
|
|
|
|
|
def forward(self, x): |
|
|
return torch.cat([x, x.orig], dim=1) if self.dense else (x+x.orig) |
|
|
|
|
|
class SigmoidRange(nn.Module): |
|
|
"Sigmoid module with range `(low, x_max)`" |
|
|
def __init__(self, low, high): |
|
|
super().__init__() |
|
|
self.low, self.high = low, high |
|
|
|
|
|
def forward(self, x): |
|
|
return torch.sigmoid(x) * (self.high - self.low) + self.low |
|
|
|
|
|
def conv_layer(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=None, is_1d:bool=False, |
|
|
norm_type:Optional[str]=None, use_activ:bool=True, leaky:float=None, |
|
|
transpose:bool=False, init:Callable=nn.init.kaiming_normal_, self_attention:bool=False): |
|
|
"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers." |
|
|
if padding is None: padding = (ks-1)//2 if not transpose else 0 |
|
|
bn = norm_type in ('Batch', 'BatchZero') |
|
|
if bias is None: bias = not bn |
|
|
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d |
|
|
conv = conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding) |
|
|
|
|
|
if init: init(conv.weight) |
|
|
|
|
|
if norm_type == 'Weight': conv = weight_norm(conv) |
|
|
elif norm_type == 'Spectral': conv = spectral_norm(conv) |
|
|
|
|
|
layers = [conv] |
|
|
if use_activ: layers.append(relu(leaky=leaky)) |
|
|
if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf)) |
|
|
if self_attention: layers.append(SelfAttention(nf)) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def relu(leaky:float=None): |
|
|
return nn.LeakyReLU(ifnone(leaky,0.1), inplace=True) if leaky is not None else nn.ReLU(inplace=True) |
|
|
|
|
|
def batchnorm_2d(nf:int, norm_type:str='Batch'): |
|
|
return nn.BatchNorm2d(nf) |
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
"Self attention layer for nd." |
|
|
def __init__(self, n_channels:int): |
|
|
super().__init__() |
|
|
self.query = conv1d(n_channels, n_channels//8) |
|
|
self.key = conv1d(n_channels, n_channels//8) |
|
|
self.value = conv1d(n_channels, n_channels) |
|
|
self.gamma = nn.Parameter(tensor([0.])) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
size = x.size() |
|
|
x = x.view(*size[:2],-1) |
|
|
f,g,h = self.query(x),self.key(x),self.value(x) |
|
|
beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1) |
|
|
o = self.gamma * torch.bmm(h, beta) + x |
|
|
return o.view(*size) |
|
|
|
|
|
def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False): |
|
|
"Create and initialize a `nn.Conv1d` layer with spectral normalization." |
|
|
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) |
|
|
nn.init.kaiming_normal_(conv.weight) |
|
|
if bias: conv.bias.data.zero_() |
|
|
return spectral_norm(conv) |
|
|
|
|
|
def res_block(nf, dense:bool=False, norm_type:str='Batch', bottle:bool=False, **kwargs): |
|
|
"Resnet block of `nf` features." |
|
|
norm2 = norm_type |
|
|
if not dense and (norm_type=='Batch'): norm2 = 'BatchZero' |
|
|
nf_inner = nf//2 if bottle else nf |
|
|
return SequentialEx(conv_layer(nf, nf_inner, norm_type=norm_type, **kwargs), |
|
|
conv_layer(nf_inner, nf, norm_type=norm2, **kwargs), |
|
|
MergeLayer(dense)) |
|
|
|
|
|
|
|
|
|
|
|
class Hook(): |
|
|
"Create a hook on `m` with `hook_func`." |
|
|
def __init__(self, m:nn.Module, hook_func:Callable, is_forward:bool=True, detach:bool=True): |
|
|
self.hook_func,self.detach,self.stored = hook_func,detach,None |
|
|
f = m.register_forward_hook if is_forward else m.register_backward_hook |
|
|
self.hook = f(self.hook_fn) |
|
|
self.removed = False |
|
|
|
|
|
def hook_fn(self, module:nn.Module, input:Tensor, output:Tensor): |
|
|
if self.detach: |
|
|
input = (o.detach() for o in input) if isinstance(input, tuple) else input.detach() |
|
|
output = (o.detach() for o in output) if isinstance(output, tuple) else output.detach() |
|
|
self.stored = self.hook_func(module, input, output) |
|
|
|
|
|
def remove(self): |
|
|
if not self.removed: |
|
|
self.hook.remove() |
|
|
self.removed = True |
|
|
|
|
|
def __enter__(self, *args): return self |
|
|
def __exit__(self, *args): self.remove() |
|
|
|
|
|
class Hooks(): |
|
|
"Create several hooks on the modules in `ms` with `hook_func`." |
|
|
def __init__(self, ms:List[nn.Module], hook_func:Callable, is_forward:bool=True, detach:bool=True): |
|
|
self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms] |
|
|
|
|
|
def __getitem__(self,i:int)->Hook: return self.hooks[i] |
|
|
def __len__(self)->int: return len(self.hooks) |
|
|
def __iter__(self): return iter(self.hooks) |
|
|
@property |
|
|
def stored(self): return [o.stored for o in self] |
|
|
|
|
|
def remove(self): |
|
|
for h in self.hooks: h.remove() |
|
|
|
|
|
def __enter__(self, *args): return self |
|
|
def __exit__(self, *args): self.remove() |
|
|
|
|
|
def _hook_inner(m,i,o): return o if isinstance(o,Tensor) else o if isinstance(o,list) else list(o) |
|
|
|
|
|
def hook_outputs(modules:List[nn.Module], detach:bool=True, grad:bool=False)->Hooks: |
|
|
"Return `Hooks` that store activations of all `modules` in `self.stored`" |
|
|
return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad) |
|
|
|
|
|
def dummy_eval(m:nn.Module, size:Tuple=(64,64)): |
|
|
"Evaluate `m` on a dummy input of a certain `size`" |
|
|
ch_in = in_channels(m) |
|
|
x = torch.randn(1,ch_in,*size) |
|
|
if next(m.parameters()).is_cuda: x = x.cuda() |
|
|
return m.eval()(x) |
|
|
|
|
|
def model_sizes(m:nn.Module, size:Tuple=(64,64)): |
|
|
"Pass a dummy input through the model `m` to get the various sizes of activations." |
|
|
with hook_outputs(m) as hooks: |
|
|
dummy_eval(m, size) |
|
|
return [o.stored.shape for o in hooks] |
|
|
|
|
|
def in_channels(m:nn.Module) -> int: |
|
|
"Return the shape of the first weight layer in `m`." |
|
|
for l in m.modules(): |
|
|
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): |
|
|
return l.in_channels |
|
|
raise Exception('No weight layer') |
|
|
|
|
|
|
|
|
|
|
|
def create_body(arch:Callable, pretrained:bool=True, cut:Optional[Union[int, Callable]]=None): |
|
|
"Cut off the head of a typically pretrained `arch`." |
|
|
model = arch(pretrained=pretrained) |
|
|
|
|
|
|
|
|
if cut is None: |
|
|
ll = list(enumerate(model.children())) |
|
|
cut = next(i for i,o in reversed(ll) if has_pool_type(o)) |
|
|
|
|
|
return nn.Sequential(*list(model.children())[:cut]) |
|
|
|
|
|
def has_pool_type(m): |
|
|
if isinstance(m, (nn.AdaptiveAvgPool2d, nn.AdaptiveMaxPool2d, nn.AvgPool2d, nn.MaxPool2d)): return True |
|
|
return False |
|
|
|
|
|
def cnn_config(arch): |
|
|
"Get the metadata for `arch`." |
|
|
|
|
|
return {'split': lambda m: (m[0][6], m[1])} |
|
|
|
|
|
|
|
|
|
|
|
class Learner: |
|
|
"Minimal Learner shim for inference." |
|
|
def __init__(self, data, model, path=None): |
|
|
self.data = data |
|
|
self.model = model |
|
|
self.path = path |
|
|
|
|
|
|
|
|
from deoldify import device as device_settings |
|
|
self.device = device_settings.get_torch_device() |
|
|
|
|
|
self.model.to(self.device) |
|
|
|
|
|
def load(self, name): |
|
|
|
|
|
if self.path: |
|
|
path = self.path / 'models' / f'{name}.pth' |
|
|
else: |
|
|
path = f'models/{name}.pth' |
|
|
|
|
|
|
|
|
state = torch.load(path, map_location=self.device) |
|
|
if 'model' in state: state = state['model'] |
|
|
self.model.load_state_dict(state, strict=True) |
|
|
|
|
|
def split(self, split_on): |
|
|
pass |
|
|
|
|
|
def freeze(self): |
|
|
for p in self.model.parameters(): p.requires_grad = False |
|
|
|
|
|
class DataBunch: |
|
|
"Minimal DataBunch shim." |
|
|
def __init__(self, c=3, device=None): |
|
|
self.c = c |
|
|
if device: |
|
|
self.device = device |
|
|
else: |
|
|
from deoldify import device as device_settings |
|
|
self.device = device_settings.get_torch_device() |
|
|
|
|
|
def get_dummy_databunch(): |
|
|
return DataBunch() |
|
|
|
|
|
def tensor(x, *args, **kwargs): |
|
|
return torch.tensor(x, *args, **kwargs) |
|
|
|