mmek commited on
Commit
9b544a2
·
1 Parent(s): e64e835

fix torch import

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
  import timm
3
- from torch import nn
4
 
5
  model = timm.create_model("mobileone_s2", pretrained = False)
6
- model.head.fc = nn.Linear(model.head.fc.in_features,3)
7
  model.load_state_dict(torch.load("olive_classifier.pth", weights_only=True))
8
  model.eval()
9
 
 
1
  import gradio as gr
2
  import timm
3
+ import torch
4
 
5
  model = timm.create_model("mobileone_s2", pretrained = False)
6
+ model.head.fc = torch.nn.Linear(model.head.fc.in_features,3)
7
  model.load_state_dict(torch.load("olive_classifier.pth", weights_only=True))
8
  model.eval()
9