Upload full package: YOLO + ArcFace + Scripts
Browse files- README.md +114 -3
- config.json +8 -0
- finetune.py +260 -0
- inference.py +210 -0
- pytorch_model.bin +3 -0
- yolov8s-face-lindevs.onnx +3 -0
- yolov8s-face-lindevs.onnx:Zone.Identifier +0 -0
README.md
CHANGED
|
@@ -1,3 +1,114 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
readme_content = """---
|
| 4 |
+
tags:
|
| 5 |
+
- face-recognition
|
| 6 |
+
- yolo
|
| 7 |
+
- pytorch
|
| 8 |
+
- computer-vision
|
| 9 |
+
- arcface
|
| 10 |
+
- metric-learning
|
| 11 |
+
- biometrics
|
| 12 |
+
- 100m-parameters
|
| 13 |
+
library_name: generic
|
| 14 |
+
license: mit
|
| 15 |
+
pipeline_tag: image-feature-extraction
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
# 🧠 Face Recognition System (ArcFace + YOLOv8)
|
| 19 |
+
|
| 20 |
+

|
| 21 |
+

|
| 22 |
+

|
| 23 |
+

|
| 24 |
+
|
| 25 |
+
## 📖 Overview
|
| 26 |
+
|
| 27 |
+
This repository hosts a production-ready **Face Recognition Pipeline** designed for high-accuracy biometric identification. Unlike standard recognizers, this system integrates **YOLOv8** for robust face detection and alignment before feature extraction.
|
| 28 |
+
|
| 29 |
+
The core recognition model is built upon a **Wide ResNet-101-2** backbone, trained with a hybrid loss function (**ArcFace + Center Loss**) to generate highly discriminative 512-dimensional embeddings.
|
| 30 |
+
|
| 31 |
+
### 🌟 Key Features
|
| 32 |
+
- **Robust Detection**: Uses **YOLOv8 (ONNX)** to detect faces even in challenging lighting or angles.
|
| 33 |
+
- **High Accuracy**: Achieves **90.5%** accuracy on the LFW (Labeled Faces in the Wild) dataset and 90% on Validation.
|
| 34 |
+
- **Discriminative Embeddings**: 512-dim vectors optimized for Cosine Similarity.
|
| 35 |
+
- **Easy-to-Use API**: Includes a wrapper (`inference.py`) for 3-line code implementation.
|
| 36 |
+
- **Fine-tuning Ready**: Includes scripts to retrain the model on your custom dataset.
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## 🛠️ Installation
|
| 41 |
+
|
| 42 |
+
To run the pipeline, you need to install the necessary dependencies. We recommend using a virtual environment.
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # For CUDA support
|
| 46 |
+
pip install opencv-python onnxruntime-gpu huggingface_hub pillow tqdm numpy
|
| 47 |
+
```
|
| 48 |
+
## Step 1: Download the Wrapper
|
| 49 |
+
- **Download our helper script inference.py which handles model downloading and YOLO detection automatically.**
|
| 50 |
+
```bash
|
| 51 |
+
wget https://huggingface.co/biometric-ai-lab/Face_Recognition/resolve/main/inference.py
|
| 52 |
+
```
|
| 53 |
+
---
|
| 54 |
+
## Step 2: Create & Run Python Script
|
| 55 |
+
- **Create a new file named run_demo.py.**
|
| 56 |
+
- **Copy and paste the code below into it.**
|
| 57 |
+
- **Make sure you have 2 images to test (e.g., face1.jpg and face2.jpg).**
|
| 58 |
+
```bash
|
| 59 |
+
# File: run_demo.py
|
| 60 |
+
from inference import FaceAnalysis
|
| 61 |
+
|
| 62 |
+
# 1. Initialize the AI (Downloads models automatically on first run)
|
| 63 |
+
print("⏳ Initializing models...")
|
| 64 |
+
app = FaceAnalysis()
|
| 65 |
+
|
| 66 |
+
# 2. Define your images
|
| 67 |
+
img1_path = "face1.jpg" # <--- Change this to your image path
|
| 68 |
+
img2_path = "face2.jpg" # <--- Change this to your image path
|
| 69 |
+
|
| 70 |
+
# 3. Run Comparison
|
| 71 |
+
print(f"🔍 Comparing {img1_path} vs {img2_path}...")
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
# Get similarity score and boolean result
|
| 75 |
+
similarity, is_same = app.compare(img1_path, img2_path)
|
| 76 |
+
|
| 77 |
+
print("-" * 30)
|
| 78 |
+
print(f"🔹 Similarity Score: {similarity:.4f}")
|
| 79 |
+
print("-" * 30)
|
| 80 |
+
|
| 81 |
+
if is_same:
|
| 82 |
+
print("✅ RESULT: SAME PERSON")
|
| 83 |
+
else:
|
| 84 |
+
print("❌ RESULT: DIFFERENT PERSON")
|
| 85 |
+
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"Error: {e}")
|
| 88 |
+
print("Tip: Make sure the image paths are correct!")
|
| 89 |
+
```
|
| 90 |
+
---
|
| 91 |
+
## 🎓 Training Guide
|
| 92 |
+
Option: Full Training (Advanced): Use train.py to train the model from scratch (ImageNet weights) on a large dataset.
|
| 93 |
+
**Step 1: Prepare Dataset**
|
| 94 |
+
- **Organize images in ImageFolder format**
|
| 95 |
+
```bash
|
| 96 |
+
dataset/
|
| 97 |
+
├── person_1/
|
| 98 |
+
│ ├── img1.jpg
|
| 99 |
+
│ └── ...
|
| 100 |
+
└── person_2/
|
| 101 |
+
└── img1.jpg
|
| 102 |
+
```
|
| 103 |
+
**Step 2: Run Training**
|
| 104 |
+
```bash
|
| 105 |
+
python train.py \\
|
| 106 |
+
--data_dir ./dataset \\
|
| 107 |
+
--output_dir ./checkpoints \\
|
| 108 |
+
--epochs 50 \\
|
| 109 |
+
--batch_size 64 \\
|
| 110 |
+
--lr_backbone 8e-6 \\
|
| 111 |
+
--lr_head 8e-5
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
|
config.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "face_recognition",
|
| 3 |
+
"backbone": "wide_resnet101_2",
|
| 4 |
+
"embedding_size": 512,
|
| 5 |
+
"num_classes": 100,
|
| 6 |
+
"test_accuracy": 90.5,
|
| 7 |
+
"test_dataset": "LFW"
|
| 8 |
+
}
|
finetune.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision import transforms, datasets
|
| 6 |
+
import torchvision.models as models
|
| 7 |
+
from torch.utils.data import DataLoader, random_split
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ==========================================
|
| 14 |
+
# 1. MODEL ARCHITECTURE
|
| 15 |
+
# ==========================================
|
| 16 |
+
class FaceRecognitionModel(nn.Module):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super(FaceRecognitionModel, self).__init__()
|
| 19 |
+
# Load backbone
|
| 20 |
+
print("🏗️ Loading Backbone: Wide ResNet-101-2...")
|
| 21 |
+
self.backbone = models.wide_resnet101_2(weights='IMAGENET1K_V2')
|
| 22 |
+
self.backbone.fc = nn.Identity()
|
| 23 |
+
|
| 24 |
+
# Embedding Head
|
| 25 |
+
self.embed = nn.Sequential(
|
| 26 |
+
nn.Linear(2048, 512),
|
| 27 |
+
nn.BatchNorm1d(512),
|
| 28 |
+
nn.ReLU(inplace=True)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, img):
|
| 32 |
+
features = self.backbone(img)
|
| 33 |
+
embedding = self.embed(features)
|
| 34 |
+
# Normalize to hypersphere
|
| 35 |
+
return F.normalize(embedding, p=2, dim=1)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ==========================================
|
| 39 |
+
# 2. LOSS FUNCTIONS
|
| 40 |
+
# ==========================================
|
| 41 |
+
class ArcFaceLoss(nn.Module):
|
| 42 |
+
def __init__(self, num_classes, embedding_size=512, margin=0.5, scale=64):
|
| 43 |
+
super(ArcFaceLoss, self).__init__()
|
| 44 |
+
self.margin = margin
|
| 45 |
+
self.scale = scale
|
| 46 |
+
self.weight = nn.Parameter(torch.Tensor(num_classes, embedding_size))
|
| 47 |
+
nn.init.xavier_uniform_(self.weight)
|
| 48 |
+
|
| 49 |
+
def forward(self, embeddings, labels):
|
| 50 |
+
W = F.normalize(self.weight, dim=1)
|
| 51 |
+
x = F.normalize(embeddings, dim=1)
|
| 52 |
+
|
| 53 |
+
cosine = torch.matmul(x, W.t())
|
| 54 |
+
cosine = cosine.clamp(-1 + 1e-7, 1 - 1e-7)
|
| 55 |
+
|
| 56 |
+
theta = torch.acos(cosine)
|
| 57 |
+
target_logits = torch.cos(theta + self.margin)
|
| 58 |
+
|
| 59 |
+
one_hot = torch.zeros_like(cosine)
|
| 60 |
+
one_hot.scatter_(1, labels.view(-1, 1), 1.0)
|
| 61 |
+
|
| 62 |
+
output = cosine * (1 - one_hot) + target_logits * one_hot
|
| 63 |
+
output = output * self.scale
|
| 64 |
+
return output
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class CenterLoss(nn.Module):
|
| 68 |
+
def __init__(self, num_classes, embedding_size=512):
|
| 69 |
+
super(CenterLoss, self).__init__()
|
| 70 |
+
self.centers = nn.Parameter(torch.randn(num_classes, embedding_size))
|
| 71 |
+
nn.init.xavier_uniform_(self.centers)
|
| 72 |
+
|
| 73 |
+
def forward(self, embeddings, labels):
|
| 74 |
+
centers_norm = F.normalize(self.centers, p=2, dim=1)
|
| 75 |
+
centers_batch = centers_norm[labels]
|
| 76 |
+
cosine_sim = (embeddings * centers_batch).sum(dim=1)
|
| 77 |
+
loss = (1.0 - cosine_sim).mean()
|
| 78 |
+
return loss
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ==========================================
|
| 82 |
+
# 3. DATA LOADER
|
| 83 |
+
# ==========================================
|
| 84 |
+
def get_dataloader(data_dir, batch_size=64, num_workers=4, split_ratio=0.9):
|
| 85 |
+
print(f"📂 Loading Data from: {data_dir}")
|
| 86 |
+
|
| 87 |
+
# Strong Augmentation for Training
|
| 88 |
+
transform_train = transforms.Compose([
|
| 89 |
+
transforms.Resize((256, 256)),
|
| 90 |
+
transforms.RandomCrop((224, 224)),
|
| 91 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
| 92 |
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.25, hue=0.08),
|
| 93 |
+
transforms.RandomGrayscale(p=0.1),
|
| 94 |
+
transforms.RandomRotation(degrees=10),
|
| 95 |
+
transforms.RandomAffine(degrees=0, translate=(0.08, 0.08), scale=(0.92, 1.08)),
|
| 96 |
+
transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.3),
|
| 97 |
+
transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
|
| 98 |
+
transforms.ToTensor(),
|
| 99 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 100 |
+
transforms.RandomErasing(p=0.25, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
|
| 101 |
+
])
|
| 102 |
+
|
| 103 |
+
# Standard Transform for Validation
|
| 104 |
+
transform_val = transforms.Compose([
|
| 105 |
+
transforms.Resize((224, 224)),
|
| 106 |
+
transforms.ToTensor(),
|
| 107 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 108 |
+
])
|
| 109 |
+
|
| 110 |
+
full_dataset = datasets.ImageFolder(root=data_dir, transform=transform_train)
|
| 111 |
+
num_classes = len(full_dataset.classes)
|
| 112 |
+
|
| 113 |
+
# Split Train/Val
|
| 114 |
+
train_size = int(split_ratio * len(full_dataset))
|
| 115 |
+
val_size = len(full_dataset) - train_size
|
| 116 |
+
train_set, val_set = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
|
| 117 |
+
|
| 118 |
+
# Apply specific transform to validation set
|
| 119 |
+
val_set.dataset.transform = transform_val
|
| 120 |
+
|
| 121 |
+
print(f" ✅ Classes: {num_classes}")
|
| 122 |
+
print(f" ✅ Train Images: {len(train_set)}")
|
| 123 |
+
print(f" ✅ Val Images: {len(val_set)}")
|
| 124 |
+
|
| 125 |
+
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
|
| 126 |
+
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
|
| 127 |
+
|
| 128 |
+
return train_loader, val_loader, num_classes
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ==========================================
|
| 132 |
+
# 4. TRAINING ENGINE
|
| 133 |
+
# ==========================================
|
| 134 |
+
def evaluate(model, arcface, val_loader, criterion, device):
|
| 135 |
+
model.eval()
|
| 136 |
+
arcface.eval()
|
| 137 |
+
total_loss = 0
|
| 138 |
+
correct = 0
|
| 139 |
+
total = 0
|
| 140 |
+
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
for imgs, labels in tqdm(val_loader, desc=" 🧪 Evaluating"):
|
| 143 |
+
imgs, labels = imgs.to(device), labels.to(device)
|
| 144 |
+
embeddings = model(imgs)
|
| 145 |
+
logits = arcface(embeddings, labels)
|
| 146 |
+
loss = criterion(logits, labels)
|
| 147 |
+
|
| 148 |
+
total_loss += loss.item()
|
| 149 |
+
_, predicted = torch.max(logits.data, 1)
|
| 150 |
+
total += labels.size(0)
|
| 151 |
+
correct += (predicted == labels).sum().item()
|
| 152 |
+
|
| 153 |
+
return total_loss / len(val_loader), 100 * correct / total
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def main(args):
|
| 157 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 158 |
+
print(f"🚀 Device: {device}")
|
| 159 |
+
|
| 160 |
+
# Data
|
| 161 |
+
train_loader, val_loader, num_classes = get_dataloader(args.data_dir, args.batch_size, args.num_workers)
|
| 162 |
+
|
| 163 |
+
# Models
|
| 164 |
+
model = FaceRecognitionModel().to(device)
|
| 165 |
+
arcface = ArcFaceLoss(num_classes=num_classes).to(device)
|
| 166 |
+
center_loss = CenterLoss(num_classes=num_classes).to(device)
|
| 167 |
+
|
| 168 |
+
# Load Checkpoint (Resume)
|
| 169 |
+
start_epoch = 0
|
| 170 |
+
if args.resume and os.path.exists(args.resume):
|
| 171 |
+
print(f"🔄 Resuming from {args.resume}...")
|
| 172 |
+
checkpoint = torch.load(args.resume, map_location=device)
|
| 173 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 174 |
+
arcface.load_state_dict(checkpoint['arcface_state_dict'])
|
| 175 |
+
if 'center_loss_state_dict' in checkpoint:
|
| 176 |
+
center_loss.load_state_dict(checkpoint['center_loss_state_dict'])
|
| 177 |
+
start_epoch = checkpoint.get('epoch', 0)
|
| 178 |
+
|
| 179 |
+
# Optimizer
|
| 180 |
+
optimizer = torch.optim.Adam([
|
| 181 |
+
{'params': model.backbone.parameters(), 'lr': args.lr_backbone},
|
| 182 |
+
{'params': model.embed.parameters(), 'lr': args.lr_head},
|
| 183 |
+
{'params': arcface.parameters(), 'lr': args.lr_head},
|
| 184 |
+
{'params': center_loss.parameters(), 'lr': 1e-4}
|
| 185 |
+
], weight_decay=1e-3)
|
| 186 |
+
|
| 187 |
+
criterion = nn.CrossEntropyLoss()
|
| 188 |
+
best_acc = 0.0
|
| 189 |
+
|
| 190 |
+
# Training Loop
|
| 191 |
+
print("\n🔥 START TRAINING...")
|
| 192 |
+
for epoch in range(start_epoch, args.epochs):
|
| 193 |
+
model.train()
|
| 194 |
+
total_loss = 0
|
| 195 |
+
|
| 196 |
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}")
|
| 197 |
+
for imgs, labels in pbar:
|
| 198 |
+
imgs, labels = imgs.to(device), labels.to(device)
|
| 199 |
+
|
| 200 |
+
# Forward
|
| 201 |
+
embeddings = model(imgs)
|
| 202 |
+
logits = arcface(embeddings, labels)
|
| 203 |
+
|
| 204 |
+
# Loss Calculation
|
| 205 |
+
loss_ce = criterion(logits, labels)
|
| 206 |
+
loss_center = center_loss(embeddings, labels)
|
| 207 |
+
loss = loss_ce + (args.lambda_center * loss_center)
|
| 208 |
+
|
| 209 |
+
# Backward
|
| 210 |
+
optimizer.zero_grad()
|
| 211 |
+
loss.backward()
|
| 212 |
+
optimizer.step()
|
| 213 |
+
|
| 214 |
+
total_loss += loss.item()
|
| 215 |
+
pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'CE': f"{loss_ce.item():.4f}"})
|
| 216 |
+
|
| 217 |
+
# Save Checkpoint
|
| 218 |
+
save_dict = {
|
| 219 |
+
'epoch': epoch + 1,
|
| 220 |
+
'model_state_dict': model.state_dict(),
|
| 221 |
+
'arcface_state_dict': arcface.state_dict(),
|
| 222 |
+
'center_loss_state_dict': center_loss.state_dict(),
|
| 223 |
+
'num_classes': num_classes
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
# Save Last
|
| 227 |
+
torch.save(save_dict, os.path.join(args.output_dir, "last_checkpoint.bin"))
|
| 228 |
+
|
| 229 |
+
# Evaluate & Save Best
|
| 230 |
+
val_loss, val_acc = evaluate(model, arcface, val_loader, criterion, device)
|
| 231 |
+
print(f" 🏆 Epoch {epoch + 1} | Val Loss: {val_loss:.4f} | Accuracy: {val_acc:.2f}%")
|
| 232 |
+
|
| 233 |
+
if val_acc > best_acc:
|
| 234 |
+
best_acc = val_acc
|
| 235 |
+
print(f" 💾 Saving New Best Model (Acc: {best_acc:.2f}%)")
|
| 236 |
+
torch.save(save_dict, os.path.join(args.output_dir, "pytorch_model.bin"))
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
if __name__ == "__main__":
|
| 240 |
+
parser = argparse.ArgumentParser(description="Train Face Recognition Model (ArcFace + CenterLoss)")
|
| 241 |
+
|
| 242 |
+
# Required
|
| 243 |
+
parser.add_argument('--data_dir', type=str, required=True, help="Path to ImageFolder dataset")
|
| 244 |
+
|
| 245 |
+
# Optional
|
| 246 |
+
parser.add_argument('--output_dir', type=str, default=".", help="Where to save .bin files")
|
| 247 |
+
parser.add_argument('--resume', type=str, default=None, help="Path to checkpoint to resume")
|
| 248 |
+
parser.add_argument('--epochs', type=int, default=20)
|
| 249 |
+
parser.add_argument('--batch_size', type=int, default=64)
|
| 250 |
+
parser.add_argument('--num_workers', type=int, default=4)
|
| 251 |
+
|
| 252 |
+
# Hyperparameters
|
| 253 |
+
parser.add_argument('--lr_backbone', type=float, default=8e-6)
|
| 254 |
+
parser.add_argument('--lr_head', type=float, default=8e-5)
|
| 255 |
+
parser.add_argument('--lambda_center', type=float, default=0.18)
|
| 256 |
+
|
| 257 |
+
args = parser.parse_args()
|
| 258 |
+
|
| 259 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 260 |
+
main(args)
|
inference.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
import torchvision.models as models
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import onnxruntime as ort
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
| 12 |
+
|
| 13 |
+
# ==========================================
|
| 14 |
+
# CẤU HÌNH REPO
|
| 15 |
+
# ==========================================
|
| 16 |
+
REPO_ID = "biometric-ai-lab/Face_Recognition"
|
| 17 |
+
RECOG_FILENAME = "pytorch_model.bin"
|
| 18 |
+
YOLO_FILENAME = "yolov8s-face-lindevs.onnx"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ==========================================
|
| 22 |
+
# 1. MODEL ARCHITECTURE (Giống hệt code bạn)
|
| 23 |
+
# ==========================================
|
| 24 |
+
class FaceRecognitionModel(nn.Module):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super(FaceRecognitionModel, self).__init__()
|
| 27 |
+
# Khởi tạo backbone, để weights=None vì ta sẽ load weight train của bạn
|
| 28 |
+
self.backbone = models.wide_resnet101_2(weights=None)
|
| 29 |
+
self.backbone.fc = nn.Identity()
|
| 30 |
+
self.embed = nn.Sequential(
|
| 31 |
+
nn.Linear(2048, 512),
|
| 32 |
+
nn.BatchNorm1d(512),
|
| 33 |
+
nn.ReLU(inplace=True),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, img):
|
| 37 |
+
features = self.backbone(img)
|
| 38 |
+
embedding = self.embed(features)
|
| 39 |
+
return F.normalize(embedding, p=2, dim=1)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ==========================================
|
| 43 |
+
# 2. YOLO DETECTOR (Logic chuẩn của bạn)
|
| 44 |
+
# ==========================================
|
| 45 |
+
class YOLOFaceDetector:
|
| 46 |
+
def __init__(self, model_path, conf_threshold=0.5):
|
| 47 |
+
self.session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
| 48 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 49 |
+
self.output_names = [output.name for output in self.session.get_outputs()]
|
| 50 |
+
self.conf_threshold = conf_threshold
|
| 51 |
+
self.input_size = 640
|
| 52 |
+
|
| 53 |
+
def detect_extract_face(self, image_pil, expand_ratio=0.0):
|
| 54 |
+
"""
|
| 55 |
+
Input: PIL Image
|
| 56 |
+
Output: PIL Image (Cropped Face)
|
| 57 |
+
"""
|
| 58 |
+
# Convert PIL -> OpenCV (BGR) để giống logic cũ
|
| 59 |
+
image_np = np.array(image_pil)
|
| 60 |
+
image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
| 61 |
+
img_height, img_width = image_bgr.shape[:2]
|
| 62 |
+
|
| 63 |
+
# Preprocess (Resize -> RGB -> Norm -> Transpose)
|
| 64 |
+
img_resized = cv2.resize(image_bgr, (self.input_size, self.input_size))
|
| 65 |
+
# Lưu ý: YOLO training thường dùng RGB
|
| 66 |
+
img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
|
| 67 |
+
img_normalized = img_rgb.astype(np.float32) / 255.0
|
| 68 |
+
img_transposed = np.transpose(img_normalized, (2, 0, 1))
|
| 69 |
+
img_batch = np.expand_dims(img_transposed, axis=0)
|
| 70 |
+
|
| 71 |
+
# Inference
|
| 72 |
+
outputs = self.session.run(self.output_names, {self.input_name: img_batch})
|
| 73 |
+
predictions = outputs[0]
|
| 74 |
+
|
| 75 |
+
if len(predictions.shape) == 3:
|
| 76 |
+
predictions = predictions[0].T
|
| 77 |
+
|
| 78 |
+
best_face = None
|
| 79 |
+
max_area = 0
|
| 80 |
+
|
| 81 |
+
# Post-process
|
| 82 |
+
for pred in predictions:
|
| 83 |
+
conf = pred[4]
|
| 84 |
+
if conf > self.conf_threshold:
|
| 85 |
+
x_center, y_center, w, h = pred[:4]
|
| 86 |
+
|
| 87 |
+
# Scale về ảnh gốc
|
| 88 |
+
x_center = x_center * img_width / self.input_size
|
| 89 |
+
y_center = y_center * img_height / self.input_size
|
| 90 |
+
w = w * img_width / self.input_size
|
| 91 |
+
h = h * img_height / self.input_size
|
| 92 |
+
|
| 93 |
+
x1 = int(x_center - w / 2)
|
| 94 |
+
y1 = int(y_center - h / 2)
|
| 95 |
+
x2 = int(x_center + w / 2)
|
| 96 |
+
y2 = int(y_center + h / 2)
|
| 97 |
+
|
| 98 |
+
x1 = max(0, x1)
|
| 99 |
+
y1 = max(0, y1)
|
| 100 |
+
x2 = min(img_width, x2)
|
| 101 |
+
y2 = min(img_height, y2)
|
| 102 |
+
|
| 103 |
+
area = (x2 - x1) * (y2 - y1)
|
| 104 |
+
|
| 105 |
+
# Lấy mặt to nhất
|
| 106 |
+
if area > max_area:
|
| 107 |
+
max_area = area
|
| 108 |
+
best_face = (x1, y1, x2, y2)
|
| 109 |
+
|
| 110 |
+
# Crop ảnh
|
| 111 |
+
if best_face:
|
| 112 |
+
x1, y1, x2, y2 = best_face
|
| 113 |
+
|
| 114 |
+
# Xử lý expand_ratio (nếu có dùng)
|
| 115 |
+
if expand_ratio != 0:
|
| 116 |
+
w_box = x2 - x1
|
| 117 |
+
h_box = y2 - y1
|
| 118 |
+
pad = int(expand_ratio * max(w_box, h_box))
|
| 119 |
+
x1 = max(0, x1 - pad)
|
| 120 |
+
y1 = max(0, y1 - pad)
|
| 121 |
+
x2 = min(img_width, x2 + pad)
|
| 122 |
+
y2 = min(img_height, y2 + pad)
|
| 123 |
+
|
| 124 |
+
# Crop từ ảnh gốc PIL (để giữ chất lượng tốt nhất)
|
| 125 |
+
return image_pil.crop((x1, y1, x2, y2))
|
| 126 |
+
|
| 127 |
+
print("⚠️ Warning: No face detected. Using full image.")
|
| 128 |
+
return image_pil
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ==========================================
|
| 132 |
+
# 3. FACE ANALYSIS WRAPPER
|
| 133 |
+
# ==========================================
|
| 134 |
+
class FaceAnalysis:
|
| 135 |
+
def __init__(self, device=None):
|
| 136 |
+
self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
|
| 137 |
+
print(f"🚀 Initializing Face Analysis on {self.device}...")
|
| 138 |
+
|
| 139 |
+
# 1. Tải Model
|
| 140 |
+
try:
|
| 141 |
+
print(f"📥 Checking models from {REPO_ID}...")
|
| 142 |
+
recog_path = hf_hub_download(repo_id=REPO_ID, filename=RECOG_FILENAME)
|
| 143 |
+
yolo_path = hf_hub_download(repo_id=REPO_ID, filename=YOLO_FILENAME)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
raise RuntimeError(f"❌ Failed to download models. Check internet or Repo ID.\nError: {e}")
|
| 146 |
+
|
| 147 |
+
# 2. Init YOLO
|
| 148 |
+
self.yolo = YOLOFaceDetector(yolo_path, conf_threshold=0.5)
|
| 149 |
+
|
| 150 |
+
# 3. Init Recognition
|
| 151 |
+
self.model = FaceRecognitionModel().to(self.device)
|
| 152 |
+
|
| 153 |
+
# Load weights an toàn
|
| 154 |
+
checkpoint = torch.load(recog_path, map_location=self.device)
|
| 155 |
+
if 'model' in checkpoint:
|
| 156 |
+
self.model.load_state_dict(checkpoint['model'])
|
| 157 |
+
else:
|
| 158 |
+
# Fallback nếu file chỉ chứa weight không
|
| 159 |
+
self.model.load_state_dict(checkpoint)
|
| 160 |
+
|
| 161 |
+
self.model.eval()
|
| 162 |
+
|
| 163 |
+
# 4. Transform (Giống hệt inference_transform của bạn)
|
| 164 |
+
self.transform = transforms.Compose([
|
| 165 |
+
transforms.Resize((224, 224)),
|
| 166 |
+
transforms.ToTensor(),
|
| 167 |
+
transforms.Normalize(
|
| 168 |
+
mean=[0.485, 0.456, 0.406],
|
| 169 |
+
std=[0.229, 0.224, 0.225],
|
| 170 |
+
),
|
| 171 |
+
])
|
| 172 |
+
print("✅ System Ready!")
|
| 173 |
+
|
| 174 |
+
def process_image(self, image_source, expand_ratio=0.0):
|
| 175 |
+
# Load ảnh
|
| 176 |
+
if isinstance(image_source, str):
|
| 177 |
+
if not os.path.exists(image_source):
|
| 178 |
+
raise FileNotFoundError(f"Image not found: {image_source}")
|
| 179 |
+
img_pil = Image.open(image_source).convert('RGB')
|
| 180 |
+
elif isinstance(image_source, Image.Image):
|
| 181 |
+
img_pil = image_source.convert('RGB')
|
| 182 |
+
elif isinstance(image_source, np.ndarray):
|
| 183 |
+
img_pil = Image.fromarray(cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB))
|
| 184 |
+
else:
|
| 185 |
+
raise ValueError("Input must be filepath, PIL Image, or Numpy Array")
|
| 186 |
+
|
| 187 |
+
# 1. YOLO Detect & Crop
|
| 188 |
+
face_crop = self.yolo.detect_extract_face(img_pil, expand_ratio=expand_ratio)
|
| 189 |
+
|
| 190 |
+
# 2. Transform & Embedding
|
| 191 |
+
img_tensor = self.transform(face_crop).unsqueeze(0).to(self.device)
|
| 192 |
+
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
embedding = self.model(img_tensor)
|
| 195 |
+
|
| 196 |
+
return embedding
|
| 197 |
+
|
| 198 |
+
def compare(self, img1, img2, threshold=0.45, expand_ratio=0.01):
|
| 199 |
+
"""
|
| 200 |
+
So sánh 2 ảnh.
|
| 201 |
+
expand_ratio=0.01 giống code demo của bạn.
|
| 202 |
+
"""
|
| 203 |
+
emb1 = self.process_image(img1, expand_ratio)
|
| 204 |
+
emb2 = self.process_image(img2, expand_ratio)
|
| 205 |
+
|
| 206 |
+
# Cosine Similarity
|
| 207 |
+
similarity = F.cosine_similarity(emb1, emb2).item()
|
| 208 |
+
is_same = similarity > threshold
|
| 209 |
+
|
| 210 |
+
return similarity, is_same
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64d4b0af1cbe947a0bcf2996d362cf973bd30f277f55e67a748a409fd733385a
|
| 3 |
+
size 529070510
|
yolov8s-face-lindevs.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a6d19f2f68d7f0cc8104ab5c9eaa54b63e298f91dcfefd4be897f94a1561d02
|
| 3 |
+
size 44731626
|
yolov8s-face-lindevs.onnx:Zone.Identifier
ADDED
|
Binary file (25 Bytes). View file
|
|
|