File size: 12,810 Bytes
fc19502
0eccb9d
fc19502
 
c2b8898
34948ba
474ecad
0eccb9d
 
7d3069e
fc19502
 
474ecad
 
a8e2058
55edead
a49c797
6606c41
a8e2058
a49c797
e273778
0eccb9d
34948ba
0eccb9d
a50a9ef
 
59f1e7f
a50a9ef
34948ba
 
 
 
c2b8898
34948ba
a49c797
6606c41
36879ea
a50a9ef
34948ba
a49c797
0eccb9d
 
 
a49c797
34948ba
5c17bdd
 
 
 
 
 
 
 
 
 
 
a50a9ef
0eccb9d
c7400ce
0eccb9d
e86bb8e
a50a9ef
0eccb9d
 
 
 
e273778
0eccb9d
 
 
 
a50a9ef
0eccb9d
 
 
23fab7f
e5bf4cd
 
d57afcc
 
 
e5bf4cd
d57afcc
 
e5bf4cd
d57afcc
 
e5bf4cd
d57afcc
 
 
e5bf4cd
 
 
 
 
 
 
d57afcc
 
 
e5bf4cd
 
 
 
8aae84a
e86bb8e
0eccb9d
e86bb8e
8938c96
e5bf4cd
0eccb9d
8938c96
a50a9ef
c7400ce
a50a9ef
34948ba
69a1926
e273778
cfcb992
69a1926
7d3069e
34948ba
d57afcc
0eccb9d
 
a49c797
0eccb9d
 
8938c96
a50a9ef
257adde
34948ba
 
 
0eccb9d
34948ba
d57afcc
34948ba
cfcb992
 
 
 
 
cdbee3f
 
69a1926
 
6e09aed
8aae84a
e5bf4cd
 
 
 
9a1f8d8
8aae84a
 
e5bf4cd
69a1926
e5bf4cd
0eccb9d
324b7e6
cfcb992
8aae84a
0eccb9d
e86bb8e
e5bf4cd
e273778
0eccb9d
34948ba
 
 
 
d57afcc
cdbee3f
34948ba
69a1926
34948ba
cdbee3f
34948ba
 
a49c797
324b7e6
 
0eccb9d
 
d57afcc
324b7e6
 
 
 
 
 
34948ba
e5bf4cd
34948ba
324b7e6
 
12a8a6d
 
 
 
cfcb992
8938c96
 
12a8a6d
 
 
 
8aae84a
 
12a8a6d
 
 
c0d286a
12a8a6d
 
 
34948ba
8aae84a
324b7e6
a49c797
324b7e6
5c17bdd
a50a9ef
b8655b7
 
 
 
a50a9ef
34948ba
 
7d3069e
b8655b7
 
7d3069e
b8655b7
 
 
 
59f1e7f
b8655b7
a50a9ef
b8655b7
 
1ee3050
b8655b7
8938c96
e5bf4cd
34948ba
6e09aed
d57afcc
0eccb9d
 
324b7e6
5c17bdd
0eccb9d
36879ea
 
e5bf4cd
36879ea
34948ba
 
e5bf4cd
36879ea
34948ba
8aae84a
324b7e6
e5bf4cd
324b7e6
 
1a83ff0
 
 
 
324b7e6
b74a2d5
36879ea
 
34948ba
36879ea
 
 
34948ba
36879ea
0eccb9d
257adde
a49c797
5c17bdd
a49c797
64cb831
e86bb8e
af0c0e0
0eccb9d
 
257adde
34948ba
257adde
 
 
a50a9ef
36c4483
 
a50a9ef
474ecad
a49c797
8938c96
a49c797
 
8938c96
e86bb8e
7d3069e
e86bb8e
8938c96
a50a9ef
5c17bdd
8938c96
e86bb8e
a49c797
e19c16f
 
5c17bdd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import os
import torch
import numpy as np
import cv2
import gc
import time
from PIL import Image, ImageFilter
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
from ultralytics import YOLO
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import io
import asyncio
from concurrent.futures import ThreadPoolExecutor
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ====================== CONFIG ======================
BEARD_MODEL_PATH = "models/best_hair_117_epoch_v4.pt"
SAFE_IMG_SIZE = 384
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info(f"Using Device: {DEVICE}")
logger.info(f"CUDA Available: {torch.cuda.is_available()}")

