| language: en | |
| license: mit | |
| library_name: pytorch | |
| tags: | |
| - mnist | |
| - image-classification | |
| - neural-network | |
| datasets: | |
| - mnist | |
| metrics: | |
| - accuracy | |
| # Simple PyTorch Neural Network for MNIST | |
| This model is a basic feed-forward neural network trained on the MNIST dataset as part of a PyTorch tutorial. | |
| ## Model Architecture | |
| The model consists of: | |
| 1. **Input Layer**: 784 neurons (28x28 flattened images). | |
| 2. **Hidden Layer**: 128 neurons with ReLU activation. | |
| 3. **Output Layer**: 10 neurons (one for each digit from 0-9). | |
| ## Training Details | |
| - **Dataset**: MNIST (60,000 training images, 10,000 test images) | |
| - **Epochs**: 5 (by default) | |
| - **Optimizer**: Adam (lr=0.001) | |
| - **Loss Function**: CrossEntropyLoss | |
| ## Usage | |
| To load this model in your PyTorch project: | |
| ```python | |
| import torch | |
| from simple_nn import SimpleNN | |
| # 1. Initialize the model architecture | |
| model = SimpleNN() | |
| # 2. Load the state dictionary | |
| model.load_state_dict(torch.load("model.pth")) | |
| model.eval() | |
| ``` | |
| ## Dataset Information | |
| The MNIST dataset consists of 28x28 grayscale images of the 10 digits. It is a classic dataset for image classification tasks. | |