WeekendZhou's picture
ui.py里面是python QT的界面,我不会JS。
b98d6e3 verified
import numpy as np
import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Subset
from PIL import Image
from PyQt5.QtGui import QImage
from PyQt5.QtCore import QSize
from PyQt5.Qt import Qt
import numpy as np
import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(123456)
torch.manual_seed(123456)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=2)
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.fc1 = torch.nn.Linear(in_features=720, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = torch.relu(x)
x = self.pool1(x)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = torch.nn.functional.softmax(x, dim=1)
return x
def train_and_save(save_path='mnist_cnn.pth'):
# 数据加载
mnist = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
subset = Subset(mnist, indices=range(60000))
loader = DataLoader(subset, batch_size=60000, shuffle=True)
x, y = next(iter(loader))
nepoch = 30
batch_size = 200
lr = 0.001
np.random.seed(123)
torch.manual_seed(123)
model = MyModel()
losses = []
opt = torch.optim.Adam(model.parameters(), lr=lr)
n = x.shape[0]
obs_id = np.arange(n) # [0, 1, ..., n-1]
# Run the whole data set `nepoch` times
for i in range(nepoch):
# Shuffle observation IDs
np.random.shuffle(obs_id)
# Update on mini-batches
for j in range(0, n, batch_size):
# Create mini-batch
x_mini_batch = x[obs_id[j:(j + batch_size)]]
y_mini_batch = y[obs_id[j:(j + batch_size)]]
# Compute loss
pred = model(x_mini_batch)
lossfn = torch.nn.NLLLoss()
loss = lossfn(torch.log(pred), y_mini_batch)
# Compute gradient and update parameters
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
if (j // batch_size) % 20 == 0:
print(f"epoch {i}, batch {j // batch_size}, loss = {loss.item()}")
torch.save({
'model_state': model.state_dict(),
'input_size': (1, 28, 28),
'output_size': 10
}, save_path)
# 函数:加载已训练模型
def load_trained_model(model_path='mnist_cnn.pth'):
model = MyModel()
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model
def predict_user_image(img_qimage,model):
"""
:param img_qimage: 来自绘图板的QImage对象(需要是28x28大小)
:return: (预测结果, 概率分布数组)
"""
# 确保图像是Grayscale8格式
if img_qimage.format() != QImage.Format_Grayscale8:
img_qimage = img_qimage.convertToFormat(QImage.Format_Grayscale8)
# 正确获取QImage二进制数据 (重要:PyQt和PySide的bits()方法差异)
# PyQt使用bits().tobytes(),PySide直接访问bits
if isinstance(img_qimage, QImage):
ptr = img_qimage.bits() # 获取内存指针
ptr.setsize(img_qimage.byteCount()) # 设置数据大小(PyQt需要)
img_bytes = bytes(ptr) # 转换为bytes
else:
raise ValueError("输入的图像必须是QImage对象")
# 转换为numpy数组 (注意dtype与数值范围)
img_array = np.frombuffer(img_bytes, dtype=np.uint8).reshape(28, 28).astype(np.float32)
# 转换为张量并归一化(黑底白字无需反转)
tensor_img = torch.tensor(img_array / 255.0).unsqueeze(0).unsqueeze(0).float()
# 预测逻辑
with torch.no_grad():
output = model(tensor_img)
probs = np.round(output.detach().cpu().numpy(), 3) # 修正概率计算
pred = torch.argmax(output).item()
return pred, probs
if __name__ == '__main__':
# 训练并保存模型
train_and_save()
model = load_trained_model()