mohammed-aljafry commited on
Commit
1ce0d83
·
verified ·
1 Parent(s): d68f054

Final fix v9: Correctly filter all unexpected kwargs during init

Browse files
Files changed (1) hide show
  1. interfuser.py +1086 -0
interfuser.py ADDED
@@ -0,0 +1,1086 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ import torch, math, copy, inspect
4
+ from torch import nn, Tensor
5
+ from functools import partial
6
+ from typing import Optional, List
7
+ from collections import OrderedDict
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
+
10
+ try:
11
+ from timm.models.layers import to_2tuple
12
+ from timm.models.resnet import resnet50d, resnet26d, resnet18d
13
+ except ImportError:
14
+ raise ImportError("This model requires timm. Please install with 'pip install timm==0.4.12' or a compatible version.")
15
+
16
+
17
+ class HybridEmbed(nn.Module):
18
+ def __init__(
19
+ self,
20
+ backbone,
21
+ img_size=224,
22
+ patch_size=1,
23
+ feature_size=None,
24
+ in_chans=3,
25
+ embed_dim=768,
26
+ ):
27
+ super().__init__()
28
+ assert isinstance(backbone, nn.Module)
29
+ img_size = to_2tuple(img_size)
30
+ patch_size = to_2tuple(patch_size)
31
+ self.img_size = img_size
32
+ self.patch_size = patch_size
33
+ self.backbone = backbone
34
+ if feature_size is None:
35
+ with torch.no_grad():
36
+ training = backbone.training
37
+ if training:
38
+ backbone.eval()
39
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
40
+ if isinstance(o, (list, tuple)):
41
+ o = o[-1] # last feature if backbone outputs list/tuple of features
42
+ feature_size = o.shape[-2:]
43
+ feature_dim = o.shape[1]
44
+ backbone.train(training)
45
+ else:
46
+ feature_size = to_2tuple(feature_size)
47
+ if hasattr(self.backbone, "feature_info"):
48
+ feature_dim = self.backbone.feature_info.channels()[-1]
49
+ else:
50
+ feature_dim = self.backbone.num_features
51
+
52
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
53
+
54
+ def forward(self, x):
55
+ x = self.backbone(x)
56
+ if isinstance(x, (list, tuple)):
57
+ x = x[-1] # last feature if backbone outputs list/tuple of features
58
+ x = self.proj(x)
59
+ global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
60
+ return x, global_x
61
+
62
+
63
+ class PositionEmbeddingSine(nn.Module):
64
+
65
+
66
+ def __init__(
67
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
68
+ ):
69
+ super().__init__()
70
+ self.num_pos_feats = num_pos_feats
71
+ self.temperature = temperature
72
+ self.normalize = normalize
73
+ if scale is not None and normalize is False:
74
+ raise ValueError("normalize should be True if scale is passed")
75
+ if scale is None:
76
+ scale = 2 * math.pi
77
+ self.scale = scale
78
+
79
+ def forward(self, tensor):
80
+ x = tensor
81
+ bs, _, h, w = x.shape
82
+ not_mask = torch.ones((bs, h, w), device=x.device)
83
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
84
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
85
+ if self.normalize:
86
+ eps = 1e-6
87
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
88
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
89
+
90
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
91
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
92
+
93
+ pos_x = x_embed[:, :, :, None] / dim_t
94
+ pos_y = y_embed[:, :, :, None] / dim_t
95
+ pos_x = torch.stack(
96
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
97
+ ).flatten(3)
98
+ pos_y = torch.stack(
99
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
100
+ ).flatten(3)
101
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
102
+ return pos
103
+
104
+
105
+ class TransformerEncoder(nn.Module):
106
+ def __init__(self, encoder_layer, num_layers, norm=None):
107
+ super().__init__()
108
+ self.layers = _get_clones(encoder_layer, num_layers)
109
+ self.num_layers = num_layers
110
+ self.norm = norm
111
+
112
+ def forward(
113
+ self,
114
+ src,
115
+ mask: Optional[Tensor] = None,
116
+ src_key_padding_mask: Optional[Tensor] = None,
117
+ pos: Optional[Tensor] = None,
118
+ ):
119
+ output = src
120
+
121
+ for layer in self.layers:
122
+ output = layer(
123
+ output,
124
+ src_mask=mask,
125
+ src_key_padding_mask=src_key_padding_mask,
126
+ pos=pos,
127
+ )
128
+
129
+ if self.norm is not None:
130
+ output = self.norm(output)
131
+
132
+ return output
133
+
134
+
135
+ class SpatialSoftmax(nn.Module):
136
+ def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
137
+ super().__init__()
138
+
139
+ self.data_format = data_format
140
+ self.height = height
141
+ self.width = width
142
+ self.channel = channel
143
+
144
+ if temperature:
145
+ self.temperature = Parameter(torch.ones(1) * temperature)
146
+ else:
147
+ self.temperature = 1.0
148
+
149
+ pos_x, pos_y = np.meshgrid(
150
+ np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
151
+ )
152
+ pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
153
+ pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
154
+ self.register_buffer("pos_x", pos_x)
155
+ self.register_buffer("pos_y", pos_y)
156
+
157
+ def forward(self, feature):
158
+ # Output:
159
+ # (N, C*2) x_0 y_0 ...
160
+
161
+ if self.data_format == "NHWC":
162
+ feature = (
163
+ feature.transpose(1, 3)
164
+ .tranpose(2, 3)
165
+ .view(-1, self.height * self.width)
166
+ )
167
+ else:
168
+ feature = feature.view(-1, self.height * self.width)
169
+
170
+ weight = F.softmax(feature / self.temperature, dim=-1)
171
+ expected_x = torch.sum(
172
+ torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
173
+ )
174
+ expected_y = torch.sum(
175
+ torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
176
+ )
177
+ expected_xy = torch.cat([expected_x, expected_y], 1)
178
+ feature_keypoints = expected_xy.view(-1, self.channel, 2)
179
+ feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
180
+ feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
181
+ return feature_keypoints
182
+
183
+
184
+ class MultiPath_Generator(nn.Module):
185
+ def __init__(self, in_channel, embed_dim, out_channel):
186
+ super().__init__()
187
+ self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
188
+ self.tconv0 = nn.Sequential(
189
+ nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
190
+ nn.BatchNorm2d(256),
191
+ nn.ReLU(True),
192
+ )
193
+ self.tconv1 = nn.Sequential(
194
+ nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
195
+ nn.BatchNorm2d(256),
196
+ nn.ReLU(True),
197
+ )
198
+ self.tconv2 = nn.Sequential(
199
+ nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
200
+ nn.BatchNorm2d(192),
201
+ nn.ReLU(True),
202
+ )
203
+ self.tconv3 = nn.Sequential(
204
+ nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
205
+ nn.BatchNorm2d(64),
206
+ nn.ReLU(True),
207
+ )
208
+ self.tconv4_list = torch.nn.ModuleList(
209
+ [
210
+ nn.Sequential(
211
+ nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
212
+ nn.Tanh(),
213
+ )
214
+ for _ in range(6)
215
+ ]
216
+ )
217
+
218
+ self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
219
+
220
+ def forward(self, x, measurements):
221
+ mask = measurements[:, :6]
222
+ mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
223
+ velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
224
+ velocity = velocity.repeat(1, 32, 2, 2)
225
+
226
+ n, d, c = x.shape
227
+ x = x.transpose(1, 2)
228
+ x = x.view(n, -1, 2, 2)
229
+ x = torch.cat([x, velocity], dim=1)
230
+ x = self.tconv0(x)
231
+ x = self.tconv1(x)
232
+ x = self.tconv2(x)
233
+ x = self.tconv3(x)
234
+ x = self.upsample(x)
235
+ xs = []
236
+ for i in range(6):
237
+ xt = self.tconv4_list[i](x)
238
+ xs.append(xt)
239
+ xs = torch.stack(xs, dim=1)
240
+ x = torch.sum(xs * mask, dim=1)
241
+ x = self.spatial_softmax(x)
242
+ return x
243
+
244
+
245
+ class LinearWaypointsPredictor(nn.Module):
246
+ def __init__(self, input_dim, cumsum=True):
247
+ super().__init__()
248
+ self.cumsum = cumsum
249
+ self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
250
+ self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
251
+ self.head_relu = nn.ReLU(inplace=True)
252
+ self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
253
+
254
+ def forward(self, x, measurements):
255
+ # input shape: n 10 embed_dim
256
+ bs, n, dim = x.shape
257
+ x = x + self.rank_embed
258
+ x = x.reshape(-1, dim)
259
+
260
+ mask = measurements[:, :6]
261
+ mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
262
+
263
+ rs = []
264
+ for i in range(6):
265
+ res = self.head_fc1_list[i](x)
266
+ res = self.head_relu(res)
267
+ res = self.head_fc2_list[i](res)
268
+ rs.append(res)
269
+ rs = torch.stack(rs, 1)
270
+ x = torch.sum(rs * mask, dim=1)
271
+
272
+ x = x.view(bs, n, 2)
273
+ if self.cumsum:
274
+ x = torch.cumsum(x, 1)
275
+ return x
276
+
277
+
278
+ class GRUWaypointsPredictor(nn.Module):
279
+ def __init__(self, input_dim, waypoints=10):
280
+ super().__init__()
281
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
282
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
283
+ self.encoder = nn.Linear(2, 64)
284
+ self.decoder = nn.Linear(64, 2)
285
+ self.waypoints = waypoints
286
+
287
+ def forward(self, x, target_point):
288
+ bs = x.shape[0]
289
+ z = self.encoder(target_point).unsqueeze(0)
290
+ output, _ = self.gru(x, z)
291
+ output = output.reshape(bs * self.waypoints, -1)
292
+ output = self.decoder(output).reshape(bs, self.waypoints, 2)
293
+ output = torch.cumsum(output, 1)
294
+ return output
295
+
296
+ class GRUWaypointsPredictorWithCommand(nn.Module):
297
+ def __init__(self, input_dim, waypoints=10):
298
+ super().__init__()
299
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
300
+ self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
301
+ self.encoder = nn.Linear(2, 64)
302
+ self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
303
+ self.waypoints = waypoints
304
+
305
+ def forward(self, x, target_point, measurements):
306
+ bs, n, dim = x.shape
307
+ mask = measurements[:, :6, None, None]
308
+ mask = mask.repeat(1, 1, self.waypoints, 2)
309
+
310
+ z = self.encoder(target_point).unsqueeze(0)
311
+ outputs = []
312
+ for i in range(6):
313
+ output, _ = self.grus[i](x, z)
314
+ output = output.reshape(bs * self.waypoints, -1)
315
+ output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
316
+ output = torch.cumsum(output, 1)
317
+ outputs.append(output)
318
+ outputs = torch.stack(outputs, 1)
319
+ output = torch.sum(outputs * mask, dim=1)
320
+ return output
321
+
322
+
323
+ class TransformerDecoder(nn.Module):
324
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
325
+ super().__init__()
326
+ self.layers = _get_clones(decoder_layer, num_layers)
327
+ self.num_layers = num_layers
328
+ self.norm = norm
329
+ self.return_intermediate = return_intermediate
330
+
331
+ def forward(
332
+ self,
333
+ tgt,
334
+ memory,
335
+ tgt_mask: Optional[Tensor] = None,
336
+ memory_mask: Optional[Tensor] = None,
337
+ tgt_key_padding_mask: Optional[Tensor] = None,
338
+ memory_key_padding_mask: Optional[Tensor] = None,
339
+ pos: Optional[Tensor] = None,
340
+ query_pos: Optional[Tensor] = None,
341
+ ):
342
+ output = tgt
343
+
344
+ intermediate = []
345
+
346
+ for layer in self.layers:
347
+ output = layer(
348
+ output,
349
+ memory,
350
+ tgt_mask=tgt_mask,
351
+ memory_mask=memory_mask,
352
+ tgt_key_padding_mask=tgt_key_padding_mask,
353
+ memory_key_padding_mask=memory_key_padding_mask,
354
+ pos=pos,
355
+ query_pos=query_pos,
356
+ )
357
+ if self.return_intermediate:
358
+ intermediate.append(self.norm(output))
359
+
360
+ if self.norm is not None:
361
+ output = self.norm(output)
362
+ if self.return_intermediate:
363
+ intermediate.pop()
364
+ intermediate.append(output)
365
+
366
+ if self.return_intermediate:
367
+ return torch.stack(intermediate)
368
+
369
+ return output.unsqueeze(0)
370
+
371
+
372
+ class TransformerEncoderLayer(nn.Module):
373
+ def __init__(
374
+ self,
375
+ d_model,
376
+ nhead,
377
+ dim_feedforward=2048,
378
+ dropout=0.1,
379
+ activation=nn.ReLU(),
380
+ normalize_before=False,
381
+ ):
382
+ super().__init__()
383
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
384
+ # Implementation of Feedforward model
385
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
386
+ self.dropout = nn.Dropout(dropout)
387
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
388
+
389
+ self.norm1 = nn.LayerNorm(d_model)
390
+ self.norm2 = nn.LayerNorm(d_model)
391
+ self.dropout1 = nn.Dropout(dropout)
392
+ self.dropout2 = nn.Dropout(dropout)
393
+
394
+ self.activation = activation()
395
+ self.normalize_before = normalize_before
396
+
397
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
398
+ return tensor if pos is None else tensor + pos
399
+
400
+ def forward_post(
401
+ self,
402
+ src,
403
+ src_mask: Optional[Tensor] = None,
404
+ src_key_padding_mask: Optional[Tensor] = None,
405
+ pos: Optional[Tensor] = None,
406
+ ):
407
+ q = k = self.with_pos_embed(src, pos)
408
+ src2 = self.self_attn(
409
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
410
+ )[0]
411
+ src = src + self.dropout1(src2)
412
+ src = self.norm1(src)
413
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
414
+ src = src + self.dropout2(src2)
415
+ src = self.norm2(src)
416
+ return src
417
+
418
+ def forward_pre(
419
+ self,
420
+ src,
421
+ src_mask: Optional[Tensor] = None,
422
+ src_key_padding_mask: Optional[Tensor] = None,
423
+ pos: Optional[Tensor] = None,
424
+ ):
425
+ src2 = self.norm1(src)
426
+ q = k = self.with_pos_embed(src2, pos)
427
+ src2 = self.self_attn(
428
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
429
+ )[0]
430
+ src = src + self.dropout1(src2)
431
+ src2 = self.norm2(src)
432
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
433
+ src = src + self.dropout2(src2)
434
+ return src
435
+
436
+ def forward(
437
+ self,
438
+ src,
439
+ src_mask: Optional[Tensor] = None,
440
+ src_key_padding_mask: Optional[Tensor] = None,
441
+ pos: Optional[Tensor] = None,
442
+ ):
443
+ if self.normalize_before:
444
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
445
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
446
+
447
+
448
+ class TransformerDecoderLayer(nn.Module):
449
+ def __init__(
450
+ self,
451
+ d_model,
452
+ nhead,
453
+ dim_feedforward=2048,
454
+ dropout=0.1,
455
+ activation=nn.ReLU(),
456
+ normalize_before=False,
457
+ ):
458
+ super().__init__()
459
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
460
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
461
+ # Implementation of Feedforward model
462
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
463
+ self.dropout = nn.Dropout(dropout)
464
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
465
+
466
+ self.norm1 = nn.LayerNorm(d_model)
467
+ self.norm2 = nn.LayerNorm(d_model)
468
+ self.norm3 = nn.LayerNorm(d_model)
469
+ self.dropout1 = nn.Dropout(dropout)
470
+ self.dropout2 = nn.Dropout(dropout)
471
+ self.dropout3 = nn.Dropout(dropout)
472
+
473
+ self.activation = activation()
474
+ self.normalize_before = normalize_before
475
+
476
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
477
+ return tensor if pos is None else tensor + pos
478
+
479
+ def forward_post(
480
+ self,
481
+ tgt,
482
+ memory,
483
+ tgt_mask: Optional[Tensor] = None,
484
+ memory_mask: Optional[Tensor] = None,
485
+ tgt_key_padding_mask: Optional[Tensor] = None,
486
+ memory_key_padding_mask: Optional[Tensor] = None,
487
+ pos: Optional[Tensor] = None,
488
+ query_pos: Optional[Tensor] = None,
489
+ ):
490
+ q = k = self.with_pos_embed(tgt, query_pos)
491
+ tgt2 = self.self_attn(
492
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
493
+ )[0]
494
+ tgt = tgt + self.dropout1(tgt2)
495
+ tgt = self.norm1(tgt)
496
+ tgt2 = self.multihead_attn(
497
+ query=self.with_pos_embed(tgt, query_pos),
498
+ key=self.with_pos_embed(memory, pos),
499
+ value=memory,
500
+ attn_mask=memory_mask,
501
+ key_padding_mask=memory_key_padding_mask,
502
+ )[0]
503
+ tgt = tgt + self.dropout2(tgt2)
504
+ tgt = self.norm2(tgt)
505
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
506
+ tgt = tgt + self.dropout3(tgt2)
507
+ tgt = self.norm3(tgt)
508
+ return tgt
509
+
510
+ def forward_pre(
511
+ self,
512
+ tgt,
513
+ memory,
514
+ tgt_mask: Optional[Tensor] = None,
515
+ memory_mask: Optional[Tensor] = None,
516
+ tgt_key_padding_mask: Optional[Tensor] = None,
517
+ memory_key_padding_mask: Optional[Tensor] = None,
518
+ pos: Optional[Tensor] = None,
519
+ query_pos: Optional[Tensor] = None,
520
+ ):
521
+ tgt2 = self.norm1(tgt)
522
+ q = k = self.with_pos_embed(tgt2, query_pos)
523
+ tgt2 = self.self_attn(
524
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
525
+ )[0]
526
+ tgt = tgt + self.dropout1(tgt2)
527
+ tgt2 = self.norm2(tgt)
528
+ tgt2 = self.multihead_attn(
529
+ query=self.with_pos_embed(tgt2, query_pos),
530
+ key=self.with_pos_embed(memory, pos),
531
+ value=memory,
532
+ attn_mask=memory_mask,
533
+ key_padding_mask=memory_key_padding_mask,
534
+ )[0]
535
+ tgt = tgt + self.dropout2(tgt2)
536
+ tgt2 = self.norm3(tgt)
537
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
538
+ tgt = tgt + self.dropout3(tgt2)
539
+ return tgt
540
+
541
+ def forward(
542
+ self,
543
+ tgt,
544
+ memory,
545
+ tgt_mask: Optional[Tensor] = None,
546
+ memory_mask: Optional[Tensor] = None,
547
+ tgt_key_padding_mask: Optional[Tensor] = None,
548
+ memory_key_padding_mask: Optional[Tensor] = None,
549
+ pos: Optional[Tensor] = None,
550
+ query_pos: Optional[Tensor] = None,
551
+ ):
552
+ if self.normalize_before:
553
+ return self.forward_pre(
554
+ tgt,
555
+ memory,
556
+ tgt_mask,
557
+ memory_mask,
558
+ tgt_key_padding_mask,
559
+ memory_key_padding_mask,
560
+ pos,
561
+ query_pos,
562
+ )
563
+ return self.forward_post(
564
+ tgt,
565
+ memory,
566
+ tgt_mask,
567
+ memory_mask,
568
+ tgt_key_padding_mask,
569
+ memory_key_padding_mask,
570
+ pos,
571
+ query_pos,
572
+ )
573
+
574
+
575
+ def _get_clones(module, N):
576
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
577
+
578
+
579
+ def _get_activation_fn(activation):
580
+ if activation == "relu":
581
+ return F.relu
582
+ if activation == "gelu":
583
+ return F.gelu
584
+ if activation == "glu":
585
+ return F.glu
586
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
587
+
588
+
589
+ def build_attn_mask(mask_type):
590
+ mask = torch.ones((151, 151), dtype=torch.bool).cuda()
591
+ if mask_type == "seperate_all":
592
+ mask[:50, :50] = False
593
+ mask[50:67, 50:67] = False
594
+ mask[67:84, 67:84] = False
595
+ mask[84:101, 84:101] = False
596
+ mask[101:151, 101:151] = False
597
+ elif mask_type == "seperate_view":
598
+ mask[:50, :50] = False
599
+ mask[50:67, 50:67] = False
600
+ mask[67:84, 67:84] = False
601
+ mask[84:101, 84:101] = False
602
+ mask[101:151, :] = False
603
+ mask[:, 101:151] = False
604
+ return mask
605
+
606
+
607
+ class Interfuser(nn.Module):
608
+ def __init__(
609
+ self,
610
+ img_size=224,
611
+ multi_view_img_size=112,
612
+ patch_size=8,
613
+ in_chans=3,
614
+ embed_dim=768,
615
+ enc_depth=6,
616
+ dec_depth=6,
617
+ dim_feedforward=2048,
618
+ normalize_before=False,
619
+ rgb_backbone_name="r26",
620
+ lidar_backbone_name="r26",
621
+ num_heads=8,
622
+ norm_layer=None,
623
+ dropout=0.1,
624
+ end2end=False,
625
+ direct_concat=True,
626
+ separate_view_attention=False,
627
+ separate_all_attention=False,
628
+ act_layer=None,
629
+ weight_init="",
630
+ freeze_num=-1,
631
+ with_lidar=False,
632
+ with_right_left_sensors=True,
633
+ with_center_sensor=False,
634
+ traffic_pred_head_type="det",
635
+ waypoints_pred_head="heatmap",
636
+ reverse_pos=True,
637
+ use_different_backbone=False,
638
+ use_view_embed=True,
639
+ use_mmad_pretrain=None,
640
+ ):
641
+ super().__init__()
642
+ self.traffic_pred_head_type = traffic_pred_head_type
643
+ self.num_features = (
644
+ self.embed_dim
645
+ ) = embed_dim # num_features for consistency with other models
646
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
647
+ act_layer = act_layer or nn.GELU
648
+
649
+ self.reverse_pos = reverse_pos
650
+ self.waypoints_pred_head = waypoints_pred_head
651
+ self.with_lidar = with_lidar
652
+ self.with_right_left_sensors = with_right_left_sensors
653
+ self.with_center_sensor = with_center_sensor
654
+
655
+ self.direct_concat = direct_concat
656
+ self.separate_view_attention = separate_view_attention
657
+ self.separate_all_attention = separate_all_attention
658
+ self.end2end = end2end
659
+ self.use_view_embed = use_view_embed
660
+
661
+ if self.direct_concat:
662
+ in_chans = in_chans * 4
663
+ self.with_center_sensor = False
664
+ self.with_right_left_sensors = False
665
+
666
+ if self.separate_view_attention:
667
+ self.attn_mask = build_attn_mask("seperate_view")
668
+ elif self.separate_all_attention:
669
+ self.attn_mask = build_attn_mask("seperate_all")
670
+ else:
671
+ self.attn_mask = None
672
+
673
+ if use_different_backbone:
674
+ if rgb_backbone_name == "r50":
675
+ self.rgb_backbone = resnet50d(
676
+ pretrained=True,
677
+ in_chans=in_chans,
678
+ features_only=True,
679
+ out_indices=[4],
680
+ )
681
+ elif rgb_backbone_name == "r26":
682
+ self.rgb_backbone = resnet26d(
683
+ pretrained=True,
684
+ in_chans=in_chans,
685
+ features_only=True,
686
+ out_indices=[4],
687
+ )
688
+ elif rgb_backbone_name == "r18":
689
+ self.rgb_backbone = resnet18d(
690
+ pretrained=True,
691
+ in_chans=in_chans,
692
+ features_only=True,
693
+ out_indices=[4],
694
+ )
695
+ if lidar_backbone_name == "r50":
696
+ self.lidar_backbone = resnet50d(
697
+ pretrained=False,
698
+ in_chans=in_chans,
699
+ features_only=True,
700
+ out_indices=[4],
701
+ )
702
+ elif lidar_backbone_name == "r26":
703
+ self.lidar_backbone = resnet26d(
704
+ pretrained=False,
705
+ in_chans=in_chans,
706
+ features_only=True,
707
+ out_indices=[4],
708
+ )
709
+ elif lidar_backbone_name == "r18":
710
+ self.lidar_backbone = resnet18d(
711
+ pretrained=False, in_chans=3, features_only=True, out_indices=[4]
712
+ )
713
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
714
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
715
+
716
+ if use_mmad_pretrain:
717
+ params = torch.load(use_mmad_pretrain)["state_dict"]
718
+ updated_params = OrderedDict()
719
+ for key in params:
720
+ if "backbone" in key:
721
+ updated_params[key.replace("backbone.", "")] = params[key]
722
+ self.rgb_backbone.load_state_dict(updated_params)
723
+
724
+ self.rgb_patch_embed = rgb_embed_layer(
725
+ img_size=img_size,
726
+ patch_size=patch_size,
727
+ in_chans=in_chans,
728
+ embed_dim=embed_dim,
729
+ )
730
+ self.lidar_patch_embed = lidar_embed_layer(
731
+ img_size=img_size,
732
+ patch_size=patch_size,
733
+ in_chans=3,
734
+ embed_dim=embed_dim,
735
+ )
736
+ else:
737
+ if rgb_backbone_name == "r50":
738
+ self.rgb_backbone = resnet50d(
739
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
740
+ )
741
+ elif rgb_backbone_name == "r101":
742
+ self.rgb_backbone = resnet101d(
743
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
744
+ )
745
+ elif rgb_backbone_name == "r26":
746
+ self.rgb_backbone = resnet26d(
747
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
748
+ )
749
+ elif rgb_backbone_name == "r18":
750
+ self.rgb_backbone = resnet18d(
751
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
752
+ )
753
+ embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
754
+
755
+ self.rgb_patch_embed = embed_layer(
756
+ img_size=img_size,
757
+ patch_size=patch_size,
758
+ in_chans=in_chans,
759
+ embed_dim=embed_dim,
760
+ )
761
+ self.lidar_patch_embed = embed_layer(
762
+ img_size=img_size,
763
+ patch_size=patch_size,
764
+ in_chans=in_chans,
765
+ embed_dim=embed_dim,
766
+ )
767
+
768
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
769
+ self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
770
+
771
+ if self.end2end:
772
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
773
+ self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
774
+ elif self.waypoints_pred_head == "heatmap":
775
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
776
+ self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
777
+ else:
778
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
779
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
780
+
781
+ if self.end2end:
782
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
783
+ elif self.waypoints_pred_head == "heatmap":
784
+ self.waypoints_generator = MultiPath_Generator(
785
+ embed_dim + 32, embed_dim, 10
786
+ )
787
+ elif self.waypoints_pred_head == "gru":
788
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
789
+ elif self.waypoints_pred_head == "gru-command":
790
+ self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
791
+ elif self.waypoints_pred_head == "linear":
792
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
793
+ elif self.waypoints_pred_head == "linear-sum":
794
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
795
+
796
+ self.junction_pred_head = nn.Linear(embed_dim, 2)
797
+ self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
798
+ self.stop_sign_head = nn.Linear(embed_dim, 2)
799
+
800
+ if self.traffic_pred_head_type == "det":
801
+ self.traffic_pred_head = nn.Sequential(
802
+ *[
803
+ nn.Linear(embed_dim + 32, 64),
804
+ nn.ReLU(),
805
+ nn.Linear(64, 7),
806
+ nn.Sigmoid(),
807
+ ]
808
+ )
809
+ elif self.traffic_pred_head_type == "seg":
810
+ self.traffic_pred_head = nn.Sequential(
811
+ *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
812
+ )
813
+
814
+ self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
815
+
816
+ encoder_layer = TransformerEncoderLayer(
817
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
818
+ )
819
+ self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
820
+
821
+ decoder_layer = TransformerDecoderLayer(
822
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
823
+ )
824
+ decoder_norm = nn.LayerNorm(embed_dim)
825
+ self.decoder = TransformerDecoder(
826
+ decoder_layer, dec_depth, decoder_norm, return_intermediate=False
827
+ )
828
+ self.reset_parameters()
829
+
830
+ def reset_parameters(self):
831
+ nn.init.uniform_(self.global_embed)
832
+ nn.init.uniform_(self.view_embed)
833
+ nn.init.uniform_(self.query_embed)
834
+ nn.init.uniform_(self.query_pos_embed)
835
+
836
+ def forward_features(
837
+ self,
838
+ front_image,
839
+ left_image,
840
+ right_image,
841
+ front_center_image,
842
+ lidar,
843
+ measurements,
844
+ ):
845
+ features = []
846
+
847
+ # Front view processing
848
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
849
+ if self.use_view_embed:
850
+ front_image_token = (
851
+ front_image_token
852
+ + self.view_embed[:, :, 0:1, :]
853
+ + self.position_encoding(front_image_token)
854
+ )
855
+ else:
856
+ front_image_token = front_image_token + self.position_encoding(
857
+ front_image_token
858
+ )
859
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
860
+ front_image_token_global = (
861
+ front_image_token_global
862
+ + self.view_embed[:, :, 0, :]
863
+ + self.global_embed[:, :, 0:1]
864
+ )
865
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
866
+ features.extend([front_image_token, front_image_token_global])
867
+
868
+ if self.with_right_left_sensors:
869
+ # Left view processing
870
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
871
+ if self.use_view_embed:
872
+ left_image_token = (
873
+ left_image_token
874
+ + self.view_embed[:, :, 1:2, :]
875
+ + self.position_encoding(left_image_token)
876
+ )
877
+ else:
878
+ left_image_token = left_image_token + self.position_encoding(
879
+ left_image_token
880
+ )
881
+ left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
882
+ left_image_token_global = (
883
+ left_image_token_global
884
+ + self.view_embed[:, :, 1, :]
885
+ + self.global_embed[:, :, 1:2]
886
+ )
887
+ left_image_token_global = left_image_token_global.permute(2, 0, 1)
888
+
889
+ # Right view processing
890
+ right_image_token, right_image_token_global = self.rgb_patch_embed(
891
+ right_image
892
+ )
893
+ if self.use_view_embed:
894
+ right_image_token = (
895
+ right_image_token
896
+ + self.view_embed[:, :, 2:3, :]
897
+ + self.position_encoding(right_image_token)
898
+ )
899
+ else:
900
+ right_image_token = right_image_token + self.position_encoding(
901
+ right_image_token
902
+ )
903
+ right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
904
+ right_image_token_global = (
905
+ right_image_token_global
906
+ + self.view_embed[:, :, 2, :]
907
+ + self.global_embed[:, :, 2:3]
908
+ )
909
+ right_image_token_global = right_image_token_global.permute(2, 0, 1)
910
+
911
+ features.extend(
912
+ [
913
+ left_image_token,
914
+ left_image_token_global,
915
+ right_image_token,
916
+ right_image_token_global,
917
+ ]
918
+ )
919
+
920
+ if self.with_center_sensor:
921
+ # Front center view processing
922
+ (
923
+ front_center_image_token,
924
+ front_center_image_token_global,
925
+ ) = self.rgb_patch_embed(front_center_image)
926
+ if self.use_view_embed:
927
+ front_center_image_token = (
928
+ front_center_image_token
929
+ + self.view_embed[:, :, 3:4, :]
930
+ + self.position_encoding(front_center_image_token)
931
+ )
932
+ else:
933
+ front_center_image_token = (
934
+ front_center_image_token
935
+ + self.position_encoding(front_center_image_token)
936
+ )
937
+
938
+ front_center_image_token = front_center_image_token.flatten(2).permute(
939
+ 2, 0, 1
940
+ )
941
+ front_center_image_token_global = (
942
+ front_center_image_token_global
943
+ + self.view_embed[:, :, 3, :]
944
+ + self.global_embed[:, :, 3:4]
945
+ )
946
+ front_center_image_token_global = front_center_image_token_global.permute(
947
+ 2, 0, 1
948
+ )
949
+ features.extend([front_center_image_token, front_center_image_token_global])
950
+
951
+ if self.with_lidar:
952
+ lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
953
+ if self.use_view_embed:
954
+ lidar_token = (
955
+ lidar_token
956
+ + self.view_embed[:, :, 4:5, :]
957
+ + self.position_encoding(lidar_token)
958
+ )
959
+ else:
960
+ lidar_token = lidar_token + self.position_encoding(lidar_token)
961
+ lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
962
+ lidar_token_global = (
963
+ lidar_token_global
964
+ + self.view_embed[:, :, 4, :]
965
+ + self.global_embed[:, :, 4:5]
966
+ )
967
+ lidar_token_global = lidar_token_global.permute(2, 0, 1)
968
+ features.extend([lidar_token, lidar_token_global])
969
+
970
+ features = torch.cat(features, 0)
971
+ return features
972
+
973
+ def forward(self, x):
974
+ front_image = x["rgb"]
975
+ left_image = x["rgb_left"]
976
+ right_image = x["rgb_right"]
977
+ front_center_image = x["rgb_center"]
978
+ measurements = x["measurements"]
979
+ target_point = x["target_point"]
980
+ lidar = x["lidar"]
981
+
982
+ if self.direct_concat:
983
+ img_size = front_image.shape[-1]
984
+ left_image = torch.nn.functional.interpolate(
985
+ left_image, size=(img_size, img_size)
986
+ )
987
+ right_image = torch.nn.functional.interpolate(
988
+ right_image, size=(img_size, img_size)
989
+ )
990
+ front_center_image = torch.nn.functional.interpolate(
991
+ front_center_image, size=(img_size, img_size)
992
+ )
993
+ front_image = torch.cat(
994
+ [front_image, left_image, right_image, front_center_image], dim=1
995
+ )
996
+ features = self.forward_features(
997
+ front_image,
998
+ left_image,
999
+ right_image,
1000
+ front_center_image,
1001
+ lidar,
1002
+ measurements,
1003
+ )
1004
+
1005
+ bs = front_image.shape[0]
1006
+
1007
+ if self.end2end:
1008
+ tgt = self.query_pos_embed.repeat(bs, 1, 1)
1009
+ else:
1010
+ tgt = self.position_encoding(
1011
+ torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
1012
+ )
1013
+ tgt = tgt.flatten(2)
1014
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
1015
+ tgt = tgt.permute(2, 0, 1)
1016
+
1017
+ memory = self.encoder(features, mask=self.attn_mask)
1018
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
1019
+
1020
+ hs = hs.permute(1, 0, 2) # Batchsize , N, C
1021
+ if self.end2end:
1022
+ waypoints = self.waypoints_generator(hs, target_point)
1023
+ return waypoints
1024
+
1025
+ if self.waypoints_pred_head != "heatmap":
1026
+ traffic_feature = hs[:, :400]
1027
+ is_junction_feature = hs[:, 400]
1028
+ traffic_light_state_feature = hs[:, 400]
1029
+ stop_sign_feature = hs[:, 400]
1030
+ waypoints_feature = hs[:, 401:411]
1031
+ else:
1032
+ traffic_feature = hs[:, :400]
1033
+ is_junction_feature = hs[:, 400]
1034
+ traffic_light_state_feature = hs[:, 400]
1035
+ stop_sign_feature = hs[:, 400]
1036
+ waypoints_feature = hs[:, 401:405]
1037
+
1038
+ if self.waypoints_pred_head == "heatmap":
1039
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1040
+ elif self.waypoints_pred_head == "gru":
1041
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
1042
+ elif self.waypoints_pred_head == "gru-command":
1043
+ waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
1044
+ elif self.waypoints_pred_head == "linear":
1045
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1046
+ elif self.waypoints_pred_head == "linear-sum":
1047
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1048
+
1049
+ is_junction = self.junction_pred_head(is_junction_feature)
1050
+ traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
1051
+ stop_sign = self.stop_sign_head(stop_sign_feature)
1052
+
1053
+ velocity = measurements[:, 6:7].unsqueeze(-1)
1054
+ velocity = velocity.repeat(1, 400, 32)
1055
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1056
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
1057
+ return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
1058
+
1059
+ # --- HF WRAPPER CLASSES (WITH THE FINAL FIX) ---
1060
+ class InterfuserConfig(PretrainedConfig):
1061
+ model_type="interfuser"
1062
+ def __init__(self, **kwargs):
1063
+ super().__init__(**kwargs)
1064
+ for key, value in kwargs.items(): setattr(self, key, value)
1065
+ self.auto_map={"AutoModel":"modeling_interfuser.InterfuserForHuggingFace"}
1066
+
1067
+ class InterfuserForHuggingFace(PreTrainedModel):
1068
+ config_class = InterfuserConfig
1069
+ def __init__(self, config: InterfuserConfig):
1070
+ super().__init__(config)
1071
+ init_args = config.to_dict()
1072
+
1073
+ # ** الإصلاح الرئيسي هنا **
1074
+ # نزيل كل المفاتيح التي تضيفها transformers والتي لا يتوقعها __init__ الأصلي
1075
+ # الطريقة الأكثر أماناً هي الحصول على قائمة المعاملات المتوقعة ديناميكياً
1076
+ expected_keys = inspect.signature(Interfuser.__init__).parameters.keys()
1077
+
1078
+ # نقوم بإنشاء قاموس جديد يحتوي فقط على المفاتيح المتوقعة
1079
+ final_args = {key: init_args[key] for key in expected_keys if key in init_args}
1080
+
1081
+ # الآن final_args يحتوي فقط على المعاملات التي يعرفها Interfuser
1082
+ self.interfuser = Interfuser(**final_args)
1083
+
1084
+ def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
1085
+ inputs_dict = {'rgb':rgb, 'rgb_left':rgb_left, 'rgb_right':rgb_right, 'rgb_center':rgb_center, 'lidar':lidar, 'measurements':measurements, 'target_point':target_point}
1086
+ return self.interfuser.forward(inputs_dict)