Final fix v4: Corrected model logic to properly use config
Browse files- modeling_interfuser.py +72 -35
modeling_interfuser.py
CHANGED
|
@@ -10,54 +10,91 @@ from transformers import PreTrainedModel, PretrainedConfig
|
|
| 10 |
try:
|
| 11 |
from timm.models.layers import to_2tuple
|
| 12 |
from timm.models.resnet import resnet50d, resnet26d, resnet18d
|
| 13 |
-
except ImportError:
|
| 14 |
-
raise ImportError("This model requires timm. Please install with 'pip install timm==0.3.2'")
|
| 15 |
|
| 16 |
def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
| 17 |
class HybridEmbed(nn.Module):
|
| 18 |
def __init__(self, backbone, img_size=224, in_chans=3, embed_dim=768, **kwargs):
|
| 19 |
-
super().__init__();
|
| 20 |
with torch.no_grad():
|
| 21 |
-
dummy_input=torch.zeros(1,in_chans,img_size,img_size);
|
| 22 |
self.proj=nn.Conv2d(feature_dim,embed_dim,kernel_size=1)
|
| 23 |
-
def forward(self,
|
| 24 |
-
features=self.backbone(x);
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
# --- THE CORRECTED Interfuser CLASS ---
|
| 29 |
class Interfuser(nn.Module):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
# --- يستخدم القيم من config حصرياً ---
|
| 35 |
-
self.embed_dim = config.embed_dim
|
| 36 |
-
self.waypoints_pred_head = config.waypoints_pred_head
|
| 37 |
if config.use_different_backbone:
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
in_chans_effective = config.in_chans * 4 if config.direct_concat else config.in_chans
|
| 41 |
-
|
| 42 |
-
# نستخدم in_chans_effective هنا
|
| 43 |
-
if config.rgb_backbone_name=="r50": self.rgb_backbone = resnet50d(pretrained=False, in_chans=in_chans_effective, features_only=True, out_indices=[4])
|
| 44 |
-
else: self.rgb_backbone=resnet18d(pretrained=False, in_chans=in_chans_effective, features_only=True, out_indices=[4])
|
| 45 |
self.rgb_patch_embed=partial(HybridEmbed,backbone=self.rgb_backbone)(in_chans=in_chans_effective,embed_dim=config.embed_dim)
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
def forward(self, x):
|
| 49 |
-
# ... (تعريف forward كما هو)
|
| 50 |
-
|
| 51 |
-
# --- HF WRAPPER CLASSES ---
|
| 52 |
class InterfuserConfig(PretrainedConfig):
|
| 53 |
-
model_type
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
class InterfuserForHuggingFace(PreTrainedModel):
|
| 57 |
-
config_class
|
| 58 |
-
def __init__(self,
|
| 59 |
-
super().__init__(config)
|
| 60 |
-
|
| 61 |
-
def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
|
| 62 |
inputs_dict={'rgb':rgb,'rgb_left':rgb_left,'rgb_right':rgb_right,'rgb_center':rgb_center,'lidar':lidar,'measurements':measurements,'target_point':target_point}
|
| 63 |
return self.interfuser_model.forward(inputs_dict)
|
|
|
|
| 10 |
try:
|
| 11 |
from timm.models.layers import to_2tuple
|
| 12 |
from timm.models.resnet import resnet50d, resnet26d, resnet18d
|
| 13 |
+
except ImportError: raise ImportError("This model requires timm. Please install with 'pip install timm==0.3.2'")
|
|
|
|
| 14 |
|
| 15 |
def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
| 16 |
class HybridEmbed(nn.Module):
|
| 17 |
def __init__(self, backbone, img_size=224, in_chans=3, embed_dim=768, **kwargs):
|
| 18 |
+
super().__init__();self.backbone=backbone
|
| 19 |
with torch.no_grad():
|
| 20 |
+
dummy_input=torch.zeros(1,in_chans,img_size,img_size);features=self.backbone(dummy_input);feature_dim=features[-1].shape[1]
|
| 21 |
self.proj=nn.Conv2d(feature_dim,embed_dim,kernel_size=1)
|
| 22 |
+
def forward(self,x):
|
| 23 |
+
features=self.backbone(x);x_feat=features[-1];x=self.proj(x_feat);global_x=torch.mean(x,[2,3],keepdim=False)[:,:,None];return x,global_x
|
| 24 |
+
class PositionEmbeddingSine(nn.Module):
|
| 25 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, **kwargs):
|
| 26 |
+
super().__init__();self.num_pos_feats=num_pos_feats;self.temperature=temperature;self.normalize=normalize;self.scale=2*math.pi
|
| 27 |
+
def forward(self, tensor):
|
| 28 |
+
x=tensor;bs,_,h,w=x.shape;not_mask=torch.ones((bs,h,w),device=x.device);y_embed=not_mask.cumsum(1,dtype=torch.float32);x_embed=not_mask.cumsum(2,dtype=torch.float32)
|
| 29 |
+
if self.normalize:eps=1e-6;y_embed=y_embed/(y_embed[:,-1:,:]+eps)*self.scale;x_embed=x_embed/(x_embed[:,:,-1:]+eps)*self.scale
|
| 30 |
+
dim_t=torch.arange(self.num_pos_feats,dtype=torch.float32,device=x.device);dim_t=self.temperature**(2*(dim_t//2)/self.num_pos_feats);pos_x=x_embed[:,:,:,None]/dim_t;pos_y=y_embed[:,:,:,None]/dim_t
|
| 31 |
+
pos_x=torch.stack((pos_x[:,:,:,0::2].sin(),pos_x[:,:,:,1::2].cos()),dim=4).flatten(3);pos_y=torch.stack((pos_y[:,:,:,0::2].sin(),pos_y[:,:,:,1::2].cos()),dim=4).flatten(3);return torch.cat((pos_y,pos_x),dim=3).permute(0,3,1,2)
|
| 32 |
+
class TransformerEncoderLayer(nn.Module):
|
| 33 |
+
def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),**kwargs):
|
| 34 |
+
super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.activation=activation()
|
| 35 |
+
def with_pos_embed(self,tensor,pos):return tensor if pos is None else tensor+pos
|
| 36 |
+
def forward(self,src,src_mask=None,src_key_padding_mask=None,pos=None):
|
| 37 |
+
q=k=self.with_pos_embed(src,pos);src2=self.self_attn(q,k,value=src,attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0];src=src+self.dropout1(src2);src=self.norm1(src);src2=self.linear2(self.dropout(self.activation(self.linear1(src))));src=src+self.dropout2(src2);src=self.norm2(src);return src
|
| 38 |
+
class TransformerEncoder(nn.Module):
|
| 39 |
+
def __init__(self,encoder_layer,num_layers,norm=None):super().__init__();self.layers=_get_clones(encoder_layer,num_layers);self.norm=norm
|
| 40 |
+
def forward(self,src,mask=None,src_key_padding_mask=None,pos=None):
|
| 41 |
+
output=src;[output:=layer(output,src_mask=mask,src_key_padding_mask=src_key_padding_mask,pos=pos) for layer in self.layers];
|
| 42 |
+
if self.norm is not None:output=self.norm(output)
|
| 43 |
+
return output
|
| 44 |
+
class TransformerDecoderLayer(nn.Module):
|
| 45 |
+
def __init__(self,d_model,nhead,dim_feedforward=2048,dropout=0.1,activation=nn.ReLU(),**kwargs):
|
| 46 |
+
super().__init__();self.self_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.multihead_attn=nn.MultiheadAttention(d_model,nhead,dropout=dropout);self.linear1=nn.Linear(d_model,dim_feedforward);self.dropout=nn.Dropout(dropout);self.linear2=nn.Linear(dim_feedforward,d_model);self.norm1=nn.LayerNorm(d_model);self.norm2=nn.LayerNorm(d_model);self.norm3=nn.LayerNorm(d_model);self.dropout1=nn.Dropout(dropout);self.dropout2=nn.Dropout(dropout);self.dropout3=nn.Dropout(dropout);self.activation=activation()
|
| 47 |
+
def with_pos_embed(self,tensor,pos):return tensor if pos is None else tensor+pos
|
| 48 |
+
def forward(self,tgt,memory,tgt_mask=None,memory_mask=None,tgt_key_padding_mask=None,memory_key_padding_mask=None,pos=None,query_pos=None):
|
| 49 |
+
q=k=self.with_pos_embed(tgt,query_pos);tgt2=self.self_attn(q,k,value=tgt,attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0];tgt=tgt+self.dropout1(tgt2);tgt=self.norm1(tgt);tgt2=self.multihead_attn(query=self.with_pos_embed(tgt,query_pos),key=self.with_pos_embed(memory,pos),value=memory,attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0];tgt=tgt+self.dropout2(tgt2);tgt=self.norm2(tgt);tgt2=self.linear2(self.dropout(self.activation(self.linear1(tgt))));tgt=tgt+self.dropout3(tgt2);tgt=self.norm3(tgt);return tgt
|
| 50 |
+
class TransformerDecoder(nn.Module):
|
| 51 |
+
def __init__(self,decoder_layer,num_layers,norm=None,return_intermediate=False):super().__init__();self.layers=_get_clones(decoder_layer,num_layers);self.norm=norm;self.return_intermediate=return_intermediate
|
| 52 |
+
def forward(self,tgt,memory,**kwargs):output=tgt;[output:=layer(output,memory,**kwargs) for layer in self.layers];return self.norm(output).unsqueeze(0) if self.norm is not None else output.unsqueeze(0)
|
| 53 |
+
class GRUWaypointsPredictor(nn.Module):
|
| 54 |
+
def __init__(self,input_dim,waypoints=10,**kwargs):super().__init__();self.gru=torch.nn.GRU(input_size=input_dim,hidden_size=64,batch_first=True);self.encoder=nn.Linear(2,64);self.decoder=nn.Linear(64,2);self.waypoints=waypoints
|
| 55 |
+
def forward(self,x,target_point):bs=x.shape[0];z=self.encoder(target_point).unsqueeze(0);output,_=self.gru(x,z);output=self.decoder(output.reshape(bs*self.waypoints,-1)).reshape(bs,self.waypoints,2);return torch.cumsum(output,1)
|
| 56 |
|
|
|
|
| 57 |
class Interfuser(nn.Module):
|
| 58 |
+
def __init__(self, config):
|
| 59 |
+
super().__init__();self.config=config;self.embed_dim=config.embed_dim;self.waypoints_pred_head=config.waypoints_pred_head
|
| 60 |
+
# ** THE LOGICAL FIX IS HERE **
|
| 61 |
+
in_chans_effective=config.in_chans*4 if config.direct_concat else config.in_chans
|
|
|
|
|
|
|
|
|
|
| 62 |
if config.use_different_backbone:
|
| 63 |
+
if config.rgb_backbone_name=="r50":self.rgb_backbone=resnet50d(pretrained=False,in_chans=in_chans_effective,features_only=True,out_indices=[4])
|
| 64 |
+
else:self.rgb_backbone=resnet18d(pretrained=False,in_chans=in_chans_effective,features_only=True,out_indices=[4])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
self.rgb_patch_embed=partial(HybridEmbed,backbone=self.rgb_backbone)(in_chans=in_chans_effective,embed_dim=config.embed_dim)
|
| 66 |
+
self.global_embed=nn.Parameter(torch.zeros(1,config.embed_dim,5));self.view_embed=nn.Parameter(torch.zeros(1,config.embed_dim,5,1));self.query_pos_embed=nn.Parameter(torch.zeros(1,config.embed_dim,11));self.query_embed=nn.Parameter(torch.zeros(400+11,1,config.embed_dim))
|
| 67 |
+
self.waypoints_generator=GRUWaypointsPredictor(config.embed_dim);self.junction_pred_head=nn.Linear(config.embed_dim,2);self.traffic_light_pred_head=nn.Linear(config.embed_dim,2);self.stop_sign_head=nn.Linear(config.embed_dim,2)
|
| 68 |
+
self.traffic_pred_head=nn.Sequential(nn.Linear(config.embed_dim+32,64),nn.ReLU(),nn.Linear(64,7),nn.Sigmoid());self.position_encoding=PositionEmbeddingSine(config.embed_dim//2)
|
| 69 |
+
encoder_layer=TransformerEncoderLayer(config.embed_dim,config.num_heads,config.dim_feedforward);self.encoder=TransformerEncoder(encoder_layer,config.enc_depth);decoder_layer=TransformerDecoderLayer(config.embed_dim,config.num_heads,config.dim_feedforward);self.decoder=TransformerDecoder(decoder_layer,config.dec_depth,nn.LayerNorm(config.embed_dim))
|
| 70 |
+
def forward_features(self,input_image,**kwargs):
|
| 71 |
+
features,_=self.rgb_patch_embed(input_image);features=(features+self.position_encoding(features)).flatten(2).permute(2,0,1)
|
| 72 |
+
return features
|
| 73 |
+
def forward(self,x):
|
| 74 |
+
front=x["rgb"];left=x["rgb_left"];right=x["rgb_right"];center=x["rgb_center"];measurements=x["measurements"];target_point=x["target_point"]
|
| 75 |
+
if self.config.direct_concat:
|
| 76 |
+
s=front.shape[-1];left=torch.nn.functional.interpolate(left,size=(s,s));right=torch.nn.functional.interpolate(right,size=(s,s));center=torch.nn.functional.interpolate(center,size=(s,s))
|
| 77 |
+
input_image=torch.cat([front,left,right,center],dim=1)
|
| 78 |
+
else:input_image=front
|
| 79 |
+
features=self.forward_features(input_image)
|
| 80 |
+
bs=front.shape[0];tgt=self.position_encoding(torch.ones((bs,1,20,20),device=x["rgb"].device)).flatten(2);tgt=torch.cat([tgt,self.query_pos_embed.repeat(bs,1,1)],2).permute(2,0,1)
|
| 81 |
+
memory=self.encoder(features);hs=self.decoder(self.query_embed.repeat(1,bs,1),memory,query_pos=tgt)[0].permute(1,0,2)
|
| 82 |
+
waypoints=self.waypoints_generator(hs[:,401:411],target_point)
|
| 83 |
+
is_junc=self.junction_pred_head(hs[:,400]);light=self.traffic_light_pred_head(hs[:,400]);stop=self.stop_sign_head(hs[:,400])
|
| 84 |
+
vel=measurements[:,6:7].unsqueeze(-1).repeat(1,400,32);traffic_feat=torch.cat([hs[:,:400],vel],dim=2);traffic=self.traffic_pred_head(traffic_feat)
|
| 85 |
+
return traffic,waypoints,is_junc,light,stop,hs[:,:400]
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
class InterfuserConfig(PretrainedConfig):
|
| 88 |
+
model_type="interfuser"
|
| 89 |
+
def __init__(self,in_chans=3,embed_dim=256,enc_depth=6,dec_depth=6,num_heads=8,dim_feedforward=2048,rgb_backbone_name="r50",lidar_backbone_name="r18",use_different_backbone=True,waypoints_pred_head="gru",direct_concat=True,**kwargs):
|
| 90 |
+
super().__init__(**kwargs);self.in_chans=in_chans;self.embed_dim=embed_dim;self.enc_depth=enc_depth;self.dec_depth=dec_depth;self.num_heads=num_heads;self.dim_feedforward=dim_feedforward
|
| 91 |
+
self.rgb_backbone_name=rgb_backbone_name;self.lidar_backbone_name=lidar_backbone_name;self.use_different_backbone=use_different_backbone;self.waypoints_pred_head=waypoints_pred_head;self.direct_concat=direct_concat
|
| 92 |
+
self.auto_map={"AutoModel":"modeling_interfuser.InterfuserForHuggingFace"}
|
| 93 |
|
| 94 |
class InterfuserForHuggingFace(PreTrainedModel):
|
| 95 |
+
config_class=InterfuserConfig
|
| 96 |
+
def __init__(self,config:InterfuserConfig):
|
| 97 |
+
super().__init__(config);self.config=config;self.interfuser_model=Interfuser(config)
|
| 98 |
+
def forward(self,rgb,rgb_left,rgb_right,rgb_center,lidar,measurements,target_point,**kwargs):
|
|
|
|
| 99 |
inputs_dict={'rgb':rgb,'rgb_left':rgb_left,'rgb_right':rgb_right,'rgb_center':rgb_center,'lidar':lidar,'measurements':measurements,'target_point':target_point}
|
| 100 |
return self.interfuser_model.forward(inputs_dict)
|