mohammed-aljafry commited on
Commit
201bea6
·
verified ·
1 Parent(s): 0ced52a

Final fix: Upload self-contained model with corrected 12-channel config and full code

Browse files
Files changed (2) hide show
  1. config.json +4 -1
  2. modeling_interfuser.py +86 -43
config.json CHANGED
@@ -13,5 +13,8 @@
13
  "rgb_backbone_name": "r50",
14
  "lidar_backbone_name": "r18",
15
  "use_different_backbone": true,
16
- "waypoints_pred_head": "gru"
 
 
 
17
  }
 
13
  "rgb_backbone_name": "r50",
14
  "lidar_backbone_name": "r18",
15
  "use_different_backbone": true,
16
+ "waypoints_pred_head": "gru",
17
+ "auto_map": {
18
+ "AutoModel": "modeling_interfuser.InterfuserForHuggingFace"
19
+ }
20
  }
modeling_interfuser.py CHANGED
@@ -1,85 +1,122 @@
1
 
2
  # -*- coding: utf-8 -*-
 
3
  import torch, math, copy
4
  from torch import nn, Tensor
5
  from functools import partial
6
  from typing import Optional, List
 
7
  from transformers import PreTrainedModel, PretrainedConfig
8
- try: from timm.models.layers import to_2tuple
9
- except ImportError: raise ImportError("Please install timm 'pip install timm==0.3.2'")
10
- try: from timm.models.resnet import resnet50d, resnet26d, resnet18d
11
- except ImportError: raise ImportError("Please install timm 'pip install timm==0.3.2'")
 
 
 
 
 
 
 
12
 
13
  class HybridEmbed(nn.Module):
14
  def __init__(self, backbone, img_size=224, in_chans=3, embed_dim=768, **kwargs):
15
  super().__init__()
16
- self.backbone=backbone
17
  with torch.no_grad():
18
- dummy_input=torch.zeros(1,in_chans,img_size,img_size)
19
- features=self.backbone(dummy_input)
20
- feature_dim=features[-1].shape[1]
21
- self.proj=nn.Conv2d(feature_dim,embed_dim,kernel_size=1)
22
- def forward(self,x):
23
- features=self.backbone(x);x_feat=features[-1];x=self.proj(x_feat);global_x=torch.mean(x,[2,3],keepdim=False)[:,:,None]
24
- return x,global_x
 
 
25
  class PositionEmbeddingSine(nn.Module):
26
- def __init__(self,num_pos_feats=64,temperature=10000,normalize=True,**kwargs):
27
  super().__init__();self.num_pos_feats=num_pos_feats;self.temperature=temperature;self.normalize=normalize;self.scale=2*math.pi
28
- def forward(self,tensor):
29
  x=tensor;bs,_,h,w=x.shape;not_mask=torch.ones((bs,h,w),device=x.device);y_embed=not_mask.cumsum(1,dtype=torch.float32);x_embed=not_mask.cumsum(2,dtype=torch.float32)
30
  if self.normalize:eps=1e-6;y_embed=y_embed/(y_embed[:,-1:,:]+eps)*self.scale;x_embed=x_embed/(x_embed[:,:,-1:]+eps)*self.scale
