Spaces:
Build error
Build error
mochuan zhan commited on
Commit ·
9159eeb
1
Parent(s): d78706f
iface fixed
Browse files
app.py
CHANGED
|
@@ -67,8 +67,8 @@ class ViT(nn.Module):
|
|
| 67 |
# 加载模型
|
| 68 |
model = ViT(num_classes=10) # 确保num_classes与你的MNIST任务一致
|
| 69 |
model_path = "vit_model.pth" # 模型权重文件名
|
| 70 |
-
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 71 |
-
|
| 72 |
|
| 73 |
# 定义图像预处理
|
| 74 |
transform = transforms.Compose([
|
|
@@ -106,8 +106,7 @@ def classify_image(image):
|
|
| 106 |
|
| 107 |
iface = gr.Interface(
|
| 108 |
fn=classify_image,
|
| 109 |
-
inputs=gr.
|
| 110 |
-
source="canvas",
|
| 111 |
tool="editor",
|
| 112 |
type="pil",
|
| 113 |
invert_colors=True,
|
|
@@ -116,7 +115,7 @@ iface = gr.Interface(
|
|
| 116 |
label="Draw a digit"
|
| 117 |
),
|
| 118 |
outputs=gr.Label(num_top_classes=1),
|
| 119 |
-
title="MNIST Digit Classification with ViT",
|
| 120 |
description="使用鼠标手绘一个数字,模型将预测其所属的类别。"
|
| 121 |
)
|
| 122 |
|
|
|
|
| 67 |
# 加载模型
|
| 68 |
model = ViT(num_classes=10) # 确保num_classes与你的MNIST任务一致
|
| 69 |
model_path = "vit_model.pth" # 模型权重文件名
|
| 70 |
+
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
|
| 71 |
+
|
| 72 |
|
| 73 |
# 定义图像预处理
|
| 74 |
transform = transforms.Compose([
|
|
|
|
| 106 |
|
| 107 |
iface = gr.Interface(
|
| 108 |
fn=classify_image,
|
| 109 |
+
inputs=gr.Sketchpad(
|
|
|
|
| 110 |
tool="editor",
|
| 111 |
type="pil",
|
| 112 |
invert_colors=True,
|
|
|
|
| 115 |
label="Draw a digit"
|
| 116 |
),
|
| 117 |
outputs=gr.Label(num_top_classes=1),
|
| 118 |
+
title="MNIST Digit Classification with ViT",
|
| 119 |
description="使用鼠标手绘一个数字,模型将预测其所属的类别。"
|
| 120 |
)
|
| 121 |
|