mohammed-aljafry commited on
Commit
8deaea0
·
verified ·
1 Parent(s): 8fc986f

Final fix v9: Correctly filter all unexpected kwargs during init

Browse files
Files changed (2) hide show
  1. config.json +0 -3
  2. modeling_interfuser.py +24 -98
config.json CHANGED
@@ -20,9 +20,6 @@
20
  "direct_concat": true,
21
  "separate_view_attention": false,
22
  "separate_all_attention": false,
23
- "act_layer": null,
24
- "weight_init": "",
25
- "freeze_num": -1,
26
  "with_lidar": false,
27
  "with_right_left_sensors": true,
28
  "with_center_sensor": false,
 
20
  "direct_concat": true,
21
  "separate_view_attention": false,
22
  "separate_all_attention": false,
 
 
 
23
  "with_lidar": false,
24
  "with_right_left_sensors": true,
25
  "with_center_sensor": false,
modeling_interfuser.py CHANGED
@@ -13,74 +13,15 @@ try:
13
  except ImportError:
14
  raise ImportError("This model requires timm. Please install with 'pip install timm==0.4.12' or a compatible version.")
15
 
 
16
  def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
17
- def build_attn_mask(mask_type):
18
- mask=torch.ones((151,151),dtype=torch.bool).cuda();
19
- if mask_type=="seperate_all":mask[:50,:50]=False;mask[50:67,50:67]=False;mask[67:84,67:84]=False;mask[84:101,84:101]=False;mask[101:151,101:151]=False
20
- elif mask_type=="seperate_view":mask[:50,:50]=False;mask[50:67,50:67]=False;mask[67:84,67:84]=False;mask[84:101,84:101]=False;mask[101:151,:]=False;mask[:,101:151]=False
21
- return mask
22
- class HybridEmbed(nn.Module):
23
- def __init__(self,backbone,img_size=224,patch_size=1,feature_size=None,in_chans=3,embed_dim=768):
24
- super().__init__();assert isinstance(backbone,nn.Module);img_size=to_2tuple(img_size);patch_size=to_2tuple(patch_size);self.img_size=img_size;self.patch_size=patch_size;self.backbone=backbone
25
- if feature_size is None:
26
- with torch.no_grad():
27
- training=backbone.training;
28
- if training:backbone.eval()
29
- o=self.backbone(torch.zeros(1,in_chans,img_size[0],img_size[1]))
30
- if isinstance(o,(list,tuple)):o=o[-1]
31
- feature_size=o.shape[-2:];feature_dim=o.shape[1];backbone.train(training)
32
- else:
33
- feature_size=to_2tuple(feature_size)
34
- if hasattr(self.backbone,"feature_info"):feature_dim=self.backbone.feature_info.channels()[-1]
35
- else:feature_dim=self.backbone.num_features
36
- self.proj=nn.Conv2d(feature_dim,embed_dim,kernel_size=1,stride=1)
37
- def forward(self,x):
38
- x=self.backbone(x)
39
- if isinstance(x,(list,tuple)):x=x[-1]
40
- x=self.proj(x);global_x=torch.mean(x,[2,3],keepdim=False)[:,:,None];return x,global_x
41
- class PositionEmbeddingSine(nn.Module):
42
- def __init__(self,num_pos_feats=64,temperature=10000,normalize=False,scale=None):
43
- super().__init__();self.num_pos_feats=num_pos_feats;self.temperature=temperature;self.normalize=normalize
44
- if scale is not None and normalize is False:raise ValueError("normalize should be True if scale is passed")
45
- if scale is None:scale=2*math.pi
46
- self.scale=scale
47
- def forward(self,tensor):
48
- x=tensor;bs,_,h,w=x.shape;not_mask=torch.ones((bs,h,w),device=x.device);y_embed=not_mask.cumsum(1,dtype=torch.float32);x_embed=not_mask.cumsum(2,dtype=torch.float32)
49
- if self.normalize:eps=1e-6;y_embed=y_embed/(y_embed[:,-1:,:]+eps)*self.scale;x_embed=x_embed/(x_embed[:,:,-1:]+eps)*self.scale
50
- dim_t=torch.arange(self.num_pos_feats,dtype=torch.float32,device=x.device);dim_t=self.temperature**(2*(dim_t//2)/self.num_pos_feats);pos_x=x_embed[:,:,:,None]/dim_t;pos_y=y_embed[:,:,:,None]/dim_t
51
- pos_x=torch.stack((pos_x[:,:,:,0::2].sin(),pos_x[:,:,:,1::2].cos()),dim=4).flatten(3);pos_y=torch.stack((pos_y[:,:,:,0::2].sin(),pos_y[:,:,:,1::2].cos()),dim=4).flatten(3);return torch.cat((pos_y,pos_x),dim=3).permute(0,3,1,2)
52
- class TransformerEncoder(nn.Module):
53
- def __init__(self,encoder_layer,num_layers,norm=None):super().__init__();self.layers=_get_clones(encoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm
54
- def forward(self,src,mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
55
- output=src;
56
- for layer in self.layers:output=layer(output,src_mask=mask,src_key_padding_mask=src_key_padding_mask,pos=pos)
57
- if self.norm is not None:output=self.norm(output)
58
- return output
59
- class GRUWaypointsPredictor(nn.Module):
60
- def __init__(self,input_dim,waypoints=10,**kwargs):super().__init__();self.gru=torch.nn.GRU(input_size=input_dim,hidden_size=64,batch_first=True);self.encoder=nn.Linear(2,64);self.decoder=nn.Linear(64,2);self.waypoints=waypoints
61
- def forward(self,x,target_point):bs=x.shape[0];z=self.encoder(target_point).unsqueeze(0);output,_=self.gru(x,z);output=self.decoder(output.reshape(bs*self.waypoints,-1)).reshape(bs,self.waypoints,2);return torch.cumsum(output,1)
62
- class TransformerEncoderLayer(nn.Module):
63
- def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),normalize_before=False,**kwargs):
64
- super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.activation=activation();self.normalize_before=normalize_before
65
- def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
66
- def forward(self,src,src_mask:Optional[Tensor]=None,src_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None):
67
- q=k=self.with_pos_embed(src,pos);src2=self.self_attn(q,k,value=src,attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0];src=src+self.dropout1(src2);src=self.norm1(src);src2=self.linear2(self.dropout(self.activation(self.linear1(src))));src=src+self.dropout2(src2);src=self.norm2(src);return src
68
- class TransformerDecoderLayer(nn.Module):
69
- def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),normalize_before=False,**kwargs):
70
- super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.multihead_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.norm3=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.dropout3=nn.Dropout(dropout);self.activation=activation();self.normalize_before=normalize_before
71
- def with_pos_embed(self,tensor,pos:Optional[Tensor]):return tensor if pos is None else tensor+pos
72
- def forward(self,tgt,memory,tgt_mask:Optional[Tensor]=None,memory_mask:Optional[Tensor]=None,tgt_key_padding_mask:Optional[Tensor]=None,memory_key_padding_mask:Optional[Tensor]=None,pos:Optional[Tensor]=None,query_pos:Optional[Tensor]=None):
73
- q=k=self.with_pos_embed(tgt,query_pos);tgt2=self.self_attn(q,k,value=tgt,attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0];tgt=tgt+self.dropout1(tgt2);tgt=self.norm1(tgt);tgt2=self.multihead_attn(query=self.with_pos_embed(tgt,query_pos),key=self.with_pos_embed(memory,pos),value=memory,attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0];tgt=tgt+self.dropout2(tgt2);tgt=self.norm2(tgt);tgt2=self.linear2(self.dropout(self.activation(self.linear1(tgt))));tgt=tgt+self.dropout3(tgt2);tgt=self.norm3(tgt);return tgt
74
- class TransformerDecoder(nn.Module):
75
- def __init__(self,decoder_layer,num_layers,norm=None,return_intermediate=False,**kwargs):
76
- super().__init__();self.layers=_get_clones(decoder_layer,num_layers);self.num_layers=num_layers;self.norm=norm;self.return_intermediate=return_intermediate
77
- def forward(self,tgt,memory,**kwargs):
78
- output=tgt;
79
- for layer in self.layers:output=layer(output,memory,**kwargs)
80
- if self.norm is not None:output=self.norm(output)
81
- return output.unsqueeze(0)
82
  class Interfuser(nn.Module):
