i4ata commited on
Commit
563e5fb
·
1 Parent(s): 0548088

small change

Browse files
Files changed (3) hide show
  1. app.py +2 -4
  2. transforms.py +0 -12
  3. utils.py +29 -0
app.py CHANGED
@@ -1,13 +1,12 @@
1
  import torch
2
  import torch.nn as nn
3
- from torchvision import models
4
  import gradio as gr
5
  from PIL import Image
6
  import os
7
  from typing import List, Dict, Union
8
 
9
  from custom_transformer.vit import ViT
10
- from transforms import model_transforms
11
 
12
  class GradioApp:
13
 
@@ -18,8 +17,7 @@ class GradioApp:
18
  custom = ViT().to(device).eval()
19
  custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
20
 
21
- pretrained = models.vit_b_16().to(device).eval()
22
- pretrained.heads = nn.Linear(768, 3)
23
  pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
24
 
25
  self.models: Dict[str, Union[str, nn.Module]] = {
 
1
  import torch
2
  import torch.nn as nn
 
3
  import gradio as gr
4
  from PIL import Image
5
  import os
6
  from typing import List, Dict, Union
7
 
8
  from custom_transformer.vit import ViT
9
+ from utils import model_transforms, get_pretrained_vit
10
 
11
  class GradioApp:
12
 
 
17
  custom = ViT().to(device).eval()
18
  custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
19
 
20
+ pretrained = get_pretrained_vit().to(device).eval()
 
21
  pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
22
 
23
  self.models: Dict[str, Union[str, nn.Module]] = {
transforms.py DELETED
@@ -1,12 +0,0 @@
1
- from torchvision import transforms, models
2
- from typing import Literal, Dict
3
-
4
- _weights = models.ViT_B_16_Weights.DEFAULT
5
-
6
- model_transforms: Dict[Literal['Custom', 'Pretrained'], transforms.Compose] = {
7
- 'Custom': transforms.Compose([
8
- transforms.Resize((224, 224)),
9
- transforms.ToTensor()
10
- ]),
11
- 'Pretrained': _weights.transforms()
12
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchvision import transforms, models
3
+ from typing import Literal, Dict
4
+
5
+ _weights = models.ViT_B_16_Weights.DEFAULT
6
+
7
+ model_transforms: Dict[Literal['custom', 'pretrained'], Dict[Literal['train', 'val'], transforms.Compose]] = {
8
+ 'custom': {
9
+ 'train': transforms.Compose([
10
+ transforms.Resize((224, 224)),
11
+ transforms.TrivialAugmentWide(),
12
+ transforms.ToTensor()
13
+ ]),
14
+ 'val': transforms.Compose([
15
+ transforms.Resize((224, 224)),
16
+ transforms.ToTensor()
17
+ ])
18
+ },
19
+ 'pretrained': {
20
+ 'train': _weights.transforms(),
21
+ 'val': _weights.transforms()
22
+ }
23
+ }
24
+
25
+ def get_pretrained_vit() -> models.VisionTransformer:
26
+ model = models.vit_b_16(weights='DEFAULT')
27
+ for parameter in model.parameters(): parameter.requires_grad = False
28
+ model.heads = nn.Linear(in_features=768, out_features=3)
29
+ return model