Seniordev22 commited on
Commit
474ecad
Β·
verified Β·
1 Parent(s): a8e2058

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -86
app.py CHANGED
@@ -1,143 +1,247 @@
 
 
 
1
  import os
2
  import torch
3
  import numpy as np
4
  import cv2
5
- import time
6
- import asyncio
7
- import io
8
  import traceback
9
  import gc
10
- from PIL import Image
 
11
  from transformers import SegformerImageProcessor
12
  from fastapi import FastAPI, File, UploadFile, HTTPException
13
  from fastapi.responses import StreamingResponse
14
  from fastapi.middleware.cors import CORSMiddleware
 
 
15
  from concurrent.futures import ThreadPoolExecutor
16
  import logging
17
  import onnxruntime as ort
18
 
19
- logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
- PROCESS_SIZE = 256
23
- ONNX_PATH = "models/segformer_face_parsing.onnx"
24
  os.makedirs("models", exist_ok=True)
25
 
26
  executor = ThreadPoolExecutor(max_workers=1)
27
  face_processor = None
28
  ort_session = None
29
 
30
- def ensure_onnx_exists():
31
- if os.path.exists(ONNX_PATH):
32
- return True
33
- logger.info("Exporting ONNX (once, takes ~30s)...")
34
  from transformers import SegformerForSemanticSegmentation
35
  model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
36
  model.eval()
37
- dummy = torch.randn(1, 3, 192, 192)
38
- torch.onnx.export(
39
- model, dummy, ONNX_PATH,
40
- input_names=["pixel_values"],
41
- output_names=["logits"],
42
- dynamic_axes={"pixel_values": {0: "batch", 2: "height", 3: "width"},
43
- "logits": {0: "batch", 2: "height", 3: "width"}},
44
- opset_version=14,
45
- do_constant_folding=False
46
- )
47
- logger.info("ONNX ready")
48
- return True
49
 
50
  def load_face_parser():
51
  global face_processor, ort_session
52
  if ort_session is not None:
53
  return
54
- ensure_onnx_exists()
 
55
  face_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
56
- opts = ort.SessionOptions()
57
- opts.intra_op_num_threads = 1
58
- ort_session = ort.InferenceSession(ONNX_PATH, opts, providers=['CPUExecutionProvider'])
59
- logger.info("ONNX loaded")
 
60
 
61
- def get_hair_mask(pil_image):
62
  load_face_parser()
63
  orig_w, orig_h = pil_image.size
64
- # Always feed 192x192 to ONNX
65
- img_model = pil_image.resize((192,192), Image.LANCZOS)
66
- inputs = face_processor(images=img_model, return_tensors="pt")
67
- pixel_vals = inputs["pixel_values"].numpy().astype(np.float32)
68
- logits = torch.from_numpy(ort_session.run(["logits"], {"pixel_values": pixel_vals})[0])
69
- up = torch.nn.functional.interpolate(logits, size=(192,192), mode="bilinear")
 
 
 
 
 
 
 
 
 
70
  probs = torch.softmax(up, dim=1)[0]
71
- strong = (probs[13].cpu().numpy() > 0.055).astype(np.float32)
72
- soft = (probs[13].cpu().numpy() > 0.022).astype(np.float32)
73
- hair = np.maximum(strong, soft * 0.68)
74
- # face subtraction
 
 
 
75
  parsing = up.argmax(dim=1).squeeze(0).cpu().numpy()
76
- face_cls = list(range(1,6)) + list(range(8,13)) + [17,18]
77
- face_mask = np.isin(parsing, face_cls).astype(np.float32)
78
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
79
- face_mask = cv2.dilate(face_mask, kernel, iterations=1)
80
- h,w = face_mask.shape
81
- forehead = np.zeros_like(face_mask)
82
  forehead[:int(h*0.3), :] = 1.0
