Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -25,6 +25,7 @@ numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist()
|
|
| 25 |
species: List[str] = df["Species"].unique().tolist()
|
| 26 |
species.sort()
|
| 27 |
|
|
|
|
| 28 |
app_ui = ui.page_fillable(
|
| 29 |
shinyswatch.theme.minty(),
|
| 30 |
ui.layout_sidebar(
|
|
@@ -62,6 +63,7 @@ app_ui = ui.page_fillable(
|
|
| 62 |
),
|
| 63 |
)
|
| 64 |
|
|
|
|
| 65 |
def tif_bytes_to_pil_image(tif_bytes):
|
| 66 |
# Create a BytesIO object from the TIFF bytes
|
| 67 |
bytes_io = io.BytesIO(tif_bytes)
|
|
@@ -89,6 +91,7 @@ def load_model():
|
|
| 89 |
|
| 90 |
return model, processor, device
|
| 91 |
|
|
|
|
| 92 |
def server(input: Inputs, output: Outputs, session: Session):
|
| 93 |
|
| 94 |
# set model, processor, device once
|
|
@@ -163,11 +166,14 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
| 163 |
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
| 164 |
|
| 165 |
""" Get Image """
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
""" Prepare Inputs """
|
| 173 |
# get input points prompt (grid of points)
|
|
@@ -176,10 +182,64 @@ def server(input: Inputs, output: Outputs, session: Session):
|
|
| 176 |
# prepare image and prompt for the model
|
| 177 |
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
| 178 |
|
| 179 |
-
# remove batch dimension which the processor adds by default
|
| 180 |
-
inputs = {k:v.squeeze(0) for k,v in inputs.items()}
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
""" Get Predictions """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
# Evaluate the image with the model
|
| 184 |
# Example: predictions = model.predict(image_array)
|
| 185 |
|
|
|
|
| 25 |
species: List[str] = df["Species"].unique().tolist()
|
| 26 |
species.sort()
|
| 27 |
|
| 28 |
+
### UI ###
|
| 29 |
app_ui = ui.page_fillable(
|
| 30 |
shinyswatch.theme.minty(),
|
| 31 |
ui.layout_sidebar(
|
|
|
|
| 63 |
),
|
| 64 |
)
|
| 65 |
|
| 66 |
+
### HELPER FUNCTIONS ###
|
| 67 |
def tif_bytes_to_pil_image(tif_bytes):
|
| 68 |
# Create a BytesIO object from the TIFF bytes
|
| 69 |
bytes_io = io.BytesIO(tif_bytes)
|
|
|
|
| 91 |
|
| 92 |
return model, processor, device
|
| 93 |
|
| 94 |
+
### SERVER ###
|
| 95 |
def server(input: Inputs, output: Outputs, session: Session):
|
| 96 |
|
| 97 |
# set model, processor, device once
|
|
|
|
| 166 |
"""Processes the uploaded image, loads the model, and evaluates to get predictions"""
|
| 167 |
|
| 168 |
""" Get Image """
|
| 169 |
+
img_src = uploaded_image_path()
|
| 170 |
+
|
| 171 |
+
# Read the image bytes from the file
|
| 172 |
+
with open(img_src, 'rb') as f:
|
| 173 |
+
image_bytes = f.read()
|
| 174 |
+
|
| 175 |
+
# Convert the image bytes to a PIL Image
|
| 176 |
+
image = tif_bytes_to_pil_image(image_bytes)
|
| 177 |
|
| 178 |
""" Prepare Inputs """
|
| 179 |
# get input points prompt (grid of points)
|
|
|
|
| 182 |
# prepare image and prompt for the model
|
| 183 |
inputs = processor(image, input_points=input_points, return_tensors="pt")
|
| 184 |
|
| 185 |
+
# # remove batch dimension which the processor adds by default
|
| 186 |
+
# inputs = {k:v.squeeze(0) for k,v in inputs.items()}
|
| 187 |
|
| 188 |
+
# Move the input tensor to the GPU if it's not already there
|
| 189 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 190 |
+
|
| 191 |
""" Get Predictions """
|
| 192 |
+
# forward pass
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
outputs = model(**inputs, multimask_output=False)
|
| 195 |
+
|
| 196 |
+
# apply sigmoid
|
| 197 |
+
prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
|
| 198 |
+
# convert soft mask to hard mask
|
| 199 |
+
prob = prob.cpu().numpy().squeeze()
|
| 200 |
+
prediction = (prob > 0.5).astype(np.uint8)
|
| 201 |
+
|
| 202 |
+
# fig, axes = plt.subplots(1, 5, figsize=(15, 5))
|
| 203 |
+
|
| 204 |
+
# # Extract the image data from the batch
|
| 205 |
+
# image_data = batch['image'].cpu().detach().numpy()[0] # Assuming batch size is 1
|
| 206 |
+
|
| 207 |
+
# # Plot the first image on the left
|
| 208 |
+
# axes[0].imshow(image_data)
|
| 209 |
+
# axes[0].set_title("Image")
|
| 210 |
+
|
| 211 |
+
# # Plot the second image on the right
|
| 212 |
+
# axes[1].imshow(prob)
|
| 213 |
+
# axes[1].set_title("Probability Map")
|
| 214 |
+
|
| 215 |
+
# # Plot the prediction image on the right
|
| 216 |
+
# axes[2].imshow(prediction)
|
| 217 |
+
# axes[2].set_title("Prediction")
|
| 218 |
+
|
| 219 |
+
# # Plot the predicted mask on the right
|
| 220 |
+
# axes[3].imshow(image_data)
|
| 221 |
+
# show_mask(prediction, axes[3])
|
| 222 |
+
# axes[3].set_title("Predicted Mask")
|
| 223 |
+
|
| 224 |
+
# # Extract the ground truth mask data from the batch
|
| 225 |
+
# ground_truth_mask_data = inputs['ground_truth_mask'].cpu().detach().numpy()[0] # Assuming batch size is 1
|
| 226 |
+
|
| 227 |
+
# # Plot the ground truth mask on the right
|
| 228 |
+
# axes[4].imshow(image_data)
|
| 229 |
+
# axes[4].imshow(ground_truth_mask_data)
|
| 230 |
+
# #show_mask(inputs['ground_truth_mask'], axes[4])
|
| 231 |
+
# axes[4].set_title("Ground Truth Mask")
|
| 232 |
+
|
| 233 |
+
# # Hide axis ticks and labels
|
| 234 |
+
# for ax in axes:
|
| 235 |
+
# ax.set_xticks([])
|
| 236 |
+
# ax.set_yticks([])
|
| 237 |
+
# ax.set_xticklabels([])
|
| 238 |
+
# ax.set_yticklabels([])
|
| 239 |
+
|
| 240 |
+
# # Display the images side by side
|
| 241 |
+
# plt.show()
|
| 242 |
+
|
| 243 |
# Evaluate the image with the model
|
| 244 |
# Example: predictions = model.predict(image_array)
|
| 245 |
|