Seniordev22 commited on
Commit
0141af9
·
verified ·
1 Parent(s): 31c788a

Update bald_processor.py

Browse files
Files changed (1) hide show
  1. bald_processor.py +29 -54
bald_processor.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import cv2
2
  import torch
3
  import numpy as np
@@ -23,88 +24,65 @@ except Exception as e:
23
  raise RuntimeError("SegFormer model load failed!")
24
 
25
  hair_class_id = 13
26
- ear_class_ids = [8, 9] # Corrected: 8 for left_ear, 9 for right_ear
27
- skin_class_id = 1 # Added for color correction reference
28
 
29
  def make_realistic_bald(input_image: Image.Image) -> Image.Image:
30
  if input_image is None:
31
  raise ValueError("No input image provided!")
32
-
33
  try:
34
  orig_w, orig_h = input_image.size
35
  original_np = np.array(input_image)
36
  original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
37
-
38
  MAX_DIM = 2048
39
  scale_factor = 1.0
40
  working_np = original_np.copy()
41
  working_bgr = original_bgr.copy()
42
  working_h, working_w = orig_h, orig_w
43
-
44
  if max(orig_w, orig_h) > MAX_DIM:
45
  scale_factor = MAX_DIM / max(orig_w, orig_h)
46
- working_w, working_h = int(orig_w * scale_factor), int(orig_h * scale_factor)
47
  working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
48
  working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
49
-
50
  # Segmentation
51
  pil_working = Image.fromarray(working_np)
52
  inputs = processor(images=pil_working, return_tensors="pt").to(device)
53
  with torch.no_grad():
54
  outputs = model(**inputs)
55
  logits = outputs.logits
56
- upsampled = torch.nn.functional.interpolate(
57
- logits, size=(working_h, working_w), mode="bilinear", align_corners=False
58
- )
59
- parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
60
-
 
 
61
  # Hair mask
62
  hair_mask = (parsing == hair_class_id).astype(np.uint8)
63
  ears_mask = np.zeros_like(hair_mask)
64
  for cls in ear_class_ids:
65
  ears_mask[parsing == cls] = 1
66
- hair_mask[ears_mask == 1] = 0
67
-
68
  # Morphology to clean
69
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13))
70
  hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
71
  hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
72
- hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5, 5), 0) > 0.28).astype(np.uint8)
73
-
74
  hair_pixels = np.sum(hair_mask)
75
  if hair_pixels < 50:
76
  raise ValueError("NO_HAIR_DETECTED")
77
-
78
  # Inpainting (no extra noise, no blur)
79
  radius = 15 if hair_pixels > 220000 else 10
80
  flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
81
- inpainted_bgr = cv2.inpaint(working_bgr, hair_mask * 255, inpaintRadius=radius, flags=flag)
82
-
83
- # Conditional color correction for large hair areas
84
- if hair_pixels > 220000:
85
- # Skin mask for reference
86
- skin_mask = (parsing == skin_class_id).astype(np.uint8)
87
-
88
- # Reference mask: skin excluding hair area
89
- ref_mask = skin_mask.copy()
90
- ref_mask[hair_mask == 1] = 0
91
- ref_mask = cv2.dilate(ref_mask, np.ones((5, 5), np.uint8), iterations=1) # Slight expand
92
-
93
- # Mean colors (BGR)
94
- ref_mean = cv2.mean(working_bgr, mask=ref_mask * 255)[:3]
95
- inpainted_mean = cv2.mean(inpainted_bgr, mask=hair_mask * 255)[:3]
96
-
97
- # Color difference
98
- color_diff = np.array(ref_mean) - np.array(inpainted_mean)
99
-
100
- # Adjust inpainted area
101
- hair_mask_3ch = np.repeat(hair_mask[:, :, np.newaxis], 3, axis=2)
102
- inpainted_bgr[hair_mask_3ch == 1] = np.clip(
103
- inpainted_bgr[hair_mask_3ch == 1] + color_diff, 0, 255
104
- ).astype(np.uint8)
105
-
106
  inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
107
-
108
  # Upscale bald area if needed
109
  if scale_factor < 1.0:
110
  bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
@@ -112,18 +90,15 @@ def make_realistic_bald(input_image: Image.Image) -> Image.Image:
112
  else:
