mmek commited on
Commit
46676c9
·
1 Parent(s): fe49efd

add model transforms

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  import timm
3
  import torch
4
-
5
  model = timm.create_model("mobileone_s2", pretrained = False)
6
  model.head.fc = torch.nn.Linear(model.head.fc.in_features,3)
7
- transforms = timm.data.create_transform(**timm.data.resolve_data_config(model.pretrained_cfg)).transforms
8
  model.load_state_dict(torch.load("olive-classifier.pth", map_location=torch.device('cpu'), weights_only=True))
9
  model.eval()
10
 
 
1
  import gradio as gr
2
  import timm
3
  import torch
4
+ from torchvision import transforms
5
  model = timm.create_model("mobileone_s2", pretrained = False)
6
  model.head.fc = torch.nn.Linear(model.head.fc.in_features,3)
7
+ transforms = transforms.Compose(timm.data.create_transform(**timm.data.resolve_data_config(model.pretrained_cfg)).transforms)
8
  model.load_state_dict(torch.load("olive-classifier.pth", map_location=torch.device('cpu'), weights_only=True))
9
  model.eval()
10