File size: 15,118 Bytes
02bcfd5
54355d5
f08920f
7e3e9a5
54355d5
 
2c9f0fc
7e3e9a5
201bea6
 
 
 
94dcf53
9a8472f
def2824
f08920f
d68f054
02bcfd5
 
 
 
 
 
 
 
 
 
 
 
 
 
f08920f
02bcfd5
 
 
d68f054
f08920f
 
 
 
 
 
 
 
 
 
d68f054
f08920f
 
 
 
 
d68f054
 
f08920f
 
d68f054
f08920f
 
 
 
 
d68f054
f08920f
 
 
 
 
 
 
 
 
 
 
 
 
54355d5
f08920f
 
 
 
94dcf53
f08920f
 
 
 
 
 
 
 
 
 
 
 
 
 
d68f054
02bcfd5
 
 
f08920f
 
 
 
 
 
d68f054
f08920f
 
 
 
 
 
 
 
 
54355d5
f261f2d
94dcf53
 
 
f261f2d
54355d5
94dcf53
 
 
 
f08920f
 
94dcf53
8deaea0
94dcf53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

# -*- coding: utf-8 -*-
import torch, math, copy
from torch import nn, Tensor
from functools import partial
from typing import Optional, List
from collections import OrderedDict
from transformers import PreTrainedModel, PretrainedConfig

try:
    from timm.models.layers import to_2tuple
    from timm.models.resnet import resnet50d, resnet26d, resnet18d
except ImportError:
    raise ImportError("This model requires timm. Please install with 'pip install timm==0.4.12' or a compatible version.")

def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class HybridEmbed(nn.Module):
    def __init__(self,backbone,img_size=224,patch_size=1,feature_size=None,in_chans=3,embed_dim=768):
        super().__init__();assert isinstance(backbone,nn.Module);img_size=to_2tuple(img_size);patch_size=to_2tuple(patch_size);self.img_size=img_size;self.patch_size=patch_size;self.backbone=backbone
        if feature_size is None:
            with torch.no_grad():
                training=backbone.training
                if training:backbone.eval()
                o=self.backbone(torch.zeros(1,in_chans,img_size[0],img_size[1]))
                if isinstance(o,(list,tuple)):o=o[-1]
                feature_size=o.shape[-2:];feature_dim=o.shape[1];backbone.train(training)
        else:
            feature_size=to_2tuple(feature_size)
            if hasattr(self.backbone,"feature_info"):feature_dim=self.backbone.feature_info.channels()[-1]
            else:feature_dim=self.backbone.num_features
        self.proj=nn.Conv2d(feature_dim,embed_dim,kernel_size=1,stride=1)
    def forward(self,x):
        x=self.backbone(x)
        if isinstance(x,(list,tuple)):x=x[-1]
        x=self.proj(x);global_x=torch.mean(x,[2,3],keepdim=False)[:,:,None];return x,global_x
