matejpekar commited on
Commit
d57f1a3
·
verified ·
1 Parent(s): 3960669

Upload model

Browse files
Files changed (4) hide show
  1. config.json +1 -7
  2. configuration.py +20 -4
  3. model.safetensors +2 -2
  4. modeling.py +109 -125
config.json CHANGED
@@ -14,14 +14,8 @@
14
  ],
15
  "dim": 384,
16
  "dropout": 0.1,
17
- "in_channels": [
18
- 768,
19
- 384,
20
- 192,
21
- 96
22
- ],
23
  "model_type": "lsp_detr",
24
- "num_classes": 2,
25
  "num_heads": 12,
26
  "num_radial_distances": 64,
27
  "query_block_size": 16,
 
14
  ],
15
  "dim": 384,
16
  "dropout": 0.1,
 
 
 
 
 
 
17
  "model_type": "lsp_detr",
18
+ "num_classes": 1,
19
  "num_heads": 12,
20
  "num_radial_distances": 64,
21
  "query_block_size": 16,
configuration.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  from transformers import PretrainedConfig
 
2
 
3
 
4
  class LSPDetrConfig(PretrainedConfig):
@@ -6,11 +9,14 @@ class LSPDetrConfig(PretrainedConfig):
6
 
7
  def __init__(
8
  self,
9
- backbone="microsoft/swinv2-tiny-patch4-window16-256",
 
 
 
 
10
  dim: int = 384,
11
- num_classes: int = 2,
12
  depths: tuple[int, ...] = (6, 2, 2),
13
- in_channels: tuple[int, ...] = (768, 384, 192, 96),
14
  query_block_size: int = 16,
15
  num_heads: int = 12,
16
  window_size: int = 16,
@@ -20,11 +26,21 @@ class LSPDetrConfig(PretrainedConfig):
20
  dropout: float = 0.1,
21
  **kwargs,
22
  ) -> None:
 
 
 
 
 
 
 
 
 
 
 
23
  self.backbone = backbone
24
  self.dim = dim
25
  self.num_classes = num_classes
26
  self.depths = depths
27
- self.in_channels = in_channels
28
  self.query_block_size = query_block_size
29
  self.num_heads = num_heads
30
  self.window_size = window_size
 
1
+ from typing import Any
2
+
3
  from transformers import PretrainedConfig
4
+ from transformers.utils.backbone_utils import verify_backbone_config_arguments
5
 
6
 
7
  class LSPDetrConfig(PretrainedConfig):
 
9
 
10
  def __init__(
11
  self,
12
+ use_timm_backbone: bool = False,
13
+ use_pretrained_backbone: bool = True,
14
+ backbone: str = "microsoft/swinv2-tiny-patch4-window16-256",
15
+ backbone_kwargs: dict[str, Any] | None = None,
16
+ backbone_config: Any | None = None,
17
  dim: int = 384,
18
+ num_classes: int = 1,
19
  depths: tuple[int, ...] = (6, 2, 2),
 
20
  query_block_size: int = 16,
21
  num_heads: int = 12,
22
  window_size: int = 16,
 
26
  dropout: float = 0.1,
27
  **kwargs,
28
  ) -> None:
29
+ if backbone_kwargs is None:
30
+ backbone_kwargs = {"out_features": ["stage1", "stage2", "stage3", "stage4"]}
31
+
32
+ verify_backbone_config_arguments(
33
+ use_timm_backbone=use_timm_backbone,
34
+ use_pretrained_backbone=use_pretrained_backbone,
35
+ backbone=backbone,
36
+ backbone_config=backbone_config,
37
+ backbone_kwargs=backbone_kwargs,
38
+ )
39
+
40
  self.backbone = backbone
41
  self.dim = dim
42
  self.num_classes = num_classes
43
  self.depths = depths
 
44
  self.query_block_size = query_block_size
45
  self.num_heads = num_heads
46
  self.window_size = window_size
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:34df87bc194a31875e6fad557746c2c5e94f027039211df60054807f61107bd0
3
- size 205650424
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6411cad5a0ebad05cbeb8324502f020a4a2a145fa4605dd09757cedb1018ad45
3
+ size 205648888
modeling.py CHANGED
@@ -5,12 +5,65 @@ import torch.nn.functional as F
5
  from einops import rearrange, repeat
6
  from torch import Tensor, nn
7
  from torch.nn.utils import parametrize