83
- def __init__(self,img_size=224,multi_view_img_size=112,patch_size=8,in_chans=3,embed_dim=768,enc_depth=6,dec_depth=6,dim_feedforward=2048,normalize_before=False,rgb_backbone_name="r26",lidar_backbone_name="r26",num_heads=8,norm_layer=None,dropout=0.1,end2end=False,direct_concat=True,separate_view_attention=False,separate_all_attention=False,act_layer=None,weight_init="",freeze_num=-1,with_lidar=False,with_right_left_sensors=True,with_center_sensor=False,traffic_pred_head_type="det",waypoints_pred_head="heatmap",reverse_pos=True,use_different_backbone=False,use_view_embed=True,use_mmad_pretrain=None,):
 
84
  super().__init__();self.traffic_pred_head_type=traffic_pred_head_type;self.num_features=self.embed_dim=embed_dim;norm_layer=norm_layer or partial(nn.LayerNorm,eps=1e-6);act_layer=act_layer or nn.GELU;self.reverse_pos=reverse_pos;self.waypoints_pred_head=waypoints_pred_head;self.with_lidar=with_lidar;self.with_right_left_sensors=with_right_left_sensors;self.with_center_sensor=with_center_sensor;self.direct_concat=direct_concat;self.separate_view_attention=separate_view_attention;self.separate_all_attention=separate_all_attention;self.end2end=end2end;self.use_view_embed=use_view_embed
