Spaces:
Runtime error
Runtime error
| import torch | |
| from torchvision.models import vit_b_16, ViT_B_16_Weights | |
| from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights | |
| def get_vit_16_base_transformer(): | |
| vit_b_16_model = torch.load("models/vit_16_base_custom_head_3_classes.pth", map_location = torch.device('cpu')) | |
| vit_b_16_transforms = ViT_B_16_Weights.DEFAULT.transforms() | |
| return vit_b_16_model, vit_b_16_transforms | |
| def get_effnet_b2(): | |
| eff_net_b2_model = torch.load("models/eff_netb2_custom_head_3_classes.pth", map_location = torch.device('cpu')) | |
| eff_net_b2_transforms = EfficientNet_B2_Weights.DEFAULT.transforms() | |
| return eff_net_b2_model, eff_net_b2_transforms |