Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data import Subset | |
| from PyQt5.QtGui import QImage | |
| import numpy as np | |
| from torchvision import datasets | |
| from torchvision.transforms import ToTensor | |
| from torch.utils.data import DataLoader | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| np.random.seed(123456) | |
| torch.manual_seed(123456) | |
| class MyModel(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=2) | |
| self.pool1 = torch.nn.MaxPool2d(kernel_size=2) | |
| self.fc1 = torch.nn.Linear(in_features=720, out_features=10) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = torch.relu(x) | |
| x = self.pool1(x) | |
| x = torch.flatten(x, start_dim=1) | |
| x = self.fc1(x) | |
| x = torch.nn.functional.softmax(x, dim=1) | |
| return x | |
| def train_and_save(save_path='mnist_cnn.pth'): | |
| # 数据加载 | |
| mnist = datasets.MNIST( | |
| root="data", | |
| train=True, | |
| download=True, | |
| transform=ToTensor() | |
| ) | |
| subset = Subset(mnist, indices=range(60000)) | |
| loader = DataLoader(subset, batch_size=60000, shuffle=True) | |
| x, y = next(iter(loader)) | |
| nepoch = 30 | |
| batch_size = 200 | |
| lr = 0.001 | |
| np.random.seed(123) | |
| torch.manual_seed(123) | |
| model = MyModel() | |
| losses = [] | |
| opt = torch.optim.Adam(model.parameters(), lr=lr) | |
| n = x.shape[0] | |
| obs_id = np.arange(n) # [0, 1, ..., n-1] | |
| # Run the whole data set `nepoch` times | |
| for i in range(nepoch): | |
| # Shuffle observation IDs | |
| np.random.shuffle(obs_id) | |
| # Update on mini-batches | |
| for j in range(0, n, batch_size): | |
| # Create mini-batch | |
| x_mini_batch = x[obs_id[j:(j + batch_size)]] | |
| y_mini_batch = y[obs_id[j:(j + batch_size)]] | |
| # Compute loss | |
| pred = model(x_mini_batch) | |
| lossfn = torch.nn.NLLLoss() | |
| loss = lossfn(torch.log(pred), y_mini_batch) | |
| # Compute gradient and update parameters | |
| opt.zero_grad() | |
| loss.backward() | |
| opt.step() | |
| losses.append(loss.item()) | |
| if (j // batch_size) % 20 == 0: | |
| print(f"epoch {i}, batch {j // batch_size}, loss = {loss.item()}") | |
| torch.save({ | |
| 'model_state': model.state_dict(), | |
| 'input_size': (1, 28, 28), | |
| 'output_size': 10 | |
| }, save_path) | |
| # 函数:加载已训练模型 | |
| def load_trained_model(model_path='mnist_cnn.pth'): | |
| model = MyModel() | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state']) | |
| model.eval() | |
| return model | |
| def predict_user_image(img_qimage,model): | |
| """ | |
| :param img_qimage: 来自绘图板的QImage对象(需要是28x28大小) | |
| :return: (预测结果, 概率分布数组) | |
| """ | |
| # 确保图像是Grayscale8格式 | |
| if img_qimage.format() != QImage.Format_Grayscale8: | |
| img_qimage = img_qimage.convertToFormat(QImage.Format_Grayscale8) | |
| # 正确获取QImage二进制数据 (重要:PyQt和PySide的bits()方法差异) | |
| # PyQt使用bits().tobytes(),PySide直接访问bits | |
| if isinstance(img_qimage, QImage): | |
| ptr = img_qimage.bits() # 获取内存指针 | |
| ptr.setsize(img_qimage.byteCount()) # 设置数据大小(PyQt需要) | |
| img_bytes = bytes(ptr) # 转换为bytes | |
| else: | |
| raise ValueError("输入的图像必须是QImage对象") | |
| # 转换为numpy数组 (注意dtype与数值范围) | |
| img_array = np.frombuffer(img_bytes, dtype=np.uint8).reshape(28, 28).astype(np.float32) | |
| # 转换为张量并归一化(黑底白字无需反转) | |
| tensor_img = torch.tensor(img_array / 255.0).unsqueeze(0).unsqueeze(0).float() | |
| # 预测逻辑 | |
| with torch.no_grad(): | |
| output = model(tensor_img) | |
| probs = np.round(output.detach().cpu().numpy(), 3) # 修正概率计算 | |
| pred = torch.argmax(output).item() | |
| return pred, probs | |
| if __name__ == '__main__': | |
| # 训练并保存模型 | |
| train_and_save() | |
| model = load_trained_model() | |