HichTala commited on
Commit
a28e591
·
verified ·
1 Parent(s): 9c5dfaa

Delete diffusiondet/head.py

Browse files
Files changed (1) hide show
  1. diffusiondet/head.py +0 -386
diffusiondet/head.py DELETED
@@ -1,386 +0,0 @@
1
- import copy
2
- import math
3
- from dataclasses import astuple
4
-
5
- import torch
6
- from torch import nn
7
- from torch.nn.modules.transformer import _get_activation_fn
8
- from torchvision.ops import RoIAlign
9
-
10
- _DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
11
-
12
- def convert_boxes_to_pooler_format(bboxes):
13
- bs, num_proposals = bboxes.shape[:2]
14
- sizes = torch.full((bs,), num_proposals).to(bboxes.device)
15
- aggregated_bboxes = bboxes.view(bs * num_proposals, -1)
16
- indices = torch.repeat_interleave(
17
- torch.arange(len(sizes), dtype=aggregated_bboxes.dtype, device=aggregated_bboxes.device), sizes
18
- )
19
- return torch.cat([indices[:, None], aggregated_bboxes], dim=1)
20
-
21
-
22
- def assign_boxes_to_levels(
23
- bboxes,
24
- min_level,
25
- max_level,
26
- canonical_box_size,
27
- canonical_level,
28
- ):
29
- aggregated_bboxes = bboxes.view(bboxes.shape[0] * bboxes.shape[1], -1)
30
- area = (aggregated_bboxes[:, 2] - aggregated_bboxes[:, 0]) * (aggregated_bboxes[:, 3] - aggregated_bboxes[:, 1])
31
- box_sizes = torch.sqrt(area)
32
- # Eqn.(1) in FPN paper
33
- level_assignments = torch.floor(canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8))
34
- # clamp level to (min, max), in case the box size is too large or too small
35
- # for the available feature maps
36
- level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level)
37
- return level_assignments.to(torch.int64) - min_level
38
-
39
-
40
- class SinusoidalPositionEmbeddings(nn.Module):
41
- def __init__(self, dim):
42
- super().__init__()
43
- self.dim = dim
44
-
45
- def forward(self, time):
46
- device = time.device
47
- half_dim = self.dim // 2
48
- embeddings = math.log(10000) / (half_dim - 1)
49
- embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
50
- embeddings = time[:, None] * embeddings[None, :]
51
- embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
52
- return embeddings
53
-
54
-
55
- class HeadDynamicK(nn.Module):
56
- def __init__(self, config, roi_input_shape):
57
- super().__init__()
58
- num_classes = config.num_labels
59
-
60
- ddet_head = DiffusionDetHead(config, roi_input_shape, num_classes)
61
- self.num_head = config.num_heads
62
- self.head_series = nn.ModuleList([copy.deepcopy(ddet_head) for _ in range(self.num_head)])
63
- self.return_intermediate = config.deep_supervision
64
-
65
- # Gaussian random feature embedding layer for time
66
- self.hidden_dim = config.hidden_dim
67
- time_dim = self.hidden_dim * 4
68
- self.time_mlp = nn.Sequential(
69
- SinusoidalPositionEmbeddings(self.hidden_dim),
70
- nn.Linear(self.hidden_dim, time_dim),
71
- nn.GELU(),
72
- nn.Linear(time_dim, time_dim),
73
- )
74
-
75
- # Init parameters.
76
- self.use_focal = config.use_focal
77
- self.use_fed_loss = config.use_fed_loss
78
- self.num_classes = num_classes
79
- if self.use_focal or self.use_fed_loss:
80
- prior_prob = config.prior_prob
81
- self.bias_value = -math.log((1 - prior_prob) / prior_prob)
82
- self._reset_parameters()
83
-
84
- def _reset_parameters(self):
85
- # init all parameters.
86
- for p in self.parameters():
87
- if p.dim() > 1:
88
- nn.init.xavier_uniform_(p)
89
-
90
- # initialize the bias for focal loss and fed loss.
91
- if self.use_focal or self.use_fed_loss:
92
- if p.shape[-1] == self.num_classes or p.shape[-1] == self.num_classes + 1:
93
- nn.init.constant_(p, self.bias_value)
94
-
95
-
96
- def forward(self, features, bboxes, t):
97
- # assert t shape (batch_size)
98
- time = self.time_mlp(t)
99
-
100
- inter_class_logits = []
101
- inter_pred_bboxes = []
102
-
103
- bs = len(features[0])
104
-
105
- class_logits, pred_bboxes = None, None
106
- for head_idx, ddet_head in enumerate(self.head_series):
107
- class_logits, pred_bboxes, proposal_features = ddet_head(features, bboxes, time)
108
- if self.return_intermediate:
109
- inter_class_logits.append(class_logits)
110
- inter_pred_bboxes.append(pred_bboxes)
111
- bboxes = pred_bboxes.detach()
112
-
113
- if self.return_intermediate:
114
- return torch.stack(inter_class_logits), torch.stack(inter_pred_bboxes)
115
-
116
- return class_logits[None], pred_bboxes[None]
117
-
118
-
119
- class DynamicConv(nn.Module):
120
- def __init__(self, config):
121
- super().__init__()
122
-
123
- self.hidden_dim = config.hidden_dim
124
- self.dim_dynamic = config.dim_dynamic
125
- self.num_dynamic = config.num_dynamic
126
- self.num_params = self.hidden_dim * self.dim_dynamic
127
- self.dynamic_layer = nn.Linear(self.hidden_dim, self.num_dynamic * self.num_params)
128
-
129
- self.norm1 = nn.LayerNorm(self.dim_dynamic)
130
- self.norm2 = nn.LayerNorm(self.hidden_dim)
131
-
132
- self.activation = nn.ReLU(inplace=True)
133
-
134
- pooler_resolution = config.pooler_resolution
135
- num_output = self.hidden_dim * pooler_resolution ** 2
136
- self.out_layer = nn.Linear(num_output, self.hidden_dim)
137
- self.norm3 = nn.LayerNorm(self.hidden_dim)
138
-
139
-
140
- def forward(self, pro_features, roi_features):
141
- features = roi_features.permute(1, 0, 2)
142
- parameters = self.dynamic_layer(pro_features).permute(1, 0, 2)
143
-
144
- param1 = parameters[:, :, :self.num_params].view(-1, self.hidden_dim, self.dim_dynamic)
145
- param2 = parameters[:, :, self.num_params:].view(-1, self.dim_dynamic, self.hidden_dim)
146
-
147
- features = torch.bmm(features, param1)
148
- features = self.norm1(features)
149
- features = self.activation(features)
150
-
151
- features = torch.bmm(features, param2)
152
- features = self.norm2(features)
153
- features = self.activation(features)
154
-
155
- features = features.flatten(1)
156
- features = self.out_layer(features)
157
- features = self.norm3(features)
158
- features = self.activation(features)
159
-
160
- return features
161
-
162
-
163
- class DiffusionDetHead(nn.Module):
164
- def __init__(self, config, roi_input_shape, num_classes):
165
- super().__init__()
166
-
167
- dim_feedforward = config.dim_feedforward
168
- nhead = config.num_attn_heads
169
- dropout = config.dropout
170
- activation = config.activation
171
- in_features = config.roi_head_in_features
172
- pooler_resolution = config.pooler_resolution
173
- pooler_scales = tuple(1.0 / roi_input_shape[k]['stride'] for k in in_features)
174
- sampling_ratio = config.sampling_ratio
175
-
176
- self.hidden_dim = config.hidden_dim
177
-
178
- self.pooler = ROIPooler(
179
- output_size=pooler_resolution,
180
- scales=pooler_scales,
181
- sampling_ratio=sampling_ratio,
182
- )
183
-
184
- # dynamic.
185
- self.self_attn = nn.MultiheadAttention(self.hidden_dim, nhead, dropout=dropout)
186
- self.inst_interact = DynamicConv(config)
187
-
188
- self.linear1 = nn.Linear(self.hidden_dim, dim_feedforward)
189
- self.dropout = nn.Dropout(dropout)
190
- self.linear2 = nn.Linear(dim_feedforward, self.hidden_dim)
191
-
192
- self.norm1 = nn.LayerNorm(self.hidden_dim)
193
- self.norm2 = nn.LayerNorm(self.hidden_dim)
194
- self.norm3 = nn.LayerNorm(self.hidden_dim)
195
- self.dropout1 = nn.Dropout(dropout)
196
- self.dropout2 = nn.Dropout(dropout)
197
- self.dropout3 = nn.Dropout(dropout)
198
-
199
- self.activation = _get_activation_fn(activation)
200
-
201
- # block time mlp
202
- self.block_time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(self.hidden_dim * 4, self.hidden_dim * 2))
203
-
204
- # cls.
205
- num_cls = config.num_cls
206
- cls_module = list()
207
- for _ in range(num_cls):
208
- cls_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False))
209
- cls_module.append(nn.LayerNorm(self.hidden_dim))
210
- cls_module.append(nn.ReLU(inplace=True))
211
- self.cls_module = nn.ModuleList(cls_module)
212
-
213
- # reg.
214
- num_reg = config.num_reg
215
- reg_module = list()
216
- for _ in range(num_reg):
217
- reg_module.append(nn.Linear(self.hidden_dim, self.hidden_dim, False))
218
- reg_module.append(nn.LayerNorm(self.hidden_dim))
219
- reg_module.append(nn.ReLU(inplace=True))
220
- self.reg_module = nn.ModuleList(reg_module)
221
-
222
- # pred.
223
- self.use_focal = config.use_focal
224
- self.use_fed_loss = config.use_fed_loss
225
- if self.use_focal or self.use_fed_loss:
226
- self.class_logits = nn.Linear(self.hidden_dim, num_classes)
227
- else:
228
- self.class_logits = nn.Linear(self.hidden_dim, num_classes + 1)
229
- self.bboxes_delta = nn.Linear(self.hidden_dim, 4)
230
- self.scale_clamp = _DEFAULT_SCALE_CLAMP
231
- self.bbox_weights = (2.0, 2.0, 1.0, 1.0)
232
-
233
- def forward(self, features, bboxes, time_emb):
234
- bs, num_proposals = bboxes.shape[:2]
235
-
236
- # roi_feature.
237
- roi_features = self.pooler(features, bboxes)
238
-
239
- pro_features = roi_features.view(bs, num_proposals, self.hidden_dim, -1).mean(-1)
240
-
241
- roi_features = roi_features.view(bs * num_proposals, self.hidden_dim, -1).permute(2, 0, 1)
242
-
243
- # self_att.
244
- pro_features = pro_features.view(bs, num_proposals, self.hidden_dim).permute(1, 0, 2)
245
- pro_features2 = self.self_attn(pro_features, pro_features, value=pro_features)[0]
246
- pro_features = pro_features + self.dropout1(pro_features2)
247
- pro_features = self.norm1(pro_features)
248
-
249
- # inst_interact.
250
- pro_features = pro_features.view(num_proposals, bs, self.hidden_dim).permute(1, 0, 2).reshape(1, bs * num_proposals,
251
- self.hidden_dim)
252
- pro_features2 = self.inst_interact(pro_features, roi_features)
253
- pro_features = pro_features + self.dropout2(pro_features2)
254
- obj_features = self.norm2(pro_features)
255
-
256
- # obj_feature.
257
- obj_features2 = self.linear2(self.dropout(self.activation(self.linear1(obj_features))))
258
- obj_features = obj_features + self.dropout3(obj_features2)
259
- obj_features = self.norm3(obj_features)
260
-
261
- fc_feature = obj_features.transpose(0, 1).reshape(bs * num_proposals, -1)
262
-
263
- scale_shift = self.block_time_mlp(time_emb)
264
- scale_shift = torch.repeat_interleave(scale_shift, num_proposals, dim=0)
265
- scale, shift = scale_shift.chunk(2, dim=1)
266
- fc_feature = fc_feature * (scale + 1) + shift
267
-
268
- cls_feature = fc_feature.clone()
269
- reg_feature = fc_feature.clone()
270
- for cls_layer in self.cls_module:
271
- cls_feature = cls_layer(cls_feature)
272
- for reg_layer in self.reg_module:
273
- reg_feature = reg_layer(reg_feature)
274
- class_logits = self.class_logits(cls_feature)
275
- bboxes_deltas = self.bboxes_delta(reg_feature)
276
- pred_bboxes = self.apply_deltas(bboxes_deltas, bboxes.view(-1, 4))
277
-
278
- return class_logits.view(bs, num_proposals, -1), pred_bboxes.view(bs, num_proposals, -1), obj_features
279
-
280
- def apply_deltas(self, deltas, boxes):
281
- """
282
- Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`.
283
-
284
- Args:
285
- deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1.
286
- deltas[i] represents k potentially different class-specific
287
- box transformations for the single box boxes[i].
288
- boxes (Tensor): boxes to transform, of shape (N, 4)
289
- """
290
- boxes = boxes.to(deltas.dtype)
291
-
292
- widths = boxes[:, 2] - boxes[:, 0]
293
- heights = boxes[:, 3] - boxes[:, 1]
294
- ctr_x = boxes[:, 0] + 0.5 * widths
295
- ctr_y = boxes[:, 1] + 0.5 * heights
296
-
297
- wx, wy, ww, wh = self.bbox_weights
298
- dx = deltas[:, 0::4] / wx
299
- dy = deltas[:, 1::4] / wy
300
- dw = deltas[:, 2::4] / ww
301
- dh = deltas[:, 3::4] / wh
302
-
303
- # Prevent sending too large values into torch.exp()
304
- dw = torch.clamp(dw, max=self.scale_clamp)
305
- dh = torch.clamp(dh, max=self.scale_clamp)
306
-
307
- pred_ctr_x = dx * widths[:, None] + ctr_x[:, None]
308
- pred_ctr_y = dy * heights[:, None] + ctr_y[:, None]
309
- pred_w = torch.exp(dw) * widths[:, None]
310
- pred_h = torch.exp(dh) * heights[:, None]
311
-
312
- pred_boxes = torch.zeros_like(deltas)
313
- pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # x1
314
- pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # y1
315
- pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w # x2
316
- pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h # y2
317
-
318
- return pred_boxes
319
-
320
-
321
- class ROIPooler(nn.Module):
322
- """
323
- Region of interest feature map pooler that supports pooling from one or more
324
- feature maps.
325
- """
326
-
327
- def __init__(
328
- self,
329
- output_size,
330
- scales,
331
- sampling_ratio,
332
- canonical_box_size=224,
333
- canonical_level=4,
334
- ):
335
- super().__init__()
336
-
337
- min_level = -(math.log2(scales[0]))
338
- max_level = -(math.log2(scales[-1]))
339
-
340
- if isinstance(output_size, int):
341
- output_size = (output_size, output_size)
342
- assert len(output_size) == 2 and isinstance(output_size[0], int) and isinstance(output_size[1], int)
343
- assert math.isclose(min_level, int(min_level)) and math.isclose(max_level, int(max_level))
344
- assert (len(scales) == max_level - min_level + 1)
345
- assert 0 <= min_level <= max_level
346
- assert canonical_box_size > 0
347
-
348
- self.output_size = output_size
349
- self.min_level = int(min_level)
350
- self.max_level = int(max_level)
351
- self.canonical_level = canonical_level
352
- self.canonical_box_size = canonical_box_size
353
- self.level_poolers = nn.ModuleList(
354
- RoIAlign(
355
- output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True
356
- )
357
- for scale in scales
358
- )
359
-
360
- def forward(self, x, bboxes):
361
- num_level_assignments = len(self.level_poolers)
362
- assert len(x) == num_level_assignments and len(bboxes) == x[0].size(0)
363
-
364
- pooler_fmt_boxes = convert_boxes_to_pooler_format(bboxes)
365
-
366
- if num_level_assignments == 1:
367
- return self.level_poolers[0](x[0], pooler_fmt_boxes)
368
-
369
- level_assignments = assign_boxes_to_levels(
370
- bboxes, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level
371
- )
372
-
373
- batches = pooler_fmt_boxes.shape[0]
374
- channels = x[0].shape[1]
375
- output_size = self.output_size[0]
376
- sizes = (batches, channels, output_size, output_size)
377
-
378
- output = torch.zeros(sizes, dtype=x[0].dtype, device=x[0].device)
379
-
380
- for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)):
381
- inds = (level_assignments == level).nonzero(as_tuple=True)[0]
382
- pooler_fmt_boxes_level = pooler_fmt_boxes[inds]
383
- # Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852
384
- output.index_put_((inds,), pooler(x_level, pooler_fmt_boxes_level))
385
-
386
- return output