| datasets: | |
| - garythung/trashnet | |
| pipeline_tag: image-classification | |
| to load this state model use this step: | |
| #define the model | |
| model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) | |
| for param in model_resnet.parameters(): | |
| param.requires_grad = False | |
| num_ftrs = model_resnet.fc.in_features | |
| model_resnet.fc = nn.Linear(num_ftrs, 6) | |
| # Load the weights | |
| state_dict = torch.load('trashnet_resnet50.pth') | |
| model.load_state_dict(state_dict) | |
| # Switch to evaluation mode | |
| model.eval() | |