ninjals commited on
Commit
a08abea
·
verified ·
1 Parent(s): f4d351b

Update model.py

Browse files

corrected wrong torchvision import

Files changed (1) hide show
  1. model.py +30 -30
model.py CHANGED
@@ -1,31 +1,31 @@
1
- import torch
2
- from torchvision import transforms
3
- from torch import nn
4
- from torch import torchvision
5
- def create_vit_model(num_classes: int = 3,
6
- seed: int = 42):
7
- """Creates a ViT-B/16 feature extractor model and transforms.
8
-
9
- Args:
10
- num_classes (int, optional): number of target classes. Defaults to 3.
11
- seed (int, optional): random seed value for output layer. Defaults to 42.
12
-
13
- Returns:
14
- model (torch.nn.Module): ViT-B/16 feature extractor model.
15
- transforms (torchvision.transforms): ViT-B/16 image transforms.
16
- """
17
- # Create ViT_B_16 pretrained weights, transforms and model
18
- weights = torchvision.models.ViT_B_16_Weights.DEFAULT
19
- transforms = weights.transforms()
20
- model = torchvision.models.vit_b_16(weights=weights)
21
-
22
- # Freeze all layers in model
23
- for param in model.parameters():
24
- param.requires_grad = False
25
-
26
- # Change classifier head to suit our needs (this will be trainable)
27
- torch.manual_seed(seed)
28
- model.heads = nn.Sequential(nn.Linear(in_features=768, # keep this the same as original model
29
- out_features=num_classes)) # update to reflect target number of classes
30
-
31
  return model, transforms
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from torch import nn
4
+ import torchvision
5
+ def create_vit_model(num_classes: int = 3,
6
+ seed: int = 42):
7
+ """Creates a ViT-B/16 feature extractor model and transforms.
8
+
9
+ Args:
10
+ num_classes (int, optional): number of target classes. Defaults to 3.
11
+ seed (int, optional): random seed value for output layer. Defaults to 42.
12
+
13
+ Returns:
14
+ model (torch.nn.Module): ViT-B/16 feature extractor model.
15
+ transforms (torchvision.transforms): ViT-B/16 image transforms.
16
+ """
17
+ # Create ViT_B_16 pretrained weights, transforms and model
18
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
19
+ transforms = weights.transforms()
20
+ model = torchvision.models.vit_b_16(weights=weights)
21
+
22
+ # Freeze all layers in model
23
+ for param in model.parameters():
24
+ param.requires_grad = False
25
+
26
+ # Change classifier head to suit our needs (this will be trainable)
27
+ torch.manual_seed(seed)
28
+ model.heads = nn.Sequential(nn.Linear(in_features=768, # keep this the same as original model
29
+ out_features=num_classes)) # update to reflect target number of classes
30
+
31
  return model, transforms