import gradio as gr import torch import torch.nn as nn from torchvision import transforms from PIL import Image # Define the complex CNN model class ComplexCNN(nn.Module): def __init__(self, num_classes=4): super(ComplexCNN, self).__init__() self.layer1 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0) ) self.layer2 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0) ) self.layer3 = nn.Sequential( nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0) ) self.layer4 = nn.Sequential( nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, padding=0) ) self.fc1 = nn.Linear(512 * 2 * 2, 1024) # Adjust based on the final feature map size self.fc2 = nn.Linear(1024, 512) self.fc3 = nn.Linear(512, num_classes) self.relu = nn.ReLU() self.dropout = nn.Dropout(p=0.5) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = x.view(x.size(0), -1) # Flatten the output x = self.dropout(self.relu(self.fc1(x))) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # Load the trained model model = ComplexCNN(num_classes=4) model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu'))) model.eval() # Define the transformation transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Define the class labels class_labels = ['dog', 'goat', 'lion', 'sheep'] # Function to predict the class of an uploaded image def predict(image): image = transform(image).unsqueeze(0) # Transform and add batch dimension with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) predicted_class = class_labels[predicted.item()] return predicted_class # Create the Gradio interface interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Predicted Class") ) # Launch the Gradio app if __name__ == "__main__": interface.launch()