i4ata commited on
Commit
cce011e
·
1 Parent(s): 26a33b7
__pycache__/model.cpython-310.pyc CHANGED
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
 
app.py CHANGED
@@ -1,11 +1,13 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from typing import List, Dict, Union
4
 
5
  import torch
6
 
7
  from model import ClassifierModel
8
 
 
 
9
  class GradioApp:
10
 
11
  def __init__(self) -> None:
@@ -19,6 +21,7 @@ class GradioApp:
19
 
20
  def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
21
 
 
22
  if isinstance(self.models[model_name], str):
23
  self.models[model_name] = torch.load(self.models[model_name], map_location='cpu')
24
 
@@ -29,10 +32,22 @@ class GradioApp:
29
 
30
  def launch(self):
31
 
 
 
 
 
32
  demo = gr.Interface(
33
  fn=self.predict,
34
- inputs=[gr.Image(type='filepath'), gr.Radio(('Custom', 'Pretrained'))],
35
- outputs=gr.Label(num_top_classes=3),
 
 
 
 
 
 
 
 
36
  )
37
  demo.launch()
38
 
 
1
  import gradio as gr
2
  from PIL import Image
3
+ import os
4
 
5
  import torch
6
 
7
  from model import ClassifierModel
8
 
9
+ from typing import List, Dict, Union
10
+
11
  class GradioApp:
12
 
13
  def __init__(self) -> None:
 
21
 
22
  def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
23
 
24
+ # Lazy loading of models
25
  if isinstance(self.models[model_name], str):
26
  self.models[model_name] = torch.load(self.models[model_name], map_location='cpu')
27
 
 
32
 
33
  def launch(self):
34
 
35
+ dataset_url = 'https://www.kaggle.com/datasets/marquis03/bean-leaf-lesions-classification/data'
36
+ github_repo_url = 'https://github.com/i4ata/TransformerClassification'
37
+ examples_list = [['examples/' + example] for example in os.listdir('examples')]
38
+
39
  demo = gr.Interface(
40
  fn=self.predict,
41
+ inputs=[
42
+ gr.Image(type='filepath', label='Input image to classify'),
43
+ gr.Radio(choices=('Custom', 'Pretrained'), label='Available models')
44
+ ],
45
+ outputs=gr.Label(num_top_classes=3, label='Model predictions'),
46
+ title='Plants Diseases Classification',
47
+ description=f'This model performs classification on images of leaves that are either healthy, \
48
+ have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \
49
+ The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).',
50
+ examples=examples_list
51
  )
52
  demo.launch()
53
 
custom_transformer/__pycache__/embedding.cpython-310.pyc ADDED
Binary file (3.07 kB). View file
 
custom_transformer/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (4.55 kB). View file
 
