import hashlib import os from io import BytesIO import base64 import requests from pathlib import Path import subprocess import shutil import gc import time import json import threading import gradio as gr from PIL import Image from cachetools import LRUCache import torch import numpy as np import torchvision.transforms.functional as F # FastAPI imports for enhanced frontend integration from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional import uvicorn # T4 Medium GPU Optimizations torch.backends.cudnn.benchmark = True torch.backends.cuda.max_split_size_mb = 512 # API Models for FastAPI class ImageRequest(BaseModel): source_image: str # base64 string shape_image: Optional[str] = None # base64 string color_image: Optional[str] = None # base64 string blending: str = "Article" poisson_iters: int = 0 poisson_erosion: int = 15 class ImageResponse(BaseModel): success: bool result_image: Optional[str] = None # base64 string message: str # Download working face landmarks def download_face_landmarks(): """Download working dlib face landmarks predictor""" landmarks_path = 'pretrained_models/ShapeAdaptor/shape_predictor_68_face_landmarks.dat' if os.path.exists(landmarks_path) and os.path.getsize(landmarks_path) > 50000000: print("Face landmarks already exists and appears valid") return True try: print("Downloading working face landmarks predictor...") url = 'https://github.com/davisking/dlib-models/raw/master/shape_predictor_68_face_landmarks.dat.bz2' os.makedirs(os.path.dirname(landmarks_path), exist_ok=True) response = requests.get(url, stream=True, timeout=300) response.raise_for_status() import bz2 compressed_data = response.content with open(landmarks_path, 'wb') as f: f.write(bz2.decompress(compressed_data)) print(f"Face landmarks downloaded successfully ({os.path.getsize(landmarks_path)/1024/1024:.1f}MB)") return True except Exception as e: print(f"Failed to download face landmarks: {e}") return False # Comprehensive model download for full accuracy def download_all_missing_models(): """Download ALL required models for full accuracy""" all_required_models = { 'pretrained_models/ArcFace/backbone_ir50.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ArcFace/backbone_ir50.pth', 'pretrained_models/ArcFace/ir_se50.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ArcFace/ir_se50.pth', 'pretrained_models/BiSeNet/face_parsing_79999_iter.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/BiSeNet/face_parsing_79999_iter.pth', 'pretrained_models/FeatureStyleEncoder/backbone.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/backbone.pth', 'pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt', 'pretrained_models/FeatureStyleEncoder/79999_iter.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/79999_iter.pth', 'pretrained_models/FeatureStyleEncoder/143_enc.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/143_enc.pth', 'pretrained_models/encoder4editing/e4e_ffhq_encode.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/encoder4editing/e4e_ffhq_encode.pt', 'pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth', 'pretrained_models/PostProcess/pp_model.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/PostProcess/pp_model.pth', 'pretrained_models/PostProcess/latent_avg.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/PostProcess/latent_avg.pt', } print("Checking for missing models for full accuracy...") missing_models = [] existing_models = [] for model_path, url in all_required_models.items(): if os.path.exists(model_path): existing_models.append(os.path.basename(model_path)) else: missing_models.append((model_path, url)) if existing_models: print(f"Found existing models: {', '.join(existing_models)}") if missing_models: print(f"Downloading {len(missing_models)} missing models for full accuracy...") for i, (model_path, url) in enumerate(missing_models, 1): try: model_name = os.path.basename(model_path) print(f"[{i}/{len(missing_models)}] Downloading: {model_name}") os.makedirs(os.path.dirname(model_path), exist_ok=True) start_time = time.time() response = requests.get(url, stream=True, timeout=600) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open(model_path, 'wb') as f: for chunk in response.iter_content(chunk_size=1024*1024): if chunk: f.write(chunk) downloaded += len(chunk) if total_size > 100*1024*1024 and downloaded % (50*1024*1024) == 0: progress = (downloaded / total_size) * 100 if total_size > 0 else 0 print(f" Progress: {progress:.1f}% ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)") elapsed = time.time() - start_time print(f" Downloaded: {model_name} ({downloaded/1024/1024:.1f}MB in {elapsed:.1f}s)") except Exception as e: print(f" Failed to download {model_name}: {e}") continue else: print("All models already present!") return True def download_mask_generator_separately(): """Download large mask_generator.pth file separately""" mask_path = 'pretrained_models/ShapeAdaptor/mask_generator.pth' if os.path.exists(mask_path): print("Mask generator already exists") return True try: print("Downloading mask_generator.pth (919MB)...") url = 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ShapeAdaptor/mask_generator.pth' os.makedirs(os.path.dirname(mask_path), exist_ok=True) response = requests.get(url, stream=True, timeout=900) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) downloaded = 0 with open(mask_path, 'wb') as f: for chunk in response.iter_content(chunk_size=2*1024*1024): if chunk: f.write(chunk) downloaded += len(chunk) if downloaded % (100*1024*1024) == 0: progress = (downloaded / total_size) * 100 if total_size > 0 else 0 print(f" Mask Generator Progress: {progress:.1f}% ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)") print(f"Successfully downloaded mask_generator.pth ({downloaded/1024/1024:.1f}MB)") return True except Exception as e: print(f"Failed to download mask_generator.pth: {e}") return False # Download all models download_all_missing_models() download_mask_generator_separately() download_face_landmarks() # Direct HairFast imports try: from hair_swap import HairFast, get_parser HAIRFAST_AVAILABLE = True print("HairFast successfully imported!") except ImportError as e: print(f"HairFast import failed: {e}") HAIRFAST_AVAILABLE = False try: from utils.shape_predictor import align_face ALIGN_AVAILABLE = True print("Face alignment available!") except ImportError as e: print(f"Face alignment not available: {e}") ALIGN_AVAILABLE = False # Global variables hair_fast_model = None align_cache = LRUCache(maxsize=10) def get_gpu_memory(): """Check GPU memory for optimization""" if torch.cuda.is_available(): return torch.cuda.get_device_properties(0).total_memory / 1e9 return 0 def optimize_for_t4(): """T4 GPU specific optimizations""" if torch.cuda.is_available(): gpu_memory = get_gpu_memory() print(f"GPU Memory: {gpu_memory:.1f}GB") if gpu_memory < 20: torch.cuda.empty_cache() os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' print("T4 optimizations applied") def check_model_completeness(): """Check which critical models are available""" critical_models = { 'StyleGAN': 'pretrained_models/StyleGAN/ffhq.pt', 'Blending': 'pretrained_models/Blending/checkpoint.pth', 'Rotate': 'pretrained_models/Rotate/rotate_best.pth', 'PostProcess': 'pretrained_models/PostProcess/pp_model.pth', 'FeatureEncoder': 'pretrained_models/FeatureStyleEncoder/143_enc.pth', 'E4E': 'pretrained_models/encoder4editing/e4e_ffhq_encode.pt', 'SEAN': 'pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth', 'ShapeAdaptor': 'pretrained_models/ShapeAdaptor/mask_generator.pth', } available_models = {} for name, path in critical_models.items(): available_models[name] = os.path.exists(path) return available_models def initialize_hairfast_original(): """Initialize HairFast exactly like original - use pure defaults""" global hair_fast_model if not HAIRFAST_AVAILABLE: print("HairFast not available") return False try: print("Initializing HairFast with original default arguments...") optimize_for_t4() available_models = check_model_completeness() print("Available models:", {k: v for k, v in available_models.items() if v}) parser = get_parser() args = parser.parse_args([]) hair_fast_model = HairFast(args) available_count = sum(available_models.values()) total_count = len(available_models) accuracy_percentage = (available_count / total_count) * 100 print(f"HairFast initialized with original defaults ({accuracy_percentage:.1f}% accuracy)") torch.cuda.empty_cache() gc.collect() return True except Exception as e: print(f"HairFast initialization failed: {e}") hair_fast_model = None torch.cuda.empty_cache() return False def get_bytes(img): """EXACT copy of original get_bytes function""" if img is None: return img buffered = BytesIO() img.save(buffered, format="JPEG") return buffered.getvalue() def bytes_to_image(image: bytes) -> Image.Image: """EXACT copy of original bytes_to_image function""" image = Image.open(BytesIO(image)) return image def base64_to_image(base64_string): """Convert base64 string to PIL Image with error handling""" try: if base64_string.startswith('data:image'): base64_string = base64_string.split(',')[1] image_bytes = base64.b64decode(base64_string) image = Image.open(BytesIO(image_bytes)) if image.mode != 'RGB': image = image.convert('RGB') return image except Exception as e: print(f"Error converting base64 to image: {e}") return None def image_to_base64(image): """Convert PIL Image to base64 string with maximum quality""" if image is None: return None buffered = BytesIO() # Use PNG for lossless quality image.save(buffered, format="PNG", optimize=False) img_bytes = buffered.getvalue() img_base64 = base64.b64encode(img_bytes).decode('utf-8') return f"data:image/png;base64,{img_base64}" def center_crop(img): """EXACT copy of original center_crop function""" width, height = img.size side = min(width, height) left = (width - side) / 2 top = (height - side) / 2 right = (width + side) / 2 bottom = (height + side) / 2 img = img.crop((left, top, right, bottom)) return img def resize(name): """Fixed resize function with proper size handling""" def resize_inner(img, align): global align_cache if name in align and ALIGN_AVAILABLE: img_hash = hashlib.md5(get_bytes(img)).hexdigest() if img_hash not in align_cache: try: aligned_imgs = align_face(img, return_tensors=False) if aligned_imgs and len(aligned_imgs) > 0: img = aligned_imgs[0] if img.size != (1024, 1024): img = img.resize((1024, 1024), Image.Resampling.LANCZOS) align_cache[img_hash] = img else: img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) align_cache[img_hash] = img except Exception as e: print(f"Face alignment failed for {name}, using center crop: {e}") img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) align_cache[img_hash] = img else: img = align_cache[img_hash] else: if img.size != (1024, 1024): img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) return img return resize_inner def swap_hair_selective(face, shape, color, blending, poisson_iters, poisson_erosion): """ Enhanced swap logic with selective transfer: - If only shape provided: change hairstyle only - If only color provided: change hair color only - If both provided: change both hairstyle and color """ global hair_fast_model if hair_fast_model is None: if not initialize_hairfast_original(): return None, "HairFast model not available. Please check if all model files are uploaded." if not face and not shape and not color: return None, "Need to upload a face and at least a shape or color ❗" elif not face: return None, "Need to upload a face ❗" elif not shape and not color: return None, "Need to upload at least a shape or color ❗" try: print("Starting selective hair transfer...") torch.cuda.empty_cache() gc.collect() def validate_size(img, name): if img is not None: if img.size != (1024, 1024): print(f"Resizing {name} from {img.size} to (1024, 1024)") img = center_crop(img) img = img.resize((1024, 1024), Image.Resampling.LANCZOS) return img face = validate_size(face, "face") shape = validate_size(shape, "shape") if shape else None color = validate_size(color, "color") if color else None # Determine transfer mode has_shape = shape is not None has_color = color is not None if has_shape and has_color: transfer_mode = "both" print("Transfer mode: Both hairstyle and color") elif has_shape and not has_color: transfer_mode = "shape_only" color = face # Use original face for color reference print("Transfer mode: Hairstyle only (preserving original color)") elif has_color and not has_shape: transfer_mode = "color_only" shape = face # Use original face for shape reference print("Transfer mode: Color only (preserving original hairstyle)") print(f"Final sizes - Face: {face.size}, Shape: {shape.size if shape else 'None'}, Color: {color.size if color else 'None'}") with torch.no_grad(): start_time = time.time() # Use the HairFast model's swap method with proper parameters final_image = hair_fast_model.swap(face, shape, color) inference_time = time.time() - start_time print(f"Inference completed in {inference_time:.2f} seconds") result_image = F.to_pil_image(final_image) torch.cuda.empty_cache() gc.collect() success_message = f"Hair transfer ({transfer_mode}) completed successfully in {inference_time:.2f}s" print(success_message) return result_image, success_message except Exception as e: torch.cuda.empty_cache() gc.collect() error_msg = f"Hair transfer failed: {str(e)}" print(f"Detailed error: {e}") return None, error_msg def hair_transfer_api(source_image, shape_image=None, color_image=None, blending="Article", poisson_iters=0, poisson_erosion=15): """Enhanced API function for frontend integration with quality preservation""" try: print("API call received - processing images...") if isinstance(source_image, str): source_image = base64_to_image(source_image) print(f"Source image loaded: {source_image.size if source_image else 'Failed'}") if isinstance(shape_image, str) and shape_image: shape_image = base64_to_image(shape_image) print(f"Shape image loaded: {shape_image.size if shape_image else 'Failed'}") if isinstance(color_image, str) and color_image: color_image = base64_to_image(color_image) print(f"Color image loaded: {color_image.size if color_image else 'Failed'}") if not source_image: return None, "Failed to process source image" result_image, status_message = swap_hair_selective( source_image, shape_image, color_image, blending, poisson_iters, poisson_erosion ) if result_image is not None: print(f"Result image generated: {result_image.size}") result_base64 = image_to_base64(result_image) print("Result converted to base64 successfully") return result_base64, status_message else: print("Result image is None") return None, status_message except Exception as e: error_msg = f"API Error: {str(e)}" print(f"API Error details: {e}") return None, error_msg def get_demo(): """Enhanced Gradio interface with selective transfer options""" with gr.Blocks( title="HairFastGAN - Selective Transfer", theme=gr.themes.Soft(), css=""" .error-message { color: red !important; background-color: #ffebee !important; padding: 10px !important; border-radius: 5px !important; } .transfer-info { background-color: #e3f2fd !important; padding: 10px !important; border-radius: 5px !important; margin: 10px 0 !important; } """ ) as demo: gr.Markdown("## HairFastGan - Selective Transfer") gr.Markdown( '
' ) gr.Markdown( """