Seniordev22 commited on
Commit
e8a4bf5
·
verified ·
1 Parent(s): 1bac48a

Update bald_processor.py

Browse files
Files changed (1) hide show
  1. bald_processor.py +21 -27
bald_processor.py CHANGED
@@ -10,11 +10,9 @@ import logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- # ---------------- DEVICE ----------------
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  logger.info(f"Using device: {device}")
16
 
17
- # ---------------- MODEL ----------------
18
  try:
19
  logger.info("Loading SegFormer face-parsing model...")
20
  processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
@@ -26,39 +24,38 @@ except Exception as e:
26
  logger.error(f"Failed to load model: {e}", exc_info=True)
27
  raise RuntimeError("SegFormer model load failed!")
28
 
29
- # ---------------- CLASS IDS ----------------
30
  hair_class_id = 13
31
  ear_class_ids = [7, 8]
32
  skin_class_id = 1
33
  nose_class_id = 2 # fallback
34
 
35
- # ---------------- CORE FUNCTION ----------------
36
  def make_realistic_bald(input_image: Image.Image) -> Image.Image:
37
  """
38
  Takes PIL Image, returns PIL Image bald version.
 
39
  """
40
  if input_image is None:
41
  raise ValueError("No input image provided!")
42
 
43
  try:
44
- # -------- ORIGINAL IMAGE & RESIZE --------
45
  orig_w, orig_h = input_image.size
46
  original_np = np.array(input_image)
47
  original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
48
 
 
49
  MAX_DIM = 2048
50
  scale_factor = 1.0
51
- working_np = original_np
52
- working_bgr = original_bgr
53
  working_h, working_w = orig_h, orig_w
54
 
55
  if max(orig_w, orig_h) > MAX_DIM:
56
  scale_factor = MAX_DIM / max(orig_w, orig_h)
57
  working_w, working_h = int(orig_w*scale_factor), int(orig_h*scale_factor)
58
- working_np = cv2.resize(original_np, (working_w, working_h), cv2.INTER_AREA)
59
  working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
60
 
61
- # -------- SEGMENTATION --------
62
  pil_working = Image.fromarray(working_np)
63
  inputs = processor(images=pil_working, return_tensors="pt").to(device)
64
  with torch.no_grad():
@@ -71,15 +68,15 @@ def make_realistic_bald(input_image: Image.Image) -> Image.Image:
71
  )
72
  parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
73
 
74
- # -------- HAIR & EARS MASK --------
75
  hair_mask = (parsing == hair_class_id).astype(np.uint8)
76
  ears_mask = np.zeros_like(hair_mask)
77
  for cls in ear_class_ids:
78
  ears_mask[parsing == cls] = 1
79
 
80
- hair_mask[ears_mask == 1] = 0
81
 
82
- # Smooth & clean hair mask
83
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,13))
84
  hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
85
  hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
@@ -89,30 +86,27 @@ def make_realistic_bald(input_image: Image.Image) -> Image.Image:
89
  if hair_pixels < 50:
90
  raise ValueError("NO_HAIR_DETECTED")
91
 
92
- # -------- INPAINTING --------
93
  radius = 15 if hair_pixels > 220000 else 10
94
  flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
95
  inpainted_bgr = cv2.inpaint(working_bgr, hair_mask*255, inpaintRadius=radius, flags=flag)
96
  inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
97
 
98
- # -------- ADD SUBTLE SKIN TEXTURE --------
99
  noise = np.random.normal(0,12,(working_h, working_w,3)).astype(np.float32)
100
- blended = np.clip(inpainted_rgb + noise*0.7, 0,255).astype(np.uint8)
101
 
102
- # -------- PREPARE RESULT SMALL & FINAL MASK --------
103
- result_small = working_np.copy()
104
- final_mask = hair_mask.copy()
105
- result_small[final_mask == 1] = blended[final_mask == 1]
106
-
107
- # -------- FINAL COMPOSITING (NO BLUR FIX) --------
108
  if scale_factor < 1.0:
