import hashlib import os import base64 from io import BytesIO from typing import Optional import grpc import uvicorn from PIL import Image from cachetools import LRUCache from fastapi import FastAPI, HTTPException from pydantic import BaseModel from inference_pb2 import HairSwapRequest, HairSwapResponse from inference_pb2_grpc import HairSwapServiceStub from utils.shape_predictor import align_face app = FastAPI( title="HairFastGAN API", description="API for HairFastGAN: Realistic and Robust Hair Transfer with a Fast Encoder-Based Approach", version="1.0.0" ) # Global cache for aligned faces align_cache = LRUCache(maxsize=10) class HairSwapRequest(BaseModel): face: str # Base64 encoded image shape: Optional[str] = None # Base64 encoded image color: Optional[str] = None # Base64 encoded image blending: str = "Article" poisson_iters: int = 0 poisson_erosion: int = 15 align_face_img: bool = True align_shape_img: bool = True align_color_img: bool = True class HairSwapResponse(BaseModel): image: str # Base64 encoded image def base64_to_image(base64_str: str) -> Image.Image: """Convert base64 string to PIL Image""" if not base64_str: return None # Remove header if present if "base64," in base64_str: base64_str = base64_str.split("base64,")[1] image_bytes = base64.b64decode(base64_str) image = Image.open(BytesIO(image_bytes)) return image def image_to_base64(img: Image.Image, format="JPEG") -> str: """Convert PIL Image to base64 string""" if img is None: return None buffered = BytesIO() img.save(buffered, format=format) img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/{format.lower()};base64,{img_str}" def get_bytes(img): if img is None: return None buffered = BytesIO() img.save(buffered, format="JPEG") return buffered.getvalue() def bytes_to_image(image: bytes) -> Image.Image: image = Image.open(BytesIO(image)) return image def center_crop(img): width, height = img.size side = min(width, height) left = (width - side) / 2 top = (height - side) / 2 right = (width + side) / 2 bottom = (height + side) / 2 img = img.crop((left, top, right, bottom)) return img def process_image(img, should_align=True): global align_cache if should_align: img_bytes = get_bytes(img) img_hash = hashlib.md5(img_bytes).hexdigest() if img_hash not in align_cache: img = align_face(img, return_tensors=False)[0] align_cache[img_hash] = img else: img = align_cache[img_hash] elif img.size != (1024, 1024): img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) return img @app.post("/swap-hair", response_model=HairSwapResponse) async def swap_hair(request: HairSwapRequest): """ Swap hair in the source face image with the shape and/or color from provided images. - face: Source image as base64 string (required) - shape: Image with desired hairstyle shape as base64 string (optional, but either shape or color is required) - color: Image with desired hair color as base64 string (optional, but either shape or color is required) - blending: Color Encoder version ("Article", "Alternative_v1", or "Alternative_v2") - poisson_iters: Power of blending with original image (0-2500) - poisson_erosion: Smooths out blending area (1-100) - align_face_img: Whether to align the face image - align_shape_img: Whether to align the shape image - align_color_img: Whether to align the color image Returns the processed image as a base64-encoded JPEG. """ # Validate inputs if not request.face: raise HTTPException(status_code=400, detail="Need to provide a face image") if not request.shape and not request.color: raise HTTPException(status_code=400, detail="Need to provide at least a shape or color image") # Convert base64 to images try: face_img = base64_to_image(request.face) shape_img = None if request.shape: shape_img = base64_to_image(request.shape) shape_img = process_image(shape_img, request.align_shape_img) color_img = None if request.color: color_img = base64_to_image(request.color) color_img = process_image(color_img, request.align_color_img) # Process face image (always required) face_img = process_image(face_img, request.align_face_img) except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing images: {str(e)}") # Convert images to bytes face_bytes = get_bytes(face_img) shape_bytes = get_bytes(shape_img) if shape_img else b'face' color_bytes = get_bytes(color_img) if color_img else b'shape' # Call gRPC service try: with grpc.insecure_channel(os.environ['SERVER']) as channel: stub = HairSwapServiceStub(channel) output: HairSwapResponse = stub.swap( HairSwapRequest( face=face_bytes, shape=shape_bytes, color=color_bytes, blending=request.blending, poisson_iters=request.poisson_iters, poisson_erosion=request.poisson_erosion, use_cache=True ) ) # Convert result to image output_img = bytes_to_image(output.image) # Convert image to base64 base64_img = image_to_base64(output_img) return HairSwapResponse(image=base64_img) except Exception as e: raise HTTPException(status_code=500, detail=f"Error during hair swapping: {str(e)}") if __name__ == "__main__": port = int(os.environ.get("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)