mohammed-aljafry commited on
Commit
54355d5
·
verified ·
1 Parent(s): 45256a3

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +17 -0
  2. modeling_interfuser.py +311 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "InterfuserForHuggingFace"
4
+ ],
5
+ "dec_depth": 6,
6
+ "dim_feedforward": 2048,
7
+ "embed_dim": 256,
8
+ "enc_depth": 6,
9
+ "lidar_backbone_name": "r18",
10
+ "model_type": "interfuser",
11
+ "num_heads": 8,
12
+ "rgb_backbone_name": "r50",
13
+ "torch_dtype": "float32",
14
+ "transformers_version": "4.52.4",
15
+ "use_different_backbone": true,
16
+ "waypoints_pred_head": "gru"
17
+ }
modeling_interfuser.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+ # This file contains all custom class definitions required to run the Interfuser model.
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from transformers import PreTrainedModel, PretrainedConfig
9
+ from functools import partial
10
+ import math
11
+ from collections import OrderedDict
12
+ import copy
13
+ from typing import Optional, List
14
+ from torch import Tensor
15
+
16
+ # It's better to import from the original source if possible
17
+ # For full portability, we define them here.
18
+ from InterFuser.interfuser.timm.models.layers import to_2tuple
19
+ from InterFuser.interfuser.timm.models.resnet import resnet50d, resnet26d, resnet18d
20
+
21
+ # ==============================================================================
22
+ # SECTION 1: ALL DEPENDENCY CLASSES FROM THE ORIGINAL CODE
23
+ # ==============================================================================
24
+
25
+ class HybridEmbed(nn.Module):
26
+ def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
27
+ super().__init__()
28
+ assert isinstance(backbone, nn.Module)
29
+ img_size = to_2tuple(img_size)
30
+ patch_size = to_2tuple(patch_size)
31
+ self.img_size = img_size
32
+ self.patch_size = patch_size
33
+ self.backbone = backbone
34
+ if feature_size is None:
35
+ with torch.no_grad():
36
+ training = backbone.training
37
+ if training:
38
+ backbone.eval()
39
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
40
+ if isinstance(o, (list, tuple)):
41
+ o = o[-1]
42
+ feature_size = o.shape[-2:]
43
+ feature_dim = o.shape[1]
44
+ backbone.train(training)
45
+ else:
46
+ feature_size = to_2tuple(feature_size)
47
+ if hasattr(self.backbone, "feature_info"):
48
+ feature_dim = self.backbone.feature_info.channels()[-1]
49
+ else:
50
+ feature_dim = self.backbone.num_features
51
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
52
+
53
+ def forward(self, x):
54
+ x = self.backbone(x)
55
+ if isinstance(x, (list, tuple)):
56
+ x = x[-1]
57
+ x = self.proj(x)
58
+ global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
59
+ return x, global_x
60
+
61
+ class PositionEmbeddingSine(nn.Module):
62
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
63
+ super().__init__()
64
+ self.num_pos_feats = num_pos_feats
65
+ self.temperature = temperature
66
+ self.normalize = normalize
67
+ if scale is not None and normalize is False:
68
+ raise ValueError("normalize should be True if scale is passed")
69
+ if scale is None:
70
+ scale = 2 * math.pi
71
+ self.scale = scale
72
+
73
+ def forward(self, tensor):
74
+ x = tensor
75
+ bs, _, h, w = x.shape
76
+ not_mask = torch.ones((bs, h, w), device=x.device)
77
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
78
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
79
+ if self.normalize:
80
+ eps = 1e-6
81
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
82
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
83
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
84
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
85
+ pos_x = x_embed[:, :, :, None] / dim_t
86
+ pos_y = y_embed[:, :, :, None] / dim_t
87
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
88
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
89
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
90
+ return pos
91
+
92
+ def _get_clones(module, N):
93
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
94
+
95
+ class TransformerEncoderLayer(nn.Module):
96
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False):
97
+ super().__init__()
98
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
99
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
100
+ self.dropout = nn.Dropout(dropout)
101
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
102
+ self.norm1 = nn.LayerNorm(d_model)
103
+ self.norm2 = nn.LayerNorm(d_model)
104
+ self.dropout1 = nn.Dropout(dropout)
105
+ self.dropout2 = nn.Dropout(dropout)
106
+ self.activation = activation
107
+ self.normalize_before = normalize_before
108
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
109
+ return tensor if pos is None else tensor + pos
110
+ def forward(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
111
+ q = k = self.with_pos_embed(src, pos)
112
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
113
+ src = src + self.dropout1(src2)
114
+ src = self.norm1(src)
115
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
116
+ src = src + self.dropout2(src2)
117
+ src = self.norm2(src)
118
+ return src
119
+
120
+ class TransformerEncoder(nn.Module):
121
+ def __init__(self, encoder_layer, num_layers, norm=None):
122
+ super().__init__()
123
+ self.layers = _get_clones(encoder_layer, num_layers)
124
+ self.num_layers = num_layers
125
+ self.norm = norm
126
+ def forward(self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
127
+ output = src
128
+ for layer in self.layers:
129
+ output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
130
+ if self.norm is not None:
131
+ output = self.norm(output)
132
+ return output
133
+
134
+ class TransformerDecoderLayer(nn.Module):
135
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False):
136
+ super().__init__()
137
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
138
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
139
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
140
+ self.dropout = nn.Dropout(dropout)
141
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
142
+ self.norm1 = nn.LayerNorm(d_model)
143
+ self.norm2 = nn.LayerNorm(d_model)
144
+ self.norm3 = nn.LayerNorm(d_model)
145
+ self.dropout1 = nn.Dropout(dropout)
146
+ self.dropout2 = nn.Dropout(dropout)
147
+ self.dropout3 = nn.Dropout(dropout)
148
+ self.activation = activation
149
+ self.normalize_before = normalize_before
150
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
151
+ return tensor if pos is None else tensor + pos
152
+ def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
153
+ q = k = self.with_pos_embed(tgt, query_pos)
154
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
155
+ tgt = tgt + self.dropout1(tgt2)
156
+ tgt = self.norm1(tgt)
157
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
158
+ tgt = tgt + self.dropout2(tgt2)
159
+ tgt = self.norm2(tgt)
160
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
161
+ tgt = tgt + self.dropout3(tgt2)
162
+ tgt = self.norm3(tgt)
163
+ return tgt
164
+
165
+ class TransformerDecoder(nn.Module):
166
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
167
+ super().__init__()
168
+ self.layers = _get_clones(decoder_layer, num_layers)
169
+ self.num_layers = num_layers
170
+ self.norm = norm
171
+ self.return_intermediate = return_intermediate
172
+ def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
173
+ output = tgt
174
+ for layer in self.layers:
175
+ output = layer(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos)
176
+ if self.norm is not None:
177
+ output = self.norm(output)
178
+ return output.unsqueeze(0)
179
+
180
+ class GRUWaypointsPredictor(nn.Module):
181
+ def __init__(self, input_dim, waypoints=10):
182
+ super().__init__()
183
+ self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
184
+ self.encoder = nn.Linear(2, 64)
185
+ self.decoder = nn.Linear(64, 2)
186
+ self.waypoints = waypoints
187
+ def forward(self, x, target_point):
188
+ bs = x.shape[0]
189
+ z = self.encoder(target_point).unsqueeze(0)
190
+ output, _ = self.gru(x, z)
191
+ output = output.reshape(bs * self.waypoints, -1)
192
+ output = self.decoder(output).reshape(bs, self.waypoints, 2)
193
+ output = torch.cumsum(output, 1)
194
+ return output
195
+
196
+ # ... (Add other dependency classes like SpatialSoftmax, MultiPath_Generator, etc. if needed by other configs)
197
+
198
+ # --- The ORIGINAL Interfuser Model Class ---
199
+ class Interfuser(nn.Module):
200
+ def __init__(self, img_size=224, multi_view_img_size=112, patch_size=8, in_chans=3, embed_dim=768, enc_depth=6, dec_depth=6, dim_feedforward=2048, normalize_before=False, rgb_backbone_name="r26", lidar_backbone_name="r26", num_heads=8, norm_layer=None, dropout=0.1, end2end=False, direct_concat=True, separate_view_attention=False, separate_all_attention=False, act_layer=None, weight_init="", freeze_num=-1, with_lidar=False, with_right_left_sensors=True, with_center_sensor=False, traffic_pred_head_type="det", waypoints_pred_head="heatmap", reverse_pos=True, use_different_backbone=False, use_view_embed=True, use_mmad_pretrain=None):
201
+ super().__init__()
202
+ self.num_features = self.embed_dim = embed_dim
203
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
204
+ act_layer = act_layer or nn.GELU
205
+
206
+ self.waypoints_pred_head = waypoints_pred_head
207
+ self.with_lidar = with_lidar
208
+ self.with_right_left_sensors = with_right_left_sensors
209
+ self.attn_mask = None # Simplified
210
+
211
+ if use_different_backbone:
212
+ if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d(pretrained=False, in_chans=3, features_only=True, out_indices=[4])
213
+ if rgb_backbone_name == "r26": self.rgb_backbone = resnet26d(pretrained=False, in_chans=3, features_only=True, out_indices=[4])
214
+ if lidar_backbone_name == "r18": self.lidar_backbone = resnet18d(pretrained=False, in_chans=3, features_only=True, out_indices=[4])
215
+
216
+ rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
217
+ lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
218
+ self.rgb_patch_embed = rgb_embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
219
+ self.lidar_patch_embed = lidar_embed_layer(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
220
+ else: raise NotImplementedError("Only use_different_backbone=True supported in this wrapper")
221
+
222
+ self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
223
+ self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
224
+ self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
225
+ self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
226
+
227
+ if self.waypoints_pred_head == "gru": self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
228
+ else: raise NotImplementedError("Only GRU waypoints head supported in this wrapper")
229
+
230
+ self.junction_pred_head = nn.Linear(embed_dim, 2)
231
+ self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
232
+ self.stop_sign_head = nn.Linear(embed_dim, 2)
233
+ self.traffic_pred_head = nn.Sequential(*[nn.Linear(embed_dim + 32, 64), nn.ReLU(), nn.Linear(64, 7), nn.Sigmoid()])
234
+ self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
235
+
236
+ encoder_layer = TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before)
237
+ self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
238
+ decoder_layer = TransformerDecoderLayer(embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before)
239
+ decoder_norm = nn.LayerNorm(embed_dim)
240
+ self.decoder = TransformerDecoder(decoder_layer, dec_depth, decoder_norm, return_intermediate=False)
241
+
242
+ def forward_features(self, front_image, left_image, right_image, front_center_image, lidar, measurements):
243
+ features = []
244
+ front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
245
+ front_image_token = (front_image_token + self.position_encoding(front_image_token))
246
+ front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
247
+ front_image_token_global = (front_image_token_global + self.global_embed[:, :, 0:1])
248
+ front_image_token_global = front_image_token_global.permute(2, 0, 1)
249
+ features.extend([front_image_token, front_image_token_global])
250
+ left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
251
+ left_image_token = (left_image_token + self.position_encoding(left_image_token)).flatten(2).permute(2, 0, 1)
252
+ left_image_token_global = (left_image_token_global + self.global_embed[:, :, 1:2]).permute(2, 0, 1)
253
+ right_image_token, right_image_token_global = self.rgb_patch_embed(right_image)
254
+ right_image_token = (right_image_token + self.position_encoding(right_image_token)).flatten(2).permute(2, 0, 1)
255
+ right_image_token_global = (right_image_token_global + self.global_embed[:, :, 2:3]).permute(2, 0, 1)
256
+ features.extend([left_image_token, left_image_token_global, right_image_token, right_image_token_global])
257
+ return torch.cat(features, 0)
258
+
259
+ def forward(self, x):
260
+ front_image, left_image, right_image = x["rgb"], x["rgb_left"], x["rgb_right"]
261
+ measurements, target_point = x["measurements"], x["target_point"]
262
+ features = self.forward_features(front_image, left_image, right_image, x["rgb_center"], x["lidar"], measurements)
263
+ bs = front_image.shape[0]
264
+ tgt = self.position_encoding(torch.ones((bs, 1, 20, 20), device=x["rgb"].device)).flatten(2)
265
+ tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2).permute(2, 0, 1)
266
+ memory = self.encoder(features, mask=self.attn_mask)
267
+ hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0].permute(1, 0, 2)
268
+ traffic_feature = hs[:, :400]
269
+ waypoints_feature = hs[:, 401:411]
270
+ is_junction_feature = hs[:, 400]
271
+ traffic_light_state_feature, stop_sign_feature = hs[:, 400], hs[:, 400]
272
+ waypoints = self.waypoints_generator(waypoints_feature, target_point)
273
+ is_junction = self.junction_pred_head(is_junction_feature)
274
+ traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
275
+ stop_sign = self.stop_sign_head(stop_sign_feature)
276
+ velocity = measurements[:, 6:7].unsqueeze(-1).repeat(1, 400, 32)
277
+ traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
278
+ traffic = self.traffic_pred_head(traffic_feature_with_vel)
279
+ return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
280
+
281
+ # ==============================================================================
282
+ # SECTION 2: HUGGING FACE WRAPPER CLASSES
283
+ # ==============================================================================
284
+
285
+ class InterfuserConfig(PretrainedConfig):
286
+ model_type = "interfuser"
287
+ def __init__(self, embed_dim=256, enc_depth=6, dec_depth=6, num_heads=8, dim_feedforward=2048, rgb_backbone_name="r50", lidar_backbone_name="r18", waypoints_pred_head="gru", use_different_backbone=True, **kwargs):
288
+ super().__init__(**kwargs)
289
+ self.embed_dim, self.enc_depth, self.dec_depth = embed_dim, enc_depth, dec_depth
290
+ self.num_heads, self.dim_feedforward = num_heads, dim_feedforward
291
+ self.rgb_backbone_name, self.lidar_backbone_name = rgb_backbone_name, lidar_backbone_name
292
+ self.waypoints_pred_head, self.use_different_backbone = waypoints_pred_head, use_different_backbone
293
+ self.architectures = ["InterfuserForHuggingFace"]
294
+
295
+ class InterfuserForHuggingFace(PreTrainedModel):
296
+ config_class = InterfuserConfig
297
+ def __init__(self, config: InterfuserConfig):
298
+ super().__init__(config)
299
+ self.config = config
300
+ self.interfuser_model = Interfuser(
301
+ embed_dim=self.config.embed_dim, enc_depth=self.config.enc_depth, dec_depth=self.config.dec_depth,
302
+ rgb_backbone_name=self.config.rgb_backbone_name, lidar_backbone_name=self.config.lidar_backbone_name,
303
+ waypoints_pred_head=self.config.waypoints_pred_head, use_different_backbone=self.config.use_different_backbone,
304
+ )
305
+
306
+ def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
307
+ inputs_dict = {
308
+ 'rgb': rgb, 'rgb_left': rgb_left, 'rgb_right': rgb_right, 'rgb_center': rgb_center,
309
+ 'lidar': lidar, 'measurements': measurements, 'target_point': target_point
310
+ }
311
+ return self.interfuser_model.forward(inputs_dict)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b9217e7ecb188304fb3c33b621df6120b6458ba25d969594af358cd5ef5d9db
3
+ size 212292994