mohammed-aljafry commited on
Commit
aa7de94
·
verified ·
1 Parent(s): 5587234

Upload model_definition.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_definition.py +1338 -0
model_definition.py ADDED
@@ -0,0 +1,1338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_definition.py
2
+ # ============================================================================
3
+ # الاستيرادات الأساسية
4
+ # ============================================================================
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.optim as optim
9
+ from torch.optim import AdamW
10
+ from torch.optim.lr_scheduler import OneCycleLR
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from torchvision import transforms
13
+ from functools import partial
14
+ from typing import Optional, List
15
+ from torch import Tensor
16
+ import os
17
+ import json
18
+ import numpy as np
19
+ import cv2
20
+ from PIL import Image
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from torchvision import transforms
25
+ from functools import partial
26
+ from collections import deque, OrderedDict
27
+ import math
28
+ from torch.nn import MultiheadAttention
29
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
30
+ from torch.nn import TransformerDecoder, TransformerDecoderLayer
31
+ from timm.models.resnet import resnet50d, resnet26d, resnet18d
32
+ try:
33
+ from timm.layers import trunc_normal_
34
+ except ImportError:
35
+ from timm.models.layers import trunc_normal_
36
+
37
+ # مكتبات إضافية
38
+ import os
39
+ import json
40
+ import logging
41
+ import math
42
+ import copy
43
+ from pathlib import Path
44
+ from collections import OrderedDict
45
+
46
+ # مكتبات معالجة البيانات
47
+ import numpy as np
48
+ import cv2
49
+
50
+ # مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة)
51
+ try:
52
+ import wandb
53
+ WANDB_AVAILABLE = True
54
+ except ImportError:
55
+ WANDB_AVAILABLE = False
56
+
57
+ try:
58
+ from tqdm import tqdm
59
+ except ImportError:
60
+ # إذا لم تكن tqdm متوفرة، استخدم دالة بديلة
61
+ def tqdm(iterable, *args, **kwargs):
62
+ return iterable
63
+
64
+ # ============================================================================
65
+ # دوال مساعدة
66
+ # ============================================================================
67
+ def to_2tuple(x):
68
+ """تحويل قيمة إلى tuple من عنصرين"""
69
+ if isinstance(x, (list, tuple)):
70
+ return tuple(x)
71
+ return (x, x)
72
+ # ============================================================================
73
+ # ============================================================================
74
+
75
+ class HybridEmbed(nn.Module):
76
+ def __init__(
77
+ self,
78
+ backbone,
79
+ img_size=224,
80
+ patch_size=1,
81
+ feature_size=None,
82
+ in_chans=3,
83
+ embed_dim=768,
84
+ ):
85
+ super().__init__()
86
+ assert isinstance(backbone, nn.Module)
87
+ img_size = to_2tuple(img_size)
88
+ patch_size = to_2tuple(patch_size)
89
+ self.img_size = img_size
90
+ self.patch_size = patch_size
91
+ self.backbone = backbone
92
+ if feature_size is None:
93
+ with torch.no_grad():
94
+ training = backbone.training
95
+ if training:
96
+ backbone.eval()
97
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
98
+ if isinstance(o, (list, tuple)):
99
+ o = o[-1] # last feature if backbone outputs list/tuple of features
100
+ feature_size = o.shape[-2:]
101
+ feature_dim = o.shape[1]
102
+ backbone.train(training)
103
+ else:
104
+ feature_size = to_2tuple(feature_size)
105
+ if hasattr(self.backbone, "feature_info"):
106
+ feature_dim = self.backbone.feature_info.channels()[-1]
107
+ else:
108
+ feature_dim = self.backbone.num_features
109
+
110
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
111
+
112
+ def forward(self, x):
113
+ x = self.backbone(x)
114
+ if isinstance(x, (list, tuple)):
115
+ x = x[-1] # last feature if backbone outputs list/tuple of features
116
+ x = self.proj(x)
117
+ global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
118
+ return x, global_x
119
+
120
+
121
+ class PositionEmbeddingSine(nn.Module):
122
+ """
123
+ This is a more standard version of the position embedding, very similar to the one
124
+ used by the Attention is all you need paper, generalized to work on images.
125
+ """
126
+
127
+ def __init__(
128
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
129
+ ):
130
+ super().__init__()
131
+ self.num_pos_feats = num_pos_feats
132
+ self.temperature = temperature
133
+ self.normalize = normalize
134
+ if scale is not None and normalize is False:
135
+ raise ValueError("normalize should be True if scale is passed")
136
+ if scale is None:
137
+ scale = 2 * math.pi
138
+ self.scale = scale
139
+
140
+ def forward(self, tensor):
141
+ x = tensor
142
+ bs, _, h, w = x.shape
143
+ not_mask = torch.ones((bs, h, w), device=x.device)
144
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
145
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
146
+ if self.normalize:
147
+ eps = 1e-6
148
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
149
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
150
+
151
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
152
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
153
+
154
+ pos_x = x_embed[:, :, :, None] / dim_t
155
+ pos_y = y_embed[:, :, :, None] / dim_t
156
+ pos_x = torch.stack(
157
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
158
+ ).flatten(3)
159
+ pos_y = torch.stack(
160
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
161
+ ).flatten(3)
162
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
163
+ return pos
164
+
165
+
166
+ class TransformerEncoder(nn.Module):
167
+ def __init__(self, encoder_layer, num_layers, norm=None):
168
+ super().__init__()
169
+ self.layers = _get_clones(encoder_layer, num_layers)
170
+ self.num_layers = num_layers
171
+ self.norm = norm
172
+
173
+ def forward(
174
+ self,
175
+ src,
176
+ mask: Optional[Tensor] = None,
177
+ src_key_padding_mask: Optional[Tensor] = None,
178
+ pos: Optional[Tensor] = None,
179
+ ):
180
+ output = src
181
+
182
+ for layer in self.layers:
183
+ output = layer(
184
+ output,
185
+ src_mask=mask,
186
+ src_key_padding_mask=src_key_padding_mask,
187
+ pos=pos,
188
+ )
189
+
190
+ if self.norm is not None:
191
+ output = self.norm(output)
192
+
193
+ return output
194
+
195
+
196
+ class SpatialSoftmax(nn.Module):
197
+ def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
198
+ super().__init__()
199
+
200
+ self.data_format = data_format
201
+ self.height = height
202
+ self.width = width
203
+ self.channel = channel
204
+
205
+ if temperature:
206
+ self.temperature = Parameter(torch.ones(1) * temperature)
207
+ else:
208
+ self.temperature = 1.0
209
+
210
+ pos_x, pos_y = np.meshgrid(
211
+ np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
212
+ )
213
+ pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
214
+ pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
215
+ self.register_buffer("pos_x", pos_x)
216
+ self.register_buffer("pos_y", pos_y)
217
+
218
+ def forward(self, feature):
219
+ # Output:
220
+ # (N, C*2) x_0 y_0 ...
221
+
222
+ if self.data_format == "NHWC":
223
+ feature = (
224
+ feature.transpose(1, 3)
225
+ .tranpose(2, 3)
226
+ .view(-1, self.height * self.width)
227
+ )
228
+ else:
229
+ feature = feature.view(-1, self.height * self.width)
230
+
231
+ weight = F.softmax(feature / self.temperature, dim=-1)
232
+ expected_x = torch.sum(
233
+ torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
234
+ )
235
+ expected_y = torch.sum(
236
+ torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
237
+ )
238
+ expected_xy = torch.cat([expected_x, expected_y], 1)
239
+ feature_keypoints = expected_xy.view(-1, self.channel, 2)
240
+ feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
241
+ feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
242
+ return feature_keypoints
243
+
244
+
245
+ class MultiPath_Generator(nn.Module):
246
+ def __init__(self, in_channel, embed_dim, out_channel):
247
+ super().__init__()
248
+ self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
249
+ self.tconv0 = nn.Sequential(
250
+ nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
251
+ nn.BatchNorm2d(256),
252
+ nn.ReLU(True),
253
+ )
254
+ self.tconv1 = nn.Sequential(
255
+ nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
256
+ nn.BatchNorm2d(256),
257
+ nn.ReLU(True),
258
+ )
259
+ self.tconv2 = nn.Sequential(
260
+ nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
261
+ nn.BatchNorm2d(192),
262
+ nn.ReLU(True),
263
+ )
264
+ self.tconv3 = nn.Sequential(
265
+ nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
266
+ nn.BatchNorm2d(64),
267
+ nn.ReLU(True),
268
+ )
269
+ self.tconv4_list = torch.nn.ModuleList(
270
+ [
271
+ nn.Sequential(
272
+ nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
273
+ nn.Tanh(),
274
+ )
275
+ for _ in range(6)
276
+ ]
277
+ )
278
+
279
+ self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
280
+
281
+ def forward(self, x, measurements):
282
+ mask = measurements[:, :6]
283
+ mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
284
+ velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
285
+ velocity = velocity.repeat(1, 32, 2, 2)
286
+
287
+ n, d, c = x.shape
288
+ x = x.transpose(1, 2)
289
+ x = x.view(n, -1, 2, 2)
290
+ x = torch.cat([x, velocity], dim=1)
291
+ x = self.tconv0(x)
292
+ x = self.tconv1(x)
293
+ x = self.tconv2(x)
294
+ x = self.tconv3(x)
295
+ x = self.upsample(x)
296
+ xs = []
297
+ for i in range(6):
298
+ xt = self.tconv4_list[i](x)
299
+ xs.append(xt)
300
+ xs = torch.stack(xs, dim=1)
301
+ x = torch.sum(xs * mask, dim=1)
302
+ x = self.spatial_softmax(x)
303
+ return x
304
+
305
+
306
+ class LinearWaypointsPredictor(nn.Module):
307
+ def __init__(self, input_dim, cumsum=True):
308
+ super().__init__()
309
+ self.cumsum = cumsum
310
+ self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
311
+ self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
312
+ self.head_relu = nn.ReLU(inplace=True)
313
+ self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
314
+
315
+ def forward(self, x, measurements):
316
+ # input shape: n 10 embed_dim
317
+ bs, n, dim = x.shape
318
+ x = x + self.rank_embed
319
+ x = x.reshape(-1, dim)
320
+
321
+ mask = measurements[:, :6]
322
+ mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
323
+
324
+ rs = []
325
+ for i in range(6):
326
+ res = self.head_fc1_list[i](x)
327
+ res = self.head_relu(res)
328
+ res = self.head_fc2_list[i](res)
329
+ rs.append(res)
330
+ rs = torch.stack(rs, 1)
331
+ x = torch.sum(rs * mask, dim=1)
332
+
333
+ x = x.view(bs, n, 2)
334
+ if self.cumsum:
335
+ x = torch.cumsum(x, 1)
336
+ return x
337
+
338
+
339
+ class GRUWaypointsPredictor(nn.Module):
340
+ def __init__(self, input_dim, waypoints=10):
341
+ super().__init__()
342
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
343
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
344
+ self.encoder = nn.Linear(2, 64)
345
+ self.decoder = nn.Linear(64, 2)
346
+ self.waypoints = waypoints
347
+
348
+ def forward(self, x, target_point):
349
+ bs = x.shape[0]
350
+ z = self.encoder(target_point).unsqueeze(0)
351
+ output, _ = self.gru(x, z)
352
+ output = output.reshape(bs * self.waypoints, -1)
353
+ output = self.decoder(output).reshape(bs, self.waypoints, 2)
354
+ output = torch.cumsum(output, 1)
355
+ return output
356
+
357
+ class GRUWaypointsPredictorWithCommand(nn.Module):
358
+ def __init__(self, input_dim, waypoints=10):
359
+ super().__init__()
360
+ # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
361
+ self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
362
+ self.encoder = nn.Linear(2, 64)
363
+ self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
364
+ self.waypoints = waypoints
365
+
366
+ def forward(self, x, target_point, measurements):
367
+ bs, n, dim = x.shape
368
+ mask = measurements[:, :6, None, None]
369
+ mask = mask.repeat(1, 1, self.waypoints, 2)
370
+
371
+ z = self.encoder(target_point).unsqueeze(0)
372
+ outputs = []
373
+ for i in range(6):
374
+ output, _ = self.grus[i](x, z)
375
+ output = output.reshape(bs * self.waypoints, -1)
376
+ output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
377
+ output = torch.cumsum(output, 1)
378
+ outputs.append(output)
379
+ outputs = torch.stack(outputs, 1)
380
+ output = torch.sum(outputs * mask, dim=1)
381
+ return output
382
+
383
+
384
+ class TransformerDecoder(nn.Module):
385
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
386
+ super().__init__()
387
+ self.layers = _get_clones(decoder_layer, num_layers)
388
+ self.num_layers = num_layers
389
+ self.norm = norm
390
+ self.return_intermediate = return_intermediate
391
+
392
+ def forward(
393
+ self,
394
+ tgt,
395
+ memory,
396
+ tgt_mask: Optional[Tensor] = None,
397
+ memory_mask: Optional[Tensor] = None,
398
+ tgt_key_padding_mask: Optional[Tensor] = None,
399
+ memory_key_padding_mask: Optional[Tensor] = None,
400
+ pos: Optional[Tensor] = None,
401
+ query_pos: Optional[Tensor] = None,
402
+ ):
403
+ output = tgt
404
+
405
+ intermediate = []
406
+
407
+ for layer in self.layers:
408
+ output = layer(
409
+ output,
410
+ memory,
411
+ tgt_mask=tgt_mask,
412
+ memory_mask=memory_mask,
413
+ tgt_key_padding_mask=tgt_key_padding_mask,
414
+ memory_key_padding_mask=memory_key_padding_mask,
415
+ pos=pos,
416
+ query_pos=query_pos,
417
+ )
418
+ if self.return_intermediate:
419
+ intermediate.append(self.norm(output))
420
+
421
+ if self.norm is not None:
422
+ output = self.norm(output)
423
+ if self.return_intermediate:
424
+ intermediate.pop()
425
+ intermediate.append(output)
426
+
427
+ if self.return_intermediate:
428
+ return torch.stack(intermediate)
429
+
430
+ return output.unsqueeze(0)
431
+
432
+
433
+ class TransformerEncoderLayer(nn.Module):
434
+ def __init__(
435
+ self,
436
+ d_model,
437
+ nhead,
438
+ dim_feedforward=2048,
439
+ dropout=0.1,
440
+ activation=nn.ReLU(),
441
+ normalize_before=False,
442
+ ):
443
+ super().__init__()
444
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
445
+ # Implementation of Feedforward model
446
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
447
+ self.dropout = nn.Dropout(dropout)
448
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
449
+
450
+ self.norm1 = nn.LayerNorm(d_model)
451
+ self.norm2 = nn.LayerNorm(d_model)
452
+ self.dropout1 = nn.Dropout(dropout)
453
+ self.dropout2 = nn.Dropout(dropout)
454
+
455
+ self.activation = activation()
456
+ self.normalize_before = normalize_before
457
+
458
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
459
+ return tensor if pos is None else tensor + pos
460
+
461
+ def forward_post(
462
+ self,
463
+ src,
464
+ src_mask: Optional[Tensor] = None,
465
+ src_key_padding_mask: Optional[Tensor] = None,
466
+ pos: Optional[Tensor] = None,
467
+ ):
468
+ q = k = self.with_pos_embed(src, pos)
469
+ src2 = self.self_attn(
470
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
471
+ )[0]
472
+ src = src + self.dropout1(src2)
473
+ src = self.norm1(src)
474
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
475
+ src = src + self.dropout2(src2)
476
+ src = self.norm2(src)
477
+ return src
478
+
479
+ def forward_pre(
480
+ self,
481
+ src,
482
+ src_mask: Optional[Tensor] = None,
483
+ src_key_padding_mask: Optional[Tensor] = None,
484
+ pos: Optional[Tensor] = None,
485
+ ):
486
+ src2 = self.norm1(src)
487
+ q = k = self.with_pos_embed(src2, pos)
488
+ src2 = self.self_attn(
489
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
490
+ )[0]
491
+ src = src + self.dropout1(src2)
492
+ src2 = self.norm2(src)
493
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
494
+ src = src + self.dropout2(src2)
495
+ return src
496
+
497
+ def forward(
498
+ self,
499
+ src,
500
+ src_mask: Optional[Tensor] = None,
501
+ src_key_padding_mask: Optional[Tensor] = None,
502
+ pos: Optional[Tensor] = None,
503
+ ):
504
+ if self.normalize_before:
505
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
506
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
507
+
508
+
509
+ class TransformerDecoderLayer(nn.Module):
510
+ def __init__(
511
+ self,
512
+ d_model,
513
+ nhead,
514
+ dim_feedforward=2048,
515
+ dropout=0.1,
516
+ activation=nn.ReLU(),
517
+ normalize_before=False,
518
+ ):
519
+ super().__init__()
520
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
521
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
522
+ # Implementation of Feedforward model
523
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
524
+ self.dropout = nn.Dropout(dropout)
525
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
526
+
527
+ self.norm1 = nn.LayerNorm(d_model)
528
+ self.norm2 = nn.LayerNorm(d_model)
529
+ self.norm3 = nn.LayerNorm(d_model)
530
+ self.dropout1 = nn.Dropout(dropout)
531
+ self.dropout2 = nn.Dropout(dropout)
532
+ self.dropout3 = nn.Dropout(dropout)
533
+
534
+ self.activation = activation()
535
+ self.normalize_before = normalize_before
536
+
537
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
538
+ return tensor if pos is None else tensor + pos
539
+
540
+ def forward_post(
541
+ self,
542
+ tgt,
543
+ memory,
544
+ tgt_mask: Optional[Tensor] = None,
545
+ memory_mask: Optional[Tensor] = None,
546
+ tgt_key_padding_mask: Optional[Tensor] = None,
547
+ memory_key_padding_mask: Optional[Tensor] = None,
548
+ pos: Optional[Tensor] = None,
549
+ query_pos: Optional[Tensor] = None,
550
+ ):
551
+ q = k = self.with_pos_embed(tgt, query_pos)
552
+ tgt2 = self.self_attn(
553
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
554
+ )[0]
555
+ tgt = tgt + self.dropout1(tgt2)
556
+ tgt = self.norm1(tgt)
557
+ tgt2 = self.multihead_attn(
558
+ query=self.with_pos_embed(tgt, query_pos),
559
+ key=self.with_pos_embed(memory, pos),
560
+ value=memory,
561
+ attn_mask=memory_mask,
562
+ key_padding_mask=memory_key_padding_mask,
563
+ )[0]
564
+ tgt = tgt + self.dropout2(tgt2)
565
+ tgt = self.norm2(tgt)
566
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
567
+ tgt = tgt + self.dropout3(tgt2)
568
+ tgt = self.norm3(tgt)
569
+ return tgt
570
+
571
+ def forward_pre(
572
+ self,
573
+ tgt,
574
+ memory,
575
+ tgt_mask: Optional[Tensor] = None,
576
+ memory_mask: Optional[Tensor] = None,
577
+ tgt_key_padding_mask: Optional[Tensor] = None,
578
+ memory_key_padding_mask: Optional[Tensor] = None,
579
+ pos: Optional[Tensor] = None,
580
+ query_pos: Optional[Tensor] = None,
581
+ ):
582
+ tgt2 = self.norm1(tgt)
583
+ q = k = self.with_pos_embed(tgt2, query_pos)
584
+ tgt2 = self.self_attn(
585
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
586
+ )[0]
587
+ tgt = tgt + self.dropout1(tgt2)
588
+ tgt2 = self.norm2(tgt)
589
+ tgt2 = self.multihead_attn(
590
+ query=self.with_pos_embed(tgt2, query_pos),
591
+ key=self.with_pos_embed(memory, pos),
592
+ value=memory,
593
+ attn_mask=memory_mask,
594
+ key_padding_mask=memory_key_padding_mask,
595
+ )[0]
596
+ tgt = tgt + self.dropout2(tgt2)
597
+ tgt2 = self.norm3(tgt)
598
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
599
+ tgt = tgt + self.dropout3(tgt2)
600
+ return tgt
601
+
602
+ def forward(
603
+ self,
604
+ tgt,
605
+ memory,
606
+ tgt_mask: Optional[Tensor] = None,
607
+ memory_mask: Optional[Tensor] = None,
608
+ tgt_key_padding_mask: Optional[Tensor] = None,
609
+ memory_key_padding_mask: Optional[Tensor] = None,
610
+ pos: Optional[Tensor] = None,
611
+ query_pos: Optional[Tensor] = None,
612
+ ):
613
+ if self.normalize_before:
614
+ return self.forward_pre(
615
+ tgt,
616
+ memory,
617
+ tgt_mask,
618
+ memory_mask,
619
+ tgt_key_padding_mask,
620
+ memory_key_padding_mask,
621
+ pos,
622
+ query_pos,
623
+ )
624
+ return self.forward_post(
625
+ tgt,
626
+ memory,
627
+ tgt_mask,
628
+ memory_mask,
629
+ tgt_key_padding_mask,
630
+ memory_key_padding_mask,
631
+ pos,
632
+ query_pos,
633
+ )
634
+
635
+
636
+ def _get_clones(module, N):
637
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
638
+
639
+
640
+ def _get_activation_fn(activation):
641
+ """Return an activation function given a string"""
642
+ if activation == "relu":
643
+ return F.relu
644
+ if activation == "gelu":
645
+ return F.gelu
646
+ if activation == "glu":
647
+ return F.glu
648
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
649
+
650
+
651
+ def build_attn_mask(mask_type):
652
+ mask = torch.ones((151, 151), dtype=torch.bool).cuda()
653
+ if mask_type == "seperate_all":
654
+ mask[:50, :50] = False
655
+ mask[50:67, 50:67] = False
656
+ mask[67:84, 67:84] = False
657
+ mask[84:101, 84:101] = False
658
+ mask[101:151, 101:151] = False
659
+ elif mask_type == "seperate_view":
660
+ mask[:50, :50] = False
661
+ mask[50:67, 50:67] = False
662
+ mask[67:84, 67:84] = False
663
+ mask[84:101, 84:101] = False
664
+ mask[101:151, :] = False
665
+ mask[:, 101:151] = False
666
+ return mask
667
+ # class InterfuserModel(nn.Module):
668
+
669
+ class InterfuserModel(nn.Module):
670
+ def __init__(
671
+ self,
672
+ img_size=224,
673
+ multi_view_img_size=112,
674
+ patch_size=8,
675
+ in_chans=3,
676
+ embed_dim=768,
677
+ enc_depth=6,
678
+ dec_depth=6,
679
+ dim_feedforward=2048,
680
+ normalize_before=False,
681
+ rgb_backbone_name="r50",
682
+ lidar_backbone_name="r50",
683
+ num_heads=8,
684
+ norm_layer=None,
685
+ dropout=0.1,
686
+ end2end=False,
687
+ direct_concat=False,
688
+ separate_view_attention=False,
689
+ separate_all_attention=False,
690
+ act_layer=None,
691
+ weight_init="",
692
+ freeze_num=-1,
693
+ with_lidar=False,
694
+ with_right_left_sensors=False,
695
+ with_center_sensor=False,
696
+ traffic_pred_head_type="det",
697
+ waypoints_pred_head="heatmap",
698
+ reverse_pos=True,
699
+ use_different_backbone=False,
700
+ use_view_embed=False,
701
+ use_mmad_pretrain=None,
702
+ ):
703
+ super().__init__()
704
+ self.traffic_pred_head_type = traffic_pred_head_type
705
+ self.num_features = (
706
+ self.embed_dim
707
+ ) = embed_dim # num_features for consistency with other models
708
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
709
+ act_layer = act_layer or nn.GELU
710
+
711
+ self.reverse_pos = reverse_pos
712
+ self.waypoints_pred_head = waypoints_pred_head
713
+ self.with_lidar = with_lidar
714
+ self.with_right_left_sensors = with_right_left_sensors
715
+ self.with_center_sensor = with_center_sensor
716
+
717
+ self.direct_concat = direct_concat
718
+ self.separate_view_attention = separate_view_attention
719
+ self.separate_all_attention = separate_all_attention
720
+ self.end2end = end2end
721
+ self.use_view_embed = use_view_embed
722
+
723
+ if self.direct_concat:
724
+ in_chans = in_chans * 4
725
+ self.with_center_sensor = False
726
+ self.with_right_left_sensors = False
727
+
728
+ if self.separate_view_attention:
729
+ self.attn_mask = build_attn_mask("seperate_view")
730
+ elif self.separate_all_attention:
731
+ self.attn_mask = build_attn_mask("seperate_all")
732
+ else:
733
+ self.attn_mask = None
734
+
735
+ if use_different_backbone:
736
+ if rgb_backbone_name == "r50":
737
+ self.rgb_backbone = resnet50d(
738
+ pretrained=True,
739
+ in_chans=in_chans,
740
+ features_only=True,
741
+ out_indices=[4],
742
+ )
743
+ elif rgb_backbone_name == "r26":
744
+ self.rgb_backbone = resnet26d(
745
+ pretrained=True,
746
+ in_chans=in_chans,
747
+ features_only=True,
748
+ out_indices=[4],
749
+ )
750
+ elif rgb_backbone_name == "r18":
751
+ self.rgb_backbone = resnet18d(
752
+ pretrained=True,
753
+ in_chans=in_chans,
754
+ features_only=True,
755
+ out_indices=[4],
756
+ )
757
+ if lidar_backbone_name == "r50":
758
+ self.lidar_backbone = resnet50d(
759
+ pretrained=False,
760
+ in_chans=in_chans,
761
+ features_only=True,
762
+ out_indices=[4],
763
+ )
764
+ elif lidar_backbone_name == "r26":
765
+ self.lidar_backbone = resnet26d(
766
+ pretrained=False,
767
+ in_chans=in_chans,
768
+ features_only=True,
769
+ out_indices=[4],
770
+ )
771
+ elif lidar_backbone_name == "r18":
772
+ self.lidar_backbone = resnet18d(
773
+ pretrained=False, in_chans=3, features_only=True, out_indices=[4]
774
+ )
775
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
776
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
777
+
778
+ if use_mmad_pretrain:
779
+ params = torch.load(use_mmad_pretrain)["state_dict"]
780
+ updated_params = OrderedDict()
781
+ for key in params:
782
+ if "backbone" in key:
783
+ updated_params[key.replace("backbone.", "")] = params[key]
784
+ self.rgb_backbone.load_state_dict(updated_params)
785
+
786
+ self.rgb_patch_embed = rgb_embed_layer(
787
+ img_size=img_size,
788
+ patch_size=patch_size,
789
+ in_chans=in_chans,
790
+ embed_dim=embed_dim,
791
+ )
792
+ self.lidar_patch_embed = lidar_embed_layer(
793
+ img_size=img_size,
794
+ patch_size=patch_size,
795
+ in_chans=3,
796
+ embed_dim=embed_dim,
797
+ )
798
+ else:
799
+ if rgb_backbone_name == "r50":
800
+ self.rgb_backbone = resnet50d(
801
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
802
+ )
803
+ elif rgb_backbone_name == "r101":
804
+ self.rgb_backbone = resnet101d(
805
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
806
+ )
807
+ elif rgb_backbone_name == "r26":
808
+ self.rgb_backbone = resnet26d(
809
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
810
+ )
811
+ elif rgb_backbone_name == "r18":
812
+ self.rgb_backbone = resnet18d(
813
+ pretrained=True, in_chans=3, features_only=True, out_indices=[4]
814
+ )
815
+ embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
816
+
817
+ self.rgb_patch_embed = embed_layer(
818
+ img_size=img_size,
819
+ patch_size=patch_size,
820
+ in_chans=in_chans,
821
+ embed_dim=embed_dim,
822
+ )
823
+ self.lidar_patch_embed = embed_layer(
824
+ img_size=img_size,
825
+ patch_size=patch_size,
826
+ in_chans=in_chans,
827
+ embed_dim=embed_dim,
828
+ )
829
+
830
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
831
+ self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
832
+
833
+ if self.end2end:
834
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
835
+ self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
836
+ elif self.waypoints_pred_head == "heatmap":
837
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
838
+ self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
839
+ else:
840
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
841
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
842
+
843
+ if self.end2end:
844
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
845
+ elif self.waypoints_pred_head == "heatmap":
846
+ self.waypoints_generator = MultiPath_Generator(
847
+ embed_dim + 32, embed_dim, 10
848
+ )
849
+ elif self.waypoints_pred_head == "gru":
850
+ self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
851
+ elif self.waypoints_pred_head == "gru-command":
852
+ self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
853
+ elif self.waypoints_pred_head == "linear":
854
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
855
+ elif self.waypoints_pred_head == "linear-sum":
856
+ self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
857
+
858
+ self.junction_pred_head = nn.Linear(embed_dim, 2)
859
+ self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
860
+ self.stop_sign_head = nn.Linear(embed_dim, 2)
861
+
862
+ if self.traffic_pred_head_type == "det":
863
+ self.traffic_pred_head = nn.Sequential(
864
+ *[
865
+ nn.Linear(embed_dim + 32, 64),
866
+ nn.ReLU(),
867
+ nn.Linear(64, 7),
868
+ # nn.Sigmoid(),
869
+ ]
870
+ )
871
+ elif self.traffic_pred_head_type == "seg":
872
+ self.traffic_pred_head = nn.Sequential(
873
+ *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
874
+ )
875
+
876
+ self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
877
+
878
+ encoder_layer = TransformerEncoderLayer(
879
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
880
+ )
881
+ self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
882
+
883
+ decoder_layer = TransformerDecoderLayer(
884
+ embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
885
+ )
886
+ decoder_norm = nn.LayerNorm(embed_dim)
887
+ self.decoder = TransformerDecoder(
888
+ decoder_layer, dec_depth, decoder_norm, return_intermediate=False
889
+ )
890
+ self.reset_parameters()
891
+
892
+ def reset_parameters(self):
893
+ nn.init.uniform_(self.global_embed)
894
+ nn.init.uniform_(self.view_embed)
895
+ nn.init.uniform_(self.query_embed)
896
+ nn.init.uniform_(self.query_pos_embed)
897
+
898
+ def forward_features(
899
+ self,
900
+ front_image,
901
+ left_image,
902
+ right_image,
903
+ front_center_image,
904
+ lidar,
905
+ measurements,
906
+ ):
907
+ features = []
908
+
909
+ # Front view processing
910
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
911
+ if self.use_view_embed:
912
+ front_image_token = (
913
+ front_image_token
914
+ + self.view_embed[:, :, 0:1, :]
915
+ + self.position_encoding(front_image_token)
916
+ )
917
+ else:
918
+ front_image_token = front_image_token + self.position_encoding(
919
+ front_image_token
920
+ )
921
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
922
+ front_image_token_global = (
923
+ front_image_token_global
924
+ + self.view_embed[:, :, 0, :]
925
+ + self.global_embed[:, :, 0:1]
926
+ )
927
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
928
+ features.extend([front_image_token, front_image_token_global])
929
+
930
+ if self.with_right_left_sensors:
931
+ # Left view processing
932
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
933
+ if self.use_view_embed:
934
+ left_image_token = (
935
+ left_image_token
936
+ + self.view_embed[:, :, 1:2, :]
937
+ + self.position_encoding(left_image_token)
938
+ )
939
+ else:
940
+ left_image_token = left_image_token + self.position_encoding(
941
+ left_image_token
942
+ )
943
+ left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
944
+ left_image_token_global = (
945
+ left_image_token_global
946
+ + self.view_embed[:, :, 1, :]
947
+ + self.global_embed[:, :, 1:2]
948
+ )
949
+ left_image_token_global = left_image_token_global.permute(2, 0, 1)
950
+
951
+ # Right view processing
952
+ right_image_token, right_image_token_global = self.rgb_patch_embed(
953
+ right_image
954
+ )
955
+ if self.use_view_embed:
956
+ right_image_token = (
957
+ right_image_token
958
+ + self.view_embed[:, :, 2:3, :]
959
+ + self.position_encoding(right_image_token)
960
+ )
961
+ else:
962
+ right_image_token = right_image_token + self.position_encoding(
963
+ right_image_token
964
+ )
965
+ right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
966
+ right_image_token_global = (
967
+ right_image_token_global
968
+ + self.view_embed[:, :, 2, :]
969
+ + self.global_embed[:, :, 2:3]
970
+ )
971
+ right_image_token_global = right_image_token_global.permute(2, 0, 1)
972
+
973
+ features.extend(
974
+ [
975
+ left_image_token,
976
+ left_image_token_global,
977
+ right_image_token,
978
+ right_image_token_global,
979
+ ]
980
+ )
981
+
982
+ if self.with_center_sensor:
983
+ # Front center view processing
984
+ (
985
+ front_center_image_token,
986
+ front_center_image_token_global,
987
+ ) = self.rgb_patch_embed(front_center_image)
988
+ if self.use_view_embed:
989
+ front_center_image_token = (
990
+ front_center_image_token
991
+ + self.view_embed[:, :, 3:4, :]
992
+ + self.position_encoding(front_center_image_token)
993
+ )
994
+ else:
995
+ front_center_image_token = (
996
+ front_center_image_token
997
+ + self.position_encoding(front_center_image_token)
998
+ )
999
+
1000
+ front_center_image_token = front_center_image_token.flatten(2).permute(
1001
+ 2, 0, 1
1002
+ )
1003
+ front_center_image_token_global = (
1004
+ front_center_image_token_global
1005
+ + self.view_embed[:, :, 3, :]
1006
+ + self.global_embed[:, :, 3:4]
1007
+ )
1008
+ front_center_image_token_global = front_center_image_token_global.permute(
1009
+ 2, 0, 1
1010
+ )
1011
+ features.extend([front_center_image_token, front_center_image_token_global])
1012
+
1013
+ if self.with_lidar:
1014
+ lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
1015
+ if self.use_view_embed:
1016
+ lidar_token = (
1017
+ lidar_token
1018
+ + self.view_embed[:, :, 4:5, :]
1019
+ + self.position_encoding(lidar_token)
1020
+ )
1021
+ else:
1022
+ lidar_token = lidar_token + self.position_encoding(lidar_token)
1023
+ lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
1024
+ lidar_token_global = (
1025
+ lidar_token_global
1026
+ + self.view_embed[:, :, 4, :]
1027
+ + self.global_embed[:, :, 4:5]
1028
+ )
1029
+ lidar_token_global = lidar_token_global.permute(2, 0, 1)
1030
+ features.extend([lidar_token, lidar_token_global])
1031
+
1032
+ features = torch.cat(features, 0)
1033
+ return features
1034
+
1035
+ def forward(self, x):
1036
+ front_image = x["rgb"]
1037
+ left_image = x["rgb_left"]
1038
+ right_image = x["rgb_right"]
1039
+ front_center_image = x["rgb_center"]
1040
+ measurements = x["measurements"]
1041
+ target_point = x["target_point"]
1042
+ lidar = x["lidar"]
1043
+
1044
+ if self.direct_concat:
1045
+ img_size = front_image.shape[-1]
1046
+ left_image = torch.nn.functional.interpolate(
1047
+ left_image, size=(img_size, img_size)
1048
+ )
1049
+ right_image = torch.nn.functional.interpolate(
1050
+ right_image, size=(img_size, img_size)
1051
+ )
1052
+ front_center_image = torch.nn.functional.interpolate(
1053
+ front_center_image, size=(img_size, img_size)
1054
+ )
1055
+ front_image = torch.cat(
1056
+ [front_image, left_image, right_image, front_center_image], dim=1
1057
+ )
1058
+ features = self.forward_features(
1059
+ front_image,
1060
+ left_image,
1061
+ right_image,
1062
+ front_center_image,
1063
+ lidar,
1064
+ measurements,
1065
+ )
1066
+
1067
+ bs = front_image.shape[0]
1068
+
1069
+ if self.end2end:
1070
+ tgt = self.query_pos_embed.repeat(bs, 1, 1)
1071
+ else:
1072
+ tgt = self.position_encoding(
1073
+ torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
1074
+ )
1075
+ tgt = tgt.flatten(2)
1076
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
1077
+ tgt = tgt.permute(2, 0, 1)
1078
+
1079
+ memory = self.encoder(features, mask=self.attn_mask)
1080
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
1081
+
1082
+ hs = hs.permute(1, 0, 2) # Batchsize , N, C
1083
+ if self.end2end:
1084
+ waypoints = self.waypoints_generator(hs, target_point)
1085
+ return waypoints
1086
+
1087
+ if self.waypoints_pred_head != "heatmap":
1088
+ traffic_feature = hs[:, :400]
1089
+ is_junction_feature = hs[:, 400]
1090
+ traffic_light_state_feature = hs[:, 400]
1091
+ stop_sign_feature = hs[:, 400]
1092
+ waypoints_feature = hs[:, 401:411]
1093
+ else:
1094
+ traffic_feature = hs[:, :400]
1095
+ is_junction_feature = hs[:, 400]
1096
+ traffic_light_state_feature = hs[:, 400]
1097
+ stop_sign_feature = hs[:, 400]
1098
+ waypoints_feature = hs[:, 401:405]
1099
+
1100
+ if self.waypoints_pred_head == "heatmap":
1101
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1102
+ elif self.waypoints_pred_head == "gru":
1103
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
1104
+ elif self.waypoints_pred_head == "gru-command":
1105
+ waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
1106
+ elif self.waypoints_pred_head == "linear":
1107
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1108
+ elif self.waypoints_pred_head == "linear-sum":
1109
+ waypoints = self.waypoints_generator(waypoints_feature, measurements)
1110
+
1111
+ is_junction = self.junction_pred_head(is_junction_feature)
1112
+ traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
1113
+ stop_sign = self.stop_sign_head(stop_sign_feature)
1114
+
1115
+ velocity = measurements[:, 6:7].unsqueeze(-1)
1116
+ velocity = velocity.repeat(1, 400, 32)
1117
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
1118
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
1119
+ return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
1120
+ def load_pretrained(self, model_path, strict=False):
1121
+ """
1122
+ تحميل الأوزان المدربة مسبقاً - نسخة محسنة
1123
+
1124
+ Args:
1125
+ model_path (str): مسار ملف الأوزان
1126
+ strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح
1127
+ """
1128
+ if not model_path or not Path(model_path).exists():
1129
+ logging.warning(f"ملف الأوزان غير موجود: {model_path}")
1130
+ logging.info("سيتم استخدام أوزان عشوائية")
1131
+ return False
1132
+
1133
+ try:
1134
+ logging.info(f"محاولة تحميل الأوزان من: {model_path}")
1135
+
1136
+ # تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ
1137
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
1138
+
1139
+ # استخراج state_dict من أنواع مختلفة من ملفات الحفظ
1140
+ if isinstance(checkpoint, dict):
1141
+ if 'model_state_dict' in checkpoint:
1142
+ state_dict = checkpoint['model_state_dict']
1143
+ logging.info("تم العثور على 'model_state_dict' في الملف")
1144
+ elif 'state_dict' in checkpoint:
1145
+ state_dict = checkpoint['state_dict']
1146
+ logging.info("تم العثور على 'state_dict' في الملف")
1147
+ elif 'model' in checkpoint:
1148
+ state_dict = checkpoint['model']
1149
+ logging.info("تم العثور على 'model' في الملف")
1150
+ else:
1151
+ state_dict = checkpoint
1152
+ logging.info("استخدام الملف كـ state_dict مباشرة")
1153
+ else:
1154
+ state_dict = checkpoint
1155
+ logging.info("استخدام الملف كـ state_dict مباشرة")
1156
+
1157
+ # تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة)
1158
+ clean_state_dict = OrderedDict()
1159
+ for k, v in state_dict.items():
1160
+ # إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً
1161
+ clean_key = k[7:] if k.startswith('module.') else k
1162
+ clean_state_dict[clean_key] = v
1163
+
1164
+ # تحميل الأوزان
1165
+ missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict)
1166
+
1167
+ # تقرير حالة التحميل
1168
+ if missing_keys:
1169
+ logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}")
1170
+
1171
+ if unexpected_keys:
1172
+ logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}")
1173
+
1174
+ if not missing_keys and not unexpected_keys:
1175
+ logging.info("✅ تم تحميل جميع الأوزان بنجاح تام")
1176
+ elif not strict:
1177
+ logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)")
1178
+
1179
+ return True
1180
+
1181
+ except Exception as e:
1182
+ logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}")
1183
+ logging.info("سيتم استخدام أوزان عشوائية")
1184
+ return False
1185
+
1186
+
1187
+ # ============================================================================
1188
+ # دوال مساعدة لتحميل النموذج
1189
+ # ============================================================================
1190
+
1191
+ def load_and_prepare_model(config, device):
1192
+ """
1193
+ يقوم بإنشاء النموذج وتحميل الأوزان المدربة مسبقًا.
1194
+
1195
+ Args:
1196
+ config (dict): إعدادات النموذج والمسارات
1197
+ device (torch.device): الجهاز المستهدف (CPU/GPU)
1198
+
1199
+ Returns:
1200
+ InterfuserModel: النموذج المحمل
1201
+ """
1202
+ try:
1203
+ # إنشاء النموذج
1204
+ model = InterfuserModel(**config.get('model_params', {})).to(device)
1205
+ logging.info(f"تم إنشاء النموذج على الجهاز: {device}")
1206
+
1207
+ # تحميل الأوزان إذ�� كان المسار محدد
1208
+ checkpoint_path = config.get('paths', {}).get('pretrained_weights')
1209
+ if checkpoint_path:
1210
+ success = model.load_pretrained(checkpoint_path, strict=False)
1211
+ if success:
1212
+ logging.info("✅ تم تحميل النموذج والأوزان بنجاح")
1213
+ else:
1214
+ logging.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية")
1215
+ else:
1216
+ logging.info("لم يتم تحديد مسار الأوزان، سيتم استخدام أوزان عشوائية")
1217
+
1218
+ # وضع النموذج في وضع التقييم
1219
+ model.eval()
1220
+
1221
+ return model
1222
+
1223
+ except Exception as e:
1224
+ logging.error(f"خطأ في إنشاء النموذج: {str(e)}")
1225
+ raise
1226
+
1227
+
1228
+ def create_model_config(model_path="model/best_model.pth", **model_params):
1229
+ """
1230
+ إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب
1231
+
1232
+ Args:
1233
+ model_path (str): مسار ملف الأوزان
1234
+ **model_params: معاملات النموذج الإضافية
1235
+
1236
+ Returns:
1237
+ dict: إعدادات النموذج
1238
+ """
1239
+ # الإعدادات الصحيحة من كونفيج التدريب الأصلي
1240
+ training_config_params = {
1241
+ "img_size": 224,
1242
+ "embed_dim": 256, # مهم: هذه القيمة من التدريب الأصلي
1243
+ "enc_depth": 6,
1244
+ "dec_depth": 6,
1245
+ "rgb_backbone_name": 'r50',
1246
+ "lidar_backbone_name": 'r18',
1247
+ "waypoints_pred_head": 'gru',
1248
+ "use_different_backbone": True,
1249
+ "with_lidar": False,
1250
+ "with_right_left_sensors": False,
1251
+ "with_center_sensor": False,
1252
+
1253
+ # إعدادات إضافية من الكونفيج الأصلي
1254
+ "multi_view_img_size": 112,
1255
+ "patch_size": 8,
1256
+ "in_chans": 3,
1257
+ "dim_feedforward": 2048,
1258
+ "normalize_before": False,
1259
+ "num_heads": 8,
1260
+ "dropout": 0.1,
1261
+ "end2end": False,
1262
+ "direct_concat": False,
1263
+ "separate_view_attention": False,
1264
+ "separate_all_attention": False,
1265
+ "freeze_num": -1,
1266
+ "traffic_pred_head_type": "det",
1267
+ "reverse_pos": True,
1268
+ "use_view_embed": False,
1269
+ "use_mmad_pretrain": None,
1270
+ }
1271
+
1272
+ # دمج المعاملات المخصصة مع الإعدادات من التدريب
1273
+ training_config_params.update(model_params)
1274
+
1275
+ config = {
1276
+ 'model_params': training_config_params,
1277
+ 'paths': {
1278
+ 'pretrained_weights': model_path
1279
+ },
1280
+
1281
+ # إضافة إعدادات الشبكة من التدريب
1282
+ 'grid_conf': {
1283
+ 'h': 20, 'w': 20,
1284
+ 'x_res': 1.0, 'y_res': 1.0,
1285
+ 'y_min': 0.0, 'y_max': 20.0,
1286
+ 'x_min': -10.0, 'x_max': 10.0,
1287
+ },
1288
+
1289
+ # معلومات إضافية عن التدريب
1290
+ 'training_info': {
1291
+ 'original_project': 'Interfuser_Finetuning',
1292
+ 'run_name': 'Finetune_Focus_on_Detection_v5',
1293
+ 'focus': 'traffic_detection_and_iou',
1294
+ 'backbone': 'ResNet50 + ResNet18',
1295
+ 'trained_on': 'PDM_Lite_Carla'
1296
+ }
1297
+ }
1298
+
1299
+ return config
1300
+
1301
+
1302
+ def get_training_config():
1303
+ """
1304
+ إرجاع إعدادات التدريب الأصلية للمرجع
1305
+ هذه الإعدادات توضح كيف تم تدريب النموذج
1306
+ """
1307
+ return {
1308
+ 'project_info': {
1309
+ 'project': 'Interfuser_Finetuning',
1310
+ 'entity': None,
1311
+ 'run_name': 'Finetune_Focus_on_Detection_v5'
1312
+ },
1313
+ 'training': {
1314
+ 'epochs': 50,
1315
+ 'batch_size': 8,
1316
+ 'num_workers': 2,
1317
+ 'learning_rate': 1e-4, # معدل تعلم منخفض للـ Fine-tuning
1318
+ 'weight_decay': 1e-2,
1319
+ 'patience': 15,
1320
+ 'clip_grad_norm': 1.0,
1321
+ },
1322
+ 'loss_weights': {
1323
+ 'iou': 2.0, # أولوية قصوى لدقة الصناديق
1324
+ 'traffic_map': 25.0, # تركيز عالي على اكتشاف الكائنات
1325
+ 'waypoints': 1.0, # مرجع أساسي
1326
+ 'junction': 0.25, # مهام متقنة بالفعل
1327
+ 'traffic_light': 0.5,
1328
+ 'stop_sign': 0.25,
1329
+ },
1330
+ 'data_split': {
1331
+ 'strategy': 'interleaved',
1332
+ 'segment_length': 100,
1333
+ 'validation_frequency': 10,
1334
+ },
1335
+ 'transforms': {
1336
+ 'use_data_augmentation': False, # معطل للتركيز على البيانات الأصلية
1337
+ }
1338
+ }