import sys from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QLabel, QPlainTextEdit, QMainWindow,QHBoxLayout from PyQt5.QtCore import Qt, QPoint from PyQt5.QtGui import QPainter, QImage, QColor,QPen import test import numpy as np import torch # 在主程序(UI端) from test import load_trained_model model = load_trained_model() class DrawingArea(QWidget): def __init__(self): super().__init__() self.setFixedSize(750 + 40, 750 + 40) # 增加边距显示网格 self.drawing = False self.last_pos = QPoint() # 实际绘图为28x28的画布 self.image = QImage(28, 28, QImage.Format_RGB888) self.image.fill(Qt.black) # 计算放缩比例 self.cell_size = 750 // 28 def paintEvent(self, event): painter = QPainter(self) painter.setRenderHint(QPainter.Antialiasing, False) # 绘制放大后的图像 scaled_img = self.image.scaled(750, 750, Qt.KeepAspectRatio, Qt.FastTransformation) painter.drawImage(20, 20, scaled_img) # 绘制网格线 painter.setPen(QPen(Qt.gray, 1, Qt.SolidLine)) for i in range(29): # 水平线 painter.drawLine(20, 20 + i * self.cell_size, 20 + 750, 20 + i * self.cell_size) # 垂直线 painter.drawLine(20 + i * self.cell_size, 20, 20 + i * self.cell_size, 20 + 750) def mousePressEvent(self, event): if event.button() == Qt.LeftButton: self.drawing = True self.handleDrawing(event.pos()) def mouseMoveEvent(self, event): if self.drawing: self.handleDrawing(event.pos()) def mouseReleaseEvent(self, event): if event.button() == Qt.LeftButton: self.drawing = False def handleDrawing(self, pos): # 转换为画布坐标(减去边距) x = pos.x() - 20 y = pos.y() - 20 # 当在画布范围内时进行处理 if 0 <= x < 750 and 0 <= y < 750: # 转换到28x28坐标 col = x // self.cell_size row = y // self.cell_size # 防止重复绘制同一位置 if (col, row) != self.last_pos: self.last_pos = (col, row) painter = QPainter(self.image) painter.setPen(Qt.white) painter.drawPoint(col, row) self.update() def get_image(self): return self.image.convertToFormat(QImage.Format_Grayscale8) def clear_image(self): self.image.fill(Qt.black) self.update() class MainWindow(QMainWindow): def __init__(self): super().__init__() self.init_ui() def init_ui(self): # 窗口 self.setWindowTitle("手写识别") self.setFixedSize(850, 950) # 布局 main_widget = QWidget() self.setCentralWidget(main_widget) layout = QVBoxLayout(main_widget) # 绘图 self.drawing_area = DrawingArea() layout.addWidget(self.drawing_area) # 按钮 btn_layout = QHBoxLayout() self.clear_btn = QPushButton("清除") self.recognize_btn = QPushButton("识别") btn_layout.addWidget(self.clear_btn) btn_layout.addWidget(self.recognize_btn) # 结果 self.prob_label = QLabel("概率分布:") self.result_label = QLabel("识别结果:") # 组装 layout.addLayout(btn_layout) layout.addWidget(self.prob_label) layout.addWidget(self.result_label) # 信号连接 self.clear_btn.clicked.connect(self.drawing_area.clear_image) self.recognize_btn.clicked.connect(self.recognize) def recognize(self): # 获取原始绘图区图像 qimg = self.drawing_area.get_image() #预测 pred_class, probabilities = test.predict_user_image(qimg,model) self.prob_label.setText(f"概率分布: {probabilities}") self.result_label.setText(f"识别结果: {pred_class}") if __name__ == "__main__": app = QApplication(sys.argv) window = MainWindow() window.show() sys.exit(app.exec_())