Upload model
Browse files- modeling_rf_detr.py +2 -2
modeling_rf_detr.py
CHANGED
|
@@ -2,7 +2,7 @@ from dataclasses import dataclass
|
|
| 2 |
from typing import List, Dict
|
| 3 |
|
| 4 |
import torch
|
| 5 |
-
from torchvision.transforms import Resize
|
| 6 |
from transformers import PreTrainedModel
|
| 7 |
from transformers.utils import ModelOutput
|
| 8 |
from rfdetr import RFDETRBase, RFDETRLarge
|
|
@@ -118,7 +118,7 @@ class RFDetrModelForObjectDetection(PreTrainedModel):
|
|
| 118 |
label["labels"] = label["labels"].to(self.config.device)
|
| 119 |
|
| 120 |
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor, labels=None, **kwargs) -> ModelOutput:
|
| 121 |
-
resize = Resize((self.config.resolution, self.config.resolution),
|
| 122 |
|
| 123 |
if labels is not None:
|
| 124 |
self.validate_labels(labels)
|
|
|
|
| 2 |
from typing import List, Dict
|
| 3 |
|
| 4 |
import torch
|
| 5 |
+
from torchvision.transforms import Resize, InterpolationMode
|
| 6 |
from transformers import PreTrainedModel
|
| 7 |
from transformers.utils import ModelOutput
|
| 8 |
from rfdetr import RFDETRBase, RFDETRLarge
|
|
|
|
| 118 |
label["labels"] = label["labels"].to(self.config.device)
|
| 119 |
|
| 120 |
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor, labels=None, **kwargs) -> ModelOutput:
|
| 121 |
+
resize = Resize((self.config.resolution, self.config.resolution), interpolation=InterpolationMode.NEAREST) # interpolation mode set to nearest for onnx export
|
| 122 |
|
| 123 |
if labels is not None:
|
| 124 |
self.validate_labels(labels)
|