8
- from transformers import PreTrainedModel, Swinv2Backbone
9
  from transformers.models.swinv2.modeling_swinv2 import window_partition, window_reverse
 
10
 
11
  from .configuration import LSPDetrConfig
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
15
  freqs_x = []
16
  freqs_y = []
@@ -107,56 +160,11 @@ class CayleySTRING(nn.Module):
107
  return out.type_as(x)
108
 
109
 
110
- class MLP(nn.Sequential):
111
- """Very simple multi-layer perceptron."""
112
-
113
- def __init__(
114
- self,
115
- input_dim: int,
116
- hidden_dim: int,
117
- output_dim: int,
118
- num_layers: int,
119
- act_layer: type[nn.Module] = nn.ReLU,
120
- dropout: float = 0.0,
121
- ) -> None:
122
- assert num_layers > 1
123
-
124
- layers = []
125
- h = [hidden_dim] * (num_layers - 1)
126
- for n, k in zip([input_dim, *h], h, strict=False):
127
- layers.append(nn.Linear(n, k))
128
- layers.append(act_layer())
129
- if dropout > 0:
130
- layers.append(nn.Dropout(dropout))
131
-
132
- layers.append(nn.Linear(hidden_dim, output_dim))
133
- super().__init__(*layers)
134
-
135
-
136
- class FeedForward(nn.Module):
137
- """FeedForward module.
138
-
139
- Taken from https://github.com/meta-llama/llama-models/blob/main/models/llama4/ffn.py
140
- """
141
-
142
- def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256) -> None:
143
- """Initialize the FeedForward module.
144
-
145
- Args:
146
- dim (int): Input dimension.
147
- hidden_dim (int): Hidden dimension of the feedforward layer.
148
- multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
149
- """
150
- super().__init__()
151
- hidden_dim = int(2 * hidden_dim / 3)
152
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
153
-
154
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
155
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
156
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
157
-
158
- def forward(self, x: Tensor) -> Tensor:
159
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
160
 
161
 
162
  @torch.autocast("cuda", enabled=False)
@@ -261,6 +269,13 @@ class WindowCrossAttention(nn.Module):
261
  self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coord: Tensor
262
  ) -> Tensor:
263
  b, h, w, c = tgt.shape
 
 
 
 
 
 
 
264
  src_h, src_w = src.shape[1:3]
265
 
266
  # cyclic shift
@@ -286,7 +301,9 @@ class WindowCrossAttention(nn.Module):
286
  src = window_partition(src, self.src_window_size).flatten(1, 2)
287
  src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2)
288
 
289
- attn_mask = self.get_attn_mask(h, w, src_h, src_w, tgt.device, tgt.dtype)
 
 
290
 
291
  if attn_mask is not None:
292
  attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
@@ -307,7 +324,7 @@ class WindowCrossAttention(nn.Module):
307
 
308
  # merge windows
309
  tgt = tgt.view(-1, self.tgt_window_size, self.tgt_window_size, c)
310
- tgt = window_reverse(tgt, self.tgt_window_size, h, w)
311
 
312
  # reverse cyclic shift
313
  if self.tgt_shift_size > 0:
@@ -315,7 +332,7 @@ class WindowCrossAttention(nn.Module):
315
  tgt, shifts=(self.tgt_shift_size, self.tgt_shift_size), dims=(1, 2)
316
  )
317
 
318
- return tgt
319
 
320
 
321
  class WindowSelfAttention(nn.Module):
