Spaces:
Running
Running
Johann.Haselberger (PEG-AS)
Merge branch 'main' of https://huggingface.co/spaces/jHaselberger/cool-avatar
4221f6e
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image, ImageOps, ImageDraw | |
| import os | |
| import torch | |
| from transformers import AutoModelForImageSegmentation | |
| from torchvision import transforms | |
| import hashlib | |
| import re | |
| import urllib.request as urllib2 | |
| from loguru import logger | |
| # Set up model and transformations | |
| def get_background_removal_model(): | |
| try: | |
| # Using BiRefNet model for background removal | |
| model = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", trust_remote_code=True | |
| ) | |
| # Use CPU if CUDA is not available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| return model, device | |
| except Exception as e: | |
| print(f"Error loading background removal model: {e}") | |
| return None, None | |
| # Set up image transformation | |
| transform_image = transforms.Compose( | |
| [ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| # Cache for storing background removal results | |
| bg_removal_cache = {} | |
| def get_image_hash(image): | |
| """Generate a hash for an image to use as cache key""" | |
| if image is None: | |
| return None | |
| # Convert to bytes and generate hash | |
| img_byte_arr = image.tobytes() | |
| img_hash = hashlib.md5(img_byte_arr).hexdigest() | |
| # Include image dimensions in the hash to ensure uniqueness | |
| return f"{img_hash}_{image.width}_{image.height}" | |
| def remove_background(image, model_data): | |
| if model_data[0] is None: | |
| return None, None | |
| # Generate a hash for the image to use as cache key | |
| img_hash = get_image_hash(image) | |
| # Check if result is already in cache | |
| if img_hash in bg_removal_cache: | |
| logger.info("Using cached background removal result") | |
| return bg_removal_cache[img_hash] | |
| model, device = model_data | |
| try: | |
| logger.info("Starting background removal process") | |
| # Convert image to RGB if needed | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Store original size for later resizing | |
| image_size = image.size | |
| # Apply transformations and move to device | |
| input_images = transform_image(image).unsqueeze(0).to(device) | |
| # Run prediction | |
| with torch.no_grad(): | |
| preds = model(input_images)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| # Convert prediction to PIL image | |
| pred_pil = transforms.ToPILImage()(pred) | |
| # Resize mask back to original image size | |
| mask = pred_pil.resize(image_size) | |
| # Create a copy of the original image and apply alpha channel | |
| result_image = image.copy() | |
| result_image.putalpha(mask) | |
| # Cache the result | |
| result = (result_image, np.array(mask)) | |
| bg_removal_cache[img_hash] = result | |
| logger.info("Background removal process completed") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error during background removal: {e}") | |
| return None, None | |
| def parse_color(color_str): | |
| """Parse different color formats including rgba strings""" | |
| if isinstance(color_str, tuple): | |
| # If it's already a tuple, make sure it has alpha | |
| if len(color_str) == 3: | |
| return color_str + (255,) | |
| return color_str | |
| if isinstance(color_str, str): | |
| # Handle hex color format | |
| if color_str.startswith("#"): | |
| if len(color_str) == 7: # #RRGGBB format | |
| r = int(color_str[1:3], 16) | |
| g = int(color_str[3:5], 16) | |
| b = int(color_str[5:7], 16) | |
| return (r, g, b, 255) | |
| else: | |
| # Fallback to white if format is unexpected | |
| return (255, 255, 255, 255) | |
| # Handle rgba() format from Gradio color picker | |
| rgba_match = re.match(r"rgba?\(([^)]+)\)", color_str) | |
| if rgba_match: | |
| values = [float(x.strip()) for x in rgba_match.group(1).split(",")] | |
| r = min(255, int(values[0])) | |
| g = min(255, int(values[1])) | |
| b = min(255, int(values[2])) | |
| # Handle alpha if present | |
| a = 255 | |
| if len(values) > 3: | |
| a = min(255, int(values[3] * 255)) | |
| return (r, g, b, a) | |
| # For named colors, return as is for PIL to handle | |
| return color_str | |
| # Default fallback | |
| return (255, 255, 255, 255) # White | |
| def add_person_border(image, mask, border_size, border_color="white"): | |
| """Add a border around the person based on the segmentation mask""" | |
| if border_size == 0: | |
| return image | |
| # Convert mask to binary | |
| binary_mask = (np.array(mask) > 4).astype(np.uint8) * 255 | |
| # Dilate the mask to create the border | |
| kernel = np.ones((border_size * 2 + 1, border_size * 2 + 1), np.uint8) | |
| dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1) | |
| # Create border mask (includes both the person area and border area) | |
| border_mask_pil = Image.fromarray(dilated_mask) | |
| # Create an image with the border color (white) | |
| border_color_rgba = parse_color("white") # Default white border | |
| border_img = Image.new("RGBA", image.size, color=border_color_rgba) | |
| # Create transparent image for result | |
| result = Image.new("RGBA", image.size, (0, 0, 0, 0)) | |
| # First paste the white border shape (which includes both border and person area) | |
| result.paste(border_img, (0, 0), border_mask_pil) | |
| # Then paste the original image on top, but only the non-transparent parts | |
| # This will show the original person on top of the white area | |
| result.paste(image, (0, 0), Image.fromarray(binary_mask)) | |
| return result | |
| def detect_face(image): | |
| """Detect the largest face in the image and return its bounding box""" | |
| logger.info("Starting face detection") | |
| # Convert PIL image to OpenCV format | |
| img_cv = np.array(image.convert("RGB")) | |
| img_cv = img_cv[:, :, ::-1].copy() # Convert RGB to BGR for OpenCV | |
| # Load the Haar cascade for face detection | |
| face_cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml" | |
| face_cascade = cv2.CascadeClassifier(face_cascade_path) | |
| # Convert to grayscale for face detection | |
| gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) | |
| # Detect faces | |
| faces = face_cascade.detectMultiScale(gray, 1.1, 4) | |
| if len(faces) == 0: | |
| logger.warning("No faces detected") | |
| return None | |
| # Find the largest face | |
| largest_face = None | |
| max_area = 0 | |
| for x, y, w, h in faces: | |
| if w * h > max_area: | |
| max_area = w * h | |
| largest_face = (x, y, w, h) | |
| logger.info(f"Largest face detected at: {largest_face}") | |
| return largest_face | |
| def center_portrait(portrait, face_box, target_width, target_height, zoom_level=1.0): | |
| """Center the portrait based on face position and crop to avoid blurriness""" | |
| if face_box is None: | |
| # If no face detected, just center the portrait | |
| return portrait.crop((0, 0, target_width, target_height)), (0, 0) | |
| x, y, w, h = face_box | |
| # Calculate face center | |
| face_center_x = x + w // 2 | |
| face_center_y = y + h // 2 | |
| # Calculate crop box dimensions | |
| crop_width = int(target_width / zoom_level) | |
| crop_height = int(target_height / zoom_level) | |
| # Ensure the crop box stays within the image bounds | |
| left = max(0, face_center_x - crop_width // 2) | |
| top = max(0, face_center_y - crop_height // 2) | |
| right = min(portrait.width, left + crop_width) | |
| bottom = min(portrait.height, top + crop_height) | |
| # Adjust left and top if the crop box is smaller than the target dimensions | |
| left = max(0, right - crop_width) | |
| top = max(0, bottom - crop_height) | |
| # Crop the image | |
| cropped_img = portrait.crop((left, top, right, bottom)) | |
| # Center the cropped image on a transparent canvas | |
| centered_img = Image.new("RGBA", (target_width, target_height), (0, 0, 0, 0)) | |
| offset_x = (target_width - cropped_img.width) // 2 | |
| offset_y = (target_height - cropped_img.height) // 2 | |
| centered_img.paste(cropped_img, (offset_x, offset_y), cropped_img) | |
| return centered_img, (offset_x, offset_y) | |
| def process_portrait( | |
| input_image, border_size=10, bg_color="#0000FF", zoom_level=1.0, erode_size=5, circular_overlay=False | |
| ): | |
| if input_image is None: | |
| return None | |
| # Global model instance to avoid reloading | |
| global model_instance | |
| if "model_instance" not in globals(): | |
| logger.info("Loading background removal model...") | |
| model_instance = get_background_removal_model() | |
| logger.info("Processing image...") | |
| result = remove_background(input_image, model_instance) | |
| if result[0] is None: | |
| logger.warning("Failed to remove background, returning original image") | |
| return input_image | |
| person_img, mask = result | |
| # Detect face before any transformations | |
| face_box = detect_face(input_image) | |
| if face_box: | |
| logger.info(f"Face detected at: {face_box}") | |
| else: | |
| logger.warning("No face detected, will center the entire portrait") | |
| # Expand the mask by 3 pixels | |
| expanded_mask = cv2.erode( | |
| np.array(mask), np.ones((erode_size, erode_size), np.uint8), iterations=1 | |
| ) | |
| expanded_mask_pil = Image.fromarray(expanded_mask) | |
| mask = expanded_mask_pil | |
| logger.info("Adding white border...") | |
| # Add white border only around the person | |
| bordered_img = add_person_border(person_img, mask, border_size, "white") | |
| logger.info(f"Creating colored background with color: {bg_color}") | |
| # Parse the background color | |
| bg_color_rgba = parse_color(bg_color) | |
| # Create colored background | |
| width, height = bordered_img.size | |
| bg_image = Image.new("RGBA", (width, height), color=bg_color_rgba) | |
| # Center the portrait based on face location and apply zoom | |
| logger.info(f"Applying zoom level: {zoom_level}") | |
| centered_portrait, offset = center_portrait( | |
| bordered_img, face_box, width, height, zoom_level | |
| ) | |
| # Create the final composite | |
| final_image = Image.alpha_composite(bg_image, centered_portrait) | |
| # Crop the final image to the target dimensions | |
| crop_width = int(width / zoom_level) | |
| crop_height = int(height / zoom_level) | |
| left = (width - crop_width) // 2 | |
| top = (height - crop_height) // 2 | |
| right = left + crop_width | |
| bottom = top + crop_height | |
| final_image = final_image.crop((left, top, right, bottom)) | |
| # Convert back to RGB for display | |
| final_image = final_image.convert("RGB") | |
| # Ensure the final image is square | |
| width, height = final_image.size | |
| square_size = min(width, height) | |
| left = (width - square_size) // 2 | |
| top = (height - square_size) // 2 | |
| right = left + square_size | |
| bottom = top + square_size | |
| final_image = final_image.crop((left, top, right, bottom)) | |
| if circular_overlay: | |
| # Create a circular mask | |
| mask = Image.new("L", (square_size, square_size), 0) | |
| draw = ImageDraw.Draw(mask) | |
| draw.ellipse((0, 0, square_size, square_size), fill=255) | |
| # Apply the circular mask to the final image | |
| final_image.putalpha(mask) | |
| logger.info( | |
| f"Processing complete (portrait offset by {offset}, zoom: {zoom_level})" | |
| ) | |
| return final_image | |
| # Create Gradio interface | |
| with gr.Blocks(title="Cool Avatar Creator") as app: | |
| gr.Markdown("# Cool Avatar Creator") | |
| gr.Markdown( | |
| "Upload a portrait image to remove the background, add a white border, and place on a colored background." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| border_slider = gr.Slider( | |
| minimum=0, maximum=50, value=10, step=1, label="Border Size (pixels)" | |
| ) | |
| bg_color = gr.ColorPicker(value="#fdc915", label="Background Color") | |
| zoom_slider = gr.Slider( | |
| minimum=0.5, maximum=4.0, value=1.2, step=0.1, label="Zoom Level" | |
| ) | |
| erode_slider = gr.Slider( | |
| minimum=1, maximum=30, value=15, step=1, label="Erode Size" | |
| ) | |
| circular_overlay_toggle = gr.Checkbox(label="Enable Circular Overlay") | |
| process_button = gr.Button("Process Image") | |
| with gr.Column(): | |
| output_image = gr.Image(type="pil", label="Processed Image") | |
| # Add example images | |
| examples = [ | |
| [ | |
| "https://brobible.com/wp-content/uploads/2019/11/istock-153696622.jpg", | |
| 26, | |
| "#fdc915", | |
| 1.85, | |
| ], | |
| [ | |
| "https://as1.ftcdn.net/jpg/00/26/35/66/1000_F_26356634_6hC5kmcoRfysvavKTZdDQwsk5CMZwwDs.jpg", | |
| 23, | |
| "#00FF00", | |
| 1.4, | |
| ], | |
| ["https://i.imgflip.com/1freth.jpg?a483936", 29, "#FF0000", 1.4], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_image, border_slider, bg_color, zoom_slider], | |
| outputs=output_image, | |
| fn=process_portrait, | |
| cache_examples=False | |
| ) | |
| process_button.click( | |
| fn=process_portrait, | |
| inputs=[input_image, border_slider, bg_color, zoom_slider, erode_slider, circular_overlay_toggle], | |
| outputs=output_image, | |
| ) | |
| if __name__ == "__main__": | |
| app.launch(share=False) # Share=True creates a public link | |