|
|
---
|
|
|
tags:
|
|
|
- image-classification
|
|
|
- pytorch
|
|
|
- mnist
|
|
|
license: apache-2.0
|
|
|
library_name: pytorch
|
|
|
pipeline_tag: image-classification
|
|
|
---
|
|
|
|
|
|
# MNIST Digit Classifier
|
|
|
|
|
|
A convolutional neural network trained on MNIST to classify digits 0-9.
|
|
|
|
|
|
## Usage
|
|
|
|
|
|
```python
|
|
|
from src.model import DigitClassifier
|
|
|
import torch
|
|
|
|
|
|
model = DigitClassifier()
|
|
|
model.load_state_dict(torch.load("model_weights.pth"))
|
|
|
model.eval()
|
|
|
|
|
|
# Preprocessing (same as training):
|
|
|
transform = transforms.Compose([
|
|
|
transforms.Resize((28, 28)),
|
|
|
transforms.Grayscale(),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize((0.1307,), (0.3081,))
|
|
|
])
|
|
|
``` |