if DEVICE.type == "cpu":
    torch.set_num_threads(4)
    torch.set_num_interop_threads(1)
    cv2.setNumThreads(4)
else:
    torch.set_num_threads(1)

os.environ["HF_HOME"] = "/tmp/hf_cache"
os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics"

executor = ThreadPoolExecutor(max_workers=2)

face_processor = None
face_parser = None
beard_model = None

# ====================== TIMED DECORATOR ======================
def timed(name: str):
    def decorator(func):
        def wrapper(*args, **kwargs):
            start = time.perf_counter()
            result = func(*args, **kwargs)
            elapsed = (time.perf_counter() - start) * 1000
            logger.info(f"{name}: {elapsed:.1f} ms")
            return result
        return wrapper
    return decorator

# ====================== MODEL LOADING ======================
def load_face_parser():
    global face_processor, face_parser
    if face_parser is not None:
        return
    logger.info("Loading Segformer Face Parser...")
    face_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
    face_parser = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
    face_parser.to(DEVICE)
    face_parser.eval()
    logger.info("✅ Face parser loaded")

def load_beard_model():
    global beard_model
    if beard_model is None:
        logger.info("Loading YOLO Beard Model...")
        beard_model = YOLO(BEARD_MODEL_PATH)
    return beard_model

# ====================== MUSTACHE MASK ======================
@timed("Mustache Mask")
def get_mustache_mask(probs, orig_w, orig_h, exclude_mask):
    u_lip = (probs[11].numpy() > 0.13).astype(np.float32)
    l_lip = (probs[12].numpy() > 0.13).astype(np.float32)
    mouth = (probs[10].numpy() > 0.18).astype(np.float32)

    mustache = np.maximum(u_lip * 1.15, l_lip)
    mustache = np.maximum(mustache, mouth * 0.45)

    kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (9, 3))
    kernel_e = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))

    mustache = cv2.dilate(mustache, kernel_e, iterations=1)
    mustache = cv2.morphologyEx(mustache, cv2.MORPH_CLOSE, kernel_h, iterations=2)
    mustache = cv2.GaussianBlur(mustache, (7, 5), 1.2)

    shift_y = 1
    M = np.float32([[1, 0, 0], [0, 1, shift_y]])
    mustache = cv2.warpAffine(mustache, M, (mustache.shape[1], mustache.shape[0]),
                              flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)

    mustache = cv2.resize(mustache, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
    mustache = np.maximum(mustache - exclude_mask * 0.5, 0)
    mustache = cv2.GaussianBlur(mustache, (5, 5), 1.0)
    mustache = (mustache > 0.15).astype(np.float32)

    return mustache

# ====================== HAIR + EXCLUDE + LIP MASK ======================
@timed("Hair + Exclude + Lip Mask")
def get_hair_and_exclude_masks(pil_image: Image.Image):
    load_face_parser()
    orig_w, orig_h = pil_image.size

    img_small = pil_image.resize((128, 128), Image.BILINEAR)
    inputs = face_processor(images=img_small, return_tensors="pt").to(DEVICE)

    with torch.inference_mode():
        out = face_parser(**inputs)
        logits = out.logits
        up = torch.nn.functional.interpolate(logits, size=(128, 128), mode="bilinear", align_corners=False)
        probs = torch.softmax(up, dim=1)[0].cpu()

    # Hair mask
    hair = (probs[13].numpy() > 0.035).astype(np.float32)
    hair = cv2.GaussianBlur(hair, (3, 3), 1.0)

    # Face mask
    parsing = up.argmax(dim=1).squeeze(0).cpu().numpy()
    face_cls = list(range(1,6)) + list(range(8,13)) + [17,18]
    face_m = np.isin(parsing, face_cls).astype(np.float32)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
    face_m = cv2.dilate(face_m, kernel, iterations=1)

    h, w = face_m.shape
    forehead = np.zeros_like(face_m, dtype=np.float32)
    forehead[:int(h * 0.32)] = 1.0
    face_m = face_m * (1 - forehead * 0.45)
    hair = hair * (1 - face_m)
    hair = cv2.resize(hair, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)

    # Exclude mask
    exclude = np.zeros((128, 128), dtype=np.float32)
    exclude = np.maximum(exclude, (probs[10].numpy() > 0.35).astype(np.float32))
    exclude = np.maximum(exclude, (probs[11].numpy() > 0.35).astype(np.float32))
    exclude = np.maximum(exclude, (probs[12].numpy() > 0.35).astype(np.float32))
    exclude = np.maximum(exclude, (probs[4].numpy() > 0.35).astype(np.float32))
    exclude = np.maximum(exclude, (probs[5].numpy() > 0.35).astype(np.float32))
    exclude = cv2.dilate(exclude, kernel, iterations=2)
    exclude = cv2.GaussianBlur(exclude, (5, 5), 1.2)
    exclude = cv2.resize(exclude, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)

    # Lip mask
    lip_mask = np.zeros((128, 128), dtype=np.float32)
    lip_mask = np.maximum(lip_mask, (probs[10].numpy() > 0.42).astype(np.float32))
    lip_mask = np.maximum(lip_mask, (probs[11].numpy() > 0.42).astype(np.float32))
    lip_mask = np.maximum(lip_mask, (probs[12].numpy() > 0.42).astype(np.float32))
    lip_mask = cv2.dilate(lip_mask, kernel, iterations=1)
    lip_mask = cv2.resize(lip_mask, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
    lip_mask = (lip_mask > 0.5).astype(np.float32)

    mustache = get_mustache_mask(probs, orig_w, orig_h, exclude)

    return hair, exclude, mustache, lip_mask

# ====================== BEARD MASK (FIXED: returns beard_present flag) ======================
@timed("Beard Mask")
def get_beard_mask_fast(pil_image: Image.Image, exclude_mask: np.ndarray, lip_mask: np.ndarray):
    model = load_beard_model()
    orig_w, orig_h = pil_image.size

    img_small = pil_image.resize((128, 128), Image.BILINEAR)
    img_array = np.array(img_small)

    results = model.predict(
        img_array,
        device=DEVICE.type,
        conf=0.18,
        iou=0.45,
        imgsz=128,
        half=False,
        verbose=False,
        max_det=8
    )

    mask = np.zeros((orig_h, orig_w), dtype=np.float32)
    beard_present = False  # <-- NEW FLAG

    if results[0].masks is not None:
        for i, cls in enumerate(results[0].boxes.cls):
            if int(cls) == 0:
                conf = results[0].boxes.conf[i].item()
                if conf > 0.25:   # confidence threshold for considering a real beard
                    beard_present = True
                    m = results[0].masks.data[i].cpu().numpy()
                    m = cv2.resize(m, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
                    mask = np.maximum(mask, (m > 0.25).astype(np.float32))

    mask = np.maximum(mask - exclude_mask * 0.6, 0)

    # Only apply morphological refinements if beard is actually present
    if beard_present and mask.sum() > 25:
        kernel_erode = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7))
        mask = cv2.erode(mask, kernel_erode, iterations=2)
        kernel_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13))
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel_close, iterations=3)
        kernel_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel_open, iterations=1)

        contours, _ = cv2.findContours((mask > 0.1).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
        if contours:
            smooth_mask = np.zeros_like(mask, dtype=np.float32)
            for cnt in contours:
                if cv2.contourArea(cnt) > 50:
                    epsilon = 0.008 * cv2.arcLength(cnt, True)
                    approx = cv2.approxPolyDP(cnt, epsilon, True)
                    cv2.drawContours(smooth_mask, [approx], -1, 1.0, thickness=cv2.FILLED)
            mask = smooth_mask

        mask = cv2.GaussianBlur(mask, (9, 9), 2.0)
        mask = (mask > 0.28).astype(np.float32)
        mask = cv2.erode(mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)), iterations=1)

    mask[lip_mask > 0] = 0
    return mask, beard_present   # <-- RETURN BOTH

