mnist-cnn / inference.py
Anni0401's picture
Add MNIST CNN model and inference code
4414e2c verified
raw
history blame contribute delete
593 Bytes
import torch
import numpy as np
from src.models.cnn import MNISTCNN
def load_model(model_path="mnist_cnn.pth", device="cpu"):
model = MNISTCNN()
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
return model
def predict(image: np.ndarray, model):
"""
image: numpy array of shape (28, 28), values [0, 255]
"""
image = image / 255.0
image = torch.tensor(image, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
logits = model(image)
pred = logits.argmax(dim=1).item()
return pred