hp1318 commited on
Commit
a34c6f1
·
verified ·
1 Parent(s): 98b596e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -12
app.py CHANGED
@@ -1,41 +1,77 @@
1
  import torch
2
- import torchvision.models as models # Replace with your ViT model if needed
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- # CIFAR-10 class names
8
  classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
9
  'dog', 'frog', 'horse', 'ship', 'truck']
10
 
11
- # Define the model architecture (replace with your ViT if needed)
12
- model = models.resnet18(num_classes=10) # Use your custom model here
13
 
14
- # Load the model weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
16
- model.eval() # Set the model to evaluation mode
 
17
 
18
- # Define image transformations
19
  transform = transforms.Compose([
20
  transforms.Resize((32, 32)),
21
  transforms.ToTensor(),
22
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
23
  ])
24
 
25
- # Define the prediction function
26
  def predict(image):
27
- image = transform(image).unsqueeze(0) # Add batch dimension
28
  with torch.no_grad():
29
  output = model(image)
30
  _, predicted = torch.max(output, 1)
31
  return classes[predicted.item()]
32
 
33
- # Create Gradio interface
34
  interface = gr.Interface(fn=predict,
35
  inputs=gr.Image(type="pil"),
36
  outputs="label",
37
- title="CIFAR-10 Image Classification")
 
38
 
39
- # Launch the app
40
  interface.launch()
41
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import gradio as gr
6
 
7
+
8
  classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
9
  'dog', 'frog', 'horse', 'ship', 'truck']
10
 
 
 
11
 
12
+ class PatchEmbedding(nn.Module):
13
+ def __init__(self, in_channels=3, patch_size=4, embed_dim=64):
14
+ super().__init__()
15
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
16
+
17
+ def forward(self, x):
18
+ x = self.proj(x)
19
+ x = x.flatten(2).transpose(1, 2)
20
+ return x
21
+
22
+ class MultiHeadSelfAttention(nn.Module):
23
+ def __init__(self, embed_dim, num_heads):
24
+ super().__init__()
25
+ self.attention = nn.MultiheadAttention(embed_dim, num_heads)
26
+
27
+ def forward(self, x):
28
+ x = x.permute(1, 0, 2)
29
+ attn_output, _ = self.attention(x, x, x)
30
+ return attn_output.permute(1, 0, 2)
31
+
32
+ class ViT(nn.Module):
33
+ def __init__(self, num_classes=10, embed_dim=64, num_heads=4, num_layers=2):
34
+ super().__init__()
35
+ self.patch_embed = PatchEmbedding(embed_dim=embed_dim)
36
+ self.transformer_layers = nn.ModuleList([
37
+ MultiHeadSelfAttention(embed_dim, num_heads) for _ in range(num_layers)
38
+ ])
39
+ self.classifier = nn.Linear(embed_dim, num_classes)
40
+
41
+ def forward(self, x):
42
+ x = self.patch_embed(x)
43
+ for layer in self.transformer_layers:
44
+ x = layer(x) + x
45
+ x = x.mean(dim=1)
46
+ return self.classifier(x)
47
+
48
+
49
+ model = ViT()
50
  model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
51
+ model.eval()
52
+
53
 
 
54
  transform = transforms.Compose([
55
  transforms.Resize((32, 32)),
56
  transforms.ToTensor(),
57
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
58
  ])
59
 
60
+
61
  def predict(image):
62
+ image = transform(image).unsqueeze(0)
63
  with torch.no_grad():
64
  output = model(image)
65
  _, predicted = torch.max(output, 1)
66
  return classes[predicted.item()]
67
 
68
+
69
  interface = gr.Interface(fn=predict,
70
  inputs=gr.Image(type="pil"),
71
  outputs="label",
72
+ title="CIFAR-10 Image Classification with ViT")
73
+
74
 
 
75
  interface.launch()
76
 
77
+