import torch from torch.utils.data import Subset from PyQt5.QtGui import QImage import numpy as np from torchvision import datasets from torchvision.transforms import ToTensor from torch.utils.data import DataLoader device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') np.random.seed(123456) torch.manual_seed(123456) class MyModel(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=2) self.pool1 = torch.nn.MaxPool2d(kernel_size=2) self.fc1 = torch.nn.Linear(in_features=720, out_features=10) def forward(self, x): x = self.conv1(x) x = torch.relu(x) x = self.pool1(x) x = torch.flatten(x, start_dim=1) x = self.fc1(x) x = torch.nn.functional.softmax(x, dim=1) return x def train_and_save(save_path='mnist_cnn.pth'): # 数据加载 mnist = datasets.MNIST( root="data", train=True, download=True, transform=ToTensor() ) subset = Subset(mnist, indices=range(60000)) loader = DataLoader(subset, batch_size=60000, shuffle=True) x, y = next(iter(loader)) nepoch = 30 batch_size = 200 lr = 0.001 np.random.seed(123) torch.manual_seed(123) model = MyModel() losses = [] opt = torch.optim.Adam(model.parameters(), lr=lr) n = x.shape[0] obs_id = np.arange(n) # [0, 1, ..., n-1] # Run the whole data set `nepoch` times for i in range(nepoch): # Shuffle observation IDs np.random.shuffle(obs_id) # Update on mini-batches for j in range(0, n, batch_size): # Create mini-batch x_mini_batch = x[obs_id[j:(j + batch_size)]] y_mini_batch = y[obs_id[j:(j + batch_size)]] # Compute loss pred = model(x_mini_batch) lossfn = torch.nn.NLLLoss() loss = lossfn(torch.log(pred), y_mini_batch) # Compute gradient and update parameters opt.zero_grad() loss.backward() opt.step() losses.append(loss.item()) if (j // batch_size) % 20 == 0: print(f"epoch {i}, batch {j // batch_size}, loss = {loss.item()}") torch.save({ 'model_state': model.state_dict(), 'input_size': (1, 28, 28), 'output_size': 10 }, save_path) # 函数:加载已训练模型 def load_trained_model(model_path='mnist_cnn.pth'): model = MyModel() checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state']) model.eval() return model def predict_user_image(img_qimage,model): """ :param img_qimage: 来自绘图板的QImage对象(需要是28x28大小) :return: (预测结果, 概率分布数组) """ # 确保图像是Grayscale8格式 if img_qimage.format() != QImage.Format_Grayscale8: img_qimage = img_qimage.convertToFormat(QImage.Format_Grayscale8) # 正确获取QImage二进制数据 (重要:PyQt和PySide的bits()方法差异) # PyQt使用bits().tobytes(),PySide直接访问bits if isinstance(img_qimage, QImage): ptr = img_qimage.bits() # 获取内存指针 ptr.setsize(img_qimage.byteCount()) # 设置数据大小(PyQt需要) img_bytes = bytes(ptr) # 转换为bytes else: raise ValueError("输入的图像必须是QImage对象") # 转换为numpy数组 (注意dtype与数值范围) img_array = np.frombuffer(img_bytes, dtype=np.uint8).reshape(28, 28).astype(np.float32) # 转换为张量并归一化(黑底白字无需反转) tensor_img = torch.tensor(img_array / 255.0).unsqueeze(0).unsqueeze(0).float() # 预测逻辑 with torch.no_grad(): output = model(tensor_img) probs = np.round(output.detach().cpu().numpy(), 3) # 修正概率计算 pred = torch.argmax(output).item() return pred, probs if __name__ == '__main__': # 训练并保存模型 train_and_save() model = load_trained_model()