pepitolechevalier commited on
Commit
39fde13
·
verified ·
1 Parent(s): 2441757

Upload pim_module.py

Browse files
Files changed (1) hide show
  1. pim_module.py +7 -1
pim_module.py CHANGED
@@ -231,7 +231,13 @@ class FPN(nn.Module):
231
  )
232
  elif proj_type == "Linear":
233
  m = nn.Sequential(
234
- nn.Linear(inputs[node_name].size(-1), inputs[node_name].size(-1)),
 
 
 
 
 
 
235
  nn.ReLU(),
236
  nn.Linear(inputs[node_name].size(-1), fpn_size),
237
  )
 
231
  )
232
  elif proj_type == "Linear":
233
  m = nn.Sequential(
234
+ in_feat = inputs[node_name]
235
+ if isinstance(in_feat, torch.Tensor):
236
+ dim = in_feat.size(-1)
237
+ else:
238
+ raise ValueError(f"Entrée invalide dans FPN: {type(in_feat)} pour node_name={node_name}")
239
+ nn.Linear(dim, dim)
240
+ # nn.Linear(inputs[node_name].size(-1), inputs[node_name].size(-1)),
241
  nn.ReLU(),
242
  nn.Linear(inputs[node_name].size(-1), fpn_size),
243
  )