109
- # Upscale ONLY bald area
110
- bald_up = cv2.resize(result_small, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
111
- mask_up = cv2.resize(final_mask.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
112
- result = original_np.copy()
113
- result[mask_up == 1] = bald_up[mask_up == 1]
114
  else:
115
- result = result_small
 
 
 
 
116
 
117
  return Image.fromarray(result)
118
 
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  logger.info(f"Using device: {device}")
15
 
 
16
  try:
17
  logger.info("Loading SegFormer face-parsing model...")
18
  processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
 
24
  logger.error(f"Failed to load model: {e}", exc_info=True)
25
  raise RuntimeError("SegFormer model load failed!")
26
 
 
27
  hair_class_id = 13
28
  ear_class_ids = [7, 8]
29
  skin_class_id = 1
30
  nose_class_id = 2 # fallback
31
 
 
32
  def make_realistic_bald(input_image: Image.Image) -> Image.Image:
33
  """
34
  Takes PIL Image, returns PIL Image bald version.
35
+ Only bald area is modified; rest of image stays sharp.
36
  """
37
  if input_image is None:
38
  raise ValueError("No input image provided!")
39
 
40
  try:
 
41
  orig_w, orig_h = input_image.size
42
  original_np = np.array(input_image)
43
  original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
44
 
45
+ # ---------------- RESIZE FOR PROCESSING ----------------
46
  MAX_DIM = 2048
47
  scale_factor = 1.0
48
+ working_np = original_np.copy()
49
+ working_bgr = original_bgr.copy()
50
  working_h, working_w = orig_h, orig_w
51
 
52
  if max(orig_w, orig_h) > MAX_DIM:
53
  scale_factor = MAX_DIM / max(orig_w, orig_h)
54
  working_w, working_h = int(orig_w*scale_factor), int(orig_h*scale_factor)
55
+ working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
56
  working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
57
 
58
+ # ---------------- SEGMENTATION ----------------
59
  pil_working = Image.fromarray(working_np)
60
  inputs = processor(images=pil_working, return_tensors="pt").to(device)
61
  with torch.no_grad():
 
68
  )
69
  parsing = upsampled.argmax(dim=1).squeeze(0).cpu().numpy()
70
 
71
+ # ---------------- HAIR MASK ----------------
72
  hair_mask = (parsing == hair_class_id).astype(np.uint8)
73
  ears_mask = np.zeros_like(hair_mask)
74
  for cls in ear_class_ids:
75
  ears_mask[parsing == cls] = 1
76
 
77
+ hair_mask[ears_mask==1] = 0
78
 
79
+ # Morphology
80
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13,13))
81
  hair_mask = cv2.morphologyEx(hair_mask, cv2.MORPH_CLOSE, kernel, iterations=2)
82
  hair_mask = cv2.dilate(hair_mask, kernel, iterations=1)
 
86
  if hair_pixels < 50:
87
  raise ValueError("NO_HAIR_DETECTED")
88
 
89
+ # ---------------- INPAINT ----------------
90
  radius = 15 if hair_pixels > 220000 else 10
91
  flag = cv2.INPAINT_TELEA if hair_pixels > 220000 else cv2.INPAINT_NS
92
  inpainted_bgr = cv2.inpaint(working_bgr, hair_mask*255, inpaintRadius=radius, flags=flag)
93
  inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
94
 
95
+ # ---------------- ADD SUBTLE SKIN TEXTURE ----------------
96
  noise = np.random.normal(0,12,(working_h, working_w,3)).astype(np.float32)
97
+ bald_area = np.clip(inpainted_rgb + noise*0.7, 0,255).astype(np.uint8)
98
 
99
+ # ---------------- COMPOSITE BACK ON ORIGINAL IMAGE ----------------
 
 
 
 
 
100
  if scale_factor < 1.0:
101
+ # Upscale bald area mask and content separately
102
+ bald_up = cv2.resize(bald_area, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
103
+ mask_up = cv2.resize(hair_mask.astype(np.uint8), (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
 
 
104
  else:
105
+ bald_up = bald_area
106
+ mask_up = hair_mask
107
+
108
+ result = original_np.copy()
109
+ result[mask_up==1] = bald_up[mask_up==1]
110
 
111
  return Image.fromarray(result)
112