Spaces:
Running
Running
Update bald_processor.py
Browse files- bald_processor.py +29 -54
bald_processor.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import cv2
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
|
@@ -23,88 +24,65 @@ except Exception as e:
|
|
| 23 |
raise RuntimeError("SegFormer model load failed!")
|
| 24 |
|
| 25 |
hair_class_id = 13
|
| 26 |
-
ear_class_ids = [
|
| 27 |
-
skin_class_id = 1 # Added for color correction reference
|
| 28 |
|
| 29 |
def make_realistic_bald(input_image: Image.Image) -> Image.Image:
|
| 30 |
if input_image is None:
|
| 31 |
raise ValueError("No input image provided!")
|
| 32 |
-
|
| 33 |
try:
|
| 34 |
orig_w, orig_h = input_image.size
|
| 35 |
original_np = np.array(input_image)
|
| 36 |
original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
|
| 37 |
-
|
| 38 |
MAX_DIM = 2048
|
| 39 |
scale_factor = 1.0
|
| 40 |
working_np = original_np.copy()
|
| 41 |
working_bgr = original_bgr.copy()
|
| 42 |
working_h, working_w = orig_h, orig_w
|
| 43 |
-
|
| 44 |
if max(orig_w, orig_h) > MAX_DIM:
|
| 45 |
scale_factor = MAX_DIM / max(orig_w, orig_h)
|
| 46 |
-
working_w, working_h = int(orig_w
|
| 47 |
working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
|
| 48 |
working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
|
| 49 |
-
|
| 50 |
# Segmentation
|
| 51 |
pil_working = Image.fromarray(working_np)
|
| 52 |
inputs = processor(images=pil_working, return_tensors="pt").to(device)
|
| 53 |
with torch.no_grad():
|
| 54 |
outputs = model(**inputs)
|
| 55 |
logits = outputs.logits
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
| 61 |
# Hair mask
|
| 62 |
hair_mask = (parsing == hair_class_id).astype(np.uint8)
|
| 63 |
ears_mask = np.zeros_like(hair_mask)
|
| 64 |
for cls in ear_class_ids:
|
| 65 |
ears_mask[parsing == cls] = 1
|
| 66 |
-
hair_mask[ears_mask
|
| 67 |
-
|
| 68 |
# Morphology to clean
|
| 69 |
-
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,
|
| 70 |
hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
|
| 71 |
hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
|
| 72 |
-
hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5,
|
| 73 |
-
|
| 74 |
hair_pixels = np.sum(hair_mask)
|
| 75 |
if hair_pixels < 50:
|
| 76 |
raise ValueError("NO_HAIR_DETECTED")
|
| 77 |
-
|
| 78 |
# Inpainting (no extra noise, no blur)
|
| 79 |
radius = 15 if hair_pixels > 220000 else 10
|
| 80 |
flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
|
| 81 |
-
inpainted_bgr = cv2.inpaint(working_bgr, hair_mask
|
| 82 |
-
|
| 83 |
-
# Conditional color correction for large hair areas
|
| 84 |
-
if hair_pixels > 220000:
|
| 85 |
-
# Skin mask for reference
|
| 86 |
-
skin_mask = (parsing == skin_class_id).astype(np.uint8)
|
| 87 |
-
|
| 88 |
-
# Reference mask: skin excluding hair area
|
| 89 |
-
ref_mask = skin_mask.copy()
|
| 90 |
-
ref_mask[hair_mask == 1] = 0
|
| 91 |
-
ref_mask = cv2.dilate(ref_mask, np.ones((5, 5), np.uint8), iterations=1) # Slight expand
|
| 92 |
-
|
| 93 |
-
# Mean colors (BGR)
|
| 94 |
-
ref_mean = cv2.mean(working_bgr, mask=ref_mask * 255)[:3]
|
| 95 |
-
inpainted_mean = cv2.mean(inpainted_bgr, mask=hair_mask * 255)[:3]
|
| 96 |
-
|
| 97 |
-
# Color difference
|
| 98 |
-
color_diff = np.array(ref_mean) - np.array(inpainted_mean)
|
| 99 |
-
|
| 100 |
-
# Adjust inpainted area
|
| 101 |
-
hair_mask_3ch = np.repeat(hair_mask[:, :, np.newaxis], 3, axis=2)
|
| 102 |
-
inpainted_bgr[hair_mask_3ch == 1] = np.clip(
|
| 103 |
-
inpainted_bgr[hair_mask_3ch == 1] + color_diff, 0, 255
|
| 104 |
-
).astype(np.uint8)
|
| 105 |
-
|
| 106 |
inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
|
| 107 |
-
|
| 108 |
# Upscale bald area if needed
|
| 109 |
if scale_factor < 1.0:
|
| 110 |
bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
|
|
@@ -112,18 +90,15 @@ def make_realistic_bald(input_image: Image.Image) -> Image.Image:
|
|
| 112 |
else:
|
| 113 |
bald_up = inpainted_rgb
|
| 114 |
mask_up = hair_mask
|
| 115 |
-
|
| 116 |
-
#
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
result = (1 - mask_up_float[..., None]) * original_np + mask_up_float[..., None] * bald_up
|
| 121 |
-
result = np.clip(result, 0, 255).astype(np.uint8)
|
| 122 |
-
|
| 123 |
return Image.fromarray(result)
|
| 124 |
-
|
| 125 |
except UnidentifiedImageError:
|
| 126 |
raise ValueError("Invalid image format or corrupt image!")
|
| 127 |
except Exception as e:
|
| 128 |
logger.error(f"Bald processing failed: {str(e)}", exc_info=True)
|
| 129 |
-
raise RuntimeError(f"Bald processing failed: {str(e)}")
|
|
|
|
| 1 |
+
# bald_processor_clean.py
|
| 2 |
import cv2
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
|
|
|
| 24 |
raise RuntimeError("SegFormer model load failed!")
|
| 25 |
|
| 26 |
hair_class_id = 13
|
| 27 |
+
ear_class_ids = [7, 8]
|
|
|
|
| 28 |
|
| 29 |
def make_realistic_bald(input_image: Image.Image) -> Image.Image:
|
| 30 |
if input_image is None:
|
| 31 |
raise ValueError("No input image provided!")
|
| 32 |
+
|
| 33 |
try:
|
| 34 |
orig_w, orig_h = input_image.size
|
| 35 |
original_np = np.array(input_image)
|
| 36 |
original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
|
| 37 |
+
|
| 38 |
MAX_DIM = 2048
|
| 39 |
scale_factor = 1.0
|
| 40 |
working_np = original_np.copy()
|
| 41 |
working_bgr = original_bgr.copy()
|
| 42 |
working_h, working_w = orig_h, orig_w
|
| 43 |
+
|
| 44 |
if max(orig_w, orig_h) > MAX_DIM:
|
| 45 |
scale_factor = MAX_DIM / max(orig_w, orig_h)
|
| 46 |
+
working_w, working_h = int(orig_w*scale_factor), int(orig_h*scale_factor)
|
| 47 |
working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
|
| 48 |
working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
|
| 49 |
+
|
| 50 |
# Segmentation
|
| 51 |
pil_working = Image.fromarray(working_np)
|
| 52 |
inputs = processor(images=pil_working, return_tensors="pt").to(device)
|
| 53 |
with torch.no_grad():
|
| 54 |
outputs = model(**inputs)
|
| 55 |
logits = outputs.logits
|
| 56 |
+
|
| 57 |
+
upsampled = torch.nn.functional.interpolate(
|
| 58 |
+
logits, size=(working_h, working_w),
|
| 59 |
+
mode="bilinear", align_corners=False
|
| 60 |
+
)
|
| 61 |
+
parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
|
| 62 |
+
|
| 63 |
# Hair mask
|
| 64 |
hair_mask = (parsing == hair_class_id).astype(np.uint8)
|
| 65 |
ears_mask = np.zeros_like(hair_mask)
|
| 66 |
for cls in ear_class_ids:
|
| 67 |
ears_mask[parsing == cls] = 1
|
| 68 |
+
hair_mask[ears_mask==1] = 0
|
| 69 |
+
|
| 70 |
# Morphology to clean
|
| 71 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,13))
|
| 72 |
hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
|
| 73 |
hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
|
| 74 |
+
hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5,5), 0) > 0.28).astype(np.uint8)
|
| 75 |
+
|
| 76 |
hair_pixels = np.sum(hair_mask)
|
| 77 |
if hair_pixels < 50:
|
| 78 |
raise ValueError("NO_HAIR_DETECTED")
|
| 79 |
+
|
| 80 |
# Inpainting (no extra noise, no blur)
|
| 81 |
radius = 15 if hair_pixels > 220000 else 10
|
| 82 |
flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
|
| 83 |
+
inpainted_bgr = cv2.inpaint(working_bgr, hair_mask*255, inpaintRadius=radius, flags=flag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
|
| 85 |
+
|
| 86 |
# Upscale bald area if needed
|
| 87 |
if scale_factor < 1.0:
|
| 88 |
bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
|
|
|
|
| 90 |
else:
|
| 91 |
bald_up = inpainted_rgb
|
| 92 |
mask_up = hair_mask
|
| 93 |
+
|
| 94 |
+
# Composite only bald area, rest untouched
|
| 95 |
+
result = original_np.copy()
|
| 96 |
+
result[mask_up==1] = bald_up[mask_up==1]
|
| 97 |
+
|
|
|
|
|
|
|
|
|
|
| 98 |
return Image.fromarray(result)
|
| 99 |
+
|
| 100 |
except UnidentifiedImageError:
|
| 101 |
raise ValueError("Invalid image format or corrupt image!")
|
| 102 |
except Exception as e:
|
| 103 |
logger.error(f"Bald processing failed: {str(e)}", exc_info=True)
|
| 104 |
+
raise RuntimeError(f"Bald processing failed: {str(e)}")
|