@@ -360,6 +377,11 @@ class WindowSelfAttention(nn.Module):
360
  """
361
  b, h, w, c = x.shape
362
 
 
 
 
 
 
363
  # cyclic shift
364
  if self.shift_size > 0:
365
  x = x.roll(shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
@@ -371,7 +393,7 @@ class WindowSelfAttention(nn.Module):
371
  x = window_partition(x, self.window_size).flatten(1, 2)
372
  coords = window_partition(coords, self.window_size).flatten(1, 2)
373
 
374
- attn_mask = self.get_attn_mask(h, w, x.device, x.dtype)
375
  if attn_mask is not None:
376
  attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
377
 
@@ -390,13 +412,13 @@ class WindowSelfAttention(nn.Module):
390
 
391
  # merge windows
392
  x = x.view(-1, self.window_size, self.window_size, c)
393
- x = window_reverse(x, self.window_size, h, w)
394
 
395
  # reverse cyclic shift
396
  if self.shift_size > 0:
397
  x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
398
 
399
- return x
400
 
401
 
402
  class Block(nn.Module):
@@ -501,47 +523,39 @@ class FeatureSampling(nn.Module):
501
  class LSPTransformer(nn.Module):
502
  def __init__(
503
  self,
504
- dim: int,
505
- num_classes: int,
506
- query_block_size: int,
507
- in_channels: list[int],
508
- depths: list[int],
509
- num_heads: int,
510
- window_size: int,
511
- tgt_window_sizes: list[int],
512
- src_window_sizes: list[int],
513
- num_radial_distances: int,
514
- dropout: float = 0.0,
515
  ) -> None:
516
  super().__init__()
517
 
518
- self.dim = dim
519
- self.query_block_size = query_block_size
520
- self.num_radial_distances = num_radial_distances
521
 
522
- bottleneck, *in_channels = in_channels
523
- self.feature_sampling = FeatureSampling(bottleneck, dim)
524
 
525
  self.stages = nn.ModuleList()
526
- for i, depth in enumerate(depths):
527
  stage = Stage(
528
- dim=dim,
529
- src_dim=in_channels[i],
530
  depth=depth,
531
- num_heads=num_heads,
532
- window_size=window_size,
533
- tgt_window_size=tgt_window_sizes[i],
534
- src_window_size=src_window_sizes[i],
535
- dropout=dropout,
536
  )
537
  self.stages.append(stage)
538
 
539
- self.input_norm = nn.ModuleList(nn.LayerNorm(d) for d in in_channels)
540
 
541
  # output heads
542
- self.class_head = nn.Linear(dim, num_classes + 1, bias=False)
543
- self.point_head = MLP(dim, dim, 2, 3)
544
- self.radial_distances_head = MLP(dim, dim, num_radial_distances, 3)
 
 
545
 
546
  self.init_weights()
547
 
@@ -551,15 +565,13 @@ class LSPTransformer(nn.Module):
551
  nn.init.constant_(self.point_head[-1].bias, 0.0)
552
 
553
  def forward(
554
- self, multi_scale_features: list[Tensor], height: int, width: int
555
  ) -> dict[str, Tensor | list[dict[str, Tensor]]]:
556
- *multi_scale_features, bottleneck = multi_scale_features
557
-
558
  b = bottleneck.size(0)
559
 
560
  src = []
561
  src_coords = []
562
- for i, feature in enumerate(reversed(multi_scale_features)):
563
  h, w = feature.shape[2:4]
564
  coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
565
  src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c")))
@@ -610,10 +622,9 @@ class LSPTransformer(nn.Module):
610
  "logits": logits_list[-1],
611
  "points": ref_points_list[-1],
612
  "radial_distances": radial_distances_list[-1],
613
- "polygons": self.get_polygons(
614
- relative_to_absolute_points(ref_points, height, width).flatten(1, 2),
615
- radial_distances_list[-1],
616
- ),
617
  "aux_outputs": [
618
  {
619
  "logits": a,
@@ -629,19 +640,6 @@ class LSPTransformer(nn.Module):
629
  ],
630
  }
631
 
632
- @torch.no_grad()
633
- @torch.autocast("cuda", enabled=False)
634
- def get_polygons(self, ref_points: Tensor, radial_distances: Tensor) -> Tensor:
635
- t = torch.linspace(
636
- 0, 1, self.num_radial_distances + 1, device=ref_points.device
637
- )[:-1]
638
- cos = torch.cos(2 * torch.pi * t)
639
- sin = torch.sin(2 * torch.pi * t)
640
-
641
- radial_distances = radial_distances.expm1()
642
- polar = radial_distances.unsqueeze(-1) * torch.stack([sin, cos], dim=-1)
643
- return ref_points.unsqueeze(-2) + polar
644
-
645
 
646
  class LSPDetrModel(PreTrainedModel):
647
  config_class = LSPDetrConfig
@@ -649,25 +647,11 @@ class LSPDetrModel(PreTrainedModel):
649
  def __init__(self, config: LSPDetrConfig) -> None:
650
  super().__init__(config)
651
 
652
- self.backbone = Swinv2Backbone.from_pretrained(
653
- config.backbone, out_features=["stage1", "stage2", "stage3", "stage4"]
654
- )
655
-
656
- self.decode_head = LSPTransformer(
657
- dim=config.dim,
658
- num_classes=config.num_classes,
659
- query_block_size=config.query_block_size,
660
- in_channels=config.in_channels,
661
- depths=config.depths,
662
- num_heads=config.num_heads,
663
- window_size=config.window_size,
664
- tgt_window_sizes=config.tgt_window_sizes,
665
- src_window_sizes=config.src_window_sizes,
666
- num_radial_distances=config.num_radial_distances,
667
- dropout=config.dropout,
668
- )
669
 
670
  def forward(self, image: Tensor) -> dict[str, Tensor]:
671
- features = self.backbone(image).feature_maps
672
  height, width = image.shape[2:]
673
- return self.decode_head(features, height, width)
 
5
  from einops import rearrange, repeat
6
  from torch import Tensor, nn
7
  from torch.nn.utils import parametrize
8
+ from transformers import PreTrainedModel
9
  from transformers.models.swinv2.modeling_swinv2 import window_partition, window_reverse
10
+ from transformers.utils.backbone_utils import load_backbone
11
 
12
  from .configuration import LSPDetrConfig
13
 
14
 
15
+ class MLP(nn.Sequential):
16
+ """Very simple multi-layer perceptron."""
17
+
18
+ def __init__(
19
+ self,
20
+ input_dim: int,
21
+ hidden_dim: int,
22
+ output_dim: int,
23
+ num_layers: int,
24
+ act_layer: type[nn.Module] = nn.ReLU,
25
+ dropout: float = 0.0,
26
+ ) -> None:
27
+ assert num_layers > 1
28
+
29
+ layers = []
30
+ h = [hidden_dim] * (num_layers - 1)
31
+ for n, k in zip([input_dim, *h], h, strict=False):
32
+ layers.append(nn.Linear(n, k))
33
+ layers.append(act_layer())
34
+ if dropout > 0:
35
+ layers.append(nn.Dropout(dropout))
36
+
37
+ layers.append(nn.Linear(hidden_dim, output_dim))
38
+ super().__init__(*layers)
39
+
40
+
41
+ class FeedForward(nn.Module):
42
+ """FeedForward module.
43
+
44
+ Taken from https://github.com/meta-llama/llama-models/blob/main/models/llama4/ffn.py
45
+ """
46
+
47
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256) -> None:
48
+ """Initialize the FeedForward module.
49
+
50
+ Args:
51
+ dim (int): Input dimension.
52
+ hidden_dim (int): Hidden dimension of the feedforward layer.
53
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
54
+ """
55
+ super().__init__()
56
+ hidden_dim = int(2 * hidden_dim / 3)
57
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
58
+
59
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
60
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
61
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
62
+
63
+ def forward(self, x: Tensor) -> Tensor:
64
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
65
+
66
+
67
  def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
68
  freqs_x = []
69
  freqs_y = []
 
160
  return out.type_as(x)
161
 
162
 
163
+ def maybe_pad(x: Tensor, window_size: int) -> Tensor:
164
+ h, w = x.shape[1:3]
165
+ pad_right = (window_size - w % window_size) % window_size
166
+ pad_bottom = (window_size - h % window_size) % window_size
167
+ return F.pad(x, (0, 0, 0, pad_right, 0, pad_bottom))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
 
170
  @torch.autocast("cuda", enabled=False)
 
269
  self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coord: Tensor
270
  ) -> Tensor:
271
  b, h, w, c = tgt.shape
272
+
273
+ # pad to multiples of window size
274
+ tgt = maybe_pad(tgt, self.tgt_window_size)
275
+ src = maybe_pad(src, self.src_window_size)
276
+ tgt_coords = maybe_pad(tgt_coords, self.tgt_window_size)
277
+ src_coord = maybe_pad(src_coord, self.src_window_size)
278
+ h_pad, w_pad = tgt.shape[1:3]
279
  src_h, src_w = src.shape[1:3]
280
 
281
  # cyclic shift
 
301
  src = window_partition(src, self.src_window_size).flatten(1, 2)
302
  src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2)
303
 
304
+ attn_mask = self.get_attn_mask(
305
+ h_pad, w_pad, src_h, src_w, tgt.device, tgt.dtype
306
+ )
307
 
308
  if attn_mask is not None:
309
  attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
 
324
 
325
  # merge windows
326
  tgt = tgt.view(-1, self.tgt_window_size, self.tgt_window_size, c)
327
+ tgt = window_reverse(tgt, self.tgt_window_size, h_pad, w_pad)
328
 
329
  # reverse cyclic shift
330
  if self.tgt_shift_size > 0:
 
332
  tgt, shifts=(self.tgt_shift_size, self.tgt_shift_size), dims=(1, 2)
333
  )
334
 
335
+ return tgt[:, :h, :w, :].contiguous() # remove padding
336
 
337
 
338
  class WindowSelfAttention(nn.Module):
 
377
  """
