mohammed-aljafry commited on
Commit
cfc2547
·
verified ·
1 Parent(s): e170a78

Add model architecture code

Browse files
Files changed (1) hide show
  1. modeling_interfuser.py +326 -78
modeling_interfuser.py CHANGED
@@ -1,5 +1,3 @@
1
- # modeling_interfuser.py
2
-
3
  import torch
4
  from torch import nn
5
  import torch.nn.functional as F
@@ -9,74 +7,170 @@ from functools import partial
9
  import math
10
  from collections import OrderedDict
11
  import copy
12
- from typing import Optional, List, Tuple, Union
13
  from torch import Tensor
14
  from dataclasses import dataclass
15
- import numpy as np
16
 
17
  # ==============================================================================
18
- # ملاحظة: هذا الملف يحتوي على كل التعريفات اللازمة لتشغيل النموذج.
 
19
  # ==============================================================================
20
 
21
- # --- الكلاسات الوهمية للـ Backbones ---
22
- # في الاستخدام الحقيقي، يجب استبدالها بالشبكات الحقيقية من مكتبة مثل timm.
 
 
 
 
 
 
 
 
 
 
 
23
  class DummyResNet(nn.Module):
 
 
 
 
24
  def __init__(self, name="r26", **kwargs):
25
  super().__init__()
26
- out_channels = 512 if name == "r18" else 2048
 
 
 
 
 
 
 
27
  self.features = nn.Sequential(
28
  nn.Conv2d(kwargs.get('in_chans', 3), out_channels, kernel_size=7, stride=2, padding=3),
29
  nn.AdaptiveAvgPool2d((1, 1))
30
  )
31
  self.num_features = out_channels
 
32
  def forward(self, x):
33
  return [self.features(x)]
34
 
35
- def resnet18d(**kwargs): return DummyResNet(name="r18", **kwargs)
36
- def resnet26d(**kwargs): return DummyResNet(name="r26", **kwargs)
37
- def resnet50d(**kwargs): return DummyResNet(name="r50", **kwargs)
38
- def to_2tuple(x): return (x, x) if not isinstance(x, tuple) else x
39
 
40
- # --- جميع الكلاسات المساعدة ---
41
- # (HybridEmbed, PositionEmbeddingSine, TransformerEncoder, SpatialSoftmax, etc.)
42
- # تم نسخها بالكامل هنا.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  class HybridEmbed(nn.Module):
44
  def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
45
  super().__init__()
46
- self.img_size = to_2tuple(img_size)
 
 
 
 
 
 
 
47
  self.patch_size = to_2tuple(patch_size)
48
  self.backbone = backbone
 
49
  if feature_size is None:
50
  with torch.no_grad():
51
  training = backbone.training
