|
|
import numpy as np
|
|
|
import torch
|
|
|
import torchvision
|
|
|
from torchvision import datasets
|
|
|
from torchvision.transforms import ToTensor
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch.utils.data import DataLoader, Subset
|
|
|
from PIL import Image
|
|
|
from PyQt5.QtGui import QImage
|
|
|
from PyQt5.QtCore import QSize
|
|
|
from PyQt5.Qt import Qt
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torchvision
|
|
|
from torchvision import datasets
|
|
|
from torchvision.transforms import ToTensor
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
np.random.seed(123456)
|
|
|
torch.manual_seed(123456)
|
|
|
|
|
|
|
|
|
class MyModel(torch.nn.Module):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=2)
|
|
|
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
|
|
|
self.fc1 = torch.nn.Linear(in_features=720, out_features=10)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = self.conv1(x)
|
|
|
x = torch.relu(x)
|
|
|
x = self.pool1(x)
|
|
|
x = torch.flatten(x, start_dim=1)
|
|
|
x = self.fc1(x)
|
|
|
x = torch.nn.functional.softmax(x, dim=1)
|
|
|
return x
|
|
|
|
|
|
|
|
|
def train_and_save(save_path='mnist_cnn.pth'):
|
|
|
|
|
|
mnist = datasets.MNIST(
|
|
|
root="data",
|
|
|
train=True,
|
|
|
download=True,
|
|
|
transform=ToTensor()
|
|
|
)
|
|
|
subset = Subset(mnist, indices=range(60000))
|
|
|
loader = DataLoader(subset, batch_size=60000, shuffle=True)
|
|
|
x, y = next(iter(loader))
|
|
|
|
|
|
nepoch = 30
|
|
|
batch_size = 200
|
|
|
lr = 0.001
|
|
|
|
|
|
np.random.seed(123)
|
|
|
torch.manual_seed(123)
|
|
|
|
|
|
model = MyModel()
|
|
|
losses = []
|
|
|
opt = torch.optim.Adam(model.parameters(), lr=lr)
|
|
|
|
|
|
n = x.shape[0]
|
|
|
obs_id = np.arange(n)
|
|
|
|
|
|
for i in range(nepoch):
|
|
|
|
|
|
np.random.shuffle(obs_id)
|
|
|
|
|
|
|
|
|
for j in range(0, n, batch_size):
|
|
|
|
|
|
x_mini_batch = x[obs_id[j:(j + batch_size)]]
|
|
|
y_mini_batch = y[obs_id[j:(j + batch_size)]]
|
|
|
|
|
|
pred = model(x_mini_batch)
|
|
|
lossfn = torch.nn.NLLLoss()
|
|
|
loss = lossfn(torch.log(pred), y_mini_batch)
|
|
|
|
|
|
opt.zero_grad()
|
|
|
loss.backward()
|
|
|
opt.step()
|
|
|
losses.append(loss.item())
|
|
|
|
|
|
if (j // batch_size) % 20 == 0:
|
|
|
print(f"epoch {i}, batch {j // batch_size}, loss = {loss.item()}")
|
|
|
|
|
|
|
|
|
torch.save({
|
|
|
'model_state': model.state_dict(),
|
|
|
'input_size': (1, 28, 28),
|
|
|
'output_size': 10
|
|
|
}, save_path)
|
|
|
|
|
|
|
|
|
def load_trained_model(model_path='mnist_cnn.pth'):
|
|
|
model = MyModel()
|
|
|
checkpoint = torch.load(model_path, map_location=device)
|
|
|
model.load_state_dict(checkpoint['model_state'])
|
|
|
model.eval()
|
|
|
return model
|
|
|
|
|
|
def predict_user_image(img_qimage,model):
|
|
|
"""
|
|
|
:param img_qimage: 来自绘图板的QImage对象(需要是28x28大小)
|
|
|
:return: (预测结果, 概率分布数组)
|
|
|
"""
|
|
|
|
|
|
if img_qimage.format() != QImage.Format_Grayscale8:
|
|
|
img_qimage = img_qimage.convertToFormat(QImage.Format_Grayscale8)
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(img_qimage, QImage):
|
|
|
ptr = img_qimage.bits()
|
|
|
ptr.setsize(img_qimage.byteCount())
|
|
|
img_bytes = bytes(ptr)
|
|
|
else:
|
|
|
raise ValueError("输入的图像必须是QImage对象")
|
|
|
|
|
|
|
|
|
img_array = np.frombuffer(img_bytes, dtype=np.uint8).reshape(28, 28).astype(np.float32)
|
|
|
|
|
|
|
|
|
tensor_img = torch.tensor(img_array / 255.0).unsqueeze(0).unsqueeze(0).float()
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output = model(tensor_img)
|
|
|
probs = np.round(output.detach().cpu().numpy(), 3)
|
|
|
pred = torch.argmax(output).item()
|
|
|
|
|
|
return pred, probs
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
train_and_save()
|
|
|
|
|
|
model = load_trained_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|