ayush2003 commited on
Commit
013f5dd
·
1 Parent(s): e9c0736

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -508,8 +508,9 @@ model = SimpleCNN()
508
  model.load_state_dict(torch.load("model.pth"))
509
  model.eval()
510
 
511
- def predict_pose(image):
512
- img = cv2.resize(image, (32,32))
 
513
  convert_tensor = transforms.ToTensor()
514
  tensor_img = convert_tensor(img)
515
  tensor_img = tensor_img[None,:,:,:]
@@ -522,11 +523,11 @@ def predict_pose(image):
522
 
523
  # predict_pose(test_image)
524
  input_image = [
525
- gr.components.Image(type = 'pil'),
526
  ]
527
 
528
  output_image = [
529
- gr.components.Image(type = 'numpy'),
530
  ]
531
  pose_detector = gr.Interface(fn = predict_pose, inputs = input_image , outputs = output_image )
532
 
 
508
  model.load_state_dict(torch.load("model.pth"))
509
  model.eval()
510
 
511
+ def predict_pose(path):
512
+ img = cv2.imread(str(path))
513
+ img= cv2.resize(img, (32,32))
514
  convert_tensor = transforms.ToTensor()
515
  tensor_img = convert_tensor(img)
516
  tensor_img = tensor_img[None,:,:,:]
 
523
 
524
  # predict_pose(test_image)
525
  input_image = [
526
+ gr.components.Image(type = "filepath"),
527
  ]
528
 
529
  output_image = [
530
+ gr.components.Image(type = "numpy"),
531
  ]
532
  pose_detector = gr.Interface(fn = predict_pose, inputs = input_image , outputs = output_image )
533