52
- if training: backbone.eval()
53
- o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
54
- if isinstance(o, (list, tuple)): o = o[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  feature_dim = o.shape[1]
56
  backbone.train(training)
57
  else:
58
  feature_dim = self.backbone.num_features
 
59
  self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
 
 
60
  def forward(self, x):
61
  x = self.backbone(x)
62
- if isinstance(x, (list, tuple)): x = x[-1]
 
63
  x = self.proj(x)
64
  global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
65
  return x, global_x
66
 
67
- # ... (يتم لصق بقية الكلاسات المساعدة هنا: PositionEmbeddingSine, Transformer*...)
68
- # (للاختصار، لن أعرضها كلها مرة أخرى، ولكن يجب أن تكون كلها في هذا الملف)
69
  class PositionEmbeddingSine(nn.Module):
70
- def __init__( self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
 
 
71
  super().__init__()
72
  self.num_pos_feats = num_pos_feats
73
  self.temperature = temperature
74
  self.normalize = normalize
75
- if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed")
76
- if scale is None: scale = 2 * math.pi
 
 
77
  self.scale = scale
 
78
  def forward(self, tensor):
79
- x = tensor; bs, _, h, w = x.shape
 
80
  not_mask = torch.ones((bs, h, w), device=x.device)
81
  y_embed = not_mask.cumsum(1, dtype=torch.float32)
82
  x_embed = not_mask.cumsum(2, dtype=torch.float32)
@@ -84,14 +178,22 @@ class PositionEmbeddingSine(nn.Module):
84
  eps = 1e-6
85
  y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
86
  x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
 
87
  dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
88
  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
89
- pos_x = x_embed[:, :, :, None] / dim_t; pos_y = y_embed[:, :, :, None] / dim_t
90
- pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
91
- pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
 
 
 
 
 
 
92
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
93
  return pos
94
- # (لصق باقي الكلاسات المساعدة هنا)
 
95
  class TransformerEncoder(nn.Module):
96
  def __init__(self, encoder_layer, num_layers, norm=None):
97
  super().__init__()
@@ -322,26 +424,80 @@ def build_attn_mask(mask_type, device):
322
  mask[84:101, 84:101] = False; mask[101:151, :] = False; mask[:, 101:151] = False
323
  return mask
324
 
325
- # --- تعريف فئة الإعدادات (Config) ---
 
 
 
326
  class InterfuserConfig(PretrainedConfig):
327
  model_type = "interfuser"
328
- def __init__(self, img_size=224, embed_dim=256, enc_depth=6, dec_depth=6, num_heads=8, rgb_backbone_name="r26", lidar_backbone_name="r18", use_different_backbone=True, waypoints_pred_head="gru", **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  super().__init__(**kwargs)
330
  self.img_size = img_size
 
 
331
  self.embed_dim = embed_dim
332
  self.enc_depth = enc_depth
333
  self.dec_depth = dec_depth
334
- self.num_heads = num_heads
 
335
  self.rgb_backbone_name = rgb_backbone_name
336
  self.lidar_backbone_name = lidar_backbone_name
337
- self.use_different_backbone = use_different_backbone
 
 
 
 
 
 
 
 
 
 
338
  self.waypoints_pred_head = waypoints_pred_head
339
- # أضف أي إعدادات أخرى ضرورية هنا
340
- self.patch_size=8; self.in_chans=3; self.dim_feedforward=2048; self.normalize_before=False; self.dropout=0.1; self.end2end=False; self.direct_concat=False; self.separate_view_attention=False; self.separate_all_attention=False; self.freeze_num=-1; self.with_lidar=True; self.with_right_left_sensors=True; self.with_center_sensor=True; self.traffic_pred_head_type="det"; self.reverse_pos=True; self.use_view_embed=True; self.use_mmad_pretrain=None
 
 
 
 
 
 
341
 
342
- # --- تعريف فئة مخرجات النموذج (ModelOutput) ---
343
  @dataclass
344
  class InterfuserOutput(ModelOutput):
 
 
 
345
  waypoints: torch.FloatTensor = None
346
  traffic_predictions: Optional[torch.FloatTensor] = None
347
  is_junction: Optional[torch.FloatTensor] = None
@@ -349,13 +505,15 @@ class InterfuserOutput(ModelOutput):
349
  stop_sign: Optional[torch.FloatTensor] = None
350
  traffic_features: Optional[torch.FloatTensor] = None
351
 
352
- # --- تعريف النموذج الأصلي (Interfuser) ---
353
- # (يجب لصق كلاس Interfuser بالكامل هنا)
 
 
354
  class Interfuser(nn.Module):
355
  def __init__(self, config: InterfuserConfig):
356
  super().__init__()
357
  self.config = config
358
-
359
  # استخلاص المتغيرات من كائن الـ config
360
  embed_dim = config.embed_dim
361
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
@@ -365,7 +523,7 @@ class Interfuser(nn.Module):
365
  self.traffic_pred_head_type = config.traffic_pred_head_type
366
  self.waypoints_pred_head = config.waypoints_pred_head
367
  self.end2end = config.end2end
368
-
369
  # ... باقي متغيرات الـ init من الكود الأصلي
370
  self.direct_concat = config.direct_concat
371
  self.with_center_sensor = config.with_center_sensor
@@ -374,7 +532,7 @@ class Interfuser(nn.Module):
374
  self.use_view_embed = config.use_view_embed
375
  self.separate_view_attention = config.separate_view_attention
376
  self.separate_all_attention = config.separate_all_attention
377
-
378
  if self.direct_concat:
379
  in_chans = config.in_chans * 4
380
  self.with_center_sensor = False
@@ -392,11 +550,11 @@ class Interfuser(nn.Module):
392
  # تعريف الـ backbones (استخدام DummyResNet كمثال)
393
  # في الاستخدام الحقيقي، استبدل هذا بالتحميل الفعلي للشبكات
394
  backbone_map = {"r50": resnet50d, "r26": resnet26d, "r18": resnet18d}
395
-
396
  # RGB Backbone
397
  rgb_backbone_class = backbone_map.get(config.rgb_backbone_name, resnet26d)
398
  self.rgb_backbone = rgb_backbone_class(pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4])
399
-
400
  # Lidar Backbone
401
  if config.use_different_backbone:
402
  lidar_backbone_class = backbone_map.get(config.lidar_backbone_name, resnet26d)
@@ -435,21 +593,21 @@ class Interfuser(nn.Module):
435
  elif self.waypoints_pred_head == "gru-command": self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
436
  elif self.waypoints_pred_head == "linear": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=False)
437
  elif self.waypoints_pred_head == "linear-sum": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
438
-
439
  self.junction_pred_head = nn.Linear(embed_dim, 2)
440
  self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
441
  self.stop_sign_head = nn.Linear(embed_dim, 2)
442
-
443
  self.traffic_pred_head = nn.Sequential(*[nn.Linear(embed_dim + 32, 64), nn.ReLU(), nn.Linear(64, 7), nn.Sigmoid()])
444
  self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
445
-
446
  encoder_layer = TransformerEncoderLayer(embed_dim, config.num_heads, config.dim_feedforward, config.dropout, act_layer, config.normalize_before)
447
  self.encoder = TransformerEncoder(encoder_layer, config.enc_depth, None)
448
-
449
  decoder_layer = TransformerDecoderLayer(embed_dim, config.num_heads, config.dim_feedforward, config.dropout, act_layer, config.normalize_before)
450
  decoder_norm = nn.LayerNorm(embed_dim)
451
  self.decoder = TransformerDecoder(decoder_layer, config.dec_depth, decoder_norm, return_intermediate=False)
452
-
453
  self.reset_parameters()
454
 
455
  def reset_parameters(self):
@@ -505,7 +663,7 @@ class Interfuser(nn.Module):
505
  lidar_token_global = lidar_token_global + self.view_embed[:, :, 4, :] + self.global_embed[:, :, 4:5]
506
  lidar_token_global = lidar_token_global.permute(2, 0, 1)
507
  features.extend([lidar_token, lidar_token_global])
508
-
509
  return torch.cat(features, 0)
510
 
511
  def forward(self, x):
@@ -518,7 +676,7 @@ class Interfuser(nn.Module):
518
  right_image = F.interpolate(right_image, size=(img_size, img_size))
519
  front_center_image = F.interpolate(front_center_image, size=(img_size, img_size))
520
  front_image = torch.cat([front_image, left_image, right_image, front_center_image], dim=1)
521
-
522
  features = self.forward_features(front_image, left_image, right_image, front_center_image, lidar, measurements)
523
  bs = front_image.shape[0]
524
 
@@ -545,7 +703,7 @@ class Interfuser(nn.Module):
545
  if self.waypoints_pred_head == "heatmap": waypoints = self.waypoints_generator(waypoints_feature, measurements)
546
  elif self.waypoints_pred_head.startswith("gru"): waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) if "command" in self.waypoints_pred_head else self.waypoints_generator(waypoints_feature, target_point)
