Spaces:
Sleeping
Sleeping
tesalonikahtp commited on
Commit ·
588e92b
1
Parent(s): c602e28
feat: image gen
Browse files
app/util/passport_photo_engine/haar_face_detector.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class HaarFaceDetector:
|
| 5 |
+
def __init__(self):
|
| 6 |
+
self.detector = cv2.CascadeClassifier(cv2.data.haarcascades + "haarcascade_frontalface_default.xml")
|
| 7 |
+
def detect(self, img_rgb):
|
| 8 |
+
gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
|
| 9 |
+
faces = self.detector.detectMultiScale(gray, 1.1, 5, minSize=(80,80))
|
| 10 |
+
return max(faces, key=lambda f: f[2]*f[3]) if len(faces) > 0 else None
|
app/util/passport_photo_engine/manual_face_extractor.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
class ManualFaceExtractor:
|
| 5 |
+
def extract_face(self, img_rgb, mask):
|
| 6 |
+
if mask is None: return {"chin_angle": 0.0}
|
| 7 |
+
y_idxs, x_idxs = np.where(mask > 30)
|
| 8 |
+
if len(y_idxs) < 100: return {"chin_angle": 0.0}
|
| 9 |
+
sort_idx = np.argsort(y_idxs)[-int(len(y_idxs)*0.2):] # Bottom 20%
|
| 10 |
+
try:
|
| 11 |
+
m, c = np.polyfit(x_idxs[sort_idx], y_idxs[sort_idx], 1)
|
| 12 |
+
return {"chin_angle": float(np.clip(math.degrees(math.atan(m)), -30, 30))}
|
| 13 |
+
except: return {"chin_angle": 0.0}
|
app/util/passport_photo_engine/passport_cropper.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class PassportCropper:
|
| 5 |
+
def __init__(self, output_size=(600,800), bg_color=(255,255,255)):
|
| 6 |
+
self.out_w, self.out_h = output_size
|
| 7 |
+
self.bg_color = tuple(int(x) for x in bg_color)
|
| 8 |
+
self.target_aspect = self.out_w / self.out_h
|
| 9 |
+
|
| 10 |
+
def composite(self, img_rgb, mask):
|
| 11 |
+
bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
|
| 12 |
+
m = cv2.GaussianBlur(mask, (5,5), 0)
|
| 13 |
+
alpha = (m.astype(float) / 255.0)[:,:,None]
|
| 14 |
+
bg = np.full_like(bgr, np.array(self.bg_color, dtype=np.uint8))
|
| 15 |
+
return (bgr.astype(float) * alpha + bg.astype(float) * (1.0 - alpha)).astype(np.uint8)
|
| 16 |
+
|
| 17 |
+
def rotate_and_expand_face(self, img_bgr, angle_deg, raw_face_box):
|
| 18 |
+
h0, w0 = img_bgr.shape[:2]
|
| 19 |
+
x1, y1, x2, y2 = raw_face_box
|
| 20 |
+
# Expansion Logic: 0.4 sides, 0.6 top, 1.2 bottom
|
| 21 |
+
fw, fh = x2-x1, y2-y1
|
| 22 |
+
hx1, hy1 = max(0, x1-int(fw*0.4)), max(0, y1-int(fh*0.6))
|
| 23 |
+
hx2, hy2 = min(w0-1, x2+int(fw*0.4)), min(h0-1, y2+int(fh*1.2))
|
| 24 |
+
|
| 25 |
+
# Rotation
|
| 26 |
+
M = cv2.getRotationMatrix2D((w0/2, h0/2), -angle_deg, 1.0)
|
| 27 |
+
cos, sin = np.abs(M[0,0]), np.abs(M[0,1])
|
| 28 |
+
nW, nH = int((h0*sin)+(w0*cos)), int((h0*cos)+(w0*sin))
|
| 29 |
+
M[0,2] += (nW/2) - w0/2; M[1,2] += (nH/2) - h0/2
|
| 30 |
+
|
| 31 |
+
rot_img = cv2.warpAffine(img_bgr, M, (nW, nH), borderValue=self.bg_color)
|
| 32 |
+
|
| 33 |
+
# Rotate Box Points
|
| 34 |
+
pts = np.array([[hx1,hy1,1],[hx2,hy1,1],[hx2,hy2,1],[hx1,hy2,1]]).T
|
| 35 |
+
rot_pts = M @ pts
|
| 36 |
+
rx, ry = rot_pts[0,:], rot_pts[1,:]
|
| 37 |
+
return rot_img, (int(rx.min()), int(ry.min()), int(rx.max()), int(ry.max()))
|
| 38 |
+
|
| 39 |
+
def crop_to_ratio(self, img, box):
|
| 40 |
+
bx1, by1, bx2, by2 = box
|
| 41 |
+
bw, bh = bx2-bx1, by2-by1
|
| 42 |
+
if bw/bh > self.target_aspect: # Too wide
|
| 43 |
+
new_h = int(bw/self.target_aspect)
|
| 44 |
+
by1 -= (new_h - bh)//2; by2 = by1 + new_h
|
| 45 |
+
else: # Too tall
|
| 46 |
+
new_w = int(bh*self.target_aspect)
|
| 47 |
+
bx1 -= (new_w - bw)//2; bx2 = bx1 + new_w
|
| 48 |
+
|
| 49 |
+
# Canvas Crop
|
| 50 |
+
H, W = img.shape[:2]
|
| 51 |
+
canvas = np.full((by2-by1, bx2-bx1, 3), self.bg_color, dtype=np.uint8)
|
| 52 |
+
sx1, sy1 = max(0, bx1), max(0, by1)
|
| 53 |
+
sx2, sy2 = min(W, bx2), min(H, by2)
|
| 54 |
+
dx1, dy1 = max(0, sx1-bx1), max(0, sy1-by1)
|
| 55 |
+
if sx2>sx1 and sy2>sy1: canvas[dy1:dy1+(sy2-sy1), dx1:dx1+(sx2-sx1)] = img[sy1:sy2, sx1:sx2]
|
| 56 |
+
return cv2.resize(canvas, (self.out_w, self.out_h), interpolation=cv2.INTER_AREA)
|
| 57 |
+
|
app/util/passport_photo_engine/segmenter_rmbg.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torchvision.transforms.functional import normalize
|
| 4 |
+
from transformers import AutoModelForImageSegmentation
|
| 5 |
+
|
| 6 |
+
class SegmenterRMBG:
|
| 7 |
+
def __init__(self, device=None, model_input_size=(1024,1024)):
|
| 8 |
+
self.device = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
+
self.model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True).to(self.device)
|
| 10 |
+
self.modenl.eval()
|
| 11 |
+
self.model_input_size = list(model_input_size)
|
| 12 |
+
|
| 13 |
+
def _preprocess(self, img_np):
|
| 14 |
+
t = torch.tensor(img_np, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
|
| 15 |
+
t = F.interpolate(t, size=self.model_input_size, mode="bilinear") / 255.0
|
| 16 |
+
return normalize(t, [0.5]*3, [1.0]*3).to(self.device)
|
| 17 |
+
|
| 18 |
+
def _postprocess(self, result, orig_size):
|
| 19 |
+
H, W = orig_size
|
| 20 |
+
if result.dim() == 4: result = result
|
| 21 |
+
elif result.dim() == 3: result = result.unsqueeze(0)
|
| 22 |
+
result = F.interpolate(result, size=(H, W), mode="bilinear").squeeze(0).squeeze(0)
|
| 23 |
+
r_min, r_max = result.min(), result.max()
|
| 24 |
+
result = (result - r_min) / (r_max - r_min + 1e-8)
|
| 25 |
+
# Gamma correction for hair
|
| 26 |
+
result = torch.pow(result, 2.5)
|
| 27 |
+
result[result < 0.05] = 0
|
| 28 |
+
return (result * 255).cpu().numpy().astype(np.uint8)
|
| 29 |
+
|
| 30 |
+
def segment(self, img_rgb):
|
| 31 |
+
inp = self._preprocess(img_rgb)
|
| 32 |
+
with torch.no_grad(): out = self.model(inp)
|
| 33 |
+
if isinstance(out, (list, tuple)): out = out[0]
|
| 34 |
+
return self._postprocess(out[0], img_rgb.shape[:2])
|
requirements.txt
CHANGED
|
@@ -17,4 +17,13 @@ pandas
|
|
| 17 |
SQLAlchemy
|
| 18 |
psycopg2-binary
|
| 19 |
|
| 20 |
-
boto3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
SQLAlchemy
|
| 18 |
psycopg2-binary
|
| 19 |
|
| 20 |
+
boto3
|
| 21 |
+
|
| 22 |
+
numpy
|
| 23 |
+
opencv-python-headless
|
| 24 |
+
pillow
|
| 25 |
+
torch
|
| 26 |
+
torchvision
|
| 27 |
+
transformers
|
| 28 |
+
huggingface_hub
|
| 29 |
+
accelerate
|
server.py
CHANGED
|
@@ -3,23 +3,44 @@ import tempfile
|
|
| 3 |
os.environ["PLAYWRIGHT_BROWSERS_PATH"] = "/home/user/.cache/ms-playwright"
|
| 4 |
|
| 5 |
import logging
|
| 6 |
-
from flask import Flask, request, jsonify, send_file
|
| 7 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
import json
|
| 9 |
import requests
|
| 10 |
import uuid
|
| 11 |
import importlib
|
| 12 |
import io
|
|
|
|
|
|
|
| 13 |
|
| 14 |
from app.util.gen_ai_base import GenAIBaseClient
|
| 15 |
from app.util.browser_agent import BrowserAgent
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
from app.util.parameter_utils import init_secret
|
| 18 |
import sys
|
| 19 |
sys.stdout.reconfigure(line_buffering=True)
|
| 20 |
API = "https://api-dev.spun.global"
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
def create_app() -> Flask:
|
| 25 |
load_dotenv()
|
|
@@ -206,6 +227,53 @@ def create_app() -> Flask:
|
|
| 206 |
except Exception as e:
|
| 207 |
print(f"Error in /generate/{visa_type}: {e}")
|
| 208 |
return jsonify({"error": str(e)}), 500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
@app.route('/', methods=['GET'])
|
| 211 |
def hello_world():
|
|
|
|
| 3 |
os.environ["PLAYWRIGHT_BROWSERS_PATH"] = "/home/user/.cache/ms-playwright"
|
| 4 |
|
| 5 |
import logging
|
|
|
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
+
import io
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
import json
|
| 12 |
import requests
|
| 13 |
import uuid
|
| 14 |
import importlib
|
| 15 |
import io
|
| 16 |
+
from flask import Flask, request, jsonify, send_file
|
| 17 |
+
|
| 18 |
|
| 19 |
from app.util.gen_ai_base import GenAIBaseClient
|
| 20 |
from app.util.browser_agent import BrowserAgent
|
| 21 |
+
from app.util.passport_photo_engine.haar_face_detector import HaarFaceDetector
|
| 22 |
+
from app.util.passport_photo_engine.manual_face_extractor import ManualFaceExtractor
|
| 23 |
+
from app.util.passport_photo_engine.passport_cropper import PassportCropper
|
| 24 |
+
from app.util.passport_photo_engine.segmenter_rmbg import SegmenterRMBG
|
| 25 |
from app.util.parameter_utils import init_secret
|
| 26 |
import sys
|
| 27 |
sys.stdout.reconfigure(line_buffering=True)
|
| 28 |
API = "https://api-dev.spun.global"
|
| 29 |
|
| 30 |
+
print("--- Loading Passport AI Models (This happens once) ---")
|
| 31 |
+
passport_models = {
|
| 32 |
+
"segmenter": SegmenterRMBG(), # Heavy model (GPU/CPU)
|
| 33 |
+
"detector": HaarFaceDetector(), # Fast model
|
| 34 |
+
"extractor": ManualFaceExtractor()
|
| 35 |
+
}
|
| 36 |
+
print("--- Passport Models Ready ---")
|
| 37 |
+
|
| 38 |
+
PASSPORT_COLORS = {
|
| 39 |
+
"white": (255, 255, 255),
|
| 40 |
+
"id_red": (0, 0, 219),
|
| 41 |
+
"id_blue": (219, 0, 0),
|
| 42 |
+
"light_blue": (235, 206, 135)
|
| 43 |
+
}
|
| 44 |
|
| 45 |
def create_app() -> Flask:
|
| 46 |
load_dotenv()
|
|
|
|
| 227 |
except Exception as e:
|
| 228 |
print(f"Error in /generate/{visa_type}: {e}")
|
| 229 |
return jsonify({"error": str(e)}), 500
|
| 230 |
+
|
| 231 |
+
@app.route("/generate-passport-photo", methods=["POST"])
|
| 232 |
+
def generate_passport():
|
| 233 |
+
data = request.get_json()
|
| 234 |
+
bg_color_name = data.get('bg_color_name', 'white')
|
| 235 |
+
response = requests.get(data['raw_photo'], stream=True)
|
| 236 |
+
if response.status_code != 200:
|
| 237 |
+
return jsonify({"error": f"Failed to download image from S3. Status: {response.status_code}"}), 400
|
| 238 |
+
try:
|
| 239 |
+
# Read image
|
| 240 |
+
in_memory_file = io.BytesIO(response.content)
|
| 241 |
+
pil_image = Image.open(in_memory_file).convert("RGB")
|
| 242 |
+
img_rgb = np.array(pil_image)
|
| 243 |
+
|
| 244 |
+
# Get models
|
| 245 |
+
seg = passport_models["segmenter"]
|
| 246 |
+
det = passport_models["detector"]
|
| 247 |
+
ext = passport_models["extractor"]
|
| 248 |
+
|
| 249 |
+
# 1. Segment
|
| 250 |
+
mask = seg.segment(img_rgb)
|
| 251 |
+
|
| 252 |
+
# 2. Detect
|
| 253 |
+
face_rect = det.detect(img_rgb)
|
| 254 |
+
if face_rect is None:
|
| 255 |
+
return jsonify({"error": "No face detected"}), 400
|
| 256 |
+
x, y, w, h = face_rect
|
| 257 |
+
|
| 258 |
+
# 3. Angle
|
| 259 |
+
info = ext.extract_face(img_rgb, mask)
|
| 260 |
+
angle = info.get("chin_angle", 0.0)
|
| 261 |
+
|
| 262 |
+
# 4. Process
|
| 263 |
+
selected_bg = PASSPORT_COLORS.get(bg_color_name, (255, 255, 255))
|
| 264 |
+
cropper = PassportCropper(output_size=(600, 800), bg_color=selected_bg)
|
| 265 |
+
|
| 266 |
+
img_clean = cropper.composite(img_rgb, mask)
|
| 267 |
+
img_rot, rot_box = cropper.rotate_and_expand_face(img_clean, angle, (x,y,x+w,y+h))
|
| 268 |
+
final_passport = cropper.crop_to_ratio(img_rot, rot_box)
|
| 269 |
+
|
| 270 |
+
# Return result
|
| 271 |
+
is_success, buffer = cv2.imencode(".jpg", final_passport)
|
| 272 |
+
return send_file(io.BytesIO(buffer), mimetype='image/jpeg')
|
| 273 |
+
|
| 274 |
+
except Exception as e:
|
| 275 |
+
print(f"Passport Error: {e}")
|
| 276 |
+
return jsonify({"error": str(e)}), 500
|
| 277 |
|
| 278 |
@app.route('/', methods=['GET'])
|
| 279 |
def hello_world():
|