Spaces:
Running
Running
File size: 13,448 Bytes
4550fcf 26fdb35 4550fcf c28e15d a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 26fdb35 a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 4550fcf a49d644 c28e15d a49d644 26fdb35 a49d644 4550fcf c28e15d a49d644 4550fcf a49d644 4550fcf a49d644 26fdb35 4550fcf a49d644 4550fcf a49d644 c28e15d a49d644 c28e15d 93b60fd c28e15d a49d644 4550fcf 26fdb35 a49d644 4550fcf a49d644 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 | import gradio as gr
import numpy as np
import cv2
from PIL import Image, ImageOps, ImageDraw
import os
import torch
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
import hashlib
import re
import urllib.request as urllib2
from loguru import logger
# Set up model and transformations
def get_background_removal_model():
try:
# Using BiRefNet model for background removal
model = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
# Use CPU if CUDA is not available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
return model, device
except Exception as e:
print(f"Error loading background removal model: {e}")
return None, None
# Set up image transformation
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
# Cache for storing background removal results
bg_removal_cache = {}
def get_image_hash(image):
"""Generate a hash for an image to use as cache key"""
if image is None:
return None
# Convert to bytes and generate hash
img_byte_arr = image.tobytes()
img_hash = hashlib.md5(img_byte_arr).hexdigest()
# Include image dimensions in the hash to ensure uniqueness
return f"{img_hash}_{image.width}_{image.height}"
def remove_background(image, model_data):
if model_data[0] is None:
return None, None
# Generate a hash for the image to use as cache key
img_hash = get_image_hash(image)
# Check if result is already in cache
if img_hash in bg_removal_cache:
logger.info("Using cached background removal result")
return bg_removal_cache[img_hash]
model, device = model_data
try:
logger.info("Starting background removal process")
# Convert image to RGB if needed
if image.mode != "RGB":
image = image.convert("RGB")
# Store original size for later resizing
image_size = image.size
# Apply transformations and move to device
input_images = transform_image(image).unsqueeze(0).to(device)
# Run prediction
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Convert prediction to PIL image
pred_pil = transforms.ToPILImage()(pred)
# Resize mask back to original image size
mask = pred_pil.resize(image_size)
# Create a copy of the original image and apply alpha channel
result_image = image.copy()
result_image.putalpha(mask)
# Cache the result
result = (result_image, np.array(mask))
bg_removal_cache[img_hash] = result
logger.info("Background removal process completed")
return result
except Exception as e:
logger.error(f"Error during background removal: {e}")
return None, None
def parse_color(color_str):
"""Parse different color formats including rgba strings"""
if isinstance(color_str, tuple):
# If it's already a tuple, make sure it has alpha
if len(color_str) == 3:
return color_str + (255,)
return color_str
if isinstance(color_str, str):
# Handle hex color format
if color_str.startswith("#"):
if len(color_str) == 7: # #RRGGBB format
r = int(color_str[1:3], 16)
g = int(color_str[3:5], 16)
b = int(color_str[5:7], 16)
return (r, g, b, 255)
else:
# Fallback to white if format is unexpected
return (255, 255, 255, 255)
# Handle rgba() format from Gradio color picker
rgba_match = re.match(r"rgba?\(([^)]+)\)", color_str)
if rgba_match:
values = [float(x.strip()) for x in rgba_match.group(1).split(",")]
r = min(255, int(values[0]))
g = min(255, int(values[1]))
b = min(255, int(values[2]))
# Handle alpha if present
a = 255
if len(values) > 3:
a = min(255, int(values[3] * 255))
return (r, g, b, a)
# For named colors, return as is for PIL to handle
return color_str
# Default fallback
return (255, 255, 255, 255) # White
def add_person_border(image, mask, border_size, border_color="white"):
"""Add a border around the person based on the segmentation mask"""
if border_size == 0:
return image
# Convert mask to binary
binary_mask = (np.array(mask) > 4).astype(np.uint8) * 255
# Dilate the mask to create the border
kernel = np.ones((border_size * 2 + 1, border_size * 2 + 1), np.uint8)
dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1)
# Create border mask (includes both the person area and border area)
border_mask_pil = Image.fromarray(dilated_mask)
# Create an image with the border color (white)
border_color_rgba = parse_color("white") # Default white border
border_img = Image.new("RGBA", image.size, color=border_color_rgba)
# Create transparent image for result
result = Image.new("RGBA", image.size, (0, 0, 0, 0))
# First paste the white border shape (which includes both border and person area)
result.paste(border_img, (0, 0), border_mask_pil)
# Then paste the original image on top, but only the non-transparent parts
# This will show the original person on top of the white area
result.paste(image, (0, 0), Image.fromarray(binary_mask))
return result
def detect_face(image):
"""Detect the largest face in the image and return its bounding box"""
logger.info("Starting face detection")
# Convert PIL image to OpenCV format
img_cv = np.array(image.convert("RGB"))
img_cv = img_cv[:, :, ::-1].copy() # Convert RGB to BGR for OpenCV
# Load the Haar cascade for face detection
face_cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
face_cascade = cv2.CascadeClassifier(face_cascade_path)
# Convert to grayscale for face detection
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
# Detect faces
faces = face_cascade.detectMultiScale(gray, 1.1, 4)
if len(faces) == 0:
logger.warning("No faces detected")
return None
# Find the largest face
largest_face = None
max_area = 0
for x, y, w, h in faces:
if w * h > max_area:
max_area = w * h
largest_face = (x, y, w, h)
logger.info(f"Largest face detected at: {largest_face}")
return largest_face
def center_portrait(portrait, face_box, target_width, target_height, zoom_level=1.0):
"""Center the portrait based on face position and crop to avoid blurriness"""
if face_box is None:
# If no face detected, just center the portrait
return portrait.crop((0, 0, target_width, target_height)), (0, 0)
x, y, w, h = face_box
# Calculate face center
face_center_x = x + w // 2
face_center_y = y + h // 2
# Calculate crop box dimensions
crop_width = int(target_width / zoom_level)
crop_height = int(target_height / zoom_level)
# Ensure the crop box stays within the image bounds
left = max(0, face_center_x - crop_width // 2)
top = max(0, face_center_y - crop_height // 2)
right = min(portrait.width, left + crop_width)
bottom = min(portrait.height, top + crop_height)
# Adjust left and top if the crop box is smaller than the target dimensions
left = max(0, right - crop_width)
top = max(0, bottom - crop_height)
# Crop the image
cropped_img = portrait.crop((left, top, right, bottom))
# Center the cropped image on a transparent canvas
centered_img = Image.new("RGBA", (target_width, target_height), (0, 0, 0, 0))
offset_x = (target_width - cropped_img.width) // 2
offset_y = (target_height - cropped_img.height) // 2
centered_img.paste(cropped_img, (offset_x, offset_y), cropped_img)
return centered_img, (offset_x, offset_y)
def process_portrait(
input_image, border_size=10, bg_color="#0000FF", zoom_level=1.0, erode_size=5, circular_overlay=False
):
if input_image is None:
return None
# Global model instance to avoid reloading
global model_instance
if "model_instance" not in globals():
logger.info("Loading background removal model...")
model_instance = get_background_removal_model()
logger.info("Processing image...")
result = remove_background(input_image, model_instance)
if result[0] is None:
logger.warning("Failed to remove background, returning original image")
return input_image
person_img, mask = result
# Detect face before any transformations
face_box = detect_face(input_image)
if face_box:
logger.info(f"Face detected at: {face_box}")
else:
logger.warning("No face detected, will center the entire portrait")
# Expand the mask by 3 pixels
expanded_mask = cv2.erode(
np.array(mask), np.ones((erode_size, erode_size), np.uint8), iterations=1
)
expanded_mask_pil = Image.fromarray(expanded_mask)
mask = expanded_mask_pil
logger.info("Adding white border...")
# Add white border only around the person
bordered_img = add_person_border(person_img, mask, border_size, "white")
logger.info(f"Creating colored background with color: {bg_color}")
# Parse the background color
bg_color_rgba = parse_color(bg_color)
# Create colored background
width, height = bordered_img.size
bg_image = Image.new("RGBA", (width, height), color=bg_color_rgba)
# Center the portrait based on face location and apply zoom
logger.info(f"Applying zoom level: {zoom_level}")
centered_portrait, offset = center_portrait(
bordered_img, face_box, width, height, zoom_level
)
# Create the final composite
final_image = Image.alpha_composite(bg_image, centered_portrait)
# Crop the final image to the target dimensions
crop_width = int(width / zoom_level)
crop_height = int(height / zoom_level)
left = (width - crop_width) // 2
top = (height - crop_height) // 2
right = left + crop_width
bottom = top + crop_height
final_image = final_image.crop((left, top, right, bottom))
# Convert back to RGB for display
final_image = final_image.convert("RGB")
# Ensure the final image is square
width, height = final_image.size
square_size = min(width, height)
left = (width - square_size) // 2
top = (height - square_size) // 2
right = left + square_size
bottom = top + square_size
final_image = final_image.crop((left, top, right, bottom))
if circular_overlay:
# Create a circular mask
mask = Image.new("L", (square_size, square_size), 0)
draw = ImageDraw.Draw(mask)
draw.ellipse((0, 0, square_size, square_size), fill=255)
# Apply the circular mask to the final image
final_image.putalpha(mask)
logger.info(
f"Processing complete (portrait offset by {offset}, zoom: {zoom_level})"
)
return final_image
# Create Gradio interface
with gr.Blocks(title="Cool Avatar Creator") as app:
gr.Markdown("# Cool Avatar Creator")
gr.Markdown(
"Upload a portrait image to remove the background, add a white border, and place on a colored background."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
border_slider = gr.Slider(
minimum=0, maximum=50, value=10, step=1, label="Border Size (pixels)"
)
bg_color = gr.ColorPicker(value="#fdc915", label="Background Color")
zoom_slider = gr.Slider(
minimum=0.5, maximum=4.0, value=1.2, step=0.1, label="Zoom Level"
)
erode_slider = gr.Slider(
minimum=1, maximum=30, value=15, step=1, label="Erode Size"
)
circular_overlay_toggle = gr.Checkbox(label="Enable Circular Overlay")
process_button = gr.Button("Process Image")
with gr.Column():
output_image = gr.Image(type="pil", label="Processed Image")
# Add example images
examples = [
[
"https://brobible.com/wp-content/uploads/2019/11/istock-153696622.jpg",
26,
"#fdc915",
1.85,
],
[
"https://as1.ftcdn.net/jpg/00/26/35/66/1000_F_26356634_6hC5kmcoRfysvavKTZdDQwsk5CMZwwDs.jpg",
23,
"#00FF00",
1.4,
],
["https://i.imgflip.com/1freth.jpg?a483936", 29, "#FF0000", 1.4],
]
gr.Examples(
examples=examples,
inputs=[input_image, border_slider, bg_color, zoom_slider],
outputs=output_image,
fn=process_portrait,
cache_examples=False
)
process_button.click(
fn=process_portrait,
inputs=[input_image, border_slider, bg_color, zoom_slider, erode_slider, circular_overlay_toggle],
outputs=output_image,
)
if __name__ == "__main__":
app.launch(share=False) # Share=True creates a public link
|