547
  elif self.waypoints_pred_head.startswith("linear"): waypoints = self.waypoints_generator(waypoints_feature, measurements)
548
-
549
  is_junction = self.junction_pred_head(is_junction_feature)
550
  traffic_light_state = self.traffic_light_pred_head(is_junction_feature) # Original code uses same feature
551
  stop_sign = self.stop_sign_head(is_junction_feature) # Original code uses same feature
@@ -553,44 +711,134 @@ class Interfuser(nn.Module):
553
  velocity = measurements[:, 6:7].unsqueeze(-1).repeat(1, 400, 32)
554
  traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
555
  traffic = self.traffic_pred_head(traffic_feature_with_vel)
556
-
557
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
558
 
 
 
 
 
559
 
560
- # --- تعريف الغلاف الرئيسي (Wrapper) ---
561
- # هذا هو الكلاس الذي سيتم استدعاؤه بواسطة AutoModel
562
  class InterfuserForHuggingFace(PreTrainedModel):
563
  config_class = InterfuserConfig
 
564
  def __init__(self, config: InterfuserConfig):
565
  super().__init__(config)
566
- self.model = Interfuser(config) # سيتم بناء النموذج الأصلي هنا
 
567
  def _init_weights(self, module):
 
 
 
 
568
  if hasattr(module, 'reset_parameters'):
569
  module.reset_parameters()
570
- def forward(self, rgb: torch.FloatTensor, rgb_left: torch.FloatTensor, rgb_right: torch.FloatTensor, rgb_center: torch.FloatTensor, lidar: torch.FloatTensor, measurements: torch.FloatTensor, target_point: torch.FloatTensor, return_dict: Optional[bool] = None) -> Union[Tuple, InterfuserOutput]:
 
 
 
 
 
 
 
 
 
 
 
 
 
