Spaces:
Sleeping
Sleeping
File size: 5,160 Bytes
78245fb dfefec8 78245fb dfefec8 78245fb d78706f 9159eeb 360f3af 78245fb 360f3af 78245fb 360f3af 78245fb a2ce9c9 455f3a3 360f3af 455f3a3 360f3af aa2c6eb 78245fb 360f3af 78245fb dfefec8 78245fb 360f3af aa2c6eb dfefec8 aa2c6eb 360f3af aa2c6eb 78245fb 360f3af 78245fb 9159eeb 455f3a3 78245fb dfefec8 360f3af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageOps
import torch.nn as nn
import torch.nn.functional as F
# 如果你的模型结构与标准的torchvision模型不同,请确保在此处定义或导入你的模型结构
# 例如,如果你有一个model.py文件:
# from model import ViTModel
# 示例:定义一个简单的ViT模型结构(请根据你的实际模型调整)
class ViT(nn.Module):
def __init__(self, image_size=28, patch_size=7, num_classes=10, dim=128, depth=6, heads=8, mlp_dim=256, dropout=0.1):
super(ViT, self).__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = 1 * patch_size ** 2
# 定义线性层将图像分块并映射到嵌入空间
self.patch_embedding = nn.Linear(patch_dim, dim)
# 位置编码
# nn.Parameter是Pytorch中的一个类,用于将一个张量注册为模型的参数
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
# Dropout层
self.dropout = nn.Dropout(dropout)
# Transformer编码器
# 当 batch_first=True 时,输入和输出张量的形状为 (batch_size, seq_length, feature_dim)。当 batch_first=False 时,输入和输出张量的形状为 (seq_length, batch_size, feature_dim)。
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=dim,
nhead=heads,
dim_feedforward=mlp_dim
# batch_first=True
),
num_layers=depth
)
# 分类头
# nn.Identity()是一个空的层,它不执行任何操作,只是返回输入
# self.to_cls_token = nn.Identity()
# self.mlp_head = nn.Linear(dim, num_classes)
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, x):
# x shape: [batch_size, 1, 28, 28]
batch_size = x.size(0)
x = x.view(batch_size, -1, 7*7) # 将图像划分为7x7的Patch
x = self.patch_embedding(x) # [batch_size, num_patches, dim]
x += self.pos_embedding # 添加位置编码
x = self.dropout(x) # 应用Dropout
x = x.permute(1, 0, 2) # Transformer期望的输入形状:[seq_len, batch_size, embedding_dim]
x = self.transformer(x) # [序列长度, batch_size, dim]
x = x.permute(1, 0, 2) # 转回原来的形状:[batch_size, seq_len, dim]
x = x.mean(dim=1) # 对所有Patch取平均,x.mean(dim=1) 这一步是对所有 Patch 的特征向量取平均值,从而得到一个代表整个图像的全局特征向量。
x = self.mlp_head(x) # [batch_size, num_classes]
return x
# 加载模型
model = ViT(num_classes=10) # 确保num_classes与你的MNIST任务一致
model_path = "vit_model.pth" # 模型权重文件名
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
model.eval()
# 定义图像预处理
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # 转换为单通道
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 定义预测函数
def classify_image(image):
# 检查是否包含 'composite' 数据
if isinstance(image, dict) and 'composite' in image:
image = image['composite']
# 确保 image 是一个 PIL 图像
if not isinstance(image, Image.Image):
raise TypeError(f"Expected image to be PIL Image, but got {type(image)}")
# 打印image的数组
print(image)
# 图像预处理
img = transform(image).unsqueeze(0) # 添加批次维度
image_pil = Image.fromarray(img.numpy().squeeze() * 255).convert('L')
image_pil.show()
# 模型预测
with torch.no_grad():
outputs = model(img)
probabilities = F.softmax(outputs, dim=1)
# 获取预测结果
_, predicted = torch.max(probabilities, 1)
confidence = probabilities[0][predicted].item()
# 返回结果字典,包含预测类别和置信度
print(predicted, confidence)
return {str(predicted.item()): confidence}
# # 创建Gradio界面
# iface = gr.Interface(
# fn=classify_image,
# inputs=gr.Image(shape=(28, 28), image_mode='L', source="upload", tool="editor"),
# outputs=gr.Label(num_top_classes=1),
# title="MNIST Classification with ViT",
# description="上传一张28x28的灰度图像,模型将预测其所属的数字类别。"
# )
iface = gr.Interface(
fn=classify_image,
inputs=gr.Sketchpad(type='pil', image_mode='L', brush=gr.Brush(default_size=18), crop_size=(600, 600)),
outputs=gr.Label(num_top_classes=1),
title="MNIST Digit Classification with ViT",
description="Use the mouse to hand draw a number and the model will predict the category it belongs to."
)
if __name__ == "__main__":
iface.launch() |