Spaces:
Runtime error
Runtime error
echen01
commited on
Commit
·
f7bf9fb
1
Parent(s):
5b7158a
fix deeplab device
Browse files- criteria/deeplab.py +2 -2
- criteria/mask.py +1 -0
criteria/deeplab.py
CHANGED
|
@@ -309,7 +309,7 @@ def resnet50(pretrained=False, **kwargs):
|
|
| 309 |
return model
|
| 310 |
|
| 311 |
|
| 312 |
-
def resnet101(path=None, pretrained=False, num_groups=None, weight_std=False, **kwargs):
|
| 313 |
"""Constructs a ResNet-101 model.
|
| 314 |
|
| 315 |
Args:
|
|
@@ -326,7 +326,7 @@ def resnet101(path=None, pretrained=False, num_groups=None, weight_std=False, **
|
|
| 326 |
model_dict = model.state_dict()
|
| 327 |
if num_groups and weight_std:
|
| 328 |
path = os.path.join(os.path.dirname(path), "R-101-GN-WS.pth.tar")
|
| 329 |
-
pretrained_dict = torch.load(path)
|
| 330 |
overlap_dict = {
|
| 331 |
k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict
|
| 332 |
}
|
|
|
|
| 309 |
return model
|
| 310 |
|
| 311 |
|
| 312 |
+
def resnet101(path=None, pretrained=False, num_groups=None, weight_std=False, device="cpu", **kwargs):
|
| 313 |
"""Constructs a ResNet-101 model.
|
| 314 |
|
| 315 |
Args:
|
|
|
|
| 326 |
model_dict = model.state_dict()
|
| 327 |
if num_groups and weight_std:
|
| 328 |
path = os.path.join(os.path.dirname(path), "R-101-GN-WS.pth.tar")
|
| 329 |
+
pretrained_dict = torch.load(path, map_location=device)
|
| 330 |
overlap_dict = {
|
| 331 |
k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict
|
| 332 |
}
|
criteria/mask.py
CHANGED
|
@@ -36,6 +36,7 @@ class Mask(nn.Module):
|
|
| 36 |
num_groups=32,
|
| 37 |
weight_std=True,
|
| 38 |
beta=False,
|
|
|
|
| 39 |
)
|
| 40 |
.eval()
|
| 41 |
.requires_grad_(False)
|
|
|
|
| 36 |
num_groups=32,
|
| 37 |
weight_std=True,
|
| 38 |
beta=False,
|
| 39 |
+
device=device,
|
| 40 |
)
|
| 41 |
.eval()
|
| 42 |
.requires_grad_(False)
|