File size: 5,160 Bytes
78245fb
 
 
dfefec8
78245fb
dfefec8
78245fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d78706f
9159eeb
360f3af
78245fb
 
 
360f3af
 
78245fb
360f3af
78245fb
 
 
 
a2ce9c9
 
 
 
 
455f3a3
 
360f3af
455f3a3
360f3af
 
aa2c6eb
78245fb
360f3af
 
 
78245fb
 
 
 
dfefec8
78245fb
 
360f3af
aa2c6eb
dfefec8
aa2c6eb
360f3af
aa2c6eb
78245fb
 
 
 
 
 
 
 
 
 
 
 
360f3af
78245fb
9159eeb
455f3a3
78245fb
 
dfefec8
360f3af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageOps
import torch.nn as nn
import torch.nn.functional as F

# 如果你的模型结构与标准的torchvision模型不同,请确保在此处定义或导入你的模型结构
# 例如,如果你有一个model.py文件:
# from model import ViTModel

# 示例:定义一个简单的ViT模型结构(请根据你的实际模型调整)
class ViT(nn.Module):
    def __init__(self, image_size=28, patch_size=7, num_classes=10, dim=128, depth=6, heads=8, mlp_dim=256, dropout=0.1):
        super(ViT, self).__init__()

        assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 1 * patch_size ** 2

        # 定义线性层将图像分块并映射到嵌入空间
        self.patch_embedding = nn.Linear(patch_dim, dim)

        # 位置编码
        # nn.Parameter是Pytorch中的一个类,用于将一个张量注册为模型的参数
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))

        # Dropout层
        self.dropout = nn.Dropout(dropout)

        # Transformer编码器
        # 当 batch_first=True 时,输入和输出张量的形状为 (batch_size, seq_length, feature_dim)。当 batch_first=False 时,输入和输出张量的形状为 (seq_length, batch_size, feature_dim)。
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=dim,
                nhead=heads,
                dim_feedforward=mlp_dim
                # batch_first=True
            ),
            num_layers=depth
        )

        # 分类头
        # nn.Identity()是一个空的层,它不执行任何操作,只是返回输入
        # self.to_cls_token = nn.Identity()
        # self.mlp_head = nn.Linear(dim, num_classes)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
    
    def forward(self, x):
        # x shape: [batch_size, 1, 28, 28]
        batch_size = x.size(0)
        x = x.view(batch_size, -1, 7*7)  # 将图像划分为7x7的Patch
        x = self.patch_embedding(x)  # [batch_size, num_patches, dim]
        x += self.pos_embedding  # 添加位置编码
        x = self.dropout(x)  # 应用Dropout

        x = x.permute(1, 0, 2)  # Transformer期望的输入形状:[seq_len, batch_size, embedding_dim]
        x = self.transformer(x)  # [序列长度, batch_size, dim]
        x = x.permute(1, 0, 2)  # 转回原来的形状:[batch_size, seq_len, dim]

        x = x.mean(dim=1)  # 对所有Patch取平均,x.mean(dim=1) 这一步是对所有 Patch 的特征向量取平均值,从而得到一个代表整个图像的全局特征向量。
        x = self.mlp_head(x)  # [batch_size, num_classes]
        return x

# 加载模型
model = ViT(num_classes=10)  # 确保num_classes与你的MNIST任务一致
model_path = "vit_model.pth"  # 模型权重文件名
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
model.eval()

# 定义图像预处理
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # 转换为单通道
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 定义预测函数
def classify_image(image):
    # 检查是否包含 'composite' 数据
    if isinstance(image, dict) and 'composite' in image:
        image = image['composite']
    
    # 确保 image 是一个 PIL 图像
    if not isinstance(image, Image.Image):
        raise TypeError(f"Expected image to be PIL Image, but got {type(image)}")
    # 打印image的数组
    
    print(image)

    # 图像预处理
    img = transform(image).unsqueeze(0)  # 添加批次维度

    image_pil = Image.fromarray(img.numpy().squeeze() * 255).convert('L')
    image_pil.show()
    
    # 模型预测
    with torch.no_grad():
        outputs = model(img)
        probabilities = F.softmax(outputs, dim=1)
    
    # 获取预测结果
    _, predicted = torch.max(probabilities, 1)
    confidence = probabilities[0][predicted].item()
    
    # 返回结果字典,包含预测类别和置信度
    print(predicted, confidence)
    return {str(predicted.item()): confidence}

# # 创建Gradio界面
# iface = gr.Interface(
#     fn=classify_image,
#     inputs=gr.Image(shape=(28, 28), image_mode='L', source="upload", tool="editor"),
#     outputs=gr.Label(num_top_classes=1),
#     title="MNIST Classification with ViT",
#     description="上传一张28x28的灰度图像,模型将预测其所属的数字类别。"
# )

iface = gr.Interface(
    fn=classify_image,
    inputs=gr.Sketchpad(type='pil', image_mode='L', brush=gr.Brush(default_size=18), crop_size=(600, 600)),
    outputs=gr.Label(num_top_classes=1),
    title="MNIST Digit Classification with ViT", 
    description="Use the mouse to hand draw a number and the model will predict the category it belongs to."
)


if __name__ == "__main__":
    iface.launch()