Spaces:
Runtime error
Runtime error
| import hashlib | |
| import os | |
| from io import BytesIO | |
| import gradio as gr | |
| import grpc | |
| from PIL import Image | |
| from cachetools import LRUCache | |
| from inference_pb2 import HairSwapRequest, HairSwapResponse | |
| from inference_pb2_grpc import HairSwapServiceStub | |
| from utils.shape_predictor import align_face | |
| def get_bytes(img): | |
| if img is None: | |
| return img | |
| buffered = BytesIO() | |
| img.save(buffered, format="JPEG") | |
| return buffered.getvalue() | |
| def bytes_to_image(image: bytes) -> Image.Image: | |
| image = Image.open(BytesIO(image)) | |
| return image | |
| def center_crop(img): | |
| width, height = img.size | |
| side = min(width, height) | |
| left = (width - side) / 2 | |
| top = (height - side) / 2 | |
| right = (width + side) / 2 | |
| bottom = (height + side) / 2 | |
| img = img.crop((left, top, right, bottom)) | |
| return img | |
| def resize(name): | |
| def resize_inner(img, align): | |
| global align_cache | |
| if name in align: | |
| img_hash = hashlib.md5(get_bytes(img)).hexdigest() | |
| if img_hash not in align_cache: | |
| img = align_face(img, return_tensors=False)[0] | |
| align_cache[img_hash] = img | |
| else: | |
| img = align_cache[img_hash] | |
| elif img.size != (1024, 1024): | |
| img = center_crop(img) | |
| img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| return img | |
| return resize_inner | |
| def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion): | |
| if not face and not shape and not color: | |
| return gr.update(visible=False), gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True) | |
| elif not face: | |
| return gr.update(visible=False), gr.update(value="Need to upload a face ❗", visible=True) | |
| elif not shape and not color: | |
| return gr.update(visible=False), gr.update(value="Need to upload at least a shape or color ❗", visible=True) | |
| face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item), (face, shape, color)) | |
| if shape_bytes is None: | |
| shape_bytes = b'face' | |
| if color_bytes is None: | |
| color_bytes = b'shape' | |
| with grpc.insecure_channel(os.environ['SERVER']) as channel: | |
| stub = HairSwapServiceStub(channel) | |
| output: HairSwapResponse = stub.swap( | |
| HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending, | |
| poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True) | |
| ) | |
| output = bytes_to_image(output.image) | |
| return gr.update(value=output, visible=True), gr.update(visible=False) | |
| def get_demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Hair") | |
| with gr.Row(): | |
| with gr.Column(): | |
| source = gr.Image(label="Source photo to try on the hairstyle", type="pil") | |
| with gr.Row(): | |
| shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil") | |
| color = gr.Image(label="Color photo with desired hair color (optional)", type="pil") | |
| with gr.Accordion("Advanced Options", open=False): | |
| blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article', | |
| label="Color Encoder version", info="Selects a model for hair color transfer.") | |
| poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters", | |
| info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.") | |
| poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion", | |
| info="Smooths out the blending area.") | |
| align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"], | |
| label="Image cropping [recommended]", | |
| info="Selects which images to crop by face") | |
| btn = gr.Button("Get the haircut") | |
| with gr.Column(): | |
| output = gr.Image(label="Your result") | |
| error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message") | |
| gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None], | |
| ["input/10.jpg", None, "input/11.jpg"]], | |
| inputs=[source, shape, color], outputs=output) | |
| source.upload(fn=resize('Face'), inputs=[source, align], outputs=source) | |
| shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape) | |
| color.upload(fn=resize('Color'), inputs=[color, align], outputs=color) | |
| btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion], | |
| outputs=[output, error_message]) | |
| return demo | |
| if __name__ == '__main__': | |
| align_cache = LRUCache(maxsize=10) | |
| demo = get_demo() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |
| # import hashlib | |
| # import os | |
| # from io import BytesIO | |
| # import gradio as gr | |
| # import grpc | |
| # from PIL import Image | |
| # from cachetools import LRUCache | |
| # from inference_pb2 import HairSwapRequest, HairSwapResponse | |
| # from inference_pb2_grpc import HairSwapServiceStub | |
| # from utils.shape_predictor import align_face | |
| # def get_bytes(img): | |
| # if img is None: | |
| # return img | |
| # buffered = BytesIO() | |
| # img.save(buffered, format="JPEG") | |
| # return buffered.getvalue() | |
| # def bytes_to_image(image: bytes) -> Image.Image: | |
| # image = Image.open(BytesIO(image)) | |
| # return image | |
| # def center_crop(img): | |
| # width, height = img.size | |
| # side = min(width, height) | |
| # left = (width - side) / 2 | |
| # top = (height - side) / 2 | |
| # right = (width + side) / 2 | |
| # bottom = (height + side) / 2 | |
| # img = img.crop((left, top, right, bottom)) | |
| # return img | |
| # def resize(name): | |
| # def resize_inner(img, align): | |
| # global align_cache | |
| # if name in align: | |
| # img_hash = hashlib.md5(get_bytes(img)).hexdigest() | |
| # if img_hash not in align_cache: | |
| # img = align_face(img, return_tensors=False)[0] | |
| # align_cache[img_hash] = img | |
| # else: | |
| # img = align_cache[img_hash] | |
| # elif img.size != (1024, 1024): | |
| # img = center_crop(img) | |
| # img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| # return img | |
| # return resize_inner | |
| # def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion): | |
| # if not face and not shape and not color: | |
| # return gr.update(visible=False), gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True) | |
| # elif not face: | |
| # return gr.update(visible=False), gr.update(value="Need to upload a face ❗", visible=True) | |
| # elif not shape and not color: | |
| # return gr.update(visible=False), gr.update(value="Need to upload at least a shape or color ❗", visible=True) | |
| # face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item), (face, shape, color)) | |
| # if shape_bytes is None: | |
| # shape_bytes = b'face' | |
| # if color_bytes is None: | |
| # color_bytes = b'shape' | |
| # with grpc.insecure_channel(os.environ['SERVER']) as channel: | |
| # stub = HairSwapServiceStub(channel) | |
| # output: HairSwapResponse = stub.swap( | |
| # HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending, | |
| # poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True) | |
| # ) | |
| # output = bytes_to_image(output.image) | |
| # return gr.update(value=output, visible=True), gr.update(visible=False) | |
| # def get_demo(): | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("## Hair Designer") | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # with gr.Row(): | |
| # source = gr.Image(label="Source photo to try on the hairstyle", type="pil") | |
| # shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil") | |
| # color = gr.Image(label="Color photo with desired hair color (optional)", type="pil") | |
| # with gr.Row(): | |
| # with gr.Accordion("Advanced Options", open=False): | |
| # blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article', | |
| # label="Color Encoder version", info="Selects a model for hair color transfer.") | |
| # poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters", | |
| # info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.") | |
| # poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion", | |
| # info="Smooths out the blending area.") | |
| # align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"], | |
| # label="Image cropping [recommended]", | |
| # info="Selects which images to crop by face") | |
| # btn = gr.Button("Design Now!") | |
| # with gr.Column(): | |
| # output = gr.Image(label="Your result") | |
| # error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=True, elem_classes="error-message") | |
| # gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], | |
| # ["input/6.png", "input/7.png", None], | |
| # ["input/10.jpg", None, "input/11.jpg"]], | |
| # inputs=[source, shape, color], outputs=output) | |
| # source.upload(fn=resize('Face'), inputs=[source, align], outputs=source) | |
| # shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape) | |
| # color.upload(fn=resize('Color'), inputs=[color, align], outputs=color) | |
| # btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion], | |
| # outputs=[output, error_message]) | |
| # return demo | |
| # if __name__ == '__main__': | |
| # align_cache = LRUCache(maxsize=10) | |
| # demo = get_demo() | |
| # demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |