Youseff1987 commited on
Commit
74b427f
·
verified ·
1 Parent(s): 7b13f44

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +12 -0
model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import torch.nn as nn
3
+ from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights
4
+
5
+ def load_model(device='cpu'):
6
+ weights = EfficientNet_B7_Weights.DEFAULT
7
+ model = efficientnet_b7(weights=weights)
8
+ model.classifier[1] = nn.Linear(2560, 15)
9
+ return model.to(device)
10
+
11
+ def get_preprocess():
12
+ return EfficientNet_B7_Weights.DEFAULT.transforms()