import gradio as gr import numpy as np from PIL import Image import torch import test # 假设 test 模块包含预测逻辑 # 加载模型 (与 Qt 版本保持一致) model = test.load_trained_model() def predict_interface(sketch_image): """处理绘制图像的预测逻辑""" if sketch_image is None: return "请先绘制数字", {} # 将 sketchpad 的 numpy 数组转换为模型需要的格式 img = Image.fromarray(sketch_image).convert('L') # 转换为灰度图 # 可能需要添加预处理步骤(根据 test.predict_user_image 的接口调整) # 如果用原始 Qt 的预处理逻辑,这里可以复用 test 模块的函数 pred_class, probabilities = test.predict_user_image(img, model) # 转换概率为字典供 Label 组件显示 prob_dict = {str(i): float(prob) for i, prob in enumerate(probabilities)} return f"识别结果: {pred_class}", prob_dict def clear_canvas(): """清空画布的函数""" return None, "识别结果: ", {} # 构建 Gradio 界面 with gr.Blocks(title="手写数字识别") as demo: gr.Markdown("# 手写数字识别系统") with gr.Row(): # 手写板组件 (调整尺寸匹配原 Qt 设计) sketch = gr.Sketchpad( label="绘制区域", shape=(750, 750), brush_radius=15, # 根据原 Qt 的笔刷大小调整 image_mode="L", # 灰度模式 invert_colors=True # 反转颜色(白底黑字) ) # 结果显示区域 with gr.Column(): result_label = gr.Label(label="概率分布", num_top_classes=5) output_text = gr.Markdown("识别结果: ") # 按钮行 with gr.Row(): clear_btn = gr.Button("清除", variant="secondary") submit_btn = gr.Button("识别", variant="primary") # 绑定交互事件 submit_btn.click( fn=predict_interface, inputs=sketch, outputs=[output_text, result_label] ) clear_btn.click( fn=lambda: [None, "识别结果: ", None], # 清空所有输出 outputs=[sketch, output_text, result_label] ) # 启动应用 if __name__ == "__main__": demo.launch()