mochuan zhan commited on
Commit
9159eeb
·
1 Parent(s): d78706f

iface fixed

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -67,8 +67,8 @@ class ViT(nn.Module):
67
  # 加载模型
68
  model = ViT(num_classes=10) # 确保num_classes与你的MNIST任务一致
69
  model_path = "vit_model.pth" # 模型权重文件名
70
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
71
- model.eval()
72
 
73
  # 定义图像预处理
74
  transform = transforms.Compose([
@@ -106,8 +106,7 @@ def classify_image(image):
106
 
107
  iface = gr.Interface(
108
  fn=classify_image,
109
- inputs=gr.Image(
110
- source="canvas",
111
  tool="editor",
112
  type="pil",
113
  invert_colors=True,
@@ -116,7 +115,7 @@ iface = gr.Interface(
116
  label="Draw a digit"
117
  ),
118
  outputs=gr.Label(num_top_classes=1),
119
- title="MNIST Digit Classification with ViT",
120
  description="使用鼠标手绘一个数字,模型将预测其所属的类别。"
121
  )
122
 
 
67
  # 加载模型
68
  model = ViT(num_classes=10) # 确保num_classes与你的MNIST任务一致
69
  model_path = "vit_model.pth" # 模型权重文件名
70
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
71
+
72
 
73
  # 定义图像预处理
74
  transform = transforms.Compose([
 
106
 
107
  iface = gr.Interface(
108
  fn=classify_image,
109
+ inputs=gr.Sketchpad(
 
110
  tool="editor",
111
  type="pil",
112
  invert_colors=True,
 
115
  label="Draw a digit"
116
  ),
117
  outputs=gr.Label(num_top_classes=1),
118
+ title="MNIST Digit Classification with ViT",
119
  description="使用鼠标手绘一个数字,模型将预测其所属的类别。"
120
  )
121