small change
Browse files- app.py +2 -4
- transforms.py +0 -12
- utils.py +29 -0
app.py
CHANGED
|
@@ -1,13 +1,12 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
-
from torchvision import models
|
| 4 |
import gradio as gr
|
| 5 |
from PIL import Image
|
| 6 |
import os
|
| 7 |
from typing import List, Dict, Union
|
| 8 |
|
| 9 |
from custom_transformer.vit import ViT
|
| 10 |
-
from
|
| 11 |
|
| 12 |
class GradioApp:
|
| 13 |
|
|
@@ -18,8 +17,7 @@ class GradioApp:
|
|
| 18 |
custom = ViT().to(device).eval()
|
| 19 |
custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
|
| 20 |
|
| 21 |
-
pretrained =
|
| 22 |
-
pretrained.heads = nn.Linear(768, 3)
|
| 23 |
pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
|
| 24 |
|
| 25 |
self.models: Dict[str, Union[str, nn.Module]] = {
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
from PIL import Image
|
| 5 |
import os
|
| 6 |
from typing import List, Dict, Union
|
| 7 |
|
| 8 |
from custom_transformer.vit import ViT
|
| 9 |
+
from utils import model_transforms, get_pretrained_vit
|
| 10 |
|
| 11 |
class GradioApp:
|
| 12 |
|
|
|
|
| 17 |
custom = ViT().to(device).eval()
|
| 18 |
custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
|
| 19 |
|
| 20 |
+
pretrained = get_pretrained_vit().to(device).eval()
|
|
|
|
| 21 |
pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
|
| 22 |
|
| 23 |
self.models: Dict[str, Union[str, nn.Module]] = {
|
transforms.py
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
from torchvision import transforms, models
|
| 2 |
-
from typing import Literal, Dict
|
| 3 |
-
|
| 4 |
-
_weights = models.ViT_B_16_Weights.DEFAULT
|
| 5 |
-
|
| 6 |
-
model_transforms: Dict[Literal['Custom', 'Pretrained'], transforms.Compose] = {
|
| 7 |
-
'Custom': transforms.Compose([
|
| 8 |
-
transforms.Resize((224, 224)),
|
| 9 |
-
transforms.ToTensor()
|
| 10 |
-
]),
|
| 11 |
-
'Pretrained': _weights.transforms()
|
| 12 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
from torchvision import transforms, models
|
| 3 |
+
from typing import Literal, Dict
|
| 4 |
+
|
| 5 |
+
_weights = models.ViT_B_16_Weights.DEFAULT
|
| 6 |
+
|
| 7 |
+
model_transforms: Dict[Literal['custom', 'pretrained'], Dict[Literal['train', 'val'], transforms.Compose]] = {
|
| 8 |
+
'custom': {
|
| 9 |
+
'train': transforms.Compose([
|
| 10 |
+
transforms.Resize((224, 224)),
|
| 11 |
+
transforms.TrivialAugmentWide(),
|
| 12 |
+
transforms.ToTensor()
|
| 13 |
+
]),
|
| 14 |
+
'val': transforms.Compose([
|
| 15 |
+
transforms.Resize((224, 224)),
|
| 16 |
+
transforms.ToTensor()
|
| 17 |
+
])
|
| 18 |
+
},
|
| 19 |
+
'pretrained': {
|
| 20 |
+
'train': _weights.transforms(),
|
| 21 |
+
'val': _weights.transforms()
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
def get_pretrained_vit() -> models.VisionTransformer:
|
| 26 |
+
model = models.vit_b_16(weights='DEFAULT')
|
| 27 |
+
for parameter in model.parameters(): parameter.requires_grad = False
|
| 28 |
+
model.heads = nn.Linear(in_features=768, out_features=3)
|
| 29 |
+
return model
|