Thastp commited on
Commit
a94d598
·
verified ·
1 Parent(s): 47dc41d

Upload model

Browse files
Files changed (1) hide show
  1. modeling_rf_detr.py +2 -2
modeling_rf_detr.py CHANGED
@@ -2,7 +2,7 @@ from dataclasses import dataclass
2
  from typing import List, Dict
3
 
4
  import torch
5
- from torchvision.transforms import Resize
6
  from transformers import PreTrainedModel
7
  from transformers.utils import ModelOutput
8
  from rfdetr import RFDETRBase, RFDETRLarge
@@ -118,7 +118,7 @@ class RFDetrModelForObjectDetection(PreTrainedModel):
118
  label["labels"] = label["labels"].to(self.config.device)
119
 
120
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor, labels=None, **kwargs) -> ModelOutput:
121
- resize = Resize((self.config.resolution, self.config.resolution), antialias=False) # antialias set to False for onnx export
122
 
123
  if labels is not None:
124
  self.validate_labels(labels)
 
2
  from typing import List, Dict
3
 
4
  import torch
5
+ from torchvision.transforms import Resize, InterpolationMode
6
  from transformers import PreTrainedModel
7
  from transformers.utils import ModelOutput
8
  from rfdetr import RFDETRBase, RFDETRLarge
 
118
  label["labels"] = label["labels"].to(self.config.device)
119
 
120
  def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor, labels=None, **kwargs) -> ModelOutput:
121
+ resize = Resize((self.config.resolution, self.config.resolution), interpolation=InterpolationMode.NEAREST) # interpolation mode set to nearest for onnx export
122
 
123
  if labels is not None:
124
  self.validate_labels(labels)