Spaces:
Build error
Build error
File size: 6,853 Bytes
57b5fc8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | import os
import hashlib
from io import BytesIO
import base64
from flask import Flask, request, jsonify
from flask_cors import CORS
from PIL import Image
import grpc
from cachetools import LRUCache
from inference_pb2 import HairSwapRequest, HairSwapResponse
from inference_pb2_grpc import HairSwapServiceStub
from utils.shape_predictor import align_face
app = Flask(__name__)
CORS(app)
# Global cache
align_cache = LRUCache(maxsize=10)
def get_bytes(img):
if img is None:
return None
buffered = BytesIO()
img.save(buffered, format="JPEG")
return buffered.getvalue()
def bytes_to_image(image_bytes: bytes) -> Image.Image:
return Image.open(BytesIO(image_bytes))
def base64_to_image(base64_string: str) -> Image.Image:
"""Convert base64 string to PIL Image"""
image_data = base64.b64decode(base64_string.split(',')[-1])
return Image.open(BytesIO(image_data))
def image_to_base64(img: Image.Image) -> str:
"""Convert PIL Image to base64 string"""
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
def center_crop(img):
width, height = img.size
side = min(width, height)
left = (width - side) / 2
top = (height - side) / 2
right = (width + side) / 2
bottom = (height + side) / 2
return img.crop((left, top, right, bottom))
def resize_image(img, should_align=True):
"""Resize and optionally align image"""
if should_align:
img_hash = hashlib.md5(get_bytes(img)).hexdigest()
if img_hash not in align_cache:
img = align_face(img, return_tensors=False)[0]
align_cache[img_hash] = img
else:
img = align_cache[img_hash]
elif img.size != (1024, 1024):
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
return img
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({"status": "healthy", "service": "HairFastGAN API"}), 200
@app.route('/api/swap-hair', methods=['POST'])
def swap_hair():
"""
Hair swap endpoint
Expected JSON payload:
{
"face": "base64_encoded_image",
"shape": "base64_encoded_image (optional)",
"color": "base64_encoded_image (optional)",
"blending": "Article|Alternative_v1|Alternative_v2 (default: Article)",
"poisson_iters": 0-2500 (default: 0),
"poisson_erosion": 1-100 (default: 15),
"align_face": true|false (default: true),
"align_shape": true|false (default: true),
"align_color": true|false (default: true)
}
"""
try:
data = request.get_json()
if not data:
return jsonify({"error": "No JSON data provided"}), 400
# Validate required fields
if 'face' not in data:
return jsonify({"error": "Face image is required"}), 400
if 'shape' not in data and 'color' not in data:
return jsonify({"error": "At least shape or color image is required"}), 400
# Parse images
face_img = base64_to_image(data['face'])
shape_img = base64_to_image(data['shape']) if 'shape' in data and data['shape'] else None
color_img = base64_to_image(data['color']) if 'color' in data and data['color'] else None
# Get options
blending = data.get('blending', 'Article')
poisson_iters = int(data.get('poisson_iters', 0))
poisson_erosion = int(data.get('poisson_erosion', 15))
align_face_flag = data.get('align_face', True)
align_shape_flag = data.get('align_shape', True)
align_color_flag = data.get('align_color', True)
# Validate blending option
if blending not in ['Article', 'Alternative_v1', 'Alternative_v2']:
return jsonify({"error": "Invalid blending option"}), 400
# Resize images
face_img = resize_image(face_img, align_face_flag)
if shape_img:
shape_img = resize_image(shape_img, align_shape_flag)
if color_img:
color_img = resize_image(color_img, align_color_flag)
# Convert to bytes
face_bytes = get_bytes(face_img)
shape_bytes = get_bytes(shape_img) if shape_img else b'face'
color_bytes = get_bytes(color_img) if color_img else b'shape'
# Call gRPC service
with grpc.insecure_channel(os.environ.get('SERVER', 'localhost:50051')) as channel:
stub = HairSwapServiceStub(channel)
output: HairSwapResponse = stub.swap(
HairSwapRequest(
face=face_bytes,
shape=shape_bytes,
color=color_bytes,
blending=blending,
poisson_iters=poisson_iters,
poisson_erosion=poisson_erosion,
use_cache=True
)
)
# Convert result to base64
output_img = bytes_to_image(output.image)
result_base64 = image_to_base64(output_img)
return jsonify({
"success": True,
"result": result_base64,
"message": "Hair swap completed successfully"
}), 200
except ValueError as e:
return jsonify({"error": f"Invalid input: {str(e)}"}), 400
except grpc.RpcError as e:
return jsonify({"error": f"gRPC error: {str(e)}"}), 500
except Exception as e:
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
@app.route('/', methods=['GET'])
def index():
"""API documentation endpoint"""
return jsonify({
"service": "HairFastGAN API",
"version": "1.0",
"endpoints": {
"/health": "GET - Health check",
"/api/swap-hair": "POST - Hair swap endpoint",
"/test": "GET - Test HTML interface"
},
"documentation": {
"swap_hair": {
"method": "POST",
"content_type": "application/json",
"required_fields": ["face", "shape or color"],
"optional_fields": {
"blending": "Article (default), Alternative_v1, Alternative_v2",
"poisson_iters": "0-2500 (default: 0)",
"poisson_erosion": "1-100 (default: 15)",
"align_face": "true (default) or false",
"align_shape": "true (default) or false",
"align_color": "true (default) or false"
}
}
}
}), 200
if __name__ == '__main__':
port = int(os.environ.get('PORT', 5000))
app.run(host='0.0.0.0', port=port, debug=False) |