31
  dim_t=torch.arange(self.num_pos_feats,dtype=torch.float32,device=x.device);dim_t=self.temperature**(2*(dim_t//2)/self.num_pos_feats)
32
  pos_x=x_embed[:,:,:,None]/dim_t;pos_y=y_embed[:,:,:,None]/dim_t
33
  pos_x=torch.stack((pos_x[:,:,:,0::2].sin(),pos_x[:,:,:,1::2].cos()),dim=4).flatten(3);pos_y=torch.stack((pos_y[:,:,:,0::2].sin(),pos_y[:,:,:,1::2].cos()),dim=4).flatten(3)
34
  return torch.cat((pos_y,pos_x),dim=3).permute(0,3,1,2)
35
- def _get_clones(module,N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
36
  class TransformerEncoderLayer(nn.Module):
37
  def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),**kwargs):
38
- super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.activation=activation
39
  def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
40
  def forward(self,src,src_mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
41
  q=k=self.with_pos_embed(src,pos);src2=self.self_attn(q,k,value=src,attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0];src=src+self.dropout1(src2);src=self.norm1(src);src2=self.linear2(self.dropout(self.activation(self.linear1(src))));src=src+self.dropout2(src2);src=self.norm2(src);return src
 
42
  class TransformerEncoder(nn.Module):
43
  def __init__(self,encoder_layer,num_layers,norm=None):super().__init__();self.layers=_get_clones(encoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm
44
  def forward(self,src,mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
45
- output=src;[output:=layer(output,src_mask=mask,src_key_padding_mask=src_key_padding_mask,pos=pos) for layer in self.layers];
 
46
  if self.norm is not None:output=self.norm(output)
47
  return output
 
48
  class TransformerDecoderLayer(nn.Module):
49
  def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),**kwargs):
50
- super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.multihead_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.norm3=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.dropout3=nn.Dropout(dropout);self.activation=activation
51
  def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
52
  def forward(self,tgt,memory,tgt_mask:Optional[Tensor]=None,memory_mask:Optional[Tensor]=None,tgt_key_padding_mask:Optional[Tensor]=None,memory_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None,query_pos:Optional[Tensor]=None):
53
- q=k=self.with_pos_embed(tgt,query_pos);tgt2=self.self_attn(q,k,value=tgt,attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0];tgt=tgt+self.dropout1(tgt2);tgt=self.norm1(tgt);tgt2=self.multihead_attn(query=self.with_pos_embed(tgt,query_pos),key=self.with_pos_embed(memory,pos),value=memory,attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0];tgt=tgt+self.dropout2(tgt2);tgt=self.norm2(tgt);tgt2=self.linear2(self.dropout(self.activation(self.linear1(tgt))));tgt=tgt+self.dropout3(tgt2);tgt=self.norm3(tgt);return tgt
 
 
 
54
  class TransformerDecoder(nn.Module):
55
  def __init__(self,decoder_layer,num_layers,norm=None,return_intermediate=False):super().__init__();self.layers=_get_clones(decoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm;self.return_intermediate=return_intermediate
56
- def forward(self,tgt,memory,**kwargs):output=tgt;[output:=layer(output,memory,**kwargs) for layer in self.layers];return self.norm(output).unsqueeze(0) if self.norm is not None else output.unsqueeze(0)
 
 
 
 
 
57
  class GRUWaypointsPredictor(nn.Module):
58
- def __init__(self,input_dim,waypoints=10):super().__init__();self.gru=torch.nn.GRU(input_size=input_dim,hidden_size=64,batch_first=True);self.encoder=nn.Linear(2,64);self.decoder=nn.Linear(64,2);self.waypoints=waypoints
59
  def forward(self,x,target_point):bs=x.shape[0];z=self.encoder(target_point).unsqueeze(0);output,_=self.gru(x,z);output=self.decoder(output.reshape(bs*self.waypoints,-1)).reshape(bs,self.waypoints,2);return torch.cumsum(output,1)
60
 
 
 
 
61
  class Interfuser(nn.Module):
62
- def __init__(self,config):
63
- super().__init__();self.config=config;self.embed_dim=config.embed_dim;self.waypoints_pred_head=config.waypoints_pred_head
64
- if config.rgb_backbone_name=="r50":self.rgb_backbone=resnet50d(pretrained=False,in_chans=config.in_chans,features_only=True,out_indices=[4])
65
- else:self.rgb_backbone=resnet18d(pretrained=False,in_chans=config.in_chans,features_only=True,out_indices=[4])
66
- self.rgb_patch_embed=partial(HybridEmbed,backbone=self.rgb_backbone)(in_chans=config.in_chans,embed_dim=config.embed_dim)
67
- if getattr(config,'lidar_backbone_name',None):self.lidar_backbone=resnet18d(pretrained=False,in_chans=3,features_only=True,out_indices=[4]);self.lidar_patch_embed=partial(HybridEmbed,backbone=self.lidar_backbone)(in_chans=3,embed_dim=config.embed_dim)
 
 
 
 
 
68
  self.global_embed=nn.Parameter(torch.zeros(1,config.embed_dim,5));self.view_embed=nn.Parameter(torch.zeros(1,config.embed_dim,5,1));self.query_pos_embed=nn.Parameter(torch.zeros(1,config.embed_dim,11));self.query_embed=nn.Parameter(torch.zeros(400+11,1,config.embed_dim))
69
  self.waypoints_generator=GRUWaypointsPredictor(config.embed_dim);self.junction_pred_head=nn.Linear(config.embed_dim,2);self.traffic_light_pred_head=nn.Linear(config.embed_dim,2);self.stop_sign_head=nn.Linear(config.embed_dim,2)
70
  self.traffic_pred_head=nn.Sequential(nn.Linear(config.embed_dim+32,64),nn.ReLU(),nn.Linear(64,7),nn.Sigmoid());self.position_encoding=PositionEmbeddingSine(config.embed_dim//2)
71
- encoder_layer=TransformerEncoderLayer(config.embed_dim,config.num_heads,config.dim_feedforward);self.encoder=TransformerEncoder(encoder_layer,config.enc_depth);decoder_layer=TransformerDecoderLayer(config.embed_dim,config.num_heads,config.dim_feedforward);self.decoder=TransformerDecoder(decoder_layer,config.dec_depth,nn.LayerNorm(config.embed_dim))
 
72
  def forward_features(self,front_image,left_image,right_image,front_center_image,lidar,**kwargs):
73
- features=[];front_image_token,_=self.rgb_patch_embed(front_image);features.append(front_image_token+self.position_encoding(front_image_token))
74
- left_image_token,_=self.rgb_patch_embed(left_image);features.append(left_image_token+self.position_encoding(left_image_token))
75
- right_image_token,_=self.rgb_patch_embed(right_image);features.append(right_image_token+self.position_encoding(right_image_token))
76
- return torch.cat([f.flatten(2) for f in features],-1).permute(2,0,1)
 
 
 
 
 
 
77
  def forward(self,x):
78
  front_image=x["rgb"];left_image=x["rgb_left"];right_image=x["rgb_right"];front_center_image=x["rgb_center"];measurements=x["measurements"];target_point=x["target_point"];lidar=x["lidar"]
79
  if self.config.direct_concat:
80
  img_size=front_image.shape[-1];left_image=torch.nn.functional.interpolate(left_image,size=(img_size,img_size));right_image=torch.nn.functional.interpolate(right_image,size=(img_size,img_size));front_center_image=torch.nn.functional.interpolate(front_center_image,size=(img_size,img_size))
81
- input_image=torch.cat([front_image,left_image,right_image,front_center_image],dim=1)
82
- features,_=self.rgb_patch_embed(input_image);features=features+self.position_encoding(features);features=features.flatten(2).permute(2,0,1)
83
  bs=front_image.shape[0];tgt=self.position_encoding(torch.ones((bs,1,20,20),device=x["rgb"].device)).flatten(2);tgt=torch.cat([tgt,self.query_pos_embed.repeat(bs,1,1)],2).permute(2,0,1)
84
  memory=self.encoder(features,mask=None);hs=self.decoder(self.query_embed.repeat(1,bs,1),memory,query_pos=tgt)[0].permute(1,0,2)
85
  waypoints=self.waypoints_generator(hs[:,401:411],target_point)
@@ -87,17 +124,23 @@ class Interfuser(nn.Module):
87
  velocity=measurements[:,6:7].unsqueeze(-1).repeat(1,400,32);traffic_feature_with_vel=torch.cat([hs[:,:400],velocity],dim=2);traffic=self.traffic_pred_head(traffic_feature_with_vel)
88
  return traffic,waypoints,is_junction,traffic_light_state,stop_sign,hs[:,:400]
89
 
 
 
 
90
  class InterfuserConfig(PretrainedConfig):
91
  model_type="interfuser"
92
- def __init__(self,in_chans=12,embed_dim=256,enc_depth=6,dec_depth=6,num_heads=8,dim_feedforward=2048,rgb_backbone_name="r50",lidar_backbone_name="r18",use_different_backbone=True,waypoints_pred_head="gru",direct_concat=True,**kwargs):
93
- super().__init__(**kwargs);self.in_chans=in_chans;self.embed_dim,self.enc_depth,self.dec_depth,self.num_heads,self.dim_feedforward=embed_dim,enc_depth,dec_depth,num_heads,dim_feedforward
94
- self.rgb_backbone_name,self.lidar_backbone_name,self.use_different_backbone,self.waypoints_pred_head,self.direct_concat=rgb_backbone_name,lidar_backbone_name,use_different_backbone,waypoints_pred_head,direct_concat
95
- self.auto_map={"AutoModel":"modeling_interfuser.InterfuserForHuggingFace"}
 
96
 
97
  class InterfuserForHuggingFace(PreTrainedModel):
98
- config_class=InterfuserConfig
99
- def __init__(self,config:InterfuserConfig):
100
- super().__init__(config);self.config=config;self.interfuser_model=Interfuser(config)
101
- def forward(self,rgb,rgb_left,rgb_right,rgb_center,lidar,measurements,target_point,**kwargs):
102
- inputs_dict={'rgb':rgb,'rgb_left':rgb_left,'rgb_right':rgb_right,'rgb_center':rgb_center,'lidar':lidar,'measurements':measurements,'target_point':target_point}
 
 
103
  return self.interfuser_model.forward(inputs_dict)
 
1
 
2
  # -*- coding: utf-8 -*-
3
+ # This is a self-contained file for the Interfuser model.
4
  import torch, math, copy
5
  from torch import nn, Tensor
6
  from functools import partial
7
  from typing import Optional, List
8
+ from collections import OrderedDict
9
  from transformers import PreTrainedModel, PretrainedConfig
10
+
11
+ try:
12
+ from timm.models.layers import to_2tuple
13
+ from timm.models.resnet import resnet50d, resnet26d, resnet18d
14
+ except ImportError:
15
+ raise ImportError("This model requires timm. Please install with 'pip install timm==0.3.2'")
16
+
17
+ # ==================================================
18
+ # SECTION 1: ALL DEPENDENCY CLASSES
19
+ # ==================================================
20
+ def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
21
 
22
  class HybridEmbed(nn.Module):
23
  def __init__(self, backbone, img_size=224, in_chans=3, embed_dim=768, **kwargs):
24
  super().__init__()
25
+ self.backbone = backbone
26
  with torch.no_grad():
27
+ dummy_input = torch.zeros(1, in_chans, img_size, img_size)
28
+ features = self.backbone(dummy_input)
29
+ feature_dim = features[-1].shape[1]
30
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1)
31
+ def forward(self, x):
32
+ features = self.backbone(x); x_feat = features[-1]
33
+ x = self.proj(x_feat); global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
34
+ return x, global_x
35
+
36
  class PositionEmbeddingSine(nn.Module):
37
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, **kwargs):
38
  super().__init__();self.num_pos_feats=num_pos_feats;self.temperature=temperature;self.normalize=normalize;self.scale=2*math.pi
39
+ def forward(self, tensor):
40
  x=tensor;bs,_,h,w=x.shape;not_mask=torch.ones((bs,h,w),device=x.device);y_embed=not_mask.cumsum(1,dtype=torch.float32);x_embed=not_mask.cumsum(2,dtype=torch.float32)
41
  if self.normalize:eps=1e-6;y_embed=y_embed/(y_embed[:,-1:,:]+eps)*self.scale;x_embed=x_embed/(x_embed[:,:,-1:]+eps)*self.scale
42
  dim_t=torch.arange(self.num_pos_feats,dtype=torch.float32,device=x.device);dim_t=self.temperature**(2*(dim_t//2)/self.num_pos_feats)
43
  pos_x=x_embed[:,:,:,None]/dim_t;pos_y=y_embed[:,:,:,None]/dim_t
44
  pos_x=torch.stack((pos_x[:,:,:,0::2].sin(),pos_x[:,:,:,1::2].cos()),dim=4).flatten(3);pos_y=torch.stack((pos_y[:,:,:,0::2].sin(),pos_y[:,:,:,1::2].cos()),dim=4).flatten(3)
45
  return torch.cat((pos_y,pos_x),dim=3).permute(0,3,1,2)
46
+
47
  class TransformerEncoderLayer(nn.Module):
48
  def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),**kwargs):
49
+ super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.activation=activation()
50
  def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
51
  def forward(self,src,src_mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
52
  q=k=self.with_pos_embed(src,pos);src2=self.self_attn(q,k,value=src,attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0];src=src+self.dropout1(src2);src=self.norm1(src);src2=self.linear2(self.dropout(self.activation(self.linear1(src))));src=src+self.dropout2(src2);src=self.norm2(src);return src
53
+
54
  class TransformerEncoder(nn.Module):
55
  def __init__(self,encoder_layer,num_layers,norm=None):super().__init__();self.layers=_get_clones(encoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm
56
  def forward(self,src,mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
57
+ output=src
58
+ for layer in self.layers: output = layer(output,src_mask=mask,src_key_padding_mask=src_key_padding_mask,pos=pos)
59
  if self.norm is not None:output=self.norm(output)
60
  return output
61
+
62
  class TransformerDecoderLayer(nn.Module):
63
  def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),**kwargs):
64
+ super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.multihead_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.norm3=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.dropout3=nn.Dropout(dropout);self.activation=activation()
65
  def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
66
  def forward(self,tgt,memory,tgt_mask:Optional[Tensor]=None,memory_mask:Optional[Tensor]=None,tgt_key_padding_mask:Optional[Tensor]=None,memory_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None,query_pos:Optional[Tensor]=None):
67
+ q=k=self.with_pos_embed(tgt,query_pos);tgt2=self.self_attn(q,k,value=tgt,attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0];tgt=tgt+self.dropout1(tgt2);tgt=self.norm1(tgt)
68
+ tgt2=self.multihead_attn(query=self.with_pos_embed(tgt,query_pos),key=self.with_pos_embed(memory,pos),value=memory,attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0];tgt=tgt+self.dropout2(tgt2);tgt=self.norm2(tgt)
69
+ tgt2=self.linear2(self.dropout(self.activation(self.linear1(tgt))));tgt=tgt+self.dropout3(tgt2);tgt=self.norm3(tgt);return tgt
70
+
71
  class TransformerDecoder(nn.Module):
72
  def __init__(self,decoder_layer,num_layers,norm=None,return_intermediate=False):super().__init__();self.layers=_get_clones(decoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm;self.return_intermediate=return_intermediate
73
+ def forward(self,tgt,memory,**kwargs):
74
+ output=tgt
75
+ for layer in self.layers: output = layer(output,memory,**kwargs)
76
+ if self.norm is not None:output=self.norm(output)
77
+ return output.unsqueeze(0)
78
+
79
  class GRUWaypointsPredictor(nn.Module):
80
+ def __init__(self,input_dim,waypoints=10,**kwargs):super().__init__();self.gru=torch.nn.GRU(input_size=input_dim,hidden_size=64,batch_first=True);self.encoder=nn.Linear(2,64);self.decoder=nn.Linear(64,2);self.waypoints=waypoints
81
  def forward(self,x,target_point):bs=x.shape[0];z=self.encoder(target_point).unsqueeze(0);output,_=self.gru(x,z);output=self.decoder(output.reshape(bs*self.waypoints,-1)).reshape(bs,self.waypoints,2);return torch.cumsum(output,1)
82
 
83
+ # ==================================================
84
+ # SECTION 2: ORIGINAL INTERFUSER MODEL
85
+ # ==================================================
86
  class Interfuser(nn.Module):
87
+ def __init__(self, config):
88
+ super().__init__()
89
+ self.config = config
90
+ self.traffic_pred_head_type=getattr(config,"traffic_pred_head_type","det");self.embed_dim=config.embed_dim;self.waypoints_pred_head=config.waypoints_pred_head;self.use_view_embed=getattr(config,"use_view_embed",True)
91
+ if config.use_different_backbone:
92
+ if config.rgb_backbone_name=="r50":self.rgb_backbone=resnet50d(pretrained=False,in_chans=config.in_chans,features_only=True,out_indices=[4])
93
+ elif config.rgb_backbone_name=="r26":self.rgb_backbone=resnet26d(pretrained=False,in_chans=config.in_chans,features_only=True,out_indices=[4])
94
+ else:self.rgb_backbone=resnet18d(pretrained=False,in_chans=config.in_chans,features_only=True,out_indices=[4])
95
+ if config.lidar_backbone_name=="r18":self.lidar_backbone=resnet18d(pretrained=False,in_chans=3,features_only=True,out_indices=[4])
96
+ self.rgb_patch_embed=partial(HybridEmbed,backbone=self.rgb_backbone)(in_chans=config.in_chans,embed_dim=config.embed_dim)
97
+ self.lidar_patch_embed=partial(HybridEmbed,backbone=self.lidar_backbone)(in_chans=3,embed_dim=config.embed_dim)
98
  self.global_embed=nn.Parameter(torch.zeros(1,config.embed_dim,5));self.view_embed=nn.Parameter(torch.zeros(1,config.embed_dim,5,1));self.query_pos_embed=nn.Parameter(torch.zeros(1,config.embed_dim,11));self.query_embed=nn.Parameter(torch.zeros(400+11,1,config.embed_dim))
99
  self.waypoints_generator=GRUWaypointsPredictor(config.embed_dim);self.junction_pred_head=nn.Linear(config.embed_dim,2);self.traffic_light_pred_head=nn.Linear(config.embed_dim,2);self.stop_sign_head=nn.Linear(config.embed_dim,2)
100
  self.traffic_pred_head=nn.Sequential(nn.Linear(config.embed_dim+32,64),nn.ReLU(),nn.Linear(64,7),nn.Sigmoid());self.position_encoding=PositionEmbeddingSine(config.embed_dim//2)
101
+ encoder_layer=TransformerEncoderLayer(config.embed_dim,config.num_heads,config.dim_feedforward);self.encoder=TransformerEncoder(encoder_layer,config.enc_depth)
102
+ decoder_layer=TransformerDecoderLayer(config.embed_dim,config.num_heads,config.dim_feedforward);self.decoder=TransformerDecoder(decoder_layer,config.dec_depth,nn.LayerNorm(config.embed_dim))
103
  def forward_features(self,front_image,left_image,right_image,front_center_image,lidar,**kwargs):
104
+ features=[];front_image_token,front_image_token_global=self.rgb_patch_embed(front_image)
105
+ front_image_token=(front_image_token+self.view_embed[:,:,0:1,:]+self.position_encoding(front_image_token)).flatten(2).permute(2,0,1);front_image_token_global=(front_image_token_global+self.view_embed[:,:,0,:]+self.global_embed[:,:,0:1]).permute(2,0,1)
106
+ features.extend([front_image_token,front_image_token_global])
107
+ left_image_token,left_image_token_global=self.rgb_patch_embed(left_image)
108
+ left_image_token=(left_image_token+self.view_embed[:,:,1:2,:]+self.position_encoding(left_image_token)).flatten(2).permute(2,0,1);left_image_token_global=(left_image_token_global+self.view_embed[:,:,1,:]+self.global_embed[:,:,1:2]).permute(2,0,1)
109
+ features.extend([left_image_token,left_image_token_global])
110
+ right_image_token,right_image_token_global=self.rgb_patch_embed(right_image)
111
+ right_image_token=(right_image_token+self.view_embed[:,:,2:3,:]+self.position_encoding(right_image_token)).flatten(2).permute(2,0,1);right_image_token_global=(right_image_token_global+self.view_embed[:,:,2,:]+self.global_embed[:,:,2:3]).permute(2,0,1)
112
+ features.extend([right_image_token,right_image_token_global])
113
+ return torch.cat(features,0)
114
  def forward(self,x):
115
  front_image=x["rgb"];left_image=x["rgb_left"];right_image=x["rgb_right"];front_center_image=x["rgb_center"];measurements=x["measurements"];target_point=x["target_point"];lidar=x["lidar"]
116
  if self.config.direct_concat:
117
  img_size=front_image.shape[-1];left_image=torch.nn.functional.interpolate(left_image,size=(img_size,img_size));right_image=torch.nn.functional.interpolate(right_image,size=(img_size,img_size));front_center_image=torch.nn.functional.interpolate(front_center_image,size=(img_size,img_size))
118
+ front_image=torch.cat([front_image,left_image,right_image,front_center_image],dim=1)
119
+ features=self.forward_features(front_image,left_image,right_image,front_center_image,lidar,measurements=measurements)
120
  bs=front_image.shape[0];tgt=self.position_encoding(torch.ones((bs,1,20,20),device=x["rgb"].device)).flatten(2);tgt=torch.cat([tgt,self.query_pos_embed.repeat(bs,1,1)],2).permute(2,0,1)
121
  memory=self.encoder(features,mask=None);hs=self.decoder(self.query_embed.repeat(1,bs,1),memory,query_pos=tgt)[0].permute(1,0,2)
122
  waypoints=self.waypoints_generator(hs[:,401:411],target_point)
 
124
  velocity=measurements[:,6:7].unsqueeze(-1).repeat(1,400,32);traffic_feature_with_vel=torch.cat([hs[:,:400],velocity],dim=2);traffic=self.traffic_pred_head(traffic_feature_with_vel)
125
  return traffic,waypoints,is_junction,traffic_light_state,stop_sign,hs[:,:400]
126
 
127
+ # ==================================================
128
+ # SECTION 3: HUGGING FACE WRAPPER CLASSES
129
+ # ==================================================
130
  class InterfuserConfig(PretrainedConfig):
131
  model_type="interfuser"
132
+ def __init__(self, in_chans=12, embed_dim=256, enc_depth=6, dec_depth=6, num_heads=8, dim_feedforward=2048, rgb_backbone_name="r50", lidar_backbone_name="r18", use_different_backbone=True, waypoints_pred_head="gru", direct_concat=True, **kwargs):
133
+ super().__init__(**kwargs)
134
+ self.in_chans = in_chans; self.embed_dim, self.enc_depth, self.dec_depth, self.num_heads, self.dim_feedforward = embed_dim, enc_depth, dec_depth, num_heads, dim_feedforward
135
+ self.rgb_backbone_name, self.lidar_backbone_name, self.use_different_backbone, self.waypoints_pred_head, self.direct_concat = rgb_backbone_name, lidar_backbone_name, use_different_backbone, waypoints_pred_head, direct_concat
136
+ self.auto_map = {"AutoModel": "modeling_interfuser.InterfuserForHuggingFace"}
137
 
138
  class InterfuserForHuggingFace(PreTrainedModel):
139
+ config_class = InterfuserConfig
140
+ def __init__(self, config: InterfuserConfig):
141
+ super().__init__(config)
142
+ self.config = config
143
+ self.interfuser_model = Interfuser(config)
144
+ def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
145
+ inputs_dict = {'rgb':rgb, 'rgb_left':rgb_left, 'rgb_right':right_image, 'rgb_center':rgb_center, 'lidar':lidar, 'measurements':measurements, 'target_point':target_point}
146
  return self.interfuser_model.forward(inputs_dict)