Spaces:
Runtime error
Runtime error
update: remove image resize
Browse files
app.py
CHANGED
|
@@ -1,171 +1,170 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
|
| 3 |
-
import os
|
| 4 |
-
import gradio as gr
|
| 5 |
-
import numpy as np
|
| 6 |
-
from PIL import Image
|
| 7 |
-
from PIL.ImageOps import grayscale
|
| 8 |
-
import cv2
|
| 9 |
-
import torch
|
| 10 |
-
import gc
|
| 11 |
-
import math
|
| 12 |
-
import cvzone
|
| 13 |
-
from cvzone.PoseModule import PoseDetector
|
| 14 |
-
from cvzone.FaceMeshModule import FaceMeshDetector
|
| 15 |
-
import spaces
|
| 16 |
-
|
| 17 |
-
choker_images = [Image.open(os.path.join("short_necklace", x)) for x in os.listdir("short_necklace")]
|
| 18 |
-
person_images = [Image.open(os.path.join("without_necklace", x)) for x in os.listdir("without_necklace")]
|
| 19 |
-
|
| 20 |
-
model_id = "stabilityai/stable-diffusion-2-inpainting"
|
| 21 |
-
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 22 |
-
model_id, torch_dtype=torch.float16
|
| 23 |
-
)
|
| 24 |
-
pipeline = pipeline.to("cuda")
|
| 25 |
-
|
| 26 |
-
detector = PoseDetector()
|
| 27 |
-
meshDetector = FaceMeshDetector(staticMode=True, maxFaces=1)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def clear_func():
|
| 31 |
-
torch.cuda.empty_cache()
|
| 32 |
-
gc.collect()
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
@spaces.GPU
|
| 36 |
-
def clothing_try_on_n_necklace_try_on(image, jewellery):
|
| 37 |
-
image
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
pixel_value =
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
binaryMask =
|
| 100 |
-
binaryMask[binaryMask
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
)
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
image =
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
mask_y
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
np.array(
|
| 146 |
-
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
with gr.
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
gr.Examples(examples=
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
interface.launch(debug=True)
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
|
| 3 |
+
import os
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from PIL.ImageOps import grayscale
|
| 8 |
+
import cv2
|
| 9 |
+
import torch
|
| 10 |
+
import gc
|
| 11 |
+
import math
|
| 12 |
+
import cvzone
|
| 13 |
+
from cvzone.PoseModule import PoseDetector
|
| 14 |
+
from cvzone.FaceMeshModule import FaceMeshDetector
|
| 15 |
+
import spaces
|
| 16 |
+
|
| 17 |
+
choker_images = [Image.open(os.path.join("short_necklace", x)) for x in os.listdir("short_necklace")]
|
| 18 |
+
person_images = [Image.open(os.path.join("without_necklace", x)) for x in os.listdir("without_necklace")]
|
| 19 |
+
|
| 20 |
+
model_id = "stabilityai/stable-diffusion-2-inpainting"
|
| 21 |
+
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 22 |
+
model_id, torch_dtype=torch.float16
|
| 23 |
+
)
|
| 24 |
+
pipeline = pipeline.to("cuda")
|
| 25 |
+
|
| 26 |
+
detector = PoseDetector()
|
| 27 |
+
meshDetector = FaceMeshDetector(staticMode=True, maxFaces=1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def clear_func():
|
| 31 |
+
torch.cuda.empty_cache()
|
| 32 |
+
gc.collect()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@spaces.GPU
|
| 36 |
+
def clothing_try_on_n_necklace_try_on(image, jewellery):
|
| 37 |
+
image = np.array(image)
|
| 38 |
+
copy_image = image.copy()
|
| 39 |
+
jewellery = np.array(jewellery)
|
| 40 |
+
|
| 41 |
+
image = detector.findPose(image)
|
| 42 |
+
lmList, _ = detector.findPosition(image, bboxWithHands=False, draw=False)
|
| 43 |
+
|
| 44 |
+
img, faces = meshDetector.findFaceMesh(image, draw=False)
|
| 45 |
+
leftLandmarkIndex = 172
|
| 46 |
+
rightLandmarkIndex = 397
|
| 47 |
+
|
| 48 |
+
leftLandmark, rightLandmark = faces[0][leftLandmarkIndex], faces[0][rightLandmarkIndex]
|
| 49 |
+
landmarksDistance = int(
|
| 50 |
+
((leftLandmark[0] - rightLandmark[0]) ** 2 + (leftLandmark[1] - rightLandmark[1]) ** 2) ** 0.5)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
avg_x1 = int(leftLandmark[0] - landmarksDistance * 0.12)
|
| 54 |
+
avg_x2 = int(rightLandmark[0] + landmarksDistance * 0.12)
|
| 55 |
+
|
| 56 |
+
avg_y1 = int(leftLandmark[1] + landmarksDistance * 0.5)
|
| 57 |
+
avg_y2 = int(rightLandmark[1] + landmarksDistance * 0.5)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if avg_y2 < avg_y1:
|
| 61 |
+
angle = math.ceil(
|
| 62 |
+
detector.findAngle(
|
| 63 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
| 64 |
+
)[0]
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
angle = math.ceil(
|
| 68 |
+
detector.findAngle(
|
| 69 |
+
p1=(avg_x2, avg_y2), p2=(avg_x1, avg_y1), p3=(avg_x2, avg_y1)
|
| 70 |
+
)[0]
|
| 71 |
+
)
|
| 72 |
+
angle = angle * -1
|
| 73 |
+
|
| 74 |
+
xdist = avg_x2 - avg_x1
|
| 75 |
+
origImgRatio = xdist / jewellery.shape[1]
|
| 76 |
+
ydist = jewellery.shape[0] * origImgRatio
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
image_gray = cv2.cvtColor(jewellery, cv2.COLOR_BGRA2GRAY)
|
| 80 |
+
for offset_orig in range(image_gray.shape[1]):
|
| 81 |
+
pixel_value = image_gray[0, :][offset_orig]
|
| 82 |
+
if (pixel_value != 255) & (pixel_value != 0):
|
| 83 |
+
break
|
| 84 |
+
else:
|
| 85 |
+
continue
|
| 86 |
+
offset = int(0.8 * xdist * (offset_orig / jewellery.shape[1]))
|
| 87 |
+
jewellery = cv2.resize(
|
| 88 |
+
jewellery, (int(xdist), int(ydist)), interpolation=cv2.INTER_AREA
|
| 89 |
+
)
|
| 90 |
+
jewellery = cvzone.rotateImage(jewellery, angle)
|
| 91 |
+
y_coordinate = avg_y1 - offset
|
| 92 |
+
result = cvzone.overlayPNG(copy_image, jewellery, (avg_x1, y_coordinate))
|
| 93 |
+
|
| 94 |
+
blackedNecklace = np.zeros(shape=copy_image.shape)
|
| 95 |
+
# overlay
|
| 96 |
+
cvzone.overlayPNG(blackedNecklace, jewellery, (avg_x1, y_coordinate))
|
| 97 |
+
blackedNecklace = cv2.cvtColor(blackedNecklace.astype(np.uint8), cv2.COLOR_BGR2GRAY)
|
| 98 |
+
binaryMask = blackedNecklace * ((blackedNecklace > 5) * 255)
|
| 99 |
+
binaryMask[binaryMask >= 255] = 255
|
| 100 |
+
binaryMask[binaryMask < 255] = 0
|
| 101 |
+
|
| 102 |
+
gc.collect()
|
| 103 |
+
|
| 104 |
+
image = Image.fromarray(result.astype(np.uint8))
|
| 105 |
+
mask = Image.fromarray(binaryMask.astype(np.uint8)).convert("RGB")
|
| 106 |
+
|
| 107 |
+
jewellery_mask = Image.fromarray(
|
| 108 |
+
np.bitwise_and(np.array(mask), np.array(image))
|
| 109 |
+
)
|
| 110 |
+
arr_orig = np.array(grayscale(mask))
|
| 111 |
+
|
| 112 |
+
image = cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
|
| 113 |
+
image = Image.fromarray(image)
|
| 114 |
+
|
| 115 |
+
arr = arr_orig.copy()
|
| 116 |
+
mask_y = np.where(arr == arr[arr != 0][0])[0][0]
|
| 117 |
+
arr[mask_y:, :] = 255
|
| 118 |
+
|
| 119 |
+
new = Image.fromarray(arr)
|
| 120 |
+
|
| 121 |
+
mask = new.copy()
|
| 122 |
+
|
| 123 |
+
orig_size = image.size
|
| 124 |
+
|
| 125 |
+
image = image.resize((512, 512))
|
| 126 |
+
mask = mask.resize((512, 512))
|
| 127 |
+
|
| 128 |
+
results = []
|
| 129 |
+
prompt = f" South Indian Saree, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple"
|
| 130 |
+
negative_prompt = "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly"
|
| 131 |
+
|
| 132 |
+
output = pipeline(
|
| 133 |
+
prompt=prompt,
|
| 134 |
+
negative_prompt=negative_prompt,
|
| 135 |
+
image=image,
|
| 136 |
+
mask_image=mask,
|
| 137 |
+
strength=0.95,
|
| 138 |
+
guidance_score=9,
|
| 139 |
+
# generator = torch.Generator("cuda").manual_seed(42)
|
| 140 |
+
).images[0]
|
| 141 |
+
|
| 142 |
+
output = output.resize(orig_size)
|
| 143 |
+
temp_generated = np.bitwise_and(
|
| 144 |
+
np.array(output),
|
| 145 |
+
np.bitwise_not(np.array(Image.fromarray(arr_orig).convert("RGB"))),
|
| 146 |
+
)
|
| 147 |
+
results.append(temp_generated)
|
| 148 |
+
|
| 149 |
+
results = [
|
| 150 |
+
Image.fromarray(np.bitwise_or(x, np.array(jewellery_mask))) for x in results
|
| 151 |
+
]
|
| 152 |
+
clear_func()
|
| 153 |
+
return results[0]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
with gr.Blocks() as interface:
|
| 157 |
+
with gr.Row():
|
| 158 |
+
inputImage = gr.Image(label="Input Image", type="pil", image_mode="RGB", interactive=True)
|
| 159 |
+
selectedNecklace = gr.Image(label="Selected Necklace", type="pil", image_mode="RGBA", visible=False)
|
| 160 |
+
outputOne = gr.Image(label="Output", interactive=False)
|
| 161 |
+
|
| 162 |
+
with gr.Row():
|
| 163 |
+
gr.Examples(examples=choker_images, inputs=[selectedNecklace], label="Select Necklace")
|
| 164 |
+
gr.Examples(examples=person_images, inputs=[inputImage], label="Select Model")
|
| 165 |
+
|
| 166 |
+
submit = gr.Button("Apply")
|
| 167 |
+
|
| 168 |
+
submit.click(fn=clothing_try_on_n_necklace_try_on, inputs=[inputImage, selectedNecklace], outputs=[outputOne])
|
| 169 |
+
|
| 170 |
+
interface.launch(debug=True)
|
|
|