File size: 15,118 Bytes
02bcfd5 54355d5 f08920f 7e3e9a5 54355d5 2c9f0fc 7e3e9a5 201bea6 94dcf53 9a8472f def2824 f08920f d68f054 02bcfd5 f08920f 02bcfd5 d68f054 f08920f d68f054 f08920f d68f054 f08920f d68f054 f08920f d68f054 f08920f 54355d5 f08920f 94dcf53 f08920f d68f054 02bcfd5 f08920f d68f054 f08920f 54355d5 f261f2d 94dcf53 f261f2d 54355d5 94dcf53 f08920f 94dcf53 8deaea0 94dcf53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# -*- coding: utf-8 -*-
import torch, math, copy
from torch import nn, Tensor
from functools import partial
from typing import Optional, List
from collections import OrderedDict
from transformers import PreTrainedModel, PretrainedConfig
try:
from timm.models.layers import to_2tuple
from timm.models.resnet import resnet50d, resnet26d, resnet18d
except ImportError:
raise ImportError("This model requires timm. Please install with 'pip install timm==0.4.12' or a compatible version.")
def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class HybridEmbed(nn.Module):
def __init__(self,backbone,img_size=224,patch_size=1,feature_size=None,in_chans=3,embed_dim=768):
super().__init__();assert isinstance(backbone,nn.Module);img_size=to_2tuple(img_size);patch_size=to_2tuple(patch_size);self.img_size=img_size;self.patch_size=patch_size;self.backbone=backbone
if feature_size is None:
with torch.no_grad():
training=backbone.training
if training:backbone.eval()
o=self.backbone(torch.zeros(1,in_chans,img_size[0],img_size[1]))
if isinstance(o,(list,tuple)):o=o[-1]
feature_size=o.shape[-2:];feature_dim=o.shape[1];backbone.train(training)
else:
feature_size=to_2tuple(feature_size)
if hasattr(self.backbone,"feature_info"):feature_dim=self.backbone.feature_info.channels()[-1]
else:feature_dim=self.backbone.num_features
self.proj=nn.Conv2d(feature_dim,embed_dim,kernel_size=1,stride=1)
def forward(self,x):
x=self.backbone(x)
if isinstance(x,(list,tuple)):x=x[-1]
x=self.proj(x);global_x=torch.mean(x,[2,3],keepdim=False)[:,:,None];return x,global_x
class PositionEmbeddingSine(nn.Module):
def __init__(self,num_pos_feats=64,temperature=10000,normalize=False,scale=None):
super().__init__();self.num_pos_feats=num_pos_feats;self.temperature=temperature;self.normalize=normalize
if scale is not None and normalize is False:raise ValueError("normalize should be True if scale is passed")
if scale is None:scale=2*math.pi
self.scale=scale
def forward(self,tensor):
x=tensor;bs,_,h,w=x.shape;not_mask=torch.ones((bs,h,w),device=x.device);y_embed=not_mask.cumsum(1,dtype=torch.float32);x_embed=not_mask.cumsum(2,dtype=torch.float32)
if self.normalize:eps=1e-6;y_embed=y_embed/(y_embed[:,-1:,:]+eps)*self.scale;x_embed=x_embed/(x_embed[:,:,-1:]+eps)*self.scale
dim_t=torch.arange(self.num_pos_feats,dtype=torch.float32,device=x.device);dim_t=self.temperature**(2*(dim_t//2)/self.num_pos_feats);pos_x=x_embed[:,:,:,None]/dim_t;pos_y=y_embed[:,:,:,None]/dim_t
pos_x=torch.stack((pos_x[:,:,:,0::2].sin(),pos_x[:,:,:,1::2].cos()),dim=4).flatten(3);pos_y=torch.stack((pos_y[:,:,:,0::2].sin(),pos_y[:,:,:,1::2].cos()),dim=4).flatten(3);return torch.cat((pos_y,pos_x),dim=3).permute(0,3,1,2)
class TransformerEncoder(nn.Module):
def __init__(self,encoder_layer,num_layers,norm=None):super().__init__();self.layers=_get_clones(encoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm
def forward(self,src,mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
output=src;
for layer in self.layers:output=layer(output,src_mask=mask,src_key_padding_mask=src_key_padding_mask,pos=pos)
if self.norm is not None:output=self.norm(output)
return output
class GRUWaypointsPredictor(nn.Module):
def __init__(self,input_dim,waypoints=10,**kwargs):super().__init__();self.gru=torch.nn.GRU(input_size=input_dim,hidden_size=64,batch_first=True);self.encoder=nn.Linear(2,64);self.decoder=nn.Linear(64,2);self.waypoints=waypoints
def forward(self,x,target_point):bs=x.shape[0];z=self.encoder(target_point).unsqueeze(0);output,_=self.gru(x,z);output=self.decoder(output.reshape(bs*self.waypoints,-1)).reshape(bs,self.waypoints,2);return torch.cumsum(output,1)
class TransformerEncoderLayer(nn.Module):
def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),normalize_before=False,**kwargs):
super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.activation=activation();self.normalize_before=normalize_before
def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
def forward(self,src,src_mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
q=k=self.with_pos_embed(src,pos);src2=self.self_attn(q,k,value=src,attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0];src=src+self.dropout1(src2);src=self.norm1(src);src2=self.linear2(self.dropout(self.activation(self.linear1(src))));src=src+self.dropout2(src2);src=self.norm2(src);return src
class TransformerDecoderLayer(nn.Module):
def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),normalize_before=False,**kwargs):
super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.multihead_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.norm3=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.dropout3=nn.Dropout(dropout);self.activation=activation();self.normalize_before=normalize_before
def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
def forward(self,tgt,memory,tgt_mask:Optional[Tensor]=None,memory_mask:Optional[Tensor]=None,tgt_key_padding_mask:Optional[Tensor]=None,memory_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None,query_pos:Optional[Tensor]=None):
q=k=self.with_pos_embed(tgt,query_pos);tgt2=self.self_attn(q,k,value=tgt,attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0];tgt=tgt+self.dropout1(tgt2);tgt=self.norm1(tgt);tgt2=self.multihead_attn(query=self.with_pos_embed(tgt,query_pos),key=self.with_pos_embed(memory,pos),value=memory,attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0];tgt=tgt+self.dropout2(tgt2);tgt=self.norm2(tgt);tgt2=self.linear2(self.dropout(self.activation(self.linear1(tgt))));tgt=tgt+self.dropout3(tgt2);tgt=self.norm3(tgt);return tgt
class TransformerDecoder(nn.Module):
def __init__(self,decoder_layer,num_layers,norm=None,return_intermediate=False,**kwargs):
super().__init__();self.layers=_get_clones(decoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm;self.return_intermediate=return_intermediate
def forward(self,tgt,memory,**kwargs):
output=tgt;
for layer in self.layers:output=layer(output,memory,**kwargs)
if self.norm is not None:output=self.norm(output)
return output.unsqueeze(0)
class Interfuser(nn.Module):
def __init__(self,img_size=224,multi_view_img_size=112,patch_size=8,in_chans=3,embed_dim=768,enc_depth=6,dec_depth=6,dim_feedforward=2048,normalize_before=False,rgb_backbone_name="r26",lidar_backbone_name="r26",num_heads=8,norm_layer=None,dropout=0.1,end2end=False,direct_concat=True,separate_view_attention=False,separate_all_attention=False,act_layer=None,weight_init="",freeze_num=-1,with_lidar=False,with_right_left_sensors=True,with_center_sensor=False,traffic_pred_head_type="det",waypoints_pred_head="heatmap",reverse_pos=True,use_different_backbone=False,use_view_embed=True,use_mmad_pretrain=None):
super().__init__();self.traffic_pred_head_type=traffic_pred_head_type;self.num_features=self.embed_dim=embed_dim;norm_layer=norm_layer or partial(nn.LayerNorm,eps=1e-6);act_layer=act_layer or nn.GELU;self.reverse_pos=reverse_pos;self.waypoints_pred_head=waypoints_pred_head;self.with_lidar=with_lidar;self.with_right_left_sensors=with_right_left_sensors;self.with_center_sensor=with_center_sensor;self.direct_concat=direct_concat;self.separate_view_attention=separate_view_attention;self.separate_all_attention=separate_all_attention;self.end2end=end2end;self.use_view_embed=use_view_embed
if self.direct_concat:in_chans=in_chans*4;self.with_center_sensor=False;self.with_right_left_sensors=False
self.attn_mask=None
if use_different_backbone:
if rgb_backbone_name=="r50":self.rgb_backbone=resnet50d(pretrained=False,in_chans=in_chans,features_only=True,out_indices=[4])
elif rgb_backbone_name=="r26":self.rgb_backbone=resnet26d(pretrained=False,in_chans=in_chans,features_only=True,out_indices=[4])
else:self.rgb_backbone=resnet18d(pretrained=False,in_chans=in_chans,features_only=True,out_indices=[4])
if lidar_backbone_name=="r18":self.lidar_backbone=resnet18d(pretrained=False,in_chans=3,features_only=True,out_indices=[4])
self.rgb_patch_embed=partial(HybridEmbed,backbone=self.rgb_backbone)(img_size=img_size,patch_size=patch_size,in_chans=in_chans,embed_dim=embed_dim)
self.lidar_patch_embed=partial(HybridEmbed,backbone=self.lidar_backbone)(img_size=img_size,patch_size=patch_size,in_chans=3,embed_dim=embed_dim)
self.global_embed=nn.Parameter(torch.zeros(1,embed_dim,5));self.view_embed=nn.Parameter(torch.zeros(1,embed_dim,5,1))
if waypoints_pred_head=="gru":self.query_pos_embed=nn.Parameter(torch.zeros(1,embed_dim,11));self.query_embed=nn.Parameter(torch.zeros(400+11,1,embed_dim))
self.waypoints_generator=GRUWaypointsPredictor(embed_dim);self.junction_pred_head=nn.Linear(embed_dim,2);self.traffic_light_pred_head=nn.Linear(embed_dim,2);self.stop_sign_head=nn.Linear(embed_dim,2);self.traffic_pred_head=nn.Sequential(nn.Linear(embed_dim+32,64),nn.ReLU(),nn.Linear(64,7),nn.Sigmoid());self.position_encoding=PositionEmbeddingSine(embed_dim//2,normalize=True);encoder_layer=TransformerEncoderLayer(embed_dim,num_heads,dim_feedforward,dropout,act_layer,normalize_before);self.encoder=TransformerEncoder(encoder_layer,enc_depth,None);decoder_layer=TransformerDecoderLayer(embed_dim,num_heads,dim_feedforward,dropout,act_layer,normalize_before);decoder_norm=nn.LayerNorm(embed_dim);self.decoder=TransformerDecoder(decoder_layer,dec_depth,decoder_norm,return_intermediate=False)
def forward_features(self,front_image,left_image,right_image,front_center_image,lidar,measurements):
features=[];front_image_token,front_image_token_global=self.rgb_patch_embed(front_image)
if self.use_view_embed:front_image_token=front_image_token+self.view_embed[:,:,0:1,:]+self.position_encoding(front_image_token)
else:front_image_token=front_image_token+self.position_encoding(front_image_token)
front_image_token=front_image_token.flatten(2).permute(2,0,1);front_image_token_global=(front_image_token_global+self.view_embed[:,:,0,:]+self.global_embed[:,:,0:1]).permute(2,0,1);features.extend([front_image_token,front_image_token_global])
if self.with_right_left_sensors:
left_image_token,left_image_token_global=self.rgb_patch_embed(left_image);right_image_token,right_image_token_global=self.rgb_patch_embed(right_image)
if self.use_view_embed:left_image_token=left_image_token+self.view_embed[:,:,1:2,:]+self.position_encoding(left_image_token);right_image_token=right_image_token+self.view_embed[:,:,2:3,:]+self.position_encoding(right_image_token)
else:left_image_token=left_image_token+self.position_encoding(left_image_token);right_image_token=right_image_token+self.position_encoding(right_image_token)
left_image_token=left_image_token.flatten(2).permute(2,0,1);left_image_token_global=(left_image_token_global+self.view_embed[:,:,1,:]+self.global_embed[:,:,1:2]).permute(2,0,1)
right_image_token=right_image_token.flatten(2).permute(2,0,1);right_image_token_global=(right_image_token_global+self.view_embed[:,:,2,:]+self.global_embed[:,:,2:3]).permute(2,0,1)
features.extend([left_image_token,left_image_token_global,right_image_token,right_image_token_global])
return torch.cat(features,0)
def forward(self,x):
front_image=x["rgb"];left_image=x["rgb_left"];right_image=x["rgb_right"];front_center_image=x["rgb_center"];measurements=x["measurements"];target_point=x["target_point"];lidar=x["lidar"]
if self.direct_concat:
img_size=front_image.shape[-1];left_image=torch.nn.functional.interpolate(left_image,size=(img_size,img_size));right_image=torch.nn.functional.interpolate(right_image,size=(img_size,img_size));front_center_image=torch.nn.functional.interpolate(front_center_image,size=(img_size,img_size))
front_image=torch.cat([front_image,left_image,right_image,front_center_image],dim=1)
features=self.forward_features(front_image,left_image,right_image,front_center_image,lidar,measurements)
bs=front_image.shape[0];tgt=self.position_encoding(torch.ones((bs,1,20,20),device=x["rgb"].device)).flatten(2);tgt=torch.cat([tgt,self.query_pos_embed.repeat(bs,1,1)],2).permute(2,0,1)
memory=self.encoder(features,mask=None);hs=self.decoder(self.query_embed.repeat(1,bs,1),memory,query_pos=tgt)[0].permute(1,0,2)
waypoints=self.waypoints_generator(hs[:,401:411],target_point)
is_junction=self.junction_pred_head(hs[:,400]);traffic_light_state=self.traffic_light_pred_head(hs[:,400]);stop_sign=self.stop_sign_head(hs[:,400])
velocity=measurements[:,6:7].unsqueeze(-1).repeat(1,400,32);traffic_feature_with_vel=torch.cat([hs[:,:400],velocity],dim=2);traffic=self.traffic_pred_head(traffic_feature_with_vel)
return traffic,waypoints,is_junction,traffic_light_state,stop_sign,hs[:,:400]
class InterfuserConfig(PretrainedConfig):
model_type="interfuser"
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in kwargs.items(): setattr(self, key, value)
self.auto_map={"AutoModel":"modeling_interfuser.InterfuserForHuggingFace"}
class InterfuserForHuggingFace(PreTrainedModel):
config_class = InterfuserConfig
def __init__(self, config: InterfuserConfig):
super().__init__(config)
init_args = config.to_dict()
for key in ["model_type", "architectures", "auto_map", "transformers_version", "_name_or_path"]: init_args.pop(key, None)
self.interfuser = Interfuser(**init_args)
def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
inputs_dict = {'rgb':rgb, 'rgb_left':rgb_left, 'rgb_right':rgb_right, 'rgb_center':rgb_center, 'lidar':lidar, 'measurements':measurements, 'target_point':target_point}
return self.interfuser.forward(inputs_dict)
|