Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import torch.nn as nn | |
| class CustomEfficientNet(nn.Module): | |
| def __init__(self, num_classes, num_layers, neurons_per_layer): | |
| super(CustomEfficientNet, self).__init__() | |
| self.base_model = models.efficientnet_b0(pretrained=True) | |
| in_features = self.base_model.classifier[1].in_features | |
| self.base_model.classifier = nn.Identity() # Remove the existing classifier | |
| # Define custom layers | |
| layers = [] | |
| for _ in range(num_layers): | |
| layers.append(nn.Linear(in_features, neurons_per_layer)) | |
| layers.append(nn.ReLU()) | |
| layers.append(nn.Dropout(0.5)) | |
| in_features = neurons_per_layer | |
| layers.append(nn.Linear(neurons_per_layer, num_classes)) | |
| self.custom_classifier = nn.Sequential(*layers) | |
| def forward(self, x): | |
| x = self.base_model(x) | |
| x = x.view(x.size(0), -1) # Flatten the tensor | |
| x = self.custom_classifier(x) | |
| return x | |
| def create_model(num_classes, num_layers, neurons_per_layer): | |
| model = CustomEfficientNet(num_classes, num_layers, neurons_per_layer) | |
| return model | |
| def load_model(path, num_classes, num_layers, neurons_per_layer): | |
| model = create_model(num_classes, num_layers, neurons_per_layer) | |
| model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| # Parameters | |
| num_classes = 52 | |
| num_layers = 3 | |
| neurons_per_layer = 1024 | |
| # Load the model | |
| model = load_model('card_classification_model.pth', num_classes, num_layers, neurons_per_layer) | |
| # Define the transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # Class names | |
| class_names = ['Coeur 1', 'Coeur 10', 'Coeur 2', 'Coeur 3', 'Coeur 4', 'Coeur 5', 'Coeur 6', | |
| 'Coeur 7', 'Coeur 8', 'Coeur 9', 'Coeur Dame', 'Coeur Roi', 'Coeur Valet', 'Pique 1', | |
| 'Pique 10', 'Pique 2', 'Pique 3', 'Pique 4', 'Pique 5', 'Pique 6', 'Pique 7', 'Pique 8', | |
| 'Pique 9', 'Pique Dame', 'Pique Roi', 'Pique Valet', 'Trefle 1', 'Trefle 10', 'Trefle 2', | |
| 'Trefle 3', 'Trefle 4', 'Trefle 5', 'Trefle 6', 'Trefle 7', 'Trefle 8', 'Trefle 9', 'Trefle Dame', | |
| 'Trefle Roi', 'Trefle Valet', 'carreau 1', 'carreau 10', 'carreau 2', 'carreau 3', 'carreau 4', 'carreau 5', | |
| 'carreau 6', 'carreau 7', 'carreau 8', 'carreau 9', 'carreau Dame', 'carreau Roi', 'carreau Valet'] | |
| def predict(image): | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| _, predicted = torch.max(outputs, 1) | |
| return class_names[predicted[0]] | |
| # Create the Gradio interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs="label", | |
| description="Upload an image to classify" | |
| ) | |
| iface.launch() | |