sposhiy commited on
Commit
28e25c7
·
verified ·
1 Parent(s): 11abf87

Upload mlp.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mlp.py +183 -0
mlp.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ classification-only implementation of the counting network that
3
+ uses shared MLP for all
4
+ '''
5
+
6
+ import time
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from huggingface_hub import PyTorchModelHubMixin
15
+
16
+ from util.misc import (NestedTensor, accuracy, get_world_size, interpolate,
17
+ is_dist_avail_and_initialized,
18
+ nested_tensor_from_tensor_list)
19
+
20
+ from .backbone import build_backbone
21
+ from .matcher import build_matcher_crowd
22
+ from .p2pnet import SetCriterion_Crowd
23
+ from .classification import *
24
+
25
+
26
+ class Linear(nn.Module):
27
+ def __init__(self, in_feat, out_feat):
28
+ super(Linear, self).__init__()
29
+
30
+
31
+ self.lin = nn.Sequential(
32
+ nn.Linear(in_feat, 512),
33
+ nn.ReLU(),
34
+ nn.Linear(512, 512),
35
+ nn.ReLU(),
36
+ nn.Linear(512, out_feat),
37
+ )
38
+
39
+ def forward(self, x):
40
+
41
+ x = x.permute([0,2,3,1])
42
+ x = x.flatten(start_dim=1, end_dim=2)
43
+
44
+ out = self.lin(x)
45
+
46
+ return out
47
+
48
+
49
+ class MLP_Classifier(nn.Module, PyTorchModelHubMixin,
50
+ repo_url="your-repo-url",
51
+ pipeline_tag="text-to-image",
52
+ license="mit",):
53
+ """MLP classifier-only predictor on top of FPN feature space"""
54
+
55
+ def __init__(self, backbone, num_classes, row=2, line=2):
56
+ super().__init__()
57
+
58
+ self.vgg_backbone = backbone
59
+ self.num_classes = num_classes + 1
60
+ self.row = row
61
+ self.line = line
62
+
63
+ num_anchor_points = row * line
64
+
65
+ self.linear = Linear(
66
+ in_feat=256,
67
+ out_feat=self.num_classes,
68
+ )
69
+
70
+ self.anchor_points = AnchorPoints(
71
+ pyramid_levels=[
72
+ 3,
73
+ ],
74
+ row=row,
75
+ line=line,
76
+ )
77
+
78
+ self.fpn = Decoder(256, 512, 512)
79
+
80
+ def forward(self, samples: NestedTensor):
81
+ # get the backbone (vgg) features
82
+ features = self.vgg_backbone(samples)
83
+ # construct the feature space
84
+ features_fpn = self.fpn([features[1], features[2], features[3]])
85
+ batch_size = features[0].shape[0]
86
+
87
+ # pass through the classifer
88
+ classification = self.linear(features_fpn[1])
89
+ anchor_points = self.anchor_points(samples).repeat(batch_size, 1, 1)
90
+
91
+ output_coord = anchor_points
92
+ output_class = classification
93
+ out = {"pred_logits": output_class, "pred_points": output_coord}
94
+
95
+ return out
96
+
97
+ class MLP(nn.Module, PyTorchModelHubMixin):
98
+ "MLP model for both regression and classification tasks"
99
+
100
+ def __init__(self, backbone, num_classes, row=2, line=2):
101
+ super().__init__()
102
+
103
+ self.vgg_backbone = backbone
104
+ self.num_classes = num_classes + 1
105
+ self.row = row
106
+ self.line = line
107
+
108
+ num_anchor_points = row * line
109
+
110
+ self.lin_class = Linear(
111
+ in_feat=256,
112
+ out_feat=self.num_classes,
113
+ )
114
+
115
+ self.lin_reg = Linear(
116
+ in_feat=256,
117
+ out_feat=self.num_classes,
118
+ )
119
+
120
+ self.anchor_points = AnchorPoints(
121
+ pyramid_levels=[
122
+ 3,
123
+ ],
124
+ row=row,
125
+ line=line,
126
+ )
127
+
128
+ self.fpn = Decoder(256, 512, 512)
129
+
130
+ def forward(self, samples: NestedTensor):
131
+ # get the backbone (vgg) features
132
+ features = self.vgg_backbone(samples)
133
+ # construct the feature space
134
+ features_fpn = self.fpn([features[1], features[2], features[3]])
135
+ batch_size = features[0].shape[0]
136
+
137
+ # pass sample through classification and regression branch
138
+ classification = self.lin_class(features_fpn[1])
139
+ regression = self.lin_reg(features_fpn[1]) * 100
140
+ anchor_points = self.anchor_points(samples).repeat(batch_size, 1, 1)
141
+
142
+ output_coord = regression + anchor_points
143
+ output_class = classification
144
+
145
+ out = {"pred_logits": output_class, "pred_points": output_coord}
146
+
147
+ return out
148
+
149
+ def build_mlp(args, training):
150
+
151
+ backbone = build_backbone(args)
152
+
153
+ # model selection logic
154
+ if args.mlp_classifier:
155
+ model = MLP_Classifier(backbone, args.num_classes, args.row, args.line)
156
+ elif args.mlp:
157
+ model = MLP(backbone, args.num_classes, args.row, args.line)
158
+
159
+ if not training:
160
+ return model
161
+
162
+ weight_dict = {
163
+ "loss_ce": args.label_loss_coef,
164
+ "loss_point": args.point_loss_coef,
165
+ "loss_dense": args.dense_loss_coef,
166
+ "loss_distance": args.distance_loss_coef,
167
+ "loss_count": args.count_loss_coef,
168
+ }
169
+
170
+ losses = args.loss
171
+ matcher = build_matcher_crowd(args)
172
+ criterion = SetCriterion_Crowd(
173
+ args.num_classes,
174
+ matcher=matcher,
175
+ weight_dict=weight_dict,
176
+ eos_coef=args.eos_coef,
177
+ ce_coef=args.ce_coef,
178
+ map_res=args.map_res,
179
+ gauss_kernel_res=args.gauss_kernel_res,
180
+ losses=losses,
181
+ )
182
+
183
+ return model, criterion