|
|
import sys
|
|
|
from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QPushButton, QLabel, QPlainTextEdit, QMainWindow,QHBoxLayout
|
|
|
from PyQt5.QtCore import Qt, QPoint
|
|
|
from PyQt5.QtGui import QPainter, QImage, QColor,QPen
|
|
|
import test
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
|
|
|
from test import load_trained_model
|
|
|
|
|
|
model = load_trained_model()
|
|
|
|
|
|
|
|
|
class DrawingArea(QWidget):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.setFixedSize(750 + 40, 750 + 40)
|
|
|
self.drawing = False
|
|
|
self.last_pos = QPoint()
|
|
|
|
|
|
|
|
|
self.image = QImage(28, 28, QImage.Format_RGB888)
|
|
|
self.image.fill(Qt.black)
|
|
|
|
|
|
|
|
|
self.cell_size = 750 // 28
|
|
|
|
|
|
def paintEvent(self, event):
|
|
|
painter = QPainter(self)
|
|
|
painter.setRenderHint(QPainter.Antialiasing, False)
|
|
|
|
|
|
|
|
|
scaled_img = self.image.scaled(750, 750, Qt.KeepAspectRatio, Qt.FastTransformation)
|
|
|
painter.drawImage(20, 20, scaled_img)
|
|
|
|
|
|
|
|
|
painter.setPen(QPen(Qt.gray, 1, Qt.SolidLine))
|
|
|
for i in range(29):
|
|
|
|
|
|
painter.drawLine(20, 20 + i * self.cell_size,
|
|
|
20 + 750, 20 + i * self.cell_size)
|
|
|
|
|
|
painter.drawLine(20 + i * self.cell_size, 20,
|
|
|
20 + i * self.cell_size, 20 + 750)
|
|
|
|
|
|
def mousePressEvent(self, event):
|
|
|
if event.button() == Qt.LeftButton:
|
|
|
self.drawing = True
|
|
|
self.handleDrawing(event.pos())
|
|
|
|
|
|
def mouseMoveEvent(self, event):
|
|
|
if self.drawing:
|
|
|
self.handleDrawing(event.pos())
|
|
|
|
|
|
def mouseReleaseEvent(self, event):
|
|
|
if event.button() == Qt.LeftButton:
|
|
|
self.drawing = False
|
|
|
|
|
|
def handleDrawing(self, pos):
|
|
|
|
|
|
x = pos.x() - 20
|
|
|
y = pos.y() - 20
|
|
|
|
|
|
|
|
|
if 0 <= x < 750 and 0 <= y < 750:
|
|
|
|
|
|
col = x // self.cell_size
|
|
|
row = y // self.cell_size
|
|
|
|
|
|
|
|
|
if (col, row) != self.last_pos:
|
|
|
self.last_pos = (col, row)
|
|
|
painter = QPainter(self.image)
|
|
|
painter.setPen(Qt.white)
|
|
|
painter.drawPoint(col, row)
|
|
|
self.update()
|
|
|
|
|
|
def get_image(self):
|
|
|
return self.image.convertToFormat(QImage.Format_Grayscale8)
|
|
|
|
|
|
def clear_image(self):
|
|
|
self.image.fill(Qt.black)
|
|
|
self.update()
|
|
|
|
|
|
|
|
|
class MainWindow(QMainWindow):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.init_ui()
|
|
|
|
|
|
def init_ui(self):
|
|
|
|
|
|
self.setWindowTitle("手写识别")
|
|
|
self.setFixedSize(850, 950)
|
|
|
|
|
|
|
|
|
main_widget = QWidget()
|
|
|
self.setCentralWidget(main_widget)
|
|
|
layout = QVBoxLayout(main_widget)
|
|
|
|
|
|
|
|
|
self.drawing_area = DrawingArea()
|
|
|
layout.addWidget(self.drawing_area)
|
|
|
|
|
|
|
|
|
btn_layout = QHBoxLayout()
|
|
|
self.clear_btn = QPushButton("清除")
|
|
|
self.recognize_btn = QPushButton("识别")
|
|
|
btn_layout.addWidget(self.clear_btn)
|
|
|
btn_layout.addWidget(self.recognize_btn)
|
|
|
|
|
|
|
|
|
self.prob_label = QLabel("概率分布:")
|
|
|
self.result_label = QLabel("识别结果:")
|
|
|
|
|
|
|
|
|
layout.addLayout(btn_layout)
|
|
|
layout.addWidget(self.prob_label)
|
|
|
layout.addWidget(self.result_label)
|
|
|
|
|
|
self.clear_btn.clicked.connect(self.drawing_area.clear_image)
|
|
|
self.recognize_btn.clicked.connect(self.recognize)
|
|
|
|
|
|
|
|
|
|
|
|
def recognize(self):
|
|
|
|
|
|
qimg = self.drawing_area.get_image()
|
|
|
|
|
|
pred_class, probabilities = test.predict_user_image(qimg,model)
|
|
|
|
|
|
self.prob_label.setText(f"概率分布: {probabilities}")
|
|
|
self.result_label.setText(f"识别结果: {pred_class}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
app = QApplication(sys.argv)
|
|
|
window = MainWindow()
|
|
|
window.show()
|
|
|
sys.exit(app.exec_())
|
|
|
|