0oAstro commited on
Commit
96904d7
·
unverified ·
1 Parent(s): f9581ed
Files changed (2) hide show
  1. app.py +155 -112
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,20 +1,74 @@
1
  import hashlib
2
  import os
 
3
  from io import BytesIO
 
4
 
5
- import gradio as gr
6
  import grpc
 
7
  from PIL import Image
8
  from cachetools import LRUCache
 
 
9
 
10
  from inference_pb2 import HairSwapRequest, HairSwapResponse
11
  from inference_pb2_grpc import HairSwapServiceStub
12
  from utils.shape_predictor import align_face
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def get_bytes(img):
16
  if img is None:
17
- return img
18
 
19
  buffered = BytesIO()
20
  img.save(buffered, format="JPEG")
@@ -39,113 +93,102 @@ def center_crop(img):
39
  return img
40
 
41
 
42
- def resize(name):
43
- def resize_inner(img, align):
44
- global align_cache
45
-
46
- if name in align:
47
- img_hash = hashlib.md5(get_bytes(img)).hexdigest()
48
-
49
- if img_hash not in align_cache:
50
- img = align_face(img, return_tensors=False)[0]
51
- align_cache[img_hash] = img
52
- else:
53
- img = align_cache[img_hash]
54
-
55
- elif img.size != (1024, 1024):
56
- img = center_crop(img)
57
- img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
58
-
59
- return img
60
-
61
- return resize_inner
62
-
63
-
64
- def swap_hair(face, shape, color, blending, poisson_iters, poisson_erosion):
65
- if not face and not shape and not color:
66
- return gr.update(visible=False), gr.update(value="Need to upload a face and at least a shape or color ❗", visible=True)
67
- elif not face:
68
- return gr.update(visible=False), gr.update(value="Need to upload a face ❗", visible=True)
69
- elif not shape and not color:
70
- return gr.update(visible=False), gr.update(value="Need to upload at least a shape or color ❗", visible=True)
71
-
72
- face_bytes, shape_bytes, color_bytes = map(lambda item: get_bytes(item), (face, shape, color))
73
-
74
- if shape_bytes is None:
75
- shape_bytes = b'face'
76
- if color_bytes is None:
77
- color_bytes = b'shape'
78
-
79
- with grpc.insecure_channel(os.environ['SERVER']) as channel:
80
- stub = HairSwapServiceStub(channel)
81
-
82
- output: HairSwapResponse = stub.swap(
83
- HairSwapRequest(face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending,
84
- poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True)
85
- )
86
-
87
- output = bytes_to_image(output.image)
88
- return gr.update(value=output, visible=True), gr.update(visible=False)
89
-
90
-
91
- def get_demo():
92
- with gr.Blocks() as demo:
93
- gr.Markdown("## HairFastGan")
94
- gr.Markdown(
95
- '<div style="display: flex; align-items: center; gap: 10px;">'
96
- '<span>Official HairFastGAN Gradio demo:</span>'
97
- '<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
98
- '<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
99
- '<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
100
- '<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
101
- '</div>'
102
- )
103
- with gr.Row():
104
- with gr.Column():
105
- source = gr.Image(label="Source photo to try on the hairstyle", type="pil")
106
- with gr.Row():
107
- shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil")
108
- color = gr.Image(label="Color photo with desired hair color (optional)", type="pil")
109
- with gr.Accordion("Advanced Options", open=False):
110
- blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
111
- label="Color Encoder version", info="Selects a model for hair color transfer.")
112
- poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters",
113
- info="The power of blending with the original image, helps to recover more details. Not included in the article, disabled by default.")
114
- poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion",
115
- info="Smooths out the blending area.")
116
- align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
117
- label="Image cropping [recommended]",
118
- info="Selects which images to crop by face")
119
- btn = gr.Button("Get the haircut")
120
- with gr.Column():
121
- output = gr.Image(label="Your result")
122
- error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message")
123
-
124
- gr.Examples(examples=[["input/0.png", "input/1.png", "input/2.png"], ["input/6.png", "input/7.png", None],
125
- ["input/10.jpg", None, "input/11.jpg"]],
126
- inputs=[source, shape, color], outputs=output)
127
-
128
- source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
129
- shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
130
- color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
131
-
132
- btn.click(fn=swap_hair, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
133
- outputs=[output, error_message])
134
-
135
- gr.Markdown('''To cite the paper by the authors
136
- ```
137
- @article{nikolaev2024hairfastgan,
138
- title={HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach},
139
- author={Nikolaev, Maxim and Kuznetsov, Mikhail and Vetrov, Dmitry and Alanov, Aibek},
140
- journal={arXiv preprint arXiv:2404.01094},
141
- year={2024}
142
- }
143
- ```
144
- ''')
145
- return demo
146
-
147
-
148
- if __name__ == '__main__':
149
- align_cache = LRUCache(maxsize=10)
150
- demo = get_demo()
151
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import hashlib
2
  import os
