Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| # Define the complex CNN model | |
| class ComplexCNN(nn.Module): | |
| def __init__(self, num_classes=4): | |
| super(ComplexCNN, self).__init__() | |
| self.layer1 = nn.Sequential( | |
| nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | |
| ) | |
| self.layer2 = nn.Sequential( | |
| nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | |
| ) | |
| self.layer3 = nn.Sequential( | |
| nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(256), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | |
| ) | |
| self.layer4 = nn.Sequential( | |
| nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), | |
| nn.BatchNorm2d(512), | |
| nn.ReLU(), | |
| nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | |
| ) | |
| self.fc1 = nn.Linear(512 * 2 * 2, 1024) # Adjust based on the final feature map size | |
| self.fc2 = nn.Linear(1024, 512) | |
| self.fc3 = nn.Linear(512, num_classes) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(p=0.5) | |
| def forward(self, x): | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = x.view(x.size(0), -1) # Flatten the output | |
| x = self.dropout(self.relu(self.fc1(x))) | |
| x = self.relu(self.fc2(x)) | |
| x = self.fc3(x) | |
| return x | |
| # Load the trained model | |
| model = ComplexCNN(num_classes=4) | |
| model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu'))) | |
| model.eval() | |
| # Define the transformation | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| # Define the class labels | |
| class_labels = ['dog', 'goat', 'lion', 'sheep'] | |
| # Function to predict the class of an uploaded image | |
| def predict(image): | |
| image = transform(image).unsqueeze(0) # Transform and add batch dimension | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| _, predicted = torch.max(outputs, 1) | |
| predicted_class = class_labels[predicted.item()] | |
| return predicted_class | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Textbox(label="Predicted Class") | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| interface.launch() | |