WeekendZhou's picture
add cpu channel
711e925
import gradio as gr
import numpy as np
import os
import PIL
from PIL import ImageOps
import test_1
import torch
model = test_1.load_trained_model()
def convert_to_28x28_grid(canvas_image):
#将750x750画布转换为28x28数值矩阵
canvas_array = np.array(canvas_image)
GRID_SIZE = 28
CELL_SIZE = 750 // GRID_SIZE
# 转换为灰度
if canvas_array.ndim == 3:
gray = np.dot(canvas_array[..., :3], [0, 0, 0]) # 标准灰度公式
else:
gray = canvas_array
grid = np.zeros((GRID_SIZE, GRID_SIZE), dtype=np.float32)
for i in range(GRID_SIZE):
for j in range(GRID_SIZE):
# 提取每个单元格
x_start = i * CELL_SIZE
x_end = (i + 1) * CELL_SIZE
y_start = j * CELL_SIZE
y_end = (j + 1) * CELL_SIZE
cell = gray[x_start:x_end, y_start:y_end]
# 注意:所有以下操作应在循环内部
activation = np.mean(255 - cell) # 反色处理
grid[i, j] = min(255, activation * 2.5) # 提高对比度
return grid.astype(np.uint8) # 0-255整型
def predict_digit(input_image):
if input_image is None:
return "请先绘制数字", {}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 使用 GPU/CPU
model.to(device)
img=ImageOps.invert(input_image)
# 执行网格转换
digit_matrix = convert_to_28x28_grid(img)
# 模型预测
pred_class, probabilities = test_1.predict_user_image(digit_matrix, model,device)
probs_dict = {str(i): float(p) for i, p in enumerate(probabilities)}
return f"结果: {pred_class}", probs_dict
# 界面构建
with gr.Blocks() as demo:
gr.Markdown("""# digit-recognition""")
with gr.Row():
# 本地使用 shape 参数
input_image = gr.Sketchpad(
shape=(750, 750),
brush_color="black",
image_mode="L", # 确保输出是单通道
type="pil"
)
with gr.Row():
output_label = gr.Label(label="识别结果")
output_prob = gr.Label(label="概率分布")
with gr.Row():
recognize_btn = gr.Button("识别", variant="primary")
recognize_btn.click(
predict_digit,
inputs=input_image,
outputs=[output_label, output_prob]
)
if __name__ == "__main__":
demo.launch()