PITICHA commited on
Commit
14bd5ef
·
verified ·
1 Parent(s): 671a240

Create model/modeling_maskrcnn.py

Browse files
Files changed (1) hide show
  1. model/modeling_maskrcnn.py +42 -0
model/modeling_maskrcnn.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
3
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
4
+ from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
5
+
6
+ from transformers import PreTrainedModel
7
+ from .configuration_maskrcnn import MaskRCNNConfig
8
+
9
+
10
+ class MaskRCNNForInstanceSegmentation(PreTrainedModel):
11
+ config_class = MaskRCNNConfig
12
+
13
+ def __init__(self, config: MaskRCNNConfig):
14
+ super().__init__(config)
15
+ self.config = config
16
+
17
+ model = maskrcnn_resnet50_fpn(weights=None)
18
+
19
+ # box head
20
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
21
+ model.roi_heads.box_predictor = FastRCNNPredictor(
22
+ in_features, config.num_classes
23
+ )
24
+
25
+ # mask head
26
+ in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
27
+ model.roi_heads.mask_predictor = MaskRCNNPredictor(
28
+ in_features_mask,
29
+ config.hidden_layer,
30
+ config.num_classes
31
+ )
32
+
33
+ self.model = model
34
+
35
+ def forward(self, images, targets=None):
36
+ """
37
+ Train:
38
+ returns dict(loss_classifier, loss_box_reg, loss_mask, ...)
39
+ Eval:
40
+ returns List[Dict(boxes, labels, scores, masks)]
41
+ """
42
+ return self.model(images, targets)