Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from src.model import TrashNetClassifier | |
| from torchvision.models import resnet18, efficientnet_b0 | |
| class EnsembleModel(nn.Module): | |
| def __init__(self, num_classes=6): | |
| super(EnsembleModel, self).__init__() | |
| self.model1 = TrashNetClassifier(num_classes=num_classes) | |
| self.model2 = resnet18(pretrained=True) | |
| self.model2.fc = nn.Linear(self.model2.fc.in_features, num_classes) | |
| self.model3 = efficientnet_b0(pretrained=True) | |
| self.model3.classifier[1] = nn.Linear(self.model3.classifier[1].in_features, num_classes) | |
| def forward(self, x): | |
| out1 = self.model1(x) | |
| out2 = self.model2(x) | |
| out3 = self.model3(x) | |
| return (out1 + out2 + out3) / 3 |