Spaces:
Build error
Build error
File size: 6,315 Bytes
3d3e65e 23faa2e 96904d7 23faa2e 96904d7 23faa2e 96904d7 23faa2e 96904d7 23faa2e 96904d7 23faa2e 96904d7 23faa2e 96904d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import hashlib
import os
import base64
from io import BytesIO
from typing import Optional
import grpc
import uvicorn
from PIL import Image
from cachetools import LRUCache
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from inference_pb2 import HairSwapRequest, HairSwapResponse
from inference_pb2_grpc import HairSwapServiceStub
from utils.shape_predictor import align_face
app = FastAPI(
title="HairFastGAN API",
description="API for HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach",
version="1.0.0"
)
# Global cache for aligned faces
align_cache = LRUCache(maxsize=10)
class HairSwapRequest(BaseModel):
face: str # Base64 encoded image
shape: Optional[str] = None # Base64 encoded image
color: Optional[str] = None # Base64 encoded image
blending: str = "Article"
poisson_iters: int = 0
poisson_erosion: int = 15
align_face_img: bool = True
align_shape_img: bool = True
align_color_img: bool = True
class HairSwapResponse(BaseModel):
image: str # Base64 encoded image
def base64_to_image(base64_str: str) -> Image.Image:
"""Convert base64 string to PIL Image"""
if not base64_str:
return None
# Remove header if present
if "base64," in base64_str:
base64_str = base64_str.split("base64,")[1]
image_bytes = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_bytes))
return image
def image_to_base64(img: Image.Image, format="JPEG") -> str:
"""Convert PIL Image to base64 string"""
if img is None:
return None
buffered = BytesIO()
img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/{format.lower()};base64,{img_str}"
def get_bytes(img):
if img is None:
return None
buffered = BytesIO()
img.save(buffered, format="JPEG")
return buffered.getvalue()
def bytes_to_image(image: bytes) -> Image.Image:
image = Image.open(BytesIO(image))
return image
def center_crop(img):
width, height = img.size
side = min(width, height)
left = (width - side) / 2
top = (height - side) / 2
right = (width + side) / 2
bottom = (height + side) / 2
img = img.crop((left, top, right, bottom))
return img
def process_image(img, should_align=True):
global align_cache
if should_align:
img_bytes = get_bytes(img)
img_hash = hashlib.md5(img_bytes).hexdigest()
if img_hash not in align_cache:
img = align_face(img, return_tensors=False)[0]
align_cache[img_hash] = img
else:
img = align_cache[img_hash]
elif img.size != (1024, 1024):
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
return img
@app.post("/swap-hair", response_model=HairSwapResponse)
async def swap_hair(request: HairSwapRequest):
"""
Swap hair in the source face image with the shape and/or color from provided images.
- face: Source image as base64 string (required)
- shape: Image with desired hairstyle shape as base64 string (optional, but either shape or color is required)
- color: Image with desired hair color as base64 string (optional, but either shape or color is required)
- blending: Color Encoder version ("Article", "Alternative_v1", or "Alternative_v2")
- poisson_iters: Power of blending with original image (0-2500)
- poisson_erosion: Smooths out blending area (1-100)
- align_face_img: Whether to align the face image
- align_shape_img: Whether to align the shape image
- align_color_img: Whether to align the color image
Returns the processed image as a base64-encoded JPEG.
"""
# Validate inputs
if not request.face:
raise HTTPException(status_code=400, detail="Need to provide a face image")
if not request.shape and not request.color:
raise HTTPException(status_code=400, detail="Need to provide at least a shape or color image")
# Convert base64 to images
try:
face_img = base64_to_image(request.face)
shape_img = None
if request.shape:
shape_img = base64_to_image(request.shape)
shape_img = process_image(shape_img, request.align_shape_img)
color_img = None
if request.color:
color_img = base64_to_image(request.color)
color_img = process_image(color_img, request.align_color_img)
# Process face image (always required)
face_img = process_image(face_img, request.align_face_img)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing images: {str(e)}")
# Convert images to bytes
face_bytes = get_bytes(face_img)
shape_bytes = get_bytes(shape_img) if shape_img else b'face'
color_bytes = get_bytes(color_img) if color_img else b'shape'
# Call gRPC service
try:
with grpc.insecure_channel(os.environ['SERVER']) as channel:
stub = HairSwapServiceStub(channel)
output: HairSwapResponse = stub.swap(
HairSwapRequest(
face=face_bytes,
shape=shape_bytes,
color=color_bytes,
blending=request.blending,
poisson_iters=request.poisson_iters,
poisson_erosion=request.poisson_erosion,
use_cache=True
)
)
# Convert result to image
output_img = bytes_to_image(output.image)
# Convert image to base64
base64_img = image_to_base64(output_img)
return HairSwapResponse(image=base64_img)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during hair swapping: {str(e)}")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)
|