3
+ import base64
4
  from io import BytesIO
5
+ from typing import Optional
6
 
 
7
  import grpc
8
+ import uvicorn
9
  from PIL import Image
10
  from cachetools import LRUCache
11
+ from fastapi import FastAPI, HTTPException
12
+ from pydantic import BaseModel
13
 
14
  from inference_pb2 import HairSwapRequest, HairSwapResponse
15
  from inference_pb2_grpc import HairSwapServiceStub
16
  from utils.shape_predictor import align_face
17
 
18
+ app = FastAPI(
19
+ title="HairFastGAN API",
20
+ description="API for HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach",
21
+ version="1.0.0"
22
+ )
23
+
24
+ # Global cache for aligned faces
25
+ align_cache = LRUCache(maxsize=10)
26
+
27
+
28
+ class HairSwapRequest(BaseModel):
29
+ face: str # Base64 encoded image
30
+ shape: Optional[str] = None # Base64 encoded image
31
+ color: Optional[str] = None # Base64 encoded image
32
+ blending: str = "Article"
33
+ poisson_iters: int = 0
34
+ poisson_erosion: int = 15
35
+ align_face_img: bool = True
36
+ align_shape_img: bool = True
37
+ align_color_img: bool = True
38
+
39
+
40
+ class HairSwapResponse(BaseModel):
41
+ image: str # Base64 encoded image
42
+
43
+
44
+ def base64_to_image(base64_str: str) -> Image.Image:
45
+ """Convert base64 string to PIL Image"""
46
+ if not base64_str:
47
+ return None
48
+
49
+ # Remove header if present
50
+ if "base64," in base64_str:
51
+ base64_str = base64_str.split("base64,")[1]
52
+
53
+ image_bytes = base64.b64decode(base64_str)
54
+ image = Image.open(BytesIO(image_bytes))
55
+ return image
56
+
57
+
58
+ def image_to_base64(img: Image.Image, format="JPEG") -> str:
59
+ """Convert PIL Image to base64 string"""
60
+ if img is None:
61
+ return None
62
+
63
+ buffered = BytesIO()
64
+ img.save(buffered, format=format)
65
+ img_str = base64.b64encode(buffered.getvalue()).decode()
66
+ return f"data:image/{format.lower()};base64,{img_str}"
67
+
68
 
69
  def get_bytes(img):
70
  if img is None:
71
+ return None
72
 
73
  buffered = BytesIO()
74
  img.save(buffered, format="JPEG")
 
93
  return img
94
 
95
 
