drankush-ai commited on
Commit
723b107
·
verified ·
1 Parent(s): d8db3a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -81,6 +81,10 @@ def get_fused_image(img, pred_mask, view, alpha=0.8):
81
  # Define the inference function
82
  def inference(learn, reorder, resample, org_img, input_img, org_size):
83
  """Perform segmentation using the loaded model."""
 
 
 
 
84
  # Perform the segmentation
85
  with torch.no_grad():
86
  pred = learn.predict(input_img)
@@ -122,6 +126,10 @@ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view)
122
  # Infer org_size from org_img
123
  org_size = org_img.shape[1:] # Assuming org_img has a shape attribute
124
 
 
 
 
 
125
  mask_data = inference(learn, reorder=reorder, resample=resample,
126
  org_img=org_img, input_img=input_img,
127
  org_size=org_size)
 
81
  # Define the inference function
82
  def inference(learn, reorder, resample, org_img, input_img, org_size):
83
  """Perform segmentation using the loaded model."""
84
+ # Ensure input_img is a torch.Tensor
85
+ if not isinstance(input_img, torch.Tensor):
86
+ raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
87
+
88
  # Perform the segmentation
89
  with torch.no_grad():
90
  pred = learn.predict(input_img)
 
126
  # Infer org_size from org_img
127
  org_size = org_img.shape[1:] # Assuming org_img has a shape attribute
128
 
129
+ # Ensure input_img is a torch.Tensor
130
+ if not isinstance(input_img, torch.Tensor):
131
+ raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
132
+
133
  mask_data = inference(learn, reorder=reorder, resample=resample,
134
  org_img=org_img, input_img=input_img,
135
  org_size=org_size)