File size: 13,448 Bytes
4550fcf
 
 
26fdb35
4550fcf
 
 
 
 
 
c28e15d
a49d644
 
4550fcf
 
 
 
a49d644
 
 
4550fcf
 
 
 
 
 
 
 
 
a49d644
 
 
 
 
 
 
4550fcf
 
 
 
 
 
 
 
a49d644
4550fcf
 
 
a49d644
4550fcf
 
 
 
 
 
a49d644
4550fcf
 
a49d644
4550fcf
 
a49d644
4550fcf
a49d644
4550fcf
a49d644
4550fcf
a49d644
4550fcf
a49d644
 
 
4550fcf
 
a49d644
4550fcf
 
a49d644
4550fcf
 
 
 
 
 
 
 
 
 
 
a49d644
4550fcf
 
 
a49d644
 
4550fcf
a49d644
4550fcf
a49d644
4550fcf
 
 
 
 
 
 
 
 
a49d644
4550fcf
 
a49d644
4550fcf
 
 
 
 
 
 
 
a49d644
4550fcf
a49d644
4550fcf
a49d644
4550fcf
 
 
a49d644
4550fcf
 
 
 
a49d644
4550fcf
a49d644
4550fcf
 
a49d644
4550fcf
 
 
a49d644
4550fcf
 
 
a49d644
4550fcf
 
a49d644
4550fcf
a49d644
4550fcf
a49d644
4550fcf
 
a49d644
4550fcf
a49d644
 
 
4550fcf
a49d644
 
4550fcf
 
a49d644
4550fcf
 
a49d644
 
4550fcf
 
 
 
a49d644
4550fcf
a49d644
4550fcf
a49d644
4550fcf
a49d644
4550fcf
a49d644
4550fcf
 
a49d644
4550fcf
 
a49d644
4550fcf
a49d644
4550fcf
a49d644
4550fcf
 
 
a49d644
 
4550fcf
 
 
a49d644
 
4550fcf
 
 
 
 
 
 
a49d644
4550fcf
a49d644
4550fcf
 
 
a49d644
4550fcf
 
 
a49d644
4550fcf
 
 
 
 
a49d644
4550fcf
 
 
a49d644
4550fcf
 
a49d644
4550fcf
a49d644
4550fcf
 
 
a49d644
4550fcf
 
a49d644
26fdb35
a49d644
4550fcf
 
a49d644
4550fcf
 
a49d644
 
4550fcf
a49d644
 
4550fcf
 
a49d644
4550fcf
a49d644
4550fcf
a49d644
4550fcf
 
 
a49d644
4550fcf
a49d644
 
4550fcf
a49d644
 
 
4550fcf
a49d644
4550fcf
a49d644
 
4550fcf
a49d644
 
 
4550fcf
 
a49d644
4550fcf
 
a49d644
 
4550fcf
a49d644
 
 
 
 
4550fcf
 
a49d644
4550fcf
 
 
 
 
 
 
 
a49d644
4550fcf
a49d644
 
c28e15d
 
 
 
 
 
 
 
a49d644
26fdb35
 
 
 
 
 
 
 
 
a49d644
 
 
4550fcf
 
 
c28e15d
 
a49d644
 
 
 
4550fcf
 
 
a49d644
 
 
4550fcf
a49d644
 
 
 
 
 
26fdb35
4550fcf
a49d644
4550fcf
 
a49d644
c28e15d
 
a49d644
 
 
 
 
 
 
 
 
 
 
 
c28e15d
 
 
 
 
 
 
93b60fd
c28e15d
a49d644
4550fcf
 
26fdb35
a49d644
4550fcf
 
 
a49d644
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
import gradio as gr
import numpy as np
import cv2
from PIL import Image, ImageOps, ImageDraw
import os
import torch
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
import hashlib
import re
import urllib.request as urllib2
from loguru import logger

# Set up model and transformations
def get_background_removal_model():
    try:
        # Using BiRefNet model for background removal
        model = AutoModelForImageSegmentation.from_pretrained(
            "ZhengPeng7/BiRefNet", trust_remote_code=True
        )
        # Use CPU if CUDA is not available
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        return model, device
    except Exception as e:
        print(f"Error loading background removal model: {e}")
        return None, None

# Set up image transformation
transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

# Cache for storing background removal results
bg_removal_cache = {}

def get_image_hash(image):
    """Generate a hash for an image to use as cache key"""
    if image is None:
        return None

    # Convert to bytes and generate hash
    img_byte_arr = image.tobytes()
    img_hash = hashlib.md5(img_byte_arr).hexdigest()

    # Include image dimensions in the hash to ensure uniqueness
    return f"{img_hash}_{image.width}_{image.height}"

def remove_background(image, model_data):
    if model_data[0] is None:
        return None, None

    # Generate a hash for the image to use as cache key
    img_hash = get_image_hash(image)

    # Check if result is already in cache
    if img_hash in bg_removal_cache:
        logger.info("Using cached background removal result")
        return bg_removal_cache[img_hash]

    model, device = model_data

    try:
        logger.info("Starting background removal process")
        # Convert image to RGB if needed
        if image.mode != "RGB":
            image = image.convert("RGB")

        # Store original size for later resizing
        image_size = image.size

        # Apply transformations and move to device
        input_images = transform_image(image).unsqueeze(0).to(device)

        # Run prediction
        with torch.no_grad():
            preds = model(input_images)[-1].sigmoid().cpu()
            pred = preds[0].squeeze()
            # Convert prediction to PIL image
            pred_pil = transforms.ToPILImage()(pred)
            # Resize mask back to original image size
            mask = pred_pil.resize(image_size)
        # Create a copy of the original image and apply alpha channel
        result_image = image.copy()
        result_image.putalpha(mask)

        # Cache the result
        result = (result_image, np.array(mask))
        bg_removal_cache[img_hash] = result

        logger.info("Background removal process completed")
        return result

    except Exception as e:
        logger.error(f"Error during background removal: {e}")
        return None, None

def parse_color(color_str):
    """Parse different color formats including rgba strings"""
    if isinstance(color_str, tuple):
        # If it's already a tuple, make sure it has alpha
        if len(color_str) == 3:
            return color_str + (255,)
        return color_str

    if isinstance(color_str, str):
        # Handle hex color format
        if color_str.startswith("#"):
            if len(color_str) == 7:  # #RRGGBB format
                r = int(color_str[1:3], 16)
                g = int(color_str[3:5], 16)
                b = int(color_str[5:7], 16)
                return (r, g, b, 255)
            else:
                # Fallback to white if format is unexpected
                return (255, 255, 255, 255)

        # Handle rgba() format from Gradio color picker
        rgba_match = re.match(r"rgba?\(([^)]+)\)", color_str)
        if rgba_match:
            values = [float(x.strip()) for x in rgba_match.group(1).split(",")]
            r = min(255, int(values[0]))
            g = min(255, int(values[1]))
            b = min(255, int(values[2]))

            # Handle alpha if present
            a = 255
            if len(values) > 3:
                a = min(255, int(values[3] * 255))

            return (r, g, b, a)

        # For named colors, return as is for PIL to handle
        return color_str

    # Default fallback
    return (255, 255, 255, 255)  # White

