Thastp commited on
Commit
bf40aa3
·
verified ·
1 Parent(s): a94d598

Upload model

Browse files
Files changed (2) hide show
  1. configuration_rf_detr.py +19 -0
  2. modeling_rf_detr.py +38 -1
configuration_rf_detr.py CHANGED
@@ -3,6 +3,7 @@ from typing import Dict, Literal, List, OrderedDict
3
  import torch
4
  from transformers.configuration_utils import PretrainedConfig
5
  from optimum.exporters.onnx.model_configs import ViTOnnxConfig
 
6
 
7
  ### modified from https://github.com/roboflow/rf-detr/blob/main/rfdetr/config.py
8
 
@@ -66,7 +67,25 @@ class RFDetrConfig(PretrainedConfig):
66
  super().__init__(**kwargs)
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  class RFDetrOnnxConfig(ViTOnnxConfig):
 
 
70
  @property
71
  def inputs(self) -> Dict[str, Dict[int, str]]:
72
  return OrderedDict(
 
3
  import torch
4
  from transformers.configuration_utils import PretrainedConfig
5
  from optimum.exporters.onnx.model_configs import ViTOnnxConfig
6
+ from optimum.utils import DummyVisionInputGenerator
7
 
8
  ### modified from https://github.com/roboflow/rf-detr/blob/main/rfdetr/config.py
9
 
 
67
  super().__init__(**kwargs)
68
 
69
 
70
+ class RFDetrDummyInputGenerator(DummyVisionInputGenerator):
71
+ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
72
+ if input_name == "pixel_mask":
73
+ return self.random_mask_tensor(
74
+ shape=[self.batch_size, self.height, self.width],
75
+ framework=framework,
76
+ dtype="bool",
77
+ )
78
+ else:
79
+ return self.random_float_tensor(
80
+ shape=[self.batch_size, self.num_channels, self.height, self.width],
81
+ framework=framework,
82
+ dtype=float_dtype,
83
+ )
84
+
85
+
86
  class RFDetrOnnxConfig(ViTOnnxConfig):
87
+ DUMMY_INPUT_GENERATOR_CLASSES = (RFDetrDummyInputGenerator,)
88
+
89
  @property
90
  def inputs(self) -> Dict[str, Dict[int, str]]:
91
  return OrderedDict(
modeling_rf_detr.py CHANGED
@@ -1,5 +1,6 @@
1
  from dataclasses import dataclass
2
  from typing import List, Dict
 
3
 
4
  import torch
5
  from torchvision.transforms import Resize, InterpolationMode
@@ -12,6 +13,38 @@ from .configuration_rf_detr import RFDetrConfig
12
 
13
  ### ONLY WORKS WITH Transformers version 4.50.3 and python 3.11
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @dataclass
16
  class RFDetrObjectDetectionOutput(ModelOutput):
17
  loss: torch.Tensor = None
@@ -118,7 +151,11 @@ class RFDetrModelForObjectDetection(PreTrainedModel):
118
  label["labels"] = label["labels"].to(self.config.device)
119
 
120
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor, labels=None, **kwargs) -> ModelOutput:
121
- resize = Resize((self.config.resolution, self.config.resolution), interpolation=InterpolationMode.NEAREST) # interpolation mode set to nearest for onnx export
 
 
 
 
122
 
123
  if labels is not None:
124
  self.validate_labels(labels)
 
1
  from dataclasses import dataclass
2
  from typing import List, Dict
3
+ import math
4
 
5
  import torch
6
  from torchvision.transforms import Resize, InterpolationMode
 
13
 
14
  ### ONLY WORKS WITH Transformers version 4.50.3 and python 3.11
15
 
16
+ # modified from https://github.com/roboflow/rf-detr/blob/develop/rfdetr/models/backbone/dinov2.py make_new_interpolated_pos_encoding
17
+ def _onnx_make_new_interpolated_pos_encoding(
18
+ position_embeddings, patch_size, height, width
19
+ ):
20
+
21
+ num_positions = position_embeddings.shape[1] - 1
22
+ dim = position_embeddings.shape[-1]
23
+ height = height // patch_size
24
+ width = width // patch_size
25
+
26
+ class_pos_embed = position_embeddings[:, 0]
27
+ patch_pos_embed = position_embeddings[:, 1:]
28
+
29
+ # Reshape and permute
30
+ patch_pos_embed = patch_pos_embed.reshape(
31
+ 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
32
+ )
33
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
34
+
35
+ # Use bilinear interpolation without antialias
36
+ patch_pos_embed = F.interpolate(
37
+ patch_pos_embed,
38
+ size=(height, width),
39
+ mode="bicubic",
40
+ align_corners=False,
41
+ antialias=False,
42
+ )
43
+
44
+ # Reshape back
45
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
46
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
47
+
48
  @dataclass
49
  class RFDetrObjectDetectionOutput(ModelOutput):
50
  loss: torch.Tensor = None
 
151
  label["labels"] = label["labels"].to(self.config.device)
152
 
153
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor, labels=None, **kwargs) -> ModelOutput:
154
+ if torch.jit.is_tracing():
155
+ resize = Resize((self.config.resolution, self.config.resolution), interpolation=InterpolationMode.NEAREST) # interpolation mode set to nearest for onnx export
156
+ self.model.backbone[0].encoder.encoder.embeddings.interpolate_pos_encoding = lambda self_mod, embeddings, height, width : self.model.backbone[0].encoder.encoder.embeddings.position_embeddings # skip interpolation for onnx export
157
+ else:
158
+ resize = Resize((self.config.resolution, self.config.resolution))
159
 
160
  if labels is not None:
161
  self.validate_labels(labels)