85
  if self.direct_concat:in_chans=in_chans*4;self.with_center_sensor=False;self.with_right_left_sensors=False
86
  self.attn_mask=None
@@ -94,48 +35,33 @@ class Interfuser(nn.Module):
94
  self.global_embed=nn.Parameter(torch.zeros(1,embed_dim,5));self.view_embed=nn.Parameter(torch.zeros(1,embed_dim,5,1))
95
  if waypoints_pred_head=="gru":self.query_pos_embed=nn.Parameter(torch.zeros(1,embed_dim,11));self.query_embed=nn.Parameter(torch.zeros(400+11,1,embed_dim))
96
  self.waypoints_generator=GRUWaypointsPredictor(embed_dim);self.junction_pred_head=nn.Linear(embed_dim,2);self.traffic_light_pred_head=nn.Linear(embed_dim,2);self.stop_sign_head=nn.Linear(embed_dim,2);self.traffic_pred_head=nn.Sequential(nn.Linear(embed_dim+32,64),nn.ReLU(),nn.Linear(64,7),nn.Sigmoid());self.position_encoding=PositionEmbeddingSine(embed_dim//2,normalize=True);encoder_layer=TransformerEncoderLayer(embed_dim,num_heads,dim_feedforward,dropout,act_layer,normalize_before);self.encoder=TransformerEncoder(encoder_layer,enc_depth,None);decoder_layer=TransformerDecoderLayer(embed_dim,num_heads,dim_feedforward,dropout,act_layer,normalize_before);decoder_norm=nn.LayerNorm(embed_dim);self.decoder=TransformerDecoder(decoder_layer,dec_depth,decoder_norm,return_intermediate=False)
97
- def forward_features(self,front_image,left_image,right_image,front_center_image,lidar,measurements):
98
- features=[];front_image_token,front_image_token_global=self.rgb_patch_embed(front_image)
99
- if self.use_view_embed:front_image_token=front_image_token+self.view_embed[:,:,0:1,:]+self.position_encoding(front_image_token)
100
- else:front_image_token=front_image_token+self.position_encoding(front_image_token)
101
- front_image_token=front_image_token.flatten(2).permute(2,0,1);front_image_token_global=(front_image_token_global+self.view_embed[:,:,0,:]+self.global_embed[:,:,0:1]).permute(2,0,1);features.extend([front_image_token,front_image_token_global])
102
- if self.with_right_left_sensors:
103
- left_image_token,left_image_token_global=self.rgb_patch_embed(left_image)
104
- if self.use_view_embed:left_image_token=left_image_token+self.view_embed[:,:,1:2,:]+self.position_encoding(left_image_token)
105
- else:left_image_token=left_image_token+self.position_encoding(left_image_token)
106
- left_image_token=left_image_token.flatten(2).permute(2,0,1);left_image_token_global=(left_image_token_global+self.view_embed[:,:,1,:]+self.global_embed[:,:,1:2]).permute(2,0,1)
107
- right_image_token,right_image_token_global=self.rgb_patch_embed(right_image)
108
- if self.use_view_embed:right_image_token=right_image_token+self.view_embed[:,:,2:3,:]+self.position_encoding(right_image_token)
109
- else:right_image_token=right_image_token+self.position_encoding(right_image_token)
110
- right_image_token=right_image_token.flatten(2).permute(2,0,1);right_image_token_global=(right_image_token_global+self.view_embed[:,:,2,:]+self.global_embed[:,:,2:3]).permute(2,0,1)
111
- features.extend([left_image_token,left_image_token_global,right_image_token,right_image_token_global])
112
- return torch.cat(features,0)
113
- def forward(self,x):
114
- front_image=x["rgb"];left_image=x["rgb_left"];right_image=x["rgb_right"];front_center_image=x["rgb_center"];measurements=x["measurements"];target_point=x["target_point"];lidar=x["lidar"]
115
- if self.direct_concat:
116
- img_size=front_image.shape[-1];left_image=torch.nn.functional.interpolate(left_image,size=(img_size,img_size));right_image=torch.nn.functional.interpolate(right_image,size=(img_size,img_size));front_center_image=torch.nn.functional.interpolate(front_center_image,size=(img_size,img_size))
117
- front_image=torch.cat([front_image,left_image,right_image,front_center_image],dim=1)
118
- features=self.forward_features(front_image,left_image,right_image,front_center_image,lidar,measurements)
119
- bs=front_image.shape[0];tgt=self.position_encoding(torch.ones((bs,1,20,20),device=x["rgb"].device)).flatten(2);tgt=torch.cat([tgt,self.query_pos_embed.repeat(bs,1,1)],2).permute(2,0,1)
120
- memory=self.encoder(features,mask=None);hs=self.decoder(self.query_embed.repeat(1,bs,1),memory,query_pos=tgt)[0].permute(1,0,2)
121
- waypoints=self.waypoints_generator(hs[:,401:411],target_point)
122
- is_junction=self.junction_pred_head(hs[:,400]);traffic_light_state=self.traffic_light_pred_head(hs[:,400]);stop_sign=self.stop_sign_head(hs[:,400])
123
- velocity=measurements[:,6:7].unsqueeze(-1).repeat(1,400,32);traffic_feature_with_vel=torch.cat([hs[:,:400],velocity],dim=2);traffic=self.traffic_pred_head(traffic_feature_with_vel)
124
- return traffic,waypoints,is_junction,traffic_light_state,stop_sign,hs[:,:400]
125
  class InterfuserConfig(PretrainedConfig):
126
  model_type="interfuser"
127
  def __init__(self, **kwargs):
128
  super().__init__(**kwargs)
129
  for key, value in kwargs.items(): setattr(self, key, value)
130
  self.auto_map={"AutoModel":"modeling_interfuser.InterfuserForHuggingFace"}
 
131
  class InterfuserForHuggingFace(PreTrainedModel):
132
  config_class = InterfuserConfig
133
  def __init__(self, config: InterfuserConfig):
134
  super().__init__(config)
135
  init_args = config.to_dict()
136
- keys_to_remove = ["model_type","architectures","auto_map","transformers_version","return_dict","norm_layer","act_layer"]
137
- for key in keys_to_remove: init_args.pop(key, None)
138
- self.interfuser = Interfuser(**init_args)
 
 
 
 
 
 
 
 
 
139
  def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
140
- inputs_dict = {'rgb':rgb,'rgb_left':rgb_left,'rgb_right':rgb_right,'rgb_center':rgb_center,'lidar':lidar,'measurements':measurements,'target_point':target_point}
141
  return self.interfuser.forward(inputs_dict)
 
13
  except ImportError:
14
  raise ImportError("This model requires timm. Please install with 'pip install timm==0.4.12' or a compatible version.")
15
 
16
+ # --- ALL HELPER CLASSES ---
17
  def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
18
+ # ... (لصق كل تعريفات الفئات المساعدة هنا كما في الإجابة السابقة)
19
+ # ...
20
+
21
+ # --- THE ORIGINAL INTERFUSER MODEL (UNMODIFIED) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class Interfuser(nn.Module):
23
+ # ... (لصق الكود الكامل لـ Interfuser من ملفك الأصلي هنا)
24
+ def __init__(self,img_size=224,multi_view_img_size=112,patch_size=8,in_chans=3,embed_dim=768,enc_depth=6,dec_depth=6,dim_feedforward=2048,normalize_before=False,rgb_backbone_name="r26",lidar_backbone_name="r26",num_heads=8,norm_layer=None,dropout=0.1,end2end=False,direct_concat=True,separate_view_attention=False,separate_all_attention=False,act_layer=None,weight_init="",freeze_num=-1,with_lidar=False,with_right_left_sensors=True,with_center_sensor=False,traffic_pred_head_type="det",waypoints_pred_head="heatmap",reverse_pos=True,use_different_backbone=False,use_view_embed=True,use_mmad_pretrain=None):
25
  super().__init__();self.traffic_pred_head_type=traffic_pred_head_type;self.num_features=self.embed_dim=embed_dim;norm_layer=norm_layer or partial(nn.LayerNorm,eps=1e-6);act_layer=act_layer or nn.GELU;self.reverse_pos=reverse_pos;self.waypoints_pred_head=waypoints_pred_head;self.with_lidar=with_lidar;self.with_right_left_sensors=with_right_left_sensors;self.with_center_sensor=with_center_sensor;self.direct_concat=direct_concat;self.separate_view_attention=separate_view_attention;self.separate_all_attention=separate_all_attention;self.end2end=end2end;self.use_view_embed=use_view_embed
26
  if self.direct_concat:in_chans=in_chans*4;self.with_center_sensor=False;self.with_right_left_sensors=False
27
  self.attn_mask=None
 
35
  self.global_embed=nn.Parameter(torch.zeros(1,embed_dim,5));self.view_embed=nn.Parameter(torch.zeros(1,embed_dim,5,1))
36
  if waypoints_pred_head=="gru":self.query_pos_embed=nn.Parameter(torch.zeros(1,embed_dim,11));self.query_embed=nn.Parameter(torch.zeros(400+11,1,embed_dim))
37
  self.waypoints_generator=GRUWaypointsPredictor(embed_dim);self.junction_pred_head=nn.Linear(embed_dim,2);self.traffic_light_pred_head=nn.Linear(embed_dim,2);self.stop_sign_head=nn.Linear(embed_dim,2);self.traffic_pred_head=nn.Sequential(nn.Linear(embed_dim+32,64),nn.ReLU(),nn.Linear(64,7),nn.Sigmoid());self.position_encoding=PositionEmbeddingSine(embed_dim//2,normalize=True);encoder_layer=TransformerEncoderLayer(embed_dim,num_heads,dim_feedforward,dropout,act_layer,normalize_before);self.encoder=TransformerEncoder(encoder_layer,enc_depth,None);decoder_layer=TransformerDecoderLayer(embed_dim,num_heads,dim_feedforward,dropout,act_layer,normalize_before);decoder_norm=nn.LayerNorm(embed_dim);self.decoder=TransformerDecoder(decoder_layer,dec_depth,decoder_norm,return_intermediate=False)
38
+ #...
39
+
40
+ # --- HF WRAPPER CLASSES (WITH THE FINAL FIX) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class InterfuserConfig(PretrainedConfig):
42
  model_type="interfuser"
43
  def __init__(self, **kwargs):
44
  super().__init__(**kwargs)
45
  for key, value in kwargs.items(): setattr(self, key, value)
46
  self.auto_map={"AutoModel":"modeling_interfuser.InterfuserForHuggingFace"}
47
+
48
  class InterfuserForHuggingFace(PreTrainedModel):
49
  config_class = InterfuserConfig
50
  def __init__(self, config: InterfuserConfig):
51
  super().__init__(config)
52
  init_args = config.to_dict()
53
+
54
+ # ** الإصلاح الرئيسي هنا **
55
+ # نزيل ��ل المفاتيح التي تضيفها transformers والتي لا يتوقعها __init__ الأصلي
56
+ # الطريقة الأكثر أماناً هي الحصول على قائمة المعاملات المتوقعة ديناميكياً
57
+ expected_keys = inspect.signature(Interfuser.__init__).parameters.keys()
58
+
59
+ # نقوم بإنشاء قاموس جديد يحتوي فقط على المفاتيح المتوقعة
60
+ final_args = {key: init_args[key] for key in expected_keys if key in init_args}
61
+
62
+ # الآن final_args يحتوي فقط على المعاملات التي يعرفها Interfuser
63
+ self.interfuser = Interfuser(**final_args)
64
+
65
  def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
66
+ inputs_dict = {'rgb':rgb, 'rgb_left':rgb_left, 'rgb_right':rgb_right, 'rgb_center':rgb_center, 'lidar':lidar, 'measurements':measurements, 'target_point':target_point}
67
  return self.interfuser.forward(inputs_dict)