mohammed-aljafry commited on
Commit
ddeccc7
·
verified ·
1 Parent(s): 47f9ae4

Add model architecture code

Browse files
Files changed (1) hide show
  1. modeling_interfuser.py +579 -0
modeling_interfuser.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_interfuser.py
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+ from transformers.utils.generic import ModelOutput
8
+ from functools import partial
9
+ import math
10
+ from collections import OrderedDict
11
+ import copy
12
+ from typing import Optional, List, Tuple, Union
13
+ from torch import Tensor
14
+ from dataclasses import dataclass
15
+ import numpy as np
16
+
17
+ # ==============================================================================
18
+ # ملاحظة: هذا الملف يحتوي على كل التعريفات اللازمة لتشغيل النموذج.
19
+ # ==============================================================================
20
+
21
+ # --- الكلاسات الوهمية للـ Backbones ---
22
+ # في الاستخدام الحقيقي، يجب استبدالها بالشبكات الحقيقية من مكتبة مثل timm.
23
+ class DummyResNet(nn.Module):
24
+ def __init__(self, name="r26", **kwargs):
25
+ super().__init__()
26
+ out_channels = 512 if name == "r18" else 2048
27
+ self.features = nn.Sequential(
28
+ nn.Conv2d(kwargs.get('in_chans', 3), out_channels, kernel_size=7, stride=2, padding=3),
29
+ nn.AdaptiveAvgPool2d((1, 1))
30
+ )
31
+ self.num_features = out_channels
32
+ def forward(self, x):
33
+ return [self.features(x)]
34
+
35
+ def resnet18d(**kwargs): return DummyResNet(name="r18", **kwargs)
36
+ def resnet26d(**kwargs): return DummyResNet(name="r26", **kwargs)
37
+ def resnet50d(**kwargs): return DummyResNet(name="r50", **kwargs)
38
+ def to_2tuple(x): return (x, x) if not isinstance(x, tuple) else x
39
+
40
+ # --- جميع الكلاسات المساعدة ---
41
+ # (HybridEmbed, PositionEmbeddingSine, TransformerEncoder, SpatialSoftmax, etc.)
42
+ # تم نسخها بالكامل هنا.
43
+ class HybridEmbed(nn.Module):
44
+ def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
45
+ super().__init__()
46
+ self.img_size = to_2tuple(img_size)
47
+ self.patch_size = to_2tuple(patch_size)
48
+ self.backbone = backbone
49
+ if feature_size is None:
50
+ with torch.no_grad():
51
+ training = backbone.training
52
+ if training: backbone.eval()
53
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
54
+ if isinstance(o, (list, tuple)): o = o[-1]
55
+ feature_dim = o.shape[1]
56
+ backbone.train(training)
57
+ else:
58
+ feature_dim = self.backbone.num_features
59
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
60
+ def forward(self, x):
61
+ x = self.backbone(x)
62
+ if isinstance(x, (list, tuple)): x = x[-1]
63
+ x = self.proj(x)
64
+ global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
65
+ return x, global_x
66
+
67
+ # ... (يتم لصق بقية الكلاسات المساعدة هنا: PositionEmbeddingSine, Transformer*...)
68
+ # (للاختصار، لن أعرضها كلها مرة أخرى، ولكن يجب أن تكون كلها في هذا الملف)
69
+ class PositionEmbeddingSine(nn.Module):
70
+ def __init__( self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
71
+ super().__init__()
72
+ self.num_pos_feats = num_pos_feats
73
+ self.temperature = temperature
74
+ self.normalize = normalize
75
+ if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed")
76
+ if scale is None: scale = 2 * math.pi
77
+ self.scale = scale
78
+ def forward(self, tensor):
79
+ x = tensor; bs, _, h, w = x.shape
80
+ not_mask = torch.ones((bs, h, w), device=x.device)
81
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
82
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
83
+ if self.normalize:
84
+ eps = 1e-6
85
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
86
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
87
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
88
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
89
+ pos_x = x_embed[:, :, :, None] / dim_t; pos_y = y_embed[:, :, :, None] / dim_t
90
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
91
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
92
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
93
+ return pos
94
+ # (لصق باقي الكلاسات المساعدة هنا)
95
+ class TransformerEncoder(nn.Module):
96
+ def __init__(self, encoder_layer, num_layers, norm=None):
97
+ super().__init__()
98
+ self.layers = _get_clones(encoder_layer, num_layers)
99
+ self.num_layers = num_layers
100
+ self.norm = norm
101
+
102
+ def forward(
103
+ self,
104
+ src,
105
+ mask: Optional[Tensor] = None,
106
+ src_key_padding_mask: Optional[Tensor] = None,
107
+ pos: Optional[Tensor] = None,
108
+ ):
109
+ output = src
110
+ for layer in self.layers:
111
+ output = layer(
112
+ output,
113
+ src_mask=mask,
114
+ src_key_padding_mask=src_key_padding_mask,
115
+ pos=pos,
116
+ )
117
+ if self.norm is not None:
118
+ output = self.norm(output)
119
+ return output
120
+
121
+
122
+ class SpatialSoftmax(nn.Module):
123
+ def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
124
+ super().__init__()
125
+ self.data_format = data_format
126
+ self.height = height
127
+ self.width = width
128
+ self.channel = channel
129
+ if temperature:
130
+ self.temperature = nn.Parameter(torch.ones(1) * temperature)
131
+ else:
132
+ self.temperature = 1.0
133
+ pos_x, pos_y = np.meshgrid(
134
+ np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
135
+ )
136
+ pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
137
+ pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
138
+ self.register_buffer("pos_x", pos_x)
139
+ self.register_buffer("pos_y", pos_y)
140
+
141
+ def forward(self, feature):
142
+ if self.data_format == "NHWC":
143
+ feature = (
144
+ feature.transpose(1, 3)
145
+ .transpose(2, 3)
146
+ .view(-1, self.height * self.width)
147
+ )
148
+ else:
149
+ feature = feature.view(-1, self.height * self.width)
150
+ weight = F.softmax(feature / self.temperature, dim=-1)
151
+ expected_x = torch.sum(self.pos_x * weight, dim=1, keepdim=True)
152
+ expected_y = torch.sum(self.pos_y * weight, dim=1, keepdim=True)
153
+ expected_xy = torch.cat([expected_x, expected_y], 1)
154
+ feature_keypoints = expected_xy.view(-1, self.channel, 2)
155
+ feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
156
+ feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
157
+ return feature_keypoints
158
+
159
+ # ... (بقية الكلاسات المساعدة مثل MultiPath_Generator, LinearWaypointsPredictor, etc. توضع هنا)
160
+ class MultiPath_Generator(nn.Module):
161
+ def __init__(self, in_channel, embed_dim, out_channel):
162
+ super().__init__()
163
+ self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
164
+ self.tconv0 = nn.Sequential(nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True))
165
+ self.tconv1 = nn.Sequential(nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True))
166
+ self.tconv2 = nn.Sequential(nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False), nn.BatchNorm2d(192), nn.ReLU(True))
167
+ self.tconv3 = nn.Sequential(nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.ReLU(True))
168
+ self.tconv4_list = torch.nn.ModuleList([nn.Sequential(nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False), nn.Tanh()) for _ in range(6)])
169
+ self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
170
+
171
+ def forward(self, x, measurements):
172
+ mask = measurements[:, :6].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
173
+ velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1).repeat(1, 32, 2, 2)
174
+ n, d, c = x.shape
175
+ x = x.transpose(1, 2).view(n, -1, 2, 2)
176
+ x = torch.cat([x, velocity], dim=1)
177
+ x = self.tconv0(x); x = self.tconv1(x); x = self.tconv2(x); x = self.tconv3(x)
178
+ x = self.upsample(x)
179
+ xs = torch.stack([self.tconv4_list[i](x) for i in range(6)], dim=1)
180
+ x = torch.sum(xs * mask, dim=1)
181
+ return self.spatial_softmax(x)
182
+
183
+ class LinearWaypointsPredictor(nn.Module):
184
+ def __init__(self, input_dim, cumsum=True):
185
+ super().__init__()
186
+ self.cumsum = cumsum
187
+ self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
188
+ self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
189
+ self.head_relu = nn.ReLU(inplace=True)
190
+ self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
191
+
192
+ def forward(self, x, measurements):
193
+ bs, n, dim = x.shape
194
+ x = (x + self.rank_embed).reshape(-1, dim)
195
+ mask = measurements[:, :6].unsqueeze(-1).repeat(n, 1, 2)
196
+ rs = [self.head_fc2_list[i](self.head_relu(self.head_fc1_list[i](x))) for i in range(6)]
197
+ x = torch.sum(torch.stack(rs, 1) * mask, dim=1).view(bs, n, 2)
198
+ return torch.cumsum(x, 1) if self.cumsum else x
199
+
200
+ class GRUWaypointsPredictor(nn.Module):
201
+ def __init__(self, input_dim, waypoints=10):
202
+ super().__init__()
203
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
204
+ self.encoder = nn.Linear(2, 64)
205
+ self.decoder = nn.Linear(64, 2)
206
+ self.waypoints = waypoints
207
+
208
+ def forward(self, x, target_point):
209
+ bs = x.shape[0]
210
+ z = self.encoder(target_point).unsqueeze(0)
211
+ output, _ = self.gru(x, z)
212
+ output = self.decoder(output.reshape(bs * self.waypoints, -1)).reshape(bs, self.waypoints, 2)
213
+ return torch.cumsum(output, 1)
214
+
215
+ class GRUWaypointsPredictorWithCommand(nn.Module):
216
+ def __init__(self, input_dim, waypoints=10):
217
+ super().__init__()
218
+ self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
219
+ self.encoder = nn.Linear(2, 64)
220
+ self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
221
+ self.waypoints = waypoints
222
+
223
+ def forward(self, x, target_point, measurements):
224
+ bs, n, dim = x.shape
225
+ mask = measurements[:, :6, None, None].repeat(1, 1, self.waypoints, 2)
226
+ z = self.encoder(target_point).unsqueeze(0)
227
+ outputs = []
228
+ for i in range(6):
229
+ output, _ = self.grus[i](x, z)
230
+ output = self.decoders[i](output.reshape(bs * self.waypoints, -1)).reshape(bs, self.waypoints, 2)
231
+ outputs.append(torch.cumsum(output, 1))
232
+ return torch.sum(torch.stack(outputs, 1) * mask, dim=1)
233
+
234
+ class TransformerDecoder(nn.Module):
235
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
236
+ super().__init__()
237
+ self.layers = _get_clones(decoder_layer, num_layers)
238
+ self.num_layers = num_layers
239
+ self.norm = norm
240
+ self.return_intermediate = return_intermediate
241
+
242
+ def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
243
+ output = tgt
244
+ intermediate = []
245
+ for layer in self.layers:
246
+ output = layer(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos)
247
+ if self.return_intermediate: intermediate.append(self.norm(output))
248
+ if self.norm is not None:
249
+ output = self.norm(output)
250
+ if self.return_intermediate: intermediate.pop(); intermediate.append(output)
251
+ return torch.stack(intermediate) if self.return_intermediate else output.unsqueeze(0)
252
+
253
+ class TransformerEncoderLayer(nn.Module):
254
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU, normalize_before=False):
255
+ super().__init__()
256
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
257
+ self.linear1 = nn.Linear(d_model, dim_feedforward); self.dropout = nn.Dropout(dropout); self.linear2 = nn.Linear(dim_feedforward, d_model)
258
+ self.norm1 = nn.LayerNorm(d_model); self.norm2 = nn.LayerNorm(d_model)
259
+ self.dropout1 = nn.Dropout(dropout); self.dropout2 = nn.Dropout(dropout)
260
+ self.activation = activation(); self.normalize_before = normalize_before
261
+
262
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos
263
+ def forward_post(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
264
+ q = k = self.with_pos_embed(src, pos)
265
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
266
+ src = src + self.dropout1(src2); src = self.norm1(src)
267
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
268
+ src = src + self.dropout2(src2); return self.norm2(src)
269
+ def forward_pre(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
270
+ src2 = self.norm1(src); q = k = self.with_pos_embed(src2, pos)
271
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
272
+ src = src + self.dropout1(src2); src2 = self.norm2(src)
273
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
274
+ return src + self.dropout2(src2)
275
+ def forward(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
276
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos) if self.normalize_before else self.forward_post(src, src_mask, src_key_padding_mask, pos)
277
+
278
+ class TransformerDecoderLayer(nn.Module):
279
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU, normalize_before=False):
280
+ super().__init__()
281
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
282
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
283
+ self.linear1 = nn.Linear(d_model, dim_feedforward); self.dropout = nn.Dropout(dropout); self.linear2 = nn.Linear(dim_feedforward, d_model)
284
+ self.norm1 = nn.LayerNorm(d_model); self.norm2 = nn.LayerNorm(d_model); self.norm3 = nn.LayerNorm(d_model)
285
+ self.dropout1 = nn.Dropout(dropout); self.dropout2 = nn.Dropout(dropout); self.dropout3 = nn.Dropout(dropout)
286
+ self.activation = activation(); self.normalize_before = normalize_before
287
+
288
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]): return tensor if pos is None else tensor + pos
289
+ def forward_post(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
290
+ q = k = self.with_pos_embed(tgt, query_pos)
291
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
292
+ tgt = tgt + self.dropout1(tgt2); tgt = self.norm1(tgt)
293
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
294
+ tgt = tgt + self.dropout2(tgt2); tgt = self.norm2(tgt)
295
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
296
+ tgt = tgt + self.dropout3(tgt2); return self.norm3(tgt)
297
+ def forward_pre(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
298
+ tgt2 = self.norm1(tgt); q = k = self.with_pos_embed(tgt2, query_pos)
299
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
300
+ tgt = tgt + self.dropout1(tgt2); tgt2 = self.norm2(tgt)
301
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
302
+ tgt = tgt + self.dropout2(tgt2); tgt2 = self.norm3(tgt)
303
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
304
+ return tgt + self.dropout3(tgt2)
305
+ def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
306
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) if self.normalize_before else self.forward_post(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
307
+
308
+ def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
309
+ def _get_activation_fn(activation):
310
+ if activation == "relu": return F.relu
311
+ if activation == "gelu": return F.gelu
312
+ if activation == "glu": return F.glu
313
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
314
+
315
+ def build_attn_mask(mask_type, device):
316
+ mask = torch.ones((151, 151), dtype=torch.bool, device=device)
317
+ if mask_type == "seperate_all":
318
+ mask[:50, :50] = False; mask[50:67, 50:67] = False; mask[67:84, 67:84] = False
319
+ mask[84:101, 84:101] = False; mask[101:151, 101:151] = False
320
+ elif mask_type == "seperate_view":
321
+ mask[:50, :50] = False; mask[50:67, 50:67] = False; mask[67:84, 67:84] = False
322
+ mask[84:101, 84:101] = False; mask[101:151, :] = False; mask[:, 101:151] = False
323
+ return mask
324
+
325
+ # --- تعريف فئة الإعدادات (Config) ---
326
+ class InterfuserConfig(PretrainedConfig):
327
+ model_type = "interfuser"
328
+ def __init__(self, img_size=224, embed_dim=256, enc_depth=6, dec_depth=6, num_heads=8, rgb_backbone_name="r26", lidar_backbone_name="r18", use_different_backbone=True, waypoints_pred_head="gru", **kwargs):
329
+ super().__init__(**kwargs)
330
+ self.img_size = img_size
331
+ self.embed_dim = embed_dim
332
+ self.enc_depth = enc_depth
333
+ self.dec_depth = dec_depth
334
+ self.num_heads = num_heads
335
+ self.rgb_backbone_name = rgb_backbone_name
336
+ self.lidar_backbone_name = lidar_backbone_name
337
+ self.use_different_backbone = use_different_backbone
338
+ self.waypoints_pred_head = waypoints_pred_head
339
+ # أضف أي إعدادات أخرى ضرورية هنا
340
+ self.patch_size=8; self.in_chans=3; self.dim_feedforward=2048; self.normalize_before=False; self.dropout=0.1; self.end2end=False; self.direct_concat=False; self.separate_view_attention=False; self.separate_all_attention=False; self.freeze_num=-1; self.with_lidar=True; self.with_right_left_sensors=True; self.with_center_sensor=True; self.traffic_pred_head_type="det"; self.reverse_pos=True; self.use_view_embed=True; self.use_mmad_pretrain=None
341
+
342
+ # --- تعريف فئة مخرجات النموذج (ModelOutput) ---
343
+ @dataclass
344
+ class InterfuserOutput(ModelOutput):
345
+ waypoints: torch.FloatTensor = None
346
+ traffic_predictions: Optional[torch.FloatTensor] = None
347
+ is_junction: Optional[torch.FloatTensor] = None
348
+ traffic_light_state: Optional[torch.FloatTensor] = None
349
+ stop_sign: Optional[torch.FloatTensor] = None
350
+ traffic_features: Optional[torch.FloatTensor] = None
351
+
352
+ # --- تعريف النموذج الأصلي (Interfuser) ---
353
+ # (يجب لصق كلاس Interfuser بالكامل هنا)
354
+ class Interfuser(nn.Module):
355
+ def __init__(self, config: InterfuserConfig):
356
+ super().__init__()
357
+ self.config = config
358
+
359
+ # استخلاص المتغيرات من كائن الـ config
360
+ embed_dim = config.embed_dim
361
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
362
+ act_layer = nn.GELU
363
+
364
+ self.num_features = self.embed_dim = embed_dim
365
+ self.traffic_pred_head_type = config.traffic_pred_head_type
366
+ self.waypoints_pred_head = config.waypoints_pred_head
367
+ self.end2end = config.end2end
368
+
369
+ # ... باقي متغيرات الـ init من الكود الأصلي
370
+ self.direct_concat = config.direct_concat
371
+ self.with_center_sensor = config.with_center_sensor
372
+ self.with_right_left_sensors = config.with_right_left_sensors
373
+ self.with_lidar = config.with_lidar
374
+ self.use_view_embed = config.use_view_embed
375
+ self.separate_view_attention = config.separate_view_attention
376
+ self.separate_all_attention = config.separate_all_attention
377
+
378
+ if self.direct_concat:
379
+ in_chans = config.in_chans * 4
380
+ self.with_center_sensor = False
381
+ self.with_right_left_sensors = False
382
+ else:
383
+ in_chans = config.in_chans
384
+
385
+ if self.separate_view_attention:
386
+ self.attn_mask = build_attn_mask("seperate_view", device=self.device)
387
+ elif self.separate_all_attention:
388
+ self.attn_mask = build_attn_mask("seperate_all", device=self.device)
389
+ else:
390
+ self.attn_mask = None
391
+
392
+ # تعريف الـ backbones (استخدام DummyResNet كمثال)
393
+ # في الاستخدام الحقيقي، استبدل هذا بالتحميل الفعلي للشبكات
394
+ backbone_map = {"r50": resnet50d, "r26": resnet26d, "r18": resnet18d}
395
+
396
+ # RGB Backbone
397
+ rgb_backbone_class = backbone_map.get(config.rgb_backbone_name, resnet26d)
398
+ self.rgb_backbone = rgb_backbone_class(pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4])
399
+
400
+ # Lidar Backbone
401
+ if config.use_different_backbone:
402
+ lidar_backbone_class = backbone_map.get(config.lidar_backbone_name, resnet26d)
403
+ self.lidar_backbone = lidar_backbone_class(pretrained=False, in_chans=3, features_only=True, out_indices=[4])
404
+ else:
405
+ self.lidar_backbone = self.rgb_backbone
406
+
407
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
408
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
409
+
410
+ self.rgb_patch_embed = rgb_embed_layer(
411
+ img_size=config.img_size, patch_size=config.patch_size, in_chans=in_chans, embed_dim=embed_dim
412
+ )
413
+ self.lidar_patch_embed = lidar_embed_layer(
414
+ img_size=config.img_size, patch_size=config.patch_size, in_chans=3, embed_dim=embed_dim
415
+ )
416
+
417
+ # ... باقي تعريفات الطبقات من الكود الأصلي
418
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
419
+ self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
420
+
421
+ if self.end2end:
422
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
423
+ self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
424
+ elif self.waypoints_pred_head == "heatmap":
425
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
426
+ self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
427
+ else:
428
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
429
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
430
+
431
+ # Waypoints Generator
432
+ if self.end2end: self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
433
+ elif self.waypoints_pred_head == "heatmap": self.waypoints_generator = MultiPath_Generator(embed_dim + 32, embed_dim, 10)
434
+ elif self.waypoints_pred_head == "gru": self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
435
+ elif self.waypoints_pred_head == "gru-command": self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
436
+ elif self.waypoints_pred_head == "linear": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=False)
437
+ elif self.waypoints_pred_head == "linear-sum": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
438
+
439
+ self.junction_pred_head = nn.Linear(embed_dim, 2)
440
+ self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
441
+ self.stop_sign_head = nn.Linear(embed_dim, 2)
442
+
443
+ self.traffic_pred_head = nn.Sequential(*[nn.Linear(embed_dim + 32, 64), nn.ReLU(), nn.Linear(64, 7), nn.Sigmoid()])
444
+ self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
445
+
446
+ encoder_layer = TransformerEncoderLayer(embed_dim, config.num_heads, config.dim_feedforward, config.dropout, act_layer, config.normalize_before)
447
+ self.encoder = TransformerEncoder(encoder_layer, config.enc_depth, None)
448
+
449
+ decoder_layer = TransformerDecoderLayer(embed_dim, config.num_heads, config.dim_feedforward, config.dropout, act_layer, config.normalize_before)
450
+ decoder_norm = nn.LayerNorm(embed_dim)
451
+ self.decoder = TransformerDecoder(decoder_layer, config.dec_depth, decoder_norm, return_intermediate=False)
452
+
453
+ self.reset_parameters()
454
+
455
+ def reset_parameters(self):
456
+ nn.init.uniform_(self.global_embed)
457
+ nn.init.uniform_(self.view_embed)
458
+ nn.init.uniform_(self.query_embed)
459
+ nn.init.uniform_(self.query_pos_embed)
460
+
461
+ # ... يجب نسخ دالتي forward_features و forward من الكود الأصلي هنا بالكامل
462
+ def forward_features(self, front_image, left_image, right_image, front_center_image, lidar, measurements):
463
+ features = []
464
+ # Front view processing
465
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
466
+ if self.use_view_embed:
467
+ front_image_token = front_image_token + self.view_embed[:, :, 0:1, :] + self.position_encoding(front_image_token)
468
+ else:
469
+ front_image_token = front_image_token + self.position_encoding(front_image_token)
470
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
471
+ front_image_token_global = front_image_token_global + self.view_embed[:, :, 0, :] + self.global_embed[:, :, 0:1]
472
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
473
+ features.extend([front_image_token, front_image_token_global])
474
+
475
+ if self.with_right_left_sensors:
476
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
477
+ if self.use_view_embed: left_image_token = left_image_token + self.view_embed[:, :, 1:2, :] + self.position_encoding(left_image_token)
478
+ else: left_image_token = left_image_token + self.position_encoding(left_image_token)
479
+ left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
480
+ left_image_token_global = left_image_token_global + self.view_embed[:, :, 1, :] + self.global_embed[:, :, 1:2]
481
+ left_image_token_global = left_image_token_global.permute(2, 0, 1)
482
+
483
+ right_image_token, right_image_token_global = self.rgb_patch_embed(right_image)
484
+ if self.use_view_embed: right_image_token = right_image_token + self.view_embed[:, :, 2:3, :] + self.position_encoding(right_image_token)
485
+ else: right_image_token = right_image_token + self.position_encoding(right_image_token)
486
+ right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
487
+ right_image_token_global = right_image_token_global + self.view_embed[:, :, 2, :] + self.global_embed[:, :, 2:3]
488
+ right_image_token_global = right_image_token_global.permute(2, 0, 1)
489
+ features.extend([left_image_token, left_image_token_global, right_image_token, right_image_token_global])
490
+
491
+ if self.with_center_sensor:
492
+ front_center_image_token, front_center_image_token_global = self.rgb_patch_embed(front_center_image)
493
+ if self.use_view_embed: front_center_image_token = front_center_image_token + self.view_embed[:, :, 3:4, :] + self.position_encoding(front_center_image_token)
494
+ else: front_center_image_token = front_center_image_token + self.position_encoding(front_center_image_token)
495
+ front_center_image_token = front_center_image_token.flatten(2).permute(2, 0, 1)
496
+ front_center_image_token_global = front_center_image_token_global + self.view_embed[:, :, 3, :] + self.global_embed[:, :, 3:4]
497
+ front_center_image_token_global = front_center_image_token_global.permute(2, 0, 1)
498
+ features.extend([front_center_image_token, front_center_image_token_global])
499
+
500
+ if self.with_lidar:
501
+ lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
502
+ if self.use_view_embed: lidar_token = lidar_token + self.view_embed[:, :, 4:5, :] + self.position_encoding(lidar_token)
503
+ else: lidar_token = lidar_token + self.position_encoding(lidar_token)
504
+ lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
505
+ lidar_token_global = lidar_token_global + self.view_embed[:, :, 4, :] + self.global_embed[:, :, 4:5]
506
+ lidar_token_global = lidar_token_global.permute(2, 0, 1)
507
+ features.extend([lidar_token, lidar_token_global])
508
+
509
+ return torch.cat(features, 0)
510
+
511
+ def forward(self, x):
512
+ front_image, left_image, right_image, front_center_image = x["rgb"], x["rgb_left"], x["rgb_right"], x["rgb_center"]
513
+ measurements, target_point, lidar = x["measurements"], x["target_point"], x["lidar"]
514
+
515
+ if self.direct_concat:
516
+ img_size = front_image.shape[-1]
517
+ left_image = F.interpolate(left_image, size=(img_size, img_size))
518
+ right_image = F.interpolate(right_image, size=(img_size, img_size))
519
+ front_center_image = F.interpolate(front_center_image, size=(img_size, img_size))
520
+ front_image = torch.cat([front_image, left_image, right_image, front_center_image], dim=1)
521
+
522
+ features = self.forward_features(front_image, left_image, right_image, front_center_image, lidar, measurements)
523
+ bs = front_image.shape[0]
524
+
525
+ if self.end2end:
526
+ tgt = self.query_pos_embed.repeat(bs, 1, 1)
527
+ else:
528
+ tgt = self.position_encoding(torch.ones((bs, 1, 20, 20), device=x["rgb"].device)).flatten(2)
529
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
530
+ tgt = tgt.permute(2, 0, 1)
531
+
532
+ memory = self.encoder(features, mask=self.attn_mask)
533
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0].permute(1, 0, 2)
534
+
535
+ if self.end2end:
536
+ waypoints = self.waypoints_generator(hs, target_point)
537
+ return waypoints
538
+
539
+ if self.waypoints_pred_head != "heatmap":
540
+ traffic_feature, is_junction_feature, waypoints_feature = hs[:, :400], hs[:, 400], hs[:, 401:411]
541
+ else:
542
+ traffic_feature, is_junction_feature, waypoints_feature = hs[:, :400], hs[:, 400], hs[:, 401:405]
543
+
544
+ # Waypoints prediction
545
+ if self.waypoints_pred_head == "heatmap": waypoints = self.waypoints_generator(waypoints_feature, measurements)
546
+ elif self.waypoints_pred_head.startswith("gru"): waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) if "command" in self.waypoints_pred_head else self.waypoints_generator(waypoints_feature, target_point)
547
+ elif self.waypoints_pred_head.startswith("linear"): waypoints = self.waypoints_generator(waypoints_feature, measurements)
548
+
549
+ is_junction = self.junction_pred_head(is_junction_feature)
550
+ traffic_light_state = self.traffic_light_pred_head(is_junction_feature) # Original code uses same feature
551
+ stop_sign = self.stop_sign_head(is_junction_feature) # Original code uses same feature
552
+
553
+ velocity = measurements[:, 6:7].unsqueeze(-1).repeat(1, 400, 32)
554
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
555
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
556
+
557
+ return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
558
+
559
+
560
+ # --- تعريف الغلاف الرئيسي (Wrapper) ---
561
+ # هذا هو الكلاس الذي سيتم استدعاؤه بواسطة AutoModel
562
+ class InterfuserForHuggingFace(PreTrainedModel):
563
+ config_class = InterfuserConfig
564
+ def __init__(self, config: InterfuserConfig):
565
+ super().__init__(config)
566
+ self.model = Interfuser(config) # سيتم بناء النموذج الأصلي هنا
567
+ def _init_weights(self, module):
568
+ if hasattr(module, 'reset_parameters'):
569
+ module.reset_parameters()
570
+ def forward(self, rgb: torch.FloatTensor, rgb_left: torch.FloatTensor, rgb_right: torch.FloatTensor, rgb_center: torch.FloatTensor, lidar: torch.FloatTensor, measurements: torch.FloatTensor, target_point: torch.FloatTensor, return_dict: Optional[bool] = None) -> Union[Tuple, InterfuserOutput]:
571
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
572
+ inputs = {"rgb": rgb, "rgb_left": rgb_left, "rgb_right": rgb_right, "rgb_center": rgb_center, "lidar": lidar, "measurements": measurements, "target_point": target_point}
573
+ outputs = self.model(inputs)
574
+ if self.config.end2end:
575
+ if not return_dict: return (outputs,)
576
+ return InterfuserOutput(waypoints=outputs)
577
+ traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature = outputs
578
+ if not return_dict: return outputs
579
+ return InterfuserOutput(waypoints=waypoints, traffic_predictions=traffic, is_junction=is_junction, traffic_light_state=traffic_light_state, stop_sign=stop_sign, traffic_features=traffic_feature)