96
+ def process_image(img, should_align=True):
97
+ global align_cache
98
+
99
+ if should_align:
100
+ img_bytes = get_bytes(img)
101
+ img_hash = hashlib.md5(img_bytes).hexdigest()
102
+
103
+ if img_hash not in align_cache:
104
+ img = align_face(img, return_tensors=False)[0]
105
+ align_cache[img_hash] = img
106
+ else:
107
+ img = align_cache[img_hash]
108
+ elif img.size != (1024, 1024):
109
+ img = center_crop(img)
110
+ img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
111
+
112
+ return img
113
+
114
+
115
+ @app.post("/swap-hair", response_model=HairSwapResponse)
116
+ async def swap_hair(request: HairSwapRequest):
117
+ """
118
+ Swap hair in the source face image with the shape and/or color from provided images.
119
+
120
+ - face: Source image as base64 string (required)
121
+ - shape: Image with desired hairstyle shape as base64 string (optional, but either shape or color is required)
122
+ - color: Image with desired hair color as base64 string (optional, but either shape or color is required)
123
+ - blending: Color Encoder version ("Article", "Alternative_v1", or "Alternative_v2")
124
+ - poisson_iters: Power of blending with original image (0-2500)
125
+ - poisson_erosion: Smooths out blending area (1-100)
126
+ - align_face_img: Whether to align the face image
127
+ - align_shape_img: Whether to align the shape image
128
+ - align_color_img: Whether to align the color image
129
+
130
+ Returns the processed image as a base64-encoded JPEG.
131
+ """
132
+ # Validate inputs
133
+ if not request.face:
134
+ raise HTTPException(status_code=400, detail="Need to provide a face image")
135
+ if not request.shape and not request.color:
136
+ raise HTTPException(status_code=400, detail="Need to provide at least a shape or color image")
137
+
138
+ # Convert base64 to images
139
+ try:
140
+ face_img = base64_to_image(request.face)
141
+
142
+ shape_img = None
143
+ if request.shape:
144
+ shape_img = base64_to_image(request.shape)
145
+ shape_img = process_image(shape_img, request.align_shape_img)
146
+
147
+ color_img = None
148
+ if request.color:
149
+ color_img = base64_to_image(request.color)
150
+ color_img = process_image(color_img, request.align_color_img)
151
+
152
+ # Process face image (always required)
153
+ face_img = process_image(face_img, request.align_face_img)
154
+
155
+ except Exception as e:
156
+ raise HTTPException(status_code=400, detail=f"Error processing images: {str(e)}")
157
+
158
+ # Convert images to bytes
159
+ face_bytes = get_bytes(face_img)
160
+ shape_bytes = get_bytes(shape_img) if shape_img else b'face'
161
+ color_bytes = get_bytes(color_img) if color_img else b'shape'
162
+
163
+ # Call gRPC service
164
+ try:
165
+ with grpc.insecure_channel(os.environ['SERVER']) as channel:
166
+ stub = HairSwapServiceStub(channel)
167
+
168
+ output: HairSwapResponse = stub.swap(
169
+ HairSwapRequest(
170
+ face=face_bytes,
171
+ shape=shape_bytes,
172
+ color=color_bytes,
173
+ blending=request.blending,
174
+ poisson_iters=request.poisson_iters,
175
+ poisson_erosion=request.poisson_erosion,
176
+ use_cache=True
177
+ )
178
+ )
179
+
180
+ # Convert result to image
181
+ output_img = bytes_to_image(output.image)
182
+
183
+ # Convert image to base64
184
+ base64_img = image_to_base64(output_img)
185
+
186
+ return HairSwapResponse(image=base64_img)
187
+
188
+ except Exception as e:
189
+ raise HTTPException(status_code=500, detail=f"Error during hair swapping: {str(e)}")
190
+
191
+
192
+ if __name__ == "__main__":
193
+ port = int(os.environ.get("PORT", 8000))
194
+ uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -5,6 +5,7 @@ git+https://github.com/openai/CLIP.git
5
  gdown==3.12.2
6
  grpcio==1.63.0
7
  grpcio_tools==1.63.0
8
- gradio==4.31.5
 
9
  cachetools==5.3.3
10
  dlib==19.24.1
 
5
  gdown==3.12.2
6
  grpcio==1.63.0
7
  grpcio_tools==1.63.0
8
+ fastapi==0.104.1
9
+ uvicorn==0.23.2
10
  cachetools==5.3.3
11
  dlib==19.24.1