WeekendZhou commited on
Commit
b98d6e3
·
verified ·
1 Parent(s): fe83a34

ui.py里面是python QT的界面,我不会JS。

Browse files
Files changed (4) hide show
  1. gradio_ui.py +71 -0
  2. mnist_cnn.pth +3 -0
  3. test.py +147 -0
  4. ui.py +141 -0
gradio_ui.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ import test # 假设 test 模块包含预测逻辑
6
+
7
+ # 加载模型 (与 Qt 版本保持一致)
8
+ model = test.load_trained_model()
9
+
10
+
11
+ def predict_interface(sketch_image):
12
+ """处理绘制图像的预测逻辑"""
13
+ if sketch_image is None:
14
+ return "请先绘制数字", {}
15
+
16
+ # 将 sketchpad 的 numpy 数组转换为模型需要的格式
17
+ img = Image.fromarray(sketch_image).convert('L') # 转换为灰度图
18
+
19
+ # 可能需要添加预处理步骤(根据 test.predict_user_image 的接口调整)
20
+ # 如果用原始 Qt 的预处理逻辑,这里可以复用 test 模块的函数
21
+ pred_class, probabilities = test.predict_user_image(img, model)
22
+
23
+ # 转换概率为字典供 Label 组件显示
24
+ prob_dict = {str(i): float(prob) for i, prob in enumerate(probabilities)}
25
+ return f"识别结果: {pred_class}", prob_dict
26
+
27
+
28
+ def clear_canvas():
29
+ """清空画布的函数"""
30
+ return None, "识别结果: ", {}
31
+
32
+
33
+ # 构建 Gradio 界面
34
+ with gr.Blocks(title="手写数字识别") as demo:
35
+ gr.Markdown("# 手写数字识别系统")
36
+
37
+ with gr.Row():
38
+ # 手写板组件 (调整尺寸匹配原 Qt 设计)
39
+ sketch = gr.Sketchpad(
40
+ label="绘制区域",
41
+ shape=(750, 750),
42
+ brush_radius=15, # 根据原 Qt 的笔刷大小调整
43
+ image_mode="L", # 灰度模式
44
+ invert_colors=True # 反转颜色(白底黑字)
45
+ )
46
+
47
+ # 结果显示区域
48
+ with gr.Column():
49
+ result_label = gr.Label(label="概率分布", num_top_classes=5)
50
+ output_text = gr.Markdown("识别结果: ")
51
+
52
+ # 按钮行
53
+ with gr.Row():
54
+ clear_btn = gr.Button("清除", variant="secondary")
55
+ submit_btn = gr.Button("识别", variant="primary")
56
+
57
+ # 绑定交互事件
58
+ submit_btn.click(
59
+ fn=predict_interface,
60
+ inputs=sketch,
61
+ outputs=[output_text, result_label]
62
+ )
63
+
64
+ clear_btn.click(
65
+ fn=lambda: [None, "识别结果: ", None], # 清空所有输出
66
+ outputs=[sketch, output_text, result_label]
67
+ )
68
+
69
+ # 启动应用
70
+ if __name__ == "__main__":
71
+ demo.launch()
mnist_cnn.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ffbe9357dc0bd7ccdc850bea88c2dd393ab02e691093550b8db97968417e13c
3
+ size 33056
test.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchvision
4
+ from torchvision import datasets
5
+ from torchvision.transforms import ToTensor
6
+ from torch.utils.data import DataLoader
7
+ from torch.utils.data import DataLoader, Subset
8
+ from PIL import Image
9
+ from PyQt5.QtGui import QImage
10
+ from PyQt5.QtCore import QSize
11
+ from PyQt5.Qt import Qt
12
+ import numpy as np
13
+ import torch
14
+ import torchvision
15
+ from torchvision import datasets
16
+ from torchvision.transforms import ToTensor
17
+ from torch.utils.data import DataLoader
18
+
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+
21
+ np.random.seed(123456)
22
+ torch.manual_seed(123456)
23
+
24
+
25
+ class MyModel(torch.nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=2)
29
+ self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
30
+ self.fc1 = torch.nn.Linear(in_features=720, out_features=10)
31
+
32
+ def forward(self, x):
33
+ x = self.conv1(x)
34
+ x = torch.relu(x)
35
+ x = self.pool1(x)
36
+ x = torch.flatten(x, start_dim=1)
37
+ x = self.fc1(x)
38
+ x = torch.nn.functional.softmax(x, dim=1)
39
+ return x
40
+
41
+
42
+ def train_and_save(save_path='mnist_cnn.pth'):
43
+ # 数据加载
44
+ mnist = datasets.MNIST(
45
+ root="data",
46
+ train=True,
47
+ download=True,
48
+ transform=ToTensor()
49
+ )
50
+ subset = Subset(mnist, indices=range(60000))
51
+ loader = DataLoader(subset, batch_size=60000, shuffle=True)
52
+ x, y = next(iter(loader))
53
+
54
+ nepoch = 30
55
+ batch_size = 200
56
+ lr = 0.001
57
+
58
+ np.random.seed(123)
59
+ torch.manual_seed(123)
60
+
61
+ model = MyModel()
62
+ losses = []
63
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
64
+
65
+ n = x.shape[0]
66
+ obs_id = np.arange(n) # [0, 1, ..., n-1]
67
+ # Run the whole data set `nepoch` times
68
+ for i in range(nepoch):
69
+ # Shuffle observation IDs
70
+ np.random.shuffle(obs_id)
71
+
72
+ # Update on mini-batches
73
+ for j in range(0, n, batch_size):
74
+ # Create mini-batch
75
+ x_mini_batch = x[obs_id[j:(j + batch_size)]]
76
+ y_mini_batch = y[obs_id[j:(j + batch_size)]]
77
+ # Compute loss
78
+ pred = model(x_mini_batch)
79
+ lossfn = torch.nn.NLLLoss()
80
+ loss = lossfn(torch.log(pred), y_mini_batch)
81
+ # Compute gradient and update parameters
82
+ opt.zero_grad()
83
+ loss.backward()
84
+ opt.step()
85
+ losses.append(loss.item())
86
+
87
+ if (j // batch_size) % 20 == 0:
88
+ print(f"epoch {i}, batch {j // batch_size}, loss = {loss.item()}")
89
+
90
+
91
+ torch.save({
92
+ 'model_state': model.state_dict(),
93
+ 'input_size': (1, 28, 28),
94
+ 'output_size': 10
95
+ }, save_path)
96
+
97
+ # 函数:加载已训练模型
98
+ def load_trained_model(model_path='mnist_cnn.pth'):
99
+ model = MyModel()
100
+ checkpoint = torch.load(model_path, map_location=device)
101
+ model.load_state_dict(checkpoint['model_state'])
102
+ model.eval()
103
+ return model
104
+
105
+ def predict_user_image(img_qimage,model):
106
+ """
107
+ :param img_qimage: 来自绘图板的QImage对象(需要是28x28大小)
108
+ :return: (预测结果, 概率分布数组)
109
+ """
110
+ # 确保图像是Grayscale8格式
111
+ if img_qimage.format() != QImage.Format_Grayscale8:
112
+ img_qimage = img_qimage.convertToFormat(QImage.Format_Grayscale8)
113
+
114
+ # 正确获取QImage二进制数据 (重要:PyQt和PySide的bits()方法差异)
115
+ # PyQt使用bits().tobytes(),PySide直接访问bits
116
+ if isinstance(img_qimage, QImage):
117
+ ptr = img_qimage.bits() # 获取内存指针
118
+ ptr.setsize(img_qimage.byteCount()) # 设置数据大小(PyQt需要)
119
+ img_bytes = bytes(ptr) # 转换为bytes
120
+ else:
121
+ raise ValueError("输入的图像必须是QImage对象")
122
+
123
+ # 转换为numpy数组 (注意dtype与数值范围)
124
+ img_array = np.frombuffer(img_bytes, dtype=np.uint8).reshape(28, 28).astype(np.float32)
125
+
126
+ # 转换为张量并归一化(黑底白字无需反转)
127
+ tensor_img = torch.tensor(img_array / 255.0).unsqueeze(0).unsqueeze(0).float()
128
+
129
+ # 预测逻辑
130
+ with torch.no_grad():
131
+ output = model(tensor_img)
132
+ probs = np.round(output.detach().cpu().numpy(), 3) # 修正概率计算
133
+ pred = torch.argmax(output).item()
134
+
135
+ return pred, probs
136
+
137
+ if __name__ == '__main__':
138
+ # 训练并保存模型
139
+ train_and_save()
140
+
141
+ model = load_trained_model()
142
+
143
+
144
+
145
+
146
+
147
+
ui.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QLabel, QPlainTextEdit, QMainWindow,QHBoxLayout
3
+ from PyQt5.QtCore import Qt, QPoint
4
+ from PyQt5.QtGui import QPainter, QImage, QColor,QPen
5
+ import test
6
+ import numpy as np
7
+ import torch
8
+
9
+ # 在主程序(UI端)
10
+ from test import load_trained_model
11
+
12
+ model = load_trained_model()
13
+
14
+
15
+ class DrawingArea(QWidget):
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.setFixedSize(750 + 40, 750 + 40) # 增加边距显示网格
19
+ self.drawing = False
20
+ self.last_pos = QPoint()
21
+
22
+ # 实际绘图为28x28的画布
23
+ self.image = QImage(28, 28, QImage.Format_RGB888)
24
+ self.image.fill(Qt.black)
25
+
26
+ # 计算放缩比例
27
+ self.cell_size = 750 // 28
28
+
29
+ def paintEvent(self, event):
30
+ painter = QPainter(self)
31
+ painter.setRenderHint(QPainter.Antialiasing, False)
32
+
33
+ # 绘制放大后的图像
34
+ scaled_img = self.image.scaled(750, 750, Qt.KeepAspectRatio, Qt.FastTransformation)
35
+ painter.drawImage(20, 20, scaled_img)
36
+
37
+ # 绘制网格线
38
+ painter.setPen(QPen(Qt.gray, 1, Qt.SolidLine))
39
+ for i in range(29):
40
+ # 水平线
41
+ painter.drawLine(20, 20 + i * self.cell_size,
42
+ 20 + 750, 20 + i * self.cell_size)
43
+ # 垂直线
44
+ painter.drawLine(20 + i * self.cell_size, 20,
45
+ 20 + i * self.cell_size, 20 + 750)
46
+
47
+ def mousePressEvent(self, event):
48
+ if event.button() == Qt.LeftButton:
49
+ self.drawing = True
50
+ self.handleDrawing(event.pos())
51
+
52
+ def mouseMoveEvent(self, event):
53
+ if self.drawing:
54
+ self.handleDrawing(event.pos())
55
+
56
+ def mouseReleaseEvent(self, event):
57
+ if event.button() == Qt.LeftButton:
58
+ self.drawing = False
59
+
60
+ def handleDrawing(self, pos):
61
+ # 转换为画布坐标(减去边距)
62
+ x = pos.x() - 20
63
+ y = pos.y() - 20
64
+
65
+ # 当在画布范围内时进行处理
66
+ if 0 <= x < 750 and 0 <= y < 750:
67
+ # 转换到28x28坐标
68
+ col = x // self.cell_size
69
+ row = y // self.cell_size
70
+
71
+ # 防止重复绘制同一位置
72
+ if (col, row) != self.last_pos:
73
+ self.last_pos = (col, row)
74
+ painter = QPainter(self.image)
75
+ painter.setPen(Qt.white)
76
+ painter.drawPoint(col, row)
77
+ self.update()
78
+
79
+ def get_image(self):
80
+ return self.image.convertToFormat(QImage.Format_Grayscale8)
81
+
82
+ def clear_image(self):
83
+ self.image.fill(Qt.black)
84
+ self.update()
85
+
86
+
87
+ class MainWindow(QMainWindow):
88
+ def __init__(self):
89
+ super().__init__()
90
+ self.init_ui()
91
+
92
+ def init_ui(self):
93
+ # 窗口
94
+ self.setWindowTitle("手写识别")
95
+ self.setFixedSize(850, 950)
96
+
97
+ # 布局
98
+ main_widget = QWidget()
99
+ self.setCentralWidget(main_widget)
100
+ layout = QVBoxLayout(main_widget)
101
+
102
+ # 绘图
103
+ self.drawing_area = DrawingArea()
104
+ layout.addWidget(self.drawing_area)
105
+
106
+ # 按钮
107
+ btn_layout = QHBoxLayout()
108
+ self.clear_btn = QPushButton("清除")
109
+ self.recognize_btn = QPushButton("识别")
110
+ btn_layout.addWidget(self.clear_btn)
111
+ btn_layout.addWidget(self.recognize_btn)
112
+
113
+ # 结果
114
+ self.prob_label = QLabel("概率分布:")
115
+ self.result_label = QLabel("识别结果:")
116
+
117
+ # 组装
118
+ layout.addLayout(btn_layout)
119
+ layout.addWidget(self.prob_label)
120
+ layout.addWidget(self.result_label)
121
+ # 信号连接
122
+ self.clear_btn.clicked.connect(self.drawing_area.clear_image)
123
+ self.recognize_btn.clicked.connect(self.recognize)
124
+
125
+
126
+
127
+ def recognize(self):
128
+ # 获取原始绘图区图像
129
+ qimg = self.drawing_area.get_image()
130
+ #预测
131
+ pred_class, probabilities = test.predict_user_image(qimg,model)
132
+
133
+ self.prob_label.setText(f"概率分布: {probabilities}")
134
+ self.result_label.setText(f"识别结果: {pred_class}")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ app = QApplication(sys.argv)
139
+ window = MainWindow()
140
+ window.show()
141
+ sys.exit(app.exec_())