Upload pim_module.py
Browse files- pim_module.py +7 -1
pim_module.py
CHANGED
|
@@ -231,7 +231,13 @@ class FPN(nn.Module):
|
|
| 231 |
)
|
| 232 |
elif proj_type == "Linear":
|
| 233 |
m = nn.Sequential(
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
nn.ReLU(),
|
| 236 |
nn.Linear(inputs[node_name].size(-1), fpn_size),
|
| 237 |
)
|
|
|
|
| 231 |
)
|
| 232 |
elif proj_type == "Linear":
|
| 233 |
m = nn.Sequential(
|
| 234 |
+
in_feat = inputs[node_name]
|
| 235 |
+
if isinstance(in_feat, torch.Tensor):
|
| 236 |
+
dim = in_feat.size(-1)
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"Entrée invalide dans FPN: {type(in_feat)} pour node_name={node_name}")
|
| 239 |
+
nn.Linear(dim, dim)
|
| 240 |
+
# nn.Linear(inputs[node_name].size(-1), inputs[node_name].size(-1)),
|
| 241 |
nn.ReLU(),
|
| 242 |
nn.Linear(inputs[node_name].size(-1), fpn_size),
|
| 243 |
)
|