Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,89 +1,81 @@
|
|
| 1 |
-
import cv2
|
| 2 |
import torch
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
import gradio as gr
|
| 6 |
import numpy as np
|
| 7 |
from PIL import Image
|
| 8 |
from PIL.ImageOps import grayscale
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
submit = gr.Button("Apply")
|
| 80 |
-
|
| 81 |
-
submit.click(fn=self.clothing_try_on_n_necklace_try_on, inputs=[inputImage, mask_image],
|
| 82 |
-
outputs=[outputOne])
|
| 83 |
-
|
| 84 |
-
interface.launch(debug=True)
|
| 85 |
|
|
|
|
| 86 |
|
| 87 |
if __name__ == "__main__":
|
| 88 |
-
|
| 89 |
-
app.launch_interface()
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
|
| 3 |
+
import os
|
| 4 |
import gradio as gr
|
| 5 |
import numpy as np
|
| 6 |
from PIL import Image
|
| 7 |
from PIL.ImageOps import grayscale
|
| 8 |
+
import gc
|
| 9 |
+
import spaces
|
| 10 |
+
|
| 11 |
+
model_id = "stabilityai/stable-diffusion-2-inpainting"
|
| 12 |
+
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 13 |
+
model_id, torch_dtype=torch.float16
|
| 14 |
+
)
|
| 15 |
+
pipeline = pipeline.to("cuda")
|
| 16 |
+
|
| 17 |
+
def clear_func():
|
| 18 |
+
"""Clear GPU memory cache."""
|
| 19 |
+
torch.cuda.empty_cache()
|
| 20 |
+
gc.collect()
|
| 21 |
+
|
| 22 |
+
def process_mask(mask):
|
| 23 |
+
"""Convert mask to binary format (black and white) for inpainting."""
|
| 24 |
+
mask = mask.convert("L") # Convert to grayscale
|
| 25 |
+
mask = np.array(mask)
|
| 26 |
+
|
| 27 |
+
# Convert to binary: 0 (black) -> keep, 255 (white) -> modify
|
| 28 |
+
mask = np.where(mask > 128, 255, 0).astype(np.uint8)
|
| 29 |
+
|
| 30 |
+
return Image.fromarray(mask)
|
| 31 |
+
|
| 32 |
+
@spaces.GPU
|
| 33 |
+
def clothing_try_on(image, mask):
|
| 34 |
+
"""Perform clothing try-on using the provided image and binary mask."""
|
| 35 |
+
orig_size = image.size
|
| 36 |
+
|
| 37 |
+
# Process and ensure mask is binary
|
| 38 |
+
mask = process_mask(mask)
|
| 39 |
+
|
| 40 |
+
# Resize image and mask for Stable Diffusion
|
| 41 |
+
image = image.resize((512, 512))
|
| 42 |
+
mask = mask.resize((512, 512))
|
| 43 |
+
|
| 44 |
+
# Prompt and negative prompt
|
| 45 |
+
prompt = f"South Indian Saree, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple"
|
| 46 |
+
negative_prompt = "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly"
|
| 47 |
+
|
| 48 |
+
# Perform the inpainting using the Stable Diffusion pipeline
|
| 49 |
+
output = pipeline(
|
| 50 |
+
prompt=prompt,
|
| 51 |
+
negative_prompt=negative_prompt,
|
| 52 |
+
image=image,
|
| 53 |
+
mask_image=mask,
|
| 54 |
+
strength=0.95,
|
| 55 |
+
guidance_score=9,
|
| 56 |
+
).images[0]
|
| 57 |
+
|
| 58 |
+
# Resize the output back to the original size
|
| 59 |
+
output = output.resize(orig_size)
|
| 60 |
+
|
| 61 |
+
# Clean GPU memory
|
| 62 |
+
clear_func()
|
| 63 |
+
|
| 64 |
+
return output
|
| 65 |
+
|
| 66 |
+
def launch_interface():
|
| 67 |
+
"""Launch the Gradio interface."""
|
| 68 |
+
with gr.Blocks() as interface:
|
| 69 |
+
with gr.Row():
|
| 70 |
+
inputImage = gr.Image(label="Input Image", type="pil", image_mode="RGB", interactive=True)
|
| 71 |
+
maskImage = gr.Image(label="Input Mask", type="pil", image_mode="RGB", interactive=True)
|
| 72 |
+
outputOne = gr.Image(label="Output", interactive=False)
|
| 73 |
+
|
| 74 |
+
submit = gr.Button("Apply")
|
| 75 |
+
|
| 76 |
+
submit.click(fn=clothing_try_on, inputs=[inputImage, maskImage], outputs=[outputOne])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
interface.launch(debug=True)
|
| 79 |
|
| 80 |
if __name__ == "__main__":
|
| 81 |
+
launch_interface()
|
|
|