Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- models/MAE_SDT.py +639 -0
- models/__init__.py +0 -0
- models/__pycache__/MAE_SDT.cpython-311.pyc +0 -0
- models/__pycache__/MAE_SDT.cpython-312.pyc +0 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/__init__.cpython-312.pyc +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/encoder.cpython-311.pyc +0 -0
- models/__pycache__/encoder.cpython-312.pyc +0 -0
- models/__pycache__/metaformer.cpython-311.pyc +0 -0
- models/__pycache__/metaformer.cpython-312.pyc +0 -0
- models/__pycache__/neuron.cpython-311.pyc +0 -0
- models/__pycache__/neuron.cpython-312.pyc +0 -0
- models/__pycache__/qk_model_v1_1003.cpython-311.pyc +0 -0
- models/__pycache__/qkformer.cpython-311.pyc +0 -0
- models/__pycache__/qkformer.cpython-312.pyc +0 -0
- models/__pycache__/sd_former_v1.cpython-311.pyc +0 -0
- models/__pycache__/sd_former_v1.cpython-312.pyc +0 -0
- models/__pycache__/sdtv3.cpython-311.pyc +0 -0
- models/__pycache__/sdtv3.cpython-312.pyc +0 -0
- models/__pycache__/sdtv3.cpython-39.pyc +0 -0
- models/__pycache__/sdtv3_large.cpython-311.pyc +0 -0
- models/__pycache__/sdtv3_large.cpython-312.pyc +0 -0
- models/__pycache__/spikformer.cpython-311.pyc +0 -0
- models/__pycache__/spikformer.cpython-312.pyc +0 -0
- models/__pycache__/vit.cpython-311.pyc +3 -0
- models/__pycache__/vit.cpython-312.pyc +3 -0
- models/encoder.py +158 -0
- models/metaformer.py +1538 -0
- models/neuron.py +1587 -0
- models/q_vit/Quant.py +185 -0
- models/q_vit/__init__.py +0 -0
- models/q_vit/__pycache__/Quant.cpython-311.pyc +0 -0
- models/q_vit/__pycache__/Quant.cpython-312.pyc +0 -0
- models/q_vit/__pycache__/__init__.cpython-311.pyc +0 -0
- models/q_vit/__pycache__/__init__.cpython-312.pyc +0 -0
- models/q_vit/__pycache__/_quan_base.cpython-311.pyc +0 -0
- models/q_vit/__pycache__/_quan_base.cpython-312.pyc +0 -0
- models/q_vit/__pycache__/quant_vision_transformer.cpython-311.pyc +0 -0
- models/q_vit/__pycache__/quant_vision_transformer.cpython-312.pyc +0 -0
- models/q_vit/_quan_base.py +208 -0
- models/q_vit/quant_vision_transformer.py +527 -0
- models/qk_model_v1_1003.py +426 -0
- models/qk_model_with_delay/__init__.py +0 -0
- models/qk_model_with_delay/__pycache__/__init__.cpython-311.pyc +0 -0
- models/qk_model_with_delay/__pycache__/delay_synaptic_func_inter.cpython-311.pyc +0 -0
- models/qk_model_with_delay/__pycache__/delay_synaptic_inter_model.cpython-311.pyc +0 -0
- models/qk_model_with_delay/delay_synaptic_func_inter.py +169 -0
- models/qk_model_with_delay/delay_synaptic_inter_model.py +459 -0
.gitattributes
CHANGED
|
@@ -91,3 +91,5 @@ visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_B8_att
|
|
| 91 |
visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_B9_attn_proj.pdf filter=lfs diff=lfs merge=lfs -text
|
| 92 |
visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_all_layers.pdf filter=lfs diff=lfs merge=lfs -text
|
| 93 |
visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_average.pdf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 91 |
visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_B9_attn_proj.pdf filter=lfs diff=lfs merge=lfs -text
|
| 92 |
visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_all_layers.pdf filter=lfs diff=lfs merge=lfs -text
|
| 93 |
visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_average.pdf filter=lfs diff=lfs merge=lfs -text
|
| 94 |
+
models/__pycache__/vit.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text
|
| 95 |
+
models/__pycache__/vit.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text
|
models/MAE_SDT.py
ADDED
|
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchinfo
|
| 5 |
+
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
|
| 6 |
+
from timm.models.registry import register_model
|
| 7 |
+
from timm.models.vision_transformer import _cfg
|
| 8 |
+
from einops.layers.torch import Rearrange
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from timm.models.vision_transformer import PatchEmbed, Block
|
| 11 |
+
|
| 12 |
+
from spikingjelly.clock_driven import layer
|
| 13 |
+
import copy
|
| 14 |
+
from torchvision import transforms
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
|
| 17 |
+
import models.encoder as encoder
|
| 18 |
+
from .util.pos_embed import get_2d_sincos_pos_embed
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
#timestep
|
| 23 |
+
T=4
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class multispike(torch.autograd.Function):
|
| 27 |
+
@staticmethod
|
| 28 |
+
def forward(ctx, input, lens=T):
|
| 29 |
+
ctx.save_for_backward(input)
|
| 30 |
+
ctx.lens = lens
|
| 31 |
+
return torch.floor(torch.clamp(input, 0, lens) + 0.5)
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def backward(ctx, grad_output):
|
| 35 |
+
input, = ctx.saved_tensors
|
| 36 |
+
grad_input = grad_output.clone()
|
| 37 |
+
temp1 = 0 < input
|
| 38 |
+
temp2 = input < ctx.lens
|
| 39 |
+
return grad_input * temp1.float() * temp2.float(), None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Multispike(nn.Module):
|
| 43 |
+
def __init__(self, spike=multispike,norm=T):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.lens = norm
|
| 46 |
+
self.spike = spike
|
| 47 |
+
self.norm=norm
|
| 48 |
+
|
| 49 |
+
def forward(self, inputs):
|
| 50 |
+
return self.spike.apply(inputs)/self.norm
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1):
|
| 56 |
+
return nn.Sequential(
|
| 57 |
+
layer.SeqToANNContainer(
|
| 58 |
+
encoder.SparseConv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True),
|
| 59 |
+
encoder.SparseBatchNorm2d(out_channels)
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
class MS_ConvBlock(nn.Module):
|
| 63 |
+
def __init__(self, dim,
|
| 64 |
+
mlp_ratio=4.0):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
self.neuron1 = Multispike()
|
| 68 |
+
self.conv1 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1)
|
| 69 |
+
|
| 70 |
+
self.neuron2 = Multispike()
|
| 71 |
+
self.conv2 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def forward(self, x, mask=None):
|
| 75 |
+
short_cut = x
|
| 76 |
+
x = self.neuron1(x)
|
| 77 |
+
x = self.conv1(x)
|
| 78 |
+
x = self.neuron2(x)
|
| 79 |
+
x = self.conv2(x)
|
| 80 |
+
x = x +short_cut
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
class MS_MLP(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0
|
| 86 |
+
):
|
| 87 |
+
super().__init__()
|
| 88 |
+
out_features = out_features or in_features
|
| 89 |
+
hidden_features = hidden_features or in_features
|
| 90 |
+
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1)
|
| 91 |
+
self.fc1_bn = nn.BatchNorm1d(hidden_features)
|
| 92 |
+
self.fc1_lif = Multispike()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
self.fc2_conv = nn.Conv1d(
|
| 96 |
+
hidden_features, out_features, kernel_size=1, stride=1
|
| 97 |
+
)
|
| 98 |
+
self.fc2_bn = nn.BatchNorm1d(out_features)
|
| 99 |
+
self.fc2_lif = Multispike()
|
| 100 |
+
|
| 101 |
+
self.c_hidden = hidden_features
|
| 102 |
+
self.c_output = out_features
|
| 103 |
+
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
T, B, C, N= x.shape
|
| 106 |
+
|
| 107 |
+
x = self.fc1_lif(x)
|
| 108 |
+
x = self.fc1_conv(x.flatten(0, 1))
|
| 109 |
+
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous()
|
| 110 |
+
|
| 111 |
+
x = self.fc2_lif(x)
|
| 112 |
+
x = self.fc2_conv(x.flatten(0, 1))
|
| 113 |
+
x = self.fc2_bn(x).reshape(T, B, C, N).contiguous()
|
| 114 |
+
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
class RepConv(nn.Module):
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
in_channel,
|
| 121 |
+
out_channel,
|
| 122 |
+
bias=False,
|
| 123 |
+
):
|
| 124 |
+
super().__init__()
|
| 125 |
+
# TODO in_channel-> 2*in_channel->in_channel
|
| 126 |
+
self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5)))
|
| 127 |
+
self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel))
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
return self.conv2(self.conv1(x))
|
| 130 |
+
class RepConv2(nn.Module):
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
in_channel,
|
| 134 |
+
out_channel,
|
| 135 |
+
bias=False,
|
| 136 |
+
):
|
| 137 |
+
super().__init__()
|
| 138 |
+
# TODO in_channel-> 2*in_channel->in_channel
|
| 139 |
+
self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5)))
|
| 140 |
+
self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel))
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
return self.conv2(self.conv1(x))
|
| 143 |
+
|
| 144 |
+
class MS_Attention_Conv_qkv_id(nn.Module):
|
| 145 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
| 146 |
+
super().__init__()
|
| 147 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 148 |
+
self.dim = dim
|
| 149 |
+
self.num_heads = num_heads
|
| 150 |
+
self.scale = 0.125
|
| 151 |
+
self.sr_ratio=sr_ratio
|
| 152 |
+
|
| 153 |
+
self.head_lif = Multispike()
|
| 154 |
+
|
| 155 |
+
# track 1: split convs
|
| 156 |
+
self.q_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim))
|
| 157 |
+
self.k_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim))
|
| 158 |
+
self.v_conv = nn.Sequential(RepConv(dim,dim*sr_ratio), nn.BatchNorm1d(dim*sr_ratio))
|
| 159 |
+
|
| 160 |
+
# track 2: merge (prefer) NOTE: need `chunk` in forward
|
| 161 |
+
# self.qkv_conv = nn.Sequential(RepConv(dim,dim * 3), nn.BatchNorm2d(dim * 3))
|
| 162 |
+
|
| 163 |
+
self.q_lif = Multispike()
|
| 164 |
+
|
| 165 |
+
self.k_lif = Multispike()
|
| 166 |
+
|
| 167 |
+
self.v_lif = Multispike()
|
| 168 |
+
|
| 169 |
+
self.attn_lif = Multispike()
|
| 170 |
+
|
| 171 |
+
self.proj_conv = nn.Sequential(RepConv(sr_ratio*dim,dim), nn.BatchNorm1d(dim))
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
T, B, C, N = x.shape
|
| 175 |
+
|
| 176 |
+
x = self.head_lif(x)
|
| 177 |
+
|
| 178 |
+
x_for_qkv = x.flatten(0, 1)
|
| 179 |
+
q_conv_out = self.q_conv(x_for_qkv).reshape(T, B, C, N)
|
| 180 |
+
|
| 181 |
+
q_conv_out = self.q_lif(q_conv_out)
|
| 182 |
+
|
| 183 |
+
q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
|
| 184 |
+
4)
|
| 185 |
+
|
| 186 |
+
k_conv_out = self.k_conv(x_for_qkv).reshape(T, B, C, N)
|
| 187 |
+
|
| 188 |
+
k_conv_out = self.k_lif(k_conv_out)
|
| 189 |
+
|
| 190 |
+
k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
|
| 191 |
+
4)
|
| 192 |
+
|
| 193 |
+
v_conv_out = self.v_conv(x_for_qkv).reshape(T, B, self.sr_ratio*C, N)
|
| 194 |
+
|
| 195 |
+
v_conv_out = self.v_lif(v_conv_out)
|
| 196 |
+
|
| 197 |
+
v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, self.sr_ratio*C // self.num_heads).permute(0, 1, 3, 2,
|
| 198 |
+
4)
|
| 199 |
+
|
| 200 |
+
x = k.transpose(-2, -1) @ v
|
| 201 |
+
x = (q @ x) * self.scale
|
| 202 |
+
x = x.transpose(3, 4).reshape(T, B, self.sr_ratio*C, N)
|
| 203 |
+
x = self.attn_lif(x)
|
| 204 |
+
|
| 205 |
+
x = self.proj_conv(x.flatten(0, 1)).reshape(T, B, C, N)
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class MS_DownSampling(nn.Module):
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
in_channels=2,
|
| 215 |
+
embed_dims=256,
|
| 216 |
+
kernel_size=3,
|
| 217 |
+
stride=2,
|
| 218 |
+
padding=1,
|
| 219 |
+
first_layer=True,
|
| 220 |
+
|
| 221 |
+
):
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
self.encode_conv = encoder.SparseConv2d(
|
| 225 |
+
in_channels,
|
| 226 |
+
embed_dims,
|
| 227 |
+
kernel_size=kernel_size,
|
| 228 |
+
stride=stride,
|
| 229 |
+
padding=padding,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self.encode_bn = encoder.SparseBatchNorm2d(embed_dims)
|
| 233 |
+
self.first_layer = first_layer
|
| 234 |
+
if not first_layer:
|
| 235 |
+
self.encode_spike = Multispike()
|
| 236 |
+
|
| 237 |
+
def forward(self, x):
|
| 238 |
+
T, B, _, _, _ = x.shape
|
| 239 |
+
if hasattr(self, "encode_spike"):
|
| 240 |
+
x = self.encode_spike(x)
|
| 241 |
+
x = self.encode_conv(x.flatten(0, 1))
|
| 242 |
+
_, _, H, W = x.shape
|
| 243 |
+
x = self.encode_bn(x).reshape(T, B, -1, H, W)
|
| 244 |
+
|
| 245 |
+
return x
|
| 246 |
+
|
| 247 |
+
class MS_Block(nn.Module):
|
| 248 |
+
def __init__(
|
| 249 |
+
self,
|
| 250 |
+
dim,
|
| 251 |
+
choice,
|
| 252 |
+
num_heads,
|
| 253 |
+
mlp_ratio=4.0,
|
| 254 |
+
qkv_bias=False,
|
| 255 |
+
qk_scale=None,
|
| 256 |
+
drop=0.0,
|
| 257 |
+
attn_drop=0.0,
|
| 258 |
+
drop_path=0.0,
|
| 259 |
+
norm_layer=nn.LayerNorm,
|
| 260 |
+
sr_ratio=1,init_values=1e-6,finetune=False,
|
| 261 |
+
):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.model=choice
|
| 264 |
+
if self.model=="base":
|
| 265 |
+
self.rep_conv=RepConv2(dim,dim) #if have param==83M
|
| 266 |
+
self.lif = Multispike()
|
| 267 |
+
self.attn = MS_Attention_Conv_qkv_id(
|
| 268 |
+
dim,
|
| 269 |
+
num_heads=num_heads,
|
| 270 |
+
qkv_bias=qkv_bias,
|
| 271 |
+
qk_scale=qk_scale,
|
| 272 |
+
attn_drop=attn_drop,
|
| 273 |
+
proj_drop=drop,
|
| 274 |
+
sr_ratio=sr_ratio,
|
| 275 |
+
)
|
| 276 |
+
self.finetune = finetune
|
| 277 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 278 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 279 |
+
self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop)
|
| 280 |
+
|
| 281 |
+
if self.finetune:
|
| 282 |
+
self.layer_scale1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
| 283 |
+
self.layer_scale2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
| 284 |
+
|
| 285 |
+
def forward(self, x):
|
| 286 |
+
# T, B, C, N = x.shape
|
| 287 |
+
if self.model=="base":
|
| 288 |
+
x= x + self.rep_conv(self.lif(x).flatten(0, 1)).reshape(T, B, C, N)
|
| 289 |
+
# TODO: need channel-wise layer scale, init as 1e-6
|
| 290 |
+
if self.finetune:
|
| 291 |
+
x = x + self.drop_path(self.attn(x) * self.layer_scale1.unsqueeze(0).unsqueeze(0).unsqueeze(-1))
|
| 292 |
+
x = x + self.drop_path(self.mlp(x) * self.layer_scale2.unsqueeze(0).unsqueeze(0).unsqueeze(-1))
|
| 293 |
+
else:
|
| 294 |
+
x = x + self.attn(x)
|
| 295 |
+
x = x + self.mlp(x)
|
| 296 |
+
return x
|
| 297 |
+
|
| 298 |
+
class Spikmae(nn.Module):
|
| 299 |
+
def __init__(self, T=1,choice=None,
|
| 300 |
+
img_size_h=224,
|
| 301 |
+
img_size_w=224,
|
| 302 |
+
patch_size=16,
|
| 303 |
+
embed_dim=[128, 256, 512],
|
| 304 |
+
num_heads=8,
|
| 305 |
+
mlp_ratios=4,
|
| 306 |
+
in_channels=3,
|
| 307 |
+
qk_scale=None,
|
| 308 |
+
drop_rate=0.0,
|
| 309 |
+
attn_drop_rate=0.0,
|
| 310 |
+
drop_path_rate=0.0,
|
| 311 |
+
num_classes=1000,
|
| 312 |
+
qkv_bias=False,
|
| 313 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), #norm_layer=nn.LayerNorm shaokun
|
| 314 |
+
depths=8,
|
| 315 |
+
sr_ratios=1,
|
| 316 |
+
decoder_embed_dim=768,
|
| 317 |
+
decoder_depth=4,
|
| 318 |
+
decoder_num_heads=16,
|
| 319 |
+
mlp_ratio=4.,
|
| 320 |
+
norm_pix_loss=False, nb_classes=1000):
|
| 321 |
+
super().__init__()
|
| 322 |
+
|
| 323 |
+
self.num_classes = num_classes
|
| 324 |
+
self.depths = depths
|
| 325 |
+
self.T = 1
|
| 326 |
+
|
| 327 |
+
dpr = [
|
| 328 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depths)
|
| 329 |
+
] # stochastic depth decay rule
|
| 330 |
+
|
| 331 |
+
self.downsample1_1 = MS_DownSampling(
|
| 332 |
+
in_channels=in_channels,
|
| 333 |
+
embed_dims=embed_dim[0] // 2,
|
| 334 |
+
kernel_size=7,
|
| 335 |
+
stride=2,
|
| 336 |
+
padding=3,
|
| 337 |
+
first_layer=True,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
self.ConvBlock1_1 = nn.ModuleList(
|
| 341 |
+
[MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)]
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
self.downsample1_2 = MS_DownSampling(
|
| 345 |
+
in_channels=embed_dim[0] // 2,
|
| 346 |
+
embed_dims=embed_dim[0],
|
| 347 |
+
kernel_size=3,
|
| 348 |
+
stride=2,
|
| 349 |
+
padding=1,
|
| 350 |
+
first_layer=False,
|
| 351 |
+
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
self.ConvBlock1_2 = nn.ModuleList(
|
| 355 |
+
[MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)]
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
self.downsample2 = MS_DownSampling(
|
| 359 |
+
in_channels=embed_dim[0],
|
| 360 |
+
embed_dims=embed_dim[1],
|
| 361 |
+
kernel_size=3,
|
| 362 |
+
stride=2,
|
| 363 |
+
padding=1,
|
| 364 |
+
first_layer=False,
|
| 365 |
+
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
self.ConvBlock2_1 = nn.ModuleList(
|
| 369 |
+
[MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)]
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
self.ConvBlock2_2 = nn.ModuleList(
|
| 373 |
+
[MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)]
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
self.downsample3 = MS_DownSampling(
|
| 377 |
+
in_channels=embed_dim[1],
|
| 378 |
+
embed_dims=embed_dim[2],
|
| 379 |
+
kernel_size=3,
|
| 380 |
+
stride=2,
|
| 381 |
+
padding=1,
|
| 382 |
+
first_layer=False,
|
| 383 |
+
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
self.block3 = nn.ModuleList(
|
| 387 |
+
[
|
| 388 |
+
MS_Block(
|
| 389 |
+
dim=embed_dim[2],
|
| 390 |
+
choice=choice,
|
| 391 |
+
num_heads=num_heads,
|
| 392 |
+
mlp_ratio=mlp_ratios,
|
| 393 |
+
qkv_bias=qkv_bias,
|
| 394 |
+
qk_scale=qk_scale,
|
| 395 |
+
drop=drop_rate,
|
| 396 |
+
attn_drop=attn_drop_rate,
|
| 397 |
+
drop_path=dpr[j],
|
| 398 |
+
norm_layer=norm_layer,
|
| 399 |
+
sr_ratio=sr_ratios,
|
| 400 |
+
finetune=False,
|
| 401 |
+
)
|
| 402 |
+
for j in range(depths)
|
| 403 |
+
]
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
self.norm = nn.BatchNorm1d(embed_dim[-1])
|
| 407 |
+
self.downsample_raito =16
|
| 408 |
+
|
| 409 |
+
num_patches = 196
|
| 410 |
+
|
| 411 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim[-1],num_patches), requires_grad=False)
|
| 412 |
+
|
| 413 |
+
## MAE decoder vit
|
| 414 |
+
self.decoder_embed = nn.Linear(embed_dim[-1], decoder_embed_dim,bias=True)
|
| 415 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
| 416 |
+
# Try larned decoder
|
| 417 |
+
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False)
|
| 418 |
+
self.decoder_blocks = nn.ModuleList([
|
| 419 |
+
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=False, norm_layer=norm_layer)
|
| 420 |
+
for i in range(decoder_depth)])
|
| 421 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 422 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_channels,bias=True) # decoder to patch
|
| 423 |
+
self.initialize_weights()
|
| 424 |
+
|
| 425 |
+
def initialize_weights(self):
|
| 426 |
+
num_patches=196
|
| 427 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[1], int(num_patches ** .5),
|
| 428 |
+
cls_token=False)
|
| 429 |
+
|
| 430 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed.transpose(1,0)).float().unsqueeze(0))
|
| 431 |
+
|
| 432 |
+
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1],
|
| 433 |
+
int(num_patches** .5), cls_token=False)
|
| 434 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 435 |
+
|
| 436 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
| 437 |
+
self.apply(self._init_weights)
|
| 438 |
+
|
| 439 |
+
def _init_weights(self, m):
|
| 440 |
+
if isinstance(m, nn.Linear):
|
| 441 |
+
trunc_normal_(m.weight, std=0.02)
|
| 442 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 443 |
+
nn.init.constant_(m.bias, 0)
|
| 444 |
+
elif isinstance(m, nn.LayerNorm):
|
| 445 |
+
nn.init.constant_(m.bias, 0)
|
| 446 |
+
nn.init.constant_(m.weight, 1.0)
|
| 447 |
+
def random_masking(self, x, mask_ratio):
|
| 448 |
+
"""
|
| 449 |
+
Perform per-sample random masking by per-sample shuffling.
|
| 450 |
+
Per-sample shuffling is done by argsort random noise.
|
| 451 |
+
x: [N, L, D], sequence
|
| 452 |
+
"""
|
| 453 |
+
num_patches=196
|
| 454 |
+
T, N, _, _, _ = x.shape # batch, length, dim
|
| 455 |
+
L = num_patches
|
| 456 |
+
len_keep = int(L * (1 - mask_ratio))
|
| 457 |
+
|
| 458 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
| 459 |
+
|
| 460 |
+
# sort noise for each sample
|
| 461 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
| 462 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
| 463 |
+
|
| 464 |
+
# keep the first subset
|
| 465 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 466 |
+
|
| 467 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 468 |
+
mask = torch.ones([N, L], device=x.device)
|
| 469 |
+
mask[:, :len_keep] = 0
|
| 470 |
+
# unshuffle to get the binary mask
|
| 471 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 472 |
+
|
| 473 |
+
# active is inverse mask
|
| 474 |
+
active = torch.ones([N, L], device=x.device)
|
| 475 |
+
active[:, len_keep:] = 0
|
| 476 |
+
active = torch.gather(active, dim=1, index=ids_restore)
|
| 477 |
+
|
| 478 |
+
return ids_keep, active, ids_restore
|
| 479 |
+
|
| 480 |
+
def forward_encoder(self, x , mask_ratio=1.0):
|
| 481 |
+
x = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1)
|
| 482 |
+
# step1. Mask
|
| 483 |
+
ids_keep, active, ids_restore = self.random_masking(x , mask_ratio)
|
| 484 |
+
B,N=active.shape
|
| 485 |
+
active_b1ff=active.reshape(B,1,14,14)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
encoder._cur_active = active_b1ff
|
| 489 |
+
active_hw = active_b1ff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3)
|
| 490 |
+
active_hw = active_hw.unsqueeze(0)
|
| 491 |
+
masked_bchw = x * active_hw
|
| 492 |
+
x = masked_bchw
|
| 493 |
+
x = self.downsample1_1(x)
|
| 494 |
+
for blk in self.ConvBlock1_1:
|
| 495 |
+
x = blk(x)
|
| 496 |
+
x = self.downsample1_2(x)
|
| 497 |
+
for blk in self.ConvBlock1_2:
|
| 498 |
+
x = blk(x)
|
| 499 |
+
|
| 500 |
+
x = self.downsample2(x)
|
| 501 |
+
for blk in self.ConvBlock2_1:
|
| 502 |
+
x = blk(x)
|
| 503 |
+
for blk in self.ConvBlock2_2:
|
| 504 |
+
x = blk(x)
|
| 505 |
+
|
| 506 |
+
x = self.downsample3(x)
|
| 507 |
+
x = x.flatten(3)
|
| 508 |
+
for blk in self.block3:
|
| 509 |
+
x = blk(x)
|
| 510 |
+
|
| 511 |
+
x = x.mean(0)
|
| 512 |
+
x = self.norm(x).transpose(-1, -2).contiguous()
|
| 513 |
+
return x, active,ids_restore,active_hw
|
| 514 |
+
|
| 515 |
+
def forward_decoder(self, x, ids_restore):
|
| 516 |
+
# embed tokens
|
| 517 |
+
B, N, C = x.shape
|
| 518 |
+
x = self.decoder_embed(x) # B, N, C
|
| 519 |
+
# append mask tokens to sequence
|
| 520 |
+
# ids_restore#1,196
|
| 521 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
|
| 522 |
+
x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
|
| 523 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
| 524 |
+
x = x_
|
| 525 |
+
#
|
| 526 |
+
# add pos embed
|
| 527 |
+
x = x + self.decoder_pos_embed
|
| 528 |
+
# apply Transformer blocks
|
| 529 |
+
for blk in self.decoder_blocks:
|
| 530 |
+
x = blk(x)
|
| 531 |
+
x = self.decoder_norm(x)
|
| 532 |
+
x = self.decoder_pred(x)
|
| 533 |
+
|
| 534 |
+
return x
|
| 535 |
+
|
| 536 |
+
def patchify(self, imgs):
|
| 537 |
+
"""
|
| 538 |
+
imgs: (N, 3, H, W)
|
| 539 |
+
x: (N, L, patch_size**2 *3)
|
| 540 |
+
"""
|
| 541 |
+
p = 16
|
| 542 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
| 543 |
+
|
| 544 |
+
h = w = imgs.shape[2] // p
|
| 545 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
| 546 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
| 547 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
|
| 548 |
+
return x
|
| 549 |
+
|
| 550 |
+
def unpatchify(self, x):
|
| 551 |
+
"""
|
| 552 |
+
x: (N, L, patch_size**2 *3)
|
| 553 |
+
imgs: (N, 3, H, W)
|
| 554 |
+
"""
|
| 555 |
+
p = 16
|
| 556 |
+
h = w = int(x.shape[1] ** .5)
|
| 557 |
+
assert h * w == x.shape[1]
|
| 558 |
+
|
| 559 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
|
| 560 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 561 |
+
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
|
| 562 |
+
return imgs
|
| 563 |
+
def forward_loss(self, imgs, pred, mask):
|
| 564 |
+
"""
|
| 565 |
+
imgs: [N, 3, H, W]
|
| 566 |
+
pred: [N, L, p*p*3]
|
| 567 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
| 568 |
+
"""
|
| 569 |
+
|
| 570 |
+
inp, rec = self.patchify(imgs), pred # inp and rec: (B, L = f*f, N = C*downsample_raito**2)
|
| 571 |
+
mean = inp.mean(dim=-1, keepdim=True)
|
| 572 |
+
var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5
|
| 573 |
+
inp = (inp - mean) / var
|
| 574 |
+
l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) # (B, L, C) ==mean==> (B, L)
|
| 575 |
+
non_active = mask.logical_not().int().view(mask.shape[0], -1) # (B, 1, f, f) => (B, L)
|
| 576 |
+
recon_loss = l2_loss.mul_(non_active).sum() / (non_active.sum() + 1e-8) # loss only on masked (non-active) patches
|
| 577 |
+
return recon_loss,mean,var
|
| 578 |
+
|
| 579 |
+
def forward(self, imgs, mask_ratio=0.5,vis=False):
|
| 580 |
+
|
| 581 |
+
latent, active, ids_restore,active_hw = self.forward_encoder(imgs, mask_ratio)
|
| 582 |
+
rec = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
|
| 583 |
+
recon_loss,mean,var = self.forward_loss(imgs, rec, active)
|
| 584 |
+
if vis:
|
| 585 |
+
masked_bchw = imgs * active_hw.flatten(0,1)
|
| 586 |
+
rec_bchw = self.unpatchify(rec * var + mean)
|
| 587 |
+
rec_or_inp = torch.where(active_hw.flatten(0,1).bool(), imgs, rec_bchw)
|
| 588 |
+
return imgs, masked_bchw, rec_or_inp
|
| 589 |
+
else:
|
| 590 |
+
return recon_loss
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def spikmae_12_512(**kwargs):
|
| 594 |
+
model = Spikmae(
|
| 595 |
+
T=1,
|
| 596 |
+
choice="base",
|
| 597 |
+
img_size_h=224,
|
| 598 |
+
img_size_w=224,
|
| 599 |
+
patch_size=16,
|
| 600 |
+
embed_dim=[128,256,512],
|
| 601 |
+
num_heads=8,
|
| 602 |
+
mlp_ratios=4,
|
| 603 |
+
in_channels=3,
|
| 604 |
+
num_classes=1000,
|
| 605 |
+
qkv_bias=False,
|
| 606 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 607 |
+
depths=12,
|
| 608 |
+
sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4,
|
| 609 |
+
**kwargs)
|
| 610 |
+
return model
|
| 611 |
+
def spikmae_12_768(**kwargs):
|
| 612 |
+
model = Spikmae(
|
| 613 |
+
T=1,
|
| 614 |
+
choice="large",
|
| 615 |
+
img_size_h=224,
|
| 616 |
+
img_size_w=224,
|
| 617 |
+
patch_size=16,
|
| 618 |
+
embed_dim=[192,384,768],
|
| 619 |
+
num_heads=8,
|
| 620 |
+
mlp_ratios=4,
|
| 621 |
+
in_channels=3,
|
| 622 |
+
num_classes=1000,
|
| 623 |
+
qkv_bias=False,
|
| 624 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 625 |
+
depths=12,
|
| 626 |
+
sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4,
|
| 627 |
+
**kwargs)
|
| 628 |
+
return model
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
if __name__ == "__main__":
|
| 634 |
+
model = spikmae_12_768()
|
| 635 |
+
x=torch.randn(1,3,224,224)
|
| 636 |
+
loss = model(x,mask_ratio=0.50)
|
| 637 |
+
print('loss',loss)
|
| 638 |
+
torchinfo.summary(model, (1, 3, 224, 224))
|
| 639 |
+
print(f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
models/__init__.py
ADDED
|
File without changes
|
models/__pycache__/MAE_SDT.cpython-311.pyc
ADDED
|
Binary file (36.2 kB). View file
|
|
|
models/__pycache__/MAE_SDT.cpython-312.pyc
ADDED
|
Binary file (32.1 kB). View file
|
|
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (201 Bytes). View file
|
|
|
models/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (212 Bytes). View file
|
|
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (153 Bytes). View file
|
|
|
models/__pycache__/encoder.cpython-311.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
models/__pycache__/encoder.cpython-312.pyc
ADDED
|
Binary file (9.31 kB). View file
|
|
|
models/__pycache__/metaformer.cpython-311.pyc
ADDED
|
Binary file (72.1 kB). View file
|
|
|
models/__pycache__/metaformer.cpython-312.pyc
ADDED
|
Binary file (63.8 kB). View file
|
|
|
models/__pycache__/neuron.cpython-311.pyc
ADDED
|
Binary file (78.9 kB). View file
|
|
|
models/__pycache__/neuron.cpython-312.pyc
ADDED
|
Binary file (75.7 kB). View file
|
|
|
models/__pycache__/qk_model_v1_1003.cpython-311.pyc
ADDED
|
Binary file (30 kB). View file
|
|
|
models/__pycache__/qkformer.cpython-311.pyc
ADDED
|
Binary file (31.5 kB). View file
|
|
|
models/__pycache__/qkformer.cpython-312.pyc
ADDED
|
Binary file (27.1 kB). View file
|
|
|
models/__pycache__/sd_former_v1.cpython-311.pyc
ADDED
|
Binary file (29.7 kB). View file
|
|
|
models/__pycache__/sd_former_v1.cpython-312.pyc
ADDED
|
Binary file (25.6 kB). View file
|
|
|
models/__pycache__/sdtv3.cpython-311.pyc
ADDED
|
Binary file (64.6 kB). View file
|
|
|
models/__pycache__/sdtv3.cpython-312.pyc
ADDED
|
Binary file (55.6 kB). View file
|
|
|
models/__pycache__/sdtv3.cpython-39.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
models/__pycache__/sdtv3_large.cpython-311.pyc
ADDED
|
Binary file (27.3 kB). View file
|
|
|
models/__pycache__/sdtv3_large.cpython-312.pyc
ADDED
|
Binary file (23.7 kB). View file
|
|
|
models/__pycache__/spikformer.cpython-311.pyc
ADDED
|
Binary file (28 kB). View file
|
|
|
models/__pycache__/spikformer.cpython-312.pyc
ADDED
|
Binary file (24.6 kB). View file
|
|
|
models/__pycache__/vit.cpython-311.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c206e15daa2f79c2abc87acf17b7e6263bb292fe86a0581f10e58b94da50c3d5
|
| 3 |
+
size 204918
|
models/__pycache__/vit.cpython-312.pyc
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:541616a16f1f3839624aff7ca6c0d0f168227ee19a642340a85e71e77d6ea63d
|
| 3 |
+
size 183274
|
models/encoder.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) ByteDance, Inc. and its affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from timm.models.layers import DropPath
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
_cur_active: torch.Tensor = None # B1ff
|
| 13 |
+
# todo: try to use `gather` for speed?
|
| 14 |
+
def _get_active_ex_or_ii(H, W, returning_active_ex=True):
|
| 15 |
+
h_repeat, w_repeat = H // _cur_active.shape[-2], W // _cur_active.shape[-1]
|
| 16 |
+
active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave(w_repeat, dim=3)
|
| 17 |
+
return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def sp_conv_forward(self, x: torch.Tensor):
|
| 21 |
+
x = super(type(self), self).forward(x)
|
| 22 |
+
x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def sp_bn_forward(self, x: torch.Tensor):
|
| 27 |
+
ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
|
| 28 |
+
|
| 29 |
+
bhwc = x.permute(0, 2, 3, 1)
|
| 30 |
+
nc = bhwc[ii] # select the features on non-masked positions to form a flatten feature `nc`
|
| 31 |
+
nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc`
|
| 32 |
+
|
| 33 |
+
bchw = torch.zeros_like(bhwc)
|
| 34 |
+
bchw[ii] = nc
|
| 35 |
+
bchw = bchw.permute(0, 3, 1, 2)
|
| 36 |
+
return bchw
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SparseConv2d(nn.Conv2d):
|
| 40 |
+
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SparseMaxPooling(nn.MaxPool2d):
|
| 44 |
+
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class SparseAvgPooling(nn.AvgPool2d):
|
| 48 |
+
forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SparseBatchNorm2d(nn.BatchNorm1d):
|
| 52 |
+
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
|
| 56 |
+
forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class SparseConvNeXtLayerNorm(nn.LayerNorm):
|
| 60 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 61 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 62 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 63 |
+
with shape (batch_size, channels, height, width).
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True):
|
| 67 |
+
if data_format not in ["channels_last", "channels_first"]:
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
super().__init__(normalized_shape, eps, elementwise_affine=True)
|
| 70 |
+
self.data_format = data_format
|
| 71 |
+
self.sparse = sparse
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
if x.ndim == 4: # BHWC or BCHW
|
| 75 |
+
if self.data_format == "channels_last": # BHWC
|
| 76 |
+
if self.sparse:
|
| 77 |
+
ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], returning_active_ex=False)
|
| 78 |
+
nc = x[ii]
|
| 79 |
+
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
|
| 80 |
+
|
| 81 |
+
x = torch.zeros_like(x)
|
| 82 |
+
x[ii] = nc
|
| 83 |
+
return x
|
| 84 |
+
else:
|
| 85 |
+
return super(SparseConvNeXtLayerNorm, self).forward(x)
|
| 86 |
+
else: # channels_first, BCHW
|
| 87 |
+
if self.sparse:
|
| 88 |
+
ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False)
|
| 89 |
+
bhwc = x.permute(0, 2, 3, 1)
|
| 90 |
+
nc = bhwc[ii]
|
| 91 |
+
nc = super(SparseConvNeXtLayerNorm, self).forward(nc)
|
| 92 |
+
|
| 93 |
+
x = torch.zeros_like(bhwc)
|
| 94 |
+
x[ii] = nc
|
| 95 |
+
return x.permute(0, 3, 1, 2)
|
| 96 |
+
else:
|
| 97 |
+
u = x.mean(1, keepdim=True)
|
| 98 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 99 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 100 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 101 |
+
return x
|
| 102 |
+
else: # BLC or BC
|
| 103 |
+
if self.sparse:
|
| 104 |
+
raise NotImplementedError
|
| 105 |
+
else:
|
| 106 |
+
return super(SparseConvNeXtLayerNorm, self).forward(x)
|
| 107 |
+
|
| 108 |
+
def __repr__(self):
|
| 109 |
+
return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})'
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class SparseConvNeXtBlock(nn.Module):
|
| 113 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
| 114 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 115 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 116 |
+
We use (2) as we find it slightly faster in PyTorch
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
dim (int): Number of input channels.
|
| 120 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 121 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv
|
| 127 |
+
self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse)
|
| 128 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
| 129 |
+
self.act = nn.GELU()
|
| 130 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 131 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
| 132 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
| 133 |
+
self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 134 |
+
self.sparse = sparse
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
input = x
|
| 138 |
+
x = self.dwconv(x)
|
| 139 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 140 |
+
x = self.norm(x)
|
| 141 |
+
x = self.pwconv1(x)
|
| 142 |
+
x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`)
|
| 143 |
+
x = self.pwconv2(x)
|
| 144 |
+
if self.gamma is not None:
|
| 145 |
+
x = self.gamma * x
|
| 146 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 147 |
+
|
| 148 |
+
if self.sparse:
|
| 149 |
+
x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True)
|
| 150 |
+
|
| 151 |
+
x = input + self.drop_path(x)
|
| 152 |
+
return x
|
| 153 |
+
|
| 154 |
+
def __repr__(self):
|
| 155 |
+
return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})'
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
models/metaformer.py
ADDED
|
@@ -0,0 +1,1538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2022 Garena Online Private Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2,
|
| 17 |
+
ConvFormer and CAFormer.
|
| 18 |
+
Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models).
|
| 19 |
+
"""
|
| 20 |
+
from functools import partial
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
|
| 25 |
+
from timm.layers import trunc_normal_, DropPath
|
| 26 |
+
from timm.models.registry import register_model
|
| 27 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 28 |
+
from timm.layers.helpers import to_2tuple
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _cfg(url='', **kwargs):
|
| 32 |
+
return {
|
| 33 |
+
'url': url,
|
| 34 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 35 |
+
'crop_pct': 1.0, 'interpolation': 'bicubic',
|
| 36 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
|
| 37 |
+
**kwargs
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
default_cfgs = {
|
| 42 |
+
'identityformer_s12': _cfg(
|
| 43 |
+
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'),
|
| 44 |
+
'identityformer_s24': _cfg(
|
| 45 |
+
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'),
|
| 46 |
+
'identityformer_s36': _cfg(
|
| 47 |
+
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'),
|
| 48 |
+
'identityformer_m36': _cfg(
|
| 49 |
+
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'),
|
| 50 |
+
'identityformer_m48': _cfg(
|
| 51 |
+
url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'),
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
'randformer_s12': _cfg(
|
| 55 |
+
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'),
|
| 56 |
+
'randformer_s24': _cfg(
|
| 57 |
+
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'),
|
| 58 |
+
'randformer_s36': _cfg(
|
| 59 |
+
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'),
|
| 60 |
+
'randformer_m36': _cfg(
|
| 61 |
+
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'),
|
| 62 |
+
'randformer_m48': _cfg(
|
| 63 |
+
url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'),
|
| 64 |
+
|
| 65 |
+
'poolformerv2_s12': _cfg(
|
| 66 |
+
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'),
|
| 67 |
+
'poolformerv2_s24': _cfg(
|
| 68 |
+
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'),
|
| 69 |
+
'poolformerv2_s36': _cfg(
|
| 70 |
+
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'),
|
| 71 |
+
'poolformerv2_m36': _cfg(
|
| 72 |
+
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'),
|
| 73 |
+
'poolformerv2_m48': _cfg(
|
| 74 |
+
url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'),
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
'convformer_s18': _cfg(
|
| 79 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'),
|
| 80 |
+
'convformer_s18_384': _cfg(
|
| 81 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth',
|
| 82 |
+
input_size=(3, 384, 384)),
|
| 83 |
+
'convformer_s18_in21ft1k': _cfg(
|
| 84 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'),
|
| 85 |
+
'convformer_s18_384_in21ft1k': _cfg(
|
| 86 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth',
|
| 87 |
+
input_size=(3, 384, 384)),
|
| 88 |
+
'convformer_s18_in21k': _cfg(
|
| 89 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth',
|
| 90 |
+
num_classes=21841),
|
| 91 |
+
|
| 92 |
+
'convformer_s36': _cfg(
|
| 93 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'),
|
| 94 |
+
'convformer_s36_384': _cfg(
|
| 95 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth',
|
| 96 |
+
input_size=(3, 384, 384)),
|
| 97 |
+
'convformer_s36_in21ft1k': _cfg(
|
| 98 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'),
|
| 99 |
+
'convformer_s36_384_in21ft1k': _cfg(
|
| 100 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth',
|
| 101 |
+
input_size=(3, 384, 384)),
|
| 102 |
+
'convformer_s36_in21k': _cfg(
|
| 103 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth',
|
| 104 |
+
num_classes=21841),
|
| 105 |
+
|
| 106 |
+
'convformer_m36': _cfg(
|
| 107 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'),
|
| 108 |
+
'convformer_m36_384': _cfg(
|
| 109 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth',
|
| 110 |
+
input_size=(3, 384, 384)),
|
| 111 |
+
'convformer_m36_in21ft1k': _cfg(
|
| 112 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'),
|
| 113 |
+
'convformer_m36_384_in21ft1k': _cfg(
|
| 114 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth',
|
| 115 |
+
input_size=(3, 384, 384)),
|
| 116 |
+
'convformer_m36_in21k': _cfg(
|
| 117 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth',
|
| 118 |
+
num_classes=21841),
|
| 119 |
+
|
| 120 |
+
'convformer_b36': _cfg(
|
| 121 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'),
|
| 122 |
+
'convformer_b36_384': _cfg(
|
| 123 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth',
|
| 124 |
+
input_size=(3, 384, 384)),
|
| 125 |
+
'convformer_b36_in21ft1k': _cfg(
|
| 126 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'),
|
| 127 |
+
'convformer_b36_384_in21ft1k': _cfg(
|
| 128 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth',
|
| 129 |
+
input_size=(3, 384, 384)),
|
| 130 |
+
'convformer_b36_in21k': _cfg(
|
| 131 |
+
url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth',
|
| 132 |
+
num_classes=21841),
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
'caformer_s18': _cfg(
|
| 136 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'),
|
| 137 |
+
'caformer_s18_384': _cfg(
|
| 138 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth',
|
| 139 |
+
input_size=(3, 384, 384)),
|
| 140 |
+
'caformer_s18_in21ft1k': _cfg(
|
| 141 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'),
|
| 142 |
+
'caformer_s18_384_in21ft1k': _cfg(
|
| 143 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth',
|
| 144 |
+
input_size=(3, 384, 384)),
|
| 145 |
+
'caformer_s18_in21k': _cfg(
|
| 146 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth',
|
| 147 |
+
num_classes=21841),
|
| 148 |
+
|
| 149 |
+
'caformer_s36': _cfg(
|
| 150 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'),
|
| 151 |
+
'caformer_s36_384': _cfg(
|
| 152 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth',
|
| 153 |
+
input_size=(3, 384, 384)),
|
| 154 |
+
'caformer_s36_in21ft1k': _cfg(
|
| 155 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'),
|
| 156 |
+
'caformer_s36_384_in21ft1k': _cfg(
|
| 157 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth',
|
| 158 |
+
input_size=(3, 384, 384)),
|
| 159 |
+
'caformer_s36_in21k': _cfg(
|
| 160 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth',
|
| 161 |
+
num_classes=21841),
|
| 162 |
+
|
| 163 |
+
'caformer_m36': _cfg(
|
| 164 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'),
|
| 165 |
+
'caformer_m36_384': _cfg(
|
| 166 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth',
|
| 167 |
+
input_size=(3, 384, 384)),
|
| 168 |
+
'caformer_m36_in21ft1k': _cfg(
|
| 169 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'),
|
| 170 |
+
'caformer_m36_384_in21ft1k': _cfg(
|
| 171 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth',
|
| 172 |
+
input_size=(3, 384, 384)),
|
| 173 |
+
'caformer_m36_in21k': _cfg(
|
| 174 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth',
|
| 175 |
+
num_classes=21841),
|
| 176 |
+
|
| 177 |
+
'caformer_b36': _cfg(
|
| 178 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'),
|
| 179 |
+
'caformer_b36_384': _cfg(
|
| 180 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth',
|
| 181 |
+
input_size=(3, 384, 384)),
|
| 182 |
+
'caformer_b36_in21ft1k': _cfg(
|
| 183 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'),
|
| 184 |
+
'caformer_b36_384_in21ft1k': _cfg(
|
| 185 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth',
|
| 186 |
+
input_size=(3, 384, 384)),
|
| 187 |
+
'caformer_b36_in21k': _cfg(
|
| 188 |
+
url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth',
|
| 189 |
+
num_classes=21841),
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class Downsampling(nn.Module):
|
| 194 |
+
"""
|
| 195 |
+
Downsampling implemented by a layer of convolution.
|
| 196 |
+
"""
|
| 197 |
+
def __init__(self, in_channels, out_channels,
|
| 198 |
+
kernel_size, stride=1, padding=0,
|
| 199 |
+
pre_norm=None, post_norm=None, pre_permute=False):
|
| 200 |
+
super().__init__()
|
| 201 |
+
self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity()
|
| 202 |
+
self.pre_permute = pre_permute
|
| 203 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
|
| 204 |
+
stride=stride, padding=padding)
|
| 205 |
+
self.post_norm = post_norm(out_channels) if post_norm else nn.Identity()
|
| 206 |
+
|
| 207 |
+
def forward(self, x):
|
| 208 |
+
x = self.pre_norm(x)
|
| 209 |
+
if self.pre_permute:
|
| 210 |
+
# if take [B, H, W, C] as input, permute it to [B, C, H, W]
|
| 211 |
+
x = x.permute(0, 3, 1, 2)
|
| 212 |
+
x = self.conv(x)
|
| 213 |
+
x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
|
| 214 |
+
x = self.post_norm(x)
|
| 215 |
+
return x
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class Scale(nn.Module):
|
| 219 |
+
"""
|
| 220 |
+
Scale vector by element multiplications.
|
| 221 |
+
"""
|
| 222 |
+
def __init__(self, dim, init_value=1.0, trainable=True):
|
| 223 |
+
super().__init__()
|
| 224 |
+
self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)
|
| 225 |
+
|
| 226 |
+
def forward(self, x):
|
| 227 |
+
return x * self.scale
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class SquaredReLU(nn.Module):
|
| 231 |
+
"""
|
| 232 |
+
Squared ReLU: https://arxiv.org/abs/2109.08668
|
| 233 |
+
"""
|
| 234 |
+
def __init__(self, inplace=False):
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.relu = nn.ReLU(inplace=inplace)
|
| 237 |
+
def forward(self, x):
|
| 238 |
+
return torch.square(self.relu(x))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class StarReLU(nn.Module):
|
| 242 |
+
"""
|
| 243 |
+
StarReLU: s * relu(x) ** 2 + b
|
| 244 |
+
"""
|
| 245 |
+
def __init__(self, scale_value=1.0, bias_value=0.0,
|
| 246 |
+
scale_learnable=True, bias_learnable=True,
|
| 247 |
+
mode=None, inplace=False):
|
| 248 |
+
super().__init__()
|
| 249 |
+
self.inplace = inplace
|
| 250 |
+
self.relu = nn.ReLU(inplace=inplace)
|
| 251 |
+
self.scale = nn.Parameter(scale_value * torch.ones(1),
|
| 252 |
+
requires_grad=scale_learnable)
|
| 253 |
+
self.bias = nn.Parameter(bias_value * torch.ones(1),
|
| 254 |
+
requires_grad=bias_learnable)
|
| 255 |
+
def forward(self, x):
|
| 256 |
+
return self.scale * self.relu(x)**2 + self.bias
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class Attention(nn.Module):
|
| 260 |
+
"""
|
| 261 |
+
Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762.
|
| 262 |
+
Modified from timm.
|
| 263 |
+
"""
|
| 264 |
+
def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False,
|
| 265 |
+
attn_drop=0., proj_drop=0., proj_bias=False, **kwargs):
|
| 266 |
+
super().__init__()
|
| 267 |
+
|
| 268 |
+
self.head_dim = head_dim
|
| 269 |
+
self.scale = head_dim ** -0.5
|
| 270 |
+
|
| 271 |
+
self.num_heads = num_heads if num_heads else dim // head_dim
|
| 272 |
+
if self.num_heads == 0:
|
| 273 |
+
self.num_heads = 1
|
| 274 |
+
|
| 275 |
+
self.attention_dim = self.num_heads * self.head_dim
|
| 276 |
+
|
| 277 |
+
self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias)
|
| 278 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 279 |
+
self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
|
| 280 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def forward(self, x):
|
| 284 |
+
B, H, W, C = x.shape
|
| 285 |
+
N = H * W
|
| 286 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 287 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 288 |
+
|
| 289 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 290 |
+
attn = attn.softmax(dim=-1)
|
| 291 |
+
attn = self.attn_drop(attn)
|
| 292 |
+
|
| 293 |
+
x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim)
|
| 294 |
+
x = self.proj(x)
|
| 295 |
+
x = self.proj_drop(x)
|
| 296 |
+
return x
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class RandomMixing(nn.Module):
|
| 300 |
+
def __init__(self, num_tokens=196, **kwargs):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.random_matrix = nn.parameter.Parameter(
|
| 303 |
+
data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1),
|
| 304 |
+
requires_grad=False)
|
| 305 |
+
def forward(self, x):
|
| 306 |
+
B, H, W, C = x.shape
|
| 307 |
+
x = x.reshape(B, H*W, C)
|
| 308 |
+
x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x)
|
| 309 |
+
x = x.reshape(B, H, W, C)
|
| 310 |
+
return x
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class LayerNormGeneral(nn.Module):
|
| 314 |
+
r""" General LayerNorm for different situations.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
affine_shape (int, list or tuple): The shape of affine weight and bias.
|
| 318 |
+
Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm,
|
| 319 |
+
the affine_shape is the same as normalized_dim by default.
|
| 320 |
+
To adapt to different situations, we offer this argument here.
|
| 321 |
+
normalized_dim (tuple or list): Which dims to compute mean and variance.
|
| 322 |
+
scale (bool): Flag indicates whether to use scale or not.
|
| 323 |
+
bias (bool): Flag indicates whether to use scale or not.
|
| 324 |
+
|
| 325 |
+
We give several examples to show how to specify the arguments.
|
| 326 |
+
|
| 327 |
+
LayerNorm (https://arxiv.org/abs/1607.06450):
|
| 328 |
+
For input shape of (B, *, C) like (B, N, C) or (B, H, W, C),
|
| 329 |
+
affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True;
|
| 330 |
+
For input shape of (B, C, H, W),
|
| 331 |
+
affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True.
|
| 332 |
+
|
| 333 |
+
Modified LayerNorm (https://arxiv.org/abs/2111.11418)
|
| 334 |
+
that is idental to partial(torch.nn.GroupNorm, num_groups=1):
|
| 335 |
+
For input shape of (B, N, C),
|
| 336 |
+
affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True;
|
| 337 |
+
For input shape of (B, H, W, C),
|
| 338 |
+
affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True;
|
| 339 |
+
For input shape of (B, C, H, W),
|
| 340 |
+
affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True.
|
| 341 |
+
|
| 342 |
+
For the several metaformer baslines,
|
| 343 |
+
IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False);
|
| 344 |
+
ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False).
|
| 345 |
+
"""
|
| 346 |
+
def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True,
|
| 347 |
+
bias=True, eps=1e-5):
|
| 348 |
+
super().__init__()
|
| 349 |
+
self.normalized_dim = normalized_dim
|
| 350 |
+
self.use_scale = scale
|
| 351 |
+
self.use_bias = bias
|
| 352 |
+
self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None
|
| 353 |
+
self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None
|
| 354 |
+
self.eps = eps
|
| 355 |
+
|
| 356 |
+
def forward(self, x):
|
| 357 |
+
c = x - x.mean(self.normalized_dim, keepdim=True)
|
| 358 |
+
s = c.pow(2).mean(self.normalized_dim, keepdim=True)
|
| 359 |
+
x = c / torch.sqrt(s + self.eps)
|
| 360 |
+
if self.use_scale:
|
| 361 |
+
x = x * self.weight
|
| 362 |
+
if self.use_bias:
|
| 363 |
+
x = x + self.bias
|
| 364 |
+
return x
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class LayerNormWithoutBias(nn.Module):
|
| 368 |
+
"""
|
| 369 |
+
Equal to partial(LayerNormGeneral, bias=False) but faster,
|
| 370 |
+
because it directly utilizes otpimized F.layer_norm
|
| 371 |
+
"""
|
| 372 |
+
def __init__(self, normalized_shape, eps=1e-5, **kwargs):
|
| 373 |
+
super().__init__()
|
| 374 |
+
self.eps = eps
|
| 375 |
+
self.bias = None
|
| 376 |
+
if isinstance(normalized_shape, int):
|
| 377 |
+
normalized_shape = (normalized_shape,)
|
| 378 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 379 |
+
self.normalized_shape = normalized_shape
|
| 380 |
+
def forward(self, x):
|
| 381 |
+
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
class SepConv(nn.Module):
|
| 385 |
+
r"""
|
| 386 |
+
Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
|
| 387 |
+
"""
|
| 388 |
+
def __init__(self, dim, expansion_ratio=2,
|
| 389 |
+
act1_layer=StarReLU, act2_layer=nn.Identity,
|
| 390 |
+
bias=False, kernel_size=7, padding=3,
|
| 391 |
+
**kwargs, ):
|
| 392 |
+
super().__init__()
|
| 393 |
+
med_channels = int(expansion_ratio * dim)
|
| 394 |
+
self.pwconv1 = nn.Linear(dim, med_channels, bias=bias)
|
| 395 |
+
self.act1 = act1_layer()
|
| 396 |
+
self.dwconv = nn.Conv2d(
|
| 397 |
+
med_channels, med_channels, kernel_size=kernel_size,
|
| 398 |
+
padding=padding, groups=med_channels, bias=bias) # depthwise conv
|
| 399 |
+
self.act2 = act2_layer()
|
| 400 |
+
self.pwconv2 = nn.Linear(med_channels, dim, bias=bias)
|
| 401 |
+
|
| 402 |
+
def forward(self, x):
|
| 403 |
+
x = self.pwconv1(x)
|
| 404 |
+
x = self.act1(x)
|
| 405 |
+
x = x.permute(0, 3, 1, 2)
|
| 406 |
+
x = self.dwconv(x)
|
| 407 |
+
x = x.permute(0, 2, 3, 1)
|
| 408 |
+
x = self.act2(x)
|
| 409 |
+
x = self.pwconv2(x)
|
| 410 |
+
return x
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class Pooling(nn.Module):
|
| 414 |
+
"""
|
| 415 |
+
Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418
|
| 416 |
+
Modfiled for [B, H, W, C] input
|
| 417 |
+
"""
|
| 418 |
+
def __init__(self, pool_size=3, **kwargs):
|
| 419 |
+
super().__init__()
|
| 420 |
+
self.pool = nn.AvgPool2d(
|
| 421 |
+
pool_size, stride=1, padding=pool_size//2, count_include_pad=False)
|
| 422 |
+
|
| 423 |
+
def forward(self, x):
|
| 424 |
+
y = x.permute(0, 3, 1, 2)
|
| 425 |
+
y = self.pool(y)
|
| 426 |
+
y = y.permute(0, 2, 3, 1)
|
| 427 |
+
return y - x
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class Mlp(nn.Module):
|
| 431 |
+
""" MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks.
|
| 432 |
+
Mostly copied from timm.
|
| 433 |
+
"""
|
| 434 |
+
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs):
|
| 435 |
+
super().__init__()
|
| 436 |
+
in_features = dim
|
| 437 |
+
out_features = out_features or in_features
|
| 438 |
+
hidden_features = int(mlp_ratio * in_features)
|
| 439 |
+
drop_probs = to_2tuple(drop)
|
| 440 |
+
|
| 441 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 442 |
+
self.act = act_layer()
|
| 443 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 444 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 445 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 446 |
+
|
| 447 |
+
def forward(self, x):
|
| 448 |
+
x = self.fc1(x)
|
| 449 |
+
x = self.act(x)
|
| 450 |
+
x = self.drop1(x)
|
| 451 |
+
x = self.fc2(x)
|
| 452 |
+
x = self.drop2(x)
|
| 453 |
+
return x
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
class MlpHead(nn.Module):
|
| 457 |
+
""" MLP classification head
|
| 458 |
+
"""
|
| 459 |
+
def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU,
|
| 460 |
+
norm_layer=nn.LayerNorm, head_dropout=0., bias=True):
|
| 461 |
+
super().__init__()
|
| 462 |
+
hidden_features = int(mlp_ratio * dim)
|
| 463 |
+
self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
|
| 464 |
+
self.act = act_layer()
|
| 465 |
+
self.norm = norm_layer(hidden_features)
|
| 466 |
+
self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
|
| 467 |
+
self.head_dropout = nn.Dropout(head_dropout)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def forward(self, x):
|
| 471 |
+
x = self.fc1(x)
|
| 472 |
+
x = self.act(x)
|
| 473 |
+
x = self.norm(x)
|
| 474 |
+
x = self.head_dropout(x)
|
| 475 |
+
x = self.fc2(x)
|
| 476 |
+
return x
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class MetaFormerBlock(nn.Module):
|
| 480 |
+
"""
|
| 481 |
+
Implementation of one MetaFormer block.
|
| 482 |
+
"""
|
| 483 |
+
def __init__(self, dim,
|
| 484 |
+
token_mixer=nn.Identity, mlp=Mlp,
|
| 485 |
+
norm_layer=nn.LayerNorm,
|
| 486 |
+
drop=0., drop_path=0.,
|
| 487 |
+
layer_scale_init_value=None, res_scale_init_value=None
|
| 488 |
+
):
|
| 489 |
+
|
| 490 |
+
super().__init__()
|
| 491 |
+
|
| 492 |
+
self.norm1 = norm_layer(dim)
|
| 493 |
+
self.token_mixer = token_mixer(dim=dim, drop=drop)
|
| 494 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 495 |
+
self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) \
|
| 496 |
+
if layer_scale_init_value else nn.Identity()
|
| 497 |
+
self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) \
|
| 498 |
+
if res_scale_init_value else nn.Identity()
|
| 499 |
+
|
| 500 |
+
self.norm2 = norm_layer(dim)
|
| 501 |
+
self.mlp = mlp(dim=dim, drop=drop)
|
| 502 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 503 |
+
self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \
|
| 504 |
+
if layer_scale_init_value else nn.Identity()
|
| 505 |
+
self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) \
|
| 506 |
+
if res_scale_init_value else nn.Identity()
|
| 507 |
+
|
| 508 |
+
def forward(self, x):
|
| 509 |
+
x = self.res_scale1(x) + \
|
| 510 |
+
self.layer_scale1(
|
| 511 |
+
self.drop_path1(
|
| 512 |
+
self.token_mixer(self.norm1(x))
|
| 513 |
+
)
|
| 514 |
+
)
|
| 515 |
+
x = self.res_scale2(x) + \
|
| 516 |
+
self.layer_scale2(
|
| 517 |
+
self.drop_path2(
|
| 518 |
+
self.mlp(self.norm2(x))
|
| 519 |
+
)
|
| 520 |
+
)
|
| 521 |
+
return x
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
r"""
|
| 525 |
+
downsampling (stem) for the first stage is a layer of conv with k7, s4 and p2
|
| 526 |
+
downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1
|
| 527 |
+
DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling]
|
| 528 |
+
use `partial` to specify some arguments
|
| 529 |
+
"""
|
| 530 |
+
DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling,
|
| 531 |
+
kernel_size=7, stride=4, padding=2,
|
| 532 |
+
post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6)
|
| 533 |
+
)] + \
|
| 534 |
+
[partial(Downsampling,
|
| 535 |
+
kernel_size=3, stride=2, padding=1,
|
| 536 |
+
pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True
|
| 537 |
+
)]*3
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
class MetaFormer(nn.Module):
|
| 541 |
+
r""" MetaFormer
|
| 542 |
+
A PyTorch impl of : `MetaFormer Baselines for Vision` -
|
| 543 |
+
https://arxiv.org/abs/2210.13452
|
| 544 |
+
|
| 545 |
+
Args:
|
| 546 |
+
in_chans (int): Number of input image channels. Default: 3.
|
| 547 |
+
num_classes (int): Number of classes for classification head. Default: 1000.
|
| 548 |
+
depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2].
|
| 549 |
+
dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512].
|
| 550 |
+
downsample_layers: (list or tuple): Downsampling layers before each stage.
|
| 551 |
+
token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity.
|
| 552 |
+
mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp.
|
| 553 |
+
norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False).
|
| 554 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 555 |
+
head_dropout (float): dropout for MLP classifier. Default: 0.
|
| 556 |
+
layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None.
|
| 557 |
+
None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.
|
| 558 |
+
res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0].
|
| 559 |
+
None means not use the layer scale. From: https://arxiv.org/abs/2110.09456.
|
| 560 |
+
output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6).
|
| 561 |
+
head_fn: classification head. Default: nn.Linear.
|
| 562 |
+
"""
|
| 563 |
+
def __init__(self, in_chans=3, num_classes=1000,
|
| 564 |
+
depths=[2, 2, 6, 2],
|
| 565 |
+
dims=[64, 128, 320, 512],
|
| 566 |
+
downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES,
|
| 567 |
+
token_mixers=nn.Identity,
|
| 568 |
+
mlps=Mlp,
|
| 569 |
+
norm_layers=partial(LayerNormWithoutBias, eps=1e-6), # partial(LayerNormGeneral, eps=1e-6, bias=False),
|
| 570 |
+
drop_path_rate=0.,
|
| 571 |
+
head_dropout=0.0,
|
| 572 |
+
layer_scale_init_values=None,
|
| 573 |
+
res_scale_init_values=[None, None, 1.0, 1.0],
|
| 574 |
+
output_norm=partial(nn.LayerNorm, eps=1e-6),
|
| 575 |
+
head_fn=nn.Linear,
|
| 576 |
+
**kwargs,
|
| 577 |
+
):
|
| 578 |
+
super().__init__()
|
| 579 |
+
self.num_classes = num_classes
|
| 580 |
+
|
| 581 |
+
if not isinstance(depths, (list, tuple)):
|
| 582 |
+
depths = [depths] # it means the model has only one stage
|
| 583 |
+
if not isinstance(dims, (list, tuple)):
|
| 584 |
+
dims = [dims]
|
| 585 |
+
|
| 586 |
+
num_stage = len(depths)
|
| 587 |
+
self.num_stage = num_stage
|
| 588 |
+
|
| 589 |
+
if not isinstance(downsample_layers, (list, tuple)):
|
| 590 |
+
downsample_layers = [downsample_layers] * num_stage
|
| 591 |
+
down_dims = [in_chans] + dims
|
| 592 |
+
self.downsample_layers = nn.ModuleList(
|
| 593 |
+
[downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)]
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
if not isinstance(token_mixers, (list, tuple)):
|
| 597 |
+
token_mixers = [token_mixers] * num_stage
|
| 598 |
+
|
| 599 |
+
if not isinstance(mlps, (list, tuple)):
|
| 600 |
+
mlps = [mlps] * num_stage
|
| 601 |
+
|
| 602 |
+
if not isinstance(norm_layers, (list, tuple)):
|
| 603 |
+
norm_layers = [norm_layers] * num_stage
|
| 604 |
+
|
| 605 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 606 |
+
|
| 607 |
+
if not isinstance(layer_scale_init_values, (list, tuple)):
|
| 608 |
+
layer_scale_init_values = [layer_scale_init_values] * num_stage
|
| 609 |
+
if not isinstance(res_scale_init_values, (list, tuple)):
|
| 610 |
+
res_scale_init_values = [res_scale_init_values] * num_stage
|
| 611 |
+
|
| 612 |
+
self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks
|
| 613 |
+
cur = 0
|
| 614 |
+
for i in range(num_stage):
|
| 615 |
+
stage = nn.Sequential(
|
| 616 |
+
*[MetaFormerBlock(dim=dims[i],
|
| 617 |
+
token_mixer=token_mixers[i],
|
| 618 |
+
mlp=mlps[i],
|
| 619 |
+
norm_layer=norm_layers[i],
|
| 620 |
+
drop_path=dp_rates[cur + j],
|
| 621 |
+
layer_scale_init_value=layer_scale_init_values[i],
|
| 622 |
+
res_scale_init_value=res_scale_init_values[i],
|
| 623 |
+
) for j in range(depths[i])]
|
| 624 |
+
)
|
| 625 |
+
self.stages.append(stage)
|
| 626 |
+
cur += depths[i]
|
| 627 |
+
|
| 628 |
+
self.norm = output_norm(dims[-1])
|
| 629 |
+
|
| 630 |
+
if head_dropout > 0.0:
|
| 631 |
+
self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout)
|
| 632 |
+
else:
|
| 633 |
+
self.head = head_fn(dims[-1], num_classes)
|
| 634 |
+
|
| 635 |
+
self.apply(self._init_weights)
|
| 636 |
+
|
| 637 |
+
def _init_weights(self, m):
|
| 638 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 639 |
+
trunc_normal_(m.weight, std=.02)
|
| 640 |
+
if m.bias is not None:
|
| 641 |
+
nn.init.constant_(m.bias, 0)
|
| 642 |
+
|
| 643 |
+
@torch.jit.ignore
|
| 644 |
+
def no_weight_decay(self):
|
| 645 |
+
return {'norm'}
|
| 646 |
+
|
| 647 |
+
def forward_features(self, x):
|
| 648 |
+
for i in range(self.num_stage):
|
| 649 |
+
x = self.downsample_layers[i](x)
|
| 650 |
+
x = self.stages[i](x)
|
| 651 |
+
return self.norm(x.mean([1, 2])) # (B, H, W, C) -> (B, C)
|
| 652 |
+
|
| 653 |
+
def forward(self, x):
|
| 654 |
+
x = self.forward_features(x)
|
| 655 |
+
x = self.head(x)
|
| 656 |
+
return x
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
@register_model
|
| 661 |
+
def identityformer_s12(pretrained=False, **kwargs):
|
| 662 |
+
model = MetaFormer(
|
| 663 |
+
depths=[2, 2, 6, 2],
|
| 664 |
+
dims=[64, 128, 320, 512],
|
| 665 |
+
token_mixers=nn.Identity,
|
| 666 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 667 |
+
**kwargs)
|
| 668 |
+
model.default_cfg = default_cfgs['identityformer_s12']
|
| 669 |
+
if pretrained:
|
| 670 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 671 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 672 |
+
model.load_state_dict(state_dict)
|
| 673 |
+
return model
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
@register_model
|
| 677 |
+
def identityformer_s24(pretrained=False, **kwargs):
|
| 678 |
+
model = MetaFormer(
|
| 679 |
+
depths=[4, 4, 12, 4],
|
| 680 |
+
dims=[64, 128, 320, 512],
|
| 681 |
+
token_mixers=nn.Identity,
|
| 682 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 683 |
+
**kwargs)
|
| 684 |
+
model.default_cfg = default_cfgs['identityformer_s24']
|
| 685 |
+
if pretrained:
|
| 686 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 687 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 688 |
+
model.load_state_dict(state_dict)
|
| 689 |
+
return model
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
@register_model
|
| 693 |
+
def identityformer_s36(pretrained=False, **kwargs):
|
| 694 |
+
model = MetaFormer(
|
| 695 |
+
depths=[6, 6, 18, 6],
|
| 696 |
+
dims=[64, 128, 320, 512],
|
| 697 |
+
token_mixers=nn.Identity,
|
| 698 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 699 |
+
**kwargs)
|
| 700 |
+
model.default_cfg = default_cfgs['identityformer_s36']
|
| 701 |
+
if pretrained:
|
| 702 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 703 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 704 |
+
model.load_state_dict(state_dict)
|
| 705 |
+
return model
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
@register_model
|
| 709 |
+
def identityformer_m36(pretrained=False, **kwargs):
|
| 710 |
+
model = MetaFormer(
|
| 711 |
+
depths=[6, 6, 18, 6],
|
| 712 |
+
dims=[96, 192, 384, 768],
|
| 713 |
+
token_mixers=nn.Identity,
|
| 714 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 715 |
+
**kwargs)
|
| 716 |
+
model.default_cfg = default_cfgs['identityformer_m36']
|
| 717 |
+
if pretrained:
|
| 718 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 719 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 720 |
+
model.load_state_dict(state_dict)
|
| 721 |
+
return model
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
@register_model
|
| 725 |
+
def identityformer_m48(pretrained=False, **kwargs):
|
| 726 |
+
model = MetaFormer(
|
| 727 |
+
depths=[8, 8, 24, 8],
|
| 728 |
+
dims=[96, 192, 384, 768],
|
| 729 |
+
token_mixers=nn.Identity,
|
| 730 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 731 |
+
**kwargs)
|
| 732 |
+
model.default_cfg = default_cfgs['identityformer_m48']
|
| 733 |
+
if pretrained:
|
| 734 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 735 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 736 |
+
model.load_state_dict(state_dict)
|
| 737 |
+
return model
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
@register_model
|
| 741 |
+
def randformer_s12(pretrained=False, **kwargs):
|
| 742 |
+
model = MetaFormer(
|
| 743 |
+
depths=[2, 2, 6, 2],
|
| 744 |
+
dims=[64, 128, 320, 512],
|
| 745 |
+
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
|
| 746 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 747 |
+
**kwargs)
|
| 748 |
+
model.default_cfg = default_cfgs['randformer_s12']
|
| 749 |
+
if pretrained:
|
| 750 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 751 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 752 |
+
model.load_state_dict(state_dict)
|
| 753 |
+
return model
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
@register_model
|
| 757 |
+
def randformer_s24(pretrained=False, **kwargs):
|
| 758 |
+
model = MetaFormer(
|
| 759 |
+
depths=[4, 4, 12, 4],
|
| 760 |
+
dims=[64, 128, 320, 512],
|
| 761 |
+
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
|
| 762 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 763 |
+
**kwargs)
|
| 764 |
+
model.default_cfg = default_cfgs['randformer_s24']
|
| 765 |
+
if pretrained:
|
| 766 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 767 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 768 |
+
model.load_state_dict(state_dict)
|
| 769 |
+
return model
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
@register_model
|
| 773 |
+
def randformer_s36(pretrained=False, **kwargs):
|
| 774 |
+
model = MetaFormer(
|
| 775 |
+
depths=[6, 6, 18, 6],
|
| 776 |
+
dims=[64, 128, 320, 512],
|
| 777 |
+
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
|
| 778 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 779 |
+
**kwargs)
|
| 780 |
+
model.default_cfg = default_cfgs['randformer_s36']
|
| 781 |
+
if pretrained:
|
| 782 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 783 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 784 |
+
model.load_state_dict(state_dict)
|
| 785 |
+
return model
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
@register_model
|
| 789 |
+
def randformer_m36(pretrained=False, **kwargs):
|
| 790 |
+
model = MetaFormer(
|
| 791 |
+
depths=[6, 6, 18, 6],
|
| 792 |
+
dims=[96, 192, 384, 768],
|
| 793 |
+
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
|
| 794 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 795 |
+
**kwargs)
|
| 796 |
+
model.default_cfg = default_cfgs['randformer_m36']
|
| 797 |
+
if pretrained:
|
| 798 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 799 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 800 |
+
model.load_state_dict(state_dict)
|
| 801 |
+
return model
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
@register_model
|
| 805 |
+
def randformer_m48(pretrained=False, **kwargs):
|
| 806 |
+
model = MetaFormer(
|
| 807 |
+
depths=[8, 8, 24, 8],
|
| 808 |
+
dims=[96, 192, 384, 768],
|
| 809 |
+
token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)],
|
| 810 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 811 |
+
**kwargs)
|
| 812 |
+
model.default_cfg = default_cfgs['randformer_m48']
|
| 813 |
+
if pretrained:
|
| 814 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 815 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 816 |
+
model.load_state_dict(state_dict)
|
| 817 |
+
return model
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
@register_model
|
| 822 |
+
def poolformerv2_s12(pretrained=False, **kwargs):
|
| 823 |
+
model = MetaFormer(
|
| 824 |
+
depths=[2, 2, 6, 2],
|
| 825 |
+
dims=[64, 128, 320, 512],
|
| 826 |
+
token_mixers=Pooling,
|
| 827 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 828 |
+
**kwargs)
|
| 829 |
+
model.default_cfg = default_cfgs['poolformerv2_s12']
|
| 830 |
+
if pretrained:
|
| 831 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 832 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 833 |
+
model.load_state_dict(state_dict)
|
| 834 |
+
return model
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
@register_model
|
| 838 |
+
def poolformerv2_s24(pretrained=False, **kwargs):
|
| 839 |
+
model = MetaFormer(
|
| 840 |
+
depths=[4, 4, 12, 4],
|
| 841 |
+
dims=[64, 128, 320, 512],
|
| 842 |
+
token_mixers=Pooling,
|
| 843 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 844 |
+
**kwargs)
|
| 845 |
+
model.default_cfg = default_cfgs['poolformerv2_s24']
|
| 846 |
+
if pretrained:
|
| 847 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 848 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 849 |
+
model.load_state_dict(state_dict)
|
| 850 |
+
return model
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
@register_model
|
| 854 |
+
def poolformerv2_s36(pretrained=False, **kwargs):
|
| 855 |
+
model = MetaFormer(
|
| 856 |
+
depths=[6, 6, 18, 6],
|
| 857 |
+
dims=[64, 128, 320, 512],
|
| 858 |
+
token_mixers=Pooling,
|
| 859 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 860 |
+
**kwargs)
|
| 861 |
+
model.default_cfg = default_cfgs['poolformerv2_s36']
|
| 862 |
+
if pretrained:
|
| 863 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 864 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 865 |
+
model.load_state_dict(state_dict)
|
| 866 |
+
return model
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
@register_model
|
| 870 |
+
def poolformerv2_m36(pretrained=False, **kwargs):
|
| 871 |
+
model = MetaFormer(
|
| 872 |
+
depths=[6, 6, 18, 6],
|
| 873 |
+
dims=[96, 192, 384, 768],
|
| 874 |
+
token_mixers=Pooling,
|
| 875 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 876 |
+
**kwargs)
|
| 877 |
+
model.default_cfg = default_cfgs['poolformerv2_m36']
|
| 878 |
+
if pretrained:
|
| 879 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 880 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 881 |
+
model.load_state_dict(state_dict)
|
| 882 |
+
return model
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
@register_model
|
| 886 |
+
def poolformerv2_m48(pretrained=False, **kwargs):
|
| 887 |
+
model = MetaFormer(
|
| 888 |
+
depths=[8, 8, 24, 8],
|
| 889 |
+
dims=[96, 192, 384, 768],
|
| 890 |
+
token_mixers=Pooling,
|
| 891 |
+
norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False),
|
| 892 |
+
**kwargs)
|
| 893 |
+
model.default_cfg = default_cfgs['poolformerv2_m48']
|
| 894 |
+
if pretrained:
|
| 895 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 896 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 897 |
+
model.load_state_dict(state_dict)
|
| 898 |
+
return model
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
@register_model
|
| 902 |
+
def convformer_s18(pretrained=False, **kwargs):
|
| 903 |
+
model = MetaFormer(
|
| 904 |
+
depths=[3, 3, 9, 3],
|
| 905 |
+
dims=[64, 128, 320, 512],
|
| 906 |
+
token_mixers=SepConv,
|
| 907 |
+
head_fn=MlpHead,
|
| 908 |
+
**kwargs)
|
| 909 |
+
model.default_cfg = default_cfgs['convformer_s18']
|
| 910 |
+
if pretrained:
|
| 911 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 912 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 913 |
+
model.load_state_dict(state_dict)
|
| 914 |
+
return model
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
@register_model
|
| 918 |
+
def convformer_s18_384(pretrained=False, **kwargs):
|
| 919 |
+
model = MetaFormer(
|
| 920 |
+
depths=[3, 3, 9, 3],
|
| 921 |
+
dims=[64, 128, 320, 512],
|
| 922 |
+
token_mixers=SepConv,
|
| 923 |
+
head_fn=MlpHead,
|
| 924 |
+
**kwargs)
|
| 925 |
+
model.default_cfg = default_cfgs['convformer_s18_384']
|
| 926 |
+
if pretrained:
|
| 927 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 928 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 929 |
+
model.load_state_dict(state_dict)
|
| 930 |
+
return model
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
@register_model
|
| 934 |
+
def convformer_s18_in21ft1k(pretrained=False, **kwargs):
|
| 935 |
+
model = MetaFormer(
|
| 936 |
+
depths=[3, 3, 9, 3],
|
| 937 |
+
dims=[64, 128, 320, 512],
|
| 938 |
+
token_mixers=SepConv,
|
| 939 |
+
head_fn=MlpHead,
|
| 940 |
+
**kwargs)
|
| 941 |
+
model.default_cfg = default_cfgs['convformer_s18_in21ft1k']
|
| 942 |
+
if pretrained:
|
| 943 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 944 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 945 |
+
model.load_state_dict(state_dict)
|
| 946 |
+
return model
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
@register_model
|
| 950 |
+
def convformer_s18_384_in21ft1k(pretrained=False, **kwargs):
|
| 951 |
+
model = MetaFormer(
|
| 952 |
+
depths=[3, 3, 9, 3],
|
| 953 |
+
dims=[64, 128, 320, 512],
|
| 954 |
+
token_mixers=SepConv,
|
| 955 |
+
head_fn=MlpHead,
|
| 956 |
+
**kwargs)
|
| 957 |
+
model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k']
|
| 958 |
+
if pretrained:
|
| 959 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 960 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 961 |
+
model.load_state_dict(state_dict)
|
| 962 |
+
return model
|
| 963 |
+
|
| 964 |
+
|
| 965 |
+
@register_model
|
| 966 |
+
def convformer_s18_in21k(pretrained=False, **kwargs):
|
| 967 |
+
model = MetaFormer(
|
| 968 |
+
depths=[3, 3, 9, 3],
|
| 969 |
+
dims=[64, 128, 320, 512],
|
| 970 |
+
token_mixers=SepConv,
|
| 971 |
+
head_fn=MlpHead,
|
| 972 |
+
**kwargs)
|
| 973 |
+
model.default_cfg = default_cfgs['convformer_s18_in21k']
|
| 974 |
+
if pretrained:
|
| 975 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 976 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 977 |
+
model.load_state_dict(state_dict)
|
| 978 |
+
return model
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
@register_model
|
| 982 |
+
def convformer_s36(pretrained=False, **kwargs):
|
| 983 |
+
model = MetaFormer(
|
| 984 |
+
depths=[3, 12, 18, 3],
|
| 985 |
+
dims=[64, 128, 320, 512],
|
| 986 |
+
token_mixers=SepConv,
|
| 987 |
+
head_fn=MlpHead,
|
| 988 |
+
**kwargs)
|
| 989 |
+
model.default_cfg = default_cfgs['convformer_s36']
|
| 990 |
+
if pretrained:
|
| 991 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 992 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 993 |
+
model.load_state_dict(state_dict)
|
| 994 |
+
return model
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
@register_model
|
| 998 |
+
def convformer_s36_384(pretrained=False, **kwargs):
|
| 999 |
+
model = MetaFormer(
|
| 1000 |
+
depths=[3, 12, 18, 3],
|
| 1001 |
+
dims=[64, 128, 320, 512],
|
| 1002 |
+
token_mixers=SepConv,
|
| 1003 |
+
head_fn=MlpHead,
|
| 1004 |
+
**kwargs)
|
| 1005 |
+
model.default_cfg = default_cfgs['convformer_s36_384']
|
| 1006 |
+
if pretrained:
|
| 1007 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1008 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1009 |
+
model.load_state_dict(state_dict)
|
| 1010 |
+
return model
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
@register_model
|
| 1014 |
+
def convformer_s36_in21ft1k(pretrained=False, **kwargs):
|
| 1015 |
+
model = MetaFormer(
|
| 1016 |
+
depths=[3, 12, 18, 3],
|
| 1017 |
+
dims=[64, 128, 320, 512],
|
| 1018 |
+
token_mixers=SepConv,
|
| 1019 |
+
head_fn=MlpHead,
|
| 1020 |
+
**kwargs)
|
| 1021 |
+
model.default_cfg = default_cfgs['convformer_s36_in21ft1k']
|
| 1022 |
+
if pretrained:
|
| 1023 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1024 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1025 |
+
model.load_state_dict(state_dict)
|
| 1026 |
+
return model
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
@register_model
|
| 1030 |
+
def convformer_s36_384_in21ft1k(pretrained=False, **kwargs):
|
| 1031 |
+
model = MetaFormer(
|
| 1032 |
+
depths=[3, 12, 18, 3],
|
| 1033 |
+
dims=[64, 128, 320, 512],
|
| 1034 |
+
token_mixers=SepConv,
|
| 1035 |
+
head_fn=MlpHead,
|
| 1036 |
+
**kwargs)
|
| 1037 |
+
model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k']
|
| 1038 |
+
if pretrained:
|
| 1039 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1040 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1041 |
+
model.load_state_dict(state_dict)
|
| 1042 |
+
return model
|
| 1043 |
+
|
| 1044 |
+
|
| 1045 |
+
@register_model
|
| 1046 |
+
def convformer_s36_in21k(pretrained=False, **kwargs):
|
| 1047 |
+
model = MetaFormer(
|
| 1048 |
+
depths=[3, 12, 18, 3],
|
| 1049 |
+
dims=[64, 128, 320, 512],
|
| 1050 |
+
token_mixers=SepConv,
|
| 1051 |
+
head_fn=MlpHead,
|
| 1052 |
+
**kwargs)
|
| 1053 |
+
model.default_cfg = default_cfgs['convformer_s36_in21k']
|
| 1054 |
+
if pretrained:
|
| 1055 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1056 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1057 |
+
model.load_state_dict(state_dict)
|
| 1058 |
+
return model
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
@register_model
|
| 1062 |
+
def convformer_m36(pretrained=False, **kwargs):
|
| 1063 |
+
model = MetaFormer(
|
| 1064 |
+
depths=[3, 12, 18, 3],
|
| 1065 |
+
dims=[96, 192, 384, 576],
|
| 1066 |
+
token_mixers=SepConv,
|
| 1067 |
+
head_fn=MlpHead,
|
| 1068 |
+
**kwargs)
|
| 1069 |
+
model.default_cfg = default_cfgs['convformer_m36']
|
| 1070 |
+
if pretrained:
|
| 1071 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1072 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1073 |
+
model.load_state_dict(state_dict)
|
| 1074 |
+
return model
|
| 1075 |
+
|
| 1076 |
+
|
| 1077 |
+
@register_model
|
| 1078 |
+
def convformer_m36_384(pretrained=False, **kwargs):
|
| 1079 |
+
model = MetaFormer(
|
| 1080 |
+
depths=[3, 12, 18, 3],
|
| 1081 |
+
dims=[96, 192, 384, 576],
|
| 1082 |
+
token_mixers=SepConv,
|
| 1083 |
+
head_fn=MlpHead,
|
| 1084 |
+
**kwargs)
|
| 1085 |
+
model.default_cfg = default_cfgs['convformer_m36_384']
|
| 1086 |
+
if pretrained:
|
| 1087 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1088 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1089 |
+
model.load_state_dict(state_dict)
|
| 1090 |
+
return model
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
@register_model
|
| 1094 |
+
def convformer_m36_in21ft1k(pretrained=False, **kwargs):
|
| 1095 |
+
model = MetaFormer(
|
| 1096 |
+
depths=[3, 12, 18, 3],
|
| 1097 |
+
dims=[96, 192, 384, 576],
|
| 1098 |
+
token_mixers=SepConv,
|
| 1099 |
+
head_fn=MlpHead,
|
| 1100 |
+
**kwargs)
|
| 1101 |
+
model.default_cfg = default_cfgs['convformer_m36_in21ft1k']
|
| 1102 |
+
if pretrained:
|
| 1103 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1104 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1105 |
+
model.load_state_dict(state_dict)
|
| 1106 |
+
return model
|
| 1107 |
+
|
| 1108 |
+
|
| 1109 |
+
@register_model
|
| 1110 |
+
def convformer_m36_384_in21ft1k(pretrained=False, **kwargs):
|
| 1111 |
+
model = MetaFormer(
|
| 1112 |
+
depths=[3, 12, 18, 3],
|
| 1113 |
+
dims=[96, 192, 384, 576],
|
| 1114 |
+
token_mixers=SepConv,
|
| 1115 |
+
head_fn=MlpHead,
|
| 1116 |
+
**kwargs)
|
| 1117 |
+
model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k']
|
| 1118 |
+
if pretrained:
|
| 1119 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1120 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1121 |
+
model.load_state_dict(state_dict)
|
| 1122 |
+
return model
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
@register_model
|
| 1126 |
+
def convformer_m36_in21k(pretrained=False, **kwargs):
|
| 1127 |
+
model = MetaFormer(
|
| 1128 |
+
depths=[3, 12, 18, 3],
|
| 1129 |
+
dims=[96, 192, 384, 576],
|
| 1130 |
+
token_mixers=SepConv,
|
| 1131 |
+
head_fn=MlpHead,
|
| 1132 |
+
**kwargs)
|
| 1133 |
+
model.default_cfg = default_cfgs['convformer_m36_in21k']
|
| 1134 |
+
if pretrained:
|
| 1135 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1136 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1137 |
+
model.load_state_dict(state_dict)
|
| 1138 |
+
return model
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
@register_model
|
| 1142 |
+
def convformer_b36(pretrained=False, **kwargs):
|
| 1143 |
+
model = MetaFormer(
|
| 1144 |
+
depths=[3, 12, 18, 3],
|
| 1145 |
+
dims=[128, 256, 512, 768],
|
| 1146 |
+
token_mixers=SepConv,
|
| 1147 |
+
head_fn=MlpHead,
|
| 1148 |
+
**kwargs)
|
| 1149 |
+
model.default_cfg = default_cfgs['convformer_b36']
|
| 1150 |
+
if pretrained:
|
| 1151 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1152 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1153 |
+
model.load_state_dict(state_dict)
|
| 1154 |
+
return model
|
| 1155 |
+
|
| 1156 |
+
|
| 1157 |
+
@register_model
|
| 1158 |
+
def convformer_b36_384(pretrained=False, **kwargs):
|
| 1159 |
+
model = MetaFormer(
|
| 1160 |
+
depths=[3, 12, 18, 3],
|
| 1161 |
+
dims=[128, 256, 512, 768],
|
| 1162 |
+
token_mixers=SepConv,
|
| 1163 |
+
head_fn=MlpHead,
|
| 1164 |
+
**kwargs)
|
| 1165 |
+
model.default_cfg = default_cfgs['convformer_b36_384']
|
| 1166 |
+
if pretrained:
|
| 1167 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1168 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1169 |
+
model.load_state_dict(state_dict)
|
| 1170 |
+
return model
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
@register_model
|
| 1174 |
+
def convformer_b36_in21ft1k(pretrained=False, **kwargs):
|
| 1175 |
+
model = MetaFormer(
|
| 1176 |
+
depths=[3, 12, 18, 3],
|
| 1177 |
+
dims=[128, 256, 512, 768],
|
| 1178 |
+
token_mixers=SepConv,
|
| 1179 |
+
head_fn=MlpHead,
|
| 1180 |
+
**kwargs)
|
| 1181 |
+
model.default_cfg = default_cfgs['convformer_b36_in21ft1k']
|
| 1182 |
+
if pretrained:
|
| 1183 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1184 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1185 |
+
model.load_state_dict(state_dict)
|
| 1186 |
+
return model
|
| 1187 |
+
|
| 1188 |
+
|
| 1189 |
+
@register_model
|
| 1190 |
+
def convformer_b36_384_in21ft1k(pretrained=False, **kwargs):
|
| 1191 |
+
model = MetaFormer(
|
| 1192 |
+
depths=[3, 12, 18, 3],
|
| 1193 |
+
dims=[128, 256, 512, 768],
|
| 1194 |
+
token_mixers=SepConv,
|
| 1195 |
+
head_fn=MlpHead,
|
| 1196 |
+
**kwargs)
|
| 1197 |
+
model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k']
|
| 1198 |
+
if pretrained:
|
| 1199 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1200 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1201 |
+
model.load_state_dict(state_dict)
|
| 1202 |
+
return model
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
@register_model
|
| 1206 |
+
def convformer_b36_in21k(pretrained=False, **kwargs):
|
| 1207 |
+
model = MetaFormer(
|
| 1208 |
+
depths=[3, 12, 18, 3],
|
| 1209 |
+
dims=[128, 256, 512, 768],
|
| 1210 |
+
token_mixers=SepConv,
|
| 1211 |
+
head_fn=MlpHead,
|
| 1212 |
+
**kwargs)
|
| 1213 |
+
model.default_cfg = default_cfgs['convformer_b36_in21k']
|
| 1214 |
+
if pretrained:
|
| 1215 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1216 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1217 |
+
model.load_state_dict(state_dict)
|
| 1218 |
+
return model
|
| 1219 |
+
|
| 1220 |
+
|
| 1221 |
+
@register_model
|
| 1222 |
+
def caformer_s18(pretrained=False, **kwargs):
|
| 1223 |
+
model = MetaFormer(
|
| 1224 |
+
depths=[3, 3, 9, 3],
|
| 1225 |
+
dims=[64, 128, 320, 512],
|
| 1226 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1227 |
+
head_fn=MlpHead,
|
| 1228 |
+
**kwargs)
|
| 1229 |
+
model.default_cfg = default_cfgs['caformer_s18']
|
| 1230 |
+
if pretrained:
|
| 1231 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1232 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1233 |
+
model.load_state_dict(state_dict)
|
| 1234 |
+
return model
|
| 1235 |
+
|
| 1236 |
+
|
| 1237 |
+
@register_model
|
| 1238 |
+
def caformer_s18_384(pretrained=False, **kwargs):
|
| 1239 |
+
model = MetaFormer(
|
| 1240 |
+
depths=[3, 3, 9, 3],
|
| 1241 |
+
dims=[64, 128, 320, 512],
|
| 1242 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1243 |
+
head_fn=MlpHead,
|
| 1244 |
+
**kwargs)
|
| 1245 |
+
model.default_cfg = default_cfgs['caformer_s18_384']
|
| 1246 |
+
if pretrained:
|
| 1247 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1248 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1249 |
+
model.load_state_dict(state_dict)
|
| 1250 |
+
return model
|
| 1251 |
+
|
| 1252 |
+
|
| 1253 |
+
@register_model
|
| 1254 |
+
def caformer_s18_in21ft1k(pretrained=False, **kwargs):
|
| 1255 |
+
model = MetaFormer(
|
| 1256 |
+
depths=[3, 3, 9, 3],
|
| 1257 |
+
dims=[64, 128, 320, 512],
|
| 1258 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1259 |
+
head_fn=MlpHead,
|
| 1260 |
+
**kwargs)
|
| 1261 |
+
model.default_cfg = default_cfgs['caformer_s18_in21ft1k']
|
| 1262 |
+
if pretrained:
|
| 1263 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1264 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1265 |
+
model.load_state_dict(state_dict)
|
| 1266 |
+
return model
|
| 1267 |
+
|
| 1268 |
+
|
| 1269 |
+
@register_model
|
| 1270 |
+
def caformer_s18_384_in21ft1k(pretrained=False, **kwargs):
|
| 1271 |
+
model = MetaFormer(
|
| 1272 |
+
depths=[3, 3, 9, 3],
|
| 1273 |
+
dims=[64, 128, 320, 512],
|
| 1274 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1275 |
+
head_fn=MlpHead,
|
| 1276 |
+
**kwargs)
|
| 1277 |
+
model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k']
|
| 1278 |
+
if pretrained:
|
| 1279 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1280 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1281 |
+
model.load_state_dict(state_dict)
|
| 1282 |
+
return model
|
| 1283 |
+
|
| 1284 |
+
|
| 1285 |
+
@register_model
|
| 1286 |
+
def caformer_s18_in21k(pretrained=False, **kwargs):
|
| 1287 |
+
model = MetaFormer(
|
| 1288 |
+
depths=[3, 3, 9, 3],
|
| 1289 |
+
dims=[64, 128, 320, 512],
|
| 1290 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1291 |
+
head_fn=MlpHead,
|
| 1292 |
+
**kwargs)
|
| 1293 |
+
model.default_cfg = default_cfgs['caformer_s18_in21k']
|
| 1294 |
+
if pretrained:
|
| 1295 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1296 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1297 |
+
model.load_state_dict(state_dict)
|
| 1298 |
+
return model
|
| 1299 |
+
|
| 1300 |
+
|
| 1301 |
+
@register_model
|
| 1302 |
+
def caformer_s36(pretrained=False, **kwargs):
|
| 1303 |
+
model = MetaFormer(
|
| 1304 |
+
depths=[3, 12, 18, 3],
|
| 1305 |
+
dims=[64, 128, 320, 512],
|
| 1306 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1307 |
+
head_fn=MlpHead,
|
| 1308 |
+
**kwargs)
|
| 1309 |
+
model.default_cfg = default_cfgs['caformer_s36']
|
| 1310 |
+
if pretrained:
|
| 1311 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1312 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1313 |
+
model.load_state_dict(state_dict)
|
| 1314 |
+
return model
|
| 1315 |
+
|
| 1316 |
+
|
| 1317 |
+
@register_model
|
| 1318 |
+
def caformer_s36_384(pretrained=False, **kwargs):
|
| 1319 |
+
model = MetaFormer(
|
| 1320 |
+
depths=[3, 12, 18, 3],
|
| 1321 |
+
dims=[64, 128, 320, 512],
|
| 1322 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1323 |
+
head_fn=MlpHead,
|
| 1324 |
+
**kwargs)
|
| 1325 |
+
model.default_cfg = default_cfgs['caformer_s36_384']
|
| 1326 |
+
if pretrained:
|
| 1327 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1328 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1329 |
+
model.load_state_dict(state_dict)
|
| 1330 |
+
return model
|
| 1331 |
+
|
| 1332 |
+
|
| 1333 |
+
@register_model
|
| 1334 |
+
def caformer_s36_in21ft1k(pretrained=False, **kwargs):
|
| 1335 |
+
model = MetaFormer(
|
| 1336 |
+
depths=[3, 12, 18, 3],
|
| 1337 |
+
dims=[64, 128, 320, 512],
|
| 1338 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1339 |
+
head_fn=MlpHead,
|
| 1340 |
+
**kwargs)
|
| 1341 |
+
model.default_cfg = default_cfgs['caformer_s36_in21ft1k']
|
| 1342 |
+
if pretrained:
|
| 1343 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1344 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1345 |
+
model.load_state_dict(state_dict)
|
| 1346 |
+
return model
|
| 1347 |
+
|
| 1348 |
+
|
| 1349 |
+
@register_model
|
| 1350 |
+
def caformer_s36_384_in21ft1k(pretrained=False, **kwargs):
|
| 1351 |
+
model = MetaFormer(
|
| 1352 |
+
depths=[3, 12, 18, 3],
|
| 1353 |
+
dims=[64, 128, 320, 512],
|
| 1354 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1355 |
+
head_fn=MlpHead,
|
| 1356 |
+
**kwargs)
|
| 1357 |
+
model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k']
|
| 1358 |
+
if pretrained:
|
| 1359 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1360 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1361 |
+
model.load_state_dict(state_dict)
|
| 1362 |
+
return model
|
| 1363 |
+
|
| 1364 |
+
|
| 1365 |
+
@register_model
|
| 1366 |
+
def caformer_s36_in21k(pretrained=False, **kwargs):
|
| 1367 |
+
model = MetaFormer(
|
| 1368 |
+
depths=[3, 12, 18, 3],
|
| 1369 |
+
dims=[64, 128, 320, 512],
|
| 1370 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1371 |
+
head_fn=MlpHead,
|
| 1372 |
+
**kwargs)
|
| 1373 |
+
model.default_cfg = default_cfgs['caformer_s36_in21k']
|
| 1374 |
+
if pretrained:
|
| 1375 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1376 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1377 |
+
model.load_state_dict(state_dict)
|
| 1378 |
+
return model
|
| 1379 |
+
|
| 1380 |
+
|
| 1381 |
+
@register_model
|
| 1382 |
+
def caformer_m36(pretrained=False, **kwargs):
|
| 1383 |
+
model = MetaFormer(
|
| 1384 |
+
depths=[3, 12, 18, 3],
|
| 1385 |
+
dims=[96, 192, 384, 576],
|
| 1386 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1387 |
+
head_fn=MlpHead,
|
| 1388 |
+
**kwargs)
|
| 1389 |
+
model.default_cfg = default_cfgs['caformer_m36']
|
| 1390 |
+
if pretrained:
|
| 1391 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1392 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1393 |
+
model.load_state_dict(state_dict)
|
| 1394 |
+
return model
|
| 1395 |
+
|
| 1396 |
+
|
| 1397 |
+
@register_model
|
| 1398 |
+
def caformer_m36_384(pretrained=False, **kwargs):
|
| 1399 |
+
model = MetaFormer(
|
| 1400 |
+
depths=[3, 12, 18, 3],
|
| 1401 |
+
dims=[96, 192, 384, 576],
|
| 1402 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1403 |
+
head_fn=MlpHead,
|
| 1404 |
+
**kwargs)
|
| 1405 |
+
model.default_cfg = default_cfgs['caformer_m36_384']
|
| 1406 |
+
if pretrained:
|
| 1407 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1408 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1409 |
+
model.load_state_dict(state_dict)
|
| 1410 |
+
return model
|
| 1411 |
+
|
| 1412 |
+
|
| 1413 |
+
@register_model
|
| 1414 |
+
def caformer_m36_in21ft1k(pretrained=False, **kwargs):
|
| 1415 |
+
model = MetaFormer(
|
| 1416 |
+
depths=[3, 12, 18, 3],
|
| 1417 |
+
dims=[96, 192, 384, 576],
|
| 1418 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1419 |
+
head_fn=MlpHead,
|
| 1420 |
+
**kwargs)
|
| 1421 |
+
model.default_cfg = default_cfgs['caformer_m36_in21ft1k']
|
| 1422 |
+
if pretrained:
|
| 1423 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1424 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1425 |
+
model.load_state_dict(state_dict)
|
| 1426 |
+
return model
|
| 1427 |
+
|
| 1428 |
+
|
| 1429 |
+
@register_model
|
| 1430 |
+
def caformer_m36_384_in21ft1k(pretrained=False, **kwargs):
|
| 1431 |
+
model = MetaFormer(
|
| 1432 |
+
depths=[3, 12, 18, 3],
|
| 1433 |
+
dims=[96, 192, 384, 576],
|
| 1434 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1435 |
+
head_fn=MlpHead,
|
| 1436 |
+
**kwargs)
|
| 1437 |
+
model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k']
|
| 1438 |
+
if pretrained:
|
| 1439 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1440 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1441 |
+
model.load_state_dict(state_dict)
|
| 1442 |
+
return model
|
| 1443 |
+
|
| 1444 |
+
|
| 1445 |
+
@register_model
|
| 1446 |
+
def caformer_m364_in21k(pretrained=False, **kwargs):
|
| 1447 |
+
model = MetaFormer(
|
| 1448 |
+
depths=[3, 12, 18, 3],
|
| 1449 |
+
dims=[96, 192, 384, 576],
|
| 1450 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1451 |
+
head_fn=MlpHead,
|
| 1452 |
+
**kwargs)
|
| 1453 |
+
model.default_cfg = default_cfgs['caformer_m364_in21k']
|
| 1454 |
+
if pretrained:
|
| 1455 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1456 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1457 |
+
model.load_state_dict(state_dict)
|
| 1458 |
+
return model
|
| 1459 |
+
|
| 1460 |
+
|
| 1461 |
+
@register_model
|
| 1462 |
+
def caformer_b36(pretrained=False, **kwargs):
|
| 1463 |
+
model = MetaFormer(
|
| 1464 |
+
depths=[3, 12, 18, 3],
|
| 1465 |
+
dims=[128, 256, 512, 768],
|
| 1466 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1467 |
+
head_fn=MlpHead,
|
| 1468 |
+
**kwargs)
|
| 1469 |
+
model.default_cfg = default_cfgs['caformer_b36']
|
| 1470 |
+
if pretrained:
|
| 1471 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1472 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1473 |
+
model.load_state_dict(state_dict)
|
| 1474 |
+
return model
|
| 1475 |
+
|
| 1476 |
+
|
| 1477 |
+
@register_model
|
| 1478 |
+
def caformer_b36_384(pretrained=False, **kwargs):
|
| 1479 |
+
model = MetaFormer(
|
| 1480 |
+
depths=[3, 12, 18, 3],
|
| 1481 |
+
dims=[128, 256, 512, 768],
|
| 1482 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1483 |
+
head_fn=MlpHead,
|
| 1484 |
+
**kwargs)
|
| 1485 |
+
model.default_cfg = default_cfgs['caformer_b36_384']
|
| 1486 |
+
if pretrained:
|
| 1487 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1488 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1489 |
+
model.load_state_dict(state_dict)
|
| 1490 |
+
return model
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
@register_model
|
| 1494 |
+
def caformer_b36_in21ft1k(pretrained=False, **kwargs):
|
| 1495 |
+
model = MetaFormer(
|
| 1496 |
+
depths=[3, 12, 18, 3],
|
| 1497 |
+
dims=[128, 256, 512, 768],
|
| 1498 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1499 |
+
head_fn=MlpHead,
|
| 1500 |
+
**kwargs)
|
| 1501 |
+
model.default_cfg = default_cfgs['caformer_b36_in21ft1k']
|
| 1502 |
+
if pretrained:
|
| 1503 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1504 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1505 |
+
model.load_state_dict(state_dict)
|
| 1506 |
+
return model
|
| 1507 |
+
|
| 1508 |
+
|
| 1509 |
+
@register_model
|
| 1510 |
+
def caformer_b36_384_in21ft1k(pretrained=False, **kwargs):
|
| 1511 |
+
model = MetaFormer(
|
| 1512 |
+
depths=[3, 12, 18, 3],
|
| 1513 |
+
dims=[128, 256, 512, 768],
|
| 1514 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1515 |
+
head_fn=MlpHead,
|
| 1516 |
+
**kwargs)
|
| 1517 |
+
model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k']
|
| 1518 |
+
if pretrained:
|
| 1519 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1520 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1521 |
+
model.load_state_dict(state_dict)
|
| 1522 |
+
return model
|
| 1523 |
+
|
| 1524 |
+
|
| 1525 |
+
@register_model
|
| 1526 |
+
def caformer_b36_in21k(pretrained=False, **kwargs):
|
| 1527 |
+
model = MetaFormer(
|
| 1528 |
+
depths=[3, 12, 18, 3],
|
| 1529 |
+
dims=[128, 256, 512, 768],
|
| 1530 |
+
token_mixers=[SepConv, SepConv, Attention, Attention],
|
| 1531 |
+
head_fn=MlpHead,
|
| 1532 |
+
**kwargs)
|
| 1533 |
+
model.default_cfg = default_cfgs['caformer_b36_in21k']
|
| 1534 |
+
if pretrained:
|
| 1535 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 1536 |
+
url= model.default_cfg['url'], map_location="cpu", check_hash=True)
|
| 1537 |
+
model.load_state_dict(state_dict)
|
| 1538 |
+
return model
|
models/neuron.py
ADDED
|
@@ -0,0 +1,1587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from typing import Callable, overload
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from spikingjelly.clock_driven import surrogate, base, lava_exchange
|
| 6 |
+
from spikingjelly import configure
|
| 7 |
+
import math
|
| 8 |
+
import numpy as np
|
| 9 |
+
import logging
|
| 10 |
+
import cupy
|
| 11 |
+
from spikingjelly.clock_driven import neuron_kernel, cu_kernel_opt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
import lava.lib.dl.slayer as slayer
|
| 16 |
+
|
| 17 |
+
except BaseException as e:
|
| 18 |
+
logging.info(f'spikingjelly.clock_driven.neuron: {e}')
|
| 19 |
+
slayer = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def check_backend(backend: str):
|
| 23 |
+
if backend == 'torch':
|
| 24 |
+
return
|
| 25 |
+
elif backend == 'cupy':
|
| 26 |
+
assert cupy is not None, 'CuPy is not installed! You can install it from "https://github.com/cupy/cupy".'
|
| 27 |
+
elif backend == 'lava':
|
| 28 |
+
assert slayer is not None, 'Lava-DL is not installed! You can install it from "https://github.com/lava-nc/lava-dl".'
|
| 29 |
+
else:
|
| 30 |
+
raise NotImplementedError(backend)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class BaseNode(base.MemoryModule):
|
| 34 |
+
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
|
| 35 |
+
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
|
| 36 |
+
"""
|
| 37 |
+
* :ref:`API in English <BaseNode.__init__-en>`
|
| 38 |
+
|
| 39 |
+
.. _BaseNode.__init__-cn:
|
| 40 |
+
|
| 41 |
+
:param v_threshold: 神经元的阈值电压
|
| 42 |
+
:type v_threshold: float
|
| 43 |
+
|
| 44 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 45 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 46 |
+
:type v_reset: float
|
| 47 |
+
|
| 48 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 49 |
+
:type surrogate_function: Callable
|
| 50 |
+
|
| 51 |
+
:param detach_reset: 是否将reset过程的计算图分离
|
| 52 |
+
:type detach_reset: bool
|
| 53 |
+
|
| 54 |
+
可微分SNN神经元的基类神经元。
|
| 55 |
+
|
| 56 |
+
* :ref:`中文API <BaseNode.__init__-cn>`
|
| 57 |
+
|
| 58 |
+
.. _BaseNode.__init__-en:
|
| 59 |
+
|
| 60 |
+
:param v_threshold: threshold voltage of neurons
|
| 61 |
+
:type v_threshold: float
|
| 62 |
+
|
| 63 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 64 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 65 |
+
:type v_reset: float
|
| 66 |
+
|
| 67 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 68 |
+
:type surrogate_function: Callable
|
| 69 |
+
|
| 70 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 71 |
+
:type detach_reset: bool
|
| 72 |
+
|
| 73 |
+
This class is the base class of differentiable spiking neurons.
|
| 74 |
+
"""
|
| 75 |
+
assert isinstance(v_reset, float) or v_reset is None
|
| 76 |
+
assert isinstance(v_threshold, float)
|
| 77 |
+
assert isinstance(detach_reset, bool)
|
| 78 |
+
super().__init__()
|
| 79 |
+
|
| 80 |
+
if v_reset is None:
|
| 81 |
+
self.register_memory('v', 0.)
|
| 82 |
+
else:
|
| 83 |
+
self.register_memory('v', v_reset)
|
| 84 |
+
|
| 85 |
+
self.register_memory('v_threshold', v_threshold)
|
| 86 |
+
self.register_memory('v_reset', v_reset)
|
| 87 |
+
|
| 88 |
+
self.detach_reset = detach_reset
|
| 89 |
+
self.surrogate_function = surrogate_function
|
| 90 |
+
|
| 91 |
+
@abstractmethod
|
| 92 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 93 |
+
"""
|
| 94 |
+
* :ref:`API in English <BaseNode.neuronal_charge-en>`
|
| 95 |
+
|
| 96 |
+
.. _BaseNode.neuronal_charge-cn:
|
| 97 |
+
|
| 98 |
+
定义神经元的充电差分方程。子类必须实现这个函数。
|
| 99 |
+
|
| 100 |
+
* :ref:`中文API <BaseNode.neuronal_charge-cn>`
|
| 101 |
+
|
| 102 |
+
.. _BaseNode.neuronal_charge-en:
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
Define the charge difference equation. The sub-class must implement this function.
|
| 106 |
+
"""
|
| 107 |
+
raise NotImplementedError
|
| 108 |
+
|
| 109 |
+
def neuronal_fire(self):
|
| 110 |
+
"""
|
| 111 |
+
* :ref:`API in English <BaseNode.neuronal_fire-en>`
|
| 112 |
+
|
| 113 |
+
.. _BaseNode.neuronal_fire-cn:
|
| 114 |
+
|
| 115 |
+
根据当前神经元的电压、阈值,计算输出脉冲。
|
| 116 |
+
|
| 117 |
+
* :ref:`中文API <BaseNode.neuronal_fire-cn>`
|
| 118 |
+
|
| 119 |
+
.. _BaseNode.neuronal_fire-en:
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
Calculate out spikes of neurons by their current membrane potential and threshold voltage.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
return self.surrogate_function(self.v - self.v_threshold)
|
| 126 |
+
|
| 127 |
+
def neuronal_reset(self, spike):
|
| 128 |
+
"""
|
| 129 |
+
* :ref:`API in English <BaseNode.neuronal_reset-en>`
|
| 130 |
+
|
| 131 |
+
.. _BaseNode.neuronal_reset-cn:
|
| 132 |
+
|
| 133 |
+
根据当前神经元释放的脉冲,对膜电位进行重置。
|
| 134 |
+
|
| 135 |
+
* :ref:`中文API <BaseNode.neuronal_reset-cn>`
|
| 136 |
+
|
| 137 |
+
.. _BaseNode.neuronal_reset-en:
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
Reset the membrane potential according to neurons' output spikes.
|
| 141 |
+
"""
|
| 142 |
+
if self.detach_reset:
|
| 143 |
+
spike_d = spike.detach()
|
| 144 |
+
else:
|
| 145 |
+
spike_d = spike
|
| 146 |
+
|
| 147 |
+
if self.v_reset is None:
|
| 148 |
+
# soft reset
|
| 149 |
+
self.v = self.v - spike_d * self.v_threshold
|
| 150 |
+
|
| 151 |
+
else:
|
| 152 |
+
# hard reset
|
| 153 |
+
self.v = (1. - spike_d) * self.v + spike_d * self.v_reset
|
| 154 |
+
|
| 155 |
+
def extra_repr(self):
|
| 156 |
+
return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}'
|
| 157 |
+
|
| 158 |
+
def forward(self, x: torch.Tensor):
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
* :ref:`API in English <BaseNode.forward-en>`
|
| 162 |
+
|
| 163 |
+
.. _BaseNode.forward-cn:
|
| 164 |
+
|
| 165 |
+
:param x: 输入到神经元的电压增量
|
| 166 |
+
:type x: torch.Tensor
|
| 167 |
+
|
| 168 |
+
:return: 神经元的输出脉冲
|
| 169 |
+
:rtype: torch.Tensor
|
| 170 |
+
|
| 171 |
+
按照充电、放电、重置的顺序进行前向传播。
|
| 172 |
+
|
| 173 |
+
* :ref:`中文API <BaseNode.forward-cn>`
|
| 174 |
+
|
| 175 |
+
.. _BaseNode.forward-en:
|
| 176 |
+
|
| 177 |
+
:param x: increment of voltage inputted to neurons
|
| 178 |
+
:type x: torch.Tensor
|
| 179 |
+
|
| 180 |
+
:return: out spikes of neurons
|
| 181 |
+
:rtype: torch.Tensor
|
| 182 |
+
|
| 183 |
+
Forward by the order of `neuronal_charge`, `neuronal_fire`, and `neuronal_reset`.
|
| 184 |
+
|
| 185 |
+
"""
|
| 186 |
+
self.neuronal_charge(x)
|
| 187 |
+
spike = self.neuronal_fire()
|
| 188 |
+
self.neuronal_reset(spike)
|
| 189 |
+
return spike
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class AdaptiveBaseNode(BaseNode):
|
| 193 |
+
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
|
| 194 |
+
v_rest: float = 0., w_rest: float = 0, tau_w: float = 2., a: float = 0., b: float = 0.,
|
| 195 |
+
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
|
| 196 |
+
# b: jump amplitudes
|
| 197 |
+
# a: subthreshold coupling
|
| 198 |
+
assert isinstance(w_rest, float)
|
| 199 |
+
assert isinstance(v_rest, float)
|
| 200 |
+
assert isinstance(tau_w, float)
|
| 201 |
+
assert isinstance(a, float)
|
| 202 |
+
assert isinstance(b, float)
|
| 203 |
+
|
| 204 |
+
super.__init__(v_threshold, v_reset, surrogate_function, detach_reset)
|
| 205 |
+
|
| 206 |
+
self.register_memory('w', w_rest)
|
| 207 |
+
|
| 208 |
+
self.w_rest = w_rest
|
| 209 |
+
self.v_rest = v_rest
|
| 210 |
+
self.tau_w = tau_w
|
| 211 |
+
self.a = a
|
| 212 |
+
self.b = b
|
| 213 |
+
|
| 214 |
+
def neuronal_adaptation(self, spike):
|
| 215 |
+
self.w = self.w + 1. / self.tau_w * (self.a * (self.v - self.v_rest) - self.w) + self.b * spike
|
| 216 |
+
|
| 217 |
+
def extra_repr(self):
|
| 218 |
+
return super().extra_repr() + f', v_rest={self.v_rest}, w_rest={self.w_rest}, tau_w={self.tau_w}, a={self.a}, b={self.b}'
|
| 219 |
+
|
| 220 |
+
@overload
|
| 221 |
+
def forward(self, x: torch.Tensor):
|
| 222 |
+
self.neuronal_charge(x)
|
| 223 |
+
spike = self.neuronal_fire()
|
| 224 |
+
self.neuronal_adaptation(spike)
|
| 225 |
+
self.neuronal_reset(spike)
|
| 226 |
+
return spike
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class IFNode(BaseNode):
|
| 230 |
+
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
|
| 231 |
+
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
|
| 232 |
+
cupy_fp32_inference=False):
|
| 233 |
+
"""
|
| 234 |
+
* :ref:`API in English <IFNode.__init__-en>`
|
| 235 |
+
|
| 236 |
+
.. _IFNode.__init__-cn:
|
| 237 |
+
|
| 238 |
+
:param v_threshold: 神经元的阈值电压
|
| 239 |
+
:type v_threshold: float
|
| 240 |
+
|
| 241 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 242 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 243 |
+
:type v_reset: float
|
| 244 |
+
|
| 245 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 246 |
+
:type surrogate_function: Callable
|
| 247 |
+
|
| 248 |
+
:param detach_reset: 是否将reset过程的计算图分离
|
| 249 |
+
:type detach_reset: bool
|
| 250 |
+
|
| 251 |
+
:param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速
|
| 252 |
+
:type cupy_fp32_inference: bool
|
| 253 |
+
|
| 254 |
+
Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像LIF神经元那样衰减。其阈下神经动力学方程为:
|
| 255 |
+
|
| 256 |
+
.. math::
|
| 257 |
+
V[t] = V[t-1] + X[t]
|
| 258 |
+
|
| 259 |
+
* :ref:`中文API <IFNode.__init__-cn>`
|
| 260 |
+
|
| 261 |
+
.. _IFNode.__init__-en:
|
| 262 |
+
|
| 263 |
+
:param v_threshold: threshold voltage of neurons
|
| 264 |
+
:type v_threshold: float
|
| 265 |
+
|
| 266 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 267 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 268 |
+
:type v_reset: float
|
| 269 |
+
|
| 270 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 271 |
+
:type surrogate_function: Callable
|
| 272 |
+
|
| 273 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 274 |
+
:type detach_reset: bool
|
| 275 |
+
|
| 276 |
+
:param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
|
| 277 |
+
module will use `cupy` to accelerate
|
| 278 |
+
:type cupy_fp32_inference: bool
|
| 279 |
+
|
| 280 |
+
The Integrate-and-Fire neuron, which can be seen as a ideal integrator. The voltage of the IF neuron will not decay
|
| 281 |
+
as that of the LIF neuron. The subthreshold neural dynamics of it is as followed:
|
| 282 |
+
|
| 283 |
+
.. math::
|
| 284 |
+
V[t] = V[t-1] + X[t]
|
| 285 |
+
|
| 286 |
+
"""
|
| 287 |
+
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
|
| 288 |
+
|
| 289 |
+
if cupy_fp32_inference:
|
| 290 |
+
check_backend('cupy')
|
| 291 |
+
self.cupy_fp32_inference = cupy_fp32_inference
|
| 292 |
+
|
| 293 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 294 |
+
self.v = self.v + x
|
| 295 |
+
|
| 296 |
+
def forward(self, x: torch.Tensor):
|
| 297 |
+
if self.cupy_fp32_inference and cupy is not None and not self.training and x.dtype == torch.float32:
|
| 298 |
+
# cupy is installed && eval mode && fp32
|
| 299 |
+
device_id = x.get_device()
|
| 300 |
+
if device_id < 0:
|
| 301 |
+
return super().forward(x)
|
| 302 |
+
|
| 303 |
+
# use cupy to accelerate
|
| 304 |
+
if isinstance(self.v, float):
|
| 305 |
+
v = torch.zeros_like(x)
|
| 306 |
+
if self.v != 0.:
|
| 307 |
+
torch.fill_(v, self.v)
|
| 308 |
+
self.v = v
|
| 309 |
+
|
| 310 |
+
if self.v_reset is None:
|
| 311 |
+
hard_reset = False
|
| 312 |
+
else:
|
| 313 |
+
hard_reset = True
|
| 314 |
+
|
| 315 |
+
code = rf'''
|
| 316 |
+
extern "C" __global__
|
| 317 |
+
void IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward(
|
| 318 |
+
const float * x, const float & v_threshold, {'const float & v_reset,' if hard_reset else ''}
|
| 319 |
+
float * spike, float * v,
|
| 320 |
+
const int & numel)
|
| 321 |
+
'''
|
| 322 |
+
|
| 323 |
+
code += r'''
|
| 324 |
+
{
|
| 325 |
+
const int index = blockIdx.x * blockDim.x + threadIdx.x;
|
| 326 |
+
if (index < numel)
|
| 327 |
+
{
|
| 328 |
+
v[index] += x[index];
|
| 329 |
+
spike[index] = (float) (v[index] >= v_threshold);
|
| 330 |
+
'''
|
| 331 |
+
|
| 332 |
+
code += rf'''
|
| 333 |
+
{'v[index] = (1.0f - spike[index]) * v[index] + spike[index] * v_reset;' if hard_reset else 'v[index] -= spike[index] * v_threshold;'}
|
| 334 |
+
'''
|
| 335 |
+
|
| 336 |
+
code += r'''
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
'''
|
| 340 |
+
if hasattr(self, 'cp_kernel'):
|
| 341 |
+
if self.cp_kernel.code != code:
|
| 342 |
+
# replace codes
|
| 343 |
+
del self.cp_kernel
|
| 344 |
+
self.cp_kernel = cupy.RawKernel(code,
|
| 345 |
+
f"IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward",
|
| 346 |
+
options=configure.cuda_compiler_options,
|
| 347 |
+
backend=configure.cuda_compiler_backend)
|
| 348 |
+
else:
|
| 349 |
+
self.cp_kernel = cupy.RawKernel(code,
|
| 350 |
+
f"IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward",
|
| 351 |
+
options=configure.cuda_compiler_options,
|
| 352 |
+
backend=configure.cuda_compiler_backend)
|
| 353 |
+
|
| 354 |
+
with cu_kernel_opt.DeviceEnvironment(device_id):
|
| 355 |
+
numel = x.numel()
|
| 356 |
+
threads = configure.cuda_threads
|
| 357 |
+
blocks = cu_kernel_opt.cal_blocks(numel)
|
| 358 |
+
cp_numel = cupy.asarray(numel)
|
| 359 |
+
cp_v_threshold = cupy.asarray(self.v_threshold, dtype=np.float32)
|
| 360 |
+
if hard_reset:
|
| 361 |
+
cp_v_reset = cupy.asarray(self.v_reset, dtype=np.float32)
|
| 362 |
+
|
| 363 |
+
spike = torch.zeros_like(x)
|
| 364 |
+
if hard_reset:
|
| 365 |
+
x, cp_v_threshold, cp_v_reset, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x,
|
| 366 |
+
cp_v_threshold,
|
| 367 |
+
cp_v_reset,
|
| 368 |
+
spike, self.v,
|
| 369 |
+
cp_numel)
|
| 370 |
+
kernel_args = [x, cp_v_threshold, cp_v_reset, spike, self.v, cp_numel]
|
| 371 |
+
else:
|
| 372 |
+
x, cp_v_threshold, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold, spike,
|
| 373 |
+
self.v, cp_numel)
|
| 374 |
+
kernel_args = [x, cp_v_threshold, spike, self.v, cp_numel]
|
| 375 |
+
self.cp_kernel(
|
| 376 |
+
(blocks,), (threads,),
|
| 377 |
+
cu_kernel_opt.wrap_args_to_raw_kernel(
|
| 378 |
+
device_id,
|
| 379 |
+
*kernel_args
|
| 380 |
+
)
|
| 381 |
+
)
|
| 382 |
+
return spike
|
| 383 |
+
else:
|
| 384 |
+
return super().forward(x)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
class MultiStepIFNode(IFNode):
|
| 388 |
+
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
|
| 389 |
+
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch',
|
| 390 |
+
lava_s_cale=1 << 6):
|
| 391 |
+
"""
|
| 392 |
+
* :ref:`API in English <MultiStepIFNode.__init__-en>`
|
| 393 |
+
|
| 394 |
+
.. _MultiStepIFNode.__init__-cn:
|
| 395 |
+
|
| 396 |
+
:param v_threshold: 神经元的阈值电压
|
| 397 |
+
:type v_threshold: float
|
| 398 |
+
|
| 399 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 400 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 401 |
+
:type v_reset: float
|
| 402 |
+
|
| 403 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 404 |
+
:type surrogate_function: Callable
|
| 405 |
+
|
| 406 |
+
:param detach_reset: 是否将reset过程的计算图分离
|
| 407 |
+
:type detach_reset: bool
|
| 408 |
+
|
| 409 |
+
:param backend: 使用哪种计算后端,可以为 ``'torch'`` 或 ``'cupy'``。``'cupy'`` 速度更快,但仅支持GPU。
|
| 410 |
+
:type backend: str
|
| 411 |
+
|
| 412 |
+
多步版本的 :class:`spikingjelly.clock_driven.neuron.IFNode`。
|
| 413 |
+
|
| 414 |
+
.. tip::
|
| 415 |
+
|
| 416 |
+
对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够
|
| 417 |
+
使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。
|
| 418 |
+
|
| 419 |
+
.. tip::
|
| 420 |
+
|
| 421 |
+
阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。
|
| 422 |
+
|
| 423 |
+
* :ref:`中文API <MultiStepIFNode.__init__-cn>`
|
| 424 |
+
|
| 425 |
+
.. _MultiStepIFNode.__init__-en:
|
| 426 |
+
|
| 427 |
+
:param v_threshold: threshold voltage of neurons
|
| 428 |
+
:type v_threshold: float
|
| 429 |
+
|
| 430 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 431 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 432 |
+
:type v_reset: float
|
| 433 |
+
|
| 434 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 435 |
+
:type surrogate_function: Callable
|
| 436 |
+
|
| 437 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 438 |
+
:type detach_reset: bool
|
| 439 |
+
|
| 440 |
+
:param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU
|
| 441 |
+
:type backend: str
|
| 442 |
+
|
| 443 |
+
The multi-step version of :class:`spikingjelly.clock_driven.neuron.IFNode`.
|
| 444 |
+
|
| 445 |
+
.. admonition:: Tip
|
| 446 |
+
:class: tip
|
| 447 |
+
|
| 448 |
+
The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at
|
| 449 |
+
time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T``
|
| 450 |
+
time-steps by ``.v_seq`` and ``.spike_seq``.
|
| 451 |
+
|
| 452 |
+
.. admonition:: Tip
|
| 453 |
+
:class: tip
|
| 454 |
+
|
| 455 |
+
Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step
|
| 456 |
+
and multi-step propagation.
|
| 457 |
+
|
| 458 |
+
"""
|
| 459 |
+
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
|
| 460 |
+
|
| 461 |
+
self.register_memory('v_seq', None)
|
| 462 |
+
|
| 463 |
+
check_backend(backend)
|
| 464 |
+
|
| 465 |
+
self.backend = backend
|
| 466 |
+
|
| 467 |
+
self.lava_s_cale = lava_s_cale
|
| 468 |
+
|
| 469 |
+
if backend == 'lava':
|
| 470 |
+
self.lava_neuron = self.to_lava()
|
| 471 |
+
else:
|
| 472 |
+
self.lava_neuron = None
|
| 473 |
+
|
| 474 |
+
def forward(self, x_seq: torch.Tensor):
|
| 475 |
+
assert x_seq.dim() > 1
|
| 476 |
+
# x_seq.shape = [T, *]
|
| 477 |
+
|
| 478 |
+
if self.backend == 'torch':
|
| 479 |
+
spike_seq = []
|
| 480 |
+
self.v_seq = []
|
| 481 |
+
for t in range(x_seq.shape[0]):
|
| 482 |
+
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
|
| 483 |
+
self.v_seq.append(self.v.unsqueeze(0))
|
| 484 |
+
spike_seq = torch.cat(spike_seq, 0)
|
| 485 |
+
self.v_seq = torch.cat(self.v_seq, 0)
|
| 486 |
+
return spike_seq
|
| 487 |
+
|
| 488 |
+
elif self.backend == 'cupy':
|
| 489 |
+
if isinstance(self.v, float):
|
| 490 |
+
v_init = self.v
|
| 491 |
+
self.v = torch.zeros_like(x_seq[0].data)
|
| 492 |
+
if v_init != 0.:
|
| 493 |
+
torch.fill_(self.v, v_init)
|
| 494 |
+
|
| 495 |
+
spike_seq, self.v_seq = neuron_kernel.MultiStepIFNodePTT.apply(
|
| 496 |
+
x_seq.flatten(1), self.v.flatten(0), self.v_threshold, self.v_reset, self.detach_reset,
|
| 497 |
+
self.surrogate_function.cuda_code)
|
| 498 |
+
|
| 499 |
+
spike_seq = spike_seq.reshape(x_seq.shape)
|
| 500 |
+
self.v_seq = self.v_seq.reshape(x_seq.shape)
|
| 501 |
+
|
| 502 |
+
self.v = self.v_seq[-1].clone()
|
| 503 |
+
|
| 504 |
+
return spike_seq
|
| 505 |
+
|
| 506 |
+
elif self.backend == 'lava':
|
| 507 |
+
if self.lava_neuron is None:
|
| 508 |
+
self.lava_neuron = self.to_lava()
|
| 509 |
+
|
| 510 |
+
spike, self.v = lava_exchange.lava_neuron_forward(self.lava_neuron, x_seq, self.v)
|
| 511 |
+
|
| 512 |
+
return spike
|
| 513 |
+
|
| 514 |
+
else:
|
| 515 |
+
raise NotImplementedError(self.backend)
|
| 516 |
+
|
| 517 |
+
def extra_repr(self):
|
| 518 |
+
return super().extra_repr() + f', backend={self.backend}'
|
| 519 |
+
|
| 520 |
+
def to_lava(self):
|
| 521 |
+
return lava_exchange.to_lava_neuron(self)
|
| 522 |
+
|
| 523 |
+
def reset(self):
|
| 524 |
+
super().reset()
|
| 525 |
+
if self.lava_neuron is not None:
|
| 526 |
+
self.lava_neuron.current_state.zero_()
|
| 527 |
+
self.lava_neuron.voltage_state.zero_()
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class LIFNode(BaseNode):
|
| 531 |
+
def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
|
| 532 |
+
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
|
| 533 |
+
detach_reset: bool = False, cupy_fp32_inference=False):
|
| 534 |
+
"""
|
| 535 |
+
* :ref:`API in English <LIFNode.__init__-en>`
|
| 536 |
+
|
| 537 |
+
.. _LIFNode.__init__-cn:
|
| 538 |
+
|
| 539 |
+
:param tau: 膜电位时间常数
|
| 540 |
+
:type tau: float
|
| 541 |
+
|
| 542 |
+
:param decay_input: 输入是否会衰减
|
| 543 |
+
:type decay_input: bool
|
| 544 |
+
|
| 545 |
+
:param v_threshold: 神经元的阈值电压
|
| 546 |
+
:type v_threshold: float
|
| 547 |
+
|
| 548 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 549 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 550 |
+
:type v_reset: float
|
| 551 |
+
|
| 552 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 553 |
+
:type surrogate_function: Callable
|
| 554 |
+
|
| 555 |
+
:param detach_reset: 是否将reset过程的计算图分离
|
| 556 |
+
:type detach_reset: bool
|
| 557 |
+
|
| 558 |
+
:param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速
|
| 559 |
+
:type cupy_fp32_inference: bool
|
| 560 |
+
|
| 561 |
+
Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
|
| 562 |
+
|
| 563 |
+
若 ``decay_input == True``:
|
| 564 |
+
|
| 565 |
+
.. math::
|
| 566 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
|
| 567 |
+
|
| 568 |
+
若 ``decay_input == False``:
|
| 569 |
+
|
| 570 |
+
.. math::
|
| 571 |
+
V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
|
| 572 |
+
|
| 573 |
+
.. tip::
|
| 574 |
+
|
| 575 |
+
在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速。
|
| 576 |
+
|
| 577 |
+
* :ref:`中文API <LIFNode.__init__-cn>`
|
| 578 |
+
|
| 579 |
+
.. _LIFNode.__init__-en:
|
| 580 |
+
|
| 581 |
+
:param tau: membrane time constant
|
| 582 |
+
:type tau: float
|
| 583 |
+
|
| 584 |
+
:param decay_input: whether the input will decay
|
| 585 |
+
:type decay_input: bool
|
| 586 |
+
|
| 587 |
+
:param v_threshold: threshold voltage of neurons
|
| 588 |
+
:type v_threshold: float
|
| 589 |
+
|
| 590 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 591 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 592 |
+
:type v_reset: float
|
| 593 |
+
|
| 594 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 595 |
+
:type surrogate_function: Callable
|
| 596 |
+
|
| 597 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 598 |
+
:type detach_reset: bool
|
| 599 |
+
|
| 600 |
+
:param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
|
| 601 |
+
module will use `cupy` to accelerate
|
| 602 |
+
:type cupy_fp32_inference: bool
|
| 603 |
+
|
| 604 |
+
The Leaky Integrate-and-Fire neuron, which can be seen as a leaky integrator.
|
| 605 |
+
The subthreshold neural dynamics of it is as followed:
|
| 606 |
+
|
| 607 |
+
IF ``decay_input == True``:
|
| 608 |
+
|
| 609 |
+
.. math::
|
| 610 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
|
| 611 |
+
|
| 612 |
+
IF ``decay_input == False``:
|
| 613 |
+
|
| 614 |
+
.. math::
|
| 615 |
+
V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
|
| 616 |
+
|
| 617 |
+
.. admonition:: Tip
|
| 618 |
+
:class: tip
|
| 619 |
+
|
| 620 |
+
If this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
|
| 621 |
+
module will use `cupy` to accelerate.
|
| 622 |
+
|
| 623 |
+
"""
|
| 624 |
+
assert isinstance(tau, float) and tau > 1.
|
| 625 |
+
|
| 626 |
+
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
|
| 627 |
+
self.tau = tau
|
| 628 |
+
self.decay_input = decay_input
|
| 629 |
+
|
| 630 |
+
if cupy_fp32_inference:
|
| 631 |
+
check_backend('cupy')
|
| 632 |
+
self.cupy_fp32_inference = cupy_fp32_inference
|
| 633 |
+
|
| 634 |
+
def extra_repr(self):
|
| 635 |
+
return super().extra_repr() + f', tau={self.tau}'
|
| 636 |
+
|
| 637 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 638 |
+
if self.decay_input:
|
| 639 |
+
if self.v_reset is None or self.v_reset == 0.:
|
| 640 |
+
self.v = self.v + (x - self.v) / self.tau
|
| 641 |
+
else:
|
| 642 |
+
self.v = self.v + (x - (self.v - self.v_reset)) / self.tau
|
| 643 |
+
|
| 644 |
+
else:
|
| 645 |
+
if self.v_reset is None or self.v_reset == 0.:
|
| 646 |
+
self.v = self.v * (1. - 1. / self.tau) + x
|
| 647 |
+
else:
|
| 648 |
+
self.v = self.v - (self.v - self.v_reset) / self.tau + x
|
| 649 |
+
|
| 650 |
+
def forward(self, x: torch.Tensor):
|
| 651 |
+
if self.cupy_fp32_inference and cupy is not None and not self.training and x.dtype == torch.float32:
|
| 652 |
+
# cupy is installed && eval mode && fp32
|
| 653 |
+
device_id = x.get_device()
|
| 654 |
+
if device_id < 0:
|
| 655 |
+
return super().forward(x)
|
| 656 |
+
|
| 657 |
+
# use cupy to accelerate
|
| 658 |
+
if isinstance(self.v, float):
|
| 659 |
+
v = torch.zeros_like(x)
|
| 660 |
+
if self.v != 0.:
|
| 661 |
+
torch.fill_(v, self.v)
|
| 662 |
+
self.v = v
|
| 663 |
+
|
| 664 |
+
if self.v_reset is None:
|
| 665 |
+
hard_reset = False
|
| 666 |
+
else:
|
| 667 |
+
hard_reset = True
|
| 668 |
+
|
| 669 |
+
code = rf'''
|
| 670 |
+
extern "C" __global__
|
| 671 |
+
void LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward(
|
| 672 |
+
const float * x, const float & v_threshold, {'const float & v_reset,' if hard_reset else ''} const float & tau,
|
| 673 |
+
float * spike, float * v,
|
| 674 |
+
const int & numel)
|
| 675 |
+
'''
|
| 676 |
+
|
| 677 |
+
code += r'''
|
| 678 |
+
{
|
| 679 |
+
const int index = blockIdx.x * blockDim.x + threadIdx.x;
|
| 680 |
+
if (index < numel)
|
| 681 |
+
{
|
| 682 |
+
|
| 683 |
+
'''
|
| 684 |
+
|
| 685 |
+
if self.decay_input:
|
| 686 |
+
if hard_reset:
|
| 687 |
+
code += r'''
|
| 688 |
+
v[index] += (x[index] - (v[index] - v_reset)) / tau;
|
| 689 |
+
'''
|
| 690 |
+
else:
|
| 691 |
+
code += r'''
|
| 692 |
+
v[index] += (x[index] - v[index]) / tau;
|
| 693 |
+
'''
|
| 694 |
+
else:
|
| 695 |
+
if hard_reset:
|
| 696 |
+
code += r'''
|
| 697 |
+
v[index] = x[index] + v[index] - (v[index] - v_reset) / tau;
|
| 698 |
+
'''
|
| 699 |
+
else:
|
| 700 |
+
code += r'''
|
| 701 |
+
v[index] = x[index] + v[index] * (1.0f - 1.0f / tau);
|
| 702 |
+
'''
|
| 703 |
+
|
| 704 |
+
code += rf'''
|
| 705 |
+
spike[index] = (float) (v[index] >= v_threshold);
|
| 706 |
+
{'v[index] = (1.0f - spike[index]) * v[index] + spike[index] * v_reset;' if hard_reset else 'v[index] -= spike[index] * v_threshold;'}
|
| 707 |
+
'''
|
| 708 |
+
|
| 709 |
+
code += r'''
|
| 710 |
+
}
|
| 711 |
+
}
|
| 712 |
+
'''
|
| 713 |
+
if hasattr(self, 'cp_kernel'):
|
| 714 |
+
if self.cp_kernel.code != code:
|
| 715 |
+
# replace codes
|
| 716 |
+
del self.cp_kernel
|
| 717 |
+
self.cp_kernel = cupy.RawKernel(code,
|
| 718 |
+
f"LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward",
|
| 719 |
+
options=configure.cuda_compiler_options,
|
| 720 |
+
backend=configure.cuda_compiler_backend)
|
| 721 |
+
else:
|
| 722 |
+
self.cp_kernel = cupy.RawKernel(code,
|
| 723 |
+
f"LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward",
|
| 724 |
+
options=configure.cuda_compiler_options,
|
| 725 |
+
backend=configure.cuda_compiler_backend)
|
| 726 |
+
|
| 727 |
+
with cu_kernel_opt.DeviceEnvironment(device_id):
|
| 728 |
+
numel = x.numel()
|
| 729 |
+
threads = configure.cuda_threads
|
| 730 |
+
blocks = cu_kernel_opt.cal_blocks(numel)
|
| 731 |
+
cp_numel = cupy.asarray(numel)
|
| 732 |
+
cp_v_threshold = cupy.asarray(self.v_threshold, dtype=np.float32)
|
| 733 |
+
if hard_reset:
|
| 734 |
+
cp_v_reset = cupy.asarray(self.v_reset, dtype=np.float32)
|
| 735 |
+
cp_tau = cupy.asarray(self.tau, dtype=np.float32)
|
| 736 |
+
spike = torch.zeros_like(x)
|
| 737 |
+
if hard_reset:
|
| 738 |
+
x, cp_v_threshold, cp_v_reset, cp_tau, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x,
|
| 739 |
+
cp_v_threshold,
|
| 740 |
+
cp_v_reset,
|
| 741 |
+
cp_tau,
|
| 742 |
+
spike,
|
| 743 |
+
self.v,
|
| 744 |
+
cp_numel)
|
| 745 |
+
kernel_args = [x, cp_v_threshold, cp_v_reset, cp_tau, spike, self.v, cp_numel]
|
| 746 |
+
else:
|
| 747 |
+
x, cp_v_threshold, cp_tau, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold,
|
| 748 |
+
cp_tau, spike,
|
| 749 |
+
self.v, cp_numel)
|
| 750 |
+
kernel_args = [x, cp_v_threshold, cp_tau, spike, self.v, cp_numel]
|
| 751 |
+
|
| 752 |
+
self.cp_kernel(
|
| 753 |
+
(blocks,), (threads,),
|
| 754 |
+
cu_kernel_opt.wrap_args_to_raw_kernel(
|
| 755 |
+
device_id,
|
| 756 |
+
*kernel_args
|
| 757 |
+
)
|
| 758 |
+
)
|
| 759 |
+
return spike
|
| 760 |
+
else:
|
| 761 |
+
return super().forward(x)
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
class MultiStepLIFNode(LIFNode):
|
| 765 |
+
def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
|
| 766 |
+
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
|
| 767 |
+
detach_reset: bool = False, backend='torch', lava_s_cale=1 << 6):
|
| 768 |
+
"""
|
| 769 |
+
* :ref:`API in English <MultiStepLIFNode.__init__-en>`
|
| 770 |
+
|
| 771 |
+
.. _MultiStepLIFNode.__init__-cn:
|
| 772 |
+
|
| 773 |
+
:param tau: 膜电位时间常数
|
| 774 |
+
:type tau: float
|
| 775 |
+
|
| 776 |
+
:param decay_input: 输入是否会衰减
|
| 777 |
+
:type decay_input: bool
|
| 778 |
+
|
| 779 |
+
:param v_threshold: 神经元的阈值电压
|
| 780 |
+
:type v_threshold: float
|
| 781 |
+
|
| 782 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 783 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 784 |
+
:type v_reset: float
|
| 785 |
+
|
| 786 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 787 |
+
:type surrogate_function: Callable
|
| 788 |
+
|
| 789 |
+
:param detach_reset: 是否将reset过程的计算图分离
|
| 790 |
+
:type detach_reset: bool
|
| 791 |
+
|
| 792 |
+
:param backend: 使用哪种计算后端,可以为 ``'torch'`` 或 ``'cupy'``。``'cupy'`` 速度更快,但仅支持GPU。
|
| 793 |
+
:type backend: str
|
| 794 |
+
|
| 795 |
+
多步版本的 :class:`spikingjelly.clock_driven.neuron.LIFNode`。
|
| 796 |
+
|
| 797 |
+
.. tip::
|
| 798 |
+
|
| 799 |
+
对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够
|
| 800 |
+
使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。
|
| 801 |
+
|
| 802 |
+
.. tip::
|
| 803 |
+
|
| 804 |
+
阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。
|
| 805 |
+
|
| 806 |
+
* :ref:`中文API <MultiStepLIFNode.__init__-cn>`
|
| 807 |
+
|
| 808 |
+
.. _MultiStepLIFNode.__init__-en:
|
| 809 |
+
|
| 810 |
+
:param tau: membrane time constant
|
| 811 |
+
:type tau: float
|
| 812 |
+
|
| 813 |
+
:param decay_input: whether the input will decay
|
| 814 |
+
:type decay_input: bool
|
| 815 |
+
|
| 816 |
+
:param v_threshold: threshold voltage of neurons
|
| 817 |
+
:type v_threshold: float
|
| 818 |
+
|
| 819 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 820 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 821 |
+
:type v_reset: float
|
| 822 |
+
|
| 823 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 824 |
+
:type surrogate_function: Callable
|
| 825 |
+
|
| 826 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 827 |
+
:type detach_reset: bool
|
| 828 |
+
|
| 829 |
+
:param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU
|
| 830 |
+
:type backend: str
|
| 831 |
+
|
| 832 |
+
The multi-step version of :class:`spikingjelly.clock_driven.neuron.LIFNode`.
|
| 833 |
+
|
| 834 |
+
.. admonition:: Tip
|
| 835 |
+
:class: tip
|
| 836 |
+
|
| 837 |
+
The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at
|
| 838 |
+
time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T``
|
| 839 |
+
time-steps by ``.v_seq`` and ``.spike_seq``.
|
| 840 |
+
|
| 841 |
+
.. admonition:: Tip
|
| 842 |
+
:class: tip
|
| 843 |
+
|
| 844 |
+
Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step
|
| 845 |
+
and multi-step propagation.
|
| 846 |
+
|
| 847 |
+
"""
|
| 848 |
+
super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset)
|
| 849 |
+
self.register_memory('v_seq', None)
|
| 850 |
+
|
| 851 |
+
check_backend(backend)
|
| 852 |
+
|
| 853 |
+
self.backend = backend
|
| 854 |
+
|
| 855 |
+
self.lava_s_cale = lava_s_cale
|
| 856 |
+
|
| 857 |
+
if backend == 'lava':
|
| 858 |
+
self.lava_neuron = self.to_lava()
|
| 859 |
+
else:
|
| 860 |
+
self.lava_neuron = None
|
| 861 |
+
|
| 862 |
+
def forward(self, x_seq: torch.Tensor):
|
| 863 |
+
assert x_seq.dim() > 1
|
| 864 |
+
# x_seq.shape = [T, *]
|
| 865 |
+
|
| 866 |
+
if self.backend == 'torch':
|
| 867 |
+
spike_seq = []
|
| 868 |
+
self.v_seq = []
|
| 869 |
+
for t in range(x_seq.shape[0]):
|
| 870 |
+
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
|
| 871 |
+
self.v_seq.append(self.v.unsqueeze(0))
|
| 872 |
+
spike_seq = torch.cat(spike_seq, 0)
|
| 873 |
+
self.v_seq = torch.cat(self.v_seq, 0)
|
| 874 |
+
return spike_seq
|
| 875 |
+
|
| 876 |
+
elif self.backend == 'cupy':
|
| 877 |
+
if isinstance(self.v, float):
|
| 878 |
+
v_init = self.v
|
| 879 |
+
self.v = torch.zeros_like(x_seq[0].data)
|
| 880 |
+
if v_init != 0.:
|
| 881 |
+
torch.fill_(self.v, v_init)
|
| 882 |
+
|
| 883 |
+
spike_seq, self.v_seq = neuron_kernel.MultiStepLIFNodePTT.apply(
|
| 884 |
+
x_seq.flatten(1), self.v.flatten(0), self.decay_input, self.tau, self.v_threshold, self.v_reset,
|
| 885 |
+
self.detach_reset, self.surrogate_function.cuda_code)
|
| 886 |
+
|
| 887 |
+
spike_seq = spike_seq.reshape(x_seq.shape)
|
| 888 |
+
self.v_seq = self.v_seq.reshape(x_seq.shape)
|
| 889 |
+
|
| 890 |
+
self.v = self.v_seq[-1].clone()
|
| 891 |
+
|
| 892 |
+
return spike_seq
|
| 893 |
+
|
| 894 |
+
elif self.backend == 'lava':
|
| 895 |
+
if self.lava_neuron is None:
|
| 896 |
+
self.lava_neuron = self.to_lava()
|
| 897 |
+
|
| 898 |
+
spike, self.v = lava_exchange.lava_neuron_forward(self.lava_neuron, x_seq, self.v)
|
| 899 |
+
|
| 900 |
+
return spike
|
| 901 |
+
|
| 902 |
+
else:
|
| 903 |
+
raise NotImplementedError(self.backend)
|
| 904 |
+
|
| 905 |
+
def extra_repr(self):
|
| 906 |
+
return super().extra_repr() + f', backend={self.backend}'
|
| 907 |
+
|
| 908 |
+
def to_lava(self):
|
| 909 |
+
return lava_exchange.to_lava_neuron(self)
|
| 910 |
+
|
| 911 |
+
def reset(self):
|
| 912 |
+
super().reset()
|
| 913 |
+
if self.lava_neuron is not None:
|
| 914 |
+
self.lava_neuron.current_state.zero_()
|
| 915 |
+
self.lava_neuron.voltage_state.zero_()
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
class ParametricLIFNode(BaseNode):
|
| 919 |
+
def __init__(self, init_tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1.,
|
| 920 |
+
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
|
| 921 |
+
detach_reset: bool = False):
|
| 922 |
+
"""
|
| 923 |
+
* :ref:`API in English <ParametricLIFNode.__init__-en>`
|
| 924 |
+
|
| 925 |
+
.. _ParametricLIFNode.__init__-cn:
|
| 926 |
+
|
| 927 |
+
:param init_tau: 膜电位时间常数的初始值
|
| 928 |
+
:type init_tau: float
|
| 929 |
+
|
| 930 |
+
:param decay_input: 输入是否会衰减
|
| 931 |
+
:type decay_input: bool
|
| 932 |
+
|
| 933 |
+
:param v_threshold: 神经元的阈值电压
|
| 934 |
+
:type v_threshold: float
|
| 935 |
+
|
| 936 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 937 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 938 |
+
:type v_reset: float
|
| 939 |
+
|
| 940 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 941 |
+
:type surrogate_function: Callable
|
| 942 |
+
|
| 943 |
+
:param detach_reset: 是否将reset过程的计算图分离
|
| 944 |
+
:type detach_reset: bool
|
| 945 |
+
|
| 946 |
+
`Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_
|
| 947 |
+
提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
|
| 948 |
+
|
| 949 |
+
若 ``decay_input == True``:
|
| 950 |
+
|
| 951 |
+
.. math::
|
| 952 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
|
| 953 |
+
|
| 954 |
+
若 ``decay_input == False``:
|
| 955 |
+
|
| 956 |
+
.. math::
|
| 957 |
+
V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
|
| 958 |
+
|
| 959 |
+
其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。
|
| 960 |
+
|
| 961 |
+
* :ref:`中文API <ParametricLIFNode.__init__-cn>`
|
| 962 |
+
|
| 963 |
+
.. _ParametricLIFNode.__init__-en:
|
| 964 |
+
|
| 965 |
+
:param init_tau: the initial value of membrane time constant
|
| 966 |
+
:type init_tau: float
|
| 967 |
+
|
| 968 |
+
:param decay_input: whether the input will decay
|
| 969 |
+
:type decay_input: bool
|
| 970 |
+
|
| 971 |
+
:param v_threshold: threshold voltage of neurons
|
| 972 |
+
:type v_threshold: float
|
| 973 |
+
|
| 974 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 975 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 976 |
+
:type v_reset: float
|
| 977 |
+
|
| 978 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 979 |
+
:type surrogate_function: Callable
|
| 980 |
+
|
| 981 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 982 |
+
:type detach_reset: bool
|
| 983 |
+
|
| 984 |
+
The Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_ and can be seen as a leaky integrator.
|
| 985 |
+
The subthreshold neural dynamics of it is as followed:
|
| 986 |
+
|
| 987 |
+
IF ``decay_input == True``:
|
| 988 |
+
|
| 989 |
+
.. math::
|
| 990 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
|
| 991 |
+
|
| 992 |
+
IF ``decay_input == False``:
|
| 993 |
+
|
| 994 |
+
.. math::
|
| 995 |
+
V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
|
| 996 |
+
|
| 997 |
+
where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter.
|
| 998 |
+
"""
|
| 999 |
+
|
| 1000 |
+
assert isinstance(init_tau, float) and init_tau > 1.
|
| 1001 |
+
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
|
| 1002 |
+
self.decay_input = decay_input
|
| 1003 |
+
init_w = - math.log(init_tau - 1.)
|
| 1004 |
+
self.w = nn.Parameter(torch.as_tensor(init_w))
|
| 1005 |
+
|
| 1006 |
+
def extra_repr(self):
|
| 1007 |
+
with torch.no_grad():
|
| 1008 |
+
tau = 1. / self.w.sigmoid()
|
| 1009 |
+
return super().extra_repr() + f', tau={tau}'
|
| 1010 |
+
|
| 1011 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 1012 |
+
if self.decay_input:
|
| 1013 |
+
if self.v_reset is None or self.v_reset == 0.:
|
| 1014 |
+
self.v = self.v + (x - self.v) * self.w.sigmoid()
|
| 1015 |
+
else:
|
| 1016 |
+
self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid()
|
| 1017 |
+
else:
|
| 1018 |
+
if self.v_reset is None or self.v_reset == 0.:
|
| 1019 |
+
self.v = self.v * (1. - self.w.sigmoid()) + x
|
| 1020 |
+
else:
|
| 1021 |
+
self.v = self.v - (self.v - self.v_reset) * self.w.sigmoid() + x
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
class MultiStepParametricLIFNode(ParametricLIFNode):
|
| 1025 |
+
def __init__(self, init_tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
|
| 1026 |
+
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
|
| 1027 |
+
detach_reset: bool = False, backend='torch'):
|
| 1028 |
+
"""
|
| 1029 |
+
* :ref:`API in English <MultiStepParametricLIFNode.__init__-en>`
|
| 1030 |
+
|
| 1031 |
+
.. _MultiStepParametricLIFNode.__init__-cn:
|
| 1032 |
+
|
| 1033 |
+
:param init_tau: 膜电位时间常数的初始值
|
| 1034 |
+
:type init_tau: float
|
| 1035 |
+
|
| 1036 |
+
:param decay_input: 输入是否会衰减
|
| 1037 |
+
:type decay_input: bool
|
| 1038 |
+
|
| 1039 |
+
:param v_threshold: 神经元的阈值电压
|
| 1040 |
+
:type v_threshold: float
|
| 1041 |
+
|
| 1042 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 1043 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 1044 |
+
:type v_reset: float
|
| 1045 |
+
|
| 1046 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 1047 |
+
:type surrogate_function: Callable
|
| 1048 |
+
|
| 1049 |
+
:param detach_reset: 是否将reset过程的计算图分离
|
| 1050 |
+
:type detach_reset: bool
|
| 1051 |
+
|
| 1052 |
+
多步版本的 `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_
|
| 1053 |
+
提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
|
| 1054 |
+
|
| 1055 |
+
.. math::
|
| 1056 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})
|
| 1057 |
+
|
| 1058 |
+
其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。
|
| 1059 |
+
|
| 1060 |
+
.. tip::
|
| 1061 |
+
|
| 1062 |
+
对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够
|
| 1063 |
+
使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。
|
| 1064 |
+
|
| 1065 |
+
.. tip::
|
| 1066 |
+
|
| 1067 |
+
阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。
|
| 1068 |
+
|
| 1069 |
+
* :ref:`中文API <MultiStepParametricLIFNode.__init__-cn>`
|
| 1070 |
+
|
| 1071 |
+
.. _MultiStepParametricLIFNode.__init__-en:
|
| 1072 |
+
|
| 1073 |
+
:param init_tau: the initial value of membrane time constant
|
| 1074 |
+
:type init_tau: float
|
| 1075 |
+
|
| 1076 |
+
:param decay_input: whether the input will decay
|
| 1077 |
+
:type decay_input: bool
|
| 1078 |
+
|
| 1079 |
+
:param v_threshold: threshold voltage of neurons
|
| 1080 |
+
:type v_threshold: float
|
| 1081 |
+
|
| 1082 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 1083 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 1084 |
+
:type v_reset: float
|
| 1085 |
+
|
| 1086 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 1087 |
+
:type surrogate_function: Callable
|
| 1088 |
+
|
| 1089 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 1090 |
+
:type detach_reset: bool
|
| 1091 |
+
|
| 1092 |
+
:param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU
|
| 1093 |
+
:type backend: str
|
| 1094 |
+
|
| 1095 |
+
The multi-step Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_ and can be seen as a leaky integrator.
|
| 1096 |
+
The subthreshold neural dynamics of it is as followed:
|
| 1097 |
+
|
| 1098 |
+
.. math::
|
| 1099 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})
|
| 1100 |
+
|
| 1101 |
+
where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter.
|
| 1102 |
+
|
| 1103 |
+
.. admonition:: Tip
|
| 1104 |
+
:class: tip
|
| 1105 |
+
|
| 1106 |
+
The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at
|
| 1107 |
+
time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T``
|
| 1108 |
+
time-steps by ``.v_seq`` and ``.spike_seq``.
|
| 1109 |
+
|
| 1110 |
+
.. admonition:: Tip
|
| 1111 |
+
:class: tip
|
| 1112 |
+
|
| 1113 |
+
Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step
|
| 1114 |
+
and multi-step propagation.
|
| 1115 |
+
"""
|
| 1116 |
+
super().__init__(init_tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset)
|
| 1117 |
+
self.register_memory('v_seq', None)
|
| 1118 |
+
|
| 1119 |
+
check_backend(backend)
|
| 1120 |
+
|
| 1121 |
+
self.backend = backend
|
| 1122 |
+
|
| 1123 |
+
def forward(self, x_seq: torch.Tensor):
|
| 1124 |
+
assert x_seq.dim() > 1
|
| 1125 |
+
# x_seq.shape = [T, *]
|
| 1126 |
+
|
| 1127 |
+
if self.backend == 'torch':
|
| 1128 |
+
spike_seq = []
|
| 1129 |
+
self.v_seq = []
|
| 1130 |
+
for t in range(x_seq.shape[0]):
|
| 1131 |
+
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
|
| 1132 |
+
self.v_seq.append(self.v.unsqueeze(0))
|
| 1133 |
+
spike_seq = torch.cat(spike_seq, 0)
|
| 1134 |
+
self.v_seq = torch.cat(self.v_seq, 0)
|
| 1135 |
+
return spike_seq
|
| 1136 |
+
|
| 1137 |
+
elif self.backend == 'cupy':
|
| 1138 |
+
if isinstance(self.v, float):
|
| 1139 |
+
v_init = self.v
|
| 1140 |
+
self.v = torch.zeros_like(x_seq[0].data)
|
| 1141 |
+
if v_init != 0.:
|
| 1142 |
+
torch.fill_(self.v, v_init)
|
| 1143 |
+
|
| 1144 |
+
spike_seq, self.v_seq = neuron_kernel.MultiStepParametricLIFNodePTT.apply(
|
| 1145 |
+
x_seq.flatten(1), self.v.flatten(0), self.w.sigmoid(), self.decay_input, self.v_threshold, self.v_reset,
|
| 1146 |
+
self.detach_reset, self.surrogate_function.cuda_code)
|
| 1147 |
+
|
| 1148 |
+
spike_seq = spike_seq.reshape(x_seq.shape)
|
| 1149 |
+
self.v_seq = self.v_seq.reshape(x_seq.shape)
|
| 1150 |
+
|
| 1151 |
+
self.v = self.v_seq[-1].clone()
|
| 1152 |
+
|
| 1153 |
+
return spike_seq
|
| 1154 |
+
else:
|
| 1155 |
+
raise NotImplementedError
|
| 1156 |
+
|
| 1157 |
+
def extra_repr(self):
|
| 1158 |
+
return super().extra_repr() + f', backend={self.backend}'
|
| 1159 |
+
|
| 1160 |
+
|
| 1161 |
+
class QIFNode(BaseNode):
|
| 1162 |
+
def __init__(self, tau: float = 2., v_c: float = 0.8, a0: float = 1., v_threshold: float = 1., v_rest: float = 0.,
|
| 1163 |
+
v_reset: float = -0.1,
|
| 1164 |
+
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
|
| 1165 |
+
"""
|
| 1166 |
+
* :ref:`API in English <QIFNode.__init__-en>`
|
| 1167 |
+
|
| 1168 |
+
.. _QIFNode.__init__-cn:
|
| 1169 |
+
|
| 1170 |
+
:param tau: 膜电位时间常数
|
| 1171 |
+
:type tau: float
|
| 1172 |
+
|
| 1173 |
+
:param v_c: 关键电压
|
| 1174 |
+
:type v_c: float
|
| 1175 |
+
|
| 1176 |
+
:param a0:
|
| 1177 |
+
:type a0: float
|
| 1178 |
+
|
| 1179 |
+
:param v_threshold: 神经元的阈值电压
|
| 1180 |
+
:type v_threshold: float
|
| 1181 |
+
|
| 1182 |
+
:param v_rest: 静息电位
|
| 1183 |
+
:type v_rest: float
|
| 1184 |
+
|
| 1185 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 1186 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 1187 |
+
:type v_reset: float
|
| 1188 |
+
|
| 1189 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 1190 |
+
:type surrogate_function: Callable
|
| 1191 |
+
|
| 1192 |
+
:param detach_reset: 是否将reset过程的计算图分离
|
| 1193 |
+
:type detach_reset: bool
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
Quadratic Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,也是指数积分发放神经元(Exponential Integrate-and-Fire)的近似版本。其阈下神经动力学方程为:
|
| 1197 |
+
|
| 1198 |
+
.. math::
|
| 1199 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))
|
| 1200 |
+
|
| 1201 |
+
* :ref:`中文API <QIFNode.__init__-cn>`
|
| 1202 |
+
|
| 1203 |
+
.. _QIFNode.__init__-en:
|
| 1204 |
+
|
| 1205 |
+
:param tau: membrane time constant
|
| 1206 |
+
:type tau: float
|
| 1207 |
+
|
| 1208 |
+
:param v_c: critical voltage
|
| 1209 |
+
:type v_c: float
|
| 1210 |
+
|
| 1211 |
+
:param a0:
|
| 1212 |
+
:type a0: float
|
| 1213 |
+
|
| 1214 |
+
:param v_threshold: threshold voltage of neurons
|
| 1215 |
+
:type v_threshold: float
|
| 1216 |
+
|
| 1217 |
+
:param v_rest: resting potential
|
| 1218 |
+
:type v_rest: float
|
| 1219 |
+
|
| 1220 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 1221 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 1222 |
+
:type v_reset: float
|
| 1223 |
+
|
| 1224 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 1225 |
+
:type surrogate_function: Callable
|
| 1226 |
+
|
| 1227 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 1228 |
+
:type detach_reset: bool
|
| 1229 |
+
|
| 1230 |
+
The Quadratic Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an approximation of the Exponential Integrate-and-Fire model.
|
| 1231 |
+
The subthreshold neural dynamics of it is as followed:
|
| 1232 |
+
|
| 1233 |
+
.. math::
|
| 1234 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))
|
| 1235 |
+
"""
|
| 1236 |
+
|
| 1237 |
+
assert isinstance(tau, float) and tau > 1.
|
| 1238 |
+
if v_reset is not None:
|
| 1239 |
+
assert v_threshold > v_reset
|
| 1240 |
+
assert v_rest >= v_reset
|
| 1241 |
+
assert a0 > 0
|
| 1242 |
+
|
| 1243 |
+
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
|
| 1244 |
+
self.tau = tau
|
| 1245 |
+
self.v_c = v_c
|
| 1246 |
+
self.v_rest = v_rest
|
| 1247 |
+
self.a0 = a0
|
| 1248 |
+
|
| 1249 |
+
def extra_repr(self):
|
| 1250 |
+
return super().extra_repr() + f', tau={self.tau}, v_c={self.v_c}, a0={self.a0}, v_rest={self.v_rest}'
|
| 1251 |
+
|
| 1252 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 1253 |
+
self.v = self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c)) / self.tau
|
| 1254 |
+
|
| 1255 |
+
|
| 1256 |
+
class EIFNode(BaseNode):
|
| 1257 |
+
def __init__(self, tau: float = 2., delta_T: float = 1., theta_rh: float = .8, v_threshold: float = 1.,
|
| 1258 |
+
v_rest: float = 0., v_reset: float = -0.1,
|
| 1259 |
+
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
|
| 1260 |
+
"""
|
| 1261 |
+
* :ref:`API in English <EIFNode.__init__-en>`
|
| 1262 |
+
|
| 1263 |
+
.. _EIFNode.__init__-cn:
|
| 1264 |
+
|
| 1265 |
+
:param tau: 膜电位时间常数
|
| 1266 |
+
:type tau: float
|
| 1267 |
+
|
| 1268 |
+
:param delta_T: 陡峭度参数
|
| 1269 |
+
:type delta_T: float
|
| 1270 |
+
|
| 1271 |
+
:param theta_rh: 基强度电压阈值
|
| 1272 |
+
:type theta_rh: float
|
| 1273 |
+
|
| 1274 |
+
:param v_threshold: 神经元的阈值电压
|
| 1275 |
+
:type v_threshold: float
|
| 1276 |
+
|
| 1277 |
+
:param v_rest: 静息电位
|
| 1278 |
+
:type v_rest: float
|
| 1279 |
+
|
| 1280 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 1281 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 1282 |
+
:type v_reset: float
|
| 1283 |
+
|
| 1284 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 1285 |
+
:type surrogate_function: Callable
|
| 1286 |
+
|
| 1287 |
+
:param detach_reset: 是否将reset过程的计算图分离
|
| 1288 |
+
:type detach_reset: bool
|
| 1289 |
+
|
| 1290 |
+
|
| 1291 |
+
Exponential Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,是由HH神经元模型(Hodgkin-Huxley model)简化后推导出的一维模型。在 :math:`\\Delta_T\\to 0` 时退化为LIF模型。其阈下神经动力学方程为:
|
| 1292 |
+
|
| 1293 |
+
.. math::
|
| 1294 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right)
|
| 1295 |
+
|
| 1296 |
+
* :ref:`中文API <EIFNode.__init__-cn>`
|
| 1297 |
+
|
| 1298 |
+
.. _EIFNode.__init__-en:
|
| 1299 |
+
|
| 1300 |
+
:param tau: membrane time constant
|
| 1301 |
+
:type tau: float
|
| 1302 |
+
|
| 1303 |
+
:param delta_T: sharpness parameter
|
| 1304 |
+
:type delta_T: float
|
| 1305 |
+
|
| 1306 |
+
:param theta_rh: rheobase threshold
|
| 1307 |
+
:type theta_rh: float
|
| 1308 |
+
|
| 1309 |
+
:param v_threshold: threshold voltage of neurons
|
| 1310 |
+
:type v_threshold: float
|
| 1311 |
+
|
| 1312 |
+
:param v_rest: resting potential
|
| 1313 |
+
:type v_rest: float
|
| 1314 |
+
|
| 1315 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 1316 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 1317 |
+
:type v_reset: float
|
| 1318 |
+
|
| 1319 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 1320 |
+
:type surrogate_function: Callable
|
| 1321 |
+
|
| 1322 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 1323 |
+
:type detach_reset: bool
|
| 1324 |
+
|
| 1325 |
+
The Exponential Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an one-dimensional model derived from the Hodgkin-Huxley model. It degenerates to the LIF model when :math:`\\Delta_T\\to 0`.
|
| 1326 |
+
The subthreshold neural dynamics of it is as followed:
|
| 1327 |
+
|
| 1328 |
+
.. math::
|
| 1329 |
+
V[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right)
|
| 1330 |
+
"""
|
| 1331 |
+
|
| 1332 |
+
assert isinstance(tau, float) and tau > 1.
|
| 1333 |
+
if v_reset is not None:
|
| 1334 |
+
assert v_threshold > v_reset
|
| 1335 |
+
assert v_rest >= v_reset
|
| 1336 |
+
assert delta_T > 0
|
| 1337 |
+
|
| 1338 |
+
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
|
| 1339 |
+
self.tau = tau
|
| 1340 |
+
self.delta_T = delta_T
|
| 1341 |
+
self.v_rest = v_rest
|
| 1342 |
+
self.theta_rh = theta_rh
|
| 1343 |
+
|
| 1344 |
+
def extra_repr(self):
|
| 1345 |
+
return super().extra_repr() + f', tau={self.tau}, delta_T={self.delta_T}, theta_rh={self.theta_rh}'
|
| 1346 |
+
|
| 1347 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 1348 |
+
|
| 1349 |
+
with torch.no_grad():
|
| 1350 |
+
if not isinstance(self.v, torch.Tensor):
|
| 1351 |
+
self.v = torch.as_tensor(self.v, device=x.device)
|
| 1352 |
+
|
| 1353 |
+
self.v = self.v + (x + self.v_rest - self.v + self.delta_T * torch.exp(
|
| 1354 |
+
(self.v - self.theta_rh) / self.delta_T)) / self.tau
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
class MultiStepEIFNode(EIFNode):
|
| 1358 |
+
def __init__(self, tau: float = 2., delta_T: float = 1., theta_rh: float = .8, v_threshold: float = 1.,
|
| 1359 |
+
v_rest: float = 0., v_reset: float = -0.1,
|
| 1360 |
+
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch'):
|
| 1361 |
+
"""
|
| 1362 |
+
* :ref:`API in English <MultiStepEIFNode.__init__-en>`
|
| 1363 |
+
|
| 1364 |
+
.. _MultiStepEIFNode.__init__-cn:
|
| 1365 |
+
|
| 1366 |
+
::param tau: 膜电位时间常数
|
| 1367 |
+
:type tau: float
|
| 1368 |
+
|
| 1369 |
+
:param delta_T: 陡峭度参数
|
| 1370 |
+
:type delta_T: float
|
| 1371 |
+
|
| 1372 |
+
:param theta_rh: 基强度电压阈值
|
| 1373 |
+
:type theta_rh: float
|
| 1374 |
+
|
| 1375 |
+
:param v_threshold: 神经元的阈值电压
|
| 1376 |
+
:type v_threshold: float
|
| 1377 |
+
|
| 1378 |
+
:param v_rest: 静息电位
|
| 1379 |
+
:type v_rest: float
|
| 1380 |
+
|
| 1381 |
+
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
|
| 1382 |
+
如果设置为 ``None``,则电压会被减去 ``v_threshold``
|
| 1383 |
+
:type v_reset: float
|
| 1384 |
+
|
| 1385 |
+
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
|
| 1386 |
+
:type surrogate_function: Callable
|
| 1387 |
+
|
| 1388 |
+
:param detach_reset: 是否��reset过程的计算图分离
|
| 1389 |
+
:type detach_reset: bool
|
| 1390 |
+
|
| 1391 |
+
多步版本的 :class:`spikingjelly.clock_driven.neuron.EIFNode`。
|
| 1392 |
+
|
| 1393 |
+
.. tip::
|
| 1394 |
+
|
| 1395 |
+
对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够
|
| 1396 |
+
使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。
|
| 1397 |
+
|
| 1398 |
+
.. tip::
|
| 1399 |
+
|
| 1400 |
+
阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。
|
| 1401 |
+
|
| 1402 |
+
* :ref:`中文API <MultiStepEIFNode.__init__-cn>`
|
| 1403 |
+
|
| 1404 |
+
.. _MultiStepEIFNode.__init__-en:
|
| 1405 |
+
|
| 1406 |
+
:param tau: membrane time constant
|
| 1407 |
+
:type tau: float
|
| 1408 |
+
|
| 1409 |
+
:param delta_T: sharpness parameter
|
| 1410 |
+
:type delta_T: float
|
| 1411 |
+
|
| 1412 |
+
:param theta_rh: rheobase threshold
|
| 1413 |
+
:type theta_rh: float
|
| 1414 |
+
|
| 1415 |
+
:param v_threshold: threshold voltage of neurons
|
| 1416 |
+
:type v_threshold: float
|
| 1417 |
+
|
| 1418 |
+
:param v_rest: resting potential
|
| 1419 |
+
:type v_rest: float
|
| 1420 |
+
|
| 1421 |
+
:param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to
|
| 1422 |
+
``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold``
|
| 1423 |
+
:type v_reset: float
|
| 1424 |
+
|
| 1425 |
+
:param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation
|
| 1426 |
+
:type surrogate_function: Callable
|
| 1427 |
+
|
| 1428 |
+
:param detach_reset: whether detach the computation graph of reset
|
| 1429 |
+
:type detach_reset: bool
|
| 1430 |
+
|
| 1431 |
+
:param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU
|
| 1432 |
+
:type backend: str
|
| 1433 |
+
|
| 1434 |
+
.. admonition:: Tip
|
| 1435 |
+
:class: tip
|
| 1436 |
+
|
| 1437 |
+
The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at
|
| 1438 |
+
time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T``
|
| 1439 |
+
time-steps by ``.v_seq`` and ``.spike_seq``.
|
| 1440 |
+
|
| 1441 |
+
.. admonition:: Tip
|
| 1442 |
+
:class: tip
|
| 1443 |
+
|
| 1444 |
+
Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step
|
| 1445 |
+
and multi-step propagation.
|
| 1446 |
+
"""
|
| 1447 |
+
super().__init__(tau, delta_T, theta_rh, v_threshold, v_rest, v_reset,
|
| 1448 |
+
surrogate_function, detach_reset)
|
| 1449 |
+
self.register_memory('v_seq', None)
|
| 1450 |
+
|
| 1451 |
+
check_backend(backend)
|
| 1452 |
+
|
| 1453 |
+
self.backend = backend
|
| 1454 |
+
|
| 1455 |
+
def forward(self, x_seq: torch.Tensor):
|
| 1456 |
+
assert x_seq.dim() > 1
|
| 1457 |
+
# x_seq.shape = [T, *]
|
| 1458 |
+
|
| 1459 |
+
if self.backend == 'torch':
|
| 1460 |
+
spike_seq = []
|
| 1461 |
+
self.v_seq = []
|
| 1462 |
+
for t in range(x_seq.shape[0]):
|
| 1463 |
+
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
|
| 1464 |
+
self.v_seq.append(self.v.unsqueeze(0))
|
| 1465 |
+
spike_seq = torch.cat(spike_seq, 0)
|
| 1466 |
+
self.v_seq = torch.cat(self.v_seq, 0)
|
| 1467 |
+
return spike_seq
|
| 1468 |
+
|
| 1469 |
+
elif self.backend == 'cupy':
|
| 1470 |
+
if isinstance(self.v, float):
|
| 1471 |
+
v_init = self.v
|
| 1472 |
+
self.v = torch.zeros_like(x_seq[0].data)
|
| 1473 |
+
if v_init != 0.:
|
| 1474 |
+
torch.fill_(self.v, v_init)
|
| 1475 |
+
|
| 1476 |
+
spike_seq, self.v_seq = neuron_kernel.MultiStepEIFNodePTT.apply(
|
| 1477 |
+
x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest,
|
| 1478 |
+
self.theta_rh, self.delta_T, self.detach_reset, self.surrogate_function.cuda_code)
|
| 1479 |
+
|
| 1480 |
+
spike_seq = spike_seq.reshape(x_seq.shape)
|
| 1481 |
+
self.v_seq = self.v_seq.reshape(x_seq.shape)
|
| 1482 |
+
|
| 1483 |
+
self.v = self.v_seq[-1].clone()
|
| 1484 |
+
|
| 1485 |
+
return spike_seq
|
| 1486 |
+
else:
|
| 1487 |
+
raise NotImplementedError
|
| 1488 |
+
|
| 1489 |
+
def extra_repr(self):
|
| 1490 |
+
return super().extra_repr() + f', backend={self.backend}'
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
class GeneralNode(BaseNode):
|
| 1494 |
+
def __init__(self, a: float or torch.Tensor, b: float or torch.Tensor, c: float or torch.Tensor = 0.,
|
| 1495 |
+
v_threshold: float = 1., v_reset: float = 0.,
|
| 1496 |
+
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
|
| 1497 |
+
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
|
| 1498 |
+
self.a = self.register_buffer('a', torch.as_tensor(a))
|
| 1499 |
+
self.b = self.register_buffer('b', torch.as_tensor(b))
|
| 1500 |
+
self.c = self.register_buffer('c', torch.as_tensor(c))
|
| 1501 |
+
|
| 1502 |
+
def neuronal_charge(self, x: torch.Tensor):
|
| 1503 |
+
self.v = self.a * self.v + self.b * x + self.c
|
| 1504 |
+
|
| 1505 |
+
|
| 1506 |
+
class MultiStepGeneralNode(GeneralNode):
|
| 1507 |
+
def __init__(self, a: float, b: float, c: float, v_threshold: float = 1., v_reset: float = 0.,
|
| 1508 |
+
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch'):
|
| 1509 |
+
|
| 1510 |
+
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset)
|
| 1511 |
+
|
| 1512 |
+
self.register_memory('v_seq', None)
|
| 1513 |
+
|
| 1514 |
+
check_backend(backend)
|
| 1515 |
+
|
| 1516 |
+
self.backend = backend
|
| 1517 |
+
|
| 1518 |
+
def forward(self, x_seq: torch.Tensor):
|
| 1519 |
+
assert x_seq.dim() > 1
|
| 1520 |
+
# x_seq.shape = [T, *]
|
| 1521 |
+
|
| 1522 |
+
if self.backend == 'torch':
|
| 1523 |
+
spike_seq = []
|
| 1524 |
+
self.v_seq = []
|
| 1525 |
+
for t in range(x_seq.shape[0]):
|
| 1526 |
+
spike_seq.append(super().forward(x_seq[t]).unsqueeze(0))
|
| 1527 |
+
self.v_seq.append(self.v.unsqueeze(0))
|
| 1528 |
+
spike_seq = torch.cat(spike_seq, 0)
|
| 1529 |
+
self.v_seq = torch.cat(self.v_seq, 0)
|
| 1530 |
+
return spike_seq
|
| 1531 |
+
|
| 1532 |
+
elif self.backend == 'cupy':
|
| 1533 |
+
if isinstance(self.v, float):
|
| 1534 |
+
v_init = self.v
|
| 1535 |
+
self.v = torch.zeros_like(x_seq[0].data)
|
| 1536 |
+
if v_init != 0.:
|
| 1537 |
+
torch.fill_(self.v, v_init)
|
| 1538 |
+
|
| 1539 |
+
raise NotImplementedError
|
| 1540 |
+
|
| 1541 |
+
spike_seq = spike_seq.reshape(x_seq.shape)
|
| 1542 |
+
self.v_seq = self.v_seq.reshape(x_seq.shape)
|
| 1543 |
+
|
| 1544 |
+
self.v = self.v_seq[-1].clone()
|
| 1545 |
+
|
| 1546 |
+
return spike_seq
|
| 1547 |
+
else:
|
| 1548 |
+
raise NotImplementedError
|
| 1549 |
+
|
| 1550 |
+
def extra_repr(self):
|
| 1551 |
+
return super().extra_repr() + f', backend={self.backend}'
|
| 1552 |
+
|
| 1553 |
+
|
| 1554 |
+
class LIAFNode(LIFNode):
|
| 1555 |
+
def __init__(self, act: Callable, threshold_related: bool, *args, **kwargs):
|
| 1556 |
+
"""
|
| 1557 |
+
:param act: the activation function
|
| 1558 |
+
:type act: Callable
|
| 1559 |
+
:param threshold_related: whether the neuron uses threshold related (TR mode). If true, `y = act(h - v_th)`,
|
| 1560 |
+
otherwise `y = act(h)`
|
| 1561 |
+
:type threshold_related: bool
|
| 1562 |
+
|
| 1563 |
+
Other parameters in `*args, **kwargs` are same with :class:`LIFNode`.
|
| 1564 |
+
|
| 1565 |
+
The LIAF neuron proposed in `LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing <https://arxiv.org/abs/2011.06176>`_.
|
| 1566 |
+
|
| 1567 |
+
.. admonition:: Warning
|
| 1568 |
+
:class: warning
|
| 1569 |
+
|
| 1570 |
+
The outputs of this neuron are not binary spikes.
|
| 1571 |
+
|
| 1572 |
+
"""
|
| 1573 |
+
super().__init__(*args, **kwargs)
|
| 1574 |
+
self.act = act
|
| 1575 |
+
self.threshold_related = threshold_related
|
| 1576 |
+
|
| 1577 |
+
def forward(self, x: torch.Tensor):
|
| 1578 |
+
self.neuronal_charge(x)
|
| 1579 |
+
if self.threshold_related:
|
| 1580 |
+
y = self.act(self.v - self.v_threshold)
|
| 1581 |
+
else:
|
| 1582 |
+
y = self.act(self.v)
|
| 1583 |
+
spike = self.neuronal_fire()
|
| 1584 |
+
self.neuronal_reset(spike)
|
| 1585 |
+
return y
|
| 1586 |
+
|
| 1587 |
+
|
models/q_vit/Quant.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch.nn.modules.linear import Linear
|
| 4 |
+
import math
|
| 5 |
+
from torch.nn.parameter import Parameter
|
| 6 |
+
from ._quan_base import _Conv2dQ, Qmodes, _LinearQ, _ActQ
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
__all__ = ['Conv2dQ', 'LinearQ', 'ActQ']
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FunQ(torch.autograd.Function):
|
| 13 |
+
@staticmethod
|
| 14 |
+
def forward(ctx, weight, alpha, g, Qn, Qp):
|
| 15 |
+
assert alpha > 0, 'alpha = {}'.format(alpha)
|
| 16 |
+
ctx.save_for_backward(weight, alpha)
|
| 17 |
+
ctx.other = g, Qn, Qp
|
| 18 |
+
q_w = (weight / alpha).round().clamp(Qn, Qp)
|
| 19 |
+
w_q = q_w * alpha
|
| 20 |
+
return w_q
|
| 21 |
+
|
| 22 |
+
@staticmethod
|
| 23 |
+
def backward(ctx, grad_weight):
|
| 24 |
+
weight, alpha = ctx.saved_tensors
|
| 25 |
+
g, Qn, Qp = ctx.other
|
| 26 |
+
q_w = weight / alpha
|
| 27 |
+
indicate_small = (q_w < Qn).float()
|
| 28 |
+
indicate_big = (q_w > Qp).float()
|
| 29 |
+
# indicate_middle = torch.ones(indicate_small.shape).to(indicate_small.device) - indicate_small - indicate_big
|
| 30 |
+
indicate_middle = 1.0 - indicate_small - indicate_big # Thanks to @haolibai
|
| 31 |
+
grad_alpha = ((indicate_small * Qn + indicate_big * Qp + indicate_middle * (
|
| 32 |
+
-q_w + q_w.round())) * grad_weight * g).sum().unsqueeze(dim=0)
|
| 33 |
+
grad_weight = indicate_middle * grad_weight
|
| 34 |
+
# The following operation can make sure that alpha is always greater than zero in any case and can also
|
| 35 |
+
# suppress the update speed of alpha. (Personal understanding)
|
| 36 |
+
# grad_alpha.clamp_(-alpha.item(), alpha.item()) # FYI
|
| 37 |
+
return grad_weight, grad_alpha, None, None, None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def grad_scale(x, scale):
|
| 41 |
+
y = x
|
| 42 |
+
y_grad = x * scale
|
| 43 |
+
return y.detach() - y_grad.detach() + y_grad
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def round_pass(x):
|
| 47 |
+
y = x.round()
|
| 48 |
+
y_grad = x
|
| 49 |
+
return y.detach() - y_grad.detach() + y_grad
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Conv2dQ(_Conv2dQ):
|
| 53 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 54 |
+
padding=0, dilation=1, groups=1, bias=True, nbits_w=8, mode=Qmodes.kernel_wise, **kwargs):
|
| 55 |
+
super(Conv2dQ, self).__init__(
|
| 56 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
|
| 57 |
+
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias,
|
| 58 |
+
nbits=nbits_w, mode=mode)
|
| 59 |
+
self.act = ActQ(in_features=in_channels, nbits_a=nbits_w)
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
if self.alpha is None:
|
| 63 |
+
return F.conv2d(x, self.weight, self.bias, self.stride,
|
| 64 |
+
self.padding, self.dilation, self.groups)
|
| 65 |
+
# w_reshape = self.weight.reshape([self.weight.shape[0], -1]).transpose(0, 1)
|
| 66 |
+
Qn = -2 ** (self.nbits - 1)
|
| 67 |
+
Qp = 2 ** (self.nbits - 1) - 1
|
| 68 |
+
if self.training and self.init_state == 0:
|
| 69 |
+
# self.alpha.data.copy_(self.weight.abs().max() / 2 ** (self.nbits - 1))
|
| 70 |
+
self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
|
| 71 |
+
# self.alpha.data.copy_(self.weight.abs().max() * 2)
|
| 72 |
+
self.init_state.fill_(1)
|
| 73 |
+
"""
|
| 74 |
+
Implementation according to paper.
|
| 75 |
+
Feels wrong ...
|
| 76 |
+
When we initialize the alpha as a big number (e.g., self.weight.abs().max() * 2),
|
| 77 |
+
the clamp function can be skipped.
|
| 78 |
+
Then we get w_q = w / alpha * alpha = w, and $\frac{\partial w_q}{\partial \alpha} = 0$
|
| 79 |
+
As a result, I don't think the pseudo-code in the paper echoes the formula.
|
| 80 |
+
|
| 81 |
+
Please see jupyter/STE_LSQ.ipynb fo detailed comparison.
|
| 82 |
+
"""
|
| 83 |
+
g = 1.0 / math.sqrt(self.weight.numel() * Qp)
|
| 84 |
+
|
| 85 |
+
# Method1: 31GB GPU memory (AlexNet w4a4 bs 2048) 17min/epoch
|
| 86 |
+
alpha = grad_scale(self.alpha, g)
|
| 87 |
+
# print(alpha.shape)
|
| 88 |
+
# print(self.weight.shape)
|
| 89 |
+
alpha = alpha.unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
| 90 |
+
w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha
|
| 91 |
+
|
| 92 |
+
x = self.act(x)
|
| 93 |
+
# w = w.clamp(Qn, Qp)
|
| 94 |
+
# q_w = round_pass(w)
|
| 95 |
+
# w_q = q_w * alpha
|
| 96 |
+
|
| 97 |
+
# Method2: 25GB GPU memory (AlexNet w4a4 bs 2048) 32min/epoch
|
| 98 |
+
# w_q = FunLSQ.apply(self.weight, self.alpha, g, Qn, Qp)
|
| 99 |
+
# wq = y.transpose(0, 1).reshape(self.weight.shape).detach() + self.weight - self.weight.detach()
|
| 100 |
+
return F.conv2d(x, w_q, self.bias, self.stride,
|
| 101 |
+
self.padding, self.dilation, self.groups)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class LinearQ(_LinearQ):
|
| 105 |
+
def __init__(self, in_features, out_features, bias=True, nbits_w=4, **kwargs):
|
| 106 |
+
super(LinearQ, self).__init__(in_features=in_features,
|
| 107 |
+
out_features=out_features, bias=bias, nbits=nbits_w, mode=Qmodes.kernel_wise)
|
| 108 |
+
self.act = ActQ(in_features=in_features, nbits_a=nbits_w)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
if self.alpha is None:
|
| 112 |
+
return F.linear(x, self.weight, self.bias)
|
| 113 |
+
Qn = -2 ** (self.nbits - 1)
|
| 114 |
+
Qp = 2 ** (self.nbits - 1) - 1
|
| 115 |
+
if self.training and self.init_state == 0:
|
| 116 |
+
self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
|
| 117 |
+
# self.alpha.data.copy_(self.weight.abs().max() / 2 ** (self.nbits - 1))
|
| 118 |
+
self.init_state.fill_(1)
|
| 119 |
+
g = 1.0 / math.sqrt(self.weight.numel() * Qp)
|
| 120 |
+
|
| 121 |
+
# Method1:
|
| 122 |
+
alpha = grad_scale(self.alpha, g)
|
| 123 |
+
alpha = alpha.unsqueeze(1)
|
| 124 |
+
w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha
|
| 125 |
+
|
| 126 |
+
x = self.act(x)
|
| 127 |
+
# w = self.weight / alpha
|
| 128 |
+
# w = w.clamp(Qn, Qp)
|
| 129 |
+
# q_w = round_pass(w)
|
| 130 |
+
# w_q = q_w * alpha
|
| 131 |
+
|
| 132 |
+
# Method2:
|
| 133 |
+
# w_q = FunLSQ.apply(self.weight, self.alpha, g, Qn, Qp)
|
| 134 |
+
return F.linear(x, w_q, self.bias)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class ActQ(_ActQ):
|
| 138 |
+
def __init__(self, in_features, nbits_a=4, mode=Qmodes.kernel_wise, **kwargs):
|
| 139 |
+
super(ActQ, self).__init__(in_features=in_features, nbits=nbits_a, mode=mode)
|
| 140 |
+
# print(self.alpha.shape, self.zero_point.shape)
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
if self.alpha is None:
|
| 143 |
+
return x
|
| 144 |
+
|
| 145 |
+
if self.training and self.init_state == 0:
|
| 146 |
+
# The init alpha for activation is very very important as the experimental results shows.
|
| 147 |
+
# Please select a init_rate for activation.
|
| 148 |
+
# self.alpha.data.copy_(x.max() / 2 ** (self.nbits - 1) * self.init_rate)
|
| 149 |
+
if x.min() < -1e-5:
|
| 150 |
+
self.signed.data.fill_(1)
|
| 151 |
+
if self.signed == 1:
|
| 152 |
+
Qn = -2 ** (self.nbits - 1)
|
| 153 |
+
Qp = 2 ** (self.nbits - 1) - 1
|
| 154 |
+
else:
|
| 155 |
+
Qn = 0
|
| 156 |
+
Qp = 2 ** self.nbits - 1
|
| 157 |
+
self.alpha.data.copy_(2 * x.abs().mean() / math.sqrt(Qp))
|
| 158 |
+
self.zero_point.data.copy_(self.zero_point.data * 0.9 + 0.1 * (torch.min(x.detach()) - self.alpha.data * Qn))
|
| 159 |
+
self.init_state.fill_(1)
|
| 160 |
+
|
| 161 |
+
if self.signed == 1:
|
| 162 |
+
Qn = -2 ** (self.nbits - 1)
|
| 163 |
+
Qp = 2 ** (self.nbits - 1) - 1
|
| 164 |
+
else:
|
| 165 |
+
Qn = 0
|
| 166 |
+
Qp = 2 ** self.nbits - 1
|
| 167 |
+
|
| 168 |
+
g = 1.0 / math.sqrt(x.numel() * Qp)
|
| 169 |
+
|
| 170 |
+
# Method1:
|
| 171 |
+
zero_point = (self.zero_point.round() - self.zero_point).detach() + self.zero_point
|
| 172 |
+
alpha = grad_scale(self.alpha, g)
|
| 173 |
+
zero_point = grad_scale(zero_point, g)
|
| 174 |
+
# x = round_pass((x / alpha).clamp(Qn, Qp)) * alpha
|
| 175 |
+
if len(x.shape)==2:
|
| 176 |
+
alpha = alpha.unsqueeze(0)
|
| 177 |
+
zero_point = zero_point.unsqueeze(0)
|
| 178 |
+
elif len(x.shape)==4:
|
| 179 |
+
alpha = alpha.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
| 180 |
+
zero_point = zero_point.unsqueeze(0).unsqueeze(2).unsqueeze(3)
|
| 181 |
+
|
| 182 |
+
x = round_pass((x / alpha + zero_point).clamp(Qn, Qp))
|
| 183 |
+
x = (x - zero_point) * alpha
|
| 184 |
+
|
| 185 |
+
return x
|
models/q_vit/__init__.py
ADDED
|
File without changes
|
models/q_vit/__pycache__/Quant.cpython-311.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
models/q_vit/__pycache__/Quant.cpython-312.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
models/q_vit/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (207 Bytes). View file
|
|
|
models/q_vit/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (218 Bytes). View file
|
|
|
models/q_vit/__pycache__/_quan_base.cpython-311.pyc
ADDED
|
Binary file (12.8 kB). View file
|
|
|
models/q_vit/__pycache__/_quan_base.cpython-312.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
models/q_vit/__pycache__/quant_vision_transformer.cpython-311.pyc
ADDED
|
Binary file (33.1 kB). View file
|
|
|
models/q_vit/__pycache__/quant_vision_transformer.cpython-312.pyc
ADDED
|
Binary file (30.1 kB). View file
|
|
|
models/q_vit/_quan_base.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quantized modules: the base class
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.nn.parameter import Parameter
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from enum import Enum
|
| 10 |
+
|
| 11 |
+
__all__ = ['Qmodes', '_Conv2dQ', '_LinearQ', '_ActQ',
|
| 12 |
+
'truncation', 'get_sparsity_mask', 'FunStopGradient', 'round_pass', 'grad_scale']
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Qmodes(Enum):
|
| 16 |
+
layer_wise = 1
|
| 17 |
+
kernel_wise = 2
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def grad_scale(x, scale):
|
| 21 |
+
y = x
|
| 22 |
+
y_grad = x * scale
|
| 23 |
+
return y.detach() - y_grad.detach() + y_grad
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_sparsity_mask(param, sparsity):
|
| 27 |
+
bottomk, _ = torch.topk(param.abs().view(-1), int(sparsity * param.numel()), largest=False, sorted=True)
|
| 28 |
+
threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away
|
| 29 |
+
return torch.gt(torch.abs(param), threshold).type(param.type())
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def round_pass(x):
|
| 33 |
+
y = x.round()
|
| 34 |
+
y_grad = x
|
| 35 |
+
return y.detach() - y_grad.detach() + y_grad
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class FunStopGradient(torch.autograd.Function):
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def forward(ctx, weight, stopGradientMask):
|
| 42 |
+
ctx.save_for_backward(stopGradientMask)
|
| 43 |
+
return weight
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def backward(ctx, grad_outputs):
|
| 47 |
+
stopGradientMask, = ctx.saved_tensors
|
| 48 |
+
grad_inputs = grad_outputs * stopGradientMask
|
| 49 |
+
return grad_inputs, None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def log_shift(value_fp):
|
| 53 |
+
value_shift = 2 ** (torch.log2(value_fp).ceil())
|
| 54 |
+
return value_shift
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def clamp(input, min, max, inplace=False):
|
| 58 |
+
if inplace:
|
| 59 |
+
input.clamp_(min, max)
|
| 60 |
+
return input
|
| 61 |
+
return torch.clamp(input, min, max)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_quantized_range(num_bits, signed=True):
|
| 65 |
+
if signed:
|
| 66 |
+
n = 2 ** (num_bits - 1)
|
| 67 |
+
return -n, n - 1
|
| 68 |
+
return 0, 2 ** num_bits - 1
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def linear_quantize(input, scale_factor, inplace=False):
|
| 72 |
+
if inplace:
|
| 73 |
+
input.mul_(scale_factor).round_()
|
| 74 |
+
return input
|
| 75 |
+
return torch.round(scale_factor * input)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def linear_quantize_clamp(input, scale_factor, clamp_min, clamp_max, inplace=False):
|
| 79 |
+
output = linear_quantize(input, scale_factor, inplace)
|
| 80 |
+
return clamp(output, clamp_min, clamp_max, inplace)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def linear_dequantize(input, scale_factor, inplace=False):
|
| 84 |
+
if inplace:
|
| 85 |
+
input.div_(scale_factor)
|
| 86 |
+
return input
|
| 87 |
+
return input / scale_factor
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def truncation(fp_data, nbits=8):
|
| 91 |
+
il = torch.log2(torch.max(fp_data.max(), fp_data.min().abs())) + 1
|
| 92 |
+
il = math.ceil(il - 1e-5)
|
| 93 |
+
qcode = nbits - il
|
| 94 |
+
scale_factor = 2 ** qcode
|
| 95 |
+
clamp_min, clamp_max = get_quantized_range(nbits, signed=True)
|
| 96 |
+
q_data = linear_quantize_clamp(fp_data, scale_factor, clamp_min, clamp_max)
|
| 97 |
+
q_data = linear_dequantize(q_data, scale_factor)
|
| 98 |
+
return q_data, qcode
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_default_kwargs_q(kwargs_q, layer_type):
|
| 102 |
+
default = {
|
| 103 |
+
'nbits': 4
|
| 104 |
+
}
|
| 105 |
+
if isinstance(layer_type, _Conv2dQ):
|
| 106 |
+
default.update({
|
| 107 |
+
'mode': Qmodes.layer_wise})
|
| 108 |
+
elif isinstance(layer_type, _LinearQ):
|
| 109 |
+
pass
|
| 110 |
+
elif isinstance(layer_type, _ActQ):
|
| 111 |
+
pass
|
| 112 |
+
# default.update({
|
| 113 |
+
# 'signed': 'Auto'})
|
| 114 |
+
else:
|
| 115 |
+
assert NotImplementedError
|
| 116 |
+
return
|
| 117 |
+
for k, v in default.items():
|
| 118 |
+
if k not in kwargs_q:
|
| 119 |
+
kwargs_q[k] = v
|
| 120 |
+
return kwargs_q
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class _Conv2dQ(nn.Conv2d):
|
| 124 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 125 |
+
padding=0, dilation=1, groups=1, bias=True, **kwargs_q):
|
| 126 |
+
super(_Conv2dQ, self).__init__(in_channels, out_channels, kernel_size, stride=stride,
|
| 127 |
+
padding=padding, dilation=dilation, groups=groups, bias=bias)
|
| 128 |
+
self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
|
| 129 |
+
self.nbits = kwargs_q['nbits']
|
| 130 |
+
if self.nbits < 0:
|
| 131 |
+
self.register_parameter('alpha', None)
|
| 132 |
+
return
|
| 133 |
+
self.q_mode = kwargs_q['mode']
|
| 134 |
+
if self.q_mode == Qmodes.kernel_wise:
|
| 135 |
+
self.alpha = Parameter(torch.Tensor(out_channels))
|
| 136 |
+
else: # layer-wise quantization
|
| 137 |
+
self.alpha = Parameter(torch.Tensor(1))
|
| 138 |
+
self.register_buffer('init_state', torch.zeros(1))
|
| 139 |
+
|
| 140 |
+
def add_param(self, param_k, param_v):
|
| 141 |
+
self.kwargs_q[param_k] = param_v
|
| 142 |
+
|
| 143 |
+
def set_bit(self, nbits):
|
| 144 |
+
self.kwargs_q['nbits'] = nbits
|
| 145 |
+
|
| 146 |
+
def extra_repr(self):
|
| 147 |
+
s_prefix = super(_Conv2dQ, self).extra_repr()
|
| 148 |
+
if self.alpha is None:
|
| 149 |
+
return '{}, fake'.format(s_prefix)
|
| 150 |
+
return '{}, {}'.format(s_prefix, self.kwargs_q)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class _LinearQ(nn.Linear):
|
| 154 |
+
def __init__(self, in_features, out_features, bias=True, **kwargs_q):
|
| 155 |
+
super(_LinearQ, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
|
| 156 |
+
self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
|
| 157 |
+
self.nbits = kwargs_q['nbits']
|
| 158 |
+
if self.nbits < 0:
|
| 159 |
+
self.register_parameter('alpha', None)
|
| 160 |
+
return
|
| 161 |
+
self.q_mode = kwargs_q['mode']
|
| 162 |
+
self.alpha = Parameter(torch.Tensor(1))
|
| 163 |
+
if self.q_mode == Qmodes.kernel_wise:
|
| 164 |
+
self.alpha = Parameter(torch.Tensor(out_features))
|
| 165 |
+
self.register_buffer('init_state', torch.zeros(1))
|
| 166 |
+
|
| 167 |
+
def add_param(self, param_k, param_v):
|
| 168 |
+
self.kwargs_q[param_k] = param_v
|
| 169 |
+
|
| 170 |
+
def extra_repr(self):
|
| 171 |
+
s_prefix = super(_LinearQ, self).extra_repr()
|
| 172 |
+
if self.alpha is None:
|
| 173 |
+
return '{}, fake'.format(s_prefix)
|
| 174 |
+
return '{}, {}'.format(s_prefix, self.kwargs_q)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class _ActQ(nn.Module):
|
| 178 |
+
def __init__(self, in_features, **kwargs_q):
|
| 179 |
+
super(_ActQ, self).__init__()
|
| 180 |
+
self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
|
| 181 |
+
self.nbits = kwargs_q['nbits']
|
| 182 |
+
if self.nbits < 0:
|
| 183 |
+
self.register_parameter('alpha', None)
|
| 184 |
+
self.register_parameter('zero_point', None)
|
| 185 |
+
return
|
| 186 |
+
# self.signed = kwargs_q['signed']
|
| 187 |
+
self.q_mode = kwargs_q['mode']
|
| 188 |
+
self.alpha = Parameter(torch.Tensor(1))
|
| 189 |
+
self.zero_point = Parameter(torch.Tensor([0]))
|
| 190 |
+
if self.q_mode == Qmodes.kernel_wise:
|
| 191 |
+
self.alpha = Parameter(torch.Tensor(in_features))
|
| 192 |
+
self.zero_point = Parameter(torch.Tensor(in_features))
|
| 193 |
+
torch.nn.init.zeros_(self.zero_point)
|
| 194 |
+
# self.zero_point = Parameter(torch.Tensor([0]))
|
| 195 |
+
self.register_buffer('init_state', torch.zeros(1))
|
| 196 |
+
self.register_buffer('signed', torch.zeros(1))
|
| 197 |
+
|
| 198 |
+
def add_param(self, param_k, param_v):
|
| 199 |
+
self.kwargs_q[param_k] = param_v
|
| 200 |
+
|
| 201 |
+
def set_bit(self, nbits):
|
| 202 |
+
self.kwargs_q['nbits'] = nbits
|
| 203 |
+
|
| 204 |
+
def extra_repr(self):
|
| 205 |
+
# s_prefix = super(_ActQ, self).extra_repr()
|
| 206 |
+
if self.alpha is None:
|
| 207 |
+
return 'fake'
|
| 208 |
+
return '{}'.format(self.kwargs_q)
|
models/q_vit/quant_vision_transformer.py
ADDED
|
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import logging
|
| 3 |
+
from functools import partial
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 11 |
+
from timm.models.helpers import load_pretrained
|
| 12 |
+
from timm.models.layers import Mlp
|
| 13 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
| 14 |
+
from timm.models.resnet import resnet26d, resnet50d
|
| 15 |
+
from timm.models.registry import register_model
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
from .Quant import *
|
| 19 |
+
from ._quan_base import *
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
_logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _cfg(url='', **kwargs):
|
| 26 |
+
return {
|
| 27 |
+
'url': url,
|
| 28 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 29 |
+
'crop_pct': .9, 'interpolation': 'bicubic',
|
| 30 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 31 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
| 32 |
+
**kwargs
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
default_cfgs = {
|
| 37 |
+
# patch models (my experiments)
|
| 38 |
+
'vit_small_patch16_224': _cfg(
|
| 39 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
| 40 |
+
),
|
| 41 |
+
|
| 42 |
+
# patch models (weights ported from official Google JAX impl)
|
| 43 |
+
'vit_base_patch16_224': _cfg(
|
| 44 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
| 45 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
| 46 |
+
),
|
| 47 |
+
'vit_base_patch32_224': _cfg(
|
| 48 |
+
url='', # no official model weights for this combo, only for in21k
|
| 49 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 50 |
+
'vit_base_patch16_384': _cfg(
|
| 51 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
| 52 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
| 53 |
+
'vit_base_patch32_384': _cfg(
|
| 54 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
| 55 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
| 56 |
+
'vit_large_patch16_224': _cfg(
|
| 57 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
| 58 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 59 |
+
'vit_large_patch32_224': _cfg(
|
| 60 |
+
url='', # no official model weights for this combo, only for in21k
|
| 61 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 62 |
+
'vit_large_patch16_384': _cfg(
|
| 63 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
| 64 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
| 65 |
+
'vit_large_patch32_384': _cfg(
|
| 66 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
| 67 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
| 68 |
+
|
| 69 |
+
# patch models, imagenet21k (weights ported from official Google JAX impl)
|
| 70 |
+
'vit_base_patch16_224_in21k': _cfg(
|
| 71 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
|
| 72 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 73 |
+
'vit_base_patch32_224_in21k': _cfg(
|
| 74 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
|
| 75 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 76 |
+
'vit_large_patch16_224_in21k': _cfg(
|
| 77 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
|
| 78 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 79 |
+
'vit_large_patch32_224_in21k': _cfg(
|
| 80 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
| 81 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 82 |
+
'vit_huge_patch14_224_in21k': _cfg(
|
| 83 |
+
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
|
| 84 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
| 85 |
+
|
| 86 |
+
# hybrid models (weights ported from official Google JAX impl)
|
| 87 |
+
'vit_base_resnet50_224_in21k': _cfg(
|
| 88 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
| 89 |
+
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
|
| 90 |
+
'vit_base_resnet50_384': _cfg(
|
| 91 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
| 92 |
+
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'),
|
| 93 |
+
|
| 94 |
+
# hybrid models (my experiments)
|
| 95 |
+
'vit_small_resnet26d_224': _cfg(),
|
| 96 |
+
'vit_small_resnet50d_s3_224': _cfg(),
|
| 97 |
+
'vit_base_resnet26d_224': _cfg(),
|
| 98 |
+
'vit_base_resnet50d_224': _cfg(),
|
| 99 |
+
|
| 100 |
+
# deit models (FB weights)
|
| 101 |
+
'vit_deit_tiny_patch16_224': _cfg(
|
| 102 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
| 103 |
+
'vit_deit_small_patch16_224': _cfg(
|
| 104 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
| 105 |
+
'vit_deit_base_patch16_224': _cfg(
|
| 106 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
| 107 |
+
'vit_deit_base_patch16_384': _cfg(
|
| 108 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
| 109 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 110 |
+
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
| 111 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
|
| 112 |
+
'vit_deit_small_distilled_patch16_224': _cfg(
|
| 113 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
|
| 114 |
+
'vit_deit_base_distilled_patch16_224': _cfg(
|
| 115 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
|
| 116 |
+
'vit_deit_base_distilled_patch16_384': _cfg(
|
| 117 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
| 118 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
class Q_Mlp(nn.Module):
|
| 122 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
| 123 |
+
"""
|
| 124 |
+
def __init__(self, nbits, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 125 |
+
super().__init__()
|
| 126 |
+
out_features = out_features or in_features
|
| 127 |
+
hidden_features = hidden_features or in_features
|
| 128 |
+
drop_probs = to_2tuple(drop)
|
| 129 |
+
|
| 130 |
+
self.fc1 = LinearQ(in_features, hidden_features, nbits_w=nbits, mode=Qmodes.kernel_wise)
|
| 131 |
+
self.act = act_layer()
|
| 132 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 133 |
+
self.fc2 = LinearQ(hidden_features, out_features, nbits_w=nbits, mode=Qmodes.kernel_wise)
|
| 134 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
x = self.fc1(x)
|
| 138 |
+
# print(torch.max(x), torch.min(x))
|
| 139 |
+
x = self.act(x)
|
| 140 |
+
|
| 141 |
+
x = torch.clip(x, -10., 10.)
|
| 142 |
+
# print(torch.clip(x, -10., 10.))
|
| 143 |
+
x = self.drop1(x)
|
| 144 |
+
x = self.fc2(x)
|
| 145 |
+
x = self.drop2(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class Q_Attention(nn.Module):
|
| 150 |
+
|
| 151 |
+
def __init__(self, nbits, dim, num_heads=8, quantize_attn=True, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
| 152 |
+
super().__init__()
|
| 153 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 154 |
+
self.num_heads = num_heads
|
| 155 |
+
head_dim = dim // num_heads
|
| 156 |
+
self.scale = head_dim ** -0.5
|
| 157 |
+
self.quantize_attn = quantize_attn
|
| 158 |
+
|
| 159 |
+
self.norm_q = nn.LayerNorm(head_dim)
|
| 160 |
+
self.norm_k = nn.LayerNorm(head_dim)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if self.quantize_attn:
|
| 164 |
+
|
| 165 |
+
self.qkv = LinearQ(dim, dim * 3, bias=qkv_bias, nbits_w=nbits, mode=Qmodes.kernel_wise)
|
| 166 |
+
|
| 167 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 168 |
+
|
| 169 |
+
self.proj = LinearQ(dim, dim, nbits_w=nbits, mode=Qmodes.kernel_wise)
|
| 170 |
+
self.q_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
|
| 171 |
+
self.k_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
|
| 172 |
+
self.v_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
|
| 173 |
+
self.attn_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
|
| 174 |
+
else:
|
| 175 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 176 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 177 |
+
self.proj = nn.Linear(dim, dim)
|
| 178 |
+
self.q_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
|
| 179 |
+
self.k_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
|
| 180 |
+
self.v_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
|
| 181 |
+
self.attn_act = ActQ(nbits_a=nbits, in_features=self.num_heads)
|
| 182 |
+
|
| 183 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
B, N, C = x.shape
|
| 187 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 188 |
+
|
| 189 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
| 190 |
+
q = self.norm_q(q)
|
| 191 |
+
k = self.norm_k(k)
|
| 192 |
+
|
| 193 |
+
q = self.q_act(q)
|
| 194 |
+
k = self.k_act(k)
|
| 195 |
+
v = self.v_act(v)
|
| 196 |
+
|
| 197 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 198 |
+
attn = attn.softmax(dim=-1)
|
| 199 |
+
attn = self.attn_drop(attn)
|
| 200 |
+
attn = self.attn_act(attn)
|
| 201 |
+
|
| 202 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 203 |
+
|
| 204 |
+
x = self.proj(x)
|
| 205 |
+
x = self.proj_drop(x)
|
| 206 |
+
return x
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class Q_Block(nn.Module):
|
| 210 |
+
|
| 211 |
+
def __init__(self, nbits, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
| 212 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
| 213 |
+
super().__init__()
|
| 214 |
+
self.norm1 = norm_layer(dim)
|
| 215 |
+
self.attn = Q_Attention(nbits, dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
| 216 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 217 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 218 |
+
self.norm2 = norm_layer(dim)
|
| 219 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 220 |
+
self.mlp = Q_Mlp(nbits=nbits, in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 224 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 225 |
+
return x
|
| 226 |
+
|
| 227 |
+
class Q_PatchEmbed(nn.Module):
|
| 228 |
+
""" Image to Patch Embedding
|
| 229 |
+
"""
|
| 230 |
+
def __init__(self, nbits=4, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
| 231 |
+
super().__init__()
|
| 232 |
+
img_size = to_2tuple(img_size)
|
| 233 |
+
patch_size = to_2tuple(patch_size)
|
| 234 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 235 |
+
self.img_size = img_size
|
| 236 |
+
self.patch_size = patch_size
|
| 237 |
+
self.num_patches = num_patches
|
| 238 |
+
|
| 239 |
+
self.proj = Conv2dQ(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 240 |
+
# nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 241 |
+
|
| 242 |
+
def forward(self, x):
|
| 243 |
+
B, C, H, W = x.shape
|
| 244 |
+
# FIXME look at relaxing size constraints
|
| 245 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
| 246 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 247 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
class lowbit_VisionTransformer(nn.Module):
|
| 251 |
+
""" Vision Transformer
|
| 252 |
+
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
| 253 |
+
- https://arxiv.org/abs/2010.11929
|
| 254 |
+
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
| 255 |
+
- https://arxiv.org/abs/2012.12877
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
def __init__(self, nbits, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
| 259 |
+
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=True,
|
| 260 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=Q_PatchEmbed, norm_layer=None,
|
| 261 |
+
act_layer=None, weight_init=''):
|
| 262 |
+
"""
|
| 263 |
+
Args:
|
| 264 |
+
nbits: nbits
|
| 265 |
+
img_size (int, tuple): input image size
|
| 266 |
+
patch_size (int, tuple): patch size
|
| 267 |
+
in_chans (int): number of input channels
|
| 268 |
+
num_classes (int): number of classes for classification head
|
| 269 |
+
embed_dim (int): embedding dimension
|
| 270 |
+
depth (int): depth of transformer
|
| 271 |
+
num_heads (int): number of attention heads
|
| 272 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 273 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 274 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
| 275 |
+
distilled (bool): model includes a distillation token and head as in DeiT models
|
| 276 |
+
drop_rate (float): dropout rate
|
| 277 |
+
attn_drop_rate (float): attention dropout rate
|
| 278 |
+
drop_path_rate (float): stochastic depth rate
|
| 279 |
+
embed_layer (nn.Module): patch embedding layer
|
| 280 |
+
norm_layer: (nn.Module): normalization layer
|
| 281 |
+
weight_init: (str): weight init scheme
|
| 282 |
+
"""
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.num_classes = num_classes
|
| 285 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 286 |
+
self.num_tokens = 2 if distilled else 1
|
| 287 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
| 288 |
+
act_layer = act_layer or nn.GELU
|
| 289 |
+
|
| 290 |
+
self.patch_embed = embed_layer(
|
| 291 |
+
nbits=nbits, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 292 |
+
num_patches = self.patch_embed.num_patches
|
| 293 |
+
|
| 294 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 295 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
| 296 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 297 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
| 298 |
+
|
| 299 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 300 |
+
self.blocks = nn.Sequential(*[
|
| 301 |
+
Q_Block(
|
| 302 |
+
nbits=nbits, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
| 303 |
+
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
| 304 |
+
for i in range(depth)])
|
| 305 |
+
self.norm = norm_layer(embed_dim)
|
| 306 |
+
|
| 307 |
+
# Representation layer
|
| 308 |
+
if representation_size and not distilled:
|
| 309 |
+
self.num_features = representation_size
|
| 310 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
| 311 |
+
('fc', nn.Linear(embed_dim, representation_size)),
|
| 312 |
+
('act', nn.Tanh())
|
| 313 |
+
]))
|
| 314 |
+
else:
|
| 315 |
+
self.pre_logits = nn.Identity()
|
| 316 |
+
|
| 317 |
+
# Classifier head(s)
|
| 318 |
+
self.head = LinearQ(self.num_features, num_classes, nbits_w=8) if num_classes > 0 else nn.Identity()
|
| 319 |
+
# nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 320 |
+
self.head_dist = None
|
| 321 |
+
if distilled:
|
| 322 |
+
self.head_dist = LinearQ(self.embed_dim, self.num_classes, nbits_w=8) if num_classes > 0 else nn.Identity()
|
| 323 |
+
# self.head = LinearQ(self.embed_dim, self.num_classes, nbits_w=8) if num_classes > 0 else nn.Identity()
|
| 324 |
+
# nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
| 325 |
+
|
| 326 |
+
self.init_weights(weight_init)
|
| 327 |
+
|
| 328 |
+
def init_weights(self, mode=''):
|
| 329 |
+
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
|
| 330 |
+
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
| 331 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 332 |
+
if self.dist_token is not None:
|
| 333 |
+
trunc_normal_(self.dist_token, std=.02)
|
| 334 |
+
if mode.startswith('jax'):
|
| 335 |
+
# leave cls token as zeros to match jax impl
|
| 336 |
+
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
|
| 337 |
+
else:
|
| 338 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 339 |
+
self.apply(_init_vit_weights)
|
| 340 |
+
|
| 341 |
+
def _init_weights(self, m):
|
| 342 |
+
# this fn left here for compat with downstream users
|
| 343 |
+
_init_vit_weights(m)
|
| 344 |
+
|
| 345 |
+
@torch.jit.ignore()
|
| 346 |
+
def load_pretrained(self, checkpoint_path, prefix=''):
|
| 347 |
+
_load_weights(self, checkpoint_path, prefix)
|
| 348 |
+
|
| 349 |
+
@torch.jit.ignore
|
| 350 |
+
def no_weight_decay(self):
|
| 351 |
+
return {'pos_embed', 'cls_token', 'dist_token'}
|
| 352 |
+
|
| 353 |
+
def get_classifier(self):
|
| 354 |
+
if self.dist_token is None:
|
| 355 |
+
return self.head
|
| 356 |
+
else:
|
| 357 |
+
return self.head, self.head_dist
|
| 358 |
+
|
| 359 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
| 360 |
+
self.num_classes = num_classes
|
| 361 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 362 |
+
if self.num_tokens == 2:
|
| 363 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
| 364 |
+
|
| 365 |
+
def forward_features(self, x):
|
| 366 |
+
x = self.patch_embed(x)
|
| 367 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
| 368 |
+
if self.dist_token is None:
|
| 369 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 370 |
+
else:
|
| 371 |
+
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 372 |
+
x = self.pos_drop(x + self.pos_embed)
|
| 373 |
+
x = self.blocks(x)
|
| 374 |
+
x = self.norm(x)
|
| 375 |
+
if self.dist_token is None:
|
| 376 |
+
return self.pre_logits(x[:, 0])
|
| 377 |
+
else:
|
| 378 |
+
return x[:, 0], x[:, 1]
|
| 379 |
+
|
| 380 |
+
def forward(self, x):
|
| 381 |
+
x = self.forward_features(x)
|
| 382 |
+
if self.head_dist is not None:
|
| 383 |
+
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
|
| 384 |
+
if self.training and not torch.jit.is_scripting():
|
| 385 |
+
# during inference, return the average of both classifier predictions
|
| 386 |
+
return x, x_dist
|
| 387 |
+
else:
|
| 388 |
+
return (x + x_dist) / 2
|
| 389 |
+
else:
|
| 390 |
+
x = self.head(x)
|
| 391 |
+
return x
|
| 392 |
+
|
| 393 |
+
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
|
| 394 |
+
""" ViT weight initialization
|
| 395 |
+
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
| 396 |
+
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
| 397 |
+
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
| 398 |
+
"""
|
| 399 |
+
if isinstance(module, nn.Linear):
|
| 400 |
+
if name.startswith('head'):
|
| 401 |
+
nn.init.zeros_(module.weight)
|
| 402 |
+
nn.init.constant_(module.bias, head_bias)
|
| 403 |
+
elif name.startswith('pre_logits'):
|
| 404 |
+
lecun_normal_(module.weight)
|
| 405 |
+
nn.init.zeros_(module.bias)
|
| 406 |
+
else:
|
| 407 |
+
if jax_impl:
|
| 408 |
+
nn.init.xavier_uniform_(module.weight)
|
| 409 |
+
if module.bias is not None:
|
| 410 |
+
if 'mlp' in name:
|
| 411 |
+
nn.init.normal_(module.bias, std=1e-6)
|
| 412 |
+
else:
|
| 413 |
+
nn.init.zeros_(module.bias)
|
| 414 |
+
else:
|
| 415 |
+
trunc_normal_(module.weight, std=.02)
|
| 416 |
+
if module.bias is not None:
|
| 417 |
+
nn.init.zeros_(module.bias)
|
| 418 |
+
elif jax_impl and isinstance(module, nn.Conv2d):
|
| 419 |
+
# NOTE conv was left to pytorch default in my original init
|
| 420 |
+
lecun_normal_(module.weight)
|
| 421 |
+
if module.bias is not None:
|
| 422 |
+
nn.init.zeros_(module.bias)
|
| 423 |
+
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
| 424 |
+
nn.init.zeros_(module.bias)
|
| 425 |
+
nn.init.ones_(module.weight)
|
| 426 |
+
|
| 427 |
+
def resize_pos_embed(posemb, posemb_new):
|
| 428 |
+
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
| 429 |
+
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
| 430 |
+
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
| 431 |
+
ntok_new = posemb_new.shape[1]
|
| 432 |
+
if True:
|
| 433 |
+
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
| 434 |
+
ntok_new -= 1
|
| 435 |
+
else:
|
| 436 |
+
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
| 437 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
| 438 |
+
gs_new = int(math.sqrt(ntok_new))
|
| 439 |
+
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
|
| 440 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
| 441 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
|
| 442 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
|
| 443 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
| 444 |
+
return posemb
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def checkpoint_filter_fn(state_dict, model):
|
| 448 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
| 449 |
+
out_dict = {}
|
| 450 |
+
if 'model' in state_dict:
|
| 451 |
+
# For deit models
|
| 452 |
+
state_dict = state_dict['model']
|
| 453 |
+
for k, v in state_dict.items():
|
| 454 |
+
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
| 455 |
+
# For old models that I trained prior to conv based patchification
|
| 456 |
+
O, I, H, W = model.patch_embed.proj.weight.shape
|
| 457 |
+
v = v.reshape(O, -1, H, W)
|
| 458 |
+
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
| 459 |
+
# To resize pos embedding when using model at different size from pretrained weights
|
| 460 |
+
v = resize_pos_embed(v, model.pos_embed)
|
| 461 |
+
out_dict[k] = v
|
| 462 |
+
return out_dict
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
|
| 466 |
+
default_cfg = default_cfgs[variant]
|
| 467 |
+
default_num_classes = default_cfg['num_classes']
|
| 468 |
+
default_img_size = default_cfg['input_size'][-1]
|
| 469 |
+
|
| 470 |
+
num_classes = kwargs.pop('num_classes', default_num_classes)
|
| 471 |
+
img_size = kwargs.pop('img_size', default_img_size)
|
| 472 |
+
repr_size = kwargs.pop('representation_size', None)
|
| 473 |
+
if repr_size is not None and num_classes != default_num_classes:
|
| 474 |
+
# Remove representation layer if fine-tuning. This may not always be the desired action,
|
| 475 |
+
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
|
| 476 |
+
_logger.warning("Removing representation layer for fine-tuning.")
|
| 477 |
+
repr_size = None
|
| 478 |
+
|
| 479 |
+
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
|
| 480 |
+
model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
|
| 481 |
+
model.default_cfg = default_cfg
|
| 482 |
+
|
| 483 |
+
if pretrained:
|
| 484 |
+
load_pretrained(
|
| 485 |
+
model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
|
| 486 |
+
filter_fn=partial(checkpoint_filter_fn, model=model))
|
| 487 |
+
return model
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
@register_model
|
| 491 |
+
def fourbits_deit_small_patch16_224(pretrained=False, **kwargs):
|
| 492 |
+
model = lowbit_VisionTransformer(
|
| 493 |
+
nbits=4, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
| 494 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 495 |
+
model.default_cfg = _cfg()
|
| 496 |
+
if pretrained:
|
| 497 |
+
torch.hub.load_state_dict_from_url(
|
| 498 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
| 499 |
+
map_location="cpu", check_hash=True
|
| 500 |
+
)
|
| 501 |
+
return model
|
| 502 |
+
|
| 503 |
+
@register_model
|
| 504 |
+
def threebits_deit_small_patch16_224(pretrained=False, **kwargs):
|
| 505 |
+
model = lowbit_VisionTransformer(
|
| 506 |
+
nbits=3, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 507 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 508 |
+
model.default_cfg = _cfg()
|
| 509 |
+
if pretrained:
|
| 510 |
+
torch.hub.load_state_dict_from_url(
|
| 511 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
| 512 |
+
map_location="cpu", check_hash=True
|
| 513 |
+
)
|
| 514 |
+
return model
|
| 515 |
+
|
| 516 |
+
@register_model
|
| 517 |
+
def twobits_deit_small_patch16_224(pretrained=False, **kwargs):
|
| 518 |
+
model = lowbit_VisionTransformer(
|
| 519 |
+
nbits=2, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 520 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 521 |
+
model.default_cfg = _cfg()
|
| 522 |
+
if pretrained:
|
| 523 |
+
torch.hub.load_state_dict_from_url(
|
| 524 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
| 525 |
+
map_location="cpu", check_hash=True
|
| 526 |
+
)
|
| 527 |
+
return model
|
models/qk_model_v1_1003.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode
|
| 4 |
+
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
|
| 5 |
+
from timm.models.registry import register_model
|
| 6 |
+
from timm.models.vision_transformer import _cfg
|
| 7 |
+
from functools import partial
|
| 8 |
+
from timm.models import create_model
|
| 9 |
+
|
| 10 |
+
__all__ = ['QKFormer']
|
| 11 |
+
|
| 12 |
+
class MLP(nn.Module):
|
| 13 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
|
| 14 |
+
super().__init__()
|
| 15 |
+
out_features = out_features or in_features
|
| 16 |
+
hidden_features = hidden_features or in_features
|
| 17 |
+
self.mlp1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1)
|
| 18 |
+
self.mlp1_bn = nn.BatchNorm2d(hidden_features)
|
| 19 |
+
self.mlp1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 20 |
+
|
| 21 |
+
self.mlp2_conv = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1)
|
| 22 |
+
self.mlp2_bn = nn.BatchNorm2d(out_features)
|
| 23 |
+
self.mlp2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 24 |
+
|
| 25 |
+
self.c_hidden = hidden_features
|
| 26 |
+
self.c_output = out_features
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
T, B, C, H, W = x.shape
|
| 30 |
+
|
| 31 |
+
x = self.mlp1_conv(x.flatten(0, 1))
|
| 32 |
+
x = self.mlp1_bn(x).reshape(T, B, self.c_hidden, H, W)
|
| 33 |
+
x = self.mlp1_lif(x)
|
| 34 |
+
|
| 35 |
+
x = self.mlp2_conv(x.flatten(0, 1))
|
| 36 |
+
x = self.mlp2_bn(x).reshape(T, B, C, H, W)
|
| 37 |
+
x = self.mlp2_lif(x)
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
class Token_QK_Attention(nn.Module):
|
| 41 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
| 42 |
+
super().__init__()
|
| 43 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 44 |
+
|
| 45 |
+
self.dim = dim
|
| 46 |
+
self.num_heads = num_heads
|
| 47 |
+
|
| 48 |
+
self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 49 |
+
self.q_bn = nn.BatchNorm1d(dim)
|
| 50 |
+
self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 51 |
+
|
| 52 |
+
self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 53 |
+
self.k_bn = nn.BatchNorm1d(dim)
|
| 54 |
+
self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 55 |
+
|
| 56 |
+
self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='torch')
|
| 57 |
+
|
| 58 |
+
self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
|
| 59 |
+
self.proj_bn = nn.BatchNorm1d(dim)
|
| 60 |
+
self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
T, B, C, H, W = x.shape
|
| 64 |
+
|
| 65 |
+
x = x.flatten(3)
|
| 66 |
+
T, B, C, N = x.shape
|
| 67 |
+
x_for_qkv = x.flatten(0, 1)
|
| 68 |
+
|
| 69 |
+
q_conv_out = self.q_conv(x_for_qkv)
|
| 70 |
+
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N)
|
| 71 |
+
q_conv_out = self.q_lif(q_conv_out)
|
| 72 |
+
q = q_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N)
|
| 73 |
+
|
| 74 |
+
k_conv_out = self.k_conv(x_for_qkv)
|
| 75 |
+
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N)
|
| 76 |
+
k_conv_out = self.k_lif(k_conv_out)
|
| 77 |
+
k = k_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N)
|
| 78 |
+
|
| 79 |
+
q = torch.sum(q, dim=3, keepdim=True)
|
| 80 |
+
attn = self.attn_lif(q)
|
| 81 |
+
x = torch.mul(attn, k)
|
| 82 |
+
|
| 83 |
+
x = x.flatten(2, 3)
|
| 84 |
+
x = self.proj_bn(self.proj_conv(x.flatten(0, 1))).reshape(T, B, C, H, W)
|
| 85 |
+
# print(f"proj_conv out shape: {x.shape}")
|
| 86 |
+
x = self.proj_lif(x)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
class Spiking_Self_Attention(nn.Module):
|
| 90 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
| 91 |
+
super().__init__()
|
| 92 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 93 |
+
self.dim = dim
|
| 94 |
+
self.num_heads = num_heads
|
| 95 |
+
head_dim = dim // num_heads
|
| 96 |
+
self.scale = 0.125
|
| 97 |
+
self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 98 |
+
self.q_bn = nn.BatchNorm1d(dim)
|
| 99 |
+
self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 100 |
+
|
| 101 |
+
self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 102 |
+
self.k_bn = nn.BatchNorm1d(dim)
|
| 103 |
+
self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 104 |
+
|
| 105 |
+
self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 106 |
+
self.v_bn = nn.BatchNorm1d(dim)
|
| 107 |
+
self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 108 |
+
self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='torch')
|
| 109 |
+
|
| 110 |
+
self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
|
| 111 |
+
self.proj_bn = nn.BatchNorm1d(dim)
|
| 112 |
+
self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 113 |
+
|
| 114 |
+
self.qkv_mp = nn.MaxPool1d(4)
|
| 115 |
+
|
| 116 |
+
def forward(self, x):
|
| 117 |
+
T, B, C, H, W = x.shape
|
| 118 |
+
|
| 119 |
+
x = x.flatten(3)
|
| 120 |
+
T, B, C, N = x.shape
|
| 121 |
+
x_for_qkv = x.flatten(0, 1)
|
| 122 |
+
|
| 123 |
+
q_conv_out = self.q_conv(x_for_qkv)
|
| 124 |
+
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()
|
| 125 |
+
q_conv_out = self.q_lif(q_conv_out)
|
| 126 |
+
q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
|
| 127 |
+
4).contiguous()
|
| 128 |
+
|
| 129 |
+
k_conv_out = self.k_conv(x_for_qkv)
|
| 130 |
+
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()
|
| 131 |
+
k_conv_out = self.k_lif(k_conv_out)
|
| 132 |
+
k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
|
| 133 |
+
4).contiguous()
|
| 134 |
+
|
| 135 |
+
v_conv_out = self.v_conv(x_for_qkv)
|
| 136 |
+
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()
|
| 137 |
+
v_conv_out = self.v_lif(v_conv_out)
|
| 138 |
+
v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
|
| 139 |
+
4).contiguous()
|
| 140 |
+
|
| 141 |
+
x = k.transpose(-2, -1) @ v
|
| 142 |
+
x = (q @ x) * self.scale
|
| 143 |
+
|
| 144 |
+
x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
|
| 145 |
+
x = self.attn_lif(x)
|
| 146 |
+
x = x.flatten(0, 1)
|
| 147 |
+
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, H, W)
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
class TokenSpikingTransformer(nn.Module):
|
| 151 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 152 |
+
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.tssa = Token_QK_Attention(dim, num_heads)
|
| 155 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 156 |
+
self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop)
|
| 157 |
+
|
| 158 |
+
def forward(self, x):
|
| 159 |
+
|
| 160 |
+
x = x + self.tssa(x)
|
| 161 |
+
x = x + self.mlp(x)
|
| 162 |
+
|
| 163 |
+
return x
|
| 164 |
+
|
| 165 |
+
class SpikingTransformer(nn.Module):
|
| 166 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 167 |
+
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.ssa = Spiking_Self_Attention(dim, num_heads)
|
| 170 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 171 |
+
self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop)
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
|
| 175 |
+
x = x + self.ssa(x)
|
| 176 |
+
x = x + self.mlp(x)
|
| 177 |
+
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
class PatchEmbedInit(nn.Module):
|
| 181 |
+
def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.image_size = [img_size_h, img_size_w]
|
| 184 |
+
patch_size = to_2tuple(patch_size)
|
| 185 |
+
self.patch_size = patch_size
|
| 186 |
+
self.C = in_channels
|
| 187 |
+
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
|
| 188 |
+
self.num_patches = self.H * self.W
|
| 189 |
+
|
| 190 |
+
self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
|
| 191 |
+
self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
|
| 192 |
+
self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 193 |
+
|
| 194 |
+
self.proj1_conv = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
|
| 195 |
+
self.proj1_bn = nn.BatchNorm2d(embed_dims // 4)
|
| 196 |
+
self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 197 |
+
self.proj1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 198 |
+
|
| 199 |
+
self.proj2_conv = nn.Conv2d(embed_dims//4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
|
| 200 |
+
self.proj2_bn = nn.BatchNorm2d(embed_dims // 2)
|
| 201 |
+
self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 202 |
+
self.proj2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 203 |
+
|
| 204 |
+
self.proj3_conv = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
|
| 205 |
+
self.proj3_bn = nn.BatchNorm2d(embed_dims)
|
| 206 |
+
self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 207 |
+
self.proj3_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 208 |
+
|
| 209 |
+
self.proj_res_conv = nn.Conv2d(embed_dims // 4, embed_dims, kernel_size=1, stride=4, padding=0, bias=False)
|
| 210 |
+
self.proj_res_bn = nn.BatchNorm2d(embed_dims)
|
| 211 |
+
self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def forward(self, x):
|
| 215 |
+
T, B, C, H, W = x.shape
|
| 216 |
+
# Downsampling + Res
|
| 217 |
+
# x_feat = x.flatten(0, 1)
|
| 218 |
+
x = self.proj_conv(x.flatten(0, 1))
|
| 219 |
+
x = self.proj_bn(x).reshape(T, B, -1, H, W)
|
| 220 |
+
x = self.proj_lif(x).flatten(0, 1).contiguous()
|
| 221 |
+
|
| 222 |
+
x = self.proj1_conv(x)
|
| 223 |
+
x = self.proj1_bn(x)
|
| 224 |
+
x = self.maxpool1(x)
|
| 225 |
+
_, _, H1, W1 = x.shape
|
| 226 |
+
x = x.reshape(T, B, -1, H1, W1).contiguous()
|
| 227 |
+
x = self.proj1_lif(x).flatten(0, 1).contiguous()
|
| 228 |
+
|
| 229 |
+
x_feat = x
|
| 230 |
+
x = self.proj2_conv(x)
|
| 231 |
+
x = self.proj2_bn(x)
|
| 232 |
+
x = self.maxpool2(x)
|
| 233 |
+
_, _, H2, W2 = x.shape
|
| 234 |
+
x = x.reshape(T, B, -1, H2, W2).contiguous()
|
| 235 |
+
x = self.proj2_lif(x).flatten(0, 1).contiguous()
|
| 236 |
+
|
| 237 |
+
x = self.proj3_conv(x)
|
| 238 |
+
x = self.proj3_bn(x)
|
| 239 |
+
x = self.maxpool3(x)
|
| 240 |
+
_, _, H3, W3 = x.shape
|
| 241 |
+
x = x.reshape(T, B, -1, H3, W3).contiguous()
|
| 242 |
+
x = self.proj3_lif(x)
|
| 243 |
+
|
| 244 |
+
x_feat = self.proj_res_conv(x_feat)
|
| 245 |
+
x_feat = self.proj_res_bn(x_feat)
|
| 246 |
+
_, _, Hres, Wres = x_feat.shape
|
| 247 |
+
x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous()
|
| 248 |
+
x_feat = self.proj_res_lif(x_feat)
|
| 249 |
+
x = x + x_feat # shortcut
|
| 250 |
+
|
| 251 |
+
return x
|
| 252 |
+
|
| 253 |
+
class PatchEmbeddingStage(nn.Module):
|
| 254 |
+
def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.image_size = [img_size_h, img_size_w]
|
| 257 |
+
patch_size = to_2tuple(patch_size)
|
| 258 |
+
self.patch_size = patch_size
|
| 259 |
+
self.C = in_channels
|
| 260 |
+
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
|
| 261 |
+
self.num_patches = self.H * self.W
|
| 262 |
+
|
| 263 |
+
self.proj_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
|
| 264 |
+
self.proj_bn = nn.BatchNorm2d(embed_dims)
|
| 265 |
+
self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 266 |
+
|
| 267 |
+
self.proj4_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
|
| 268 |
+
self.proj4_bn = nn.BatchNorm2d(embed_dims)
|
| 269 |
+
self.proj4_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 270 |
+
self.proj4_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 271 |
+
|
| 272 |
+
self.proj_res_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=1, stride=2, padding=0, bias=False)
|
| 273 |
+
self.proj_res_bn = nn.BatchNorm2d(embed_dims)
|
| 274 |
+
self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch')
|
| 275 |
+
|
| 276 |
+
def forward(self, x):
|
| 277 |
+
T, B, C, H, W = x.shape
|
| 278 |
+
# Downsampling + Res
|
| 279 |
+
|
| 280 |
+
x = x.flatten(0, 1).contiguous()
|
| 281 |
+
x_feat = x
|
| 282 |
+
|
| 283 |
+
x = self.proj_conv(x)
|
| 284 |
+
x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
|
| 285 |
+
x = self.proj_lif(x).flatten(0, 1).contiguous()
|
| 286 |
+
|
| 287 |
+
x = self.proj4_conv(x)
|
| 288 |
+
x = self.proj4_bn(x)
|
| 289 |
+
x = self.proj4_maxpool(x)
|
| 290 |
+
_, _, H4, W4 = x.shape
|
| 291 |
+
x = x.reshape(T, B, -1, H4, W4).contiguous()
|
| 292 |
+
x = self.proj4_lif(x)
|
| 293 |
+
|
| 294 |
+
x_feat = self.proj_res_conv(x_feat)
|
| 295 |
+
x_feat = self.proj_res_bn(x_feat)
|
| 296 |
+
_, _, Hres, Wres = x_feat.shape
|
| 297 |
+
x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous()
|
| 298 |
+
x_feat = self.proj_res_lif(x_feat)
|
| 299 |
+
|
| 300 |
+
x = x + x_feat # shortcut
|
| 301 |
+
|
| 302 |
+
return x
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class vit_snn(nn.Module):
|
| 306 |
+
def __init__(self,
|
| 307 |
+
img_size_h=128, img_size_w=128, patch_size=16, in_channels=2, num_classes=11,
|
| 308 |
+
embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None,
|
| 309 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
| 310 |
+
depths=[6, 8, 6], sr_ratios=[8, 4, 2], T=4, pretrained_cfg=None, in_chans = 3, no_weight_decay = None
|
| 311 |
+
):
|
| 312 |
+
super().__init__()
|
| 313 |
+
self.num_classes = num_classes
|
| 314 |
+
self.depths = depths
|
| 315 |
+
self.T = T
|
| 316 |
+
num_heads = [16, 16, 16]
|
| 317 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
|
| 318 |
+
|
| 319 |
+
#
|
| 320 |
+
patch_embed1 = PatchEmbedInit(img_size_h=img_size_h,
|
| 321 |
+
img_size_w=img_size_w,
|
| 322 |
+
patch_size=patch_size,
|
| 323 |
+
in_channels=in_channels,
|
| 324 |
+
embed_dims=embed_dims // 2)
|
| 325 |
+
|
| 326 |
+
stage1 = nn.ModuleList([TokenSpikingTransformer(
|
| 327 |
+
dim=embed_dims // 2, num_heads=num_heads[0], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
|
| 328 |
+
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
|
| 329 |
+
norm_layer=norm_layer, sr_ratio=sr_ratios)
|
| 330 |
+
for j in range(1)])
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
patch_embed2 = PatchEmbeddingStage(img_size_h=img_size_h,
|
| 334 |
+
img_size_w=img_size_w,
|
| 335 |
+
patch_size=patch_size,
|
| 336 |
+
in_channels=in_channels,
|
| 337 |
+
embed_dims=embed_dims)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
stage2 = nn.ModuleList([SpikingTransformer(
|
| 341 |
+
dim=embed_dims, num_heads=num_heads[1], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
|
| 342 |
+
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
|
| 343 |
+
norm_layer=norm_layer, sr_ratio=sr_ratios)
|
| 344 |
+
for j in range(1)])
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
setattr(self, f"patch_embed1", patch_embed1)
|
| 348 |
+
setattr(self, f"stage1", stage1)
|
| 349 |
+
setattr(self, f"patch_embed2", patch_embed2)
|
| 350 |
+
setattr(self, f"stage2", stage2)
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# classification head
|
| 354 |
+
self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
|
| 355 |
+
self.apply(self._init_weights)
|
| 356 |
+
|
| 357 |
+
@torch.jit.ignore
|
| 358 |
+
def no_weight_decay(self):
|
| 359 |
+
return {'pose_embed'}
|
| 360 |
+
|
| 361 |
+
@torch.jit.ignore
|
| 362 |
+
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
|
| 363 |
+
return None
|
| 364 |
+
|
| 365 |
+
def _init_weights(self, m):
|
| 366 |
+
if isinstance(m, nn.Linear):
|
| 367 |
+
trunc_normal_(m.weight, std=.02)
|
| 368 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 369 |
+
nn.init.constant_(m.bias, 0)
|
| 370 |
+
elif isinstance(m, nn.LayerNorm):
|
| 371 |
+
nn.init.constant_(m.bias, 0)
|
| 372 |
+
nn.init.constant_(m.weight, 1.0)
|
| 373 |
+
|
| 374 |
+
def forward_features(self, x):
|
| 375 |
+
stage1 = getattr(self, f"stage1")
|
| 376 |
+
patch_embed1 = getattr(self, f"patch_embed1")
|
| 377 |
+
stage2 = getattr(self, f"stage2")
|
| 378 |
+
patch_embed2 = getattr(self, f"patch_embed2")
|
| 379 |
+
|
| 380 |
+
x = patch_embed1(x)
|
| 381 |
+
for blk in stage1:
|
| 382 |
+
x = blk(x)
|
| 383 |
+
|
| 384 |
+
x = patch_embed2(x)
|
| 385 |
+
for blk in stage2:
|
| 386 |
+
x = blk(x)
|
| 387 |
+
|
| 388 |
+
return x.flatten(3).mean(3)
|
| 389 |
+
|
| 390 |
+
def forward(self, x):
|
| 391 |
+
x = x.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
|
| 392 |
+
x = self.forward_features(x)
|
| 393 |
+
x = self.head(x.mean(0))
|
| 394 |
+
return x
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
@register_model
|
| 398 |
+
def QKFormer_1003(pretrained=False, **kwargs):
|
| 399 |
+
model = vit_snn(
|
| 400 |
+
patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=1,
|
| 401 |
+
in_channels=2, num_classes=101, qkv_bias=False,
|
| 402 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=4, sr_ratios=1,
|
| 403 |
+
**kwargs
|
| 404 |
+
)
|
| 405 |
+
model.default_cfg = _cfg()
|
| 406 |
+
return model
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
from timm.models import create_model
|
| 410 |
+
|
| 411 |
+
if __name__ == '__main__':
|
| 412 |
+
x = torch.randn(1, 1, 2, 128, 128).cuda()
|
| 413 |
+
model = create_model(
|
| 414 |
+
'QKFormer_1003',
|
| 415 |
+
pretrained=False,
|
| 416 |
+
drop_rate=0,
|
| 417 |
+
drop_path_rate=0.1,
|
| 418 |
+
drop_block_rate=None,
|
| 419 |
+
).cuda()
|
| 420 |
+
model.eval()
|
| 421 |
+
|
| 422 |
+
from torchinfo import summary
|
| 423 |
+
summary(model, input_size=(1, 1, 2, 128, 128))
|
| 424 |
+
y = model(x)
|
| 425 |
+
print(y.shape)
|
| 426 |
+
print('Test Good!')
|
models/qk_model_with_delay/__init__.py
ADDED
|
File without changes
|
models/qk_model_with_delay/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (221 Bytes). View file
|
|
|
models/qk_model_with_delay/__pycache__/delay_synaptic_func_inter.cpython-311.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
models/qk_model_with_delay/__pycache__/delay_synaptic_inter_model.cpython-311.pyc
ADDED
|
Binary file (30.3 kB). View file
|
|
|
models/qk_model_with_delay/delay_synaptic_func_inter.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
def set_sigma_for_DCLS(model, s):
|
| 7 |
+
for name, module in model.named_modules():
|
| 8 |
+
if module.__class__.__name__ == 'DelayConv':
|
| 9 |
+
if hasattr(module, 'sigma'):
|
| 10 |
+
module.sigma = s
|
| 11 |
+
print('Set sigma to ',s)
|
| 12 |
+
|
| 13 |
+
class DropoutNd(nn.Module):
|
| 14 |
+
def __init__(self, p: float = 0.5, tie=True, transposed=True):
|
| 15 |
+
"""
|
| 16 |
+
tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
|
| 17 |
+
"""
|
| 18 |
+
super().__init__()
|
| 19 |
+
if p < 0 or p >= 1:
|
| 20 |
+
raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
|
| 21 |
+
self.p = p
|
| 22 |
+
self.tie = tie
|
| 23 |
+
self.transposed = transposed
|
| 24 |
+
self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
|
| 25 |
+
|
| 26 |
+
def forward(self, X):
|
| 27 |
+
"""X: (batch, dim, lengths...)."""
|
| 28 |
+
if self.training:
|
| 29 |
+
if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
|
| 30 |
+
# binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying
|
| 31 |
+
mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
|
| 32 |
+
# mask = self.binomial.sample(mask_shape)
|
| 33 |
+
mask = torch.rand(*mask_shape, device=X.device) < 1. - self.p
|
| 34 |
+
X = X * mask * (1.0 / (1 - self.p))
|
| 35 |
+
if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
|
| 36 |
+
return X
|
| 37 |
+
return X
|
| 38 |
+
|
| 39 |
+
class DelayConv(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
in_c,
|
| 43 |
+
k,
|
| 44 |
+
dropout=0.0,
|
| 45 |
+
n_delay=1,
|
| 46 |
+
dilation=1,
|
| 47 |
+
kernel_type='triangle_r_temp'
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.C = in_c # 输入和输出通道数
|
| 51 |
+
self.win_len = k
|
| 52 |
+
self.dilation = dilation
|
| 53 |
+
self.n_delay = n_delay
|
| 54 |
+
self.kernel_type = kernel_type
|
| 55 |
+
|
| 56 |
+
self.t = torch.arange(self.win_len).float().unsqueeze(0) # [1, k]
|
| 57 |
+
self.sigma = self.win_len // 2
|
| 58 |
+
|
| 59 |
+
self.delay_kernel = None
|
| 60 |
+
self.bump = None
|
| 61 |
+
|
| 62 |
+
# ========== 修改:d 形状 -> [C_out, C_in, n_delay] ==========
|
| 63 |
+
d = torch.rand(self.C, self.C, self.n_delay)
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
for co in range(self.C):
|
| 66 |
+
for ci in range(self.C):
|
| 67 |
+
d[co, ci, :] = torch.randperm(self.win_len - 2)[:self.n_delay] + 1
|
| 68 |
+
self.register("d", d, lr=1e-2)
|
| 69 |
+
|
| 70 |
+
# 初始化权重: [C_out, C_in, k]
|
| 71 |
+
weight = torch.ones([self.C, self.C, k])
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
for co in range(self.C): # output channel
|
| 74 |
+
for ci in range(self.C): # input channel
|
| 75 |
+
for i in range(k - 2, -1, -1):
|
| 76 |
+
weight[co, ci, i] = weight[co, ci, i + 1] / 2
|
| 77 |
+
|
| 78 |
+
self.weight = nn.Parameter(weight)
|
| 79 |
+
|
| 80 |
+
self.dropout = nn.Dropout(dropout / 5) if dropout > 0.0 else nn.Identity()
|
| 81 |
+
|
| 82 |
+
def register(self, name, tensor, lr=None):
|
| 83 |
+
"""注册可训练或固定参数"""
|
| 84 |
+
if lr == 0.0:
|
| 85 |
+
self.register_buffer(name, tensor)
|
| 86 |
+
else:
|
| 87 |
+
self.register_parameter(name, nn.Parameter(tensor))
|
| 88 |
+
optim = {"weight_decay": 0}
|
| 89 |
+
if lr is not None:
|
| 90 |
+
optim["lr"] = lr
|
| 91 |
+
setattr(getattr(self, name), "_optim", optim)
|
| 92 |
+
|
| 93 |
+
def update_kernel(self, device):
|
| 94 |
+
"""
|
| 95 |
+
输出 delay kernel: shape [C_out, C_in, k]
|
| 96 |
+
"""
|
| 97 |
+
t = self.t.to(device).view(1, 1, 1, -1) # [1,1,1,k]
|
| 98 |
+
d = self.d.to(device) # [C_out, C_in, n_delay]
|
| 99 |
+
|
| 100 |
+
# ---------- 计算 bump ----------
|
| 101 |
+
if self.kernel_type == 'gauss':
|
| 102 |
+
bump = torch.exp(-0.5 * ((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma) ** 2)
|
| 103 |
+
bump = (bump - 1e-3).relu() + 1e-3
|
| 104 |
+
bump = bump / (bump.sum(dim=-1, keepdim=True) + 1e-7)
|
| 105 |
+
|
| 106 |
+
elif self.kernel_type == 'triangle':
|
| 107 |
+
bump = torch.relu(1 - torch.abs((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma))
|
| 108 |
+
bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
|
| 109 |
+
|
| 110 |
+
elif self.kernel_type == 'triangle_r':
|
| 111 |
+
d_int = (d.round() - d).detach() + d
|
| 112 |
+
bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
|
| 113 |
+
bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
|
| 114 |
+
|
| 115 |
+
elif self.kernel_type == 'triangle_r_temp':
|
| 116 |
+
scale = min(1.0, 1.0 / self.sigma)
|
| 117 |
+
d_int = (d.round() - d).detach() * scale + d
|
| 118 |
+
bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
|
| 119 |
+
bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7) # [C_out, C_in, n_delay, k]
|
| 120 |
+
# ------ 在eval模式硬化bump ------
|
| 121 |
+
if not self.training:
|
| 122 |
+
max_idx = bump.argmax(dim=-1, keepdim=True) # 找最大值索引
|
| 123 |
+
hard_mask = torch.zeros_like(bump)
|
| 124 |
+
hard_mask.scatter_(-1, max_idx, 1.0)
|
| 125 |
+
bump = bump * hard_mask
|
| 126 |
+
# --------------------------------
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Unknown kernel_type: {self.kernel_type}")
|
| 129 |
+
|
| 130 |
+
# bump: [C_out, C_in, n_delay, k]
|
| 131 |
+
self.bump = bump.detach().clone().to(device)
|
| 132 |
+
|
| 133 |
+
# ---------- 沿 n_delay 维度求和: [C_out, C_in, k] ----------
|
| 134 |
+
bump_sum = bump.sum(dim=2)
|
| 135 |
+
|
| 136 |
+
# ---------- 生成最终卷积核 ----------
|
| 137 |
+
# weight: [C_out, C_in, k]
|
| 138 |
+
self.delay_kernel = (self.weight * bump_sum).to(device) # [C_out, C_in, k]
|
| 139 |
+
|
| 140 |
+
def forward(self, x):
|
| 141 |
+
"""
|
| 142 |
+
x: (T, B, N, C)
|
| 143 |
+
return: (T*B, C, N)
|
| 144 |
+
"""
|
| 145 |
+
# 调整维度
|
| 146 |
+
x = x.permute(0, 1, 3, 2).contiguous() # (T, B, N, C)
|
| 147 |
+
T, B, N, C = x.shape
|
| 148 |
+
assert C == self.C, f"Input channel mismatch: {C} vs {self.C}"
|
| 149 |
+
x = x.permute(1, 2, 3, 0).contiguous() # (B, N, C, T)
|
| 150 |
+
|
| 151 |
+
# 合并 B*N 作为 batch
|
| 152 |
+
x_reshaped = x.view(B * N, C, T) # (B*N, C, T)
|
| 153 |
+
device = x.device
|
| 154 |
+
|
| 155 |
+
# 更新 kernel
|
| 156 |
+
self.update_kernel(device) # -> [C_out, C_in, k]
|
| 157 |
+
kernel = self.delay_kernel
|
| 158 |
+
|
| 159 |
+
# padding
|
| 160 |
+
pad_left = (self.win_len - 1) * self.dilation
|
| 161 |
+
x_padded = F.pad(x_reshaped, (pad_left, 0)) # (B*N, C, T+pad)
|
| 162 |
+
|
| 163 |
+
# 全通道卷积: groups=1 (跨通道交互)
|
| 164 |
+
y = F.conv1d(x_padded, kernel, stride=1, dilation=self.dilation, groups=1) # (B*N, C, T)
|
| 165 |
+
|
| 166 |
+
# 还原到原始形状
|
| 167 |
+
y = y.view(B, N, C, T).permute(3, 0, 2, 1).contiguous().view(-1, C, N) # (T*B, C, N)
|
| 168 |
+
|
| 169 |
+
return self.dropout(y)
|
models/qk_model_with_delay/delay_synaptic_inter_model.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode
|
| 4 |
+
from timm.models.layers import to_2tuple, trunc_normal_, DropPath
|
| 5 |
+
from timm.models.registry import register_model
|
| 6 |
+
from timm.models.vision_transformer import _cfg
|
| 7 |
+
from functools import partial
|
| 8 |
+
from timm.models import create_model
|
| 9 |
+
from .delay_synaptic_func_inter import DelayConv
|
| 10 |
+
|
| 11 |
+
__all__ = ['delay_QKFormer']
|
| 12 |
+
|
| 13 |
+
class MLP(nn.Module):
|
| 14 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
|
| 15 |
+
super().__init__()
|
| 16 |
+
out_features = out_features or in_features
|
| 17 |
+
hidden_features = hidden_features or in_features
|
| 18 |
+
self.mlp1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1)
|
| 19 |
+
self.mlp1_bn = nn.BatchNorm2d(hidden_features)
|
| 20 |
+
self.mlp1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 21 |
+
|
| 22 |
+
self.mlp2_conv = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1)
|
| 23 |
+
self.mlp2_bn = nn.BatchNorm2d(out_features)
|
| 24 |
+
self.mlp2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 25 |
+
|
| 26 |
+
self.c_hidden = hidden_features
|
| 27 |
+
self.c_output = out_features
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
T, B, C, H, W = x.shape
|
| 31 |
+
|
| 32 |
+
x = self.mlp1_conv(x.flatten(0, 1))
|
| 33 |
+
x = self.mlp1_bn(x).reshape(T, B, self.c_hidden, H, W)
|
| 34 |
+
x = self.mlp1_lif(x)
|
| 35 |
+
|
| 36 |
+
x = self.mlp2_conv(x.flatten(0, 1))
|
| 37 |
+
x = self.mlp2_bn(x).reshape(T, B, C, H, W)
|
| 38 |
+
x = self.mlp2_lif(x)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
class Token_QK_Attention(nn.Module):
|
| 42 |
+
def __init__(self,
|
| 43 |
+
dim,
|
| 44 |
+
num_heads=8,
|
| 45 |
+
qkv_bias=False,
|
| 46 |
+
qk_scale=None,
|
| 47 |
+
attn_drop=0.,
|
| 48 |
+
proj_drop=0.,
|
| 49 |
+
sr_ratio=1,
|
| 50 |
+
k=16):
|
| 51 |
+
super().__init__()
|
| 52 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 53 |
+
|
| 54 |
+
self.dim = dim
|
| 55 |
+
self.num_heads = num_heads
|
| 56 |
+
|
| 57 |
+
self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 58 |
+
self.q_bn = nn.BatchNorm1d(dim)
|
| 59 |
+
self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 60 |
+
|
| 61 |
+
# self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 62 |
+
self.k_proj_delay = DelayConv(in_c=self.dim, k=k)
|
| 63 |
+
self.k_bn = nn.BatchNorm1d(dim)
|
| 64 |
+
self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 65 |
+
|
| 66 |
+
self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy')
|
| 67 |
+
|
| 68 |
+
self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
|
| 69 |
+
self.proj_bn = nn.BatchNorm1d(dim)
|
| 70 |
+
self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
T, B, C, H, W = x.shape
|
| 74 |
+
|
| 75 |
+
x = x.flatten(3)
|
| 76 |
+
T, B, C, N = x.shape
|
| 77 |
+
x_for_qkv = x.flatten(0, 1)
|
| 78 |
+
|
| 79 |
+
q_conv_out = self.q_conv(x_for_qkv)
|
| 80 |
+
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N)
|
| 81 |
+
q_conv_out = self.q_lif(q_conv_out)
|
| 82 |
+
q = q_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N)
|
| 83 |
+
|
| 84 |
+
# k_conv_out = self.k_conv(x_for_qkv)
|
| 85 |
+
k_conv_out = self.k_proj_delay(x_for_qkv.reshape(T,B,C,N))
|
| 86 |
+
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N)
|
| 87 |
+
k_conv_out = self.k_lif(k_conv_out)
|
| 88 |
+
k = k_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N)
|
| 89 |
+
|
| 90 |
+
q = torch.sum(q, dim=3, keepdim=True)
|
| 91 |
+
attn = self.attn_lif(q)
|
| 92 |
+
x = torch.mul(attn, k)
|
| 93 |
+
|
| 94 |
+
x = x.flatten(2, 3)
|
| 95 |
+
x = self.proj_bn(self.proj_conv(x.flatten(0, 1))).reshape(T, B, C, H, W)
|
| 96 |
+
x = self.proj_lif(x)
|
| 97 |
+
return x
|
| 98 |
+
|
| 99 |
+
class Spiking_Self_Attention(nn.Module):
|
| 100 |
+
def __init__(self,
|
| 101 |
+
dim,
|
| 102 |
+
num_heads=8,
|
| 103 |
+
qkv_bias=False,
|
| 104 |
+
qk_scale=None,
|
| 105 |
+
attn_drop=0.,
|
| 106 |
+
proj_drop=0.,
|
| 107 |
+
sr_ratio=1,
|
| 108 |
+
k=16):
|
| 109 |
+
super().__init__()
|
| 110 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 111 |
+
self.dim = dim
|
| 112 |
+
self.num_heads = num_heads
|
| 113 |
+
head_dim = dim // num_heads
|
| 114 |
+
self.scale = 0.125
|
| 115 |
+
self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 116 |
+
self.q_bn = nn.BatchNorm1d(dim)
|
| 117 |
+
self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 118 |
+
|
| 119 |
+
# self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 120 |
+
self.k_proj_delay = DelayConv(in_c=self.dim, k=k)
|
| 121 |
+
self.k_bn = nn.BatchNorm1d(dim)
|
| 122 |
+
self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 123 |
+
|
| 124 |
+
# self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False)
|
| 125 |
+
self.v_proj_delay = DelayConv(in_c=self.dim, k=k)
|
| 126 |
+
self.v_bn = nn.BatchNorm1d(dim)
|
| 127 |
+
self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 128 |
+
self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy')
|
| 129 |
+
|
| 130 |
+
self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1)
|
| 131 |
+
self.proj_bn = nn.BatchNorm1d(dim)
|
| 132 |
+
self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 133 |
+
|
| 134 |
+
self.qkv_mp = nn.MaxPool1d(4)
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
T, B, C, H, W = x.shape
|
| 138 |
+
|
| 139 |
+
x = x.flatten(3)
|
| 140 |
+
T, B, C, N = x.shape
|
| 141 |
+
x_for_qkv = x.flatten(0, 1)
|
| 142 |
+
|
| 143 |
+
q_conv_out = self.q_conv(x_for_qkv)
|
| 144 |
+
q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous()
|
| 145 |
+
q_conv_out = self.q_lif(q_conv_out)
|
| 146 |
+
q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
|
| 147 |
+
4).contiguous()
|
| 148 |
+
|
| 149 |
+
k_conv_out = self.k_proj_delay(x_for_qkv.reshape(T,B,C,N))
|
| 150 |
+
k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous()
|
| 151 |
+
k_conv_out = self.k_lif(k_conv_out)
|
| 152 |
+
k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
|
| 153 |
+
4).contiguous()
|
| 154 |
+
|
| 155 |
+
v_conv_out = self.v_proj_delay(x_for_qkv.reshape(T,B,C,N))
|
| 156 |
+
v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous()
|
| 157 |
+
v_conv_out = self.v_lif(v_conv_out)
|
| 158 |
+
v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2,
|
| 159 |
+
4).contiguous()
|
| 160 |
+
|
| 161 |
+
x = k.transpose(-2, -1) @ v
|
| 162 |
+
x = (q @ x) * self.scale
|
| 163 |
+
|
| 164 |
+
x = x.transpose(3, 4).reshape(T, B, C, N).contiguous()
|
| 165 |
+
x = self.attn_lif(x)
|
| 166 |
+
x = x.flatten(0, 1)
|
| 167 |
+
x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, H, W)
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
class TokenSpikingTransformer(nn.Module):
|
| 171 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 172 |
+
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
|
| 173 |
+
super().__init__()
|
| 174 |
+
self.tssa = Token_QK_Attention(dim, num_heads)
|
| 175 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 176 |
+
self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop)
|
| 177 |
+
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
|
| 180 |
+
x = x + self.tssa(x)
|
| 181 |
+
x = x + self.mlp(x)
|
| 182 |
+
|
| 183 |
+
return x
|
| 184 |
+
|
| 185 |
+
class SpikingTransformer(nn.Module):
|
| 186 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 187 |
+
drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1):
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.ssa = Spiking_Self_Attention(dim, num_heads)
|
| 190 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 191 |
+
self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop)
|
| 192 |
+
|
| 193 |
+
def forward(self, x):
|
| 194 |
+
|
| 195 |
+
x = x + self.ssa(x)
|
| 196 |
+
x = x + self.mlp(x)
|
| 197 |
+
|
| 198 |
+
return x
|
| 199 |
+
|
| 200 |
+
class PatchEmbedInit(nn.Module):
|
| 201 |
+
def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.image_size = [img_size_h, img_size_w]
|
| 204 |
+
patch_size = to_2tuple(patch_size)
|
| 205 |
+
self.patch_size = patch_size
|
| 206 |
+
self.C = in_channels
|
| 207 |
+
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
|
| 208 |
+
self.num_patches = self.H * self.W
|
| 209 |
+
|
| 210 |
+
self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False)
|
| 211 |
+
self.proj_bn = nn.BatchNorm2d(embed_dims // 8)
|
| 212 |
+
self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 213 |
+
|
| 214 |
+
self.proj1_conv = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False)
|
| 215 |
+
self.proj1_bn = nn.BatchNorm2d(embed_dims // 4)
|
| 216 |
+
self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 217 |
+
self.proj1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 218 |
+
|
| 219 |
+
self.proj2_conv = nn.Conv2d(embed_dims//4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False)
|
| 220 |
+
self.proj2_bn = nn.BatchNorm2d(embed_dims // 2)
|
| 221 |
+
self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 222 |
+
self.proj2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 223 |
+
|
| 224 |
+
self.proj3_conv = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
|
| 225 |
+
self.proj3_bn = nn.BatchNorm2d(embed_dims)
|
| 226 |
+
self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 227 |
+
self.proj3_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 228 |
+
|
| 229 |
+
self.proj_res_conv = nn.Conv2d(embed_dims // 4, embed_dims, kernel_size=1, stride=4, padding=0, bias=False)
|
| 230 |
+
self.proj_res_bn = nn.BatchNorm2d(embed_dims)
|
| 231 |
+
self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def forward(self, x):
|
| 235 |
+
T, B, C, H, W = x.shape
|
| 236 |
+
# Downsampling + Res
|
| 237 |
+
# x_feat = x.flatten(0, 1)
|
| 238 |
+
x = self.proj_conv(x.flatten(0, 1))
|
| 239 |
+
x = self.proj_bn(x).reshape(T, B, -1, H, W)
|
| 240 |
+
x = self.proj_lif(x).flatten(0, 1).contiguous()
|
| 241 |
+
|
| 242 |
+
x = self.proj1_conv(x)
|
| 243 |
+
x = self.proj1_bn(x)
|
| 244 |
+
x = self.maxpool1(x)
|
| 245 |
+
_, _, H1, W1 = x.shape
|
| 246 |
+
x = x.reshape(T, B, -1, H1, W1).contiguous()
|
| 247 |
+
x = self.proj1_lif(x).flatten(0, 1).contiguous()
|
| 248 |
+
|
| 249 |
+
x_feat = x
|
| 250 |
+
x = self.proj2_conv(x)
|
| 251 |
+
x = self.proj2_bn(x)
|
| 252 |
+
x = self.maxpool2(x)
|
| 253 |
+
_, _, H2, W2 = x.shape
|
| 254 |
+
x = x.reshape(T, B, -1, H2, W2).contiguous()
|
| 255 |
+
x = self.proj2_lif(x).flatten(0, 1).contiguous()
|
| 256 |
+
|
| 257 |
+
x = self.proj3_conv(x)
|
| 258 |
+
x = self.proj3_bn(x)
|
| 259 |
+
x = self.maxpool3(x)
|
| 260 |
+
_, _, H3, W3 = x.shape
|
| 261 |
+
x = x.reshape(T, B, -1, H3, W3).contiguous()
|
| 262 |
+
x = self.proj3_lif(x)
|
| 263 |
+
|
| 264 |
+
x_feat = self.proj_res_conv(x_feat)
|
| 265 |
+
x_feat = self.proj_res_bn(x_feat)
|
| 266 |
+
_, _, Hres, Wres = x_feat.shape
|
| 267 |
+
x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous()
|
| 268 |
+
x_feat = self.proj_res_lif(x_feat)
|
| 269 |
+
x = x + x_feat # shortcut
|
| 270 |
+
|
| 271 |
+
return x
|
| 272 |
+
|
| 273 |
+
class PatchEmbeddingStage(nn.Module):
|
| 274 |
+
def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.image_size = [img_size_h, img_size_w]
|
| 277 |
+
patch_size = to_2tuple(patch_size)
|
| 278 |
+
self.patch_size = patch_size
|
| 279 |
+
self.C = in_channels
|
| 280 |
+
self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1]
|
| 281 |
+
self.num_patches = self.H * self.W
|
| 282 |
+
|
| 283 |
+
self.proj_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
|
| 284 |
+
self.proj_bn = nn.BatchNorm2d(embed_dims)
|
| 285 |
+
self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 286 |
+
|
| 287 |
+
self.proj4_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False)
|
| 288 |
+
self.proj4_bn = nn.BatchNorm2d(embed_dims)
|
| 289 |
+
self.proj4_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
| 290 |
+
self.proj4_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 291 |
+
|
| 292 |
+
self.proj_res_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=1, stride=2, padding=0, bias=False)
|
| 293 |
+
self.proj_res_bn = nn.BatchNorm2d(embed_dims)
|
| 294 |
+
self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy')
|
| 295 |
+
|
| 296 |
+
def forward(self, x):
|
| 297 |
+
T, B, C, H, W = x.shape
|
| 298 |
+
# Downsampling + Res
|
| 299 |
+
|
| 300 |
+
x = x.flatten(0, 1).contiguous()
|
| 301 |
+
x_feat = x
|
| 302 |
+
|
| 303 |
+
x = self.proj_conv(x)
|
| 304 |
+
x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous()
|
| 305 |
+
x = self.proj_lif(x).flatten(0, 1).contiguous()
|
| 306 |
+
|
| 307 |
+
x = self.proj4_conv(x)
|
| 308 |
+
x = self.proj4_bn(x)
|
| 309 |
+
x = self.proj4_maxpool(x)
|
| 310 |
+
_, _, H4, W4 = x.shape
|
| 311 |
+
x = x.reshape(T, B, -1, H4, W4).contiguous()
|
| 312 |
+
x = self.proj4_lif(x)
|
| 313 |
+
|
| 314 |
+
x_feat = self.proj_res_conv(x_feat)
|
| 315 |
+
x_feat = self.proj_res_bn(x_feat)
|
| 316 |
+
_, _, Hres, Wres = x_feat.shape
|
| 317 |
+
x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous()
|
| 318 |
+
x_feat = self.proj_res_lif(x_feat)
|
| 319 |
+
|
| 320 |
+
x = x + x_feat # shortcut
|
| 321 |
+
|
| 322 |
+
return x
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class vit_snn(nn.Module):
|
| 326 |
+
def __init__(self,
|
| 327 |
+
img_size_h=128, img_size_w=128, patch_size=16, in_channels=2, num_classes=11,
|
| 328 |
+
embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None,
|
| 329 |
+
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
| 330 |
+
depths=[6, 8, 6], sr_ratios=[8, 4, 2], T=4, pretrained_cfg=None, in_chans = 3, no_weight_decay = None
|
| 331 |
+
):
|
| 332 |
+
super().__init__()
|
| 333 |
+
self.num_classes = num_classes
|
| 334 |
+
self.depths = depths
|
| 335 |
+
self.T = T
|
| 336 |
+
num_heads = [16, 16, 16]
|
| 337 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule
|
| 338 |
+
|
| 339 |
+
#
|
| 340 |
+
patch_embed1 = PatchEmbedInit(img_size_h=img_size_h,
|
| 341 |
+
img_size_w=img_size_w,
|
| 342 |
+
patch_size=patch_size,
|
| 343 |
+
in_channels=in_channels,
|
| 344 |
+
embed_dims=embed_dims // 2)
|
| 345 |
+
|
| 346 |
+
stage1 = nn.ModuleList([TokenSpikingTransformer(
|
| 347 |
+
dim=embed_dims // 2, num_heads=num_heads[0], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
|
| 348 |
+
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
|
| 349 |
+
norm_layer=norm_layer, sr_ratio=sr_ratios)
|
| 350 |
+
for j in range(1)])
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
patch_embed2 = PatchEmbeddingStage(img_size_h=img_size_h,
|
| 354 |
+
img_size_w=img_size_w,
|
| 355 |
+
patch_size=patch_size,
|
| 356 |
+
in_channels=in_channels,
|
| 357 |
+
embed_dims=embed_dims)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
stage2 = nn.ModuleList([SpikingTransformer(
|
| 361 |
+
dim=embed_dims, num_heads=num_heads[1], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias,
|
| 362 |
+
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j],
|
| 363 |
+
norm_layer=norm_layer, sr_ratio=sr_ratios)
|
| 364 |
+
for j in range(1)])
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
setattr(self, f"patch_embed1", patch_embed1)
|
| 368 |
+
setattr(self, f"stage1", stage1)
|
| 369 |
+
setattr(self, f"patch_embed2", patch_embed2)
|
| 370 |
+
setattr(self, f"stage2", stage2)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
# classification head
|
| 374 |
+
self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity()
|
| 375 |
+
self.apply(self._init_weights)
|
| 376 |
+
|
| 377 |
+
@torch.jit.ignore
|
| 378 |
+
def no_weight_decay(self):
|
| 379 |
+
return {'pose_embed'}
|
| 380 |
+
|
| 381 |
+
@torch.jit.ignore
|
| 382 |
+
def _get_pos_embed(self, pos_embed, patch_embed, H, W):
|
| 383 |
+
return None
|
| 384 |
+
|
| 385 |
+
def _init_weights(self, m):
|
| 386 |
+
if isinstance(m, nn.Linear):
|
| 387 |
+
trunc_normal_(m.weight, std=.02)
|
| 388 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 389 |
+
nn.init.constant_(m.bias, 0)
|
| 390 |
+
elif isinstance(m, nn.LayerNorm):
|
| 391 |
+
nn.init.constant_(m.bias, 0)
|
| 392 |
+
nn.init.constant_(m.weight, 1.0)
|
| 393 |
+
|
| 394 |
+
def forward_features(self, x):
|
| 395 |
+
stage1 = getattr(self, f"stage1")
|
| 396 |
+
patch_embed1 = getattr(self, f"patch_embed1")
|
| 397 |
+
stage2 = getattr(self, f"stage2")
|
| 398 |
+
patch_embed2 = getattr(self, f"patch_embed2")
|
| 399 |
+
|
| 400 |
+
x = patch_embed1(x)
|
| 401 |
+
for blk in stage1:
|
| 402 |
+
x = blk(x)
|
| 403 |
+
|
| 404 |
+
x = patch_embed2(x)
|
| 405 |
+
for blk in stage2:
|
| 406 |
+
x = blk(x)
|
| 407 |
+
|
| 408 |
+
return x.flatten(3).mean(3)
|
| 409 |
+
|
| 410 |
+
def forward(self, x):
|
| 411 |
+
x = x.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *]
|
| 412 |
+
# print("torch.unique", torch.unique(x))
|
| 413 |
+
# print("torch.count_nonzero", torch.count_nonzero(x))
|
| 414 |
+
# print("numel()", x.numel())
|
| 415 |
+
x = self.forward_features(x)
|
| 416 |
+
x = self.head(x.mean(0))
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
@register_model
|
| 421 |
+
def delay_QKFormer(pretrained=False, **kwargs):
|
| 422 |
+
model = vit_snn(
|
| 423 |
+
patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=4,
|
| 424 |
+
in_channels=2, num_classes=101, qkv_bias=False,
|
| 425 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=4, sr_ratios=1,
|
| 426 |
+
**kwargs
|
| 427 |
+
)
|
| 428 |
+
model.default_cfg = _cfg()
|
| 429 |
+
return model
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
from timm.models import create_model
|
| 433 |
+
|
| 434 |
+
if __name__ == '__main__':
|
| 435 |
+
x = torch.randn(1, 1, 2, 128, 128).cuda()
|
| 436 |
+
model = create_model(
|
| 437 |
+
'delay_QKFormer',
|
| 438 |
+
pretrained=False,
|
| 439 |
+
drop_rate=0,
|
| 440 |
+
drop_path_rate=0.1,
|
| 441 |
+
drop_block_rate=None,
|
| 442 |
+
).cuda()
|
| 443 |
+
model.eval()
|
| 444 |
+
|
| 445 |
+
from torchinfo import summary
|
| 446 |
+
summary(model, input_size=(1, 1, 2, 128, 128))
|
| 447 |
+
# y = model(x)
|
| 448 |
+
# print(y.shape)
|
| 449 |
+
# print('Test Good!')
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
|