# ====================== COLOR TRANSFER - BEARD SAME AS HAIR ======================
@timed("Color Transfer")
def apply_strong_grey_hair(image: Image.Image, hair_mask: np.ndarray, beard_mask: np.ndarray):
    # Combine hair and beard masks
    combined_mask = np.maximum(hair_mask, beard_mask)
    if combined_mask.sum() < 100:
        combined_mask = cv2.GaussianBlur(combined_mask, (5,5), 1.5)

    img = np.array(image).astype(np.float32) / 255.0
    hsv = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)

    hsv_transformed = hsv.copy()
    hsv_transformed[..., 1] = hsv_transformed[..., 1] * (1 - 0.78 * combined_mask)
    original_v = hsv[..., 2]
    boost_amount = 89 * combined_mask
    hsv_transformed[..., 2] = np.clip(
        original_v + boost_amount - (original_v * 0.35 * combined_mask),
        110, 210
    )
    transformed_rgb = cv2.cvtColor(hsv_transformed.astype(np.uint8), cv2.COLOR_HSV2RGB).astype(np.float32) / 255.0

    combined_mask_3ch = np.stack([combined_mask, combined_mask, combined_mask], axis=2)
    final = transformed_rgb * combined_mask_3ch + img * (1 - combined_mask_3ch)

    final = final + (np.array([9, 7, 5], dtype=np.float32) / 255.0 * combined_mask[..., None] * 0.18)

    final = np.clip(final * 255, 0, 255).astype(np.uint8)
    result = Image.fromarray(final)
    result = result.filter(ImageFilter.UnsharpMask(radius=0.8, percent=75, threshold=1))

    return result

