drankush-ai commited on
Commit
d8db3a4
·
verified ·
1 Parent(s): ef80d68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -78,6 +78,18 @@ def get_fused_image(img, pred_mask, view, alpha=0.8):
78
  rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
79
  return rotated
80
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # Function for Gradio image segmentation
82
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
83
  """Predict function using the learner and other resources."""
@@ -112,7 +124,7 @@ def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view)
112
 
113
  mask_data = inference(learn, reorder=reorder, resample=resample,
114
  org_img=org_img, input_img=input_img,
115
- org_size=org_size).data
116
 
117
  if "".join(org_img.orientation) == "LSA":
118
  mask_data = mask_data.permute(0,1,3,2)
 
78
  rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
79
  return rotated
80
 
81
+ # Define the inference function
82
+ def inference(learn, reorder, resample, org_img, input_img, org_size):
83
+ """Perform segmentation using the loaded model."""
84
+ # Perform the segmentation
85
+ with torch.no_grad():
86
+ pred = learn.predict(input_img)
87
+
88
+ # Process the prediction if necessary
89
+ mask_data = pred[0] # Assuming the first element of the prediction is the mask
90
+
91
+ return mask_data
92
+
93
  # Function for Gradio image segmentation
94
  def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
95
  """Predict function using the learner and other resources."""
 
124
 
125
  mask_data = inference(learn, reorder=reorder, resample=resample,
126
  org_img=org_img, input_img=input_img,
127
+ org_size=org_size)
128
 
129
  if "".join(org_img.orientation) == "LSA":
130
  mask_data = mask_data.permute(0,1,3,2)