NoelShin
commited on
Commit
·
6d6f3c6
1
Parent(s):
bcc8459
fix state dict loading error
Browse files- .idea/workspace.xml +1 -1
- utils.py +4 -3
.idea/workspace.xml
CHANGED
|
@@ -48,7 +48,7 @@
|
|
| 48 |
<option name="presentableId" value="Default" />
|
| 49 |
<updated>1664204268713</updated>
|
| 50 |
<workItem from="1664204270261" duration="37000" />
|
| 51 |
-
<workItem from="1664204316867" duration="
|
| 52 |
</task>
|
| 53 |
<servers />
|
| 54 |
</component>
|
|
|
|
| 48 |
<option name="presentableId" value="Default" />
|
| 49 |
<updated>1664204268713</updated>
|
| 50 |
<workItem from="1664204270261" duration="37000" />
|
| 51 |
+
<workItem from="1664204316867" duration="5840000" />
|
| 52 |
</task>
|
| 53 |
<servers />
|
| 54 |
</component>
|
utils.py
CHANGED
|
@@ -8,13 +8,14 @@ from networks import convert_to_separable_conv, set_bn_momentum
|
|
| 8 |
def get_network() -> torch.nn.Module:
|
| 9 |
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
state_dict = torch.hub.load_state_dict_from_url(
|
| 12 |
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
|
| 13 |
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
)
|
| 15 |
-
network.
|
| 16 |
-
convert_to_separable_conv(network.classifier)
|
| 17 |
-
set_bn_momentum(network.backbone, momentum=0.01)
|
| 18 |
return network
|
| 19 |
|
| 20 |
|
|
|
|
| 8 |
def get_network() -> torch.nn.Module:
|
| 9 |
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
|
| 10 |
|
| 11 |
+
convert_to_separable_conv(network.classifier)
|
| 12 |
+
set_bn_momentum(network.backbone, momentum=0.01)
|
| 13 |
+
|
| 14 |
state_dict = torch.hub.load_state_dict_from_url(
|
| 15 |
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
|
| 16 |
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 17 |
)
|
| 18 |
+
network.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
|
| 19 |
return network
|
| 20 |
|
| 21 |
|