def add_person_border(image, mask, border_size, border_color="white"):
    """Add a border around the person based on the segmentation mask"""
    if border_size == 0:
        return image

    # Convert mask to binary
    binary_mask = (np.array(mask) > 4).astype(np.uint8) * 255

    # Dilate the mask to create the border
    kernel = np.ones((border_size * 2 + 1, border_size * 2 + 1), np.uint8)
    dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1)

    # Create border mask (includes both the person area and border area)
    border_mask_pil = Image.fromarray(dilated_mask)

    # Create an image with the border color (white)
    border_color_rgba = parse_color("white")  # Default white border
    border_img = Image.new("RGBA", image.size, color=border_color_rgba)

    # Create transparent image for result
    result = Image.new("RGBA", image.size, (0, 0, 0, 0))

    # First paste the white border shape (which includes both border and person area)
    result.paste(border_img, (0, 0), border_mask_pil)

    # Then paste the original image on top, but only the non-transparent parts
    # This will show the original person on top of the white area
    result.paste(image, (0, 0), Image.fromarray(binary_mask))

    return result

def detect_face(image):
    """Detect the largest face in the image and return its bounding box"""
    logger.info("Starting face detection")
    # Convert PIL image to OpenCV format
    img_cv = np.array(image.convert("RGB"))
    img_cv = img_cv[:, :, ::-1].copy()  # Convert RGB to BGR for OpenCV

    # Load the Haar cascade for face detection
    face_cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
    face_cascade = cv2.CascadeClassifier(face_cascade_path)

    # Convert to grayscale for face detection
    gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)

    # Detect faces
    faces = face_cascade.detectMultiScale(gray, 1.1, 4)

    if len(faces) == 0:
        logger.warning("No faces detected")
        return None

    # Find the largest face
    largest_face = None
    max_area = 0

    for x, y, w, h in faces:
        if w * h > max_area:
            max_area = w * h
            largest_face = (x, y, w, h)

    logger.info(f"Largest face detected at: {largest_face}")
    return largest_face

def center_portrait(portrait, face_box, target_width, target_height, zoom_level=1.0):
    """Center the portrait based on face position and crop to avoid blurriness"""
    if face_box is None:
        # If no face detected, just center the portrait
        return portrait.crop((0, 0, target_width, target_height)), (0, 0)

    x, y, w, h = face_box

    # Calculate face center
    face_center_x = x + w // 2
    face_center_y = y + h // 2

    # Calculate crop box dimensions
    crop_width = int(target_width / zoom_level)
    crop_height = int(target_height / zoom_level)

    # Ensure the crop box stays within the image bounds
    left = max(0, face_center_x - crop_width // 2)
    top = max(0, face_center_y - crop_height // 2)
    right = min(portrait.width, left + crop_width)
    bottom = min(portrait.height, top + crop_height)

    # Adjust left and top if the crop box is smaller than the target dimensions
    left = max(0, right - crop_width)
    top = max(0, bottom - crop_height)

    # Crop the image
    cropped_img = portrait.crop((left, top, right, bottom))

    # Center the cropped image on a transparent canvas
    centered_img = Image.new("RGBA", (target_width, target_height), (0, 0, 0, 0))
    offset_x = (target_width - cropped_img.width) // 2
    offset_y = (target_height - cropped_img.height) // 2
    centered_img.paste(cropped_img, (offset_x, offset_y), cropped_img)

    return centered_img, (offset_x, offset_y)

def process_portrait(
    input_image, border_size=10, bg_color="#0000FF", zoom_level=1.0, erode_size=5, circular_overlay=False
):
    if input_image is None:
        return None

    # Global model instance to avoid reloading
    global model_instance
    if "model_instance" not in globals():
        logger.info("Loading background removal model...")
        model_instance = get_background_removal_model()

    logger.info("Processing image...")
    result = remove_background(input_image, model_instance)
    if result[0] is None:
        logger.warning("Failed to remove background, returning original image")
        return input_image

    person_img, mask = result

    # Detect face before any transformations
    face_box = detect_face(input_image)
    if face_box:
        logger.info(f"Face detected at: {face_box}")
    else:
        logger.warning("No face detected, will center the entire portrait")

    # Expand the mask by 3 pixels
    expanded_mask = cv2.erode(
        np.array(mask), np.ones((erode_size, erode_size), np.uint8), iterations=1
    )
    expanded_mask_pil = Image.fromarray(expanded_mask)

    mask = expanded_mask_pil

    logger.info("Adding white border...")
    # Add white border only around the person
    bordered_img = add_person_border(person_img, mask, border_size, "white")

    logger.info(f"Creating colored background with color: {bg_color}")
    # Parse the background color
    bg_color_rgba = parse_color(bg_color)

    # Create colored background
    width, height = bordered_img.size
    bg_image = Image.new("RGBA", (width, height), color=bg_color_rgba)

    # Center the portrait based on face location and apply zoom
    logger.info(f"Applying zoom level: {zoom_level}")
    centered_portrait, offset = center_portrait(
        bordered_img, face_box, width, height, zoom_level
    )

    # Create the final composite
    final_image = Image.alpha_composite(bg_image, centered_portrait)

    # Crop the final image to the target dimensions
    crop_width = int(width / zoom_level)
    crop_height = int(height / zoom_level)
    left = (width - crop_width) // 2
    top = (height - crop_height) // 2
    right = left + crop_width
    bottom = top + crop_height
    final_image = final_image.crop((left, top, right, bottom))

    # Convert back to RGB for display
    final_image = final_image.convert("RGB")

    # Ensure the final image is square
    width, height = final_image.size
    square_size = min(width, height)
    left = (width - square_size) // 2
    top = (height - square_size) // 2
    right = left + square_size
    bottom = top + square_size
    final_image = final_image.crop((left, top, right, bottom))

    if circular_overlay:
        # Create a circular mask
        mask = Image.new("L", (square_size, square_size), 0)
        draw = ImageDraw.Draw(mask)
        draw.ellipse((0, 0, square_size, square_size), fill=255)

        # Apply the circular mask to the final image
        final_image.putalpha(mask)

    logger.info(
        f"Processing complete (portrait offset by {offset}, zoom: {zoom_level})"
    )
    return final_image

# Create Gradio interface
with gr.Blocks(title="Cool Avatar Creator") as app:
    gr.Markdown("# Cool Avatar Creator")
    gr.Markdown(
        "Upload a portrait image to remove the background, add a white border, and place on a colored background."
    )

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="pil", label="Input Image")
            border_slider = gr.Slider(
                minimum=0, maximum=50, value=10, step=1, label="Border Size (pixels)"
            )
            bg_color = gr.ColorPicker(value="#fdc915", label="Background Color")
            zoom_slider = gr.Slider(
                minimum=0.5, maximum=4.0, value=1.2, step=0.1, label="Zoom Level"
            )
            erode_slider = gr.Slider(
                minimum=1, maximum=30, value=15, step=1, label="Erode Size"
            )
            circular_overlay_toggle = gr.Checkbox(label="Enable Circular Overlay")
            process_button = gr.Button("Process Image")

        with gr.Column():
            output_image = gr.Image(type="pil", label="Processed Image")

    # Add example images
    examples = [
        [
            "https://brobible.com/wp-content/uploads/2019/11/istock-153696622.jpg",
            26,
            "#fdc915",
            1.85,
        ],
        [
            "https://as1.ftcdn.net/jpg/00/26/35/66/1000_F_26356634_6hC5kmcoRfysvavKTZdDQwsk5CMZwwDs.jpg",
            23,
            "#00FF00",
            1.4,
        ],
        ["https://i.imgflip.com/1freth.jpg?a483936", 29, "#FF0000", 1.4],
    ]
    gr.Examples(
        examples=examples,
        inputs=[input_image, border_slider, bg_color, zoom_slider],
        outputs=output_image,
        fn=process_portrait,
        cache_examples=False
    )

    process_button.click(
        fn=process_portrait,
        inputs=[input_image, border_slider, bg_color, zoom_slider, erode_slider, circular_overlay_toggle],
        outputs=output_image,
    )

if __name__ == "__main__":
    app.launch(share=False)  # Share=True creates a public link