import torch from torch.utils.data import Subset from PyQt5.QtGui import QImage import numpy as np from torchvision import datasets from torchvision.transforms import ToTensor from torch.utils.data import DataLoader device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') np.random.seed(123456) torch.manual_seed(123456) class MyModel(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=2) self.pool1 = torch.nn.MaxPool2d(kernel_size=2) self.fc1 = torch.nn.Linear(in_features=720, out_features=10) def forward(self, x): x = self.conv1(x) x = torch.relu(x) x = self.pool1(x) x = torch.flatten(x, start_dim=1) x = self.fc1(x) x = torch.nn.functional.softmax(x, dim=1) return x def train_and_save(save_path='mnist_cnn.pth'): # 数据加载 mnist = datasets.MNIST( root="data", train=True, download=True, transform=ToTensor() ) subset = Subset(mnist, indices=range(60000)) loader = DataLoader(subset, batch_size=60000, shuffle=True) x, y = next(iter(loader)) nepoch = 30 batch_size = 200 lr = 0.001 np.random.seed(123) torch.manual_seed(123) model = MyModel() losses = [] opt = torch.optim.Adam(model.parameters(), lr=lr) n = x.shape[0] obs_id = np.arange(n) # [0, 1, ..., n-1] # Run the whole data set `nepoch` times for i in range(nepoch): # Shuffle observation IDs np.random.shuffle(obs_id) # Update on mini-batches for j in range(0, n, batch_size): # Create mini-batch x_mini_batch = x[obs_id[j:(j + batch_size)]] y_mini_batch = y[obs_id[j:(j + batch_size)]] # Compute loss pred = model(x_mini_batch) lossfn = torch.nn.NLLLoss() loss = lossfn(torch.log(pred), y_mini_batch) # Compute gradient and update parameters opt.zero_grad() loss.backward() opt.step() losses.append(loss.item()) if (j // batch_size) % 20 == 0: print(f"epoch {i}, batch {j // batch_size}, loss = {loss.item()}") torch.save({ 'model_state': model.state_dict(), 'input_size': (1, 28, 28), 'output_size': 10 }, save_path) # 函数:加载已训练模型 def load_trained_model(model_path='mnist_cnn.pth'): model = MyModel() checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state']) model.eval() return model def predict_user_image(grid_matrix,model,device): # 输入验证 if grid_matrix.shape != (28, 28): raise ValueError(f"输入形状应为(28,28),实际获取{grid_matrix.shape}") # 预处理 input_tensor = torch.from_numpy(grid_matrix).float().unsqueeze(0).unsqueeze(0) / 255.0 # 推理:确保将张量移动到正确的设备(GPU 或 CPU) input_tensor = input_tensor.to(device) # 推理 with torch.no_grad(): outputs = model(input_tensor.to(device)) # 获取预测结果 probabilities = outputs.cpu().numpy().flatten() pred_class = int(np.argmax(probabilities)) # 关键修复:必须返回两个值 return pred_class, probabilities if __name__ == '__main__': # 训练并保存模型 train_and_save() model = load_trained_model()