|
|
import gradio as gr
|
|
|
import numpy as np
|
|
|
from PIL import Image
|
|
|
import torch
|
|
|
import test
|
|
|
|
|
|
|
|
|
model = test.load_trained_model()
|
|
|
|
|
|
|
|
|
def predict_interface(sketch_image):
|
|
|
"""处理绘制图像的预测逻辑"""
|
|
|
if sketch_image is None:
|
|
|
return "请先绘制数字", {}
|
|
|
|
|
|
|
|
|
img = Image.fromarray(sketch_image).convert('L')
|
|
|
|
|
|
|
|
|
|
|
|
pred_class, probabilities = test.predict_user_image(img, model)
|
|
|
|
|
|
|
|
|
prob_dict = {str(i): float(prob) for i, prob in enumerate(probabilities)}
|
|
|
return f"识别结果: {pred_class}", prob_dict
|
|
|
|
|
|
|
|
|
def clear_canvas():
|
|
|
"""清空画布的函数"""
|
|
|
return None, "识别结果: ", {}
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="手写数字识别") as demo:
|
|
|
gr.Markdown("# 手写数字识别系统")
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
|
sketch = gr.Sketchpad(
|
|
|
label="绘制区域",
|
|
|
shape=(750, 750),
|
|
|
brush_radius=15,
|
|
|
image_mode="L",
|
|
|
invert_colors=True
|
|
|
)
|
|
|
|
|
|
|
|
|
with gr.Column():
|
|
|
result_label = gr.Label(label="概率分布", num_top_classes=5)
|
|
|
output_text = gr.Markdown("识别结果: ")
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
clear_btn = gr.Button("清除", variant="secondary")
|
|
|
submit_btn = gr.Button("识别", variant="primary")
|
|
|
|
|
|
|
|
|
submit_btn.click(
|
|
|
fn=predict_interface,
|
|
|
inputs=sketch,
|
|
|
outputs=[output_text, result_label]
|
|
|
)
|
|
|
|
|
|
clear_btn.click(
|
|
|
fn=lambda: [None, "识别结果: ", None],
|
|
|
outputs=[sketch, output_text, result_label]
|
|
|
)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
demo.launch()
|
|
|
|