digit-recognition_0-9 / gradio_ui.py
WeekendZhou's picture
ui.py里面是python QT的界面,我不会JS。
b98d6e3 verified
import gradio as gr
import numpy as np
from PIL import Image
import torch
import test # 假设 test 模块包含预测逻辑
# 加载模型 (与 Qt 版本保持一致)
model = test.load_trained_model()
def predict_interface(sketch_image):
"""处理绘制图像的预测逻辑"""
if sketch_image is None:
return "请先绘制数字", {}
# 将 sketchpad 的 numpy 数组转换为模型需要的格式
img = Image.fromarray(sketch_image).convert('L') # 转换为灰度图
# 可能需要添加预处理步骤(根据 test.predict_user_image 的接口调整)
# 如果用原始 Qt 的预处理逻辑,这里可以复用 test 模块的函数
pred_class, probabilities = test.predict_user_image(img, model)
# 转换概率为字典供 Label 组件显示
prob_dict = {str(i): float(prob) for i, prob in enumerate(probabilities)}
return f"识别结果: {pred_class}", prob_dict
def clear_canvas():
"""清空画布的函数"""
return None, "识别结果: ", {}
# 构建 Gradio 界面
with gr.Blocks(title="手写数字识别") as demo:
gr.Markdown("# 手写数字识别系统")
with gr.Row():
# 手写板组件 (调整尺寸匹配原 Qt 设计)
sketch = gr.Sketchpad(
label="绘制区域",
shape=(750, 750),
brush_radius=15, # 根据原 Qt 的笔刷大小调整
image_mode="L", # 灰度模式
invert_colors=True # 反转颜色(白底黑字)
)
# 结果显示区域
with gr.Column():
result_label = gr.Label(label="概率分布", num_top_classes=5)
output_text = gr.Markdown("识别结果: ")
# 按钮行
with gr.Row():
clear_btn = gr.Button("清除", variant="secondary")
submit_btn = gr.Button("识别", variant="primary")
# 绑定交互事件
submit_btn.click(
fn=predict_interface,
inputs=sketch,
outputs=[output_text, result_label]
)
clear_btn.click(
fn=lambda: [None, "识别结果: ", None], # 清空所有输出
outputs=[sketch, output_text, result_label]
)
# 启动应用
if __name__ == "__main__":
demo.launch()