done myb
Browse files- __pycache__/model.cpython-310.pyc +0 -0
- app.py +18 -3
- custom_transformer/__pycache__/embedding.cpython-310.pyc +0 -0
- custom_transformer/__pycache__/encoder.cpython-310.pyc +0 -0
- custom_transformer/__pycache__/vit.cpython-310.pyc +0 -0
- custom_transformer/embedding.py +82 -0
- custom_transformer/encoder.py +97 -0
- custom_transformer/vit.py +43 -0
- examples/angular_leaf_spot_example.jpg +0 -0
- examples/bean_rust_example.jpg +0 -0
- examples/healthy_example.jpg +0 -0
- main.py +0 -4
__pycache__/model.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
|
|
|
app.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from PIL import Image
|
| 3 |
-
|
| 4 |
|
| 5 |
import torch
|
| 6 |
|
| 7 |
from model import ClassifierModel
|
| 8 |
|
|
|
|
|
|
|
| 9 |
class GradioApp:
|
| 10 |
|
| 11 |
def __init__(self) -> None:
|
|
@@ -19,6 +21,7 @@ class GradioApp:
|
|
| 19 |
|
| 20 |
def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
|
| 21 |
|
|
|
|
| 22 |
if isinstance(self.models[model_name], str):
|
| 23 |
self.models[model_name] = torch.load(self.models[model_name], map_location='cpu')
|
| 24 |
|
|
@@ -29,10 +32,22 @@ class GradioApp:
|
|
| 29 |
|
| 30 |
def launch(self):
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
demo = gr.Interface(
|
| 33 |
fn=self.predict,
|
| 34 |
-
inputs=[
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
)
|
| 37 |
demo.launch()
|
| 38 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from PIL import Image
|
| 3 |
+
import os
|
| 4 |
|
| 5 |
import torch
|
| 6 |
|
| 7 |
from model import ClassifierModel
|
| 8 |
|
| 9 |
+
from typing import List, Dict, Union
|
| 10 |
+
|
| 11 |
class GradioApp:
|
| 12 |
|
| 13 |
def __init__(self) -> None:
|
|
|
|
| 21 |
|
| 22 |
def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
|
| 23 |
|
| 24 |
+
# Lazy loading of models
|
| 25 |
if isinstance(self.models[model_name], str):
|
| 26 |
self.models[model_name] = torch.load(self.models[model_name], map_location='cpu')
|
| 27 |
|
|
|
|
| 32 |
|
| 33 |
def launch(self):
|
| 34 |
|
| 35 |
+
dataset_url = 'https://www.kaggle.com/datasets/marquis03/bean-leaf-lesions-classification/data'
|
| 36 |
+
github_repo_url = 'https://github.com/i4ata/TransformerClassification'
|
| 37 |
+
examples_list = [['examples/' + example] for example in os.listdir('examples')]
|
| 38 |
+
|
| 39 |
demo = gr.Interface(
|
| 40 |
fn=self.predict,
|
| 41 |
+
inputs=[
|
| 42 |
+
gr.Image(type='filepath', label='Input image to classify'),
|
| 43 |
+
gr.Radio(choices=('Custom', 'Pretrained'), label='Available models')
|
| 44 |
+
],
|
| 45 |
+
outputs=gr.Label(num_top_classes=3, label='Model predictions'),
|
| 46 |
+
title='Plants Diseases Classification',
|
| 47 |
+
description=f'This model performs classification on images of leaves that are either healthy, \
|
| 48 |
+
have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \
|
| 49 |
+
The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).',
|
| 50 |
+
examples=examples_list
|
| 51 |
)
|
| 52 |
demo.launch()
|
| 53 |
|
custom_transformer/__pycache__/embedding.cpython-310.pyc
ADDED
|
Binary file (3.07 kB). View file
|
|
|
custom_transformer/__pycache__/encoder.cpython-310.pyc
ADDED
|
Binary file (4.55 kB). View file
|
|
|
custom_transformer/__pycache__/vit.cpython-310.pyc
ADDED
|
Binary file (1.87 kB). View file
|
|
|
custom_transformer/embedding.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
# Use that for fancy colored prints
|
| 7 |
+
from termcolor import colored
|
| 8 |
+
|
| 9 |
+
DEBUG = False
|
| 10 |
+
|
| 11 |
+
class PatchEmbedding(nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(self, in_channels: int = 3, embedding_dim: int = 768, patch_size: int = 16) -> None:
|
| 14 |
+
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
# Linear projection:
|
| 18 |
+
self.linear_projection = nn.Conv2d(in_channels=in_channels, out_channels=embedding_dim, kernel_size=patch_size, stride=patch_size)
|
| 19 |
+
|
| 20 |
+
# Flattening:
|
| 21 |
+
self.flatten = nn.Flatten(start_dim=2)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
|
| 25 |
+
# Input: [batch_size, in_channels, H, W]
|
| 26 |
+
if DEBUG: print(f'Patch embedding input shape: {x.shape} [batch_size, in_channels, image_height, image_width]')
|
| 27 |
+
|
| 28 |
+
# Linear Projection: [batch_size, embedding_dim, sqrt(n_patches), sqrt(n_patches)]
|
| 29 |
+
x = self.linear_projection(x)
|
| 30 |
+
if DEBUG: print(f'Linearly projected input: {x.shape} [batch_size, embedding_dim, sqrt(n_patches), sqrt(n_patches)]')
|
| 31 |
+
|
| 32 |
+
# Flattening: [batch_size, embedding_dim, n_patches]
|
| 33 |
+
x = self.flatten(x)
|
| 34 |
+
if DEBUG: print(f'Flattening of last 2 dimensions of linear projection: {x.shape} [batch_size, embedding_dim, n_patches]')
|
| 35 |
+
|
| 36 |
+
# Transpose last 2 dimensions: [batch_size, n_patches, embedding_dim]
|
| 37 |
+
x = x.mT
|
| 38 |
+
if DEBUG: print(f'Transpose last 2 dimensions: {x.shape} [batch_size, n_patches, embedding_dim]')
|
| 39 |
+
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
class Embedding(nn.Module):
|
| 43 |
+
|
| 44 |
+
def __init__(self, image_size: int = 224, in_channels: int = 3, embedding_dim: int = 768, patch_size: int = 16) -> None:
|
| 45 |
+
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
assert (image_size * image_size) % (patch_size * patch_size) == 0
|
| 49 |
+
|
| 50 |
+
self.n_patches = (image_size * image_size) // (patch_size * patch_size)
|
| 51 |
+
if DEBUG: print(f'Total number of patches: {self.n_patches}, i.e. {int(math.sqrt(self.n_patches))} x {int(math.sqrt(self.n_patches))}')
|
| 52 |
+
|
| 53 |
+
# Patch embedding defined above
|
| 54 |
+
self.patch_embedding = PatchEmbedding(in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size)
|
| 55 |
+
|
| 56 |
+
# The class token x0, 1 for each embedding dim
|
| 57 |
+
self.class_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
|
| 58 |
+
|
| 59 |
+
# The positional embedding, `n_patches` many for each embedding dim
|
| 60 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.n_patches + 1, embedding_dim))
|
| 61 |
+
|
| 62 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
|
| 64 |
+
if DEBUG: print(f'Embedding input shape: {x.shape}: [batch_size, in_channels, height, width]')
|
| 65 |
+
|
| 66 |
+
x = self.patch_embedding(x)
|
| 67 |
+
if DEBUG: print(f'Patch embedding output: {x.shape}: [batch_size, n_patches, embedding_dim]')
|
| 68 |
+
|
| 69 |
+
x = torch.cat((self.class_token.expand(len(x), -1, -1), x), dim=1)
|
| 70 |
+
if DEBUG: print(f'Class token prepended: {x.shape}: [batch_size, n_patches + 1, embedding_dim]')
|
| 71 |
+
|
| 72 |
+
x = x + self.position_embedding
|
| 73 |
+
if DEBUG: print(f'Positional embedding added: {x.shape}: [batch_size, n_patches + 1, embedding_dim]')
|
| 74 |
+
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
if __name__ == '__main__':
|
| 78 |
+
DEBUG = True
|
| 79 |
+
sample_image_batch = torch.rand(5,3,224,224)
|
| 80 |
+
embedding = Embedding()
|
| 81 |
+
out = embedding(sample_image_batch)
|
| 82 |
+
print(out)
|
custom_transformer/encoder.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
DEBUG = False
|
| 5 |
+
|
| 6 |
+
class MultiHeadSelfAttention(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, embedding_dim: int = 768, num_heads: int = 12) -> None:
|
| 9 |
+
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.num_heads = num_heads
|
| 13 |
+
self.head_dim = embedding_dim // num_heads
|
| 14 |
+
|
| 15 |
+
self.q_w, self.k_w, self.v_w, self.out_w = (nn.Linear(embedding_dim, embedding_dim) for _ in range(4))
|
| 16 |
+
|
| 17 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 18 |
+
|
| 19 |
+
if DEBUG: print(f'MSA Input shape (Q, K, V): {q.shape}: [batch_size, n_patches, embedding_dim]')
|
| 20 |
+
|
| 21 |
+
# Linear projections for Q, K, V
|
| 22 |
+
if DEBUG: print(f'Linear projection for Q, K, V: {q.shape} [batch_size, n_patches, embedding_dim]')
|
| 23 |
+
q = self.q_w(q).view(*q.shape[:-1], self.num_heads, self.head_dim)
|
| 24 |
+
k = self.k_w(k).view(*k.shape[:-1], self.num_heads, self.head_dim)
|
| 25 |
+
v = self.q_w(v).view(*v.shape[:-1], self.num_heads, self.head_dim)
|
| 26 |
+
if DEBUG: print(f'Splitting the last dimension once for each head: {q.shape} [batch_size, n_patches, num_heads, head_dim]')
|
| 27 |
+
|
| 28 |
+
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
| 29 |
+
if DEBUG: print(f'Swap patches and head to have the head come first: {q.shape} [batch_size, num_heads, n_patches, head_dim]')
|
| 30 |
+
|
| 31 |
+
attention_scores = torch.matmul(q, k.mT) / (self.head_dim ** .5)
|
| 32 |
+
if DEBUG: print(f'Compute attention scores for each head (scaled dot product): {attention_scores.shape} [batch_size, num_heads, n_patches, n_patches]')
|
| 33 |
+
|
| 34 |
+
attention_weights = torch.softmax(attention_scores, dim=-1)
|
| 35 |
+
if DEBUG: print(f'Softmax of attention scores: {attention_weights.shape} [batch_size, num_batches, n_patches, n_patches]')
|
| 36 |
+
|
| 37 |
+
weighted_sum = torch.matmul(attention_weights, v)
|
| 38 |
+
if DEBUG: print(f'Weighted sum of Values: {weighted_sum.shape} [batch_size, num_heads, n_patches, head_dim]')
|
| 39 |
+
|
| 40 |
+
weighted_sum = weighted_sum.transpose(1, 2).contiguous()
|
| 41 |
+
if DEBUG: print(f'Swap again the patches and the heads: {weighted_sum.shape} [batch_size, n_patches, num_heads, head_dim]')
|
| 42 |
+
|
| 43 |
+
weighted_sum = weighted_sum.view(*weighted_sum.shape[:-2], -1)
|
| 44 |
+
if DEBUG: print(f'Recover the original dimensions by merging the last 2: {weighted_sum.shape} [batch_size, n_patches, embedding_dim]')
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
output = self.out_w(weighted_sum)
|
| 48 |
+
if DEBUG: print(f'(Output) Linear projection of the weighted sum: {output.shape} [batch_size, num_heads, n_patches, embedding_dim]')
|
| 49 |
+
|
| 50 |
+
return output
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class MSABlock(nn.Module):
|
| 54 |
+
|
| 55 |
+
def __init__(self, embedding_dim: int = 768, num_heads: int = 12) -> None:
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.msa = MultiHeadSelfAttention(embedding_dim=embedding_dim, num_heads=num_heads)
|
| 58 |
+
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
|
| 59 |
+
|
| 60 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 61 |
+
x = self.layer_norm(x)
|
| 62 |
+
return self.msa(x, x, x)
|
| 63 |
+
|
| 64 |
+
class MLPBlock(nn.Module):
|
| 65 |
+
|
| 66 |
+
def __init__(self, embedding_dim: int = 768, hidden_size: int = 3072) -> None:
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
|
| 69 |
+
self.mlp = nn.Sequential(
|
| 70 |
+
nn.Linear(in_features=embedding_dim, out_features=hidden_size),
|
| 71 |
+
nn.GELU(),
|
| 72 |
+
nn.Linear(in_features=hidden_size, out_features=embedding_dim)
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
return self.mlp(self.layer_norm(x))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TransformerEncoderBlock(nn.Module):
|
| 80 |
+
|
| 81 |
+
def __init__(self, embedding_dim: int = 768, hidden_size: int = 3072, num_heads: int = 12) -> None:
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.msa = MSABlock(embedding_dim=embedding_dim, num_heads=num_heads)
|
| 84 |
+
self.mlp = MLPBlock(embedding_dim=embedding_dim, hidden_size=hidden_size)
|
| 85 |
+
|
| 86 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
x = self.msa(x) + x
|
| 88 |
+
x = self.mlp(x) + x
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
if __name__ == '__main__':
|
| 92 |
+
|
| 93 |
+
DEBUG = True
|
| 94 |
+
x = torch.rand(5, 197, 768)
|
| 95 |
+
msa = MultiHeadSelfAttention()
|
| 96 |
+
out = msa(x,x,x)
|
| 97 |
+
print(out.shape)
|
custom_transformer/vit.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append('..')
|
| 6 |
+
from custom_transformer.embedding import Embedding
|
| 7 |
+
from custom_transformer.encoder import TransformerEncoderBlock
|
| 8 |
+
|
| 9 |
+
class ViT(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(self,
|
| 12 |
+
image_size: int = 224,
|
| 13 |
+
in_channels: int = 3,
|
| 14 |
+
patch_size: int = 16,
|
| 15 |
+
num_transformer_layers: int = 12,
|
| 16 |
+
embedding_dim: int = 768,
|
| 17 |
+
mlp_size: int = 3072,
|
| 18 |
+
num_heads: int = 12,
|
| 19 |
+
num_classes: int = 3) -> None:
|
| 20 |
+
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.embedding = Embedding(image_size=image_size, in_channels=in_channels, embedding_dim=embedding_dim, patch_size=patch_size)
|
| 24 |
+
self.transformer_encoders = nn.Sequential(
|
| 25 |
+
*[TransformerEncoderBlock(embedding_dim=embedding_dim, hidden_size=mlp_size, num_heads=num_heads)
|
| 26 |
+
for _ in range(num_transformer_layers)]
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
self.classifier = nn.Sequential(
|
| 30 |
+
nn.LayerNorm(normalized_shape=embedding_dim),
|
| 31 |
+
nn.Linear(in_features=embedding_dim, out_features=num_classes)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
x = self.embedding(x)
|
| 36 |
+
x = self.transformer_encoders(x)
|
| 37 |
+
x = self.classifier(x[:, 0])
|
| 38 |
+
return x
|
| 39 |
+
|
| 40 |
+
if __name__ == '__main__':
|
| 41 |
+
sample_image_batch = torch.rand(5,3,500,500)
|
| 42 |
+
vit = ViT(image_size=500, patch_size=50)
|
| 43 |
+
print(vit(sample_image_batch).shape)
|
examples/angular_leaf_spot_example.jpg
ADDED
|
examples/bean_rust_example.jpg
ADDED
|
examples/healthy_example.jpg
ADDED
|
main.py
DELETED
|
@@ -1,4 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
|
| 3 |
-
a = torch.load('models/pretrained_vit.pth', map_location='cpu')
|
| 4 |
-
print(a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|