import hashlib import os from io import BytesIO import base64 import requests from pathlib import Path import subprocess import shutil import gc import time import json import threading import gradio as gr from PIL import Image from cachetools import LRUCache import torch import numpy as np import torchvision.transforms.functional as F # FastAPI imports for enhanced frontend integration from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional import uvicorn # T4 Medium GPU Optimizations torch.backends.cudnn.benchmark = True torch.backends.cuda.max_split_size_mb = 512 # API Models for FastAPI class ImageRequest(BaseModel): source_image: str # base64 string shape_image: Optional[str] = None # base64 string color_image: Optional[str] = None # base64 string blending: str = "Article" poisson_iters: int = 0 poisson_erosion: int = 15 class ImageResponse(BaseModel): success: bool result_image: Optional[str] = None # base64 string message: str # Download working face landmarks def download_face_landmarks(): """Download working dlib face landmarks predictor""" landmarks_path = 'pretrained_models/ShapeAdaptor/shape_predictor_68_face_landmarks.dat' if os.path.exists(landmarks_path) and os.path.getsize(landmarks_path) > 50000000: print("Face landmarks already exists and appears valid") return True try: print("Downloading working face landmarks predictor...") url = 'https://github.com/davisking/dlib-models/raw/master/shape_predictor_68_face_landmarks.dat.bz2' os.makedirs(os.path.dirname(landmarks_path), exist_ok=True) response = requests.get(url, stream=True, timeout=300) response.raise_for_status() import bz2 compressed_data = response.content with open(landmarks_path, 'wb') as f: f.write(bz2.decompress(compressed_data)) print(f"Face landmarks downloaded successfully ({os.path.getsize(landmarks_path)/1024/1024:.1f}MB)") return True except Exception as e: print(f"Failed to download face landmarks: {e}") return False # Comprehensive model download for full accuracy def download_all_missing_models(): """Download ALL required models for full accuracy""" all_required_models = { 'pretrained_models/ArcFace/backbone_ir50.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ArcFace/backbone_ir50.pth', 'pretrained_models/ArcFace/ir_se50.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ArcFace/ir_se50.pth', 'pretrained_models/BiSeNet/face_parsing_79999_iter.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/BiSeNet/face_parsing_79999_iter.pth', 'pretrained_models/FeatureStyleEncoder/backbone.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/backbone.pth', 'pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt', 'pretrained_models/FeatureStyleEncoder/79999_iter.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/79999_iter.pth', 'pretrained_models/FeatureStyleEncoder/143_enc.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/143_enc.pth', 'pretrained_models/encoder4editing/e4e_ffhq_encode.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/encoder4editing/e4e_ffhq_encode.pt', 'pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth', 'pretrained_models/PostProcess/pp_model.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/PostProcess/pp_model.pth', 'pretrained_models/PostProcess/latent_avg.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/PostProcess/latent_avg.pt', } print("Checking for missing models for full accuracy...") missing_models = [] existing_models = [] for model_path, url in all_required_models.items(): if os.path.exists(model_path): existing_models.append(os.path.basename(model_path)) else: missing_models.append((model_path, url)) if existing_models: print(f"Found existing models: {', '.join(existing_models)}") if missing_models: print(f"Downloading {len(missing_models)} missing models for full accuracy...") for i, (model_path, url) in enumerate(missing_models, 1): try: model_name = os.path.basename(model_path) print(f"[{i}/{len(missing_models)}] Downloading: {model_name}") os.makedirs(os.path.dirname(model_path), exist_ok=True) start_time = time.time() response = requests.get(url, stream=True, timeout=600) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=1024*1024): if chunk: f.write(chunk) downloaded += len(chunk) if total_size > 100*1024*1024 and downloaded % (50*1024*1024) == 0: progress = (downloaded / total_size) * 100 if total_size > 0 else 0 print(f" Progress: {progress:.1f}% ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)") elapsed = time.time() - start_time print(f" Downloaded: {model_name} ({downloaded/1024/1024:.1f}MB in {elapsed:.1f}s)") except Exception as e: print(f" Failed to download {model_name}: {e}") continue else: print("All models already present!") return True def download_mask_generator_separately(): """Download large mask_generator.pth file separately""" mask_path = 'pretrained_models/ShapeAdaptor/mask_generator.pth' if os.path.exists(mask_path): print("Mask generator already exists") return True try: print("Downloading mask_generator.pth (919MB)...") url = 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ShapeAdaptor/mask_generator.pth' os.makedirs(os.path.dirname(mask_path), exist_ok=True) response = requests.get(url, stream=True, timeout=900) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open(mask_path, 'wb') as f: for chunk in response.iter_content(chunk_size=2*1024*1024): if chunk: f.write(chunk) downloaded += len(chunk) if downloaded % (100*1024*1024) == 0: progress = (downloaded / total_size) * 100 if total_size > 0 else 0 print(f" Mask Generator Progress: {progress:.1f}% ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)") print(f"Successfully downloaded mask_generator.pth ({downloaded/1024/1024:.1f}MB)") return True except Exception as e: print(f"Failed to download mask_generator.pth: {e}") return False # Download all models download_all_missing_models() download_mask_generator_separately() download_face_landmarks() # Direct HairFast imports try: from hair_swap import HairFast, get_parser HAIRFAST_AVAILABLE = True print("HairFast successfully imported!") except ImportError as e: print(f"HairFast import failed: {e}") HAIRFAST_AVAILABLE = False try: from utils.shape_predictor import align_face ALIGN_AVAILABLE = True print("Face alignment available!") except ImportError as e: print(f"Face alignment not available: {e}") ALIGN_AVAILABLE = False # Global variables hair_fast_model = None align_cache = LRUCache(maxsize=10) def get_gpu_memory(): """Check GPU memory for optimization""" if torch.cuda.is_available(): return torch.cuda.get_device_properties(0).total_memory / 1e9 return 0 def optimize_for_t4(): """T4 GPU specific optimizations""" if torch.cuda.is_available(): gpu_memory = get_gpu_memory() print(f"GPU Memory: {gpu_memory:.1f}GB") if gpu_memory < 20: torch.cuda.empty_cache() os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' print("T4 optimizations applied") def check_model_completeness(): """Check which critical models are available""" critical_models = { 'StyleGAN': 'pretrained_models/StyleGAN/ffhq.pt', 'Blending': 'pretrained_models/Blending/checkpoint.pth', 'Rotate': 'pretrained_models/Rotate/rotate_best.pth', 'PostProcess': 'pretrained_models/PostProcess/pp_model.pth', 'FeatureEncoder': 'pretrained_models/FeatureStyleEncoder/143_enc.pth', 'E4E': 'pretrained_models/encoder4editing/e4e_ffhq_encode.pt', 'SEAN': 'pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth', 'ShapeAdaptor': 'pretrained_models/ShapeAdaptor/mask_generator.pth', } available_models = {} for name, path in critical_models.items(): available_models[name] = os.path.exists(path) return available_models def initialize_hairfast_original(): """Initialize HairFast exactly like original - use pure defaults""" global hair_fast_model if not HAIRFAST_AVAILABLE: print("HairFast not available") return False try: print("Initializing HairFast with original default arguments...") optimize_for_t4() available_models = check_model_completeness() print("Available models:", {k: v for k, v in available_models.items() if v}) parser = get_parser() args = parser.parse_args([]) hair_fast_model = HairFast(args) available_count = sum(available_models.values()) total_count = len(available_models) accuracy_percentage = (available_count / total_count) * 100 print(f"HairFast initialized with original defaults ({accuracy_percentage:.1f}% accuracy)") torch.cuda.empty_cache() gc.collect() return True except Exception as e: print(f"HairFast initialization failed: {e}") hair_fast_model = None torch.cuda.empty_cache() return False def get_bytes(img): """EXACT copy of original get_bytes function""" if img is None: return img buffered = BytesIO() img.save(buffered, format="JPEG") return buffered.getvalue() def bytes_to_image(image: bytes) -> Image.Image: """EXACT copy of original bytes_to_image function""" image = Image.open(BytesIO(image)) return image def base64_to_image(base64_string): """Convert base64 string to PIL Image with error handling""" try: if base64_string.startswith('data:image'): base64_string = base64_string.split(',')[1] image_bytes = base64.b64decode(base64_string) image = Image.open(BytesIO(image_bytes)) if image.mode != 'RGB': image = image.convert('RGB') return image except Exception as e: print(f"Error converting base64 to image: {e}") return None def image_to_base64(image): """Convert PIL Image to base64 string with maximum quality""" if image is None: return None buffered = BytesIO() # Use PNG for lossless quality image.save(buffered, format="PNG", optimize=False) img_bytes = buffered.getvalue() img_base64 = base64.b64encode(img_bytes).decode('utf-8') return f"data:image/png;base64,{img_base64}" def center_crop(img): """EXACT copy of original center_crop function""" width, height = img.size side = min(width, height) left = (width - side) / 2 top = (height - side) / 2 right = (width + side) / 2 bottom = (height + side) / 2 img = img.crop((left, top, right, bottom)) return img def resize(name): """Fixed resize function with proper size handling""" def resize_inner(img, align): global align_cache if name in align and ALIGN_AVAILABLE: img_hash = hashlib.md5(get_bytes(img)).hexdigest() if img_hash not in align_cache: try: aligned_imgs = align_face(img, return_tensors=False) if aligned_imgs and len(aligned_imgs) > 0: img = aligned_imgs[0] if img.size != (1024, 1024): img = img.resize((1024, 1024), Image.Resampling.LANCZOS) align_cache[img_hash] = img else: img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) align_cache[img_hash] = img except Exception as e: print(f"Face alignment failed for {name}, using center crop: {e}") img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) align_cache[img_hash] = img else: img = align_cache[img_hash] else: if img.size != (1024, 1024): img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) return img return resize_inner def swap_hair_selective(face, shape, color, blending, poisson_iters, poisson_erosion): """ Enhanced swap logic with selective transfer: - If only shape provided: change hairstyle only - If only color provided: change hair color only - If both provided: change both hairstyle and color """ global hair_fast_model if hair_fast_model is None: if not initialize_hairfast_original(): return None, "HairFast model not available. Please check if all model files are uploaded." if not face and not shape and not color: return None, "Need to upload a face and at least a shape or color ❗" elif not face: return None, "Need to upload a face ❗" elif not shape and not color: return None, "Need to upload at least a shape or color ❗" try: print("Starting selective hair transfer...") torch.cuda.empty_cache() gc.collect() def validate_size(img, name): if img is not None: if img.size != (1024, 1024): print(f"Resizing {name} from {img.size} to (1024, 1024)") img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) return img face = validate_size(face, "face") shape = validate_size(shape, "shape") if shape else None color = validate_size(color, "color") if color else None # Determine transfer mode has_shape = shape is not None has_color = color is not None if has_shape and has_color: transfer_mode = "both" print("Transfer mode: Both hairstyle and color") elif has_shape and not has_color: transfer_mode = "shape_only" color = face # Use original face for color reference print("Transfer mode: Hairstyle only (preserving original color)") elif has_color and not has_shape: transfer_mode = "color_only" shape = face # Use original face for shape reference print("Transfer mode: Color only (preserving original hairstyle)") print(f"Final sizes - Face: {face.size}, Shape: {shape.size if shape else 'None'}, Color: {color.size if color else 'None'}") with torch.no_grad(): start_time = time.time() # Use the HairFast model's swap method with proper parameters final_image = hair_fast_model.swap(face, shape, color) inference_time = time.time() - start_time print(f"Inference completed in {inference_time:.2f} seconds") result_image = F.to_pil_image(final_image) torch.cuda.empty_cache() gc.collect() success_message = f"Hair transfer ({transfer_mode}) completed successfully in {inference_time:.2f}s" print(success_message) return result_image, success_message except Exception as e: torch.cuda.empty_cache() gc.collect() error_msg = f"Hair transfer failed: {str(e)}" print(f"Detailed error: {e}") return None, error_msg def hair_transfer_api(source_image, shape_image=None, color_image=None, blending="Article", poisson_iters=0, poisson_erosion=15): """Enhanced API function for frontend integration with quality preservation""" try: print("API call received - processing images...") if isinstance(source_image, str): source_image = base64_to_image(source_image) print(f"Source image loaded: {source_image.size if source_image else 'Failed'}") if isinstance(shape_image, str) and shape_image: shape_image = base64_to_image(shape_image) print(f"Shape image loaded: {shape_image.size if shape_image else 'Failed'}") if isinstance(color_image, str) and color_image: color_image = base64_to_image(color_image) print(f"Color image loaded: {color_image.size if color_image else 'Failed'}") if not source_image: return None, "Failed to process source image" result_image, status_message = swap_hair_selective( source_image, shape_image, color_image, blending, poisson_iters, poisson_erosion ) if result_image is not None: print(f"Result image generated: {result_image.size}") result_base64 = image_to_base64(result_image) print("Result converted to base64 successfully") return result_base64, status_message else: print("Result image is None") return None, status_message except Exception as e: error_msg = f"API Error: {str(e)}" print(f"API Error details: {e}") return None, error_msg def get_demo(): """Enhanced Gradio interface with selective transfer options""" with gr.Blocks( title="HairFastGAN - Selective Transfer", theme=gr.themes.Soft(), css=""" .error-message { color: red !important; background-color: #ffebee !important; padding: 10px !important; border-radius: 5px !important; } .transfer-info { background-color: #e3f2fd !important; padding: 10px !important; border-radius: 5px !important; margin: 10px 0 !important; } """ ) as demo: gr.Markdown("## HairFastGan - Selective Transfer") gr.Markdown( '
' 'Enhanced HairFastGAN with selective transfer:' '' '' '' '' '
' ) gr.Markdown( """
🎯 Selective Transfer Modes:
Shape Only: Upload only shape image → Changes hairstyle while preserving original color
Color Only: Upload only color image → Changes hair color while preserving original hairstyle
Both: Upload both images → Changes both hairstyle and color
""", elem_classes="transfer-info" ) with gr.Row(): with gr.Column(): source = gr.Image(label="Source photo to try on the hairstyle", type="pil") with gr.Row(): shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil") color = gr.Image(label="Color photo with desired hair color (optional)", type="pil") # Transfer mode indicator transfer_status = gr.Textbox(label="Transfer Mode", interactive=False, value="Upload images to see transfer mode") with gr.Accordion("Advanced Options", open=False): blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article', label="Color Encoder version") poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters") poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion") align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"], label="Image cropping [recommended]") btn = gr.Button("Get the haircut", variant="primary") with gr.Column(): output = gr.Image(label="Your result") error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message") # Function to update transfer mode display def update_transfer_mode(shape_img, color_img): if shape_img is not None and color_img is not None: return "🎯 Both hairstyle and color will be transferred" elif shape_img is not None and color_img is None: return "🎨 Only hairstyle will be transferred (color preserved)" elif shape_img is None and color_img is not None: return "🌈 Only hair color will be transferred (style preserved)" else: return "Upload shape and/or color reference images" # Update transfer mode when images change shape.change(fn=update_transfer_mode, inputs=[shape, color], outputs=transfer_status) color.change(fn=update_transfer_mode, inputs=[shape, color], outputs=transfer_status) source.upload(fn=resize('Face'), inputs=[source, align], outputs=source) shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape) color.upload(fn=resize('Color'), inputs=[color, align], outputs=color) btn.click( fn=swap_hair_selective, inputs=[source, shape, color, blending, poisson_iters, poisson_erosion], outputs=[output, error_message], api_name="predict" ) gr.Markdown(''' ### How to use: 1. **Upload your source photo** (the face you want to modify) 2. **Choose your transfer mode:** - For **hairstyle only**: Upload only a shape reference image - For **color only**: Upload only a color reference image - For **both**: Upload both shape and color reference images 3. **Click "Get the haircut"** and wait for results! ''') return demo def create_api_app(): """Create FastAPI app for better frontend integration""" app = FastAPI(title="HairFastGAN Selective Transfer API", version="2.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/api/hair-transfer", response_model=ImageResponse) async def transfer_hair(request: ImageRequest): """Enhanced API endpoint for hair transfer""" try: print(f"API request received: {len(request.source_image)} chars source") result_base64, message = hair_transfer_api( source_image=request.source_image, shape_image=request.shape_image, color_image=request.color_image, blending=request.blending, poisson_iters=request.poisson_iters, poisson_erosion=request.poisson_erosion ) if result_base64: return ImageResponse( success=True, result_image=result_base64, message=message or "Hair transfer completed successfully" ) else: return ImageResponse( success=False, message=message or "Hair transfer failed" ) except Exception as e: print(f"API endpoint error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/api/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "model_loaded": hair_fast_model is not None} return app if __name__ == '__main__': optimize_for_t4() align_cache = LRUCache(maxsize=10) if not initialize_hairfast_original(): print("Failed to initialize HairFast model") exit(1) gradio_demo = get_demo() fastapi_app = create_api_app() def run_fastapi(): uvicorn.run(fastapi_app, host="0.0.0.0", port=8000, log_level="info") def run_gradio(): gradio_demo.queue(max_size=10, default_concurrency_limit=2) gradio_demo.launch(server_name="0.0.0.0", server_port=7860, share=False) print("Starting FastAPI server on port 8000...") fastapi_thread = threading.Thread(target=run_fastapi) fastapi_thread.daemon = True fastapi_thread.start() print("Starting Gradio server on port 7860...") run_gradio()