HairFastGANendpoint / api_server.py
jefrisuparjanaAI's picture
Create api_server.py
57b5fc8 verified
import os
import hashlib
from io import BytesIO
import base64
from flask import Flask, request, jsonify
from flask_cors import CORS
from PIL import Image
import grpc
from cachetools import LRUCache
from inference_pb2 import HairSwapRequest, HairSwapResponse
from inference_pb2_grpc import HairSwapServiceStub
from utils.shape_predictor import align_face
app = Flask(__name__)
CORS(app)
# Global cache
align_cache = LRUCache(maxsize=10)
def get_bytes(img):
if img is None:
return None
buffered = BytesIO()
img.save(buffered, format="JPEG")
return buffered.getvalue()
def bytes_to_image(image_bytes: bytes) -> Image.Image:
return Image.open(BytesIO(image_bytes))
def base64_to_image(base64_string: str) -> Image.Image:
"""Convert base64 string to PIL Image"""
image_data = base64.b64decode(base64_string.split(',')[-1])
return Image.open(BytesIO(image_data))
def image_to_base64(img: Image.Image) -> str:
"""Convert PIL Image to base64 string"""
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
def center_crop(img):
width, height = img.size
side = min(width, height)
left = (width - side) / 2
top = (height - side) / 2
right = (width + side) / 2
bottom = (height + side) / 2
return img.crop((left, top, right, bottom))
def resize_image(img, should_align=True):
"""Resize and optionally align image"""
if should_align:
img_hash = hashlib.md5(get_bytes(img)).hexdigest()
if img_hash not in align_cache:
img = align_face(img, return_tensors=False)[0]
align_cache[img_hash] = img
else:
img = align_cache[img_hash]
elif img.size != (1024, 1024):
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
return img
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({"status": "healthy", "service": "HairFastGAN API"}), 200
@app.route('/api/swap-hair', methods=['POST'])
def swap_hair():
"""
Hair swap endpoint
Expected JSON payload:
{
"face": "base64_encoded_image",
"shape": "base64_encoded_image (optional)",
"color": "base64_encoded_image (optional)",
"blending": "Article|Alternative_v1|Alternative_v2 (default: Article)",
"poisson_iters": 0-2500 (default: 0),
"poisson_erosion": 1-100 (default: 15),
"align_face": true|false (default: true),
"align_shape": true|false (default: true),
"align_color": true|false (default: true)
}
"""
try:
data = request.get_json()
if not data:
return jsonify({"error": "No JSON data provided"}), 400
# Validate required fields
if 'face' not in data:
return jsonify({"error": "Face image is required"}), 400
if 'shape' not in data and 'color' not in data:
return jsonify({"error": "At least shape or color image is required"}), 400
# Parse images
face_img = base64_to_image(data['face'])
shape_img = base64_to_image(data['shape']) if 'shape' in data and data['shape'] else None
color_img = base64_to_image(data['color']) if 'color' in data and data['color'] else None
# Get options
blending = data.get('blending', 'Article')
poisson_iters = int(data.get('poisson_iters', 0))
poisson_erosion = int(data.get('poisson_erosion', 15))
align_face_flag = data.get('align_face', True)
align_shape_flag = data.get('align_shape', True)
align_color_flag = data.get('align_color', True)
# Validate blending option
if blending not in ['Article', 'Alternative_v1', 'Alternative_v2']:
return jsonify({"error": "Invalid blending option"}), 400
# Resize images
face_img = resize_image(face_img, align_face_flag)
if shape_img:
shape_img = resize_image(shape_img, align_shape_flag)
if color_img:
color_img = resize_image(color_img, align_color_flag)
# Convert to bytes
face_bytes = get_bytes(face_img)
shape_bytes = get_bytes(shape_img) if shape_img else b'face'
color_bytes = get_bytes(color_img) if color_img else b'shape'
# Call gRPC service
with grpc.insecure_channel(os.environ.get('SERVER', 'localhost:50051')) as channel:
stub = HairSwapServiceStub(channel)
output: HairSwapResponse = stub.swap(
HairSwapRequest(
face=face_bytes,
shape=shape_bytes,
color=color_bytes,
blending=blending,
poisson_iters=poisson_iters,
poisson_erosion=poisson_erosion,
use_cache=True
)
)
# Convert result to base64
output_img = bytes_to_image(output.image)
result_base64 = image_to_base64(output_img)
return jsonify({
"success": True,
"result": result_base64,
"message": "Hair swap completed successfully"
}), 200
except ValueError as e:
return jsonify({"error": f"Invalid input: {str(e)}"}), 400
except grpc.RpcError as e:
return jsonify({"error": f"gRPC error: {str(e)}"}), 500
except Exception as e:
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
@app.route('/', methods=['GET'])
def index():
"""API documentation endpoint"""
return jsonify({
"service": "HairFastGAN API",
"version": "1.0",
"endpoints": {
"/health": "GET - Health check",
"/api/swap-hair": "POST - Hair swap endpoint",
"/test": "GET - Test HTML interface"
},
"documentation": {
"swap_hair": {
"method": "POST",
"content_type": "application/json",
"required_fields": ["face", "shape or color"],
"optional_fields": {
"blending": "Article (default), Alternative_v1, Alternative_v2",
"poisson_iters": "0-2500 (default: 0)",
"poisson_erosion": "1-100 (default: 15)",
"align_face": "true (default) or false",
"align_shape": "true (default) or false",
"align_color": "true (default) or false"
}
}
}
}), 200
if __name__ == '__main__':
port = int(os.environ.get('PORT', 5000))
app.run(host='0.0.0.0', port=port, debug=False)