diff --git "a/model_definition.py" "b/model_definition.py" --- "a/model_definition.py" +++ "b/model_definition.py" @@ -1,1338 +1,1338 @@ -# model_definition.py -# ============================================================================ -# الاستيرادات الأساسية -# ============================================================================ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torch.optim import AdamW -from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import Dataset, DataLoader -from torchvision import transforms -from functools import partial -from typing import Optional, List -from torch import Tensor -import os -import json -import numpy as np -import cv2 -from PIL import Image -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision import transforms -from functools import partial -from collections import deque, OrderedDict -import math -from torch.nn import MultiheadAttention -from torch.nn import TransformerEncoder, TransformerEncoderLayer -from torch.nn import TransformerDecoder, TransformerDecoderLayer -from timm.models.resnet import resnet50d, resnet26d, resnet18d -try: - from timm.layers import trunc_normal_ -except ImportError: - from timm.models.layers import trunc_normal_ - -# مكتبات إضافية -import os -import json -import logging -import math -import copy -from pathlib import Path -from collections import OrderedDict - -# مكتبات معالجة البيانات -import numpy as np -import cv2 - -# مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة) -try: - import wandb - WANDB_AVAILABLE = True -except ImportError: - WANDB_AVAILABLE = False - -try: - from tqdm import tqdm -except ImportError: - # إذا لم تكن tqdm متوفرة، استخدم دالة بديلة - def tqdm(iterable, *args, **kwargs): - return iterable - -# ============================================================================ -# دوال مساعدة -# ============================================================================ -def to_2tuple(x): - """تحويل قيمة إلى tuple من عنصرين""" - if isinstance(x, (list, tuple)): - return tuple(x) - return (x, x) -# ============================================================================ -# ============================================================================ - -class HybridEmbed(nn.Module): - def __init__( - self, - backbone, - img_size=224, - patch_size=1, - feature_size=None, - in_chans=3, - embed_dim=768, - ): - super().__init__() - assert isinstance(backbone, nn.Module) - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.backbone = backbone - if feature_size is None: - with torch.no_grad(): - training = backbone.training - if training: - backbone.eval() - o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) - if isinstance(o, (list, tuple)): - o = o[-1] # last feature if backbone outputs list/tuple of features - feature_size = o.shape[-2:] - feature_dim = o.shape[1] - backbone.train(training) - else: - feature_size = to_2tuple(feature_size) - if hasattr(self.backbone, "feature_info"): - feature_dim = self.backbone.feature_info.channels()[-1] - else: - feature_dim = self.backbone.num_features - - self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1) - - def forward(self, x): - x = self.backbone(x) - if isinstance(x, (list, tuple)): - x = x[-1] # last feature if backbone outputs list/tuple of features - x = self.proj(x) - global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None] - return x, global_x - - -class PositionEmbeddingSine(nn.Module): - """ - This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. - """ - - def __init__( - self, num_pos_feats=64, temperature=10000, normalize=False, scale=None - ): - super().__init__() - self.num_pos_feats = num_pos_feats - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - def forward(self, tensor): - x = tensor - bs, _, h, w = x.shape - not_mask = torch.ones((bs, h, w), device=x.device) - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - - -class TransformerEncoder(nn.Module): - def __init__(self, encoder_layer, num_layers, norm=None): - super().__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - src, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - ): - output = src - - for layer in self.layers: - output = layer( - output, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - pos=pos, - ) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class SpatialSoftmax(nn.Module): - def __init__(self, height, width, channel, temperature=None, data_format="NCHW"): - super().__init__() - - self.data_format = data_format - self.height = height - self.width = width - self.channel = channel - - if temperature: - self.temperature = Parameter(torch.ones(1) * temperature) - else: - self.temperature = 1.0 - - pos_x, pos_y = np.meshgrid( - np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width) - ) - pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float() - pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float() - self.register_buffer("pos_x", pos_x) - self.register_buffer("pos_y", pos_y) - - def forward(self, feature): - # Output: - # (N, C*2) x_0 y_0 ... - - if self.data_format == "NHWC": - feature = ( - feature.transpose(1, 3) - .tranpose(2, 3) - .view(-1, self.height * self.width) - ) - else: - feature = feature.view(-1, self.height * self.width) - - weight = F.softmax(feature / self.temperature, dim=-1) - expected_x = torch.sum( - torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True - ) - expected_y = torch.sum( - torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True - ) - expected_xy = torch.cat([expected_x, expected_y], 1) - feature_keypoints = expected_xy.view(-1, self.channel, 2) - feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12 - feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12 - return feature_keypoints - - -class MultiPath_Generator(nn.Module): - def __init__(self, in_channel, embed_dim, out_channel): - super().__init__() - self.spatial_softmax = SpatialSoftmax(100, 100, out_channel) - self.tconv0 = nn.Sequential( - nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False), - nn.BatchNorm2d(256), - nn.ReLU(True), - ) - self.tconv1 = nn.Sequential( - nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False), - nn.BatchNorm2d(256), - nn.ReLU(True), - ) - self.tconv2 = nn.Sequential( - nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False), - nn.BatchNorm2d(192), - nn.ReLU(True), - ) - self.tconv3 = nn.Sequential( - nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False), - nn.BatchNorm2d(64), - nn.ReLU(True), - ) - self.tconv4_list = torch.nn.ModuleList( - [ - nn.Sequential( - nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False), - nn.Tanh(), - ) - for _ in range(6) - ] - ) - - self.upsample = nn.Upsample(size=(50, 50), mode="bilinear") - - def forward(self, x, measurements): - mask = measurements[:, :6] - mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100) - velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1) - velocity = velocity.repeat(1, 32, 2, 2) - - n, d, c = x.shape - x = x.transpose(1, 2) - x = x.view(n, -1, 2, 2) - x = torch.cat([x, velocity], dim=1) - x = self.tconv0(x) - x = self.tconv1(x) - x = self.tconv2(x) - x = self.tconv3(x) - x = self.upsample(x) - xs = [] - for i in range(6): - xt = self.tconv4_list[i](x) - xs.append(xt) - xs = torch.stack(xs, dim=1) - x = torch.sum(xs * mask, dim=1) - x = self.spatial_softmax(x) - return x - - -class LinearWaypointsPredictor(nn.Module): - def __init__(self, input_dim, cumsum=True): - super().__init__() - self.cumsum = cumsum - self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim)) - self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)]) - self.head_relu = nn.ReLU(inplace=True) - self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) - - def forward(self, x, measurements): - # input shape: n 10 embed_dim - bs, n, dim = x.shape - x = x + self.rank_embed - x = x.reshape(-1, dim) - - mask = measurements[:, :6] - mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2) - - rs = [] - for i in range(6): - res = self.head_fc1_list[i](x) - res = self.head_relu(res) - res = self.head_fc2_list[i](res) - rs.append(res) - rs = torch.stack(rs, 1) - x = torch.sum(rs * mask, dim=1) - - x = x.view(bs, n, 2) - if self.cumsum: - x = torch.cumsum(x, 1) - return x - - -class GRUWaypointsPredictor(nn.Module): - def __init__(self, input_dim, waypoints=10): - super().__init__() - # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) - self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) - self.encoder = nn.Linear(2, 64) - self.decoder = nn.Linear(64, 2) - self.waypoints = waypoints - - def forward(self, x, target_point): - bs = x.shape[0] - z = self.encoder(target_point).unsqueeze(0) - output, _ = self.gru(x, z) - output = output.reshape(bs * self.waypoints, -1) - output = self.decoder(output).reshape(bs, self.waypoints, 2) - output = torch.cumsum(output, 1) - return output - -class GRUWaypointsPredictorWithCommand(nn.Module): - def __init__(self, input_dim, waypoints=10): - super().__init__() - # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) - self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)]) - self.encoder = nn.Linear(2, 64) - self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) - self.waypoints = waypoints - - def forward(self, x, target_point, measurements): - bs, n, dim = x.shape - mask = measurements[:, :6, None, None] - mask = mask.repeat(1, 1, self.waypoints, 2) - - z = self.encoder(target_point).unsqueeze(0) - outputs = [] - for i in range(6): - output, _ = self.grus[i](x, z) - output = output.reshape(bs * self.waypoints, -1) - output = self.decoders[i](output).reshape(bs, self.waypoints, 2) - output = torch.cumsum(output, 1) - outputs.append(output) - outputs = torch.stack(outputs, 1) - output = torch.sum(outputs * mask, dim=1) - return output - - -class TransformerDecoder(nn.Module): - def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): - super().__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - self.return_intermediate = return_intermediate - - def forward( - self, - tgt, - memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - output = tgt - - intermediate = [] - - for layer in self.layers: - output = layer( - output, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - pos=pos, - query_pos=query_pos, - ) - if self.return_intermediate: - intermediate.append(self.norm(output)) - - if self.norm is not None: - output = self.norm(output) - if self.return_intermediate: - intermediate.pop() - intermediate.append(output) - - if self.return_intermediate: - return torch.stack(intermediate) - - return output.unsqueeze(0) - - -class TransformerEncoderLayer(nn.Module): - def __init__( - self, - d_model, - nhead, - dim_feedforward=2048, - dropout=0.1, - activation=nn.ReLU(), - normalize_before=False, - ): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = activation() - self.normalize_before = normalize_before - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward_post( - self, - src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - ): - q = k = self.with_pos_embed(src, pos) - src2 = self.self_attn( - q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask - )[0] - src = src + self.dropout1(src2) - src = self.norm1(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = src + self.dropout2(src2) - src = self.norm2(src) - return src - - def forward_pre( - self, - src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - ): - src2 = self.norm1(src) - q = k = self.with_pos_embed(src2, pos) - src2 = self.self_attn( - q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask - )[0] - src = src + self.dropout1(src2) - src2 = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) - src = src + self.dropout2(src2) - return src - - def forward( - self, - src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - ): - if self.normalize_before: - return self.forward_pre(src, src_mask, src_key_padding_mask, pos) - return self.forward_post(src, src_mask, src_key_padding_mask, pos) - - -class TransformerDecoderLayer(nn.Module): - def __init__( - self, - d_model, - nhead, - dim_feedforward=2048, - dropout=0.1, - activation=nn.ReLU(), - normalize_before=False, - ): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = activation() - self.normalize_before = normalize_before - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward_post( - self, - tgt, - memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - q = k = self.with_pos_embed(tgt, query_pos) - tgt2 = self.self_attn( - q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask - )[0] - tgt = tgt + self.dropout1(tgt2) - tgt = self.norm1(tgt) - tgt2 = self.multihead_attn( - query=self.with_pos_embed(tgt, query_pos), - key=self.with_pos_embed(memory, pos), - value=memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = tgt + self.dropout2(tgt2) - tgt = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = tgt + self.dropout3(tgt2) - tgt = self.norm3(tgt) - return tgt - - def forward_pre( - self, - tgt, - memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - tgt2 = self.norm1(tgt) - q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn( - q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask - )[0] - tgt = tgt + self.dropout1(tgt2) - tgt2 = self.norm2(tgt) - tgt2 = self.multihead_attn( - query=self.with_pos_embed(tgt2, query_pos), - key=self.with_pos_embed(memory, pos), - value=memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = tgt + self.dropout2(tgt2) - tgt2 = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) - tgt = tgt + self.dropout3(tgt2) - return tgt - - def forward( - self, - tgt, - memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None, - ): - if self.normalize_before: - return self.forward_pre( - tgt, - memory, - tgt_mask, - memory_mask, - tgt_key_padding_mask, - memory_key_padding_mask, - pos, - query_pos, - ) - return self.forward_post( - tgt, - memory, - tgt_mask, - memory_mask, - tgt_key_padding_mask, - memory_key_padding_mask, - pos, - query_pos, - ) - - -def _get_clones(module, N): - return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) - - -def _get_activation_fn(activation): - """Return an activation function given a string""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(f"activation should be relu/gelu, not {activation}.") - - -def build_attn_mask(mask_type): - mask = torch.ones((151, 151), dtype=torch.bool).cuda() - if mask_type == "seperate_all": - mask[:50, :50] = False - mask[50:67, 50:67] = False - mask[67:84, 67:84] = False - mask[84:101, 84:101] = False - mask[101:151, 101:151] = False - elif mask_type == "seperate_view": - mask[:50, :50] = False - mask[50:67, 50:67] = False - mask[67:84, 67:84] = False - mask[84:101, 84:101] = False - mask[101:151, :] = False - mask[:, 101:151] = False - return mask -# class InterfuserModel(nn.Module): - -class InterfuserModel(nn.Module): - def __init__( - self, - img_size=224, - multi_view_img_size=112, - patch_size=8, - in_chans=3, - embed_dim=768, - enc_depth=6, - dec_depth=6, - dim_feedforward=2048, - normalize_before=False, - rgb_backbone_name="r50", - lidar_backbone_name="r50", - num_heads=8, - norm_layer=None, - dropout=0.1, - end2end=False, - direct_concat=False, - separate_view_attention=False, - separate_all_attention=False, - act_layer=None, - weight_init="", - freeze_num=-1, - with_lidar=False, - with_right_left_sensors=False, - with_center_sensor=False, - traffic_pred_head_type="det", - waypoints_pred_head="heatmap", - reverse_pos=True, - use_different_backbone=False, - use_view_embed=False, - use_mmad_pretrain=None, - ): - super().__init__() - self.traffic_pred_head_type = traffic_pred_head_type - self.num_features = ( - self.embed_dim - ) = embed_dim # num_features for consistency with other models - norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) - act_layer = act_layer or nn.GELU - - self.reverse_pos = reverse_pos - self.waypoints_pred_head = waypoints_pred_head - self.with_lidar = with_lidar - self.with_right_left_sensors = with_right_left_sensors - self.with_center_sensor = with_center_sensor - - self.direct_concat = direct_concat - self.separate_view_attention = separate_view_attention - self.separate_all_attention = separate_all_attention - self.end2end = end2end - self.use_view_embed = use_view_embed - - if self.direct_concat: - in_chans = in_chans * 4 - self.with_center_sensor = False - self.with_right_left_sensors = False - - if self.separate_view_attention: - self.attn_mask = build_attn_mask("seperate_view") - elif self.separate_all_attention: - self.attn_mask = build_attn_mask("seperate_all") - else: - self.attn_mask = None - - if use_different_backbone: - if rgb_backbone_name == "r50": - self.rgb_backbone = resnet50d( - pretrained=True, - in_chans=in_chans, - features_only=True, - out_indices=[4], - ) - elif rgb_backbone_name == "r26": - self.rgb_backbone = resnet26d( - pretrained=True, - in_chans=in_chans, - features_only=True, - out_indices=[4], - ) - elif rgb_backbone_name == "r18": - self.rgb_backbone = resnet18d( - pretrained=True, - in_chans=in_chans, - features_only=True, - out_indices=[4], - ) - if lidar_backbone_name == "r50": - self.lidar_backbone = resnet50d( - pretrained=False, - in_chans=in_chans, - features_only=True, - out_indices=[4], - ) - elif lidar_backbone_name == "r26": - self.lidar_backbone = resnet26d( - pretrained=False, - in_chans=in_chans, - features_only=True, - out_indices=[4], - ) - elif lidar_backbone_name == "r18": - self.lidar_backbone = resnet18d( - pretrained=False, in_chans=3, features_only=True, out_indices=[4] - ) - rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) - lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone) - - if use_mmad_pretrain: - params = torch.load(use_mmad_pretrain)["state_dict"] - updated_params = OrderedDict() - for key in params: - if "backbone" in key: - updated_params[key.replace("backbone.", "")] = params[key] - self.rgb_backbone.load_state_dict(updated_params) - - self.rgb_patch_embed = rgb_embed_layer( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - ) - self.lidar_patch_embed = lidar_embed_layer( - img_size=img_size, - patch_size=patch_size, - in_chans=3, - embed_dim=embed_dim, - ) - else: - if rgb_backbone_name == "r50": - self.rgb_backbone = resnet50d( - pretrained=True, in_chans=3, features_only=True, out_indices=[4] - ) - elif rgb_backbone_name == "r101": - self.rgb_backbone = resnet101d( - pretrained=True, in_chans=3, features_only=True, out_indices=[4] - ) - elif rgb_backbone_name == "r26": - self.rgb_backbone = resnet26d( - pretrained=True, in_chans=3, features_only=True, out_indices=[4] - ) - elif rgb_backbone_name == "r18": - self.rgb_backbone = resnet18d( - pretrained=True, in_chans=3, features_only=True, out_indices=[4] - ) - embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) - - self.rgb_patch_embed = embed_layer( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - ) - self.lidar_patch_embed = embed_layer( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - ) - - self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) - self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1)) - - if self.end2end: - self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4)) - self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim)) - elif self.waypoints_pred_head == "heatmap": - self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) - self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim)) - else: - self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11)) - self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim)) - - if self.end2end: - self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4) - elif self.waypoints_pred_head == "heatmap": - self.waypoints_generator = MultiPath_Generator( - embed_dim + 32, embed_dim, 10 - ) - elif self.waypoints_pred_head == "gru": - self.waypoints_generator = GRUWaypointsPredictor(embed_dim) - elif self.waypoints_pred_head == "gru-command": - self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim) - elif self.waypoints_pred_head == "linear": - self.waypoints_generator = LinearWaypointsPredictor(embed_dim) - elif self.waypoints_pred_head == "linear-sum": - self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True) - - self.junction_pred_head = nn.Linear(embed_dim, 2) - self.traffic_light_pred_head = nn.Linear(embed_dim, 2) - self.stop_sign_head = nn.Linear(embed_dim, 2) - - if self.traffic_pred_head_type == "det": - self.traffic_pred_head = nn.Sequential( - *[ - nn.Linear(embed_dim + 32, 64), - nn.ReLU(), - nn.Linear(64, 7), - # nn.Sigmoid(), - ] - ) - elif self.traffic_pred_head_type == "seg": - self.traffic_pred_head = nn.Sequential( - *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()] - ) - - self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True) - - encoder_layer = TransformerEncoderLayer( - embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before - ) - self.encoder = TransformerEncoder(encoder_layer, enc_depth, None) - - decoder_layer = TransformerDecoderLayer( - embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before - ) - decoder_norm = nn.LayerNorm(embed_dim) - self.decoder = TransformerDecoder( - decoder_layer, dec_depth, decoder_norm, return_intermediate=False - ) - self.reset_parameters() - - def reset_parameters(self): - nn.init.uniform_(self.global_embed) - nn.init.uniform_(self.view_embed) - nn.init.uniform_(self.query_embed) - nn.init.uniform_(self.query_pos_embed) - - def forward_features( - self, - front_image, - left_image, - right_image, - front_center_image, - lidar, - measurements, - ): - features = [] - - # Front view processing - front_image_token, front_image_token_global = self.rgb_patch_embed(front_image) - if self.use_view_embed: - front_image_token = ( - front_image_token - + self.view_embed[:, :, 0:1, :] - + self.position_encoding(front_image_token) - ) - else: - front_image_token = front_image_token + self.position_encoding( - front_image_token - ) - front_image_token = front_image_token.flatten(2).permute(2, 0, 1) - front_image_token_global = ( - front_image_token_global - + self.view_embed[:, :, 0, :] - + self.global_embed[:, :, 0:1] - ) - front_image_token_global = front_image_token_global.permute(2, 0, 1) - features.extend([front_image_token, front_image_token_global]) - - if self.with_right_left_sensors: - # Left view processing - left_image_token, left_image_token_global = self.rgb_patch_embed(left_image) - if self.use_view_embed: - left_image_token = ( - left_image_token - + self.view_embed[:, :, 1:2, :] - + self.position_encoding(left_image_token) - ) - else: - left_image_token = left_image_token + self.position_encoding( - left_image_token - ) - left_image_token = left_image_token.flatten(2).permute(2, 0, 1) - left_image_token_global = ( - left_image_token_global - + self.view_embed[:, :, 1, :] - + self.global_embed[:, :, 1:2] - ) - left_image_token_global = left_image_token_global.permute(2, 0, 1) - - # Right view processing - right_image_token, right_image_token_global = self.rgb_patch_embed( - right_image - ) - if self.use_view_embed: - right_image_token = ( - right_image_token - + self.view_embed[:, :, 2:3, :] - + self.position_encoding(right_image_token) - ) - else: - right_image_token = right_image_token + self.position_encoding( - right_image_token - ) - right_image_token = right_image_token.flatten(2).permute(2, 0, 1) - right_image_token_global = ( - right_image_token_global - + self.view_embed[:, :, 2, :] - + self.global_embed[:, :, 2:3] - ) - right_image_token_global = right_image_token_global.permute(2, 0, 1) - - features.extend( - [ - left_image_token, - left_image_token_global, - right_image_token, - right_image_token_global, - ] - ) - - if self.with_center_sensor: - # Front center view processing - ( - front_center_image_token, - front_center_image_token_global, - ) = self.rgb_patch_embed(front_center_image) - if self.use_view_embed: - front_center_image_token = ( - front_center_image_token - + self.view_embed[:, :, 3:4, :] - + self.position_encoding(front_center_image_token) - ) - else: - front_center_image_token = ( - front_center_image_token - + self.position_encoding(front_center_image_token) - ) - - front_center_image_token = front_center_image_token.flatten(2).permute( - 2, 0, 1 - ) - front_center_image_token_global = ( - front_center_image_token_global - + self.view_embed[:, :, 3, :] - + self.global_embed[:, :, 3:4] - ) - front_center_image_token_global = front_center_image_token_global.permute( - 2, 0, 1 - ) - features.extend([front_center_image_token, front_center_image_token_global]) - - if self.with_lidar: - lidar_token, lidar_token_global = self.lidar_patch_embed(lidar) - if self.use_view_embed: - lidar_token = ( - lidar_token - + self.view_embed[:, :, 4:5, :] - + self.position_encoding(lidar_token) - ) - else: - lidar_token = lidar_token + self.position_encoding(lidar_token) - lidar_token = lidar_token.flatten(2).permute(2, 0, 1) - lidar_token_global = ( - lidar_token_global - + self.view_embed[:, :, 4, :] - + self.global_embed[:, :, 4:5] - ) - lidar_token_global = lidar_token_global.permute(2, 0, 1) - features.extend([lidar_token, lidar_token_global]) - - features = torch.cat(features, 0) - return features - - def forward(self, x): - front_image = x["rgb"] - left_image = x["rgb_left"] - right_image = x["rgb_right"] - front_center_image = x["rgb_center"] - measurements = x["measurements"] - target_point = x["target_point"] - lidar = x["lidar"] - - if self.direct_concat: - img_size = front_image.shape[-1] - left_image = torch.nn.functional.interpolate( - left_image, size=(img_size, img_size) - ) - right_image = torch.nn.functional.interpolate( - right_image, size=(img_size, img_size) - ) - front_center_image = torch.nn.functional.interpolate( - front_center_image, size=(img_size, img_size) - ) - front_image = torch.cat( - [front_image, left_image, right_image, front_center_image], dim=1 - ) - features = self.forward_features( - front_image, - left_image, - right_image, - front_center_image, - lidar, - measurements, - ) - - bs = front_image.shape[0] - - if self.end2end: - tgt = self.query_pos_embed.repeat(bs, 1, 1) - else: - tgt = self.position_encoding( - torch.ones((bs, 1, 20, 20), device=x["rgb"].device) - ) - tgt = tgt.flatten(2) - tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2) - tgt = tgt.permute(2, 0, 1) - - memory = self.encoder(features, mask=self.attn_mask) - hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0] - - hs = hs.permute(1, 0, 2) # Batchsize , N, C - if self.end2end: - waypoints = self.waypoints_generator(hs, target_point) - return waypoints - - if self.waypoints_pred_head != "heatmap": - traffic_feature = hs[:, :400] - is_junction_feature = hs[:, 400] - traffic_light_state_feature = hs[:, 400] - stop_sign_feature = hs[:, 400] - waypoints_feature = hs[:, 401:411] - else: - traffic_feature = hs[:, :400] - is_junction_feature = hs[:, 400] - traffic_light_state_feature = hs[:, 400] - stop_sign_feature = hs[:, 400] - waypoints_feature = hs[:, 401:405] - - if self.waypoints_pred_head == "heatmap": - waypoints = self.waypoints_generator(waypoints_feature, measurements) - elif self.waypoints_pred_head == "gru": - waypoints = self.waypoints_generator(waypoints_feature, target_point) - elif self.waypoints_pred_head == "gru-command": - waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) - elif self.waypoints_pred_head == "linear": - waypoints = self.waypoints_generator(waypoints_feature, measurements) - elif self.waypoints_pred_head == "linear-sum": - waypoints = self.waypoints_generator(waypoints_feature, measurements) - - is_junction = self.junction_pred_head(is_junction_feature) - traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature) - stop_sign = self.stop_sign_head(stop_sign_feature) - - velocity = measurements[:, 6:7].unsqueeze(-1) - velocity = velocity.repeat(1, 400, 32) - traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2) - traffic = self.traffic_pred_head(traffic_feature_with_vel) - return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature - def load_pretrained(self, model_path, strict=False): - """ - تحميل الأوزان المدربة مسبقاً - نسخة محسنة - - Args: - model_path (str): مسار ملف الأوزان - strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح - """ - if not model_path or not Path(model_path).exists(): - logging.warning(f"ملف الأوزان غير موجود: {model_path}") - logging.info("سيتم استخدام أوزان عشوائية") - return False - - try: - logging.info(f"محاولة تحميل الأوزان من: {model_path}") - - # تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ - checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) - - # استخراج state_dict من أنواع مختلفة من ملفات الحفظ - if isinstance(checkpoint, dict): - if 'model_state_dict' in checkpoint: - state_dict = checkpoint['model_state_dict'] - logging.info("تم العثور على 'model_state_dict' في الملف") - elif 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - logging.info("تم العثور على 'state_dict' في الملف") - elif 'model' in checkpoint: - state_dict = checkpoint['model'] - logging.info("تم العثور على 'model' في الملف") - else: - state_dict = checkpoint - logging.info("استخدام الملف كـ state_dict مباشرة") - else: - state_dict = checkpoint - logging.info("استخدام الملف كـ state_dict مباشرة") - - # تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة) - clean_state_dict = OrderedDict() - for k, v in state_dict.items(): - # إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً - clean_key = k[7:] if k.startswith('module.') else k - clean_state_dict[clean_key] = v - - # تحميل الأوزان - missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict) - - # تقرير حالة التحميل - if missing_keys: - logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}") - - if unexpected_keys: - logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}") - - if not missing_keys and not unexpected_keys: - logging.info("✅ تم تحميل جميع الأوزان بنجاح تام") - elif not strict: - logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)") - - return True - - except Exception as e: - logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}") - logging.info("سيتم استخدام أوزان عشوائية") - return False - - -# ============================================================================ -# دوال مساعدة لتحميل النموذج -# ============================================================================ - -def load_and_prepare_model(config, device): - """ - يقوم بإنشاء النموذج وتحميل الأوزان المدربة مسبقًا. - - Args: - config (dict): إعدادات النموذج والمسارات - device (torch.device): الجهاز المستهدف (CPU/GPU) - - Returns: - InterfuserModel: النموذج المحمل - """ - try: - # إنشاء النموذج - model = InterfuserModel(**config.get('model_params', {})).to(device) - logging.info(f"تم إنشاء النموذج على الجهاز: {device}") - - # تحميل الأوزان إذا كان المسار محدد - checkpoint_path = config.get('paths', {}).get('pretrained_weights') - if checkpoint_path: - success = model.load_pretrained(checkpoint_path, strict=False) - if success: - logging.info("✅ تم تحميل النموذج والأوزان بنجاح") - else: - logging.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية") - else: - logging.info("لم يتم تحديد مسار الأوزان، سيتم استخدام أوزان عشوائية") - - # وضع النموذج في وضع التقييم - model.eval() - - return model - - except Exception as e: - logging.error(f"خطأ في إنشاء النموذج: {str(e)}") - raise - - -def create_model_config(model_path="model/best_model.pth", **model_params): - """ - إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب - - Args: - model_path (str): مسار ملف الأوزان - **model_params: معاملات النموذج الإضافية - - Returns: - dict: إعدادات النموذج - """ - # الإعدادات الصحيحة من كونفيج التدريب الأصلي - training_config_params = { - "img_size": 224, - "embed_dim": 256, # مهم: هذه القيمة من التدريب الأصلي - "enc_depth": 6, - "dec_depth": 6, - "rgb_backbone_name": 'r50', - "lidar_backbone_name": 'r18', - "waypoints_pred_head": 'gru', - "use_different_backbone": True, - "with_lidar": False, - "with_right_left_sensors": False, - "with_center_sensor": False, - - # إعدادات إضافية من الكونفيج الأصلي - "multi_view_img_size": 112, - "patch_size": 8, - "in_chans": 3, - "dim_feedforward": 2048, - "normalize_before": False, - "num_heads": 8, - "dropout": 0.1, - "end2end": False, - "direct_concat": False, - "separate_view_attention": False, - "separate_all_attention": False, - "freeze_num": -1, - "traffic_pred_head_type": "det", - "reverse_pos": True, - "use_view_embed": False, - "use_mmad_pretrain": None, - } - - # دمج المعاملات المخصصة مع الإعدادات من التدريب - training_config_params.update(model_params) - - config = { - 'model_params': training_config_params, - 'paths': { - 'pretrained_weights': model_path - }, - - # إضافة إعدادات الشبكة من التدريب - 'grid_conf': { - 'h': 20, 'w': 20, - 'x_res': 1.0, 'y_res': 1.0, - 'y_min': 0.0, 'y_max': 20.0, - 'x_min': -10.0, 'x_max': 10.0, - }, - - # معلومات إضافية عن التدريب - 'training_info': { - 'original_project': 'Interfuser_Finetuning', - 'run_name': 'Finetune_Focus_on_Detection_v5', - 'focus': 'traffic_detection_and_iou', - 'backbone': 'ResNet50 + ResNet18', - 'trained_on': 'PDM_Lite_Carla' - } - } - - return config - - -def get_training_config(): - """ - إرجاع إعدادات التدريب الأصلية للمرجع - هذه الإعدادات توضح كيف تم تدريب النموذج - """ - return { - 'project_info': { - 'project': 'Interfuser_Finetuning', - 'entity': None, - 'run_name': 'Finetune_Focus_on_Detection_v5' - }, - 'training': { - 'epochs': 50, - 'batch_size': 8, - 'num_workers': 2, - 'learning_rate': 1e-4, # معدل تعلم منخفض للـ Fine-tuning - 'weight_decay': 1e-2, - 'patience': 15, - 'clip_grad_norm': 1.0, - }, - 'loss_weights': { - 'iou': 2.0, # أولوية قصوى لدقة الصناديق - 'traffic_map': 25.0, # تركيز عالي على اكتشاف الكائنات - 'waypoints': 1.0, # مرجع أساسي - 'junction': 0.25, # مهام متقنة بالفعل - 'traffic_light': 0.5, - 'stop_sign': 0.25, - }, - 'data_split': { - 'strategy': 'interleaved', - 'segment_length': 100, - 'validation_frequency': 10, - }, - 'transforms': { - 'use_data_augmentation': False, # معطل للتركيز على البيانات الأصلية - } - } +# model_definition.py +# ============================================================================ +# الاستيرادات الأساسية +# ============================================================================ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim import AdamW +from torch.optim.lr_scheduler import OneCycleLR +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from functools import partial +from typing import Optional, List +from torch import Tensor +import os +import json +import numpy as np +import cv2 +from PIL import Image +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from functools import partial +from collections import deque, OrderedDict +import math +from torch.nn import MultiheadAttention +from torch.nn import TransformerEncoder, TransformerEncoderLayer +from torch.nn import TransformerDecoder, TransformerDecoderLayer +from timm.models.resnet import resnet50d, resnet26d, resnet18d +try: + from timm.layers import trunc_normal_ +except ImportError: + from timm.models.layers import trunc_normal_ + +# مكتبات إضافية +import os +import json +import logging +import math +import copy +from pathlib import Path +from collections import OrderedDict + +# مكتبات معالجة البيانات +import numpy as np +import cv2 + +# مكتبات اختيارية (يمكن تعطيلها إذا لم تكن متوفرة) +try: + import wandb + WANDB_AVAILABLE = True +except ImportError: + WANDB_AVAILABLE = False + +try: + from tqdm import tqdm +except ImportError: + # إذا لم تكن tqdm متوفرة، استخدم دالة بديلة + def tqdm(iterable, *args, **kwargs): + return iterable + +# ============================================================================ +# دوال مساعدة +# ============================================================================ +def to_2tuple(x): + """تحويل قيمة إلى tuple من عنصرين""" + if isinstance(x, (list, tuple)): + return tuple(x) + return (x, x) +# ============================================================================ +# ============================================================================ + +class HybridEmbed(nn.Module): + def __init__( + self, + backbone, + img_size=224, + patch_size=1, + feature_size=None, + in_chans=3, + embed_dim=768, + ): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, "feature_info"): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x) + global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None] + return x, global_x + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor): + x = tensor + bs, _, h, w = x.shape + not_mask = torch.ones((bs, h, w), device=x.device) + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class TransformerEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + output = src + + for layer in self.layers: + output = layer( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + pos=pos, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class SpatialSoftmax(nn.Module): + def __init__(self, height, width, channel, temperature=None, data_format="NCHW"): + super().__init__() + + self.data_format = data_format + self.height = height + self.width = width + self.channel = channel + + if temperature: + self.temperature = Parameter(torch.ones(1) * temperature) + else: + self.temperature = 1.0 + + pos_x, pos_y = np.meshgrid( + np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width) + ) + pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float() + pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float() + self.register_buffer("pos_x", pos_x) + self.register_buffer("pos_y", pos_y) + + def forward(self, feature): + # Output: + # (N, C*2) x_0 y_0 ... + + if self.data_format == "NHWC": + feature = ( + feature.transpose(1, 3) + .tranpose(2, 3) + .view(-1, self.height * self.width) + ) + else: + feature = feature.view(-1, self.height * self.width) + + weight = F.softmax(feature / self.temperature, dim=-1) + expected_x = torch.sum( + torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True + ) + expected_y = torch.sum( + torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True + ) + expected_xy = torch.cat([expected_x, expected_y], 1) + feature_keypoints = expected_xy.view(-1, self.channel, 2) + feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12 + feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12 + return feature_keypoints + + +class MultiPath_Generator(nn.Module): + def __init__(self, in_channel, embed_dim, out_channel): + super().__init__() + self.spatial_softmax = SpatialSoftmax(100, 100, out_channel) + self.tconv0 = nn.Sequential( + nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(True), + ) + self.tconv1 = nn.Sequential( + nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(True), + ) + self.tconv2 = nn.Sequential( + nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False), + nn.BatchNorm2d(192), + nn.ReLU(True), + ) + self.tconv3 = nn.Sequential( + nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(True), + ) + self.tconv4_list = torch.nn.ModuleList( + [ + nn.Sequential( + nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False), + nn.Tanh(), + ) + for _ in range(6) + ] + ) + + self.upsample = nn.Upsample(size=(50, 50), mode="bilinear") + + def forward(self, x, measurements): + mask = measurements[:, :6] + mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100) + velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1) + velocity = velocity.repeat(1, 32, 2, 2) + + n, d, c = x.shape + x = x.transpose(1, 2) + x = x.view(n, -1, 2, 2) + x = torch.cat([x, velocity], dim=1) + x = self.tconv0(x) + x = self.tconv1(x) + x = self.tconv2(x) + x = self.tconv3(x) + x = self.upsample(x) + xs = [] + for i in range(6): + xt = self.tconv4_list[i](x) + xs.append(xt) + xs = torch.stack(xs, dim=1) + x = torch.sum(xs * mask, dim=1) + x = self.spatial_softmax(x) + return x + + +class LinearWaypointsPredictor(nn.Module): + def __init__(self, input_dim, cumsum=True): + super().__init__() + self.cumsum = cumsum + self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim)) + self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)]) + self.head_relu = nn.ReLU(inplace=True) + self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) + + def forward(self, x, measurements): + # input shape: n 10 embed_dim + bs, n, dim = x.shape + x = x + self.rank_embed + x = x.reshape(-1, dim) + + mask = measurements[:, :6] + mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2) + + rs = [] + for i in range(6): + res = self.head_fc1_list[i](x) + res = self.head_relu(res) + res = self.head_fc2_list[i](res) + rs.append(res) + rs = torch.stack(rs, 1) + x = torch.sum(rs * mask, dim=1) + + x = x.view(bs, n, 2) + if self.cumsum: + x = torch.cumsum(x, 1) + return x + + +class GRUWaypointsPredictor(nn.Module): + def __init__(self, input_dim, waypoints=10): + super().__init__() + # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) + self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) + self.encoder = nn.Linear(2, 64) + self.decoder = nn.Linear(64, 2) + self.waypoints = waypoints + + def forward(self, x, target_point): + bs = x.shape[0] + z = self.encoder(target_point).unsqueeze(0) + output, _ = self.gru(x, z) + output = output.reshape(bs * self.waypoints, -1) + output = self.decoder(output).reshape(bs, self.waypoints, 2) + output = torch.cumsum(output, 1) + return output + +class GRUWaypointsPredictorWithCommand(nn.Module): + def __init__(self, input_dim, waypoints=10): + super().__init__() + # self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64) + self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)]) + self.encoder = nn.Linear(2, 64) + self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)]) + self.waypoints = waypoints + + def forward(self, x, target_point, measurements): + bs, n, dim = x.shape + mask = measurements[:, :6, None, None] + mask = mask.repeat(1, 1, self.waypoints, 2) + + z = self.encoder(target_point).unsqueeze(0) + outputs = [] + for i in range(6): + output, _ = self.grus[i](x, z) + output = output.reshape(bs * self.waypoints, -1) + output = self.decoders[i](output).reshape(bs, self.waypoints, 2) + output = torch.cumsum(output, 1) + outputs.append(output) + outputs = torch.stack(outputs, 1) + output = torch.sum(outputs * mask, dim=1) + return output + + +class TransformerDecoder(nn.Module): + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, + query_pos=query_pos, + ) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation=nn.ReLU(), + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = activation() + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn( + q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn( + q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + )[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward( + self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + def __init__( + self, + d_model, + nhead, + dim_feedforward=2048, + dropout=0.1, + activation=nn.ReLU(), + normalize_before=False, + ): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = activation() + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn( + q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask + )[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + return self.forward_post( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask, + pos, + query_pos, + ) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def build_attn_mask(mask_type): + mask = torch.ones((151, 151), dtype=torch.bool).cuda() + if mask_type == "seperate_all": + mask[:50, :50] = False + mask[50:67, 50:67] = False + mask[67:84, 67:84] = False + mask[84:101, 84:101] = False + mask[101:151, 101:151] = False + elif mask_type == "seperate_view": + mask[:50, :50] = False + mask[50:67, 50:67] = False + mask[67:84, 67:84] = False + mask[84:101, 84:101] = False + mask[101:151, :] = False + mask[:, 101:151] = False + return mask +# class InterfuserModel(nn.Module): + +class InterfuserModel(nn.Module): + def __init__( + self, + img_size=224, + multi_view_img_size=112, + patch_size=8, + in_chans=3, + embed_dim=768, + enc_depth=6, + dec_depth=6, + dim_feedforward=2048, + normalize_before=False, + rgb_backbone_name="r50", + lidar_backbone_name="r50", + num_heads=8, + norm_layer=None, + dropout=0.1, + end2end=False, + direct_concat=False, + separate_view_attention=False, + separate_all_attention=False, + act_layer=None, + weight_init="", + freeze_num=-1, + with_lidar=False, + with_right_left_sensors=False, + with_center_sensor=False, + traffic_pred_head_type="det", + waypoints_pred_head="heatmap", + reverse_pos=True, + use_different_backbone=False, + use_view_embed=False, + use_mmad_pretrain=None, + ): + super().__init__() + self.traffic_pred_head_type = traffic_pred_head_type + self.num_features = ( + self.embed_dim + ) = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.reverse_pos = reverse_pos + self.waypoints_pred_head = waypoints_pred_head + self.with_lidar = with_lidar + self.with_right_left_sensors = with_right_left_sensors + self.with_center_sensor = with_center_sensor + + self.direct_concat = direct_concat + self.separate_view_attention = separate_view_attention + self.separate_all_attention = separate_all_attention + self.end2end = end2end + self.use_view_embed = use_view_embed + + if self.direct_concat: + in_chans = in_chans * 4 + self.with_center_sensor = False + self.with_right_left_sensors = False + + if self.separate_view_attention: + self.attn_mask = build_attn_mask("seperate_view") + elif self.separate_all_attention: + self.attn_mask = build_attn_mask("seperate_all") + else: + self.attn_mask = None + + if use_different_backbone: + if rgb_backbone_name == "r50": + self.rgb_backbone = resnet50d( + pretrained=True, + in_chans=in_chans, + features_only=True, + out_indices=[4], + ) + elif rgb_backbone_name == "r26": + self.rgb_backbone = resnet26d( + pretrained=True, + in_chans=in_chans, + features_only=True, + out_indices=[4], + ) + elif rgb_backbone_name == "r18": + self.rgb_backbone = resnet18d( + pretrained=True, + in_chans=in_chans, + features_only=True, + out_indices=[4], + ) + if lidar_backbone_name == "r50": + self.lidar_backbone = resnet50d( + pretrained=False, + in_chans=in_chans, + features_only=True, + out_indices=[4], + ) + elif lidar_backbone_name == "r26": + self.lidar_backbone = resnet26d( + pretrained=False, + in_chans=in_chans, + features_only=True, + out_indices=[4], + ) + elif lidar_backbone_name == "r18": + self.lidar_backbone = resnet18d( + pretrained=False, in_chans=3, features_only=True, out_indices=[4] + ) + rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) + lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone) + + if use_mmad_pretrain: + params = torch.load(use_mmad_pretrain)["state_dict"] + updated_params = OrderedDict() + for key in params: + if "backbone" in key: + updated_params[key.replace("backbone.", "")] = params[key] + self.rgb_backbone.load_state_dict(updated_params) + + self.rgb_patch_embed = rgb_embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + self.lidar_patch_embed = lidar_embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=3, + embed_dim=embed_dim, + ) + else: + if rgb_backbone_name == "r50": + self.rgb_backbone = resnet50d( + pretrained=True, in_chans=3, features_only=True, out_indices=[4] + ) + elif rgb_backbone_name == "r101": + self.rgb_backbone = resnet101d( + pretrained=True, in_chans=3, features_only=True, out_indices=[4] + ) + elif rgb_backbone_name == "r26": + self.rgb_backbone = resnet26d( + pretrained=True, in_chans=3, features_only=True, out_indices=[4] + ) + elif rgb_backbone_name == "r18": + self.rgb_backbone = resnet18d( + pretrained=True, in_chans=3, features_only=True, out_indices=[4] + ) + embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone) + + self.rgb_patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + self.lidar_patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) + self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1)) + + if self.end2end: + self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4)) + self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim)) + elif self.waypoints_pred_head == "heatmap": + self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5)) + self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim)) + else: + self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11)) + self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim)) + + if self.end2end: + self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4) + elif self.waypoints_pred_head == "heatmap": + self.waypoints_generator = MultiPath_Generator( + embed_dim + 32, embed_dim, 10 + ) + elif self.waypoints_pred_head == "gru": + self.waypoints_generator = GRUWaypointsPredictor(embed_dim) + elif self.waypoints_pred_head == "gru-command": + self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim) + elif self.waypoints_pred_head == "linear": + self.waypoints_generator = LinearWaypointsPredictor(embed_dim) + elif self.waypoints_pred_head == "linear-sum": + self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True) + + self.junction_pred_head = nn.Linear(embed_dim, 2) + self.traffic_light_pred_head = nn.Linear(embed_dim, 2) + self.stop_sign_head = nn.Linear(embed_dim, 2) + + if self.traffic_pred_head_type == "det": + self.traffic_pred_head = nn.Sequential( + *[ + nn.Linear(embed_dim + 32, 64), + nn.ReLU(), + nn.Linear(64, 7), + # nn.Sigmoid(), + ] + ) + elif self.traffic_pred_head_type == "seg": + self.traffic_pred_head = nn.Sequential( + *[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()] + ) + + self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True) + + encoder_layer = TransformerEncoderLayer( + embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before + ) + self.encoder = TransformerEncoder(encoder_layer, enc_depth, None) + + decoder_layer = TransformerDecoderLayer( + embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before + ) + decoder_norm = nn.LayerNorm(embed_dim) + self.decoder = TransformerDecoder( + decoder_layer, dec_depth, decoder_norm, return_intermediate=False + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.global_embed) + nn.init.uniform_(self.view_embed) + nn.init.uniform_(self.query_embed) + nn.init.uniform_(self.query_pos_embed) + + def forward_features( + self, + front_image, + left_image, + right_image, + front_center_image, + lidar, + measurements, + ): + features = [] + + # Front view processing + front_image_token, front_image_token_global = self.rgb_patch_embed(front_image) + if self.use_view_embed: + front_image_token = ( + front_image_token + + self.view_embed[:, :, 0:1, :] + + self.position_encoding(front_image_token) + ) + else: + front_image_token = front_image_token + self.position_encoding( + front_image_token + ) + front_image_token = front_image_token.flatten(2).permute(2, 0, 1) + front_image_token_global = ( + front_image_token_global + + self.view_embed[:, :, 0, :] + + self.global_embed[:, :, 0:1] + ) + front_image_token_global = front_image_token_global.permute(2, 0, 1) + features.extend([front_image_token, front_image_token_global]) + + if self.with_right_left_sensors: + # Left view processing + left_image_token, left_image_token_global = self.rgb_patch_embed(left_image) + if self.use_view_embed: + left_image_token = ( + left_image_token + + self.view_embed[:, :, 1:2, :] + + self.position_encoding(left_image_token) + ) + else: + left_image_token = left_image_token + self.position_encoding( + left_image_token + ) + left_image_token = left_image_token.flatten(2).permute(2, 0, 1) + left_image_token_global = ( + left_image_token_global + + self.view_embed[:, :, 1, :] + + self.global_embed[:, :, 1:2] + ) + left_image_token_global = left_image_token_global.permute(2, 0, 1) + + # Right view processing + right_image_token, right_image_token_global = self.rgb_patch_embed( + right_image + ) + if self.use_view_embed: + right_image_token = ( + right_image_token + + self.view_embed[:, :, 2:3, :] + + self.position_encoding(right_image_token) + ) + else: + right_image_token = right_image_token + self.position_encoding( + right_image_token + ) + right_image_token = right_image_token.flatten(2).permute(2, 0, 1) + right_image_token_global = ( + right_image_token_global + + self.view_embed[:, :, 2, :] + + self.global_embed[:, :, 2:3] + ) + right_image_token_global = right_image_token_global.permute(2, 0, 1) + + features.extend( + [ + left_image_token, + left_image_token_global, + right_image_token, + right_image_token_global, + ] + ) + + if self.with_center_sensor: + # Front center view processing + ( + front_center_image_token, + front_center_image_token_global, + ) = self.rgb_patch_embed(front_center_image) + if self.use_view_embed: + front_center_image_token = ( + front_center_image_token + + self.view_embed[:, :, 3:4, :] + + self.position_encoding(front_center_image_token) + ) + else: + front_center_image_token = ( + front_center_image_token + + self.position_encoding(front_center_image_token) + ) + + front_center_image_token = front_center_image_token.flatten(2).permute( + 2, 0, 1 + ) + front_center_image_token_global = ( + front_center_image_token_global + + self.view_embed[:, :, 3, :] + + self.global_embed[:, :, 3:4] + ) + front_center_image_token_global = front_center_image_token_global.permute( + 2, 0, 1 + ) + features.extend([front_center_image_token, front_center_image_token_global]) + + if self.with_lidar: + lidar_token, lidar_token_global = self.lidar_patch_embed(lidar) + if self.use_view_embed: + lidar_token = ( + lidar_token + + self.view_embed[:, :, 4:5, :] + + self.position_encoding(lidar_token) + ) + else: + lidar_token = lidar_token + self.position_encoding(lidar_token) + lidar_token = lidar_token.flatten(2).permute(2, 0, 1) + lidar_token_global = ( + lidar_token_global + + self.view_embed[:, :, 4, :] + + self.global_embed[:, :, 4:5] + ) + lidar_token_global = lidar_token_global.permute(2, 0, 1) + features.extend([lidar_token, lidar_token_global]) + + features = torch.cat(features, 0) + return features + + def forward(self, x): + front_image = x["rgb"] + left_image = x["rgb_left"] + right_image = x["rgb_right"] + front_center_image = x["rgb_center"] + measurements = x["measurements"] + target_point = x["target_point"] + lidar = x["lidar"] + + if self.direct_concat: + img_size = front_image.shape[-1] + left_image = torch.nn.functional.interpolate( + left_image, size=(img_size, img_size) + ) + right_image = torch.nn.functional.interpolate( + right_image, size=(img_size, img_size) + ) + front_center_image = torch.nn.functional.interpolate( + front_center_image, size=(img_size, img_size) + ) + front_image = torch.cat( + [front_image, left_image, right_image, front_center_image], dim=1 + ) + features = self.forward_features( + front_image, + left_image, + right_image, + front_center_image, + lidar, + measurements, + ) + + bs = front_image.shape[0] + + if self.end2end: + tgt = self.query_pos_embed.repeat(bs, 1, 1) + else: + tgt = self.position_encoding( + torch.ones((bs, 1, 20, 20), device=x["rgb"].device) + ) + tgt = tgt.flatten(2) + tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2) + tgt = tgt.permute(2, 0, 1) + + memory = self.encoder(features, mask=self.attn_mask) + hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0] + + hs = hs.permute(1, 0, 2) # Batchsize , N, C + if self.end2end: + waypoints = self.waypoints_generator(hs, target_point) + return waypoints + + if self.waypoints_pred_head != "heatmap": + traffic_feature = hs[:, :400] + is_junction_feature = hs[:, 400] + traffic_light_state_feature = hs[:, 400] + stop_sign_feature = hs[:, 400] + waypoints_feature = hs[:, 401:411] + else: + traffic_feature = hs[:, :400] + is_junction_feature = hs[:, 400] + traffic_light_state_feature = hs[:, 400] + stop_sign_feature = hs[:, 400] + waypoints_feature = hs[:, 401:405] + + if self.waypoints_pred_head == "heatmap": + waypoints = self.waypoints_generator(waypoints_feature, measurements) + elif self.waypoints_pred_head == "gru": + waypoints = self.waypoints_generator(waypoints_feature, target_point) + elif self.waypoints_pred_head == "gru-command": + waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements) + elif self.waypoints_pred_head == "linear": + waypoints = self.waypoints_generator(waypoints_feature, measurements) + elif self.waypoints_pred_head == "linear-sum": + waypoints = self.waypoints_generator(waypoints_feature, measurements) + + is_junction = self.junction_pred_head(is_junction_feature) + traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature) + stop_sign = self.stop_sign_head(stop_sign_feature) + + velocity = measurements[:, 6:7].unsqueeze(-1) + velocity = velocity.repeat(1, 400, 32) + traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2) + traffic = self.traffic_pred_head(traffic_feature_with_vel) + return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature + def load_pretrained(self, model_path, strict=False): + """ + تحميل الأوزان المدربة مسبقاً - نسخة محسنة + + Args: + model_path (str): مسار ملف الأوزان + strict (bool): إذا كان True، يتطلب تطابق تام للمفاتيح + """ + if not model_path or not Path(model_path).exists(): + logging.warning(f"ملف الأوزان غير موجود: {model_path}") + logging.info("سيتم استخدام أوزان عشوائية") + return False + + try: + logging.info(f"محاولة تحميل الأوزان من: {model_path}") + + # تحميل الملف مع معالجة أنواع مختلفة من ملفات الحفظ + checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) + + # استخراج state_dict من أنواع مختلفة من ملفات الحفظ + if isinstance(checkpoint, dict): + if 'model_state_dict' in checkpoint: + state_dict = checkpoint['model_state_dict'] + logging.info("تم العثور على 'model_state_dict' في الملف") + elif 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + logging.info("تم العثور على 'state_dict' في الملف") + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + logging.info("تم العثور على 'model' في الملف") + else: + state_dict = checkpoint + logging.info("استخدام الملف كـ state_dict مباشرة") + else: + state_dict = checkpoint + logging.info("استخدام الملف كـ state_dict مباشرة") + + # تنظيف أسماء المفاتيح (إزالة 'module.' إذا كانت موجودة) + clean_state_dict = OrderedDict() + for k, v in state_dict.items(): + # إزالة 'module.' من بداية اسم المفتاح إذا كان موجوداً + clean_key = k[7:] if k.startswith('module.') else k + clean_state_dict[clean_key] = v + + # تحميل الأوزان + missing_keys, unexpected_keys = self.load_state_dict(clean_state_dict, strict=strict) + + # تقرير حالة التحميل + if missing_keys: + logging.warning(f"مفاتيح مفقودة ({len(missing_keys)}): {missing_keys[:5]}..." if len(missing_keys) > 5 else f"مفاتيح مفقودة: {missing_keys}") + + if unexpected_keys: + logging.warning(f"مفاتيح غير متوقعة ({len(unexpected_keys)}): {unexpected_keys[:5]}..." if len(unexpected_keys) > 5 else f"مفاتيح غير متوقعة: {unexpected_keys}") + + if not missing_keys and not unexpected_keys: + logging.info("✅ تم تحميل جميع الأوزان بنجاح تام") + elif not strict: + logging.info("✅ تم تحميل الأوزان بنجاح (مع تجاهل عدم التطابق)") + + return True + + except Exception as e: + logging.error(f"❌ خطأ في تحميل الأوزان: {str(e)}") + logging.info("سيتم استخدام أوزان عشوائية") + return False + + +# ============================================================================ +# دوال مساعدة لتحميل النموذج +# ============================================================================ + +def load_and_prepare_model(config, device): + """ + يقوم بإنشاء النموذج وتحميل الأوزان المدربة مسبقًا. + + Args: + config (dict): إعدادات النموذج والمسارات + device (torch.device): الجهاز المستهدف (CPU/GPU) + + Returns: + InterfuserModel: النموذج المحمل + """ + try: + # إنشاء النموذج + model = InterfuserModel(**config.get('model_params', {})).to(device) + logging.info(f"تم إنشاء النموذج على الجهاز: {device}") + + # تحميل الأوزان إذا كان المسار محدد + checkpoint_path = config.get('paths', {}).get('pretrained_weights') + if checkpoint_path: + success = model.load_pretrained(checkpoint_path, strict=False) + if success: + logging.info("✅ تم تحميل النموذج والأوزان بنجاح") + else: + logging.warning("⚠️ تم إنشاء النموذج بأوزان عشوائية") + else: + logging.info("لم يتم تحديد مسار الأوزان، سيتم استخدام أوزان عشوائية") + + # وضع النموذج في وضع التقييم + model.eval() + + return model + + except Exception as e: + logging.error(f"خطأ في إنشاء النموذج: {str(e)}") + raise + + +def create_model_config(model_path="model/best_model.pth", **model_params): + """ + إنشاء إعدادات النموذج باستخدام الإعدادات الصحيحة من التدريب + + Args: + model_path (str): مسار ملف الأوزان + **model_params: معاملات النموذج الإضافية + + Returns: + dict: إعدادات النموذج + """ + # الإعدادات الصحيحة من كونفيج التدريب الأصلي + training_config_params = { + "img_size": 224, + "embed_dim": 256, # مهم: هذه القيمة من التدريب الأصلي + "enc_depth": 6, + "dec_depth": 6, + "rgb_backbone_name": 'r50', + "lidar_backbone_name": 'r18', + "waypoints_pred_head": 'gru', + "use_different_backbone": True, + "with_lidar": False, + "with_right_left_sensors": False, + "with_center_sensor": False, + + # إعدادات إضافية من الكونفيج الأصلي + "multi_view_img_size": 112, + "patch_size": 8, + "in_chans": 3, + "dim_feedforward": 2048, + "normalize_before": False, + "num_heads": 8, + "dropout": 0.1, + "end2end": False, + "direct_concat": False, + "separate_view_attention": False, + "separate_all_attention": False, + "freeze_num": -1, + "traffic_pred_head_type": "det", + "reverse_pos": True, + "use_view_embed": False, + "use_mmad_pretrain": None, + } + + # دمج المعاملات المخصصة مع الإعدادات من التدريب + training_config_params.update(model_params) + + config = { + 'model_params': training_config_params, + 'paths': { + 'pretrained_weights': model_path + }, + + # إضافة إعدادات الشبكة من التدريب + 'grid_conf': { + 'h': 20, 'w': 20, + 'x_res': 1.0, 'y_res': 1.0, + 'y_min': 0.0, 'y_max': 20.0, + 'x_min': -10.0, 'x_max': 10.0, + }, + + # معلومات إضافية عن التدريب + 'training_info': { + 'original_project': 'Interfuser_Finetuning', + 'run_name': 'Finetune_Focus_on_Detection_v5', + 'focus': 'traffic_detection_and_iou', + 'backbone': 'ResNet50 + ResNet18', + 'trained_on': 'PDM_Lite_Carla' + } + } + + return config + + +def get_training_config(): + """ + إرجاع إعدادات التدريب الأصلية للمرجع + هذه الإعدادات توضح كيف تم تدريب النموذج + """ + return { + 'project_info': { + 'project': 'Interfuser_Finetuning', + 'entity': None, + 'run_name': 'Finetune_Focus_on_Detection_v5' + }, + 'training': { + 'epochs': 50, + 'batch_size': 8, + 'num_workers': 2, + 'learning_rate': 1e-4, # معدل تعلم منخفض للـ Fine-tuning + 'weight_decay': 1e-2, + 'patience': 15, + 'clip_grad_norm': 1.0, + }, + 'loss_weights': { + 'iou': 2.0, # أولوية قصوى لدقة الصناديق + 'traffic_map': 25.0, # تركيز عالي على اكتشاف الكائنات + 'waypoints': 1.0, # مرجع أساسي + 'junction': 0.25, # مهام متقنة بالفعل + 'traffic_light': 0.5, + 'stop_sign': 0.25, + }, + 'data_split': { + 'strategy': 'interleaved', + 'segment_length': 100, + 'validation_frequency': 10, + }, + 'transforms': { + 'use_data_augmentation': False, # معطل للتركيز على البيانات الأصلية + } + }