Space / app.py
Seniordev22's picture
Update app.py
de3a19f verified
raw
history blame
10.2 kB
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import logging
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
print("Loading SegFormer face-parsing model...")
processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
model.to(device)
model.eval()
logger.info("Model loaded!")
hair_class_id = 13
ear_class_ids = [7, 8]
def make_realistic_bald(input_image: Image.Image) -> tuple[Image.Image, Image.Image, Image.Image]:
# (Yeh pura function tera perfect logic wala β€” same as before, no change)
try:
orig_w, orig_h = input_image.size
original_np = np.array(input_image)
original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
logger.info(f"Processing: {orig_w}x{orig_h}")
MAX_PROCESS_DIM = 2048
scale_factor = 1.0
working_np = original_np
working_bgr = original_bgr
working_h, working_w = orig_h, orig_w
if max(orig_w, orig_h) > MAX_PROCESS_DIM:
logger.info(f"Downscaling to max {MAX_PROCESS_DIM}px")
scale_factor = MAX_PROCESS_DIM / max(orig_w, orig_h)
working_w = int(orig_w * scale_factor)
working_h = int(orig_h * scale_factor)
working_np = cv2.resize(original_np, (working_w, working_h), cv2.INTER_AREA)
working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
pil_working = Image.fromarray(working_np)
inputs = processor(images=pil_working, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
upsampled_logits = torch.nn.functional.interpolate(
logits, size=(working_h, working_w), mode="bilinear", align_corners=False
)
parsing = upsampled_logits.argmax(dim=1).squeeze(0).cpu().numpy()
hair_mask = (parsing == hair_class_id).astype(np.uint8)
ears_mask = np.zeros_like(hair_mask)
for cls in ear_class_ids:
ears_mask[parsing == cls] = 1
ear_y, ear_x = np.where(ears_mask)
if len(ear_y) > 0:
ear_top_y = ear_y.min()
ear_height = ear_y.max() - ear_top_y + 1
kernel_v = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (11, 30))
ears_protected = cv2.dilate(ears_mask, kernel_v, iterations=2)
top_margin = max(8, int(ear_height * 0.12))
top_start = max(0, ear_top_y - top_margin)
ear_x_min, ear_x_max = ear_x.min(), ear_x.max()
ear_width = ear_x_max - ear_x_min + 1
x_margin = int(ear_width * 0.35)
protected_left = max(0, ear_x_min - x_margin)
protected_right = min(working_w, ear_x_max + x_margin)
limited_top_mask = np.zeros_like(ears_mask)
limited_top_mask[top_start:ear_top_y + 8, protected_left:protected_right] = 1
kernel_h = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (17, 5))
limited_top_mask = cv2.dilate(limited_top_mask, kernel_h, iterations=1)
ears_protected = np.logical_or(ears_protected, limited_top_mask).astype(np.uint8)
hair_above_ears = np.zeros_like(hair_mask)
above_ear_line = max(0, ear_top_y - int(ear_height * 0.65))
hair_above_ears[:above_ear_line, :] = hair_mask[:above_ear_line, :]
ears_protected[hair_above_ears == 1] = 0
else:
ears_protected = np.zeros_like(hair_mask)
hair_mask_final = hair_mask.copy()
hair_mask_final[ears_protected == 1] = 0
if hair_mask[:int(working_h * 0.25), :].sum() > 60:
hair_mask_final[:int(working_h * 0.25), :] = np.maximum(
hair_mask_final[:int(working_h * 0.25), :], hair_mask[:int(working_h * 0.25), :]
)
kernel_s = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13))
hair_mask_final = cv2.morphologyEx(hair_mask_final, cv2.MORPH_CLOSE, kernel_s, iterations=2)
hair_mask_final = cv2.dilate(hair_mask_final, kernel_s, iterations=1)
blurred = cv2.GaussianBlur(hair_mask_final.astype(np.float32), (9, 9), 3)
hair_mask_final = (blurred > 0.28).astype(np.uint8)
kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
hair_mask_final = cv2.dilate(hair_mask_final, kernel_edge, iterations=1)
hair_pixels = np.sum(hair_mask_final)
logger.info(f"Hair pixels (resized): {hair_pixels:,}")
final_mask = hair_mask_final.copy()
use_extended_mask = False
if hair_pixels > 380000:
logger.info("Large hair β†’ extended mask")
use_extended_mask = True
big_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (25, 25))
extended = cv2.dilate(hair_mask_final, big_kernel, iterations=1)
upper = np.zeros_like(hair_mask_final)
upper_end = int(working_h * 0.48)
upper[:upper_end, :] = 1
extended = np.logical_or(extended, upper).astype(np.uint8)
extended[ears_protected == 1] = 0
if np.mean(working_np) < 110:
hsv = cv2.cvtColor(working_np, cv2.COLOR_RGB2HSV)
dark_lower = np.array([0, 0, 0])
dark_upper = np.array([180, 70, 90])
dark_mask = cv2.inRange(hsv, dark_lower, dark_upper)
extended = np.logical_or(extended, (dark_mask > 127)).astype(np.uint8)
extended = cv2.morphologyEx(extended, cv2.MORPH_CLOSE, kernel_s, iterations=1)
extended[int(working_h * 0.75):, :] = 0
final_mask = extended
if use_extended_mask or hair_pixels > 420000:
radius, flag = 18, cv2.INPAINT_TELEA
elif hair_pixels > 220000:
radius, flag = 15, cv2.INPAINT_TELEA
else:
radius, flag = 10, cv2.INPAINT_NS
logger.info(f"Inpainting radius={radius}")
inpainted_bgr = cv2.inpaint(working_bgr, final_mask * 255, inpaintRadius=radius, flags=flag)
inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
result_small = working_np.copy()
result_small[final_mask == 1] = inpainted_rgb[final_mask == 1]
if use_extended_mask or hair_pixels > 280000:
logger.info("Skin color correction")
regions = [(0.18, 0.30, 0.34, 0.66), (0.32, 0.47, 0.35, 0.65)]
colors = []
for y1r, y2r, x1r, x2r in regions:
y1, y2 = int(working_h * y1r), int(working_h * y2r)
x1, x2 = int(working_w * x1r), int(working_w * x2r)
if y2 > y1 + 40 and x2 > x1 + 80:
crop = working_np[y1:y2, x1:x2]
if crop.size > 0:
colors.append(np.median(crop, axis=(0,1)).astype(np.float32))
if colors:
target_color = np.mean(colors, axis=0)
brightness = np.mean(target_color)
strength = 0.82 if brightness > 145 else 0.62 if brightness < 85 else 0.74
bald_area = result_small[final_mask == 1].astype(np.float32)
if len(bald_area) > 200:
current_mean = bald_area.mean(axis=0)
diff = target_color - current_mean
corrected = np.clip(bald_area + diff * strength, 0, 255).astype(np.uint8)
result_small[final_mask == 1] = corrected
if hair_pixels > 90000 or use_extended_mask:
blurred_bald = cv2.GaussianBlur(result_small, (5, 5), 0.8)
result_small[final_mask == 1] = cv2.addWeighted(
result_small[final_mask == 1], 0.65, blurred_bald[final_mask == 1], 0.35, 0
)
if scale_factor < 1.0:
logger.info("Upscaling to original size")
result = cv2.resize(result_small, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
else:
result = result_small
result_pil = Image.fromarray(result)
comparison = np.hstack((original_np, result))
comparison_pil = Image.fromarray(comparison)
final_mask_big = cv2.resize(final_mask.astype(np.uint8) * 255, (orig_w, orig_h), cv2.INTER_NEAREST) > 127
mask_vis = np.zeros_like(original_np)
mask_vis[final_mask_big] = [255, 70, 70]
mask_overlay = cv2.addWeighted(original_np, 0.78, mask_vis, 0.22, 0)
mask_pil = Image.fromarray(mask_overlay)
return result_pil, comparison_pil, mask_pil
except Exception as e:
logger.error(f"Error: {str(e)}", exc_info=True)
raise gr.Error(f"Processing failed: {str(e)}. Try smaller image.")
with gr.Blocks(title="Make Me Bald πŸ§‘β€πŸ¦²", theme=gr.themes.Soft()) as demo:
gr.Markdown("# Realistic Bald Maker πŸ”₯")
gr.Markdown("Upload face photo β†’ get bald version with natural skin blending. Ears protected, no weird halos!")
with gr.Row():
input_img = gr.Image(type="pil", label="Your Photo", sources=["upload", "webcam"])
output_bald = gr.Image(label="Bald Version")
with gr.Row():
comparison = gr.Image(label="Before vs After")
mask_overlay = gr.Image(label="Hair Mask Overlay (red = removed area)")
btn = gr.Button("Make Bald 😎", variant="primary")
btn.click(
fn=make_realistic_bald,
inputs=input_img,
outputs=[output_bald, comparison, mask_overlay],
api_name="make_bald"
)
gr.Examples(
examples=[["example1.jpg"], ["example2.jpg"]], # agar examples folder mein daale to
inputs=input_img,
label="Try these examples"
)
gr.Markdown("""
**Tips:**
- Best results on clear front-facing photos.
- Large images auto-resized for speed (then upscaled).
- If no hair detected β†’ try another photo.
""")
# NO demo.launch() here β€” HF Spaces handles it automatically!