571
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
572
- inputs = {"rgb": rgb, "rgb_left": rgb_left, "rgb_right": rgb_right, "rgb_center": rgb_center, "lidar": lidar, "measurements": measurements, "target_point": target_point}
573
- outputs = self.model(inputs)
574
- if self.config.end2end:
575
- if not return_dict: return (outputs,)
576
- return InterfuserOutput(waypoints=outputs)
577
- traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature = outputs
578
- if not return_dict: return outputs
579
- return InterfuserOutput(waypoints=waypoints, traffic_predictions=traffic, is_junction=is_junction, traffic_light_state=traffic_light_state, stop_sign=stop_sign, traffic_features=traffic_feature)
580
- # ==============================================================================
581
- # --- التسجيل الديناميكي للنموذج في مكتبة Transformers ---
582
- # هذا هو الجزء الحاسم الذي يحل خطأ KeyError
583
- # ==============================================================================
584
- from transformers.models.auto.configuration_auto import AutoConfig
585
- from transformers.models.auto.modeling_auto import AutoModel
586
 
587
- print("Registering Interfuser model with AutoModel...")
 
 
 
 
 
 
 
 
588
 
589
- # 1. تسجيل فئة الإعدادات
590
- AutoConfig.register("interfuser", InterfuserConfig)
591
 
592
- # 2. تسجيل فئة النموذج
593
- # هذا يربط model_type="interfuser" مع الكلاس InterfuserForHuggingFace
594
- AutoModel.register(InterfuserConfig, InterfuserForHuggingFace)
 
595
 
