le312113 commited on
Commit
83d5d1c
·
verified ·
1 Parent(s): e83a2ae

Upload full package: YOLO + ArcFace + Scripts

Browse files
README.md CHANGED
@@ -1,3 +1,114 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ readme_content = """---
4
+ tags:
5
+ - face-recognition
6
+ - yolo
7
+ - pytorch
8
+ - computer-vision
9
+ - arcface
10
+ - metric-learning
11
+ - biometrics
12
+ - 100m-parameters
13
+ library_name: generic
14
+ license: mit
15
+ pipeline_tag: image-feature-extraction
16
+ ---
17
+
18
+ # 🧠 Face Recognition System (ArcFace + YOLOv8)
19
+
20
+ ![Python](https://img.shields.io/badge/Python-3.8%2B-blue)
21
+ ![PyTorch](https://img.shields.io/badge/PyTorch-2.0%2B-orange)
22
+ ![Status](https://img.shields.io/badge/Status-Stable-green)
23
+ ![License](https://img.shields.io/badge/License-MIT-yellow)
24
+
25
+ ## 📖 Overview
26
+
27
+ This repository hosts a production-ready **Face Recognition Pipeline** designed for high-accuracy biometric identification. Unlike standard recognizers, this system integrates **YOLOv8** for robust face detection and alignment before feature extraction.
28
+
29
+ The core recognition model is built upon a **Wide ResNet-101-2** backbone, trained with a hybrid loss function (**ArcFace + Center Loss**) to generate highly discriminative 512-dimensional embeddings.
30
+
31
+ ### 🌟 Key Features
32
+ - **Robust Detection**: Uses **YOLOv8 (ONNX)** to detect faces even in challenging lighting or angles.
33
+ - **High Accuracy**: Achieves **90.5%** accuracy on the LFW (Labeled Faces in the Wild) dataset and 90% on Validation.
34
+ - **Discriminative Embeddings**: 512-dim vectors optimized for Cosine Similarity.
35
+ - **Easy-to-Use API**: Includes a wrapper (`inference.py`) for 3-line code implementation.
36
+ - **Fine-tuning Ready**: Includes scripts to retrain the model on your custom dataset.
37
+
38
+ ---
39
+
40
+ ## 🛠️ Installation
41
+
42
+ To run the pipeline, you need to install the necessary dependencies. We recommend using a virtual environment.
43
+
44
+ ```bash
45
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # For CUDA support
46
+ pip install opencv-python onnxruntime-gpu huggingface_hub pillow tqdm numpy
47
+ ```
48
+ ## Step 1: Download the Wrapper
49
+ - **Download our helper script inference.py which handles model downloading and YOLO detection automatically.**
50
+ ```bash
51
+ wget https://huggingface.co/biometric-ai-lab/Face_Recognition/resolve/main/inference.py
52
+ ```
53
+ ---
54
+ ## Step 2: Create & Run Python Script
55
+ - **Create a new file named run_demo.py.**
56
+ - **Copy and paste the code below into it.**
57
+ - **Make sure you have 2 images to test (e.g., face1.jpg and face2.jpg).**
58
+ ```bash
59
+ # File: run_demo.py
60
+ from inference import FaceAnalysis
61
+
62
+ # 1. Initialize the AI (Downloads models automatically on first run)
63
+ print("⏳ Initializing models...")
64
+ app = FaceAnalysis()
65
+
66
+ # 2. Define your images
67
+ img1_path = "face1.jpg" # <--- Change this to your image path
68
+ img2_path = "face2.jpg" # <--- Change this to your image path
69
+
70
+ # 3. Run Comparison
71
+ print(f"🔍 Comparing {img1_path} vs {img2_path}...")
72
+
73
+ try:
74
+ # Get similarity score and boolean result
75
+ similarity, is_same = app.compare(img1_path, img2_path)
76
+
77
+ print("-" * 30)
78
+ print(f"🔹 Similarity Score: {similarity:.4f}")
79
+ print("-" * 30)
80
+
81
+ if is_same:
82
+ print("✅ RESULT: SAME PERSON")
83
+ else:
84
+ print("❌ RESULT: DIFFERENT PERSON")
85
+
86
+ except Exception as e:
87
+ print(f"Error: {e}")
88
+ print("Tip: Make sure the image paths are correct!")
89
+ ```
90
+ ---
91
+ ## 🎓 Training Guide
92
+ Option: Full Training (Advanced): Use train.py to train the model from scratch (ImageNet weights) on a large dataset.
93
+ **Step 1: Prepare Dataset**
94
+ - **Organize images in ImageFolder format**
95
+ ```bash
96
+ dataset/
97
+ ├── person_1/
98
+ │ ├── img1.jpg
99
+ │ └── ...
100
+ └── person_2/
101
+ └── img1.jpg
102
+ ```
103
+ **Step 2: Run Training**
104
+ ```bash
105
+ python train.py \\
106
+ --data_dir ./dataset \\
107
+ --output_dir ./checkpoints \\
108
+ --epochs 50 \\
109
+ --batch_size 64 \\
110
+ --lr_backbone 8e-6 \\
111
+ --lr_head 8e-5
112
+ ```
113
+
114
+
config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "face_recognition",
3
+ "backbone": "wide_resnet101_2",
4
+ "embedding_size": 512,
5
+ "num_classes": 100,
6
+ "test_accuracy": 90.5,
7
+ "test_dataset": "LFW"
8
+ }
finetune.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms, datasets
6
+ import torchvision.models as models
7
+ from torch.utils.data import DataLoader, random_split
8
+ from tqdm import tqdm
9
+ import os
10
+ import numpy as np
11
+
12
+
13
+ # ==========================================
14
+ # 1. MODEL ARCHITECTURE
15
+ # ==========================================
16
+ class FaceRecognitionModel(nn.Module):
17
+ def __init__(self):
18
+ super(FaceRecognitionModel, self).__init__()
19
+ # Load backbone
20
+ print("🏗️ Loading Backbone: Wide ResNet-101-2...")
21
+ self.backbone = models.wide_resnet101_2(weights='IMAGENET1K_V2')
22
+ self.backbone.fc = nn.Identity()
23
+
24
+ # Embedding Head
25
+ self.embed = nn.Sequential(
26
+ nn.Linear(2048, 512),
27
+ nn.BatchNorm1d(512),
28
+ nn.ReLU(inplace=True)
29
+ )
30
+
31
+ def forward(self, img):
32
+ features = self.backbone(img)
33
+ embedding = self.embed(features)
34
+ # Normalize to hypersphere
35
+ return F.normalize(embedding, p=2, dim=1)
36
+
37
+
38
+ # ==========================================
39
+ # 2. LOSS FUNCTIONS
40
+ # ==========================================
41
+ class ArcFaceLoss(nn.Module):
42
+ def __init__(self, num_classes, embedding_size=512, margin=0.5, scale=64):
43
+ super(ArcFaceLoss, self).__init__()
44
+ self.margin = margin
45
+ self.scale = scale
46
+ self.weight = nn.Parameter(torch.Tensor(num_classes, embedding_size))
47
+ nn.init.xavier_uniform_(self.weight)
48
+
49
+ def forward(self, embeddings, labels):
50
+ W = F.normalize(self.weight, dim=1)
51
+ x = F.normalize(embeddings, dim=1)
52
+
53
+ cosine = torch.matmul(x, W.t())
54
+ cosine = cosine.clamp(-1 + 1e-7, 1 - 1e-7)
55
+
56
+ theta = torch.acos(cosine)
57
+ target_logits = torch.cos(theta + self.margin)
58
+
59
+ one_hot = torch.zeros_like(cosine)
60
+ one_hot.scatter_(1, labels.view(-1, 1), 1.0)
61
+
62
+ output = cosine * (1 - one_hot) + target_logits * one_hot
63
+ output = output * self.scale
64
+ return output
65
+
66
+
67
+ class CenterLoss(nn.Module):
68
+ def __init__(self, num_classes, embedding_size=512):
69
+ super(CenterLoss, self).__init__()
70
+ self.centers = nn.Parameter(torch.randn(num_classes, embedding_size))
71
+ nn.init.xavier_uniform_(self.centers)
72
+
73
+ def forward(self, embeddings, labels):
74
+ centers_norm = F.normalize(self.centers, p=2, dim=1)
75
+ centers_batch = centers_norm[labels]
76
+ cosine_sim = (embeddings * centers_batch).sum(dim=1)
77
+ loss = (1.0 - cosine_sim).mean()
78
+ return loss
79
+
80
+
81
+ # ==========================================
82
+ # 3. DATA LOADER
83
+ # ==========================================
84
+ def get_dataloader(data_dir, batch_size=64, num_workers=4, split_ratio=0.9):
85
+ print(f"📂 Loading Data from: {data_dir}")
86
+
87
+ # Strong Augmentation for Training
88
+ transform_train = transforms.Compose([
89
+ transforms.Resize((256, 256)),
90
+ transforms.RandomCrop((224, 224)),
91
+ transforms.RandomHorizontalFlip(p=0.5),
92
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.25, hue=0.08),
93
+ transforms.RandomGrayscale(p=0.1),
94
+ transforms.RandomRotation(degrees=10),
95
+ transforms.RandomAffine(degrees=0, translate=(0.08, 0.08), scale=(0.92, 1.08)),
96
+ transforms.RandomApply([transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0))], p=0.3),
97
+ transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
98
+ transforms.ToTensor(),
99
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
100
+ transforms.RandomErasing(p=0.25, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
101
+ ])
102
+
103
+ # Standard Transform for Validation
104
+ transform_val = transforms.Compose([
105
+ transforms.Resize((224, 224)),
106
+ transforms.ToTensor(),
107
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
108
+ ])
109
+
110
+ full_dataset = datasets.ImageFolder(root=data_dir, transform=transform_train)
111
+ num_classes = len(full_dataset.classes)
112
+
113
+ # Split Train/Val
114
+ train_size = int(split_ratio * len(full_dataset))
115
+ val_size = len(full_dataset) - train_size
116
+ train_set, val_set = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
117
+
118
+ # Apply specific transform to validation set
119
+ val_set.dataset.transform = transform_val
120
+
121
+ print(f" ✅ Classes: {num_classes}")
122
+ print(f" ✅ Train Images: {len(train_set)}")
123
+ print(f" ✅ Val Images: {len(val_set)}")
124
+
125
+ train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
126
+ val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
127
+
128
+ return train_loader, val_loader, num_classes
129
+
130
+
131
+ # ==========================================
132
+ # 4. TRAINING ENGINE
133
+ # ==========================================
134
+ def evaluate(model, arcface, val_loader, criterion, device):
135
+ model.eval()
136
+ arcface.eval()
137
+ total_loss = 0
138
+ correct = 0
139
+ total = 0
140
+
141
+ with torch.no_grad():
142
+ for imgs, labels in tqdm(val_loader, desc=" 🧪 Evaluating"):
143
+ imgs, labels = imgs.to(device), labels.to(device)
144
+ embeddings = model(imgs)
145
+ logits = arcface(embeddings, labels)
146
+ loss = criterion(logits, labels)
147
+
148
+ total_loss += loss.item()
149
+ _, predicted = torch.max(logits.data, 1)
150
+ total += labels.size(0)
151
+ correct += (predicted == labels).sum().item()
152
+
153
+ return total_loss / len(val_loader), 100 * correct / total
154
+
155
+
156
+ def main(args):
157
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
158
+ print(f"🚀 Device: {device}")
159
+
160
+ # Data
161
+ train_loader, val_loader, num_classes = get_dataloader(args.data_dir, args.batch_size, args.num_workers)
162
+
163
+ # Models
164
+ model = FaceRecognitionModel().to(device)
165
+ arcface = ArcFaceLoss(num_classes=num_classes).to(device)
166
+ center_loss = CenterLoss(num_classes=num_classes).to(device)
167
+
168
+ # Load Checkpoint (Resume)
169
+ start_epoch = 0
170
+ if args.resume and os.path.exists(args.resume):
171
+ print(f"🔄 Resuming from {args.resume}...")
172
+ checkpoint = torch.load(args.resume, map_location=device)
173
+ model.load_state_dict(checkpoint['model_state_dict'])
174
+ arcface.load_state_dict(checkpoint['arcface_state_dict'])
175
+ if 'center_loss_state_dict' in checkpoint:
176
+ center_loss.load_state_dict(checkpoint['center_loss_state_dict'])
177
+ start_epoch = checkpoint.get('epoch', 0)
178
+
179
+ # Optimizer
180
+ optimizer = torch.optim.Adam([
181
+ {'params': model.backbone.parameters(), 'lr': args.lr_backbone},
182
+ {'params': model.embed.parameters(), 'lr': args.lr_head},
183
+ {'params': arcface.parameters(), 'lr': args.lr_head},
184
+ {'params': center_loss.parameters(), 'lr': 1e-4}
185
+ ], weight_decay=1e-3)
186
+
187
+ criterion = nn.CrossEntropyLoss()
188
+ best_acc = 0.0
189
+
190
+ # Training Loop
191
+ print("\n🔥 START TRAINING...")
192
+ for epoch in range(start_epoch, args.epochs):
193
+ model.train()
194
+ total_loss = 0
195
+
196
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}")
197
+ for imgs, labels in pbar:
198
+ imgs, labels = imgs.to(device), labels.to(device)
199
+
200
+ # Forward
201
+ embeddings = model(imgs)
202
+ logits = arcface(embeddings, labels)
203
+
204
+ # Loss Calculation
205
+ loss_ce = criterion(logits, labels)
206
+ loss_center = center_loss(embeddings, labels)
207
+ loss = loss_ce + (args.lambda_center * loss_center)
208
+
209
+ # Backward
210
+ optimizer.zero_grad()
211
+ loss.backward()
212
+ optimizer.step()
213
+
214
+ total_loss += loss.item()
215
+ pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'CE': f"{loss_ce.item():.4f}"})
216
+
217
+ # Save Checkpoint
218
+ save_dict = {
219
+ 'epoch': epoch + 1,
220
+ 'model_state_dict': model.state_dict(),
221
+ 'arcface_state_dict': arcface.state_dict(),
222
+ 'center_loss_state_dict': center_loss.state_dict(),
223
+ 'num_classes': num_classes
224
+ }
225
+
226
+ # Save Last
227
+ torch.save(save_dict, os.path.join(args.output_dir, "last_checkpoint.bin"))
228
+
229
+ # Evaluate & Save Best
230
+ val_loss, val_acc = evaluate(model, arcface, val_loader, criterion, device)
231
+ print(f" 🏆 Epoch {epoch + 1} | Val Loss: {val_loss:.4f} | Accuracy: {val_acc:.2f}%")
232
+
233
+ if val_acc > best_acc:
234
+ best_acc = val_acc
235
+ print(f" 💾 Saving New Best Model (Acc: {best_acc:.2f}%)")
236
+ torch.save(save_dict, os.path.join(args.output_dir, "pytorch_model.bin"))
237
+
238
+
239
+ if __name__ == "__main__":
240
+ parser = argparse.ArgumentParser(description="Train Face Recognition Model (ArcFace + CenterLoss)")
241
+
242
+ # Required
243
+ parser.add_argument('--data_dir', type=str, required=True, help="Path to ImageFolder dataset")
244
+
245
+ # Optional
246
+ parser.add_argument('--output_dir', type=str, default=".", help="Where to save .bin files")
247
+ parser.add_argument('--resume', type=str, default=None, help="Path to checkpoint to resume")
248
+ parser.add_argument('--epochs', type=int, default=20)
249
+ parser.add_argument('--batch_size', type=int, default=64)
250
+ parser.add_argument('--num_workers', type=int, default=4)
251
+
252
+ # Hyperparameters
253
+ parser.add_argument('--lr_backbone', type=float, default=8e-6)
254
+ parser.add_argument('--lr_head', type=float, default=8e-5)
255
+ parser.add_argument('--lambda_center', type=float, default=0.18)
256
+
257
+ args = parser.parse_args()
258
+
259
+ os.makedirs(args.output_dir, exist_ok=True)
260
+ main(args)
inference.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision import transforms
8
+ import torchvision.models as models
9
+ from PIL import Image
10
+ import onnxruntime as ort
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # ==========================================
14
+ # CẤU HÌNH REPO
15
+ # ==========================================
16
+ REPO_ID = "biometric-ai-lab/Face_Recognition"
17
+ RECOG_FILENAME = "pytorch_model.bin"
18
+ YOLO_FILENAME = "yolov8s-face-lindevs.onnx"
19
+
20
+
21
+ # ==========================================
22
+ # 1. MODEL ARCHITECTURE (Giống hệt code bạn)
23
+ # ==========================================
24
+ class FaceRecognitionModel(nn.Module):
25
+ def __init__(self):
26
+ super(FaceRecognitionModel, self).__init__()
27
+ # Khởi tạo backbone, để weights=None vì ta sẽ load weight train của bạn
28
+ self.backbone = models.wide_resnet101_2(weights=None)
29
+ self.backbone.fc = nn.Identity()
30
+ self.embed = nn.Sequential(
31
+ nn.Linear(2048, 512),
32
+ nn.BatchNorm1d(512),
33
+ nn.ReLU(inplace=True),
34
+ )
35
+
36
+ def forward(self, img):
37
+ features = self.backbone(img)
38
+ embedding = self.embed(features)
39
+ return F.normalize(embedding, p=2, dim=1)
40
+
41
+
42
+ # ==========================================
43
+ # 2. YOLO DETECTOR (Logic chuẩn của bạn)
44
+ # ==========================================
45
+ class YOLOFaceDetector:
46
+ def __init__(self, model_path, conf_threshold=0.5):
47
+ self.session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
48
+ self.input_name = self.session.get_inputs()[0].name
49
+ self.output_names = [output.name for output in self.session.get_outputs()]
50
+ self.conf_threshold = conf_threshold
51
+ self.input_size = 640
52
+
53
+ def detect_extract_face(self, image_pil, expand_ratio=0.0):
54
+ """
55
+ Input: PIL Image
56
+ Output: PIL Image (Cropped Face)
57
+ """
58
+ # Convert PIL -> OpenCV (BGR) để giống logic cũ
59
+ image_np = np.array(image_pil)
60
+ image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
61
+ img_height, img_width = image_bgr.shape[:2]
62
+
63
+ # Preprocess (Resize -> RGB -> Norm -> Transpose)
64
+ img_resized = cv2.resize(image_bgr, (self.input_size, self.input_size))
65
+ # Lưu ý: YOLO training thường dùng RGB
66
+ img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
67
+ img_normalized = img_rgb.astype(np.float32) / 255.0
68
+ img_transposed = np.transpose(img_normalized, (2, 0, 1))
69
+ img_batch = np.expand_dims(img_transposed, axis=0)
70
+
71
+ # Inference
72
+ outputs = self.session.run(self.output_names, {self.input_name: img_batch})
73
+ predictions = outputs[0]
74
+
75
+ if len(predictions.shape) == 3:
76
+ predictions = predictions[0].T
77
+
78
+ best_face = None
79
+ max_area = 0
80
+
81
+ # Post-process
82
+ for pred in predictions:
83
+ conf = pred[4]
84
+ if conf > self.conf_threshold:
85
+ x_center, y_center, w, h = pred[:4]
86
+
87
+ # Scale về ảnh gốc
88
+ x_center = x_center * img_width / self.input_size
89
+ y_center = y_center * img_height / self.input_size
90
+ w = w * img_width / self.input_size
91
+ h = h * img_height / self.input_size
92
+
93
+ x1 = int(x_center - w / 2)
94
+ y1 = int(y_center - h / 2)
95
+ x2 = int(x_center + w / 2)
96
+ y2 = int(y_center + h / 2)
97
+
98
+ x1 = max(0, x1)
99
+ y1 = max(0, y1)
100
+ x2 = min(img_width, x2)
101
+ y2 = min(img_height, y2)
102
+
103
+ area = (x2 - x1) * (y2 - y1)
104
+
105
+ # Lấy mặt to nhất
106
+ if area > max_area:
107
+ max_area = area
108
+ best_face = (x1, y1, x2, y2)
109
+
110
+ # Crop ảnh
111
+ if best_face:
112
+ x1, y1, x2, y2 = best_face
113
+
114
+ # Xử lý expand_ratio (nếu có dùng)
115
+ if expand_ratio != 0:
116
+ w_box = x2 - x1
117
+ h_box = y2 - y1
118
+ pad = int(expand_ratio * max(w_box, h_box))
119
+ x1 = max(0, x1 - pad)
120
+ y1 = max(0, y1 - pad)
121
+ x2 = min(img_width, x2 + pad)
122
+ y2 = min(img_height, y2 + pad)
123
+
124
+ # Crop từ ảnh gốc PIL (để giữ chất lượng tốt nhất)
125
+ return image_pil.crop((x1, y1, x2, y2))
126
+
127
+ print("⚠️ Warning: No face detected. Using full image.")
128
+ return image_pil
129
+
130
+
131
+ # ==========================================
132
+ # 3. FACE ANALYSIS WRAPPER
133
+ # ==========================================
134
+ class FaceAnalysis:
135
+ def __init__(self, device=None):
136
+ self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
137
+ print(f"🚀 Initializing Face Analysis on {self.device}...")
138
+
139
+ # 1. Tải Model
140
+ try:
141
+ print(f"📥 Checking models from {REPO_ID}...")
142
+ recog_path = hf_hub_download(repo_id=REPO_ID, filename=RECOG_FILENAME)
143
+ yolo_path = hf_hub_download(repo_id=REPO_ID, filename=YOLO_FILENAME)
144
+ except Exception as e:
145
+ raise RuntimeError(f"❌ Failed to download models. Check internet or Repo ID.\nError: {e}")
146
+
147
+ # 2. Init YOLO
148
+ self.yolo = YOLOFaceDetector(yolo_path, conf_threshold=0.5)
149
+
150
+ # 3. Init Recognition
151
+ self.model = FaceRecognitionModel().to(self.device)
152
+
153
+ # Load weights an toàn
154
+ checkpoint = torch.load(recog_path, map_location=self.device)
155
+ if 'model' in checkpoint:
156
+ self.model.load_state_dict(checkpoint['model'])
157
+ else:
158
+ # Fallback nếu file chỉ chứa weight không
159
+ self.model.load_state_dict(checkpoint)
160
+
161
+ self.model.eval()
162
+
163
+ # 4. Transform (Giống hệt inference_transform của bạn)
164
+ self.transform = transforms.Compose([
165
+ transforms.Resize((224, 224)),
166
+ transforms.ToTensor(),
167
+ transforms.Normalize(
168
+ mean=[0.485, 0.456, 0.406],
169
+ std=[0.229, 0.224, 0.225],
170
+ ),
171
+ ])
172
+ print("✅ System Ready!")
173
+
174
+ def process_image(self, image_source, expand_ratio=0.0):
175
+ # Load ảnh
176
+ if isinstance(image_source, str):
177
+ if not os.path.exists(image_source):
178
+ raise FileNotFoundError(f"Image not found: {image_source}")
179
+ img_pil = Image.open(image_source).convert('RGB')
180
+ elif isinstance(image_source, Image.Image):
181
+ img_pil = image_source.convert('RGB')
182
+ elif isinstance(image_source, np.ndarray):
183
+ img_pil = Image.fromarray(cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB))
184
+ else:
185
+ raise ValueError("Input must be filepath, PIL Image, or Numpy Array")
186
+
187
+ # 1. YOLO Detect & Crop
188
+ face_crop = self.yolo.detect_extract_face(img_pil, expand_ratio=expand_ratio)
189
+
190
+ # 2. Transform & Embedding
191
+ img_tensor = self.transform(face_crop).unsqueeze(0).to(self.device)
192
+
193
+ with torch.no_grad():
194
+ embedding = self.model(img_tensor)
195
+
196
+ return embedding
197
+
198
+ def compare(self, img1, img2, threshold=0.45, expand_ratio=0.01):
199
+ """
200
+ So sánh 2 ảnh.
201
+ expand_ratio=0.01 giống code demo của bạn.
202
+ """
203
+ emb1 = self.process_image(img1, expand_ratio)
204
+ emb2 = self.process_image(img2, expand_ratio)
205
+
206
+ # Cosine Similarity
207
+ similarity = F.cosine_similarity(emb1, emb2).item()
208
+ is_same = similarity > threshold
209
+
210
+ return similarity, is_same
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64d4b0af1cbe947a0bcf2996d362cf973bd30f277f55e67a748a409fd733385a
3
+ size 529070510
yolov8s-face-lindevs.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a6d19f2f68d7f0cc8104ab5c9eaa54b63e298f91dcfefd4be897f94a1561d02
3
+ size 44731626
yolov8s-face-lindevs.onnx:Zone.Identifier ADDED
Binary file (25 Bytes). View file