File size: 2,218 Bytes
35c9708
a34c6f1
35c9708
 
 
 
a34c6f1
35c9708
 
 
44a9406
a34c6f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44a9406
a34c6f1
 
35c9708
 
 
 
 
 
 
a34c6f1
35c9708
a34c6f1
35c9708
 
 
 
 
a34c6f1
35c9708
 
 
a34c6f1
 
35c9708
 
44a9406
a34c6f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr


classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']


class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, embed_dim=64):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  
        x = x.flatten(2).transpose(1, 2) 
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        x = x.permute(1, 0, 2)  
        attn_output, _ = self.attention(x, x, x)
        return attn_output.permute(1, 0, 2) 

class ViT(nn.Module):
    def __init__(self, num_classes=10, embed_dim=64, num_heads=4, num_layers=2):
        super().__init__()
        self.patch_embed = PatchEmbedding(embed_dim=embed_dim)
        self.transformer_layers = nn.ModuleList([
            MultiHeadSelfAttention(embed_dim, num_heads) for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        for layer in self.transformer_layers:
            x = layer(x) + x 
        x = x.mean(dim=1)  
        return self.classifier(x)


model = ViT()
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()  


transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


def predict(image):
    image = transform(image).unsqueeze(0)  
    with torch.no_grad():
        output = model(image)
        _, predicted = torch.max(output, 1)
        return classes[predicted.item()]


interface = gr.Interface(fn=predict, 
                         inputs=gr.Image(type="pil"), 
                         outputs="label", 
                         title="CIFAR-10 Image Classification with ViT")


interface.launch()