596
- print("Registration complete.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from torch import nn
3
  import torch.nn.functional as F
 
7
  import math
8
  from collections import OrderedDict
9
  import copy
10
+ from typing import Optional, List, Tuple
11
  from torch import Tensor
12
  from dataclasses import dataclass
13
+ import numpy as np # مطلوب لـ SpatialSoftmax
14
 
15
  # ==============================================================================
16
+ # ملاحظة: تم نسخ جميع الكلاسات المساعدة من الكود الأصلي هنا
17
+ # لضمان أن يكون الكود قابلاً للتشغيل بشكل مستقل.
18
  # ==============================================================================
19
 
20
+ # من الأفضل استيرادها من المصدر الأصلي إذا كان ذلك متاحًا
21
+ # لضمان قابلية النقل الكاملة، نعرّفها هنا.
22
+ # from InterFuser.interfuser.timm.models.layers import to_2tuple
23
+ # from InterFuser.interfuser.timm.models.resnet import resnet50d, resnet26d, resnet18d
24
+ # نظرًا لأن هذه الوحدات غير متوفرة مباشرة، سنستخدم كلاسات وهمية (placeholders)
25
+ # للسماح بتشغيل الكود. في الاستخدام الحقيقي، يجب استيرادها بشكل صحيح.
26
+
27
+ def to_2tuple(x):
28
+ if isinstance(x, tuple):
29
+ return x
30
+ return (x, x)
31
+
32
+ # DummyResNet المحسّن
33
  class DummyResNet(nn.Module):
34
+ """
35
+ كلاس وهمي محسن لـ ResNet.
36
+ يقوم بتغيير عدد القنوات المخرجة بناءً على الاسم المعطى له.
37
+ """
38
  def __init__(self, name="r26", **kwargs):
39
  super().__init__()
40
+ # تحديد عدد القنوات بناءً على اسم الشبكة
41
+ if name == "r18":
42
+ out_channels = 512
43
+ else: # r26, r50, etc.
44
+ out_channels = 2048
45
+
46
+ print(f"Building DummyResNet '{name}' with {out_channels} output channels.")
47
+
48
  self.features = nn.Sequential(
49
  nn.Conv2d(kwargs.get('in_chans', 3), out_channels, kernel_size=7, stride=2, padding=3),
50
  nn.AdaptiveAvgPool2d((1, 1))
51
  )
52
  self.num_features = out_channels
53
+
54
  def forward(self, x):
55
  return [self.features(x)]
56
 
57
+ # قم بتحديث كيفية تعريف الشبكات لاستخدام الكلاس الجديد
58
+ def resnet18d(**kwargs):
59
+ return DummyResNet(name="r18", **kwargs)
 
60
 
61
+ def resnet26d(**kwargs):
62
+ return DummyResNet(name="r26", **kwargs)
63
+
64
+ def resnet50d(**kwargs):
65
+ return DummyResNet(name="r50", **kwargs)
66
+ # ==============================================================================
67
+ # القسم 1: جميع الكلاسات المساعدة من الكود الأصلي
68
+ # ==============================================================================
69
+
70
+ # class HybridEmbed(nn.Module):
71
+ # def __init__(
72
+ # self,
73
+ # backbone,
74
+ # img_size=224,
75
+ # patch_size=1,
76
+ # feature_size=None,
77
+ # in_chans=3,
78
+ # embed_dim=768,
79
+ # ):
80
+ # super().__init__()
81
+ # assert isinstance(backbone, nn.Module)
82
+ # img_size = to_2tuple(img_size)
83
+ # patch_size = to_2tuple(patch_size)
84
+ # self.img_size = img_size
85
+ # self.patch_size = patch_size
86
+ # self.backbone = backbone
87
+ # if feature_size is None:
88
+ # with torch.no_grad():
89
+ # training = backbone.training
90
+ # if training:
91
+ # backbone.eval()
92
+ # o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
93
+ # if isinstance(o, (list, tuple)):
94
+ # o = o[-1]
95
+ # feature_size = o.shape[-2:]
96
+ # feature_dim = o.shape[1]
97
+ # backbone.train(training)
98
+ # else:
99
+ # feature_size = to_2tuple(feature_size)
100
+ # if hasattr(self.backbone, "feature_info"):
101
+ # feature_dim = self.backbone.feature_info.channels()[-1]
102
+ # else:
103
+ # feature_dim = self.backbone.num_features
104
+
105
+ # self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
106
+ # هذا هو الكود الجديد الذي يجب أن تستخدمه
107
  class HybridEmbed(nn.Module):
108
  def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
109
  super().__init__()
110
+
111
+ # --- بداية التعديلات ---
112
+ # تعديل 1: تأكد من أن img_size هو tuple للوصول الآمن إلى عناصره
113
+ if isinstance(img_size, int):
114
+ img_size = (img_size, img_size)
115
+ # --- نهاية التعديل 1 ---
116
+
117
+ self.img_size = img_size
118
  self.patch_size = to_2tuple(patch_size)
119
  self.backbone = backbone
120
+
121
  if feature_size is None:
122
  with torch.no_grad():
123
  training = backbone.training
124
+ if training:
125
+ backbone.eval()
126
+
127
+ # تعديل 2: حاول تمرير المدخلات مع حجم الصورة المحدد
128
+ try:
129
+ o = self.backbone(torch.zeros(1, in_chans, self.img_size[0], self.img_size[1]))
130
+ except Exception as e:
131
+ # إذا فشل، حاول بحجم قياسي كخطة بديلة
132
+ print(f"Warning: Failed to infer feature size with img_size {self.img_size}. Retrying with 224x224. Error: {e}")
133
+ o = self.backbone(torch.zeros(1, in_chans, 224, 224))
134
+
135
+ # تعديل 3: التعامل الآمن مع مخرجات الـ backbone
136
+ if isinstance(o, (list, tuple)):
137
+ o = o[-1]
138
+ # الآن، من المفترض أن يكون 'o' هو Tensor الذي نريده
139
+
140
  feature_dim = o.shape[1]
141
  backbone.train(training)
142
  else:
143
  feature_dim = self.backbone.num_features
144
+
145
  self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
146
+ # --- نهاية كل التعديلات ---
147
+
148
  def forward(self, x):
149
  x = self.backbone(x)
150
+ if isinstance(x, (list, tuple)):
151
+ x = x[-1]
152
  x = self.proj(x)
153
  global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
154
  return x, global_x
155
 
156
+
 
157
  class PositionEmbeddingSine(nn.Module):
158
+ def __init__(
159
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
160
+ ):
161
  super().__init__()
162
  self.num_pos_feats = num_pos_feats
163
  self.temperature = temperature
164
  self.normalize = normalize
165
+ if scale is not None and normalize is False:
166
+ raise ValueError("normalize should be True if scale is passed")
167
+ if scale is None:
168
+ scale = 2 * math.pi
169
  self.scale = scale
170
+
171
  def forward(self, tensor):
172
+ x = tensor
173
+ bs, _, h, w = x.shape
174
  not_mask = torch.ones((bs, h, w), device=x.device)
175
  y_embed = not_mask.cumsum(1, dtype=torch.float32)
176
  x_embed = not_mask.cumsum(2, dtype=torch.float32)
 
178
  eps = 1e-6
179
  y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
180
  x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
181
+
182
  dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
183
  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
184
+
185
+ pos_x = x_embed[:, :, :, None] / dim_t
186
+ pos_y = y_embed[:, :, :, None] / dim_t
187
+ pos_x = torch.stack(
188
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
189
+ ).flatten(3)
190
+ pos_y = torch.stack(
191
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
192
+ ).flatten(3)
193
  pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
194
  return pos
195
+
196
+
197
  class TransformerEncoder(nn.Module):
198
  def __init__(self, encoder_layer, num_layers, norm=None):
199
  super().__init__()
 
424
  mask[84:101, 84:101] = False; mask[101:151, :] = False; mask[:, 101:151] = False
425
  return mask
426
 
427
+ # ==============================================================================
428
+ # القسم 2: تعريف فئة الإعدادات (Config)
429
+ # ==============================================================================
430
+
431
  class InterfuserConfig(PretrainedConfig):
432
  model_type = "interfuser"
433
+
434
+ def __init__(
435
+ self,
436
+ img_size=224,
437
+ patch_size=8,
438
+ in_chans=3,
439
+ embed_dim=768,
440
+ enc_depth=6,
441
+ dec_depth=6,
442
+ dim_feedforward=2048,
443
+ normalize_before=False,
444
+ rgb_backbone_name="r26",
445
+ lidar_backbone_name="r26",
446
+ num_heads=8,
447
+ dropout=0.1,
448
+ end2end=False,
449
+ direct_concat=False, # تم تغيير القيمة الافتراضية لتجنب التعقيد
450
+ separate_view_attention=False,
451
+ separate_all_attention=False,
452
+ freeze_num=-1,
453
+ with_lidar=True,
454
+ with_right_left_sensors=True,
455
+ with_center_sensor=True,
456
+ traffic_pred_head_type="det",
457
+ waypoints_pred_head="linear-sum",
458
+ reverse_pos=True,
459
+ use_different_backbone=False,
460
+ use_view_embed=True,
461
+ use_mmad_pretrain=None,
462
+ **kwargs
463
+ ):
464
  super().__init__(**kwargs)
465
  self.img_size = img_size
466
+ self.patch_size = patch_size
467
+ self.in_chans = in_chans
468
  self.embed_dim = embed_dim
469
  self.enc_depth = enc_depth
470
  self.dec_depth = dec_depth
471
+ self.dim_feedforward = dim_feedforward
472
+ self.normalize_before = normalize_before
473
  self.rgb_backbone_name = rgb_backbone_name
474
  self.lidar_backbone_name = lidar_backbone_name
475
+ self.num_heads = num_heads
476
+ self.dropout = dropout
477
+ self.end2end = end2end
478
+ self.direct_concat = direct_concat
479
+ self.separate_view_attention = separate_view_attention
480
+ self.separate_all_attention = separate_all_attention
481
+ self.freeze_num = freeze_num
482
+ self.with_lidar = with_lidar
483
+ self.with_right_left_sensors = with_right_left_sensors
484
+ self.with_center_sensor = with_center_sensor
485
+ self.traffic_pred_head_type = traffic_pred_head_type
486
  self.waypoints_pred_head = waypoints_pred_head
487
+ self.reverse_pos = reverse_pos
488
+ self.use_different_backbone = use_different_backbone
489
+ self.use_view_embed = use_view_embed
490
+ self.use_mmad_pretrain = use_mmad_pretrain
491
+
492
+ # ==============================================================================
493
+ # القسم 3: تعريف فئة مخرجات النموذج (ModelOutput)
494
+ # ==============================================================================
495
 
 
496
  @dataclass
497
  class InterfuserOutput(ModelOutput):
498
+ """
499
+ كلاس لتخزين مخرجات نموذج Interfuser بطريقة منظمة.
500
+ """
501
  waypoints: torch.FloatTensor = None
502
  traffic_predictions: Optional[torch.FloatTensor] = None
503
  is_junction: Optional[torch.FloatTensor] = None
 
505
  stop_sign: Optional[torch.FloatTensor] = None
506
  traffic_features: Optional[torch.FloatTensor] = None
507
 
508
+ # ==============================================================================
509
+ # القسم 4: النموذج الأصلي (تم تعديل __init__ ليقبل config)
510
+ # ==============================================================================
511
+
512
  class Interfuser(nn.Module):
513
  def __init__(self, config: InterfuserConfig):
514
  super().__init__()
515
  self.config = config
516
+
517
  # استخلاص المتغيرات من كائن الـ config
518
  embed_dim = config.embed_dim
519
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
 
523
  self.traffic_pred_head_type = config.traffic_pred_head_type
524
  self.waypoints_pred_head = config.waypoints_pred_head
525
  self.end2end = config.end2end
526
+
527
  # ... باقي متغيرات الـ init من الكود الأصلي
528
  self.direct_concat = config.direct_concat
529
  self.with_center_sensor = config.with_center_sensor
 
532
  self.use_view_embed = config.use_view_embed
533
  self.separate_view_attention = config.separate_view_attention
534
  self.separate_all_attention = config.separate_all_attention
535
+
536
  if self.direct_concat:
537
  in_chans = config.in_chans * 4
538
  self.with_center_sensor = False
 
550
  # تعريف الـ backbones (استخدام DummyResNet كمثال)
551
  # في الاستخدام الحقيقي، استبدل هذا بالتحميل الفعلي للشبكات
552
  backbone_map = {"r50": resnet50d, "r26": resnet26d, "r18": resnet18d}
553
+
554
  # RGB Backbone
555
  rgb_backbone_class = backbone_map.get(config.rgb_backbone_name, resnet26d)
556
  self.rgb_backbone = rgb_backbone_class(pretrained=True, in_chans=in_chans, features_only=True, out_indices=[4])
557
+
558
  # Lidar Backbone
559
  if config.use_different_backbone:
560
  lidar_backbone_class = backbone_map.get(config.lidar_backbone_name, resnet26d)
 
593
  elif self.waypoints_pred_head == "gru-command": self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
594
  elif self.waypoints_pred_head == "linear": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=False)
