Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| import os | |
| import PIL | |
| from PIL import ImageOps | |
| import test_1 | |
| import torch | |
| model = test_1.load_trained_model() | |
| def convert_to_28x28_grid(canvas_image): | |
| #将750x750画布转换为28x28数值矩阵 | |
| canvas_array = np.array(canvas_image) | |
| GRID_SIZE = 28 | |
| CELL_SIZE = 750 // GRID_SIZE | |
| # 转换为灰度 | |
| if canvas_array.ndim == 3: | |
| gray = np.dot(canvas_array[..., :3], [0, 0, 0]) # 标准灰度公式 | |
| else: | |
| gray = canvas_array | |
| grid = np.zeros((GRID_SIZE, GRID_SIZE), dtype=np.float32) | |
| for i in range(GRID_SIZE): | |
| for j in range(GRID_SIZE): | |
| # 提取每个单元格 | |
| x_start = i * CELL_SIZE | |
| x_end = (i + 1) * CELL_SIZE | |
| y_start = j * CELL_SIZE | |
| y_end = (j + 1) * CELL_SIZE | |
| cell = gray[x_start:x_end, y_start:y_end] | |
| # 注意:所有以下操作应在循环内部 | |
| activation = np.mean(255 - cell) # 反色处理 | |
| grid[i, j] = min(255, activation * 2.5) # 提高对比度 | |
| return grid.astype(np.uint8) # 0-255整型 | |
| def predict_digit(input_image): | |
| if input_image is None: | |
| return "请先绘制数字", {} | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 使用 GPU/CPU | |
| model.to(device) | |
| img=ImageOps.invert(input_image) | |
| # 执行网格转换 | |
| digit_matrix = convert_to_28x28_grid(img) | |
| # 模型预测 | |
| pred_class, probabilities = test_1.predict_user_image(digit_matrix, model,device) | |
| probs_dict = {str(i): float(p) for i, p in enumerate(probabilities)} | |
| return f"结果: {pred_class}", probs_dict | |
| # 界面构建 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""# digit-recognition""") | |
| with gr.Row(): | |
| # 本地使用 shape 参数 | |
| input_image = gr.Sketchpad( | |
| shape=(750, 750), | |
| brush_color="black", | |
| image_mode="L", # 确保输出是单通道 | |
| type="pil" | |
| ) | |
| with gr.Row(): | |
| output_label = gr.Label(label="识别结果") | |
| output_prob = gr.Label(label="概率分布") | |
| with gr.Row(): | |
| recognize_btn = gr.Button("识别", variant="primary") | |
| recognize_btn.click( | |
| predict_digit, | |
| inputs=input_image, | |
| outputs=[output_label, output_prob] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |