import gradio as gr import numpy as np import cv2 from PIL import Image, ImageOps, ImageDraw import os import torch from transformers import AutoModelForImageSegmentation from torchvision import transforms import hashlib import re import urllib.request as urllib2 from loguru import logger # Set up model and transformations def get_background_removal_model(): try: # Using BiRefNet model for background removal model = AutoModelForImageSegmentation.from_pretrained( "ZhengPeng7/BiRefNet", trust_remote_code=True ) # Use CPU if CUDA is not available device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) return model, device except Exception as e: print(f"Error loading background removal model: {e}") return None, None # Set up image transformation transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) # Cache for storing background removal results bg_removal_cache = {} def get_image_hash(image): """Generate a hash for an image to use as cache key""" if image is None: return None # Convert to bytes and generate hash img_byte_arr = image.tobytes() img_hash = hashlib.md5(img_byte_arr).hexdigest() # Include image dimensions in the hash to ensure uniqueness return f"{img_hash}_{image.width}_{image.height}" def remove_background(image, model_data): if model_data[0] is None: return None, None # Generate a hash for the image to use as cache key img_hash = get_image_hash(image) # Check if result is already in cache if img_hash in bg_removal_cache: logger.info("Using cached background removal result") return bg_removal_cache[img_hash] model, device = model_data try: logger.info("Starting background removal process") # Convert image to RGB if needed if image.mode != "RGB": image = image.convert("RGB") # Store original size for later resizing image_size = image.size # Apply transformations and move to device input_images = transform_image(image).unsqueeze(0).to(device) # Run prediction with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() # Convert prediction to PIL image pred_pil = transforms.ToPILImage()(pred) # Resize mask back to original image size mask = pred_pil.resize(image_size) # Create a copy of the original image and apply alpha channel result_image = image.copy() result_image.putalpha(mask) # Cache the result result = (result_image, np.array(mask)) bg_removal_cache[img_hash] = result logger.info("Background removal process completed") return result except Exception as e: logger.error(f"Error during background removal: {e}") return None, None def parse_color(color_str): """Parse different color formats including rgba strings""" if isinstance(color_str, tuple): # If it's already a tuple, make sure it has alpha if len(color_str) == 3: return color_str + (255,) return color_str if isinstance(color_str, str): # Handle hex color format if color_str.startswith("#"): if len(color_str) == 7: # #RRGGBB format r = int(color_str[1:3], 16) g = int(color_str[3:5], 16) b = int(color_str[5:7], 16) return (r, g, b, 255) else: # Fallback to white if format is unexpected return (255, 255, 255, 255) # Handle rgba() format from Gradio color picker rgba_match = re.match(r"rgba?\(([^)]+)\)", color_str) if rgba_match: values = [float(x.strip()) for x in rgba_match.group(1).split(",")] r = min(255, int(values[0])) g = min(255, int(values[1])) b = min(255, int(values[2])) # Handle alpha if present a = 255 if len(values) > 3: a = min(255, int(values[3] * 255)) return (r, g, b, a) # For named colors, return as is for PIL to handle return color_str # Default fallback return (255, 255, 255, 255) # White def add_person_border(image, mask, border_size, border_color="white"): """Add a border around the person based on the segmentation mask""" if border_size == 0: return image # Convert mask to binary binary_mask = (np.array(mask) > 4).astype(np.uint8) * 255 # Dilate the mask to create the border kernel = np.ones((border_size * 2 + 1, border_size * 2 + 1), np.uint8) dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1) # Create border mask (includes both the person area and border area) border_mask_pil = Image.fromarray(dilated_mask) # Create an image with the border color (white) border_color_rgba = parse_color("white") # Default white border border_img = Image.new("RGBA", image.size, color=border_color_rgba) # Create transparent image for result result = Image.new("RGBA", image.size, (0, 0, 0, 0)) # First paste the white border shape (which includes both border and person area) result.paste(border_img, (0, 0), border_mask_pil) # Then paste the original image on top, but only the non-transparent parts # This will show the original person on top of the white area result.paste(image, (0, 0), Image.fromarray(binary_mask)) return result def detect_face(image): """Detect the largest face in the image and return its bounding box""" logger.info("Starting face detection") # Convert PIL image to OpenCV format img_cv = np.array(image.convert("RGB")) img_cv = img_cv[:, :, ::-1].copy() # Convert RGB to BGR for OpenCV # Load the Haar cascade for face detection face_cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml" face_cascade = cv2.CascadeClassifier(face_cascade_path) # Convert to grayscale for face detection gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) # Detect faces faces = face_cascade.detectMultiScale(gray, 1.1, 4) if len(faces) == 0: logger.warning("No faces detected") return None # Find the largest face largest_face = None max_area = 0 for x, y, w, h in faces: if w * h > max_area: max_area = w * h largest_face = (x, y, w, h) logger.info(f"Largest face detected at: {largest_face}") return largest_face def center_portrait(portrait, face_box, target_width, target_height, zoom_level=1.0): """Center the portrait based on face position and crop to avoid blurriness""" if face_box is None: # If no face detected, just center the portrait return portrait.crop((0, 0, target_width, target_height)), (0, 0) x, y, w, h = face_box # Calculate face center face_center_x = x + w // 2 face_center_y = y + h // 2 # Calculate crop box dimensions crop_width = int(target_width / zoom_level) crop_height = int(target_height / zoom_level) # Ensure the crop box stays within the image bounds left = max(0, face_center_x - crop_width // 2) top = max(0, face_center_y - crop_height // 2) right = min(portrait.width, left + crop_width) bottom = min(portrait.height, top + crop_height) # Adjust left and top if the crop box is smaller than the target dimensions left = max(0, right - crop_width) top = max(0, bottom - crop_height) # Crop the image cropped_img = portrait.crop((left, top, right, bottom)) # Center the cropped image on a transparent canvas centered_img = Image.new("RGBA", (target_width, target_height), (0, 0, 0, 0)) offset_x = (target_width - cropped_img.width) // 2 offset_y = (target_height - cropped_img.height) // 2 centered_img.paste(cropped_img, (offset_x, offset_y), cropped_img) return centered_img, (offset_x, offset_y) def process_portrait( input_image, border_size=10, bg_color="#0000FF", zoom_level=1.0, erode_size=5, circular_overlay=False ): if input_image is None: return None # Global model instance to avoid reloading global model_instance if "model_instance" not in globals(): logger.info("Loading background removal model...") model_instance = get_background_removal_model() logger.info("Processing image...") result = remove_background(input_image, model_instance) if result[0] is None: logger.warning("Failed to remove background, returning original image") return input_image person_img, mask = result # Detect face before any transformations face_box = detect_face(input_image) if face_box: logger.info(f"Face detected at: {face_box}") else: logger.warning("No face detected, will center the entire portrait") # Expand the mask by 3 pixels expanded_mask = cv2.erode( np.array(mask), np.ones((erode_size, erode_size), np.uint8), iterations=1 ) expanded_mask_pil = Image.fromarray(expanded_mask) mask = expanded_mask_pil logger.info("Adding white border...") # Add white border only around the person bordered_img = add_person_border(person_img, mask, border_size, "white") logger.info(f"Creating colored background with color: {bg_color}") # Parse the background color bg_color_rgba = parse_color(bg_color) # Create colored background width, height = bordered_img.size bg_image = Image.new("RGBA", (width, height), color=bg_color_rgba) # Center the portrait based on face location and apply zoom logger.info(f"Applying zoom level: {zoom_level}") centered_portrait, offset = center_portrait( bordered_img, face_box, width, height, zoom_level ) # Create the final composite final_image = Image.alpha_composite(bg_image, centered_portrait) # Crop the final image to the target dimensions crop_width = int(width / zoom_level) crop_height = int(height / zoom_level) left = (width - crop_width) // 2 top = (height - crop_height) // 2 right = left + crop_width bottom = top + crop_height final_image = final_image.crop((left, top, right, bottom)) # Convert back to RGB for display final_image = final_image.convert("RGB") # Ensure the final image is square width, height = final_image.size square_size = min(width, height) left = (width - square_size) // 2 top = (height - square_size) // 2 right = left + square_size bottom = top + square_size final_image = final_image.crop((left, top, right, bottom)) if circular_overlay: # Create a circular mask mask = Image.new("L", (square_size, square_size), 0) draw = ImageDraw.Draw(mask) draw.ellipse((0, 0, square_size, square_size), fill=255) # Apply the circular mask to the final image final_image.putalpha(mask) logger.info( f"Processing complete (portrait offset by {offset}, zoom: {zoom_level})" ) return final_image # Create Gradio interface with gr.Blocks(title="Cool Avatar Creator") as app: gr.Markdown("# Cool Avatar Creator") gr.Markdown( "Upload a portrait image to remove the background, add a white border, and place on a colored background." ) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") border_slider = gr.Slider( minimum=0, maximum=50, value=10, step=1, label="Border Size (pixels)" ) bg_color = gr.ColorPicker(value="#fdc915", label="Background Color") zoom_slider = gr.Slider( minimum=0.5, maximum=4.0, value=1.2, step=0.1, label="Zoom Level" ) erode_slider = gr.Slider( minimum=1, maximum=30, value=15, step=1, label="Erode Size" ) circular_overlay_toggle = gr.Checkbox(label="Enable Circular Overlay") process_button = gr.Button("Process Image") with gr.Column(): output_image = gr.Image(type="pil", label="Processed Image") # Add example images examples = [ [ "https://brobible.com/wp-content/uploads/2019/11/istock-153696622.jpg", 26, "#fdc915", 1.85, ], [ "https://as1.ftcdn.net/jpg/00/26/35/66/1000_F_26356634_6hC5kmcoRfysvavKTZdDQwsk5CMZwwDs.jpg", 23, "#00FF00", 1.4, ], ["https://i.imgflip.com/1freth.jpg?a483936", 29, "#FF0000", 1.4], ] gr.Examples( examples=examples, inputs=[input_image, border_slider, bg_color, zoom_slider], outputs=output_image, fn=process_portrait, cache_examples=False ) process_button.click( fn=process_portrait, inputs=[input_image, border_slider, bg_color, zoom_slider, erode_slider, circular_overlay_toggle], outputs=output_image, ) if __name__ == "__main__": app.launch(share=False) # Share=True creates a public link