class PositionEmbeddingSine(nn.Module):
    def __init__(self,num_pos_feats=64,temperature=10000,normalize=False,scale=None):
        super().__init__();self.num_pos_feats=num_pos_feats;self.temperature=temperature;self.normalize=normalize
        if scale is not None and normalize is False:raise ValueError("normalize should be True if scale is passed")
        if scale is None:scale=2*math.pi
        self.scale=scale
    def forward(self,tensor):
        x=tensor;bs,_,h,w=x.shape;not_mask=torch.ones((bs,h,w),device=x.device);y_embed=not_mask.cumsum(1,dtype=torch.float32);x_embed=not_mask.cumsum(2,dtype=torch.float32)
        if self.normalize:eps=1e-6;y_embed=y_embed/(y_embed[:,-1:,:]+eps)*self.scale;x_embed=x_embed/(x_embed[:,:,-1:]+eps)*self.scale
        dim_t=torch.arange(self.num_pos_feats,dtype=torch.float32,device=x.device);dim_t=self.temperature**(2*(dim_t//2)/self.num_pos_feats);pos_x=x_embed[:,:,:,None]/dim_t;pos_y=y_embed[:,:,:,None]/dim_t
        pos_x=torch.stack((pos_x[:,:,:,0::2].sin(),pos_x[:,:,:,1::2].cos()),dim=4).flatten(3);pos_y=torch.stack((pos_y[:,:,:,0::2].sin(),pos_y[:,:,:,1::2].cos()),dim=4).flatten(3);return torch.cat((pos_y,pos_x),dim=3).permute(0,3,1,2)
class TransformerEncoder(nn.Module):
    def __init__(self,encoder_layer,num_layers,norm=None):super().__init__();self.layers=_get_clones(encoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm
    def forward(self,src,mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
        output=src;
        for layer in self.layers:output=layer(output,src_mask=mask,src_key_padding_mask=src_key_padding_mask,pos=pos)
        if self.norm is not None:output=self.norm(output)
        return output
class GRUWaypointsPredictor(nn.Module):
    def __init__(self,input_dim,waypoints=10,**kwargs):super().__init__();self.gru=torch.nn.GRU(input_size=input_dim,hidden_size=64,batch_first=True);self.encoder=nn.Linear(2,64);self.decoder=nn.Linear(64,2);self.waypoints=waypoints
    def forward(self,x,target_point):bs=x.shape[0];z=self.encoder(target_point).unsqueeze(0);output,_=self.gru(x,z);output=self.decoder(output.reshape(bs*self.waypoints,-1)).reshape(bs,self.waypoints,2);return torch.cumsum(output,1)
class TransformerEncoderLayer(nn.Module):
    def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),normalize_before=False,**kwargs):
        super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.activation=activation();self.normalize_before=normalize_before
    def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
    def forward(self,src,src_mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
        q=k=self.with_pos_embed(src,pos);src2=self.self_attn(q,k,value=src,attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0];src=src+self.dropout1(src2);src=self.norm1(src);src2=self.linear2(self.dropout(self.activation(self.linear1(src))));src=src+self.dropout2(src2);src=self.norm2(src);return src
class TransformerDecoderLayer(nn.Module):
    def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),normalize_before=False,**kwargs):
        super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.multihead_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.norm3=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.dropout3=nn.Dropout(dropout);self.activation=activation();self.normalize_before=normalize_before
    def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
    def forward(self,tgt,memory,tgt_mask:Optional[Tensor]=None,memory_mask:Optional[Tensor]=None,tgt_key_padding_mask:Optional[Tensor]=None,memory_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None,query_pos:Optional[Tensor]=None):
        q=k=self.with_pos_embed(tgt,query_pos);tgt2=self.self_attn(q,k,value=tgt,attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0];tgt=tgt+self.dropout1(tgt2);tgt=self.norm1(tgt);tgt2=self.multihead_attn(query=self.with_pos_embed(tgt,query_pos),key=self.with_pos_embed(memory,pos),value=memory,attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0];tgt=tgt+self.dropout2(tgt2);tgt=self.norm2(tgt);tgt2=self.linear2(self.dropout(self.activation(self.linear1(tgt))));tgt=tgt+self.dropout3(tgt2);tgt=self.norm3(tgt);return tgt
class TransformerDecoder(nn.Module):
    def __init__(self,decoder_layer,num_layers,norm=None,return_intermediate=False,**kwargs):
        super().__init__();self.layers=_get_clones(decoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm;self.return_intermediate=return_intermediate
    def forward(self,tgt,memory,**kwargs):
        output=tgt;
        for layer in self.layers:output=layer(output,memory,**kwargs)
        if self.norm is not None:output=self.norm(output)
        return output.unsqueeze(0)
class Interfuser(nn.Module):
    def __init__(self,img_size=224,multi_view_img_size=112,patch_size=8,in_chans=3,embed_dim=768,enc_depth=6,dec_depth=6,dim_feedforward=2048,normalize_before=False,rgb_backbone_name="r26",lidar_backbone_name="r26",num_heads=8,norm_layer=None,dropout=0.1,end2end=False,direct_concat=True,separate_view_attention=False,separate_all_attention=False,act_layer=None,weight_init="",freeze_num=-1,with_lidar=False,with_right_left_sensors=True,with_center_sensor=False,traffic_pred_head_type="det",waypoints_pred_head="heatmap",reverse_pos=True,use_different_backbone=False,use_view_embed=True,use_mmad_pretrain=None):
        super().__init__();self.traffic_pred_head_type=traffic_pred_head_type;self.num_features=self.embed_dim=embed_dim;norm_layer=norm_layer or partial(nn.LayerNorm,eps=1e-6);act_layer=act_layer or nn.GELU;self.reverse_pos=reverse_pos;self.waypoints_pred_head=waypoints_pred_head;self.with_lidar=with_lidar;self.with_right_left_sensors=with_right_left_sensors;self.with_center_sensor=with_center_sensor;self.direct_concat=direct_concat;self.separate_view_attention=separate_view_attention;self.separate_all_attention=separate_all_attention;self.end2end=end2end;self.use_view_embed=use_view_embed
        if self.direct_concat:in_chans=in_chans*4;self.with_center_sensor=False;self.with_right_left_sensors=False
        self.attn_mask=None
        if use_different_backbone:
            if rgb_backbone_name=="r50":self.rgb_backbone=resnet50d(pretrained=False,in_chans=in_chans,features_only=True,out_indices=[4])
            elif rgb_backbone_name=="r26":self.rgb_backbone=resnet26d(pretrained=False,in_chans=in_chans,features_only=True,out_indices=[4])
            else:self.rgb_backbone=resnet18d(pretrained=False,in_chans=in_chans,features_only=True,out_indices=[4])
            if lidar_backbone_name=="r18":self.lidar_backbone=resnet18d(pretrained=False,in_chans=3,features_only=True,out_indices=[4])
            self.rgb_patch_embed=partial(HybridEmbed,backbone=self.rgb_backbone)(img_size=img_size,patch_size=patch_size,in_chans=in_chans,embed_dim=embed_dim)
            self.lidar_patch_embed=partial(HybridEmbed,backbone=self.lidar_backbone)(img_size=img_size,patch_size=patch_size,in_chans=3,embed_dim=embed_dim)
        self.global_embed=nn.Parameter(torch.zeros(1,embed_dim,5));self.view_embed=nn.Parameter(torch.zeros(1,embed_dim,5,1))
        if waypoints_pred_head=="gru":self.query_pos_embed=nn.Parameter(torch.zeros(1,embed_dim,11));self.query_embed=nn.Parameter(torch.zeros(400+11,1,embed_dim))
        self.waypoints_generator=GRUWaypointsPredictor(embed_dim);self.junction_pred_head=nn.Linear(embed_dim,2);self.traffic_light_pred_head=nn.Linear(embed_dim,2);self.stop_sign_head=nn.Linear(embed_dim,2);self.traffic_pred_head=nn.Sequential(nn.Linear(embed_dim+32,64),nn.ReLU(),nn.Linear(64,7),nn.Sigmoid());self.position_encoding=PositionEmbeddingSine(embed_dim//2,normalize=True);encoder_layer=TransformerEncoderLayer(embed_dim,num_heads,dim_feedforward,dropout,act_layer,normalize_before);self.encoder=TransformerEncoder(encoder_layer,enc_depth,None);decoder_layer=TransformerDecoderLayer(embed_dim,num_heads,dim_feedforward,dropout,act_layer,normalize_before);decoder_norm=nn.LayerNorm(embed_dim);self.decoder=TransformerDecoder(decoder_layer,dec_depth,decoder_norm,return_intermediate=False)
    def forward_features(self,front_image,left_image,right_image,front_center_image,lidar,measurements):
        features=[];front_image_token,front_image_token_global=self.rgb_patch_embed(front_image)
        if self.use_view_embed:front_image_token=front_image_token+self.view_embed[:,:,0:1,:]+self.position_encoding(front_image_token)
        else:front_image_token=front_image_token+self.position_encoding(front_image_token)
        front_image_token=front_image_token.flatten(2).permute(2,0,1);front_image_token_global=(front_image_token_global+self.view_embed[:,:,0,:]+self.global_embed[:,:,0:1]).permute(2,0,1);features.extend([front_image_token,front_image_token_global])
        if self.with_right_left_sensors:
            left_image_token,left_image_token_global=self.rgb_patch_embed(left_image);right_image_token,right_image_token_global=self.rgb_patch_embed(right_image)
            if self.use_view_embed:left_image_token=left_image_token+self.view_embed[:,:,1:2,:]+self.position_encoding(left_image_token);right_image_token=right_image_token+self.view_embed[:,:,2:3,:]+self.position_encoding(right_image_token)
            else:left_image_token=left_image_token+self.position_encoding(left_image_token);right_image_token=right_image_token+self.position_encoding(right_image_token)
            left_image_token=left_image_token.flatten(2).permute(2,0,1);left_image_token_global=(left_image_token_global+self.view_embed[:,:,1,:]+self.global_embed[:,:,1:2]).permute(2,0,1)
            right_image_token=right_image_token.flatten(2).permute(2,0,1);right_image_token_global=(right_image_token_global+self.view_embed[:,:,2,:]+self.global_embed[:,:,2:3]).permute(2,0,1)
            features.extend([left_image_token,left_image_token_global,right_image_token,right_image_token_global])
        return torch.cat(features,0)
    def forward(self,x):
        front_image=x["rgb"];left_image=x["rgb_left"];right_image=x["rgb_right"];front_center_image=x["rgb_center"];measurements=x["measurements"];target_point=x["target_point"];lidar=x["lidar"]
        if self.direct_concat:
            img_size=front_image.shape[-1];left_image=torch.nn.functional.interpolate(left_image,size=(img_size,img_size));right_image=torch.nn.functional.interpolate(right_image,size=(img_size,img_size));front_center_image=torch.nn.functional.interpolate(front_center_image,size=(img_size,img_size))
            front_image=torch.cat([front_image,left_image,right_image,front_center_image],dim=1)
        features=self.forward_features(front_image,left_image,right_image,front_center_image,lidar,measurements)
        bs=front_image.shape[0];tgt=self.position_encoding(torch.ones((bs,1,20,20),device=x["rgb"].device)).flatten(2);tgt=torch.cat([tgt,self.query_pos_embed.repeat(bs,1,1)],2).permute(2,0,1)
        memory=self.encoder(features,mask=None);hs=self.decoder(self.query_embed.repeat(1,bs,1),memory,query_pos=tgt)[0].permute(1,0,2)
        waypoints=self.waypoints_generator(hs[:,401:411],target_point)
        is_junction=self.junction_pred_head(hs[:,400]);traffic_light_state=self.traffic_light_pred_head(hs[:,400]);stop_sign=self.stop_sign_head(hs[:,400])
        velocity=measurements[:,6:7].unsqueeze(-1).repeat(1,400,32);traffic_feature_with_vel=torch.cat([hs[:,:400],velocity],dim=2);traffic=self.traffic_pred_head(traffic_feature_with_vel)
        return traffic,waypoints,is_junction,traffic_light_state,stop_sign,hs[:,:400]
class InterfuserConfig(PretrainedConfig):
    model_type="interfuser"
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for key, value in kwargs.items(): setattr(self, key, value)
        self.auto_map={"AutoModel":"modeling_interfuser.InterfuserForHuggingFace"}
class InterfuserForHuggingFace(PreTrainedModel):
    config_class = InterfuserConfig
    def __init__(self, config: InterfuserConfig):
        super().__init__(config)
        init_args = config.to_dict()
        for key in ["model_type", "architectures", "auto_map", "transformers_version", "_name_or_path"]: init_args.pop(key, None)
        self.interfuser = Interfuser(**init_args)
    def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
        inputs_dict = {'rgb':rgb, 'rgb_left':rgb_left, 'rgb_right':rgb_right, 'rgb_center':rgb_center, 'lidar':lidar, 'measurements':measurements, 'target_point':target_point}
        return self.interfuser.forward(inputs_dict)