Update app.py
Browse files
app.py
CHANGED
|
@@ -81,6 +81,10 @@ def get_fused_image(img, pred_mask, view, alpha=0.8):
|
|
| 81 |
# Define the inference function
|
| 82 |
def inference(learn, reorder, resample, org_img, input_img, org_size):
|
| 83 |
"""Perform segmentation using the loaded model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
# Perform the segmentation
|
| 85 |
with torch.no_grad():
|
| 86 |
pred = learn.predict(input_img)
|
|
@@ -122,6 +126,10 @@ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view)
|
|
| 122 |
# Infer org_size from org_img
|
| 123 |
org_size = org_img.shape[1:] # Assuming org_img has a shape attribute
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
mask_data = inference(learn, reorder=reorder, resample=resample,
|
| 126 |
org_img=org_img, input_img=input_img,
|
| 127 |
org_size=org_size)
|
|
|
|
| 81 |
# Define the inference function
|
| 82 |
def inference(learn, reorder, resample, org_img, input_img, org_size):
|
| 83 |
"""Perform segmentation using the loaded model."""
|
| 84 |
+
# Ensure input_img is a torch.Tensor
|
| 85 |
+
if not isinstance(input_img, torch.Tensor):
|
| 86 |
+
raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
|
| 87 |
+
|
| 88 |
# Perform the segmentation
|
| 89 |
with torch.no_grad():
|
| 90 |
pred = learn.predict(input_img)
|
|
|
|
| 126 |
# Infer org_size from org_img
|
| 127 |
org_size = org_img.shape[1:] # Assuming org_img has a shape attribute
|
| 128 |
|
| 129 |
+
# Ensure input_img is a torch.Tensor
|
| 130 |
+
if not isinstance(input_img, torch.Tensor):
|
| 131 |
+
raise ValueError(f"Expected input_img to be a torch.Tensor, but got {type(input_img)}")
|
| 132 |
+
|
| 133 |
mask_data = inference(learn, reorder=reorder, resample=resample,
|
| 134 |
org_img=org_img, input_img=input_img,
|
| 135 |
org_size=org_size)
|