import os import hashlib from io import BytesIO import base64 from flask import Flask, request, jsonify from flask_cors import CORS from PIL import Image import grpc from cachetools import LRUCache from inference_pb2 import HairSwapRequest, HairSwapResponse from inference_pb2_grpc import HairSwapServiceStub from utils.shape_predictor import align_face app = Flask(__name__) CORS(app) # Global cache align_cache = LRUCache(maxsize=10) def get_bytes(img): if img is None: return None buffered = BytesIO() img.save(buffered, format="JPEG") return buffered.getvalue() def bytes_to_image(image_bytes: bytes) -> Image.Image: return Image.open(BytesIO(image_bytes)) def base64_to_image(base64_string: str) -> Image.Image: """Convert base64 string to PIL Image""" image_data = base64.b64decode(base64_string.split(',')[-1]) return Image.open(BytesIO(image_data)) def image_to_base64(img: Image.Image) -> str: """Convert PIL Image to base64 string""" buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/jpeg;base64,{img_str}" def center_crop(img): width, height = img.size side = min(width, height) left = (width - side) / 2 top = (height - side) / 2 right = (width + side) / 2 bottom = (height + side) / 2 return img.crop((left, top, right, bottom)) def resize_image(img, should_align=True): """Resize and optionally align image""" if should_align: img_hash = hashlib.md5(get_bytes(img)).hexdigest() if img_hash not in align_cache: img = align_face(img, return_tensors=False)[0] align_cache[img_hash] = img else: img = align_cache[img_hash] elif img.size != (1024, 1024): img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) return img @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint""" return jsonify({"status": "healthy", "service": "HairFastGAN API"}), 200 @app.route('/api/swap-hair', methods=['POST']) def swap_hair(): """ Hair swap endpoint Expected JSON payload: { "face": "base64_encoded_image", "shape": "base64_encoded_image (optional)", "color": "base64_encoded_image (optional)", "blending": "Article|Alternative_v1|Alternative_v2 (default: Article)", "poisson_iters": 0-2500 (default: 0), "poisson_erosion": 1-100 (default: 15), "align_face": true|false (default: true), "align_shape": true|false (default: true), "align_color": true|false (default: true) } """ try: data = request.get_json() if not data: return jsonify({"error": "No JSON data provided"}), 400 # Validate required fields if 'face' not in data: return jsonify({"error": "Face image is required"}), 400 if 'shape' not in data and 'color' not in data: return jsonify({"error": "At least shape or color image is required"}), 400 # Parse images face_img = base64_to_image(data['face']) shape_img = base64_to_image(data['shape']) if 'shape' in data and data['shape'] else None color_img = base64_to_image(data['color']) if 'color' in data and data['color'] else None # Get options blending = data.get('blending', 'Article') poisson_iters = int(data.get('poisson_iters', 0)) poisson_erosion = int(data.get('poisson_erosion', 15)) align_face_flag = data.get('align_face', True) align_shape_flag = data.get('align_shape', True) align_color_flag = data.get('align_color', True) # Validate blending option if blending not in ['Article', 'Alternative_v1', 'Alternative_v2']: return jsonify({"error": "Invalid blending option"}), 400 # Resize images face_img = resize_image(face_img, align_face_flag) if shape_img: shape_img = resize_image(shape_img, align_shape_flag) if color_img: color_img = resize_image(color_img, align_color_flag) # Convert to bytes face_bytes = get_bytes(face_img) shape_bytes = get_bytes(shape_img) if shape_img else b'face' color_bytes = get_bytes(color_img) if color_img else b'shape' # Call gRPC service with grpc.insecure_channel(os.environ.get('SERVER', 'localhost:50051')) as channel: stub = HairSwapServiceStub(channel) output: HairSwapResponse = stub.swap( HairSwapRequest( face=face_bytes, shape=shape_bytes, color=color_bytes, blending=blending, poisson_iters=poisson_iters, poisson_erosion=poisson_erosion, use_cache=True ) ) # Convert result to base64 output_img = bytes_to_image(output.image) result_base64 = image_to_base64(output_img) return jsonify({ "success": True, "result": result_base64, "message": "Hair swap completed successfully" }), 200 except ValueError as e: return jsonify({"error": f"Invalid input: {str(e)}"}), 400 except grpc.RpcError as e: return jsonify({"error": f"gRPC error: {str(e)}"}), 500 except Exception as e: return jsonify({"error": f"Internal server error: {str(e)}"}), 500 @app.route('/', methods=['GET']) def index(): """API documentation endpoint""" return jsonify({ "service": "HairFastGAN API", "version": "1.0", "endpoints": { "/health": "GET - Health check", "/api/swap-hair": "POST - Hair swap endpoint", "/test": "GET - Test HTML interface" }, "documentation": { "swap_hair": { "method": "POST", "content_type": "application/json", "required_fields": ["face", "shape or color"], "optional_fields": { "blending": "Article (default), Alternative_v1, Alternative_v2", "poisson_iters": "0-2500 (default: 0)", "poisson_erosion": "1-100 (default: 15)", "align_face": "true (default) or false", "align_shape": "true (default) or false", "align_color": "true (default) or false" } } } }), 200 if __name__ == '__main__': port = int(os.environ.get('PORT', 5000)) app.run(host='0.0.0.0', port=port, debug=False)