Update README.md
Browse filesFix the problem of missing definition of device
README.md
CHANGED
|
@@ -108,6 +108,7 @@ class EnsembleModel(nn.Module):
|
|
| 108 |
output = self.fc(torch.cat((convnext_out, mobilenet_out, efficientnet_out), dim=1))
|
| 109 |
return output
|
| 110 |
|
|
|
|
| 111 |
convnext_model = CustomConvNeXtModel(weights=ConvNeXt_Tiny_Weights.DEFAULT, num_classes=2)
|
| 112 |
mobilenet_model = CustomMobileNetModel(weights=MobileNet_V2_Weights.DEFAULT, num_classes=2)
|
| 113 |
efficientnet_model = CustomEfficientNetModel(weights=EfficientNet_B0_Weights.DEFAULT, num_classes=2)
|
|
|
|
| 108 |
output = self.fc(torch.cat((convnext_out, mobilenet_out, efficientnet_out), dim=1))
|
| 109 |
return output
|
| 110 |
|
| 111 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 112 |
convnext_model = CustomConvNeXtModel(weights=ConvNeXt_Tiny_Weights.DEFAULT, num_classes=2)
|
| 113 |
mobilenet_model = CustomMobileNetModel(weights=MobileNet_V2_Weights.DEFAULT, num_classes=2)
|
| 114 |
efficientnet_model = CustomEfficientNetModel(weights=EfficientNet_B0_Weights.DEFAULT, num_classes=2)
|