113
  bald_up = inpainted_rgb
114
  mask_up = hair_mask
115
-
116
- # Soft blending for edges
117
- mask_up_float = cv2.GaussianBlur(mask_up.astype(np.float32) * 255, (21, 21), 0) / 255.0
118
-
119
- # Composite with alpha blend
120
- result = (1 - mask_up_float[..., None]) * original_np + mask_up_float[..., None] * bald_up
121
- result = np.clip(result, 0, 255).astype(np.uint8)
122
-
123
  return Image.fromarray(result)
124
-
125
  except UnidentifiedImageError:
126
  raise ValueError("Invalid image format or corrupt image!")
127
  except Exception as e:
128
  logger.error(f"Bald processing failed: {str(e)}", exc_info=True)
129
- raise RuntimeError(f"Bald processing failed: {str(e)}")
 
1
+ # bald_processor_clean.py
2
  import cv2
3
  import torch
4
  import numpy as np
 
24
  raise RuntimeError("SegFormer model load failed!")
25
 
26
  hair_class_id = 13
27
+ ear_class_ids = [7, 8]
 
28
 
29
  def make_realistic_bald(input_image: Image.Image) -> Image.Image:
30
  if input_image is None:
31
  raise ValueError("No input image provided!")
32
+
33
  try:
34
  orig_w, orig_h = input_image.size
35
  original_np = np.array(input_image)
36
  original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
37
+
38
  MAX_DIM = 2048
39
  scale_factor = 1.0
40
  working_np = original_np.copy()
41
  working_bgr = original_bgr.copy()
42
  working_h, working_w = orig_h, orig_w
43
+
44
  if max(orig_w, orig_h) > MAX_DIM:
45
  scale_factor = MAX_DIM / max(orig_w, orig_h)
46
+ working_w, working_h = int(orig_w*scale_factor), int(orig_h*scale_factor)
47
  working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
48
  working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
49
+
50
  # Segmentation
51
  pil_working = Image.fromarray(working_np)
52
  inputs = processor(images=pil_working, return_tensors="pt").to(device)
53
  with torch.no_grad():
54
  outputs = model(**inputs)
55
  logits = outputs.logits
56
+
57
+ upsampled = torch.nn.functional.interpolate(
58
+ logits, size=(working_h, working_w),
59
+ mode="bilinear", align_corners=False
60
+ )
61
+ parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
62
+
63
  # Hair mask
64
  hair_mask = (parsing == hair_class_id).astype(np.uint8)
65
  ears_mask = np.zeros_like(hair_mask)
66
  for cls in ear_class_ids:
67
  ears_mask[parsing == cls] = 1
68
+ hair_mask[ears_mask==1] = 0
69
+
70
  # Morphology to clean
71
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,13))
72
  hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
73
  hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
74
+ hair_mask = (cv2.GaussianBlur(hair_mask.astype(np.float32), (5,5), 0) > 0.28).astype(np.uint8)
75
+
76
  hair_pixels = np.sum(hair_mask)
77
  if hair_pixels < 50:
78
  raise ValueError("NO_HAIR_DETECTED")
79
+
80
  # Inpainting (no extra noise, no blur)
81
  radius = 15 if hair_pixels > 220000 else 10
82
  flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
83
+ inpainted_bgr = cv2.inpaint(working_bgr, hair_mask*255, inpaintRadius=radius, flags=flag)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
85
+
86
  # Upscale bald area if needed
87
  if scale_factor < 1.0:
88
  bald_up = cv2.resize(inpainted_rgb, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
 
90
  else:
91
  bald_up = inpainted_rgb
92
  mask_up = hair_mask
93
+
94
+ # Composite only bald area, rest untouched
95
+ result = original_np.copy()
96
+ result[mask_up==1] = bald_up[mask_up==1]
97
+
 
 
 
98
  return Image.fromarray(result)
99
+
100
  except UnidentifiedImageError:
101
  raise ValueError("Invalid image format or corrupt image!")
102
  except Exception as e:
103
  logger.error(f"Bald processing failed: {str(e)}", exc_info=True)
104
+ raise RuntimeError(f"Bald processing failed: {str(e)}")