Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image, ImageOps | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # 如果你的模型结构与标准的torchvision模型不同,请确保在此处定义或导入你的模型结构 | |
| # 例如,如果你有一个model.py文件: | |
| # from model import ViTModel | |
| # 示例:定义一个简单的ViT模型结构(请根据你的实际模型调整) | |
| class ViT(nn.Module): | |
| def __init__(self, image_size=28, patch_size=7, num_classes=10, dim=128, depth=6, heads=8, mlp_dim=256, dropout=0.1): | |
| super(ViT, self).__init__() | |
| assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' | |
| num_patches = (image_size // patch_size) ** 2 | |
| patch_dim = 1 * patch_size ** 2 | |
| # 定义线性层将图像分块并映射到嵌入空间 | |
| self.patch_embedding = nn.Linear(patch_dim, dim) | |
| # 位置编码 | |
| # nn.Parameter是Pytorch中的一个类,用于将一个张量注册为模型的参数 | |
| self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) | |
| # Dropout层 | |
| self.dropout = nn.Dropout(dropout) | |
| # Transformer编码器 | |
| # 当 batch_first=True 时,输入和输出张量的形状为 (batch_size, seq_length, feature_dim)。当 batch_first=False 时,输入和输出张量的形状为 (seq_length, batch_size, feature_dim)。 | |
| self.transformer = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer( | |
| d_model=dim, | |
| nhead=heads, | |
| dim_feedforward=mlp_dim | |
| # batch_first=True | |
| ), | |
| num_layers=depth | |
| ) | |
| # 分类头 | |
| # nn.Identity()是一个空的层,它不执行任何操作,只是返回输入 | |
| # self.to_cls_token = nn.Identity() | |
| # self.mlp_head = nn.Linear(dim, num_classes) | |
| self.mlp_head = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| nn.Linear(dim, num_classes) | |
| ) | |
| def forward(self, x): | |
| # x shape: [batch_size, 1, 28, 28] | |
| batch_size = x.size(0) | |
| x = x.view(batch_size, -1, 7*7) # 将图像划分为7x7的Patch | |
| x = self.patch_embedding(x) # [batch_size, num_patches, dim] | |
| x += self.pos_embedding # 添加位置编码 | |
| x = self.dropout(x) # 应用Dropout | |
| x = x.permute(1, 0, 2) # Transformer期望的输入形状:[seq_len, batch_size, embedding_dim] | |
| x = self.transformer(x) # [序列长度, batch_size, dim] | |
| x = x.permute(1, 0, 2) # 转回原来的形状:[batch_size, seq_len, dim] | |
| x = x.mean(dim=1) # 对所有Patch取平均,x.mean(dim=1) 这一步是对所有 Patch 的特征向量取平均值,从而得到一个代表整个图像的全局特征向量。 | |
| x = self.mlp_head(x) # [batch_size, num_classes] | |
| return x | |
| # 加载模型 | |
| model = ViT(num_classes=10) # 确保num_classes与你的MNIST任务一致 | |
| model_path = "vit_model.pth" # 模型权重文件名 | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)) | |
| model.eval() | |
| # 定义图像预处理 | |
| transform = transforms.Compose([ | |
| transforms.Grayscale(num_output_channels=1), # 转换为单通道 | |
| transforms.Resize((28, 28)), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| # 定义预测函数 | |
| def classify_image(image): | |
| # 检查是否包含 'composite' 数据 | |
| if isinstance(image, dict) and 'composite' in image: | |
| image = image['composite'] | |
| # 确保 image 是一个 PIL 图像 | |
| if not isinstance(image, Image.Image): | |
| raise TypeError(f"Expected image to be PIL Image, but got {type(image)}") | |
| # 打印image的数组 | |
| print(image) | |
| # 图像预处理 | |
| img = transform(image).unsqueeze(0) # 添加批次维度 | |
| image_pil = Image.fromarray(img.numpy().squeeze() * 255).convert('L') | |
| image_pil.show() | |
| # 模型预测 | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| probabilities = F.softmax(outputs, dim=1) | |
| # 获取预测结果 | |
| _, predicted = torch.max(probabilities, 1) | |
| confidence = probabilities[0][predicted].item() | |
| # 返回结果字典,包含预测类别和置信度 | |
| print(predicted, confidence) | |
| return {str(predicted.item()): confidence} | |
| # # 创建Gradio界面 | |
| # iface = gr.Interface( | |
| # fn=classify_image, | |
| # inputs=gr.Image(shape=(28, 28), image_mode='L', source="upload", tool="editor"), | |
| # outputs=gr.Label(num_top_classes=1), | |
| # title="MNIST Classification with ViT", | |
| # description="上传一张28x28的灰度图像,模型将预测其所属的数字类别。" | |
| # ) | |
| iface = gr.Interface( | |
| fn=classify_image, | |
| inputs=gr.Sketchpad(type='pil', image_mode='L', brush=gr.Brush(default_size=18), crop_size=(600, 600)), | |
| outputs=gr.Label(num_top_classes=1), | |
| title="MNIST Digit Classification with ViT", | |
| description="Use the mouse to hand draw a number and the model will predict the category it belongs to." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() |