ishworrsubedii commited on
Commit
4f0c4b6
Β·
verified Β·
1 Parent(s): 79061d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -81
app.py CHANGED
@@ -1,89 +1,81 @@
1
- import cv2
2
  import torch
3
- import gc
4
- import spaces
5
  import gradio as gr
6
  import numpy as np
7
  from PIL import Image
8
  from PIL.ImageOps import grayscale
9
- from diffusers import StableDiffusionInpaintPipeline
10
-
11
-
12
- class NecklaceTryOn:
13
- def __init__(self):
14
- self.model_id = "stabilityai/stable-diffusion-2-inpainting"
15
- self.pipeline = StableDiffusionInpaintPipeline.from_pretrained(
16
- self.model_id, torch_dtype=torch.float16
17
- )
18
- self.pipeline = self.pipeline.to("cuda")
19
-
20
- def clear_func(self):
21
- torch.cuda.empty_cache()
22
- gc.collect()
23
-
24
- @spaces.GPU
25
- def clothing_try_on_n_necklace_try_on(self, result_image, mask):
26
- """Main method for clothing and necklace try-on."""
27
- jewellery_mask = Image.fromarray(
28
- np.bitwise_and(np.array(mask), np.array(result_image))
29
- )
30
- arr_orig = np.array(grayscale(mask))
31
-
32
- result_image = cv2.inpaint(np.array(result_image), arr_orig, 15, cv2.INPAINT_TELEA)
33
- result_image = Image.fromarray(result_image)
34
-
35
- arr = arr_orig.copy()
36
- mask_y = np.where(arr == arr[arr != 0][0])[0][0]
37
- arr[mask_y:, :] = 255
38
-
39
- new_mask = Image.fromarray(arr)
40
- mask = new_mask.copy()
41
-
42
- orig_size = result_image.size
43
- result_image = result_image.resize((512, 512))
44
- mask = mask.resize((512, 512))
45
-
46
- results = []
47
- prompt = f" South Indian Saree, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple"
48
- negative_prompt = "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly"
49
-
50
- output = self.pipeline(
51
- prompt=prompt,
52
- negative_prompt=negative_prompt,
53
- image=result_image,
54
- mask_image=mask,
55
- strength=0.95,
56
- guidance_score=9,
57
- ).images[0]
58
-
59
- output = output.resize(orig_size)
60
- temp_generated = np.bitwise_and(
61
- np.array(output),
62
- np.bitwise_not(np.array(Image.fromarray(arr_orig).convert("RGB"))),
63
- )
64
- results.append(temp_generated)
65
-
66
- results = [
67
- Image.fromarray(np.bitwise_or(x, np.array(jewellery_mask))) for x in results
68
- ]
69
- self.clear_func()
70
- return results[0]
71
-
72
- def launch_interface(self):
73
- with gr.Blocks() as interface:
74
- with gr.Row():
75
- inputImage = gr.Image(label="Input Image", type="pil", image_mode="RGB", interactive=True)
76
- mask_image = gr.Image(label="Mask Image", type="pil", image_mode="RGB", interactive=True)
77
- outputOne = gr.Image(label="Output", interactive=False)
78
-
79
- submit = gr.Button("Apply")
80
-
81
- submit.click(fn=self.clothing_try_on_n_necklace_try_on, inputs=[inputImage, mask_image],
82
- outputs=[outputOne])
83
-
84
- interface.launch(debug=True)
85
 
 
86
 
87
  if __name__ == "__main__":
88
- app = NecklaceTryOn()
89
- app.launch_interface()
 
 
1
  import torch
2
+ from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
3
+ import os
4
  import gradio as gr
5
  import numpy as np
6
  from PIL import Image
7
  from PIL.ImageOps import grayscale
8
+ import gc
9
+ import spaces
10
+
11
+ model_id = "stabilityai/stable-diffusion-2-inpainting"
12
+ pipeline = StableDiffusionInpaintPipeline.from_pretrained(
13
+ model_id, torch_dtype=torch.float16
14
+ )
15
+ pipeline = pipeline.to("cuda")
16
+
17
+ def clear_func():
18
+ """Clear GPU memory cache."""
19
+ torch.cuda.empty_cache()
20
+ gc.collect()
21
+
22
+ def process_mask(mask):
23
+ """Convert mask to binary format (black and white) for inpainting."""
24
+ mask = mask.convert("L") # Convert to grayscale
25
+ mask = np.array(mask)
26
+
27
+ # Convert to binary: 0 (black) -> keep, 255 (white) -> modify
28
+ mask = np.where(mask > 128, 255, 0).astype(np.uint8)
29
+
30
+ return Image.fromarray(mask)
31
+
32
+ @spaces.GPU
33
+ def clothing_try_on(image, mask):
34
+ """Perform clothing try-on using the provided image and binary mask."""
35
+ orig_size = image.size
36
+
37
+ # Process and ensure mask is binary
38
+ mask = process_mask(mask)
39
+
40
+ # Resize image and mask for Stable Diffusion
41
+ image = image.resize((512, 512))
42
+ mask = mask.resize((512, 512))
43
+
44
+ # Prompt and negative prompt
45
+ prompt = f"South Indian Saree, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple"
46
+ negative_prompt = "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly"
47
+
48
+ # Perform the inpainting using the Stable Diffusion pipeline
49
+ output = pipeline(
50
+ prompt=prompt,
51
+ negative_prompt=negative_prompt,
52
+ image=image,
53
+ mask_image=mask,
54
+ strength=0.95,
55
+ guidance_score=9,
56
+ ).images[0]
57
+
58
+ # Resize the output back to the original size
59
+ output = output.resize(orig_size)
60
+
61
+ # Clean GPU memory
62
+ clear_func()
63
+
64
+ return output
65
+
66
+ def launch_interface():
67
+ """Launch the Gradio interface."""
68
+ with gr.Blocks() as interface:
69
+ with gr.Row():
70
+ inputImage = gr.Image(label="Input Image", type="pil", image_mode="RGB", interactive=True)
71
+ maskImage = gr.Image(label="Input Mask", type="pil", image_mode="RGB", interactive=True)
72
+ outputOne = gr.Image(label="Output", interactive=False)
73
+
74
+ submit = gr.Button("Apply")
75
+
76
+ submit.click(fn=clothing_try_on, inputs=[inputImage, maskImage], outputs=[outputOne])
 
 
 
 
 
 
 
77
 
78
+ interface.launch(debug=True)
79
 
80
  if __name__ == "__main__":
81
+ launch_interface()