Shivdutta commited on
Commit
9c84d1a
·
verified ·
1 Parent(s): 17771de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -11
app.py CHANGED
@@ -12,26 +12,20 @@ import re
12
  import matplotlib.pyplot as plt
13
  from io import BytesIO
14
 
15
- # inv_normalize = transforms.Normalize(
16
- # mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
17
- # std=[1/0.23, 1/0.23, 1/0.23]
18
- # )
19
-
20
- inv_normalize = transforms.Normalize(
21
- mean=[0.49139968, 0.48215827 ,0.44653124],
22
- std=[0.24703233, 0.24348505, 0.26158768]
23
- )
24
 
25
  classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
26
 
27
  model = LITResNet(classes)
28
- model.load_state_dict(torch.load("model.pth", map_location=torch.device('cpu')), strict=False)
 
29
  modellayers = list(dict(model.named_modules()))
30
 
31
  def inference(input_img, num_gradcam_images=1, target_layer_number=-1, transparency=0.5, show_misclassified=False, num_top_classes=3, num_misclassified_images=3):
32
  input_img = np.array(Image.fromarray(np.array(input_img)).resize((32, 32)))
33
  org_img = input_img
34
- transform = transforms.ToTensor()
35
  input_img = transform(input_img).unsqueeze(0)
36
  outputs = model(input_img)
37
  softmax = torch.nn.Softmax(dim=0)
 
12
  import matplotlib.pyplot as plt
13
  from io import BytesIO
14
 
15
+ transform = transforms.Compose([
16
+ transforms.ToTensor(),
17
+ transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768))])
 
 
 
 
 
 
18
 
19
  classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
20
 
21
  model = LITResNet(classes)
22
+ model.load_state_dict(torch.load("model.pth")["state_dict"])
23
+ model.eval()
24
  modellayers = list(dict(model.named_modules()))
25
 
26
  def inference(input_img, num_gradcam_images=1, target_layer_number=-1, transparency=0.5, show_misclassified=False, num_top_classes=3, num_misclassified_images=3):
27
  input_img = np.array(Image.fromarray(np.array(input_img)).resize((32, 32)))
28
  org_img = input_img
 
29
  input_img = transform(input_img).unsqueeze(0)
30
  outputs = model(input_img)
31
  softmax = torch.nn.Softmax(dim=0)