83
- face_mask = face_mask * (1 - forehead*0.45)
84
- hair = hair * (1 - face_mask)
85
- hair = cv2.morphologyEx(hair, cv2.MORPH_CLOSE, kernel, iterations=1)
86
- hair = cv2.GaussianBlur(hair, (3,3), 0.8)
87
- hair = cv2.resize(hair, (orig_w, orig_h))
88
- return hair
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- def apply_grey_hair(image, hair_mask):
91
- comb = hair_mask
92
- comb = cv2.GaussianBlur(comb, (5,5), 1)
93
- img = np.array(image).astype(np.float32)/255.0
 
 
 
94
  hsv = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
95
- hsv[:,:,1] = hsv[:,:,1] * (1 - 0.7*comb)
96
- hsv[:,:,2] = np.clip(hsv[:,:,2] + (70*comb), 100, 230)
97
- result = cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB).astype(np.float32)/255.0
98
- comb_3ch = np.stack([comb,comb,comb], axis=2)
99
- final = result*comb_3ch + img*(1-comb_3ch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  final = np.clip(final*255, 0, 255).astype(np.uint8)
101
- return Image.fromarray(final)
 
 
 
102
 
103
- def process_image(input_image):
 
 
 
 
 
104
  orig = input_image.convert("RGB")
105
  ow, oh = orig.size
106
  img_resized = orig.resize((PROCESS_SIZE, PROCESS_SIZE), Image.LANCZOS)
107
- hair = get_hair_mask(img_resized)
108
- result = apply_grey_hair(img_resized, hair)
109
- final = result.resize((ow, oh), Image.LANCZOS)
110
- return final
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- app = FastAPI()
 
 
 
113
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
114
 
115
- sem = asyncio.Semaphore(1)
116
-
117
  @app.on_event("startup")
118
  async def startup():
 
119
  loop = asyncio.get_event_loop()
120
  await loop.run_in_executor(executor, load_face_parser)
121
- logger.info("Server ready")
122
 
123
  @app.post("/age-face")
124
  async def age_face(file: UploadFile = File(...)):
125
- await sem.acquire()
126
- try:
127
- data = await file.read()
128
- img = Image.open(io.BytesIO(data)).convert("RGB")
129
- loop = asyncio.get_event_loop()
130
- out = await loop.run_in_executor(executor, process_image, img)
131
- buf = io.BytesIO()
132
- out.save(buf, format="JPEG", quality=90)
133
- buf.seek(0)
134
- return StreamingResponse(buf, media_type="image/jpeg")
135
- except Exception as e:
136
- logger.error(traceback.format_exc())
137
- raise HTTPException(500, str(e))
138
- finally:
139
- sem.release()
140
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  if __name__ == "__main__":
143
  import uvicorn
 
1
+ # ================================================
2
+ # END-TO-END TIMING LOGGING (NO HARDCODE)
3
+ # ================================================
4
  import os
5
  import torch
6
  import numpy as np
7
  import cv2
 
 
 
8
  import traceback
9
  import gc
10
+ import time
11
+ from PIL import Image, ImageFilter
12
  from transformers import SegformerImageProcessor
13
  from fastapi import FastAPI, File, UploadFile, HTTPException
14
  from fastapi.responses import StreamingResponse
15
  from fastapi.middleware.cors import CORSMiddleware
16
+ import io
17
+ import asyncio
18
  from concurrent.futures import ThreadPoolExecutor
19
  import logging
20
  import onnxruntime as ort
21
 
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
  logger = logging.getLogger(__name__)
24
 
25
+ PROCESS_SIZE = 384
26
+ onnx_path = "models/segformer_face_parsing.onnx"
27
  os.makedirs("models", exist_ok=True)
28
 
29
  executor = ThreadPoolExecutor(max_workers=1)
30
  face_processor = None
31
  ort_session = None
32
 
33
+ def convert_to_onnx():
34
+ if os.path.exists(onnx_path):
35
+ return
36
+ logger.info("βš™οΈ Converting Segformer to ONNX (first time only)")
37
  from transformers import SegformerForSemanticSegmentation
38
  model = SegformerForSemanticSegmentation.from_pretrained("jonathandinu/face-parsing")
39
  model.eval()
40
+ dummy_input = torch.randn(1, 3, 192, 192)
41
+ torch.onnx.export(model, dummy_input, onnx_path,
42
+ input_names=["pixel_values"],
43
+ output_names=["logits"],
44
+ dynamic_axes={"pixel_values": {0: "batch", 2: "height", 3: "width"}},
45
+ opset_version=14, do_constant_folding=True)
46
+ logger.info("βœ… ONNX conversion done")
 
 
 
 
 
47
 
48
  def load_face_parser():
49
  global face_processor, ort_session
50
  if ort_session is not None:
51
  return
52
+ t0 = time.time()
53
+ convert_to_onnx()
54
  face_processor = SegformerImageProcessor.from_pretrained("jonathandinu/face-parsing")
55
+ sess_options = ort.SessionOptions()
56
+ sess_options.intra_op_num_threads = 1
57
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
58
+ ort_session = ort.InferenceSession(onnx_path, sess_options, providers=['CPUExecutionProvider'])
59
+ logger.info(f"βœ… ONNX loaded in {time.time()-t0:.2f} sec")
60
 
61
+ def get_hair_and_exclude_masks(pil_image: Image.Image):
62
  load_face_parser()
63
  orig_w, orig_h = pil_image.size
64
+
65
+ t0 = time.time()
66
+ img_small = pil_image.resize((192, 192), Image.LANCZOS)
67
+ inputs = face_processor(images=img_small, return_tensors="pt")
68
+ pixel_values = inputs["pixel_values"].numpy().astype(np.float32)
69
+ logger.debug(f" - Preprocess: {time.time()-t0:.3f}s")
70
+
71
+ t0 = time.time()
72
+ ort_inputs = {"pixel_values": pixel_values}
73
+ ort_outs = ort_session.run(["logits"], ort_inputs)
74
+ logits = torch.from_numpy(ort_outs[0])
75
+ logger.debug(f" - ONNX inference: {time.time()-t0:.3f}s")
76
+
77
+ t0 = time.time()
78
+ up = torch.nn.functional.interpolate(logits, size=(192, 192), mode="bilinear", align_corners=False)
79
  probs = torch.softmax(up, dim=1)[0]
80
+ logger.debug(f" - Softmax+upsample: {time.time()-t0:.3f}s")
81
+
82
+ # Hair mask
83
+ t0 = time.time()
84
+ strong_hair = (probs[13].cpu().numpy() > 0.055).astype(np.float32)
85
+ soft_hair = (probs[13].cpu().numpy() > 0.022).astype(np.float32)
86
+ hair = np.maximum(strong_hair, soft_hair * 0.68)
87
  parsing = up.argmax(dim=1).squeeze(0).cpu().numpy()
88
+ face_cls = list(range(1, 6)) + list(range(8, 13)) + [17, 18]
89
+ face_m = np.isin(parsing, face_cls).astype(np.float32)
90
+ kernel_face = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
91
+ face_m = cv2.dilate(face_m, kernel_face, iterations=1)
92
+ h, w = face_m.shape
93
+ forehead = np.zeros_like(face_m)
94
  forehead[:int(h*0.3), :] = 1.0
95
+ face_m = face_m * (1 - forehead * 0.45)
96
+ hair = hair * (1 - face_m)
97
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
98
+ hair = cv2.morphologyEx(hair, cv2.MORPH_CLOSE, kernel, iterations=2)
99
+ hair = cv2.GaussianBlur(hair, (5,5), 1.5)
100
+ hair = cv2.resize(hair, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
101
+ logger.debug(f" - Hair postprocess: {time.time()-t0:.3f}s")
102
+
103
+ # Exclude mask
104
+ t0 = time.time()
105
+ nose = (probs[2].cpu().numpy() > 0.5).astype(np.float32)
106
+ lip_up = (probs[11].cpu().numpy() > 0.5).astype(np.float32)
107
+ lip_low = (probs[12].cpu().numpy() > 0.5).astype(np.float32)
108
+ exclude = np.clip(nose + lip_up + lip_low, 0, 1)
109
+ exclude = cv2.resize(exclude, (orig_w, orig_h), interpolation=cv2.INTER_NEAREST)
110
+ kernel_ex = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
111
+ exclude = cv2.dilate(exclude, kernel_ex, iterations=2)
112
+ logger.debug(f" - Exclude mask: {time.time()-t0:.3f}s")
113
+
114
+ return hair, exclude
115
 
116
+ def apply_strong_grey_hair(image: Image.Image, hair_mask: np.ndarray, beard_mask: np.ndarray) -> Image.Image:
117
+ t0 = time.time()
118
+ comb = np.maximum(hair_mask, beard_mask)
119
+ if np.sum(comb) < 100:
120
+ logger.warning("⚠️ Small mask area")
121
+ comb = cv2.GaussianBlur(comb, (7,7), 2)
122
+ img = np.array(image).astype(np.float32) / 255.0
123
  hsv = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2HSV).astype(np.float32)
124
+ hsv_hair = hsv.copy()
125
+ saturation_factor = 0.8
126
+ brightness_boost = 90
127
+ hsv_hair[:,:,1] = hsv_hair[:,:,1] * (1 - saturation_factor * hair_mask)
128
+ hsv_hair[:,:,2] = hsv_hair[:,:,2] + (brightness_boost * hair_mask)
129
+ hsv_hair[:,:,2] = np.clip(hsv_hair[:,:,2], 100, 200)
130
+ hair_grey = cv2.cvtColor(hsv_hair.astype(np.uint8), cv2.COLOR_HSV2RGB).astype(np.float32) / 255.0
131
+ hair_lab = cv2.cvtColor((hair_grey*255).astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
132
+ img_lab = cv2.cvtColor((img*255).astype(np.uint8), cv2.COLOR_RGB2LAB).astype(np.float32)
133
+ hair_mask_binary = (hair_mask > 0.5)
134
+ if np.sum(hair_mask_binary) > 100:
135
+ mean_hair_lab = np.mean(hair_lab[hair_mask_binary], axis=0)
136
+ std_hair_lab = np.std(hair_lab[hair_mask_binary], axis=0)
137
+ else:
138
+ mean_hair_lab = np.array([128,0,0])
139
+ std_hair_lab = np.array([30,10,10])
140
+ beard_mask_binary = (beard_mask > 0.5)
141
+ if np.sum(beard_mask_binary) > 0:
142
+ beard_pixels_lab = img_lab[beard_mask_binary]
143
+ mean_beard_lab = np.mean(beard_pixels_lab, axis=0)
144
+ std_beard_lab = np.std(beard_pixels_lab, axis=0)
145
+ std_beard_lab = np.maximum(std_beard_lab, 1e-5)
146
+ beard_norm = (beard_pixels_lab - mean_beard_lab) / std_beard_lab
147
+ beard_transfer = beard_norm * std_hair_lab + mean_hair_lab
148
+ beard_transfer = np.clip(beard_transfer, 0, 255)
149
+ img_lab_transfer = img_lab.copy()
150
+ img_lab_transfer[beard_mask_binary] = beard_transfer
151
+ else:
152
+ img_lab_transfer = img_lab
153
+ final = cv2.cvtColor(img_lab_transfer.astype(np.uint8), cv2.COLOR_LAB2RGB).astype(np.float32) / 255.0
154
+ hair_mask_3ch = np.stack([hair_mask, hair_mask, hair_mask], axis=2)
155
+ final = hair_grey * hair_mask_3ch + final * (1 - hair_mask_3ch)
156
+ comb_3ch = np.stack([comb, comb, comb], axis=2)
157
+ final = final * comb_3ch + img * (1 - comb_3ch)
158
+ warm = np.array([5,3,0], dtype=np.float32)/255.0
159
+ final = final + (warm * comb[..., None] * 0.2)
160
  final = np.clip(final*255, 0, 255).astype(np.uint8)
161
+ result = Image.fromarray(final)
162
+ result = result.filter(ImageFilter.UnsharpMask(radius=0.5, percent=50, threshold=0))
163
+ logger.debug(f" - Color transfer: {time.time()-t0:.3f}s")
164
+ return result
165
 
166
+ def process_face_whitening(input_image: Image.Image):
167
+ total_start = time.time()
168
+ logger.info("="*50)
169
+ logger.info("πŸ–ΌοΈ Processing new image")
170
+
171
+ t0 = time.time()
172
  orig = input_image.convert("RGB")
173
  ow, oh = orig.size
174
  img_resized = orig.resize((PROCESS_SIZE, PROCESS_SIZE), Image.LANCZOS)
175
+ logger.info(f"πŸ“₯ Step 1 - Load & resize: {time.time()-t0:.2f}s (to {PROCESS_SIZE}x{PROCESS_SIZE})")
176
+
177
+ t0 = time.time()
178
+ hair_mask, exclude_mask = get_hair_and_exclude_masks(img_resized)
179
+ logger.info(f"🎭 Step 2 - Mask generation: {time.time()-t0:.2f}s (hair sum: {np.sum(hair_mask):.0f})")
180
+
181
+ t0 = time.time()
182
+ beard_mask = np.zeros_like(hair_mask)
183
+ logger.info(f"πŸ§” Step 3 - Beard mask (skipped): {time.time()-t0:.2f}s")
184
+
185
+ t0 = time.time()
186
+ result_resized = apply_strong_grey_hair(img_resized, hair_mask, beard_mask)
187
+ logger.info(f"🎨 Step 4 - Color transfer: {time.time()-t0:.2f}s")
188
+
189
+ t0 = time.time()
190
+ final_img = result_resized.resize((ow, oh), Image.LANCZOS)
191
+ logger.info(f"πŸ“€ Step 5 - Resize to original: {time.time()-t0:.2f}s ({ow}x{oh})")
192
+
193
+ processing_time = time.time() - total_start
194
+ logger.info(f"βš™οΈ Core processing time: {processing_time:.2f} seconds")
195
+ return final_img, processing_time
196
 
197
+ # ================================================
198
+ # FASTAPI APP WITH END-TO-END TIMING
199
+ # ================================================
200
+ app = FastAPI(title="Grey Hair API (Accurate Timing)")
201
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
202
 
 
 
203
  @app.on_event("startup")
204
  async def startup():
205
+ t0 = time.time()
206
  loop = asyncio.get_event_loop()
207
  await loop.run_in_executor(executor, load_face_parser)
208
+ logger.info(f"πŸ”₯ Server ready in {time.time()-t0:.2f} seconds")
209
 
210
  @app.post("/age-face")
211
  async def age_face(file: UploadFile = File(...)):
212
+ request_start = time.time()
213
+ logger.info("πŸš€ Request received")
214
+
215
+ # Step A: Read file
216
+ t0 = time.time()
217
+ contents = await file.read()
218
+ read_time = time.time() - t0
219
+ logger.info(f"πŸ“ File read: {read_time:.3f}s ({len(contents)} bytes)")
220
+
221
+ # Step B: Decode image
222
+ t0 = time.time()
223
+ img = Image.open(io.BytesIO(contents)).convert("RGB")
224
+ decode_time = time.time() - t0
225
+ logger.info(f"πŸ–ΌοΈ Image decode: {decode_time:.3f}s ({img.size[0]}x{img.size[1]})")
226
+
227
+ # Step C: Process (core)
228
+ loop = asyncio.get_event_loop()
229
+ result_img, core_time = await loop.run_in_executor(executor, process_face_whitening, img)
230
+
231
+ # Step D: Encode to JPEG
232
+ t0 = time.time()
233
+ buf = io.BytesIO()
234
+ result_img.save(buf, format="JPEG", quality=92)
235
+ encode_time = time.time() - t0
236
+ logger.info(f"πŸ’Ύ JPEG encode: {encode_time:.3f}s")
237
+ buf.seek(0)
238
+
239
+ # Step E: Send response
240
+ total_time = time.time() - request_start
241
+ logger.info(f"πŸ“‘ Total end-to-end time: {total_time:.2f} seconds (core: {core_time:.2f}, overhead: {total_time-core_time:.2f})")
242
+ logger.info("="*50)
243
+
244
+ return StreamingResponse(buf, media_type="image/jpeg")
245
 
246
  if __name__ == "__main__":
247
  import uvicorn