HairSwapModel / app.py
miguelmuzo's picture
Upload 2 files
38b38c8 verified
import hashlib
import os
from io import BytesIO
import base64
import requests
from pathlib import Path
import subprocess
import shutil
import gc
import time
import json
import threading
import gradio as gr
from PIL import Image
from cachetools import LRUCache
import torch
import numpy as np
import torchvision.transforms.functional as F
# FastAPI imports for enhanced frontend integration
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import uvicorn
# T4 Medium GPU Optimizations
torch.backends.cudnn.benchmark = True
torch.backends.cuda.max_split_size_mb = 512
# API Models for FastAPI
class ImageRequest(BaseModel):
source_image: str # base64 string
shape_image: Optional[str] = None # base64 string
color_image: Optional[str] = None # base64 string
blending: str = "Article"
poisson_iters: int = 0
poisson_erosion: int = 15
class ImageResponse(BaseModel):
success: bool
result_image: Optional[str] = None # base64 string
message: str
# Download working face landmarks
def download_face_landmarks():
"""Download working dlib face landmarks predictor"""
landmarks_path = 'pretrained_models/ShapeAdaptor/shape_predictor_68_face_landmarks.dat'
if os.path.exists(landmarks_path) and os.path.getsize(landmarks_path) > 50000000:
print("Face landmarks already exists and appears valid")
return True
try:
print("Downloading working face landmarks predictor...")
url = 'https://github.com/davisking/dlib-models/raw/master/shape_predictor_68_face_landmarks.dat.bz2'
os.makedirs(os.path.dirname(landmarks_path), exist_ok=True)
response = requests.get(url, stream=True, timeout=300)
response.raise_for_status()
import bz2
compressed_data = response.content
with open(landmarks_path, 'wb') as f:
f.write(bz2.decompress(compressed_data))
print(f"Face landmarks downloaded successfully ({os.path.getsize(landmarks_path)/1024/1024:.1f}MB)")
return True
except Exception as e:
print(f"Failed to download face landmarks: {e}")
return False
# Comprehensive model download for full accuracy
def download_all_missing_models():
"""Download ALL required models for full accuracy"""
all_required_models = {
'pretrained_models/ArcFace/backbone_ir50.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ArcFace/backbone_ir50.pth',
'pretrained_models/ArcFace/ir_se50.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ArcFace/ir_se50.pth',
'pretrained_models/BiSeNet/face_parsing_79999_iter.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/BiSeNet/face_parsing_79999_iter.pth',
'pretrained_models/FeatureStyleEncoder/backbone.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/backbone.pth',
'pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/psp_ffhq_encode.pt',
'pretrained_models/FeatureStyleEncoder/79999_iter.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/79999_iter.pth',
'pretrained_models/FeatureStyleEncoder/143_enc.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/FeatureStyleEncoder/143_enc.pth',
'pretrained_models/encoder4editing/e4e_ffhq_encode.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/encoder4editing/e4e_ffhq_encode.pt',
'pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth',
'pretrained_models/PostProcess/pp_model.pth': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/PostProcess/pp_model.pth',
'pretrained_models/PostProcess/latent_avg.pt': 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/PostProcess/latent_avg.pt',
}
print("Checking for missing models for full accuracy...")
missing_models = []
existing_models = []
for model_path, url in all_required_models.items():
if os.path.exists(model_path):
existing_models.append(os.path.basename(model_path))
else:
missing_models.append((model_path, url))
if existing_models:
print(f"Found existing models: {', '.join(existing_models)}")
if missing_models:
print(f"Downloading {len(missing_models)} missing models for full accuracy...")
for i, (model_path, url) in enumerate(missing_models, 1):
try:
model_name = os.path.basename(model_path)
print(f"[{i}/{len(missing_models)}] Downloading: {model_name}")
os.makedirs(os.path.dirname(model_path), exist_ok=True)
start_time = time.time()
response = requests.get(url, stream=True, timeout=600)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024*1024):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if total_size > 100*1024*1024 and downloaded % (50*1024*1024) == 0:
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
print(f" Progress: {progress:.1f}% ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)")
elapsed = time.time() - start_time
print(f" Downloaded: {model_name} ({downloaded/1024/1024:.1f}MB in {elapsed:.1f}s)")
except Exception as e:
print(f" Failed to download {model_name}: {e}")
continue
else:
print("All models already present!")
return True
def download_mask_generator_separately():
"""Download large mask_generator.pth file separately"""
mask_path = 'pretrained_models/ShapeAdaptor/mask_generator.pth'
if os.path.exists(mask_path):
print("Mask generator already exists")
return True
try:
print("Downloading mask_generator.pth (919MB)...")
url = 'https://huggingface.co/AIRI-Institute/HairFastGAN/resolve/main/pretrained_models/ShapeAdaptor/mask_generator.pth'
os.makedirs(os.path.dirname(mask_path), exist_ok=True)
response = requests.get(url, stream=True, timeout=900)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
downloaded = 0
with open(mask_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=2*1024*1024):
if chunk:
f.write(chunk)
downloaded += len(chunk)
if downloaded % (100*1024*1024) == 0:
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
print(f" Mask Generator Progress: {progress:.1f}% ({downloaded/1024/1024:.1f}MB/{total_size/1024/1024:.1f}MB)")
print(f"Successfully downloaded mask_generator.pth ({downloaded/1024/1024:.1f}MB)")
return True
except Exception as e:
print(f"Failed to download mask_generator.pth: {e}")
return False
# Download all models
download_all_missing_models()
download_mask_generator_separately()
download_face_landmarks()
# Direct HairFast imports
try:
from hair_swap import HairFast, get_parser
HAIRFAST_AVAILABLE = True
print("HairFast successfully imported!")
except ImportError as e:
print(f"HairFast import failed: {e}")
HAIRFAST_AVAILABLE = False
try:
from utils.shape_predictor import align_face
ALIGN_AVAILABLE = True
print("Face alignment available!")
except ImportError as e:
print(f"Face alignment not available: {e}")
ALIGN_AVAILABLE = False
# Global variables
hair_fast_model = None
align_cache = LRUCache(maxsize=10)
def get_gpu_memory():
"""Check GPU memory for optimization"""
if torch.cuda.is_available():
return torch.cuda.get_device_properties(0).total_memory / 1e9
return 0
def optimize_for_t4():
"""T4 GPU specific optimizations"""
if torch.cuda.is_available():
gpu_memory = get_gpu_memory()
print(f"GPU Memory: {gpu_memory:.1f}GB")
if gpu_memory < 20:
torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
print("T4 optimizations applied")
def check_model_completeness():
"""Check which critical models are available"""
critical_models = {
'StyleGAN': 'pretrained_models/StyleGAN/ffhq.pt',
'Blending': 'pretrained_models/Blending/checkpoint.pth',
'Rotate': 'pretrained_models/Rotate/rotate_best.pth',
'PostProcess': 'pretrained_models/PostProcess/pp_model.pth',
'FeatureEncoder': 'pretrained_models/FeatureStyleEncoder/143_enc.pth',
'E4E': 'pretrained_models/encoder4editing/e4e_ffhq_encode.pt',
'SEAN': 'pretrained_models/sean_checkpoints/CelebA-HQ_pretrained/latest_net_G.pth',
'ShapeAdaptor': 'pretrained_models/ShapeAdaptor/mask_generator.pth',
}
available_models = {}
for name, path in critical_models.items():
available_models[name] = os.path.exists(path)
return available_models
def initialize_hairfast_original():
"""Initialize HairFast exactly like original - use pure defaults"""
global hair_fast_model
if not HAIRFAST_AVAILABLE:
print("HairFast not available")
return False
try:
print("Initializing HairFast with original default arguments...")
optimize_for_t4()
available_models = check_model_completeness()
print("Available models:", {k: v for k, v in available_models.items() if v})
parser = get_parser()
args = parser.parse_args([])
hair_fast_model = HairFast(args)
available_count = sum(available_models.values())
total_count = len(available_models)
accuracy_percentage = (available_count / total_count) * 100
print(f"HairFast initialized with original defaults ({accuracy_percentage:.1f}% accuracy)")
torch.cuda.empty_cache()
gc.collect()
return True
except Exception as e:
print(f"HairFast initialization failed: {e}")
hair_fast_model = None
torch.cuda.empty_cache()
return False
def get_bytes(img):
"""EXACT copy of original get_bytes function"""
if img is None:
return img
buffered = BytesIO()
img.save(buffered, format="JPEG")
return buffered.getvalue()
def bytes_to_image(image: bytes) -> Image.Image:
"""EXACT copy of original bytes_to_image function"""
image = Image.open(BytesIO(image))
return image
def base64_to_image(base64_string):
"""Convert base64 string to PIL Image with error handling"""
try:
if base64_string.startswith('data:image'):
base64_string = base64_string.split(',')[1]
image_bytes = base64.b64decode(base64_string)
image = Image.open(BytesIO(image_bytes))
if image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
print(f"Error converting base64 to image: {e}")
return None
def image_to_base64(image):
"""Convert PIL Image to base64 string with maximum quality"""
if image is None:
return None
buffered = BytesIO()
# Use PNG for lossless quality
image.save(buffered, format="PNG", optimize=False)
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode('utf-8')
return f"data:image/png;base64,{img_base64}"
def center_crop(img):
"""EXACT copy of original center_crop function"""
width, height = img.size
side = min(width, height)
left = (width - side) / 2
top = (height - side) / 2
right = (width + side) / 2
bottom = (height + side) / 2
img = img.crop((left, top, right, bottom))
return img
def resize(name):
"""Fixed resize function with proper size handling"""
def resize_inner(img, align):
global align_cache
if name in align and ALIGN_AVAILABLE:
img_hash = hashlib.md5(get_bytes(img)).hexdigest()
if img_hash not in align_cache:
try:
aligned_imgs = align_face(img, return_tensors=False)
if aligned_imgs and len(aligned_imgs) > 0:
img = aligned_imgs[0]
if img.size != (1024, 1024):
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
align_cache[img_hash] = img
else:
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
align_cache[img_hash] = img
except Exception as e:
print(f"Face alignment failed for {name}, using center crop: {e}")
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
align_cache[img_hash] = img
else:
img = align_cache[img_hash]
else:
if img.size != (1024, 1024):
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
return img
return resize_inner
def swap_hair_selective(face, shape, color, blending, poisson_iters, poisson_erosion):
"""
Enhanced swap logic with selective transfer:
- If only shape provided: change hairstyle only
- If only color provided: change hair color only
- If both provided: change both hairstyle and color
"""
global hair_fast_model
if hair_fast_model is None:
if not initialize_hairfast_original():
return None, "HairFast model not available. Please check if all model files are uploaded."
if not face and not shape and not color:
return None, "Need to upload a face and at least a shape or color ❗"
elif not face:
return None, "Need to upload a face ❗"
elif not shape and not color:
return None, "Need to upload at least a shape or color ❗"
try:
print("Starting selective hair transfer...")
torch.cuda.empty_cache()
gc.collect()
def validate_size(img, name):
if img is not None:
if img.size != (1024, 1024):
print(f"Resizing {name} from {img.size} to (1024, 1024)")
img = center_crop(img)
img = img.resize((1024, 1024), Image.Resampling.LANCZOS)
return img
face = validate_size(face, "face")
shape = validate_size(shape, "shape") if shape else None
color = validate_size(color, "color") if color else None
# Determine transfer mode
has_shape = shape is not None
has_color = color is not None
if has_shape and has_color:
transfer_mode = "both"
print("Transfer mode: Both hairstyle and color")
elif has_shape and not has_color:
transfer_mode = "shape_only"
color = face # Use original face for color reference
print("Transfer mode: Hairstyle only (preserving original color)")
elif has_color and not has_shape:
transfer_mode = "color_only"
shape = face # Use original face for shape reference
print("Transfer mode: Color only (preserving original hairstyle)")
print(f"Final sizes - Face: {face.size}, Shape: {shape.size if shape else 'None'}, Color: {color.size if color else 'None'}")
with torch.no_grad():
start_time = time.time()
# Use the HairFast model's swap method with proper parameters
final_image = hair_fast_model.swap(face, shape, color)
inference_time = time.time() - start_time
print(f"Inference completed in {inference_time:.2f} seconds")
result_image = F.to_pil_image(final_image)
torch.cuda.empty_cache()
gc.collect()
success_message = f"Hair transfer ({transfer_mode}) completed successfully in {inference_time:.2f}s"
print(success_message)
return result_image, success_message
except Exception as e:
torch.cuda.empty_cache()
gc.collect()
error_msg = f"Hair transfer failed: {str(e)}"
print(f"Detailed error: {e}")
return None, error_msg
def hair_transfer_api(source_image, shape_image=None, color_image=None,
blending="Article", poisson_iters=0, poisson_erosion=15):
"""Enhanced API function for frontend integration with quality preservation"""
try:
print("API call received - processing images...")
if isinstance(source_image, str):
source_image = base64_to_image(source_image)
print(f"Source image loaded: {source_image.size if source_image else 'Failed'}")
if isinstance(shape_image, str) and shape_image:
shape_image = base64_to_image(shape_image)
print(f"Shape image loaded: {shape_image.size if shape_image else 'Failed'}")
if isinstance(color_image, str) and color_image:
color_image = base64_to_image(color_image)
print(f"Color image loaded: {color_image.size if color_image else 'Failed'}")
if not source_image:
return None, "Failed to process source image"
result_image, status_message = swap_hair_selective(
source_image, shape_image, color_image,
blending, poisson_iters, poisson_erosion
)
if result_image is not None:
print(f"Result image generated: {result_image.size}")
result_base64 = image_to_base64(result_image)
print("Result converted to base64 successfully")
return result_base64, status_message
else:
print("Result image is None")
return None, status_message
except Exception as e:
error_msg = f"API Error: {str(e)}"
print(f"API Error details: {e}")
return None, error_msg
def get_demo():
"""Enhanced Gradio interface with selective transfer options"""
with gr.Blocks(
title="HairFastGAN - Selective Transfer",
theme=gr.themes.Soft(),
css="""
.error-message {
color: red !important;
background-color: #ffebee !important;
padding: 10px !important;
border-radius: 5px !important;
}
.transfer-info {
background-color: #e3f2fd !important;
padding: 10px !important;
border-radius: 5px !important;
margin: 10px 0 !important;
}
"""
) as demo:
gr.Markdown("## HairFastGan - Selective Transfer")
gr.Markdown(
'<div style="display: flex; align-items: center; gap: 10px;">'
'<span>Enhanced HairFastGAN with selective transfer:</span>'
'<a href="https://arxiv.org/abs/2404.01094"><img src="https://img.shields.io/badge/arXiv-2404.01094-b31b1b.svg" height=22.5></a>'
'<a href="https://github.com/AIRI-Institute/HairFastGAN"><img src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white" height=22.5></a>'
'<a href="https://huggingface.co/AIRI-Institute/HairFastGAN"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/model-on-hf-md.svg" height=22.5></a>'
'<a href="https://colab.research.google.com/#fileId=https://huggingface.co/AIRI-Institute/HairFastGAN/blob/main/notebooks/HairFast_inference.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" height=22.5></a>'
'</div>'
)
gr.Markdown(
"""
<div class="transfer-info">
<b>🎯 Selective Transfer Modes:</b><br>
• <b>Shape Only:</b> Upload only shape image → Changes hairstyle while preserving original color<br>
• <b>Color Only:</b> Upload only color image → Changes hair color while preserving original hairstyle<br>
• <b>Both:</b> Upload both images → Changes both hairstyle and color
</div>
""",
elem_classes="transfer-info"
)
with gr.Row():
with gr.Column():
source = gr.Image(label="Source photo to try on the hairstyle", type="pil")
with gr.Row():
shape = gr.Image(label="Shape photo with desired hairstyle (optional)", type="pil")
color = gr.Image(label="Color photo with desired hair color (optional)", type="pil")
# Transfer mode indicator
transfer_status = gr.Textbox(label="Transfer Mode", interactive=False,
value="Upload images to see transfer mode")
with gr.Accordion("Advanced Options", open=False):
blending = gr.Radio(["Article", "Alternative_v1", "Alternative_v2"], value='Article',
label="Color Encoder version")
poisson_iters = gr.Slider(0, 2500, value=0, step=1, label="Poisson iters")
poisson_erosion = gr.Slider(1, 100, value=15, step=1, label="Poisson erosion")
align = gr.CheckboxGroup(["Face", "Shape", "Color"], value=["Face", "Shape", "Color"],
label="Image cropping [recommended]")
btn = gr.Button("Get the haircut", variant="primary")
with gr.Column():
output = gr.Image(label="Your result")
error_message = gr.Textbox(label="⚠️ Error ⚠️", visible=False, elem_classes="error-message")
# Function to update transfer mode display
def update_transfer_mode(shape_img, color_img):
if shape_img is not None and color_img is not None:
return "🎯 Both hairstyle and color will be transferred"
elif shape_img is not None and color_img is None:
return "🎨 Only hairstyle will be transferred (color preserved)"
elif shape_img is None and color_img is not None:
return "🌈 Only hair color will be transferred (style preserved)"
else:
return "Upload shape and/or color reference images"
# Update transfer mode when images change
shape.change(fn=update_transfer_mode, inputs=[shape, color], outputs=transfer_status)
color.change(fn=update_transfer_mode, inputs=[shape, color], outputs=transfer_status)
source.upload(fn=resize('Face'), inputs=[source, align], outputs=source)
shape.upload(fn=resize('Shape'), inputs=[shape, align], outputs=shape)
color.upload(fn=resize('Color'), inputs=[color, align], outputs=color)
btn.click(
fn=swap_hair_selective,
inputs=[source, shape, color, blending, poisson_iters, poisson_erosion],
outputs=[output, error_message],
api_name="predict"
)
gr.Markdown('''
### How to use:
1. **Upload your source photo** (the face you want to modify)
2. **Choose your transfer mode:**
- For **hairstyle only**: Upload only a shape reference image
- For **color only**: Upload only a color reference image
- For **both**: Upload both shape and color reference images
3. **Click "Get the haircut"** and wait for results!
''')
return demo
def create_api_app():
"""Create FastAPI app for better frontend integration"""
app = FastAPI(title="HairFastGAN Selective Transfer API", version="2.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/api/hair-transfer", response_model=ImageResponse)
async def transfer_hair(request: ImageRequest):
"""Enhanced API endpoint for hair transfer"""
try:
print(f"API request received: {len(request.source_image)} chars source")
result_base64, message = hair_transfer_api(
source_image=request.source_image,
shape_image=request.shape_image,
color_image=request.color_image,
blending=request.blending,
poisson_iters=request.poisson_iters,
poisson_erosion=request.poisson_erosion
)
if result_base64:
return ImageResponse(
success=True,
result_image=result_base64,
message=message or "Hair transfer completed successfully"
)
else:
return ImageResponse(
success=False,
message=message or "Hair transfer failed"
)
except Exception as e:
print(f"API endpoint error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy", "model_loaded": hair_fast_model is not None}
return app
if __name__ == '__main__':
optimize_for_t4()
align_cache = LRUCache(maxsize=10)
if not initialize_hairfast_original():
print("Failed to initialize HairFast model")
exit(1)
gradio_demo = get_demo()
fastapi_app = create_api_app()
def run_fastapi():
uvicorn.run(fastapi_app, host="0.0.0.0", port=8000, log_level="info")
def run_gradio():
gradio_demo.queue(max_size=10, default_concurrency_limit=2)
gradio_demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
print("Starting FastAPI server on port 8000...")
fastapi_thread = threading.Thread(target=run_fastapi)
fastapi_thread.daemon = True
fastapi_thread.start()
print("Starting Gradio server on port 7860...")
run_gradio()