Spaces:
Build error
Build error
| import os | |
| import hashlib | |
| from io import BytesIO | |
| import base64 | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from PIL import Image | |
| import grpc | |
| from cachetools import LRUCache | |
| from inference_pb2 import HairSwapRequest, HairSwapResponse | |
| from inference_pb2_grpc import HairSwapServiceStub | |
| from utils.shape_predictor import align_face | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Global cache | |
| align_cache = LRUCache(maxsize=10) | |
| def get_bytes(img): | |
| if img is None: | |
| return None | |
| buffered = BytesIO() | |
| img.save(buffered, format="JPEG") | |
| return buffered.getvalue() | |
| def bytes_to_image(image_bytes: bytes) -> Image.Image: | |
| return Image.open(BytesIO(image_bytes)) | |
| def base64_to_image(base64_string: str) -> Image.Image: | |
| """Convert base64 string to PIL Image""" | |
| image_data = base64.b64decode(base64_string.split(',')[-1]) | |
| return Image.open(BytesIO(image_data)) | |
| def image_to_base64(img: Image.Image) -> str: | |
| """Convert PIL Image to base64 string""" | |
| buffered = BytesIO() | |
| img.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/jpeg;base64,{img_str}" | |
| def center_crop(img): | |
| width, height = img.size | |
| side = min(width, height) | |
| left = (width - side) / 2 | |
| top = (height - side) / 2 | |
| right = (width + side) / 2 | |
| bottom = (height + side) / 2 | |
| return img.crop((left, top, right, bottom)) | |
| def resize_image(img, should_align=True): | |
| """Resize and optionally align image""" | |
| if should_align: | |
| img_hash = hashlib.md5(get_bytes(img)).hexdigest() | |
| if img_hash not in align_cache: | |
| img = align_face(img, return_tensors=False)[0] | |
| align_cache[img_hash] = img | |
| else: | |
| img = align_cache[img_hash] | |
| elif img.size != (1024, 1024): | |
| img = center_crop(img) | |
| img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| return img | |
| def health_check(): | |
| """Health check endpoint""" | |
| return jsonify({"status": "healthy", "service": "HairFastGAN API"}), 200 | |
| def swap_hair(): | |
| """ | |
| Hair swap endpoint | |
| Expected JSON payload: | |
| { | |
| "face": "base64_encoded_image", | |
| "shape": "base64_encoded_image (optional)", | |
| "color": "base64_encoded_image (optional)", | |
| "blending": "Article|Alternative_v1|Alternative_v2 (default: Article)", | |
| "poisson_iters": 0-2500 (default: 0), | |
| "poisson_erosion": 1-100 (default: 15), | |
| "align_face": true|false (default: true), | |
| "align_shape": true|false (default: true), | |
| "align_color": true|false (default: true) | |
| } | |
| """ | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({"error": "No JSON data provided"}), 400 | |
| # Validate required fields | |
| if 'face' not in data: | |
| return jsonify({"error": "Face image is required"}), 400 | |
| if 'shape' not in data and 'color' not in data: | |
| return jsonify({"error": "At least shape or color image is required"}), 400 | |
| # Parse images | |
| face_img = base64_to_image(data['face']) | |
| shape_img = base64_to_image(data['shape']) if 'shape' in data and data['shape'] else None | |
| color_img = base64_to_image(data['color']) if 'color' in data and data['color'] else None | |
| # Get options | |
| blending = data.get('blending', 'Article') | |
| poisson_iters = int(data.get('poisson_iters', 0)) | |
| poisson_erosion = int(data.get('poisson_erosion', 15)) | |
| align_face_flag = data.get('align_face', True) | |
| align_shape_flag = data.get('align_shape', True) | |
| align_color_flag = data.get('align_color', True) | |
| # Validate blending option | |
| if blending not in ['Article', 'Alternative_v1', 'Alternative_v2']: | |
| return jsonify({"error": "Invalid blending option"}), 400 | |
| # Resize images | |
| face_img = resize_image(face_img, align_face_flag) | |
| if shape_img: | |
| shape_img = resize_image(shape_img, align_shape_flag) | |
| if color_img: | |
| color_img = resize_image(color_img, align_color_flag) | |
| # Convert to bytes | |
| face_bytes = get_bytes(face_img) | |
| shape_bytes = get_bytes(shape_img) if shape_img else b'face' | |
| color_bytes = get_bytes(color_img) if color_img else b'shape' | |
| # Call gRPC service | |
| with grpc.insecure_channel(os.environ.get('SERVER', 'localhost:50051')) as channel: | |
| stub = HairSwapServiceStub(channel) | |
| output: HairSwapResponse = stub.swap( | |
| HairSwapRequest( | |
| face=face_bytes, | |
| shape=shape_bytes, | |
| color=color_bytes, | |
| blending=blending, | |
| poisson_iters=poisson_iters, | |
| poisson_erosion=poisson_erosion, | |
| use_cache=True | |
| ) | |
| ) | |
| # Convert result to base64 | |
| output_img = bytes_to_image(output.image) | |
| result_base64 = image_to_base64(output_img) | |
| return jsonify({ | |
| "success": True, | |
| "result": result_base64, | |
| "message": "Hair swap completed successfully" | |
| }), 200 | |
| except ValueError as e: | |
| return jsonify({"error": f"Invalid input: {str(e)}"}), 400 | |
| except grpc.RpcError as e: | |
| return jsonify({"error": f"gRPC error: {str(e)}"}), 500 | |
| except Exception as e: | |
| return jsonify({"error": f"Internal server error: {str(e)}"}), 500 | |
| def index(): | |
| """API documentation endpoint""" | |
| return jsonify({ | |
| "service": "HairFastGAN API", | |
| "version": "1.0", | |
| "endpoints": { | |
| "/health": "GET - Health check", | |
| "/api/swap-hair": "POST - Hair swap endpoint", | |
| "/test": "GET - Test HTML interface" | |
| }, | |
| "documentation": { | |
| "swap_hair": { | |
| "method": "POST", | |
| "content_type": "application/json", | |
| "required_fields": ["face", "shape or color"], | |
| "optional_fields": { | |
| "blending": "Article (default), Alternative_v1, Alternative_v2", | |
| "poisson_iters": "0-2500 (default: 0)", | |
| "poisson_erosion": "1-100 (default: 15)", | |
| "align_face": "true (default) or false", | |
| "align_shape": "true (default) or false", | |
| "align_color": "true (default) or false" | |
| } | |
| } | |
| } | |
| }), 200 | |
| if __name__ == '__main__': | |
| port = int(os.environ.get('PORT', 5000)) | |
| app.run(host='0.0.0.0', port=port, debug=False) |