595
  elif self.waypoints_pred_head == "linear-sum": self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
596
+
597
  self.junction_pred_head = nn.Linear(embed_dim, 2)
598
  self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
599
  self.stop_sign_head = nn.Linear(embed_dim, 2)
600
+
601
  self.traffic_pred_head = nn.Sequential(*[nn.Linear(embed_dim + 32, 64), nn.ReLU(), nn.Linear(64, 7), nn.Sigmoid()])
602
  self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
603
+
604
  encoder_layer = TransformerEncoderLayer(embed_dim, config.num_heads, config.dim_feedforward, config.dropout, act_layer, config.normalize_before)
605
  self.encoder = TransformerEncoder(encoder_layer, config.enc_depth, None)
606
+
607
  decoder_layer = TransformerDecoderLayer(embed_dim, config.num_heads, config.dim_feedforward, config.dropout, act_layer, config.normalize_before)
608
  decoder_norm = nn.LayerNorm(embed_dim)
609
  self.decoder = TransformerDecoder(decoder_layer, config.dec_depth, decoder_norm, return_intermediate=False)
610
+
611
  self.reset_parameters()
612
 
613
  def reset_parameters(self):
 
663
  lidar_token_global = lidar_token_global + self.view_embed[:, :, 4, :] + self.global_embed[:, :, 4:5]