378
  b, h, w, c = x.shape
379
 
380
+ # pad to multiples of window size
381
+ x = maybe_pad(x, self.window_size)
382
+ coords = maybe_pad(coords, self.window_size)
383
+ h_pad, w_pad = x.shape[1:3]
384
+
385
  # cyclic shift
386
  if self.shift_size > 0:
387
  x = x.roll(shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
 
393
  x = window_partition(x, self.window_size).flatten(1, 2)
394
  coords = window_partition(coords, self.window_size).flatten(1, 2)
395
 
396
+ attn_mask = self.get_attn_mask(h_pad, w_pad, x.device, x.dtype)
397
  if attn_mask is not None:
398
  attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
399
 
 
412
 
413
  # merge windows
414
  x = x.view(-1, self.window_size, self.window_size, c)
415
+ x = window_reverse(x, self.window_size, h_pad, w_pad)
416
 
417
  # reverse cyclic shift
418
  if self.shift_size > 0:
419
  x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
420
 
421
+ return x[:, :h, :w, :].contiguous() # remove padding
422
 
423
 
424
  class Block(nn.Module):
 
523
  class LSPTransformer(nn.Module):
524
  def __init__(
525
  self,
526
+ config: LSPDetrConfig,
527
+ bottleneck_channels: int,
528
+ feature_channels: list[int],
 
 
 
 
 
 
 
 
529
  ) -> None:
530
  super().__init__()
531
 
532
+ self.query_block_size = config.query_block_size
533
+ self.num_radial_distances = config.num_radial_distances
 
534
 
535
+ self.feature_sampling = FeatureSampling(bottleneck_channels, config.dim)
 
536
 
537
  self.stages = nn.ModuleList()
538
+ for i, depth in enumerate(config.depths):
539
  stage = Stage(
540
+ dim=config.dim,
541
+ src_dim=feature_channels[i],
542
  depth=depth,
543
+ num_heads=config.num_heads,
544
+ window_size=config.window_size,
545
+ tgt_window_size=config.tgt_window_sizes[i],
546
+ src_window_size=config.src_window_sizes[i],
547
+ dropout=config.dropout,
548
  )
549
  self.stages.append(stage)
550
 
551
+ self.input_norm = nn.ModuleList(nn.LayerNorm(d) for d in feature_channels)
552
 
553
  # output heads
554
+ self.class_head = nn.Linear(config.dim, config.num_classes + 1, bias=False)
555
+ self.point_head = MLP(config.dim, config.dim, 2, 3)
556
+ self.radial_distances_head = MLP(
557
+ config.dim, config.dim, config.num_radial_distances, 3
558
+ )
559
 
560
  self.init_weights()
561
 
 
565
  nn.init.constant_(self.point_head[-1].bias, 0.0)
566
 
567
  def forward(
568
+ self, bottleneck: Tensor, features: list[Tensor], height: int, width: int
569
  ) -> dict[str, Tensor | list[dict[str, Tensor]]]:
 
 
570
  b = bottleneck.size(0)
571
 
572
  src = []
573
  src_coords = []
574
+ for i, feature in enumerate(reversed(features)):
575
  h, w = feature.shape[2:4]
576
  coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
577
  src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c")))
 
622
  "logits": logits_list[-1],
623
  "points": ref_points_list[-1],
624
  "radial_distances": radial_distances_list[-1],
625
+ "absolute_points": relative_to_absolute_points(
626
+ ref_points, height, width
627
+ ).flatten(1, 2),
 
628
  "aux_outputs": [
629
  {
630
  "logits": a,
 
640
  ],
641
  }
642
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
 
644
  class LSPDetrModel(PreTrainedModel):
645
  config_class = LSPDetrConfig
 
647
  def __init__(self, config: LSPDetrConfig) -> None:
648
  super().__init__(config)
649
 
650
+ self.backbone = load_backbone(config)
651
+ _, *feature_channels, bottleneck = self.backbone.num_features
652
+ self.decode_head = LSPTransformer(config, bottleneck, feature_channels[::-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
 
654
  def forward(self, image: Tensor) -> dict[str, Tensor]:
655
+ *features, bottleneck = self.backbone(image).feature_maps
656
  height, width = image.shape[2:]
657
+ return self.decode_head(bottleneck, features, height, width)