Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| # Define the CNN model | |
| class CNN(torch.nn.Module): | |
| def __init__(self): | |
| super(CNN, self).__init__() | |
| self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) | |
| self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) | |
| self.pool = torch.nn.MaxPool2d(2, 2) | |
| self.fc1 = torch.nn.Linear(64 * 14 * 14, 128) | |
| self.fc2 = torch.nn.Linear(128, 10) | |
| self.relu = torch.nn.ReLU() | |
| self.dropout = torch.nn.Dropout(0.25) | |
| def forward(self, x): | |
| x = self.relu(self.conv1(x)) | |
| x = self.pool(self.relu(self.conv2(x))) | |
| x = x.view(x.size(0), -1) # Flatten dynamically based on batch size | |
| x = self.relu(self.fc1(x)) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| return x | |
| # Load the trained model | |
| model = CNN() | |
| model.load_state_dict(torch.load("pytorch_model.bin", map_location=torch.device('cpu'), weights_only=True)) | |
| model.eval() | |
| # Define the prediction function | |
| def predict(image): | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(), # Ensure the input image is grayscale | |
| transforms.Resize((28, 28)), # Resize the image to 28x28 pixels | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5,), (0.5,)) # Normalize the image | |
| ]) | |
| image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| predicted_class = output.argmax(dim=1).item() # Get the predicted class | |
| return f"Predicted digit: {predicted_class}" | |
| # Create the Gradio interface | |
| interface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), # Updated input component | |
| outputs="text", | |
| title="Handwritten Digit Classifier", | |
| description="Upload an image of a handwritten digit, and the model will predict the digit." | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| interface.launch() # Removed share=True for Hugging Face Spaces | |