664
  lidar_token_global = lidar_token_global.permute(2, 0, 1)
665
  features.extend([lidar_token, lidar_token_global])
666
+
667
  return torch.cat(features, 0)
668
 
669
  def forward(self, x):
 
676
  right_image = F.interpolate(right_image, size=(img_size, img_size))
677
  front_center_image = F.interpolate(front_center_image, size=(img_size, img_size))
678
  front_image = torch.cat([front_image, left_image, right_image, front_center_image], dim=1)
679
+
680
  features = self.forward_features(front_image, left_image, right_image, front_center_image, lidar, measurements)
681
  bs = front_image.shape[0]
682
 
 
703
  if self.waypoints_pred_head == "heatmap": waypoints = self.waypoints_generator(waypoints_feature, measurements)
704
  elif self.waypoints_pred_head.startswith("gru"): waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) if "command" in self.waypoints_pred_head else self.waypoints_generator(waypoints_feature, target_point)
705
  elif self.waypoints_pred_head.startswith("linear"): waypoints = self.waypoints_generator(waypoints_feature, measurements)
706
+
707
  is_junction = self.junction_pred_head(is_junction_feature)
708
  traffic_light_state = self.traffic_light_pred_head(is_junction_feature) # Original code uses same feature
709
  stop_sign = self.stop_sign_head(is_junction_feature) # Original code uses same feature
 
711
  velocity = measurements[:, 6:7].unsqueeze(-1).repeat(1, 400, 32)
712
  traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
713
  traffic = self.traffic_pred_head(traffic_feature_with_vel)
714
+
715
  return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
716
 
717
+ # ==============================================================================
718
+ # القسم 5: الغلاف (Wrapper) المتوافق مع Hugging Face
719
+ # ==============================================================================
720
+ from typing import Optional, Tuple, Union
721
 
 
 
722
  class InterfuserForHuggingFace(PreTrainedModel):
723
  config_class = InterfuserConfig
724
+
725
  def __init__(self, config: InterfuserConfig):
726
  super().__init__(config)
727
+ self.model = Interfuser(config)
728
+
729
  def _init_weights(self, module):
730
+ """
731
+ هذه الدالة مطلوبة من PreTrainedModel.
732
+ بما أن نموذجنا الأصلي لديه دالة reset_parameters، يمكننا الاعتماد عليها.
733
+ """
734
  if hasattr(module, 'reset_parameters'):
735
  module.reset_parameters()
