WeekendZhou commited on
Commit
6989988
·
1 Parent(s): e9ff88c

Add gradio ui

Browse files
Files changed (1) hide show
  1. gradio_ui.py +71 -0
gradio_ui.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ import test # 假设 test 模块包含预测逻辑
6
+
7
+ # 加载模型 (与 Qt 版本保持一致)
8
+ model = test.load_trained_model()
9
+
10
+
11
+ def predict_interface(sketch_image):
12
+ """处理绘制图像的预测逻辑"""
13
+ if sketch_image is None:
14
+ return "请先绘制数字", {}
15
+
16
+ # 将 sketchpad 的 numpy 数组转换为模型需要的格式
17
+ img = Image.fromarray(sketch_image).convert('L') # 转换为灰度图
18
+
19
+ # 可能需要添加预处理步骤(根据 test.predict_user_image 的接口调整)
20
+ # 如果用原始 Qt 的预处理逻辑,这里可以复用 test 模块的函数
21
+ pred_class, probabilities = test.predict_user_image(img, model)
22
+
23
+ # 转换概率为字典供 Label 组件显示
24
+ prob_dict = {str(i): float(prob) for i, prob in enumerate(probabilities)}
25
+ return f"识别结果: {pred_class}", prob_dict
26
+
27
+
28
+ def clear_canvas():
29
+ """清空画布的函数"""
30
+ return None, "识别结果: ", {}
31
+
32
+
33
+ # 构建 Gradio 界面
34
+ with gr.Blocks(title="手写数字识别") as demo:
35
+ gr.Markdown("# 手写数字识别系统")
36
+
37
+ with gr.Row():
38
+ # 手写板组件 (调整尺寸匹配原 Qt 设计)
39
+ sketch = gr.Sketchpad(
40
+ label="绘制区域",
41
+ shape=(750, 750),
42
+ brush_radius=15, # 根据原 Qt 的笔刷大小调整
43
+ image_mode="L", # 灰度模式
44
+ invert_colors=True # 反转颜色(白底黑字)
45
+ )
46
+
47
+ # 结果显示区域
48
+ with gr.Column():
49
+ result_label = gr.Label(label="概率分布", num_top_classes=5)
50
+ output_text = gr.Markdown("识别结果: ")
51
+
52
+ # 按钮行
53
+ with gr.Row():
54
+ clear_btn = gr.Button("清除", variant="secondary")
55
+ submit_btn = gr.Button("识别", variant="primary")
56
+
57
+ # 绑定交互事件
58
+ submit_btn.click(
59
+ fn=predict_interface,
60
+ inputs=sketch,
61
+ outputs=[output_text, result_label]
62
+ )
63
+
64
+ clear_btn.click(
65
+ fn=lambda: [None, "识别结果: ", None], # 清空所有输出
66
+ outputs=[sketch, output_text, result_label]
67
+ )
68
+
69
+ # 启动应用
70
+ if __name__ == "__main__":
71
+ demo.launch()