saneshashank commited on
Commit
e305df2
·
verified ·
1 Parent(s): 45b5eaa

updated model class name

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -4,7 +4,7 @@ import torch.nn.functional as F
4
  from PIL import Image
5
  import json
6
  from torchvision import transforms
7
- from model import resnet50
8
 
9
  # Load class labels from local file
10
  with open("imagenet_classes.json", "r") as f:
@@ -12,7 +12,7 @@ with open("imagenet_classes.json", "r") as f:
12
 
13
  # Load model
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
- model = resnet50(num_classes=1000, drop_path_rate=0.0, use_blurpool=True)
16
  model.load_state_dict(torch.load("best_resnet50_imagenet_1k.pt", map_location=device))
17
  model.to(device)
18
  model.eval()
 
4
  from PIL import Image
5
  import json
6
  from torchvision import transforms
7
+ from model import ResNet
8
 
9
  # Load class labels from local file
10
  with open("imagenet_classes.json", "r") as f:
 
12
 
13
  # Load model
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model = ResNet(num_classes=1000, drop_path_rate=0.0, use_blurpool=True)
16
  model.load_state_dict(torch.load("best_resnet50_imagenet_1k.pt", map_location=device))
17
  model.to(device)
18
  model.eval()