custom_transformer/__pycache__/vit.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
custom_transformer/embedding.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import math
5
+
6
+ # Use that for fancy colored prints
7
+ from termcolor import colored
8
+
9
+ DEBUG = False
10
+
11
+ class PatchEmbedding(nn.Module):
12
+
13
+ def __init__(self, in_channels: int = 3, embedding_dim: int = 768, patch_size: int = 16) -> None:
14
+
15
+ super().__init__()
16
+
17
+ # Linear projection:
18
+ self.linear_projection = nn.Conv2d(in_channels=in_channels, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size)
19
+
20
+ # Flattening:
21
+ self.flatten = nn.Flatten(start_dim=2)
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+
25
+ # Input: [batch_size, in_channels, H, W]
26
+ if DEBUG: print(f'Patch embedding input shape: {x.shape} [batch_size, in_channels, image_height, image_width]')
27
+
28
+ # Linear Projection: [batch_size, embedding_dim, sqrt(n_patches), sqrt(n_patches)]
29
+ x = self.linear_projection(x)
30
+ if DEBUG: print(f'Linearly projected input: {x.shape} [batch_size, embedding_dim, sqrt(n_patches), sqrt(n_patches)]')
31
+
32
+ # Flattening: [batch_size, embedding_dim, n_patches]
33
+ x = self.flatten(x)
34
+ if DEBUG: print(f'Flattening of last 2 dimensions of linear projection: {x.shape} [batch_size, embedding_dim, n_patches]')
35
+
36
+ # Transpose last 2 dimensions: [batch_size, n_patches, embedding_dim]
37
+ x = x.mT
38
+ if DEBUG: print(f'Transpose last 2 dimensions: {x.shape} [batch_size, n_patches, embedding_dim]')
39
+
40
+ return x
41
+
42
+ class Embedding(nn.Module):
43
+
44
+ def __init__(self, image_size: int = 224, in_channels: int = 3, embedding_dim: int = 768, patch_size: int = 16) -> None:
45
+
46
+ super().__init__()
47
+
48
+ assert (image_size * image_size) % (patch_size * patch_size) == 0
49
+
50
+ self.n_patches = (image_size * image_size) // (patch_size * patch_size)
51
+ if DEBUG: print(f'Total number of patches: {self.n_patches}, i.e. {int(math.sqrt(self.n_patches))} x {int(math.sqrt(self.n_patches))}')
52
+
53
+ # Patch embedding defined above
54
+ self.patch_embedding = PatchEmbedding(in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size)
55
+
56
+ # The class token x0, 1 for each embedding dim
57
+ self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
58
+
59
+ # The positional embedding, `n_patches` many for each embedding dim
60
+ self.position_embedding = nn.Parameter(torch.randn(1, self.n_patches + 1, embedding_dim))
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+
64
+ if DEBUG: print(f'Embedding input shape: {x.shape}: [batch_size, in_channels, height, width]')
65
+
66
+ x = self.patch_embedding(x)
67
+ if DEBUG: print(f'Patch embedding output: {x.shape}: [batch_size, n_patches, embedding_dim]')
68
+
69
+ x = torch.cat((self.class_token.expand(len(x), -1, -1), x), dim=1)
70
+ if DEBUG: print(f'Class token prepended: {x.shape}: [batch_size, n_patches + 1, embedding_dim]')
71
+
72
+ x = x + self.position_embedding
73
+ if DEBUG: print(f'Positional embedding added: {x.shape}: [batch_size, n_patches + 1, embedding_dim]')
74
+
75
+ return x
76
+
77
+ if __name__ == '__main__':
78
+ DEBUG = True
79
+ sample_image_batch = torch.rand(5,3,224,224)
80
+ embedding = Embedding()
81
+ out = embedding(sample_image_batch)
82
+ print(out)
custom_transformer/encoder.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ DEBUG = False
5
+
6
+ class MultiHeadSelfAttention(nn.Module):
7
+
8
+ def __init__(self, embedding_dim: int = 768, num_heads: int = 12) -> None:
9
+
10
+ super().__init__()
11
+
12
+ self.num_heads = num_heads
13
+ self.head_dim = embedding_dim // num_heads
14
+
15
+ self.q_w, self.k_w, self.v_w, self.out_w = (nn.Linear(embedding_dim, embedding_dim) for _ in range(4))
16
+
17
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
18
+
19
+ if DEBUG: print(f'MSA Input shape (Q, K, V): {q.shape}: [batch_size, n_patches, embedding_dim]')
20
+
21
+ # Linear projections for Q, K, V
22
+ if DEBUG: print(f'Linear projection for Q, K, V: {q.shape} [batch_size, n_patches, embedding_dim]')
23
+ q = self.q_w(q).view(*q.shape[:-1], self.num_heads, self.head_dim)
24
+ k = self.k_w(k).view(*k.shape[:-1], self.num_heads, self.head_dim)
25
+ v = self.q_w(v).view(*v.shape[:-1], self.num_heads, self.head_dim)
26
+ if DEBUG: print(f'Splitting the last dimension once for each head: {q.shape} [batch_size, n_patches, num_heads, head_dim]')
27
+
28
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
29
+ if DEBUG: print(f'Swap patches and head to have the head come first: {q.shape} [batch_size, num_heads, n_patches, head_dim]')
30
+
31
+ attention_scores = torch.matmul(q, k.mT) / (self.head_dim ** .5)
32
+ if DEBUG: print(f'Compute attention scores for each head (scaled dot product): {attention_scores.shape} [batch_size, num_heads, n_patches, n_patches]')
33
+
34
+ attention_weights = torch.softmax(attention_scores, dim=-1)
35
+ if DEBUG: print(f'Softmax of attention scores: {attention_weights.shape} [batch_size, num_batches, n_patches, n_patches]')
36
+
37
+ weighted_sum = torch.matmul(attention_weights, v)
38
+ if DEBUG: print(f'Weighted sum of Values: {weighted_sum.shape} [batch_size, num_heads, n_patches, head_dim]')
39
+
40
+ weighted_sum = weighted_sum.transpose(1, 2).contiguous()
41
+ if DEBUG: print(f'Swap again the patches and the heads: {weighted_sum.shape} [batch_size, n_patches, num_heads, head_dim]')
42
+
43
+ weighted_sum = weighted_sum.view(*weighted_sum.shape[:-2], -1)
44
+ if DEBUG: print(f'Recover the original dimensions by merging the last 2: {weighted_sum.shape} [batch_size, n_patches, embedding_dim]')
45
+
46
+
47
+ output = self.out_w(weighted_sum)
48
+ if DEBUG: print(f'(Output) Linear projection of the weighted sum: {output.shape} [batch_size, num_heads, n_patches, embedding_dim]')
49
+
50
+ return output
51
+
52
+
53
+ class MSABlock(nn.Module):
54
+
55
+ def __init__(self, embedding_dim: int = 768, num_heads: int = 12) -> None:
56
+ super().__init__()
57
+ self.msa = MultiHeadSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads)
58
+ self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ x = self.layer_norm(x)
62
+ return self.msa(x, x, x)
63
+
64
+ class MLPBlock(nn.Module):
65
+
66
+ def __init__(self, embedding_dim: int = 768, hidden_size: int = 3072) -> None:
67
+ super().__init__()
68
+ self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
69
+ self.mlp = nn.Sequential(
70
+ nn.Linear(in_features=embedding_dim, out_features=hidden_size),
71
+ nn.GELU(),
72
+ nn.Linear(in_features=hidden_size, out_features=embedding_dim)
73
+ )
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ return self.mlp(self.layer_norm(x))
77
+
78
+
79
+ class TransformerEncoderBlock(nn.Module):
80
+
81
+ def __init__(self, embedding_dim: int = 768, hidden_size: int = 3072, num_heads: int = 12) -> None:
82
+ super().__init__()
83
+ self.msa = MSABlock(embedding_dim=embedding_dim, num_heads=num_heads)
84
+ self.mlp = MLPBlock(embedding_dim=embedding_dim, hidden_size=hidden_size)
85
+
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ x = self.msa(x) + x
88
+ x = self.mlp(x) + x
89
+ return x
90
+
91
+ if __name__ == '__main__':
92
+
93
+ DEBUG = True
94
+ x = torch.rand(5, 197, 768)
95
+ msa = MultiHeadSelfAttention()
96
+ out = msa(x,x,x)
97
+ print(out.shape)
custom_transformer/vit.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import sys
5
+ sys.path.append('..')
6
+ from custom_transformer.embedding import Embedding
7
+ from custom_transformer.encoder import TransformerEncoderBlock
8
+
9
+ class ViT(nn.Module):
10
+
11
+ def __init__(self,
12
+ image_size: int = 224,
13
+ in_channels: int = 3,
14
+ patch_size: int = 16,
15
+ num_transformer_layers: int = 12,
16
+ embedding_dim: int = 768,
17
+ mlp_size: int = 3072,
18
+ num_heads: int = 12,
19
+ num_classes: int = 3) -> None:
20
+
21
+ super().__init__()
22
+
23
+ self.embedding = Embedding(image_size=image_size, in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size)
24
+ self.transformer_encoders = nn.Sequential(
25
+ *[TransformerEncoderBlock(embedding_dim=embedding_dim, hidden_size=mlp_size, num_heads=num_heads)
26
+ for _ in range(num_transformer_layers)]
27
+ )
28
+
29
+ self.classifier = nn.Sequential(
30
+ nn.LayerNorm(normalized_shape=embedding_dim),
31
+ nn.Linear(in_features=embedding_dim, out_features=num_classes)
32
+ )
33
+
34
+ def forward(self, x):
35
+ x = self.embedding(x)
36
+ x = self.transformer_encoders(x)
37
+ x = self.classifier(x[:, 0])
38
+ return x
39
+
40
+ if __name__ == '__main__':
41
+ sample_image_batch = torch.rand(5,3,500,500)
42
+ vit = ViT(image_size=500, patch_size=50)
43
+ print(vit(sample_image_batch).shape)
examples/angular_leaf_spot_example.jpg ADDED
examples/bean_rust_example.jpg ADDED
examples/healthy_example.jpg ADDED
main.py DELETED
@@ -1,4 +0,0 @@
1
- import torch
2
-
3
- a = torch.load('models/pretrained_vit.pth', map_location='cpu')
4
- print(a)