NoelShin
commited on
Commit
·
bcc8459
1
Parent(s):
8513597
fix device error
Browse files- .idea/workspace.xml +2 -3
- utils.py +3 -1
.idea/workspace.xml
CHANGED
|
@@ -6,8 +6,7 @@
|
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
<list default="true" id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="">
|
| 8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
| 9 |
-
<change beforePath="$PROJECT_DIR$/
|
| 10 |
-
<change beforePath="$PROJECT_DIR$/networks/backbone/__init__.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/backbone/__init__.py" afterDir="false" />
|
| 11 |
</list>
|
| 12 |
<option name="SHOW_DIALOG" value="false" />
|
| 13 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
@@ -49,7 +48,7 @@
|
|
| 49 |
<option name="presentableId" value="Default" />
|
| 50 |
<updated>1664204268713</updated>
|
| 51 |
<workItem from="1664204270261" duration="37000" />
|
| 52 |
-
<workItem from="1664204316867" duration="
|
| 53 |
</task>
|
| 54 |
<servers />
|
| 55 |
</component>
|
|
|
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
<list default="true" id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="">
|
| 8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
| 9 |
+
<change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/utils.py" afterDir="false" />
|
|
|
|
| 10 |
</list>
|
| 11 |
<option name="SHOW_DIALOG" value="false" />
|
| 12 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
|
|
| 48 |
<option name="presentableId" value="Default" />
|
| 49 |
<updated>1664204268713</updated>
|
| 50 |
<workItem from="1664204270261" duration="37000" />
|
| 51 |
+
<workItem from="1664204316867" duration="5530000" />
|
| 52 |
</task>
|
| 53 |
<servers />
|
| 54 |
</component>
|
utils.py
CHANGED
|
@@ -7,8 +7,10 @@ from networks import convert_to_separable_conv, set_bn_momentum
|
|
| 7 |
|
| 8 |
def get_network() -> torch.nn.Module:
|
| 9 |
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
|
|
|
|
| 10 |
state_dict = torch.hub.load_state_dict_from_url(
|
| 11 |
-
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt"
|
|
|
|
| 12 |
)
|
| 13 |
network.backbone.load_state_dict(state_dict, strict=True)
|
| 14 |
convert_to_separable_conv(network.classifier)
|
|
|
|
| 7 |
|
| 8 |
def get_network() -> torch.nn.Module:
|
| 9 |
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
|
| 10 |
+
|
| 11 |
state_dict = torch.hub.load_state_dict_from_url(
|
| 12 |
+
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
|
| 13 |
+
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
)
|
| 15 |
network.backbone.load_state_dict(state_dict, strict=True)
|
| 16 |
convert_to_separable_conv(network.classifier)
|