Work commited on
Commit
88bddca
·
1 Parent(s): 70d8d45

fixed iamge formatting bug

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -9,6 +9,7 @@ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cp
9
 
10
  model = Autoencoder()
11
  model.load_state_dict(torch.load('model.pt', map_location=device))
 
12
 
13
  resize = Resize((224))
14
  to_tensor = ToTensor()
@@ -18,7 +19,9 @@ transforms = [to_tensor, resize]
18
  def test(image):
19
  for transform in transforms:
20
  image = transform(image)
21
- return [image]
 
 
22
 
23
 
24
  interface = gr.Interface(
@@ -26,15 +29,11 @@ interface = gr.Interface(
26
  description = "Select a image",
27
  allow_flagging="never",
28
  fn = test,
29
- inputs = [
30
- gr.Image(label = "x", shape = [224, 224]),
31
- ],
32
- outputs = [
33
- gr.Image(label = "pred"),
34
- ],
35
  examples = [
36
  ["img.jpg"],
37
  ]
38
  )
39
 
40
- interface.launch(share = False)
 
9
 
10
  model = Autoencoder()
11
  model.load_state_dict(torch.load('model.pt', map_location=device))
12
+ model = model.eval()
13
 
14
  resize = Resize((224))
15
  to_tensor = ToTensor()
 
19
  def test(image):
20
  for transform in transforms:
21
  image = transform(image)
22
+ image = image.unsqueeze(0)
23
+ image = model(image).squeeze(0).permute(1,2,0).cpu().detach().numpy()
24
+ return image
25
 
26
 
27
  interface = gr.Interface(
 
29
  description = "Select a image",
30
  allow_flagging="never",
31
  fn = test,
32
+ inputs = gr.Image(label = "x", type='numpy'),
33
+ outputs = gr.Image(label = "pred"),
 
 
 
 
34
  examples = [
35
  ["img.jpg"],
36
  ]
37
  )
38
 
39
+ interface.launch()