from pathlib import Path import torch import gradio as gr from torch import nn import numpy as np print(gr.__version__) LABELS = Path("class_names.txt").read_text().splitlines() model = nn.Sequential( nn.Conv2d(1, 32, 3, padding="same"), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding="same"), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding="same"), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(1152, 256), nn.ReLU(), nn.Linear(256, len(LABELS)), ) state_dict = torch.load("pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict, strict=False) model.eval() def predict(im): if im is None: return {} # 处理输入图像 # 如果是字典格式(新版Gradio Sketchpad的输出),提取图像 if isinstance(im, dict): im = im['image'] if 'image' in im else im.get('composite', None) # 转换为numpy数组并确保是灰度图 if isinstance(im, np.ndarray): if len(im.shape) == 3: # 如果是RGB图像,转换为灰度图 im = np.mean(im, axis=2) else: return {} # 确保图像尺寸正确(28x28) if im.shape != (28, 28): from PIL import Image im_pil = Image.fromarray(im.astype('uint8')) im_pil = im_pil.resize((28, 28)) im = np.array(im_pil) # 转换为tensor并进行预测 x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0 with torch.no_grad(): out = model(x) probabilities = torch.nn.functional.softmax(out[0], dim=0) values, indices = torch.topk(probabilities, 5) return {LABELS[i]: v.item() for i, v in zip(indices, values)} # 创建Gradio界面 interface = gr.Interface( fn=predict, inputs=gr.Sketchpad( image_mode="L", # 灰度模式 canvas_size=(280, 280), # 画布大小 brush=gr.Brush(default_size=10) # 画笔设置 ), outputs=gr.Label(num_top_classes=5), # 显示前5个预测结果 title="Sketch Recognition", description="Who wants to play Pictionary? Draw a common object like a shovel or a laptop, and the algorithm will guess in real time!", article="

Sketch Recognition | Demo Model

", ) interface.launch(share=True)