736
+
737
+ def forward(
738
+ self,
739
+ rgb: torch.FloatTensor,
740
+ rgb_left: torch.FloatTensor,
741
+ rgb_right: torch.FloatTensor,
742
+ rgb_center: torch.FloatTensor,
743
+ lidar: torch.FloatTensor,
744
+ measurements: torch.FloatTensor,
745
+ target_point: torch.FloatTensor,
746
+ return_dict: Optional[bool] = None,
747
+ ) -> Union[Tuple, InterfuserOutput]:
748
+
749
+ # --- بداية الكود المصحح ---
750
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
 
752
+ inputs = {
753
+ "rgb": rgb,
754
+ "rgb_left": rgb_left,
755
+ "rgb_right": rgb_right,
756
+ "rgb_center": rgb_center,
757
+ "lidar": lidar,
758
+ "measurements": measurements,
759
+ "target_point": target_point
760
+ }
761
 
762
+ outputs = self.model(inputs)
 
763
 
764
+ if self.config.end2end:
765
+ if not return_dict:
766
+ return (outputs,)
767
+ return InterfuserOutput(waypoints=outputs)
768
 
769
+ # تفريغ المخرجات من الـ tuple
770
+ (
771
+ traffic,
772
+ waypoints,
773
+ is_junction,
774
+ traffic_light_state,
775
+ stop_sign,
776
+ traffic_feature
777
+ ) = outputs
778
+
779
+ if not return_dict:
780
+ # إرجاع الـ tuple الأصلي إذا لم يتم طلب القاموس
781
+ return outputs
782
+
783
+ # إرجاع كائن المخرجات المنظم
784
+ return InterfuserOutput(
785
+ waypoints=waypoints,
786
+ traffic_predictions=traffic,
787
+ is_junction=is_junction,
788
+ traffic_light_state=traffic_light_state,
789
+ stop_sign=stop_sign,
790
+ traffic_features=traffic_feature,
791
+ )
792
+ # --- نهاية الكود المصحح ---
793
+ # # ==============================================================================
794
+ # # القسم 6: مثال على كيفية الاستخدام
795
+ # # ==============================================================================
796
+
797
+ # if __name__ == '__main__':
798
+ # # 1. إنشاء كائن الإعدادات
799
+ # config = InterfuserConfig(
800
+ # img_size=224,
801
+ # embed_dim=256, # تصغير البعد لسهولة التجربة
802
+ # enc_depth=2, # تصغير العمق
803
+ # dec_depth=2, # تصغير العمق
804
+ # num_heads=4, # تصغير عدد الرؤوس
805
+ # end2end=False, # اختبار الوضع الكامل
806
+ # waypoints_pred_head="linear-sum"
807
+ # )
808
+
809
+ # # 2. إنشاء النموذج من الإعدادات
810
+ # model = InterfuserForHuggingFace(config)
811
+ # model.eval()
812
+
813
+ # # 3. إنشاء بيانات وهمية (dummy data) للمدخلات
814
+ # batch_size = 2
815
+ # img_size = config.img_size
816
+
817
+ # dummy_rgb = torch.randn(batch_size, 3, img_size, img_size)
818
+ # dummy_lidar = torch.randn(batch_size, 3, img_size, img_size)
819
+ # # [command, is_junction, traffic_light_state, stop_sign, ...]
820
+ # dummy_measurements = torch.randn(batch_size, 7)
821
+ # dummy_target_point = torch.randn(batch_size, 2)
822
+
823
+ # # 4. تمرير البيانات للنموذج
824
+ # with torch.no_grad():
825
+ # outputs = model(
826
+ # rgb=dummy_rgb,
827
+ # rgb_left=dummy_rgb,
828
+ # rgb_right=dummy_rgb,
829
+ # rgb_center=dummy_rgb,
830
+ # lidar=dummy_lidar,
831
+ # measurements=dummy_measurements,
832
+ # target_point=dummy_target_point,
833
+ # return_dict=True # طلب المخرجات ككائن منظم
834
+ # )
835
+
836
+ # # 5. الوصول إلى المخرجات
837
+ # print("شكل مخرجات الـ Waypoints:", outputs.waypoints.shape)
838
+ # print("شكل مخرجات توقعات إشارات المرور:", outputs.traffic_predictions.shape)
839
+ # print("شكل مخرجات التقاطعات:", outputs.is_junction.shape)
840
+
841
+ # # يمكنك الآن حفظ النموذج وتحميله بسهولة
842
+ # # model.save_pretrained("./my_interfuser_model")
843
+ # # loaded_model = InterfuserForHuggingFace.from_pretrained("./my_interfuser_model")
844
+ # # print("\nتم تحميل النموذج بنجاح!")