File size: 6,315 Bytes
3d3e65e
23faa2e
96904d7
23faa2e
96904d7
23faa2e
 
96904d7
23faa2e
 
96904d7
 
23faa2e
 
 
 
 
96904d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23faa2e
 
 
96904d7
23faa2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96904d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import hashlib
import os
import base64
from io import BytesIO
from typing import Optional

import grpc
import uvicorn
from PIL import Image
from cachetools import LRUCache
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from inference_pb2 import HairSwapRequest, HairSwapResponse
from inference_pb2_grpc import HairSwapServiceStub
from utils.shape_predictor import align_face

app = FastAPI(
    title="HairFastGAN API",
    description="API for HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach",
    version="1.0.0"
)

# Global cache for aligned faces
align_cache = LRUCache(maxsize=10)


class HairSwapRequest(BaseModel):
    face: str  # Base64 encoded image
    shape: Optional[str] = None  # Base64 encoded image
    color: Optional[str] = None  # Base64 encoded image
    blending: str = "Article"
    poisson_iters: int = 0
    poisson_erosion: int = 15
    align_face_img: bool = True
    align_shape_img: bool = True
    align_color_img: bool = True


class HairSwapResponse(BaseModel):
    image: str  # Base64 encoded image


def base64_to_image(base64_str: str) -> Image.Image:
    """Convert base64 string to PIL Image"""
    if not base64_str:
        return None
    
    # Remove header if present
    if "base64," in base64_str:
        base64_str = base64_str.split("base64,")[1]
    
    image_bytes = base64.b64decode(base64_str)
    image = Image.open(BytesIO(image_bytes))
    return image


def image_to_base64(img: Image.Image, format="JPEG") -> str:
    """Convert PIL Image to base64 string"""
    if img is None:
        return None
    
    buffered = BytesIO()
    img.save(buffered, format=format)
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/{format.lower()};base64,{img_str}"


def get_bytes(img):
    if img is None:
        return None

    buffered = BytesIO()
    img.save(buffered, format="JPEG")
    return buffered.getvalue()


def bytes_to_image(image: bytes) -> Image.Image:
    image = Image.open(BytesIO(image))
    return image


def center_crop(img):
    width, height = img.size
    side = min(width, height)

    left = (width - side) / 2
    top = (height - side) / 2
    right = (width + side) / 2
    bottom = (height + side) / 2

    img = img.crop((left, top, right, bottom))
    return img


def process_image(img, should_align=True):
    global align_cache
    
    if should_align:
        img_bytes = get_bytes(img)
        img_hash = hashlib.md5(img_bytes).hexdigest()

        if img_hash not in align_cache:
            img = align_face(img, return_tensors=False)[0]
            align_cache[img_hash] = img
        else:
            img = align_cache[img_hash]
    elif img.size != (1024, 1024):
        img = center_crop(img)
        img = img.resize((1024, 1024), Image.Resampling.LANCZOS)

    return img


@app.post("/swap-hair", response_model=HairSwapResponse)
async def swap_hair(request: HairSwapRequest):
    """

    Swap hair in the source face image with the shape and/or color from provided images.

    

    - face: Source image as base64 string (required)

    - shape: Image with desired hairstyle shape as base64 string (optional, but either shape or color is required)

    - color: Image with desired hair color as base64 string (optional, but either shape or color is required)

    - blending: Color Encoder version ("Article", "Alternative_v1", or "Alternative_v2")

    - poisson_iters: Power of blending with original image (0-2500)

    - poisson_erosion: Smooths out blending area (1-100)

    - align_face_img: Whether to align the face image

    - align_shape_img: Whether to align the shape image

    - align_color_img: Whether to align the color image

    

    Returns the processed image as a base64-encoded JPEG.

    """
    # Validate inputs
    if not request.face:
        raise HTTPException(status_code=400, detail="Need to provide a face image")
    if not request.shape and not request.color:
        raise HTTPException(status_code=400, detail="Need to provide at least a shape or color image")
    
    # Convert base64 to images
    try:
        face_img = base64_to_image(request.face)
        
        shape_img = None
        if request.shape:
            shape_img = base64_to_image(request.shape)
            shape_img = process_image(shape_img, request.align_shape_img)
            
        color_img = None
        if request.color:
            color_img = base64_to_image(request.color)
            color_img = process_image(color_img, request.align_color_img)
            
        # Process face image (always required)
        face_img = process_image(face_img, request.align_face_img)
        
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"Error processing images: {str(e)}")
    
    # Convert images to bytes
    face_bytes = get_bytes(face_img)
    shape_bytes = get_bytes(shape_img) if shape_img else b'face'
    color_bytes = get_bytes(color_img) if color_img else b'shape'
    
    # Call gRPC service
    try:
        with grpc.insecure_channel(os.environ['SERVER']) as channel:
            stub = HairSwapServiceStub(channel)
            
            output: HairSwapResponse = stub.swap(
                HairSwapRequest(
                    face=face_bytes, 
                    shape=shape_bytes, 
                    color=color_bytes, 
                    blending=request.blending,
                    poisson_iters=request.poisson_iters, 
                    poisson_erosion=request.poisson_erosion, 
                    use_cache=True
                )
            )
            
        # Convert result to image
        output_img = bytes_to_image(output.image)
        
        # Convert image to base64
        base64_img = image_to_base64(output_img)
        
        return HairSwapResponse(image=base64_img)
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during hair swapping: {str(e)}")


if __name__ == "__main__":
    port = int(os.environ.get("PORT", 8000))
    uvicorn.run(app, host="0.0.0.0", port=port)