Max-Ploter commited on
Commit
fe103f6
·
verified ·
1 Parent(s): c591d54

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"num_channels": 2048, "num_freq_bands": 6, "max_freq": 10, "enc_layers": 1, "num_queries": 100, "hidden_dim": 256, "enc_nheads_cross": 8, "nheads": 1, "dropout": 0, "self_per_cross_attn": 8, "num_classes": 91}
detection_perceiver_model_weights_pytorch.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4658c298c59a2d5915ac18550a963152cdfbfb973990a3a45bc3073f8ba8f05e
3
+ size 170587066
modeling_detection_perceiver.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torchvision.models import ResNet50_Weights
3
+ from torchvision.models._utils import IntermediateLayerGetter
4
+ import torchvision
5
+ from perceiver_pytorch import Perceiver
6
+ from torchvision.ops.misc import FrozenBatchNorm2d
7
+
8
+
9
+ class DetectionPerceiver(nn.Module):
10
+ def __init__(self, backbone, perceiver, class_embed, bbox_embed):
11
+ super().__init__()
12
+ self.backbone = backbone
13
+ self.perceiver = perceiver
14
+ self.class_embed = class_embed
15
+ self.bbox_embed = bbox_embed
16
+
17
+ def forward(self, x):
18
+ x = self.backbone(x)
19
+ x = x['0'].permute(0, 2, 3, 1)
20
+ x = self.perceiver(data=x, return_embeddings=True)
21
+ return {'pred_logits': self.class_embed(x), 'pred_boxes': self.bbox_embed(x)}
22
+
23
+
24
+ def build_model(config):
25
+ # Backbone
26
+ backbone = IntermediateLayerGetter(
27
+ torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1, norm_layer=FrozenBatchNorm2d),
28
+ return_layers={'layer4': "0"}
29
+ )
30
+
31
+ # Perceiver
32
+ fourier_channels = 2 * ((config["num_freq_bands"] * 2) + 1)
33
+ perceiver = Perceiver(
34
+ input_channels=config["num_channels"],
35
+ input_axis=2,
36
+ num_freq_bands=config["num_freq_bands"],
37
+ max_freq=config["max_freq"],
38
+ depth=config["enc_layers"],
39
+ num_latents=config["num_queries"],
40
+ latent_dim=config["hidden_dim"],
41
+ cross_heads=config["enc_nheads_cross"],
42
+ latent_heads=config["nheads"],
43
+ cross_dim_head=(config["num_channels"] + fourier_channels) // config["enc_nheads_cross"],
44
+ latent_dim_head=config["hidden_dim"] // config["nheads"],
45
+ self_per_cross_attn=config["self_per_cross_attn"],
46
+ fourier_encode_data=True,
47
+ attn_dropout=config["dropout"],
48
+ ff_dropout=config["dropout"],
49
+ final_classifier_head=False
50
+ )
51
+
52
+ # Embeddings
53
+ bbox_embed = nn.Sequential(
54
+ nn.Linear(config["hidden_dim"], config["hidden_dim"]),
55
+ nn.ReLU(),
56
+ nn.Linear(config["hidden_dim"], config["hidden_dim"]),
57
+ nn.ReLU(),
58
+ nn.Linear(config["hidden_dim"], 4),
59
+ nn.Sigmoid()
60
+ )
61
+
62
+ class_embed = nn.Linear(config["hidden_dim"], config["num_classes"] + 1)
63
+
64
+ return DetectionPerceiver(backbone, perceiver, class_embed, bbox_embed)