HairFastGAN / app.py
0oAstro
api
96904d7 unverified
import hashlib
import os
import base64
from io import BytesIO
from typing import Optional
import grpc
import uvicorn
from PIL import Image
from cachetools import LRUCache
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from inference_pb2 import HairSwapRequest, HairSwapResponse
from inference_pb2_grpc import HairSwapServiceStub
from utils.shape_predictor import align_face
app = FastAPI(
title="HairFastGAN API",
description="API for HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach",
version="1.0.0"
)
# Global cache for aligned faces
align_cache = LRUCache(maxsize=10)
class HairSwapRequest(BaseModel):
face: str # Base64 encoded image
shape: Optional[str] = None # Base64 encoded image
color: Optional[str] = None # Base64 encoded image
blending: str = "Article"
poisson_iters: int = 0
poisson_erosion: int = 15
align_face_img: bool = True
align_shape_img: bool = True
align_color_img: bool = True
class HairSwapResponse(BaseModel):
image: str # Base64 encoded image
def base64_to_image(base64_str: str) -> Image.Image:
"""Convert base64 string to PIL Image"""
if not base64_str:
return None
# Remove header if present
if "base64," in base64_str:
base64_str = base64_str.split("base64,")[1]
image_bytes = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_bytes))
return image
def image_to_base64(img: Image.Image, format="JPEG") -> str:
"""Convert PIL Image to base64 string"""
if img is None:
return None
buffered = BytesIO()
img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/{format.lower()};base64,{img_str}"
def get_bytes(img):
if img is None:
return None
buffered = BytesIO()
img.save(buffered, format="JPEG")
return buffered.getvalue()
def bytes_to_image(image: bytes) -> Image.Image:
image = Image.open(BytesIO(image))
return image
def center_crop(img):
width, height = img.size
side = min(width, height)
left = (width - side) / 2
top = (height - side) / 2
right = (width + side) / 2
bottom = (height + side) / 2
img = img.crop((left, top, right, bottom))
return img
def process_image(img, should_align=True):
global align_cache
if should_align:
img_bytes = get_bytes(img)
img_hash = hashlib.md5(img_bytes).hexdigest()
if img_hash not in align_cache:
img = align_face(img, return_tensors=False)[0]
align_cache[img_hash] = img
else:
img = align_cache[img_hash]
elif img.size != (1024, 1024):
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
return img
@app.post("/swap-hair", response_model=HairSwapResponse)
async def swap_hair(request: HairSwapRequest):
"""
Swap hair in the source face image with the shape and/or color from provided images.
- face: Source image as base64 string (required)
- shape: Image with desired hairstyle shape as base64 string (optional, but either shape or color is required)
- color: Image with desired hair color as base64 string (optional, but either shape or color is required)
- blending: Color Encoder version ("Article", "Alternative_v1", or "Alternative_v2")
- poisson_iters: Power of blending with original image (0-2500)
- poisson_erosion: Smooths out blending area (1-100)
- align_face_img: Whether to align the face image
- align_shape_img: Whether to align the shape image
- align_color_img: Whether to align the color image
Returns the processed image as a base64-encoded JPEG.
"""
# Validate inputs
if not request.face:
raise HTTPException(status_code=400, detail="Need to provide a face image")
if not request.shape and not request.color:
raise HTTPException(status_code=400, detail="Need to provide at least a shape or color image")
# Convert base64 to images
try:
face_img = base64_to_image(request.face)
shape_img = None
if request.shape:
shape_img = base64_to_image(request.shape)
shape_img = process_image(shape_img, request.align_shape_img)
color_img = None
if request.color:
color_img = base64_to_image(request.color)
color_img = process_image(color_img, request.align_color_img)
# Process face image (always required)
face_img = process_image(face_img, request.align_face_img)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing images: {str(e)}")
# Convert images to bytes
face_bytes = get_bytes(face_img)
shape_bytes = get_bytes(shape_img) if shape_img else b'face'
color_bytes = get_bytes(color_img) if color_img else b'shape'
# Call gRPC service
try:
with grpc.insecure_channel(os.environ['SERVER']) as channel:
stub = HairSwapServiceStub(channel)
output: HairSwapResponse = stub.swap(
HairSwapRequest(
face=face_bytes,
shape=shape_bytes,
color=color_bytes,
blending=request.blending,
poisson_iters=request.poisson_iters,
poisson_erosion=request.poisson_erosion,
use_cache=True
)
)
# Convert result to image
output_img = bytes_to_image(output.image)
# Convert image to base64
base64_img = image_to_base64(output_img)
return HairSwapResponse(image=base64_img)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during hair swapping: {str(e)}")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)