tesalonikahtp commited on
Commit
588e92b
·
1 Parent(s): c602e28

feat: image gen

Browse files
app/util/passport_photo_engine/haar_face_detector.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+
3
+
4
+ class HaarFaceDetector:
5
+ def __init__(self):
6
+ self.detector = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
7
+ def detect(self, img_rgb):
8
+ gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
9
+ faces = self.detector.detectMultiScale(gray, 1.1, 5, minSize=(80,80))
10
+ return max(faces, key=lambda f: f[2]*f[3]) if len(faces) > 0 else None
app/util/passport_photo_engine/manual_face_extractor.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+
4
+ class ManualFaceExtractor:
5
+ def extract_face(self, img_rgb, mask):
6
+ if mask is None: return {"chin_angle": 0.0}
7
+ y_idxs, x_idxs = np.where(mask > 30)
8
+ if len(y_idxs) < 100: return {"chin_angle": 0.0}
9
+ sort_idx = np.argsort(y_idxs)[-int(len(y_idxs)*0.2):] # Bottom 20%
10
+ try:
11
+ m, c = np.polyfit(x_idxs[sort_idx], y_idxs[sort_idx], 1)
12
+ return {"chin_angle": float(np.clip(math.degrees(math.atan(m)), -30, 30))}
13
+ except: return {"chin_angle": 0.0}
app/util/passport_photo_engine/passport_cropper.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ class PassportCropper:
5
+ def __init__(self, output_size=(600,800), bg_color=(255,255,255)):
6
+ self.out_w, self.out_h = output_size
7
+ self.bg_color = tuple(int(x) for x in bg_color)
8
+ self.target_aspect = self.out_w / self.out_h
9
+
10
+ def composite(self, img_rgb, mask):
11
+ bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
12
+ m = cv2.GaussianBlur(mask, (5,5), 0)
13
+ alpha = (m.astype(float) / 255.0)[:,:,None]
14
+ bg = np.full_like(bgr, np.array(self.bg_color, dtype=np.uint8))
15
+ return (bgr.astype(float) * alpha + bg.astype(float) * (1.0 - alpha)).astype(np.uint8)
16
+
17
+ def rotate_and_expand_face(self, img_bgr, angle_deg, raw_face_box):
18
+ h0, w0 = img_bgr.shape[:2]
19
+ x1, y1, x2, y2 = raw_face_box
20
+ # Expansion Logic: 0.4 sides, 0.6 top, 1.2 bottom
21
+ fw, fh = x2-x1, y2-y1
22
+ hx1, hy1 = max(0, x1-int(fw*0.4)), max(0, y1-int(fh*0.6))
23
+ hx2, hy2 = min(w0-1, x2+int(fw*0.4)), min(h0-1, y2+int(fh*1.2))
24
+
25
+ # Rotation
26
+ M = cv2.getRotationMatrix2D((w0/2, h0/2), -angle_deg, 1.0)
27
+ cos, sin = np.abs(M[0,0]), np.abs(M[0,1])
28
+ nW, nH = int((h0*sin)+(w0*cos)), int((h0*cos)+(w0*sin))
29
+ M[0,2] += (nW/2) - w0/2; M[1,2] += (nH/2) - h0/2
30
+
31
+ rot_img = cv2.warpAffine(img_bgr, M, (nW, nH), borderValue=self.bg_color)
32
+
33
+ # Rotate Box Points
34
+ pts = np.array([[hx1,hy1,1],[hx2,hy1,1],[hx2,hy2,1],[hx1,hy2,1]]).T
35
+ rot_pts = M @ pts
36
+ rx, ry = rot_pts[0,:], rot_pts[1,:]
37
+ return rot_img, (int(rx.min()), int(ry.min()), int(rx.max()), int(ry.max()))
38
+
39
+ def crop_to_ratio(self, img, box):
40
+ bx1, by1, bx2, by2 = box
41
+ bw, bh = bx2-bx1, by2-by1
42
+ if bw/bh > self.target_aspect: # Too wide
43
+ new_h = int(bw/self.target_aspect)
44
+ by1 -= (new_h - bh)//2; by2 = by1 + new_h
45
+ else: # Too tall
46
+ new_w = int(bh*self.target_aspect)
47
+ bx1 -= (new_w - bw)//2; bx2 = bx1 + new_w
48
+
49
+ # Canvas Crop
50
+ H, W = img.shape[:2]
51
+ canvas = np.full((by2-by1, bx2-bx1, 3), self.bg_color, dtype=np.uint8)
52
+ sx1, sy1 = max(0, bx1), max(0, by1)
53
+ sx2, sy2 = min(W, bx2), min(H, by2)
54
+ dx1, dy1 = max(0, sx1-bx1), max(0, sy1-by1)
55
+ if sx2>sx1 and sy2>sy1: canvas[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = img[sy1:sy2, sx1:sx2]
56
+ return cv2.resize(canvas, (self.out_w, self.out_h), interpolation=cv2.INTER_AREA)
57
+
app/util/passport_photo_engine/segmenter_rmbg.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision.transforms.functional import normalize
4
+ from transformers import AutoModelForImageSegmentation
5
+
6
+ class SegmenterRMBG:
7
+ def __init__(self, device=None, model_input_size=(1024,1024)):
8
+ self.device = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ self.model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True).to(self.device)
10
+ self.modenl.eval()
11
+ self.model_input_size = list(model_input_size)
12
+
13
+ def _preprocess(self, img_np):
14
+ t = torch.tensor(img_np, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
15
+ t = F.interpolate(t, size=self.model_input_size, mode="bilinear") / 255.0
16
+ return normalize(t, [0.5]*3, [1.0]*3).to(self.device)
17
+
18
+ def _postprocess(self, result, orig_size):
19
+ H, W = orig_size
20
+ if result.dim() == 4: result = result
21
+ elif result.dim() == 3: result = result.unsqueeze(0)
22
+ result = F.interpolate(result, size=(H, W), mode="bilinear").squeeze(0).squeeze(0)
23
+ r_min, r_max = result.min(), result.max()
24
+ result = (result - r_min) / (r_max - r_min + 1e-8)
25
+ # Gamma correction for hair
26
+ result = torch.pow(result, 2.5)
27
+ result[result < 0.05] = 0
28
+ return (result * 255).cpu().numpy().astype(np.uint8)
29
+
30
+ def segment(self, img_rgb):
31
+ inp = self._preprocess(img_rgb)
32
+ with torch.no_grad(): out = self.model(inp)
33
+ if isinstance(out, (list, tuple)): out = out[0]
34
+ return self._postprocess(out[0], img_rgb.shape[:2])
requirements.txt CHANGED
@@ -17,4 +17,13 @@ pandas
17
  SQLAlchemy
18
  psycopg2-binary
19
 
20
- boto3
 
 
 
 
 
 
 
 
 
 
17
  SQLAlchemy
18
  psycopg2-binary
19
 
20
+ boto3
21
+
22
+ numpy
23
+ opencv-python-headless
24
+ pillow
25
+ torch
26
+ torchvision
27
+ transformers
28
+ huggingface_hub
29
+ accelerate
server.py CHANGED
@@ -3,23 +3,44 @@ import tempfile
3
  os.environ["PLAYWRIGHT_BROWSERS_PATH"] = "/home/user/.cache/ms-playwright"
4
 
5
  import logging
6
- from flask import Flask, request, jsonify, send_file
7
  from dotenv import load_dotenv
 
 
 
 
8
  import json
9
  import requests
10
  import uuid
11
  import importlib
12
  import io
 
 
13
 
14
  from app.util.gen_ai_base import GenAIBaseClient
15
  from app.util.browser_agent import BrowserAgent
16
- # from app.util.japan_multientry_visa_letter_generator import JapanMultiEntryVisaLetterGenerator
 
 
 
17
  from app.util.parameter_utils import init_secret
18
  import sys
19
  sys.stdout.reconfigure(line_buffering=True)
20
  API = "https://api-dev.spun.global"
21
 
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def create_app() -> Flask:
25
  load_dotenv()
@@ -206,6 +227,53 @@ def create_app() -> Flask:
206
  except Exception as e:
207
  print(f"Error in /generate/{visa_type}: {e}")
208
  return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  @app.route('/', methods=['GET'])
211
  def hello_world():
 
3
  os.environ["PLAYWRIGHT_BROWSERS_PATH"] = "/home/user/.cache/ms-playwright"
4
 
5
  import logging
 
6
  from dotenv import load_dotenv
7
+ import io
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
  import json
12
  import requests
13
  import uuid
14
  import importlib
15
  import io
16
+ from flask import Flask, request, jsonify, send_file
17
+
18
 
19
  from app.util.gen_ai_base import GenAIBaseClient
20
  from app.util.browser_agent import BrowserAgent
21
+ from app.util.passport_photo_engine.haar_face_detector import HaarFaceDetector
22
+ from app.util.passport_photo_engine.manual_face_extractor import ManualFaceExtractor
23
+ from app.util.passport_photo_engine.passport_cropper import PassportCropper
24
+ from app.util.passport_photo_engine.segmenter_rmbg import SegmenterRMBG
25
  from app.util.parameter_utils import init_secret
26
  import sys
27
  sys.stdout.reconfigure(line_buffering=True)
28
  API = "https://api-dev.spun.global"
29
 
30
+ print("--- Loading Passport AI Models (This happens once) ---")
31
+ passport_models = {
32
+ "segmenter": SegmenterRMBG(), # Heavy model (GPU/CPU)
33
+ "detector": HaarFaceDetector(), # Fast model
34
+ "extractor": ManualFaceExtractor()
35
+ }
36
+ print("--- Passport Models Ready ---")
37
+
38
+ PASSPORT_COLORS = {
39
+ "white": (255, 255, 255),
40
+ "id_red": (0, 0, 219),
41
+ "id_blue": (219, 0, 0),
42
+ "light_blue": (235, 206, 135)
43
+ }
44
 
45
  def create_app() -> Flask:
46
  load_dotenv()
 
227
  except Exception as e:
228
  print(f"Error in /generate/{visa_type}: {e}")
229
  return jsonify({"error": str(e)}), 500
230
+
231
+ @app.route("/generate-passport-photo", methods=["POST"])
232
+ def generate_passport():
233
+ data = request.get_json()
234
+ bg_color_name = data.get('bg_color_name', 'white')
235
+ response = requests.get(data['raw_photo'], stream=True)
236
+ if response.status_code != 200:
237
+ return jsonify({"error": f"Failed to download image from S3. Status: {response.status_code}"}), 400
238
+ try:
239
+ # Read image
240
+ in_memory_file = io.BytesIO(response.content)
241
+ pil_image = Image.open(in_memory_file).convert("RGB")
242
+ img_rgb = np.array(pil_image)
243
+
244
+ # Get models
245
+ seg = passport_models["segmenter"]
246
+ det = passport_models["detector"]
247
+ ext = passport_models["extractor"]
248
+
249
+ # 1. Segment
250
+ mask = seg.segment(img_rgb)
251
+
252
+ # 2. Detect
253
+ face_rect = det.detect(img_rgb)
254
+ if face_rect is None:
255
+ return jsonify({"error": "No face detected"}), 400
256
+ x, y, w, h = face_rect
257
+
258
+ # 3. Angle
259
+ info = ext.extract_face(img_rgb, mask)
260
+ angle = info.get("chin_angle", 0.0)
261
+
262
+ # 4. Process
263
+ selected_bg = PASSPORT_COLORS.get(bg_color_name, (255, 255, 255))
264
+ cropper = PassportCropper(output_size=(600, 800), bg_color=selected_bg)
265
+
266
+ img_clean = cropper.composite(img_rgb, mask)
267
+ img_rot, rot_box = cropper.rotate_and_expand_face(img_clean, angle, (x,y,x+w,y+h))
268
+ final_passport = cropper.crop_to_ratio(img_rot, rot_box)
269
+
270
+ # Return result
271
+ is_success, buffer = cv2.imencode(".jpg", final_passport)
272
+ return send_file(io.BytesIO(buffer), mimetype='image/jpeg')
273
+
274
+ except Exception as e:
275
+ print(f"Passport Error: {e}")
276
+ return jsonify({"error": str(e)}), 500
277
 
278
  @app.route('/', methods=['GET'])
279
  def hello_world():