Update modelling_magiv2.py
Browse files- modelling_magiv2.py +4 -5
modelling_magiv2.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
-
from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel
|
| 2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
| 3 |
ConditionalDetrMLPPredictionHead,
|
| 4 |
ConditionalDetrModelOutput,
|
| 5 |
-
ConditionalDetrHungarianMatcher,
|
| 6 |
inverse_sigmoid,
|
| 7 |
)
|
| 8 |
from .configuration_magiv2 import Magiv2Config
|
|
@@ -55,14 +54,14 @@ class Magiv2Model(PreTrainedModel):
|
|
| 55 |
self.class_labels_classifier = nn.Linear(
|
| 56 |
config.detection_model_config.d_model, config.detection_model_config.num_labels
|
| 57 |
)
|
| 58 |
-
|
| 59 |
config.detection_model_config.d_model, 1
|
| 60 |
)
|
| 61 |
-
self.matcher =
|
| 62 |
class_cost=config.detection_model_config.class_cost,
|
| 63 |
bbox_cost=config.detection_model_config.bbox_cost,
|
| 64 |
giou_cost=config.detection_model_config.giou_cost
|
| 65 |
-
)
|
| 66 |
|
| 67 |
def move_to_device(self, input):
|
| 68 |
return move_to_device(input, self.device)
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel, ConditionalDetrConfig
|
| 2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
| 3 |
ConditionalDetrMLPPredictionHead,
|
| 4 |
ConditionalDetrModelOutput,
|
|
|
|
| 5 |
inverse_sigmoid,
|
| 6 |
)
|
| 7 |
from .configuration_magiv2 import Magiv2Config
|
|
|
|
| 54 |
self.class_labels_classifier = nn.Linear(
|
| 55 |
config.detection_model_config.d_model, config.detection_model_config.num_labels
|
| 56 |
)
|
| 57 |
+
self.is_this_text_a_dialogue = nn.Linear(
|
| 58 |
config.detection_model_config.d_model, 1
|
| 59 |
)
|
| 60 |
+
self.matcher = ConditionalDetrModel(ConditionalDetrConfig(
|
| 61 |
class_cost=config.detection_model_config.class_cost,
|
| 62 |
bbox_cost=config.detection_model_config.bbox_cost,
|
| 63 |
giou_cost=config.detection_model_config.giou_cost
|
| 64 |
+
))
|
| 65 |
|
| 66 |
def move_to_device(self, input):
|
| 67 |
return move_to_device(input, self.device)
|