Seniordev22 commited on
Commit
e2eb2c5
·
verified ·
1 Parent(s): 1fe9ffa

Create bald_processor.py

Browse files
Files changed (1) hide show
  1. bald_processor.py +241 -0
bald_processor.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # bald_processor.py
2
+ import os
3
+ import cv2
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image, UnidentifiedImageError
7
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
8
+ import io
9
+ import traceback
10
+
11
+ # Global model load with error handling
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ print(f"Using device: {device} | CUDA available: {torch.cuda.is_available()}")
14
+
15
+ print("Loading SegFormer face-parsing model...")
16
+ try:
17
+ processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
18
+ model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
19
+ model.to(device)
20
+ model.eval()
21
+ print("Model loaded successfully!")
22
+ except Exception as e:
23
+ print(f"CRITICAL: Model loading failed! {str(e)}")
24
+ traceback.print_exc()
25
+
26
+ hair_class_id = 13
27
+ ear_class_ids = [8, 9] # l_ear=8, r_ear=9
28
+ skin_class_id = 1
29
+ nose_class_id = 2 # Reliable fallback for clean skin tone
30
+
31
+ def make_realistic_bald(image_bytes: bytes) -> bytes:
32
+ try:
33
+ # Open image safely
34
+ try:
35
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
36
+ except UnidentifiedImageError:
37
+ raise ValueError("Invalid image format or corrupt bytes")
38
+ except Exception as e:
39
+ raise ValueError(f"Image open failed: {str(e)}")
40
+
41
+ orig_w, orig_h = image.size
42
+ original_np = np.array(image)
43
+ original_bgr = cv2.cvtColor(original_np, cv2.COLOR_RGB2BGR)
44
+
45
+ # Resize if large
46
+ MAX_PROCESS_DIM = 2048
47
+ scale_factor = 1.0
48
+ working_np = original_np
49
+ working_bgr = original_bgr
50
+ working_h, working_w = orig_h, orig_w
51
+
52
+ if max(orig_w, orig_h) > MAX_PROCESS_DIM:
53
+ scale_factor = MAX_PROCESS_DIM / max(orig_w, orig_h)
54
+ working_w = int(orig_w * scale_factor)
55
+ working_h = int(orig_h * scale_factor)
56
+ working_np = cv2.resize(original_np, (working_w, working_h), interpolation=cv2.INTER_AREA)
57
+ working_bgr = cv2.cvtColor(working_np, cv2.COLOR_RGB2BGR)
58
+
59
+ # Segmentation
60
+ pil_working = Image.fromarray(working_np)
61
+ inputs = processor(images=pil_working, return_tensors="pt").to(device)
62
+
63
+ with torch.no_grad():
64
+ outputs = model(**inputs)
65
+ logits = outputs.logits
66
+
67
+ upsampled_logits = torch.nn.functional.interpolate(
68
+ logits, size=(working_h, working_w), mode="bilinear", align_corners=False
69
+ )
70
+ parsing = upsampled_logits.argmax(dim=1).squeeze(0).cpu().numpy()
71
+
72
+ # Skin mask
73
+ skin_mask = (parsing == skin_class_id).astype(np.uint8)
74
+
75
+ # IMPROVED Forehead region (better pixel coverage)
76
+ forehead_fraction_top = 0.25 # thoda neeche
77
+ forehead_fraction_bottom = 0.38 # zyada coverage
78
+ forehead_fraction_left = 0.38
79
+ forehead_fraction_right = 0.62 # wider center
80
+
81
+ h, w = parsing.shape
82
+ forehead_y_start = max(0, int(h * forehead_fraction_top))
83
+ forehead_y_end = min(h, int(h * forehead_fraction_bottom))
84
+ forehead_x_start = max(0, int(w * forehead_fraction_left))
85
+ forehead_x_end = min(w, int(w * forehead_fraction_right))
86
+
87
+ forehead_region = original_np[forehead_y_start:forehead_y_end, forehead_x_start:forehead_x_end]
88
+ forehead_skin_mask = skin_mask[forehead_y_start:forehead_y_end, forehead_x_start:forehead_x_end]
89
+
90
+ mean_color_rgb = np.array([210, 185, 170]) # Lighter neutral fallback
91
+
92
+ try:
93
+ if forehead_region.size > 0 and np.sum(forehead_skin_mask) > 80:
94
+ skin_pixels = forehead_region[forehead_skin_mask == 1]
95
+ if len(skin_pixels) > 30:
96
+ brightness = np.mean(skin_pixels.astype(float), axis=1)
97
+ thresh = np.percentile(brightness, 70)
98
+ bright_pixels = skin_pixels[brightness > thresh]
99
+ if len(bright_pixels) > 20:
100
+ mean_color_rgb = np.mean(bright_pixels, axis=0).astype(int)
101
+ else:
102
+ mean_color_rgb = np.mean(skin_pixels, axis=0).astype(int)
103
+ else:
104
+ mean_color_rgb = np.mean(forehead_region, axis=(0,1)).astype(int)
105
+ else:
106
+ # Fallback 1: Nose
107
+ nose_mask = (parsing == nose_class_id).astype(np.uint8)
108
+ nose_pixels = original_np[nose_mask == 1]
109
+ if len(nose_pixels) > 50:
110
+ mean_color_rgb = np.mean(nose_pixels, axis=0).astype(int)
111
+ else:
112
+ # Fallback 2: Full skin
113
+ skin_pixels_full = original_np[skin_mask == 1]
114
+ if len(skin_pixels_full) > 100:
115
+ mean_color_rgb = np.mean(skin_pixels_full, axis=0).astype(int)
116
+ except Exception as skin_err:
117
+ print(f"Skin detection error (fallback used): {str(skin_err)}")
118
+
119
+ # Make detected skin color 30% brighter
120
+ mean_color_rgb = np.array(mean_color_rgb, dtype=float)
121
+ brightness_factor = 1.30 # 30% brighter (change to 1.20 / 1.40 if needed)
122
+ mean_color_rgb = np.clip(mean_color_rgb * brightness_factor, 0, 255).astype(int)
123
+
124
+ # Clear forehead color print (updated one)
125
+ hex_color = '#%02x%02x%02x' % tuple(mean_color_rgb)
126
+ print(f"Adjusted (30% brighter) skin color → RGB: {mean_color_rgb.tolist()} | Hex: {hex_color}")
127
+
128
+ hair_mask = (parsing == hair_class_id).astype(np.uint8)
129
+
130
+ ears_mask = np.zeros_like(hair_mask, dtype=np.uint8)
131
+ for cls in ear_class_ids:
132
+ ears_mask[parsing == cls] = 1
133
+
134
+ ears_protected = np.zeros_like(hair_mask, dtype=np.uint8)
135
+ ear_y, ear_x = np.where(ears_mask > 0)
136
+
137
+ left, right = 0, 0
138
+ if len(ear_y) > 0:
139
+ ear_top_y = ear_y.min()
140
+ ear_x_min = ear_x.min()
141
+ ear_x_max = ear_x.max()
142
+ ear_width = ear_x_max - ear_x_min + 1
143
+
144
+ kernel_protect = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 9))
145
+ ears_protected = cv2.dilate(ears_mask, kernel_protect, iterations=1)
146
+
147
+ if ear_top_y > 10:
148
+ ears_protected[:ear_top_y - 8, :] = 0
149
+
150
+ x_margin = int(ear_width * 0.25)
151
+ left = max(0, ear_x_min - x_margin)
152
+ right = min(working_w, ear_x_max + x_margin)
153
+
154
+ hair_mask_final = hair_mask.copy()
155
+ hair_mask_final[ears_protected == 1] = 0
156
+
157
+ top_quarter = int(working_h * 0.25)
158
+ if hair_mask[:top_quarter, :].sum() > 60:
159
+ hair_mask_final[:top_quarter, :] = np.maximum(
160
+ hair_mask_final[:top_quarter, :], hair_mask[:top_quarter, :]
161
+ )
162
+
163
+ kernel_s = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13))
164
+ hair_mask_final = cv2.morphologyEx(hair_mask_final, cv2.MORPH_CLOSE, kernel_s, iterations=2)
165
+ hair_mask_final = cv2.dilate(hair_mask_final, kernel_s, iterations=1)
166
+
167
+ blurred = cv2.GaussianBlur(hair_mask_final.astype(np.float32), (9, 9), 3)
168
+ hair_mask_final = (blurred > 0.28).astype(np.uint8)
169
+
170
+ kernel_edge = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
171
+ hair_mask_final = cv2.dilate(hair_mask_final, kernel_edge, iterations=1)
172
+
173
+ hair_pixels = np.sum(hair_mask_final)
174
+
175
+ final_mask = hair_mask_final.copy()
176
+ use_extended_mask = False
177
+ if hair_pixels > 380000:
178
+ use_extended_mask = True
179
+ big_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (25, 25))
180
+ extended = cv2.dilate(hair_mask_final, big_kernel, iterations=1)
181
+ upper = np.zeros_like(hair_mask_final)
182
+ upper_end = int(working_h * 0.48)
183
+ upper[:upper_end, :] = 1
184
+ extended = np.logical_or(extended, upper).astype(np.uint8)
185
+ extended[ears_protected == 1] = 0
186
+ extended = cv2.morphologyEx(extended, cv2.MORPH_CLOSE, kernel_s, iterations=1)
187
+ extended[int(working_h * 0.75):, :] = 0
188
+ final_mask = extended
189
+
190
+ if use_extended_mask or hair_pixels > 420000:
191
+ radius = 18
192
+ inpaint_flag = cv2.INPAINT_TELEA
193
+ elif hair_pixels > 220000:
194
+ radius = 15
195
+ inpaint_flag = cv2.INPAINT_TELEA
196
+ else:
197
+ radius = 10
198
+ inpaint_flag = cv2.INPAINT_NS
199
+
200
+ inpainted_bgr = cv2.inpaint(working_bgr, final_mask * 255, inpaintRadius=radius, flags=inpaint_flag)
201
+ inpainted_rgb = cv2.cvtColor(inpainted_bgr, cv2.COLOR_BGR2RGB)
202
+
203
+ # ==================== NEW: Add realistic bald head skin texture ====================
204
+ pores_noise = np.random.normal(0, 12, (working_h, working_w, 3)).astype(np.float32)
205
+ large_kernel = cv2.getGaussianKernel(61, 20)
206
+ large_var = cv2.filter2D(pores_noise, -1, large_kernel) * 0.5
207
+ texture_noise = pores_noise * 0.7 + large_var
208
+ texture_noise = np.clip(texture_noise, -25, 25)
209
+
210
+ textured_area = inpainted_rgb.astype(np.float32) + texture_noise
211
+ textured_area = np.clip(textured_area, 0, 255).astype(np.uint8)
212
+
213
+ blend_factor = 0.75 # 75% textured, 25% smooth inpaint
214
+ blended_bald = (blend_factor * textured_area + (1 - blend_factor) * inpainted_rgb).astype(np.uint8)
215
+ # =================================================================================
216
+
217
+ result_small = working_np.copy()
218
+ result_small[final_mask == 1] = blended_bald[final_mask == 1]
219
+
220
+ if len(ear_x) > 0:
221
+ side_clean_left = max(0, left - 30)
222
+ side_clean_right = min(working_w, right + 30)
223
+ final_mask[:, side_clean_left:side_clean_right] = np.minimum(
224
+ final_mask[:, side_clean_left:side_clean_right],
225
+ 1 - ears_protected[:, side_clean_left:side_clean_right]
226
+ )
227
+
228
+ if scale_factor < 1.0:
229
+ result = cv2.resize(result_small, (orig_w, orig_h), interpolation=cv2.INTER_LANCZOS4)
230
+ else:
231
+ result = result_small
232
+
233
+ output_bytes = io.BytesIO()
234
+ Image.fromarray(result).save(output_bytes, format="JPEG")
235
+ output_bytes.seek(0)
236
+ return output_bytes.read()
237
+
238
+ except Exception as main_err:
239
+ print("ERROR in make_realistic_bald:")
240
+ traceback.print_exc()
241
+ raise RuntimeError(f"Bald processing failed: {str(main_err)}")