ishworrsubedii commited on
Commit
fda1d3e
Β·
1 Parent(s): 45ef0f3

add: base64 conversion

Browse files
Files changed (1) hide show
  1. app.py +39 -15
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
3
  import os
4
  import gradio as gr
5
  import numpy as np
@@ -8,6 +7,9 @@ from PIL.ImageOps import grayscale
8
  import gc
9
  import spaces
10
  import cv2
 
 
 
11
 
12
  model_id = "stabilityai/stable-diffusion-2-inpainting"
13
  pipeline = StableDiffusionInpaintPipeline.from_pretrained(
@@ -15,26 +17,26 @@ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
15
  )
16
  pipeline = pipeline.to("cuda")
17
 
 
18
  def clear_func():
19
- """Clear GPU memory cache."""
20
  torch.cuda.empty_cache()
21
  gc.collect()
22
 
 
23
  def process_mask(mask):
24
- """Convert mask to binary format (black and white) for inpainting."""
25
- mask = mask.convert("L") # Convert to grayscale
26
  mask = np.array(mask)
27
 
28
- # Convert to binary: 0 (black) -> keep, 255 (white) -> modify
29
  mask = np.where(mask > 128, 255, 0).astype(np.uint8)
30
 
31
  return Image.fromarray(mask)
32
 
 
33
  @spaces.GPU
34
  def clothing_try_on(image, mask):
35
  jewellery_mask = Image.fromarray(
36
- np.bitwise_and(np.array(mask), np.array(image))
37
- )
38
  arr_orig = np.array(grayscale(mask))
39
 
40
  image = cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
@@ -80,19 +82,41 @@ def clothing_try_on(image, mask):
80
  clear_func()
81
  return results[0]
82
 
83
- def launch_interface():
84
- """Launch the Gradio interface."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  with gr.Blocks() as interface:
86
  with gr.Row():
87
- inputImage = gr.Image(label="Input Image", type="pil", image_mode="RGB", interactive=True)
88
- maskImage = gr.Image(label="Input Mask", type="pil", image_mode="RGB", interactive=True)
89
- outputOne = gr.Image(label="Output", interactive=False)
90
 
91
  submit = gr.Button("Apply")
92
 
93
- submit.click(fn=clothing_try_on, inputs=[inputImage, maskImage], outputs=[outputOne])
 
 
94
 
95
- interface.launch(debug=True)
96
 
97
  if __name__ == "__main__":
98
- launch_interface()
 
1
  import torch
 
2
  import os
3
  import gradio as gr
4
  import numpy as np
 
7
  import gc
8
  import spaces
9
  import cv2
10
+ import base64
11
+ from io import BytesIO
12
+ from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
13
 
14
  model_id = "stabilityai/stable-diffusion-2-inpainting"
15
  pipeline = StableDiffusionInpaintPipeline.from_pretrained(
 
17
  )
18
  pipeline = pipeline.to("cuda")
19
 
20
+
21
  def clear_func():
 
22
  torch.cuda.empty_cache()
23
  gc.collect()
24
 
25
+
26
  def process_mask(mask):
27
+ mask = mask.convert("L")
 
28
  mask = np.array(mask)
29
 
 
30
  mask = np.where(mask > 128, 255, 0).astype(np.uint8)
31
 
32
  return Image.fromarray(mask)
33
 
34
+
35
  @spaces.GPU
36
  def clothing_try_on(image, mask):
37
  jewellery_mask = Image.fromarray(
38
+ np.bitwise_and(np.array(mask), np.array(image))
39
+ )
40
  arr_orig = np.array(grayscale(mask))
41
 
42
  image = cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
 
82
  clear_func()
83
  return results[0]
84
 
85
+
86
+ def base64_to_image(base64_str):
87
+ image_data = base64.b64decode(base64_str)
88
+ image = Image.open(BytesIO(image_data))
89
+ return image
90
+
91
+
92
+ def image_to_base64(image):
93
+ buffered = BytesIO()
94
+ image.save(buffered, format="PNG")
95
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
96
+
97
+
98
+ def clothing_try_on_base64(input_image_base64, mask_image_base64):
99
+ image = base64_to_image(input_image_base64)
100
+ mask = base64_to_image(mask_image_base64)
101
+
102
+ output_image = clothing_try_on(image, mask)
103
+
104
+ return image_to_base64(output_image)
105
+
106
+
107
+ def launch_interface_base64():
108
  with gr.Blocks() as interface:
109
  with gr.Row():
110
+ inputImage = gr.Textbox(label="Input Image (Base64)", lines=4)
111
+ maskImage = gr.Textbox(label="Input Mask (Base64)", lines=4)
112
+ outputOne = gr.Textbox(label="Output (Base64)", lines=4)
113
 
114
  submit = gr.Button("Apply")
115
 
116
+ submit.click(fn=clothing_try_on_base64, inputs=[inputImage, maskImage], outputs=[outputOne])
117
+
118
+ interface.launch(debug=True)
119
 
 
120
 
121
  if __name__ == "__main__":
122
+ launch_interface_base64()