0xtimi's picture
upgrade gradio
515bcf8
from pathlib import Path
import torch
import gradio as gr
from torch import nn
import numpy as np
print(gr.__version__)
LABELS = Path("class_names.txt").read_text().splitlines()
model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding="same"),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding="same"),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding="same"),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
model.eval()
def predict(im):
if im is None:
return {}
# 处理输入图像
# 如果是字典格式(新版Gradio Sketchpad的输出),提取图像
if isinstance(im, dict):
im = im['image'] if 'image' in im else im.get('composite', None)
# 转换为numpy数组并确保是灰度图
if isinstance(im, np.ndarray):
if len(im.shape) == 3:
# 如果是RGB图像,转换为灰度图
im = np.mean(im, axis=2)
else:
return {}
# 确保图像尺寸正确(28x28)
if im.shape != (28, 28):
from PIL import Image
im_pil = Image.fromarray(im.astype('uint8'))
im_pil = im_pil.resize((28, 28))
im = np.array(im_pil)
# 转换为tensor并进行预测
x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.0
with torch.no_grad():
out = model(x)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
values, indices = torch.topk(probabilities, 5)
return {LABELS[i]: v.item() for i, v in zip(indices, values)}
# 创建Gradio界面
interface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(
image_mode="L", # 灰度模式
canvas_size=(280, 280), # 画布大小
brush=gr.Brush(default_size=10) # 画笔设置
),
outputs=gr.Label(num_top_classes=5), # 显示前5个预测结果
title="Sketch Recognition",
description="Who wants to play Pictionary? Draw a common object like a shovel or a laptop, and the algorithm will guess in real time!",
article="<p style='text-align: center'>Sketch Recognition | Demo Model</p>",
)
interface.launch(share=True)