Upload folder using huggingface_hub
Browse files
config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"num_channels": 2048, "num_freq_bands": 6, "max_freq": 10, "enc_layers": 1, "num_queries": 100, "hidden_dim": 256, "enc_nheads_cross": 8, "nheads": 1, "dropout": 0, "self_per_cross_attn": 8, "num_classes": 91}
|
detection_perceiver_model_weights_pytorch.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4658c298c59a2d5915ac18550a963152cdfbfb973990a3a45bc3073f8ba8f05e
|
| 3 |
+
size 170587066
|
modeling_detection_perceiver.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
from torchvision.models import ResNet50_Weights
|
| 3 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
| 4 |
+
import torchvision
|
| 5 |
+
from perceiver_pytorch import Perceiver
|
| 6 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DetectionPerceiver(nn.Module):
|
| 10 |
+
def __init__(self, backbone, perceiver, class_embed, bbox_embed):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.backbone = backbone
|
| 13 |
+
self.perceiver = perceiver
|
| 14 |
+
self.class_embed = class_embed
|
| 15 |
+
self.bbox_embed = bbox_embed
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = self.backbone(x)
|
| 19 |
+
x = x['0'].permute(0, 2, 3, 1)
|
| 20 |
+
x = self.perceiver(data=x, return_embeddings=True)
|
| 21 |
+
return {'pred_logits': self.class_embed(x), 'pred_boxes': self.bbox_embed(x)}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_model(config):
|
| 25 |
+
# Backbone
|
| 26 |
+
backbone = IntermediateLayerGetter(
|
| 27 |
+
torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1, norm_layer=FrozenBatchNorm2d),
|
| 28 |
+
return_layers={'layer4': "0"}
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Perceiver
|
| 32 |
+
fourier_channels = 2 * ((config["num_freq_bands"] * 2) + 1)
|
| 33 |
+
perceiver = Perceiver(
|
| 34 |
+
input_channels=config["num_channels"],
|
| 35 |
+
input_axis=2,
|
| 36 |
+
num_freq_bands=config["num_freq_bands"],
|
| 37 |
+
max_freq=config["max_freq"],
|
| 38 |
+
depth=config["enc_layers"],
|
| 39 |
+
num_latents=config["num_queries"],
|
| 40 |
+
latent_dim=config["hidden_dim"],
|
| 41 |
+
cross_heads=config["enc_nheads_cross"],
|
| 42 |
+
latent_heads=config["nheads"],
|
| 43 |
+
cross_dim_head=(config["num_channels"] + fourier_channels) // config["enc_nheads_cross"],
|
| 44 |
+
latent_dim_head=config["hidden_dim"] // config["nheads"],
|
| 45 |
+
self_per_cross_attn=config["self_per_cross_attn"],
|
| 46 |
+
fourier_encode_data=True,
|
| 47 |
+
attn_dropout=config["dropout"],
|
| 48 |
+
ff_dropout=config["dropout"],
|
| 49 |
+
final_classifier_head=False
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Embeddings
|
| 53 |
+
bbox_embed = nn.Sequential(
|
| 54 |
+
nn.Linear(config["hidden_dim"], config["hidden_dim"]),
|
| 55 |
+
nn.ReLU(),
|
| 56 |
+
nn.Linear(config["hidden_dim"], config["hidden_dim"]),
|
| 57 |
+
nn.ReLU(),
|
| 58 |
+
nn.Linear(config["hidden_dim"], 4),
|
| 59 |
+
nn.Sigmoid()
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
class_embed = nn.Linear(config["hidden_dim"], config["num_classes"] + 1)
|
| 63 |
+
|
| 64 |
+
return DetectionPerceiver(backbone, perceiver, class_embed, bbox_embed)
|