|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms |
|
|
import torchvision.models as models |
|
|
from PIL import Image |
|
|
import onnxruntime as ort |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "biometric-ai-lab/Face_Recognition" |
|
|
RECOG_FILENAME = "pytorch_model.bin" |
|
|
YOLO_FILENAME = "yolov8s-face-lindevs.onnx" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FaceRecognitionModel(nn.Module): |
|
|
def __init__(self): |
|
|
super(FaceRecognitionModel, self).__init__() |
|
|
|
|
|
self.backbone = models.wide_resnet101_2(weights=None) |
|
|
self.backbone.fc = nn.Identity() |
|
|
self.embed = nn.Sequential( |
|
|
nn.Linear(2048, 512), |
|
|
nn.BatchNorm1d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
def forward(self, img): |
|
|
features = self.backbone(img) |
|
|
embedding = self.embed(features) |
|
|
return F.normalize(embedding, p=2, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class YOLOFaceDetector: |
|
|
def __init__(self, model_path, conf_threshold=0.5): |
|
|
self.session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) |
|
|
self.input_name = self.session.get_inputs()[0].name |
|
|
self.output_names = [output.name for output in self.session.get_outputs()] |
|
|
self.conf_threshold = conf_threshold |
|
|
self.input_size = 640 |
|
|
|
|
|
def detect_extract_face(self, image_pil, expand_ratio=0.0): |
|
|
""" |
|
|
Input: PIL Image |
|
|
Output: PIL Image (Cropped Face) |
|
|
""" |
|
|
|
|
|
image_np = np.array(image_pil) |
|
|
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) |
|
|
img_height, img_width = image_bgr.shape[:2] |
|
|
|
|
|
|
|
|
img_resized = cv2.resize(image_bgr, (self.input_size, self.input_size)) |
|
|
|
|
|
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB) |
|
|
img_normalized = img_rgb.astype(np.float32) / 255.0 |
|
|
img_transposed = np.transpose(img_normalized, (2, 0, 1)) |
|
|
img_batch = np.expand_dims(img_transposed, axis=0) |
|
|
|
|
|
|
|
|
outputs = self.session.run(self.output_names, {self.input_name: img_batch}) |
|
|
predictions = outputs[0] |
|
|
|
|
|
if len(predictions.shape) == 3: |
|
|
predictions = predictions[0].T |
|
|
|
|
|
best_face = None |
|
|
max_area = 0 |
|
|
|
|
|
|
|
|
for pred in predictions: |
|
|
conf = pred[4] |
|
|
if conf > self.conf_threshold: |
|
|
x_center, y_center, w, h = pred[:4] |
|
|
|
|
|
|
|
|
x_center = x_center * img_width / self.input_size |
|
|
y_center = y_center * img_height / self.input_size |
|
|
w = w * img_width / self.input_size |
|
|
h = h * img_height / self.input_size |
|
|
|
|
|
x1 = int(x_center - w / 2) |
|
|
y1 = int(y_center - h / 2) |
|
|
x2 = int(x_center + w / 2) |
|
|
y2 = int(y_center + h / 2) |
|
|
|
|
|
x1 = max(0, x1) |
|
|
y1 = max(0, y1) |
|
|
x2 = min(img_width, x2) |
|
|
y2 = min(img_height, y2) |
|
|
|
|
|
area = (x2 - x1) * (y2 - y1) |
|
|
|
|
|
|
|
|
if area > max_area: |
|
|
max_area = area |
|
|
best_face = (x1, y1, x2, y2) |
|
|
|
|
|
|
|
|
if best_face: |
|
|
x1, y1, x2, y2 = best_face |
|
|
|
|
|
|
|
|
if expand_ratio != 0: |
|
|
w_box = x2 - x1 |
|
|
h_box = y2 - y1 |
|
|
pad = int(expand_ratio * max(w_box, h_box)) |
|
|
x1 = max(0, x1 - pad) |
|
|
y1 = max(0, y1 - pad) |
|
|
x2 = min(img_width, x2 + pad) |
|
|
y2 = min(img_height, y2 + pad) |
|
|
|
|
|
|
|
|
return image_pil.crop((x1, y1, x2, y2)) |
|
|
|
|
|
print("⚠️ Warning: No face detected. Using full image.") |
|
|
return image_pil |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FaceAnalysis: |
|
|
def __init__(self, device=None): |
|
|
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"🚀 Initializing Face Analysis on {self.device}...") |
|
|
|
|
|
|
|
|
try: |
|
|
print(f"📥 Checking models from {REPO_ID}...") |
|
|
recog_path = hf_hub_download(repo_id=REPO_ID, filename=RECOG_FILENAME) |
|
|
yolo_path = hf_hub_download(repo_id=REPO_ID, filename=YOLO_FILENAME) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"❌ Failed to download models. Check internet or Repo ID.\nError: {e}") |
|
|
|
|
|
|
|
|
self.yolo = YOLOFaceDetector(yolo_path, conf_threshold=0.5) |
|
|
|
|
|
|
|
|
self.model = FaceRecognitionModel().to(self.device) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(recog_path, map_location=self.device) |
|
|
if 'model' in checkpoint: |
|
|
self.model.load_state_dict(checkpoint['model']) |
|
|
else: |
|
|
|
|
|
self.model.load_state_dict(checkpoint) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225], |
|
|
), |
|
|
]) |
|
|
print("✅ System Ready!") |
|
|
|
|
|
def process_image(self, image_source, expand_ratio=0.0): |
|
|
|
|
|
if isinstance(image_source, str): |
|
|
if not os.path.exists(image_source): |
|
|
raise FileNotFoundError(f"Image not found: {image_source}") |
|
|
img_pil = Image.open(image_source).convert('RGB') |
|
|
elif isinstance(image_source, Image.Image): |
|
|
img_pil = image_source.convert('RGB') |
|
|
elif isinstance(image_source, np.ndarray): |
|
|
img_pil = Image.fromarray(cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)) |
|
|
else: |
|
|
raise ValueError("Input must be filepath, PIL Image, or Numpy Array") |
|
|
|
|
|
|
|
|
face_crop = self.yolo.detect_extract_face(img_pil, expand_ratio=expand_ratio) |
|
|
|
|
|
|
|
|
img_tensor = self.transform(face_crop).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
embedding = self.model(img_tensor) |
|
|
|
|
|
return embedding |
|
|
|
|
|
def compare(self, img1, img2, threshold=0.45, expand_ratio=0.01): |
|
|
""" |
|
|
So sánh 2 ảnh. |
|
|
expand_ratio=0.01 giống code demo của bạn. |
|
|
""" |
|
|
emb1 = self.process_image(img1, expand_ratio) |
|
|
emb2 = self.process_image(img2, expand_ratio) |
|
|
|
|
|
|
|
|
similarity = F.cosine_similarity(emb1, emb2).item() |
|
|
is_same = similarity > threshold |
|
|
|
|
|
return similarity, is_same |