chore: reshape the image to (100, 100, 3) if not and add a check for rgb format
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""A local gradio app that filters images using FHE."""
|
| 2 |
-
|
| 3 |
import os
|
| 4 |
import shutil
|
| 5 |
import subprocess
|
|
@@ -191,6 +191,18 @@ def encrypt(user_id, input_image, filter_name):
|
|
| 191 |
|
| 192 |
if input_image is None:
|
| 193 |
raise gr.Error("Please choose an image first.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
# Retrieve the client API
|
| 196 |
client = get_client(user_id, filter_name)
|
|
@@ -482,7 +494,7 @@ with demo:
|
|
| 482 |
)
|
| 483 |
|
| 484 |
output_image = gr.Image(
|
| 485 |
-
label=f"Output image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):",
|
| 486 |
interactive=False,
|
| 487 |
height=256,
|
| 488 |
width=256,
|
|
@@ -513,7 +525,7 @@ with demo:
|
|
| 513 |
# Button to send the encodings to the server using post method
|
| 514 |
get_output_button.click(
|
| 515 |
get_output,
|
| 516 |
-
inputs=[user_id, filter_name],
|
| 517 |
outputs=[encrypted_output_representation]
|
| 518 |
)
|
| 519 |
|
|
|
|
| 1 |
"""A local gradio app that filters images using FHE."""
|
| 2 |
+
from PIL import Image
|
| 3 |
import os
|
| 4 |
import shutil
|
| 5 |
import subprocess
|
|
|
|
| 191 |
|
| 192 |
if input_image is None:
|
| 193 |
raise gr.Error("Please choose an image first.")
|
| 194 |
+
|
| 195 |
+
if input_image.shape[-1] != 3:
|
| 196 |
+
raise ValueError(f"Input image must have 3 channels (RGB). Current shape: {input_image.shape}")
|
| 197 |
+
|
| 198 |
+
# Resize the image if it hasn't the shape (100, 100, 3)
|
| 199 |
+
if input_image.shape != (100 , 100, 3):
|
| 200 |
+
print(f"Before: {type(input_image)=}, {input_image.shape=}")
|
| 201 |
+
input_image_pil = Image.fromarray(input_image)
|
| 202 |
+
# Resize the image
|
| 203 |
+
input_image_pil = input_image_pil.resize((100, 100))
|
| 204 |
+
input_image = numpy.array(input_image_pil)
|
| 205 |
+
print(f"After: {type(input_image)=}, {input_image.shape=}")
|
| 206 |
|
| 207 |
# Retrieve the client API
|
| 208 |
client = get_client(user_id, filter_name)
|
|
|
|
| 494 |
)
|
| 495 |
|
| 496 |
output_image = gr.Image(
|
| 497 |
+
label=f"Output image ({INPUT_SHAPE[0]}x{INPUT_SHAPE[1]}):",
|
| 498 |
interactive=False,
|
| 499 |
height=256,
|
| 500 |
width=256,
|
|
|
|
| 525 |
# Button to send the encodings to the server using post method
|
| 526 |
get_output_button.click(
|
| 527 |
get_output,
|
| 528 |
+
inputs=[user_id, filter_name],
|
| 529 |
outputs=[encrypted_output_representation]
|
| 530 |
)
|
| 531 |
|