Upload model
Browse files- configuration_rf_detr.py +19 -0
- modeling_rf_detr.py +38 -1
configuration_rf_detr.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Dict, Literal, List, OrderedDict
|
|
| 3 |
import torch
|
| 4 |
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
from optimum.exporters.onnx.model_configs import ViTOnnxConfig
|
|
|
|
| 6 |
|
| 7 |
### modified from https://github.com/roboflow/rf-detr/blob/main/rfdetr/config.py
|
| 8 |
|
|
@@ -66,7 +67,25 @@ class RFDetrConfig(PretrainedConfig):
|
|
| 66 |
super().__init__(**kwargs)
|
| 67 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
class RFDetrOnnxConfig(ViTOnnxConfig):
|
|
|
|
|
|
|
| 70 |
@property
|
| 71 |
def inputs(self) -> Dict[str, Dict[int, str]]:
|
| 72 |
return OrderedDict(
|
|
|
|
| 3 |
import torch
|
| 4 |
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
from optimum.exporters.onnx.model_configs import ViTOnnxConfig
|
| 6 |
+
from optimum.utils import DummyVisionInputGenerator
|
| 7 |
|
| 8 |
### modified from https://github.com/roboflow/rf-detr/blob/main/rfdetr/config.py
|
| 9 |
|
|
|
|
| 67 |
super().__init__(**kwargs)
|
| 68 |
|
| 69 |
|
| 70 |
+
class RFDetrDummyInputGenerator(DummyVisionInputGenerator):
|
| 71 |
+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
|
| 72 |
+
if input_name == "pixel_mask":
|
| 73 |
+
return self.random_mask_tensor(
|
| 74 |
+
shape=[self.batch_size, self.height, self.width],
|
| 75 |
+
framework=framework,
|
| 76 |
+
dtype="bool",
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
return self.random_float_tensor(
|
| 80 |
+
shape=[self.batch_size, self.num_channels, self.height, self.width],
|
| 81 |
+
framework=framework,
|
| 82 |
+
dtype=float_dtype,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
class RFDetrOnnxConfig(ViTOnnxConfig):
|
| 87 |
+
DUMMY_INPUT_GENERATOR_CLASSES = (RFDetrDummyInputGenerator,)
|
| 88 |
+
|
| 89 |
@property
|
| 90 |
def inputs(self) -> Dict[str, Dict[int, str]]:
|
| 91 |
return OrderedDict(
|
modeling_rf_detr.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import List, Dict
|
|
|
|
| 3 |
|
| 4 |
import torch
|
| 5 |
from torchvision.transforms import Resize, InterpolationMode
|
|
@@ -12,6 +13,38 @@ from .configuration_rf_detr import RFDetrConfig
|
|
| 12 |
|
| 13 |
### ONLY WORKS WITH Transformers version 4.50.3 and python 3.11
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
@dataclass
|
| 16 |
class RFDetrObjectDetectionOutput(ModelOutput):
|
| 17 |
loss: torch.Tensor = None
|
|
@@ -118,7 +151,11 @@ class RFDetrModelForObjectDetection(PreTrainedModel):
|
|
| 118 |
label["labels"] = label["labels"].to(self.config.device)
|
| 119 |
|
| 120 |
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor, labels=None, **kwargs) -> ModelOutput:
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
if labels is not None:
|
| 124 |
self.validate_labels(labels)
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import List, Dict
|
| 3 |
+
import math
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from torchvision.transforms import Resize, InterpolationMode
|
|
|
|
| 13 |
|
| 14 |
### ONLY WORKS WITH Transformers version 4.50.3 and python 3.11
|
| 15 |
|
| 16 |
+
# modified from https://github.com/roboflow/rf-detr/blob/develop/rfdetr/models/backbone/dinov2.py make_new_interpolated_pos_encoding
|
| 17 |
+
def _onnx_make_new_interpolated_pos_encoding(
|
| 18 |
+
position_embeddings, patch_size, height, width
|
| 19 |
+
):
|
| 20 |
+
|
| 21 |
+
num_positions = position_embeddings.shape[1] - 1
|
| 22 |
+
dim = position_embeddings.shape[-1]
|
| 23 |
+
height = height // patch_size
|
| 24 |
+
width = width // patch_size
|
| 25 |
+
|
| 26 |
+
class_pos_embed = position_embeddings[:, 0]
|
| 27 |
+
patch_pos_embed = position_embeddings[:, 1:]
|
| 28 |
+
|
| 29 |
+
# Reshape and permute
|
| 30 |
+
patch_pos_embed = patch_pos_embed.reshape(
|
| 31 |
+
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
|
| 32 |
+
)
|
| 33 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
| 34 |
+
|
| 35 |
+
# Use bilinear interpolation without antialias
|
| 36 |
+
patch_pos_embed = F.interpolate(
|
| 37 |
+
patch_pos_embed,
|
| 38 |
+
size=(height, width),
|
| 39 |
+
mode="bicubic",
|
| 40 |
+
align_corners=False,
|
| 41 |
+
antialias=False,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Reshape back
|
| 45 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
|
| 46 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
| 47 |
+
|
| 48 |
@dataclass
|
| 49 |
class RFDetrObjectDetectionOutput(ModelOutput):
|
| 50 |
loss: torch.Tensor = None
|
|
|
|
| 151 |
label["labels"] = label["labels"].to(self.config.device)
|
| 152 |
|
| 153 |
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor, labels=None, **kwargs) -> ModelOutput:
|
| 154 |
+
if torch.jit.is_tracing():
|
| 155 |
+
resize = Resize((self.config.resolution, self.config.resolution), interpolation=InterpolationMode.NEAREST) # interpolation mode set to nearest for onnx export
|
| 156 |
+
self.model.backbone[0].encoder.encoder.embeddings.interpolate_pos_encoding = lambda self_mod, embeddings, height, width : self.model.backbone[0].encoder.encoder.embeddings.position_embeddings # skip interpolation for onnx export
|
| 157 |
+
else:
|
| 158 |
+
resize = Resize((self.config.resolution, self.config.resolution))
|
| 159 |
|
| 160 |
if labels is not None:
|
| 161 |
self.validate_labels(labels)
|