mochuan zhan commited on
Commit
aa2c6eb
·
1 Parent(s): dfefec8

fix again again

Browse files
Files changed (1) hide show
  1. app.py +16 -12
app.py CHANGED
@@ -80,32 +80,30 @@ transform = transforms.Compose([
80
 
81
  # 定义预测函数
82
  def classify_image(image):
83
- # NumPy 数组转换为 PIL 图像
84
- image = Image.fromarray(image).convert("L")
 
85
 
86
  # 反转颜色
87
  image = ImageOps.invert(image)
88
 
89
- # 调整图像大小到模型需要的输入尺寸
90
  image = image.resize((224, 224))
91
 
92
- # 图像预处理(根据您的模型需要进行调整)
93
  img = transform(image).unsqueeze(0) # 添加批次维度
94
 
95
  # 模型预测
96
  with torch.no_grad():
97
  outputs = model(img)
98
- # 如果模型输出未经过 softmax,可以添加
99
  probabilities = F.softmax(outputs, dim=1)
100
 
101
  # 获取预测结果
102
  _, predicted = torch.max(outputs, 1)
 
103
 
104
- # 如果需要返回概率
105
- # return {str(predicted.item()): probabilities[0][predicted].item()}
106
-
107
- # 只返回预测的类别
108
- return str(predicted.item())
109
 
110
  # # 创建Gradio界面
111
  # iface = gr.Interface(
@@ -118,11 +116,17 @@ def classify_image(image):
118
 
119
  iface = gr.Interface(
120
  fn=classify_image,
121
- inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L', brush=gr.Brush()),
 
 
 
 
 
122
  outputs=gr.Label(num_top_classes=1),
123
  title="MNIST Digit Classification with ViT",
124
- description="使用鼠标手绘一个数字,模型将预测其所属的类别。"
125
  )
126
 
127
 
 
128
  iface.launch()
 
80
 
81
  # 定义预测函数
82
  def classify_image(image):
83
+ # image 已经是一个 PIL 图像
84
+ # 将图像转换为灰度模式
85
+ image = image.convert("L")
86
 
87
  # 反转颜色
88
  image = ImageOps.invert(image)
89
 
90
+ # 调整图像大小
91
  image = image.resize((224, 224))
92
 
93
+ # 图像预处理
94
  img = transform(image).unsqueeze(0) # 添加批次维度
95
 
96
  # 模型预测
97
  with torch.no_grad():
98
  outputs = model(img)
 
99
  probabilities = F.softmax(outputs, dim=1)
100
 
101
  # 获取预测结果
102
  _, predicted = torch.max(outputs, 1)
103
+ confidence = probabilities[0][predicted].item()
104
 
105
+ # 返回结果字典,包含预测类别和置信度
106
+ return {str(predicted.item()): confidence}
 
 
 
107
 
108
  # # 创建Gradio界面
109
  # iface = gr.Interface(
 
116
 
117
  iface = gr.Interface(
118
  fn=classify_image,
119
+ inputs=gr.Sketchpad(
120
+ shape=(224, 224),
121
+ invert_colors=False,
122
+ label="Draw a digit",
123
+ type='pil' # 设置为返回 PIL 图像
124
+ ),
125
  outputs=gr.Label(num_top_classes=1),
126
  title="MNIST Digit Classification with ViT",
127
+ description="Use the mouse to hand draw a number and the model will predict the category it belongs to."
128
  )
129
 
130
 
131
+
132
  iface.launch()