Spaces:
Sleeping
Sleeping
Commit
·
c442067
1
Parent(s):
ed99578
fixed map location issue in models.py
Browse files
models.py
CHANGED
|
@@ -3,13 +3,13 @@ from torchvision.models import ViT_B_16_Weights, EfficientNet_B2_Weights
|
|
| 3 |
|
| 4 |
|
| 5 |
def get_vit_16_base_transformer():
|
| 6 |
-
vit_b_16_model = torch.load(
|
| 7 |
vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()
|
| 8 |
|
| 9 |
return vit_b_16_model, vit_b_16_transforms
|
| 10 |
|
| 11 |
def get_effnet_b2():
|
| 12 |
-
eff_net_b2_model = torch.load(
|
| 13 |
eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()
|
| 14 |
|
| 15 |
return eff_net_b2_model, eff_net_b2_transforms
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
def get_vit_16_base_transformer():
|
| 6 |
+
vit_b_16_model = torch.load("models/ViT_16_base_101_classes_pretrained_custom_head.pth", map_location = torch.device("cpu"))
|
| 7 |
vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms()
|
| 8 |
|
| 9 |
return vit_b_16_model, vit_b_16_transforms
|
| 10 |
|
| 11 |
def get_effnet_b2():
|
| 12 |
+
eff_net_b2_model = torch.load("models/effnet_b2_101_classes_pretrained_custom_head.pth", map_location = torch.device("cpu"))
|
| 13 |
eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms()
|
| 14 |
|
| 15 |
return eff_net_b2_model, eff_net_b2_transforms
|