mochuan zhan commited on
Commit
dfefec8
·
1 Parent(s): fceab91

fix again

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import gradio as gr
2
  import torch
3
  import torchvision.transforms as transforms
4
- from PIL import Image
5
  import torch.nn as nn
 
6
 
7
  # 如果你的模型结构与标准的torchvision模型不同,请确保在此处定义或导入你的模型结构
8
  # 例如,如果你有一个model.py文件:
@@ -79,19 +80,31 @@ transform = transforms.Compose([
79
 
80
  # 定义预测函数
81
  def classify_image(image):
82
- # 如果输入是灰度图,将其转换为RGB
83
- if image.mode != "RGB":
84
- image = image.convert("RGB")
85
 
86
- # 预处理图像
 
 
 
 
 
 
87
  img = transform(image).unsqueeze(0) # 添加批次维度
88
 
89
  # 模型预测
90
  with torch.no_grad():
91
  outputs = model(img)
 
 
92
 
93
  # 获取预测结果
94
  _, predicted = torch.max(outputs, 1)
 
 
 
 
 
95
  return str(predicted.item())
96
 
97
  # # 创建Gradio界面
@@ -103,16 +116,13 @@ def classify_image(image):
103
  # description="上传一张28x28的灰度图像,模型将预测其所属的数字类别。"
104
  # )
105
 
106
-
107
  iface = gr.Interface(
108
  fn=classify_image,
109
- inputs=gr.Sketchpad(
110
- shape=(224, 224),
111
- label="Draw a digit"
112
- ),
113
  outputs=gr.Label(num_top_classes=1),
114
  title="MNIST Digit Classification with ViT",
115
  description="使用鼠标手绘一个数字,模型将预测其所属的类别。"
116
  )
117
 
 
118
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
  import torchvision.transforms as transforms
4
+ from PIL import Image, ImageOps
5
  import torch.nn as nn
6
+ import torch.nn.functional as F
7
 
8
  # 如果你的模型结构与标准的torchvision模型不同,请确保在此处定义或导入你的模型结构
9
  # 例如,如果你有一个model.py文件:
 
80
 
81
  # 定义预测函数
82
  def classify_image(image):
83
+ # 将 NumPy 数组转换为 PIL 图像
84
+ image = Image.fromarray(image).convert("L")
 
85
 
86
+ # 反转颜色
87
+ image = ImageOps.invert(image)
88
+
89
+ # 调整图像大小到模型需要的输入尺寸
90
+ image = image.resize((224, 224))
91
+
92
+ # 图像预处理(根据您的模型需要进行调整)
93
  img = transform(image).unsqueeze(0) # 添加批次维度
94
 
95
  # 模型预测
96
  with torch.no_grad():
97
  outputs = model(img)
98
+ # 如果模型输出未经过 softmax,可以添加
99
+ probabilities = F.softmax(outputs, dim=1)
100
 
101
  # 获取预测结果
102
  _, predicted = torch.max(outputs, 1)
103
+
104
+ # 如果需要返回概率
105
+ # return {str(predicted.item()): probabilities[0][predicted].item()}
106
+
107
+ # 只返回预测的类别
108
  return str(predicted.item())
109
 
110
  # # 创建Gradio界面
 
116
  # description="上传一张28x28的灰度图像,模型将预测其所属的数字类别。"
117
  # )
118
 
 
119
  iface = gr.Interface(
120
  fn=classify_image,
121
+ inputs=gr.Sketchpad(crop_size=(256,256), type='numpy', image_mode='L', brush=gr.Brush()),
 
 
 
122
  outputs=gr.Label(num_top_classes=1),
123
  title="MNIST Digit Classification with ViT",
124
  description="使用鼠标手绘一个数字,模型将预测其所属的类别。"
125
  )
126
 
127
+
128
  iface.launch()