# ====================== MAIN PROCESSING (FIXED: mustache only if beard detected) ======================
@timed("Total Processing")
def process_face_whitening(input_image: Image.Image):
    orig = input_image.convert("RGB")
    ow, oh = orig.size

    target = min(SAFE_IMG_SIZE, max(ow, oh))
    if target % 2 != 0:
        target -= 1

    img_resized = orig.resize((target, target), Image.BILINEAR)

    hair_mask, exclude_mask, mustache_mask, lip_mask = get_hair_and_exclude_masks(img_resized)
    beard_mask, beard_present = get_beard_mask_fast(img_resized, exclude_mask, lip_mask)

    # ========== KEY FIX: Apply mustache ONLY if a beard is present ==========
    if beard_present:
        beard_mask = np.maximum(beard_mask, mustache_mask * 0.98)
        weak_mustache = (mustache_mask > 0.18) & (beard_mask < 0.48)
        beard_mask[weak_mustache] = np.maximum(beard_mask[weak_mustache], 0.75)
        beard_mask[lip_mask > 0] = 0
    # else: no beard → mustache mask is ignored completely

    final_resized = apply_strong_grey_hair(img_resized, hair_mask, beard_mask)
    final_img = final_resized.resize((ow, oh), Image.LANCZOS)

    gc.collect()
    if DEVICE.type == "cuda":
        torch.cuda.empty_cache()

    return final_img

# ====================== FASTAPI ======================
app = FastAPI()
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

@app.on_event("startup")
async def startup():
    loop = asyncio.get_event_loop()
    await loop.run_in_executor(executor, load_face_parser)
    await loop.run_in_executor(executor, load_beard_model)
    logger.info("✅ Models loaded")
    logger.info("Running light warmup...")
    dummy = Image.new("RGB", (256, 256))
    _ = process_face_whitening(dummy)
    logger.info("✅ Server Ready!")

@app.post("/age-face")
async def age_face(file: UploadFile = File(...)):
    start_total = time.perf_counter()
    contents = await file.read()
    img = Image.open(io.BytesIO(contents)).convert("RGB")

    loop = asyncio.get_event_loop()
    result = await loop.run_in_executor(executor, process_face_whitening, img)

    buf = io.BytesIO()
    result.save(buf, format="JPEG", quality=92, optimize=True)
    buf.seek(0)

    total_time = (time.perf_counter() - start_total) * 1000
    logger.info(f"✅ Total Request Time: {total_time:.1f} ms")

    return StreamingResponse(buf, media_type="image/jpeg")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)