DL_assignment / app.py
Uzairabbasi's picture
Update app.py
433d89f verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
# Define the complex CNN model
class ComplexCNN(nn.Module):
def __init__(self, num_classes=4):
super(ComplexCNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.layer2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.layer3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.layer4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.fc1 = nn.Linear(512 * 2 * 2, 1024) # Adjust based on the final feature map size
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, num_classes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1) # Flatten the output
x = self.dropout(self.relu(self.fc1(x)))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# Load the trained model
model = ComplexCNN(num_classes=4)
model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
model.eval()
# Define the transformation
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# Define the class labels
class_labels = ['dog', 'goat', 'lion', 'sheep']
# Function to predict the class of an uploaded image
def predict(image):
image = transform(image).unsqueeze(0) # Transform and add batch dimension
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
predicted_class = class_labels[predicted.item()]
return predicted_class
# Create the Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Predicted Class")
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()