Spaces:
Paused
Paused
| import hashlib | |
| import os | |
| from io import BytesIO | |
| import base64 | |
| import requests | |
| from pathlib import Path | |
| import subprocess | |
| import shutil | |
| import gc | |
| import time | |
| import json | |
| import threading | |
| import gradio as gr | |
| from PIL import Image | |
| from cachetools import LRUCache | |
| import torch | |
| import numpy as np | |
| import torchvision.transforms.functional as F | |
| # FastAPI imports for enhanced frontend integration | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional | |
| import uvicorn | |
| # T4 Medium GPU Optimizations | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cuda.max_split_size_mb = 512 | |
| # API Models for FastAPI | |
| class ImageRequest(BaseModel): | |
| source_image: str # base64 string | |
| shape_image: Optional[str] = None # base64 string | |
| color_image: Optional[str] = None # base64 string | |
| blending: str = "Article" | |
| poisson_iters: int = 0 | |
| poisson_erosion: int = 15 | |
| class ImageResponse(BaseModel): | |
| success: bool | |
| result_image: Optional[str] = None # base64 string | |
| message: str | |
| # Download working face landmarks | |
| def download_face_landmarks(): | |
| """Download working dlib face landmarks predictor""" | |
| landmarks_path = 'pretrained_models/ShapeAdaptor/shape_predictor_68_face_landmarks.dat' | |
| if os.path.exists(landmarks_path) and os.path.getsize(landmarks_path) > 50000000: | |
| print("Face landmarks already exists and appears valid") | |
| return True | |
| try: | |
| print("Downloading working face landmarks predictor...") | |
| url = 'https://github.com/davisking/dlib-models/raw/master/shape_predictor_68_face_landmarks.dat.bz2' | |
| os.makedirs(os.path.dirname(landmarks_path), exist_ok=True) | |
| response = requests.get(url, stream=True, timeout=300) | |
| response.raise_for_status() | |
| import bz2 | |
| compressed_data = response.content | |
| with open(landmarks_path, 'wb') as f: | |
| f.write(bz2.decompress(compressed_data)) | |
| print(f"Face landmarks downloaded successfully ({os.path.getsize(landmarks_path)/1024/1024:.1f}MB)") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to download face landmarks: {e}") | |
| return False | |
| # Comprehensive model download for full accuracy | |
| def download_all_missing_models(): | |
| """Download ALL required models for full accuracy""" | |
| all_required_models = { | |
| 'pretrained_models/ArcFace/backbone_ir50.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ArcFace/backbone_ir50.pth', | |
| 'pretrained_models/ArcFace/ir_se50.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ArcFace/ir_se50.pth', | |
| 'pretrained_models/BiSeNet/face_parsing_79999_iter.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/BiSeNet/face_parsing_79999_iter.pth', | |
| 'pretrained_models/FeatureStyleEncoder/backbone.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/backbone.pth', | |
| 'pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt', | |
| 'pretrained_models/FeatureStyleEncoder/79999_iter.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/79999_iter.pth', | |
| 'pretrained_models/FeatureStyleEncoder/143_enc.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/143_enc.pth', | |
| 'pretrained_models/encoder4editing/e4e_ffhq_encode.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/encoder4editing/e4e_ffhq_encode.pt', | |
| 'pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth', | |
| 'pretrained_models/PostProcess/pp_model.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/PostProcess/pp_model.pth', | |
| 'pretrained_models/PostProcess/latent_avg.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/PostProcess/latent_avg.pt', | |
| } | |
| print("Checking for missing models for full accuracy...") | |
| missing_models = [] | |
| existing_models = [] | |
| for model_path, url in all_required_models.items(): | |
| if os.path.exists(model_path): | |
| existing_models.append(os.path.basename(model_path)) | |
| else: | |
| missing_models.append((model_path, url)) | |
| if existing_models: | |
| print(f"Found existing models: {', '.join(existing_models)}") | |
| if missing_models: | |
| print(f"Downloading {len(missing_models)} missing models for full accuracy...") | |
| for i, (model_path, url) in enumerate(missing_models, 1): | |
| try: | |
| model_name = os.path.basename(model_path) | |
| print(f"[{i}/{len(missing_models)}] Downloading: {model_name}") | |
| os.makedirs(os.path.dirname(model_path), exist_ok=True) | |
| start_time = time.time() | |
| response = requests.get(url, stream=True, timeout=600) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| downloaded = 0 | |
| with open(model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=1024*1024): | |
| if chunk: | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if total_size > 100*1024*1024 and downloaded % (50*1024*1024) == 0: | |
| progress = (downloaded / total_size) * 100 if total_size > 0 else 0 | |
| print(f" Progress: {progress:.1f}% ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)") | |
| elapsed = time.time() - start_time | |
| print(f" Downloaded: {model_name} ({downloaded/1024/1024:.1f}MB in {elapsed:.1f}s)") | |
| except Exception as e: | |
| print(f" Failed to download {model_name}: {e}") | |
| continue | |
| else: | |
| print("All models already present!") | |
| return True | |
| def download_mask_generator_separately(): | |
| """Download large mask_generator.pth file separately""" | |
| mask_path = 'pretrained_models/ShapeAdaptor/mask_generator.pth' | |
| if os.path.exists(mask_path): | |
| print("Mask generator already exists") | |
| return True | |
| try: | |
| print("Downloading mask_generator.pth (919MB)...") | |
| url = 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ShapeAdaptor/mask_generator.pth' | |
| os.makedirs(os.path.dirname(mask_path), exist_ok=True) | |
| response = requests.get(url, stream=True, timeout=900) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| downloaded = 0 | |
| with open(mask_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=2*1024*1024): | |
| if chunk: | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if downloaded % (100*1024*1024) == 0: | |
| progress = (downloaded / total_size) * 100 if total_size > 0 else 0 | |
| print(f" Mask Generator Progress: {progress:.1f}% ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)") | |
| print(f"Successfully downloaded mask_generator.pth ({downloaded/1024/1024:.1f}MB)") | |
| return True | |
| except Exception as e: | |
| print(f"Failed to download mask_generator.pth: {e}") | |
| return False | |
| # Download all models | |
| download_all_missing_models() | |
| download_mask_generator_separately() | |
| download_face_landmarks() | |
| # Direct HairFast imports | |
| try: | |
| from hair_swap import HairFast, get_parser | |
| HAIRFAST_AVAILABLE = True | |
| print("HairFast successfully imported!") | |
| except ImportError as e: | |
| print(f"HairFast import failed: {e}") | |
| HAIRFAST_AVAILABLE = False | |
| try: | |
| from utils.shape_predictor import align_face | |
| ALIGN_AVAILABLE = True | |
| print("Face alignment available!") | |
| except ImportError as e: | |
| print(f"Face alignment not available: {e}") | |
| ALIGN_AVAILABLE = False | |
| # Global variables | |
| hair_fast_model = None | |
| align_cache = LRUCache(maxsize=10) | |
| def get_gpu_memory(): | |
| """Check GPU memory for optimization""" | |
| if torch.cuda.is_available(): | |
| return torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| return 0 | |
| def optimize_for_t4(): | |
| """T4 GPU specific optimizations""" | |
| if torch.cuda.is_available(): | |
| gpu_memory = get_gpu_memory() | |
| print(f"GPU Memory: {gpu_memory:.1f}GB") | |
| if gpu_memory < 20: | |
| torch.cuda.empty_cache() | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' | |
| print("T4 optimizations applied") | |
| def check_model_completeness(): | |
| """Check which critical models are available""" | |
| critical_models = { | |
| 'StyleGAN': 'pretrained_models/StyleGAN/ffhq.pt', | |
| 'Blending': 'pretrained_models/Blending/checkpoint.pth', | |
| 'Rotate': 'pretrained_models/Rotate/rotate_best.pth', | |
| 'PostProcess': 'pretrained_models/PostProcess/pp_model.pth', | |
| 'FeatureEncoder': 'pretrained_models/FeatureStyleEncoder/143_enc.pth', | |
| 'E4E': 'pretrained_models/encoder4editing/e4e_ffhq_encode.pt', | |
| 'SEAN': 'pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth', | |
| 'ShapeAdaptor': 'pretrained_models/ShapeAdaptor/mask_generator.pth', | |
| } | |
| available_models = {} | |
| for name, path in critical_models.items(): | |
| available_models[name] = os.path.exists(path) | |
| return available_models | |
| def initialize_hairfast_original(): | |
| """Initialize HairFast exactly like original - use pure defaults""" | |
| global hair_fast_model | |
| if not HAIRFAST_AVAILABLE: | |
| print("HairFast not available") | |
| return False | |
| try: | |
| print("Initializing HairFast with original default arguments...") | |
| optimize_for_t4() | |
| available_models = check_model_completeness() | |
| print("Available models:", {k: v for k, v in available_models.items() if v}) | |
| parser = get_parser() | |
| args = parser.parse_args([]) | |
| hair_fast_model = HairFast(args) | |
| available_count = sum(available_models.values()) | |
| total_count = len(available_models) | |
| accuracy_percentage = (available_count / total_count) * 100 | |
| print(f"HairFast initialized with original defaults ({accuracy_percentage:.1f}% accuracy)") | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return True | |
| except Exception as e: | |
| print(f"HairFast initialization failed: {e}") | |
| hair_fast_model = None | |
| torch.cuda.empty_cache() | |
| return False | |
| def get_bytes(img): | |
| """EXACT copy of original get_bytes function""" | |
| if img is None: | |
| return img | |
| buffered = BytesIO() | |
| img.save(buffered, format="JPEG") | |
| return buffered.getvalue() | |
| def bytes_to_image(image: bytes) -> Image.Image: | |
| """EXACT copy of original bytes_to_image function""" | |
| image = Image.open(BytesIO(image)) | |
| return image | |
| def base64_to_image(base64_string): | |
| """Convert base64 string to PIL Image with error handling""" | |
| try: | |
| if base64_string.startswith('data:image'): | |
| base64_string = base64_string.split(',')[1] | |
| image_bytes = base64.b64decode(base64_string) | |
| image = Image.open(BytesIO(image_bytes)) | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| return image | |
| except Exception as e: | |
| print(f"Error converting base64 to image: {e}") | |
| return None | |
| def image_to_base64(image): | |
| """Convert PIL Image to base64 string with maximum quality""" | |
| if image is None: | |
| return None | |
| buffered = BytesIO() | |
| # Use PNG for lossless quality | |
| image.save(buffered, format="PNG", optimize=False) | |
| img_bytes = buffered.getvalue() | |
| img_base64 = base64.b64encode(img_bytes).decode('utf-8') | |
| return f"data:image/png;base64,{img_base64}" | |
| def center_crop(img): | |
| """EXACT copy of original center_crop function""" | |
| width, height = img.size | |
| side = min(width, height) | |
| left = (width - side) / 2 | |
| top = (height - side) / 2 | |
| right = (width + side) / 2 | |
| bottom = (height + side) / 2 | |
| img = img.crop((left, top, right, bottom)) | |
| return img | |
| def resize(name): | |
| """Fixed resize function with proper size handling""" | |
| def resize_inner(img, align): | |
| global align_cache | |
| if name in align and ALIGN_AVAILABLE: | |
| img_hash = hashlib.md5(get_bytes(img)).hexdigest() | |
| if img_hash not in align_cache: | |
| try: | |
| aligned_imgs = align_face(img, return_tensors=False) | |
| if aligned_imgs and len(aligned_imgs) > 0: | |
| img = aligned_imgs[0] | |
| if img.size != (1024, 1024): | |
| img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| align_cache[img_hash] = img | |
| else: | |
| img = center_crop(img) | |
| img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| align_cache[img_hash] = img | |
| except Exception as e: | |
| print(f"Face alignment failed for {name}, using center crop: {e}") | |
| img = center_crop(img) | |
| img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| align_cache[img_hash] = img | |
| else: | |
| img = align_cache[img_hash] | |
| else: | |
| if img.size != (1024, 1024): | |
| img = center_crop(img) | |
| img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| return img | |
| return resize_inner | |
| def swap_hair_selective(face, shape, color, blending, poisson_iters, poisson_erosion): | |
| """ | |
| Enhanced swap logic with selective transfer: | |
| - If only shape provided: change hairstyle only | |
| - If only color provided: change hair color only | |
| - If both provided: change both hairstyle and color | |
| """ | |
| global hair_fast_model | |
| if hair_fast_model is None: | |
| if not initialize_hairfast_original(): | |
| return None, "HairFast model not available. Please check if all model files are uploaded." | |
| if not face and not shape and not color: | |
| return None, "Need to upload a face and at least a shape or color ❗" | |
| elif not face: | |
| return None, "Need to upload a face ❗" | |
| elif not shape and not color: | |
| return None, "Need to upload at least a shape or color ❗" | |
| try: | |
| print("Starting selective hair transfer...") | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def validate_size(img, name): | |
| if img is not None: | |
| if img.size != (1024, 1024): | |
| print(f"Resizing {name} from {img.size} to (1024, 1024)") | |
| img = center_crop(img) | |
| img = img.resize((1024, 1024), Image.Resampling.LANCZOS) | |
| return img | |
| face = validate_size(face, "face") | |
| shape = validate_size(shape, "shape") if shape else None | |
| color = validate_size(color, "color") if color else None | |
| # Determine transfer mode | |
| has_shape = shape is not None | |
| has_color = color is not None | |
| if has_shape and has_color: | |
| transfer_mode = "both" | |
| print("Transfer mode: Both hairstyle and color") | |
| elif has_shape and not has_color: | |
| transfer_mode = "shape_only" | |
| color = face # Use original face for color reference | |
| print("Transfer mode: Hairstyle only (preserving original color)") | |
| elif has_color and not has_shape: | |
| transfer_mode = "color_only" | |
| shape = face # Use original face for shape reference | |
| print("Transfer mode: Color only (preserving original hairstyle)") | |
| print(f"Final sizes - Face: {face.size}, Shape: {shape.size if shape else 'None'}, Color: {color.size if color else 'None'}") | |
| with torch.no_grad(): | |
| start_time = time.time() | |
| # Use the HairFast model's swap method with proper parameters | |
| final_image = hair_fast_model.swap(face, shape, color) | |
| inference_time = time.time() - start_time | |
| print(f"Inference completed in {inference_time:.2f} seconds") | |
| result_image = F.to_pil_image(final_image) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| success_message = f"Hair transfer ({transfer_mode}) completed successfully in {inference_time:.2f}s" | |
| print(success_message) | |
| return result_image, success_message | |
| except Exception as e: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| error_msg = f"Hair transfer failed: {str(e)}" | |
| print(f"Detailed error: {e}") | |
| return None, error_msg | |
| def hair_transfer_api(source_image, shape_image=None, color_image=None, | |
| blending="Article", poisson_iters=0, poisson_erosion=15): | |
| """Enhanced API function for frontend integration with quality preservation""" | |
| try: | |
| print("API call received - processing images...") | |
| if isinstance(source_image, str): | |
| source_image = base64_to_image(source_image) | |
| print(f"Source image loaded: {source_image.size if source_image else 'Failed'}") | |
| if isinstance(shape_image, str) and shape_image: | |
| shape_image = base64_to_image(shape_image) | |
| print(f"Shape image loaded: {shape_image.size if shape_image else 'Failed'}") | |
| if isinstance(color_image, str) and color_image: | |
| color_image = base64_to_image(color_image) | |
| print(f"Color image loaded: {color_image.size if color_image else 'Failed'}") | |
| if not source_image: | |
| return None, "Failed to process source image" | |
| result_image, status_message = swap_hair_selective( | |
| source_image, shape_image, color_image, | |
| blending, poisson_iters, poisson_erosion | |
| ) | |
| if result_image is not None: | |
| print(f"Result image generated: {result_image.size}") | |
| result_base64 = image_to_base64(result_image) | |
| print("Result converted to base64 successfully") | |
| return result_base64, status_message | |
| else: | |
| print("Result image is None") | |
| return None, status_message | |
| except Exception as e: | |
| error_msg = f"API Error: {str(e)}" | |
| print(f"API Error details: {e}") | |
| return None, error_msg | |
| def get_demo(): | |
| """Enhanced Gradio interface with selective transfer options""" | |
| with gr.Blocks( | |
| title="HairFastGAN - Selective Transfer", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .error-message { | |
| color: red !important; | |
| background-color: #ffebee !important; | |
| padding: 10px !important; | |
| border-radius: 5px !important; | |
| } | |
| .transfer-info { | |
| background-color: #e3f2fd !important; | |
| padding: 10px !important; | |
| border-radius: 5px !important; | |
| margin: 10px 0 !important; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown("## HairFastGan - Selective Transfer") | |
| gr.Markdown( | |
| '<div style="display: flex; align-items: center; gap: 10px;">' | |
| '<span>Enhanced HairFastGAN with selective transfer:</span>' | |
| '<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>' | |
| '<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>' | |
| '<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>' | |
| '<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>' | |
| '</div>' | |
| ) | |
| gr.Markdown( | |
| """ | |
| <div class="transfer-info"> | |
| <b>🎯 Selective Transfer Modes:</b><br> | |
| • <b>Shape Only:</b> Upload only shape image → Changes hairstyle while preserving original color<br> | |
| • <b>Color Only:</b> Upload only color image → Changes hair color while preserving original hairstyle<br> | |
| • <b>Both:</b> Upload both images → Changes both hairstyle and color | |
| </div> | |
| """, | |
| elem_classes="transfer-info" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| source = gr.Image(label="Source photo to try on the hairstyle", type="pil") | |
| with gr.Row(): | |
| shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil") | |
| color = gr.Image(label="Color photo with desired hair color (optional)", type="pil") | |
| # Transfer mode indicator | |
| transfer_status = gr.Textbox(label="Transfer Mode", interactive=False, | |
| value="Upload images to see transfer mode") | |
| with gr.Accordion("Advanced Options", open=False): | |
| blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article', | |
| label="Color Encoder version") | |
| poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters") | |
| poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion") | |
| align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"], | |
| label="Image cropping [recommended]") | |
| btn = gr.Button("Get the haircut", variant="primary") | |
| with gr.Column(): | |
| output = gr.Image(label="Your result") | |
| error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message") | |
| # Function to update transfer mode display | |
| def update_transfer_mode(shape_img, color_img): | |
| if shape_img is not None and color_img is not None: | |
| return "🎯 Both hairstyle and color will be transferred" | |
| elif shape_img is not None and color_img is None: | |
| return "🎨 Only hairstyle will be transferred (color preserved)" | |
| elif shape_img is None and color_img is not None: | |
| return "🌈 Only hair color will be transferred (style preserved)" | |
| else: | |
| return "Upload shape and/or color reference images" | |
| # Update transfer mode when images change | |
| shape.change(fn=update_transfer_mode, inputs=[shape, color], outputs=transfer_status) | |
| color.change(fn=update_transfer_mode, inputs=[shape, color], outputs=transfer_status) | |
| source.upload(fn=resize('Face'), inputs=[source, align], outputs=source) | |
| shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape) | |
| color.upload(fn=resize('Color'), inputs=[color, align], outputs=color) | |
| btn.click( | |
| fn=swap_hair_selective, | |
| inputs=[source, shape, color, blending, poisson_iters, poisson_erosion], | |
| outputs=[output, error_message], | |
| api_name="predict" | |
| ) | |
| gr.Markdown(''' | |
| ### How to use: | |
| 1. **Upload your source photo** (the face you want to modify) | |
| 2. **Choose your transfer mode:** | |
| - For **hairstyle only**: Upload only a shape reference image | |
| - For **color only**: Upload only a color reference image | |
| - For **both**: Upload both shape and color reference images | |
| 3. **Click "Get the haircut"** and wait for results! | |
| ''') | |
| return demo | |
| def create_api_app(): | |
| """Create FastAPI app for better frontend integration""" | |
| app = FastAPI(title="HairFastGAN Selective Transfer API", version="2.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def transfer_hair(request: ImageRequest): | |
| """Enhanced API endpoint for hair transfer""" | |
| try: | |
| print(f"API request received: {len(request.source_image)} chars source") | |
| result_base64, message = hair_transfer_api( | |
| source_image=request.source_image, | |
| shape_image=request.shape_image, | |
| color_image=request.color_image, | |
| blending=request.blending, | |
| poisson_iters=request.poisson_iters, | |
| poisson_erosion=request.poisson_erosion | |
| ) | |
| if result_base64: | |
| return ImageResponse( | |
| success=True, | |
| result_image=result_base64, | |
| message=message or "Hair transfer completed successfully" | |
| ) | |
| else: | |
| return ImageResponse( | |
| success=False, | |
| message=message or "Hair transfer failed" | |
| ) | |
| except Exception as e: | |
| print(f"API endpoint error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy", "model_loaded": hair_fast_model is not None} | |
| return app | |
| if __name__ == '__main__': | |
| optimize_for_t4() | |
| align_cache = LRUCache(maxsize=10) | |
| if not initialize_hairfast_original(): | |
| print("Failed to initialize HairFast model") | |
| exit(1) | |
| gradio_demo = get_demo() | |
| fastapi_app = create_api_app() | |
| def run_fastapi(): | |
| uvicorn.run(fastapi_app, host="0.0.0.0", port=8000, log_level="info") | |
| def run_gradio(): | |
| gradio_demo.queue(max_size=10, default_concurrency_limit=2) | |
| gradio_demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |
| print("Starting FastAPI server on port 8000...") | |
| fastapi_thread = threading.Thread(target=run_fastapi) | |
| fastapi_thread.daemon = True | |
| fastapi_thread.start() | |
| print("Starting Gradio server on port 7860...") | |
| run_gradio() |