Update app.py
Browse files
app.py
CHANGED
|
@@ -78,6 +78,18 @@ def get_fused_image(img, pred_mask, view, alpha=0.8):
|
|
| 78 |
rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
|
| 79 |
return rotated
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# Function for Gradio image segmentation
|
| 82 |
def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
|
| 83 |
"""Predict function using the learner and other resources."""
|
|
@@ -112,7 +124,7 @@ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view)
|
|
| 112 |
|
| 113 |
mask_data = inference(learn, reorder=reorder, resample=resample,
|
| 114 |
org_img=org_img, input_img=input_img,
|
| 115 |
-
org_size=org_size)
|
| 116 |
|
| 117 |
if "".join(org_img.orientation) == "LSA":
|
| 118 |
mask_data = mask_data.permute(0,1,3,2)
|
|
|
|
| 78 |
rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
|
| 79 |
return rotated
|
| 80 |
|
| 81 |
+
# Define the inference function
|
| 82 |
+
def inference(learn, reorder, resample, org_img, input_img, org_size):
|
| 83 |
+
"""Perform segmentation using the loaded model."""
|
| 84 |
+
# Perform the segmentation
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
pred = learn.predict(input_img)
|
| 87 |
+
|
| 88 |
+
# Process the prediction if necessary
|
| 89 |
+
mask_data = pred[0] # Assuming the first element of the prediction is the mask
|
| 90 |
+
|
| 91 |
+
return mask_data
|
| 92 |
+
|
| 93 |
# Function for Gradio image segmentation
|
| 94 |
def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
|
| 95 |
"""Predict function using the learner and other resources."""
|
|
|
|
| 124 |
|
| 125 |
mask_data = inference(learn, reorder=reorder, resample=resample,
|
| 126 |
org_img=org_img, input_img=input_img,
|
| 127 |
+
org_size=org_size)
|
| 128 |
|
| 129 |
if "".join(org_img.orientation) == "LSA":
|
| 130 |
mask_data = mask_data.permute(0,1,3,2)
|