|
|
import cv2 |
|
|
import numpy as np |
|
|
from insightface.app import FaceAnalysis |
|
|
import argparse |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
|
|
|
class FaceRecognitionPipeline: |
|
|
def __init__(self, model_path='./models/axmodel'): |
|
|
""" |
|
|
初始化人脸识别管道 |
|
|
:param providers: ONNX 推理后端,支持 GPU/CPU |
|
|
""" |
|
|
|
|
|
self.app = FaceAnalysis(root=model_path) |
|
|
self.app.prepare(ctx_id=0, det_size=(640, 640)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_features(self, image): |
|
|
""" |
|
|
从图像中提取所有人脸的特征向量 |
|
|
:param image: BGR 图像 (H, W, 3) |
|
|
:return: faces_info = [{'bbox', 'kps', 'embedding'}, ...] |
|
|
""" |
|
|
faces = self.app.get(image) |
|
|
return faces |
|
|
|
|
|
def compare(self, emb1, emb2, threshold=0.25): |
|
|
""" |
|
|
计算两个 512 维特征向量的余弦相似度(InsightFace 使用余弦距离) |
|
|
:param emb1, emb2: shape=(512,) |
|
|
:param threshold: 相似度阈值(antelopev2 推荐 0.35,buffalo_l 推荐 0.25) |
|
|
:return: (similarity, is_same) |
|
|
""" |
|
|
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) |
|
|
is_same = similarity > threshold |
|
|
return similarity, is_same |
|
|
|
|
|
def recognize(self, query_image, gallery_embeddings, gallery_names, threshold=0.25): |
|
|
""" |
|
|
在图库中识别查询图像中的人脸 |
|
|
:param query_image: BGR 图像 |
|
|
:param gallery_embeddings: list of embeddings (n x 512) |
|
|
:param gallery_names: list of names (n,) |
|
|
:param threshold: 识别阈值 |
|
|
:return: [{'name', 'similarity', 'bbox'}, ...] |
|
|
""" |
|
|
query_faces = self.extract_features(query_image) |
|
|
results = [] |
|
|
for face in query_faces: |
|
|
best_sim = -1 |
|
|
best_name = "Unknown" |
|
|
for emb, name in zip(gallery_embeddings, gallery_names): |
|
|
sim, _ = self.compare(face['embedding'], emb, threshold=0.25) |
|
|
if sim > best_sim: |
|
|
best_sim = sim |
|
|
best_name = name if sim > threshold else "Unknown" |
|
|
results.append({ |
|
|
'name': best_name, |
|
|
'similarity': best_sim, |
|
|
'face': face |
|
|
}) |
|
|
return results |
|
|
|
|
|
def draw_results(self, image, results): |
|
|
"""在图像上绘制识别结果""" |
|
|
img_draw = image.copy() |
|
|
for res in results: |
|
|
img_draw = self.app.draw_on(img_draw, [res['face']]) |
|
|
|
|
|
x1, y1, x2, y2 = res['face']['bbox'].astype(int) |
|
|
color = (0, 255, 0) if res['name'] != "Unknown" else (0, 0, 255) |
|
|
cv2.putText(img_draw, f"{res['name']}: {res['similarity']:.2f}", |
|
|
(x1, y2 + 15), cv2.FONT_HERSHEY_COMPLEX, 0.7, color, 1) |
|
|
return img_draw |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = argparse.ArgumentParser(description="Face Recognition Pipeline Example") |
|
|
args.add_argument("--model_path", "-m", type=str, default="./models/buffalo_l", help="Path to the model directory") |
|
|
args.add_argument("--type", "-t", type=int, default=0, help="Type of operation: 1: 1v1 compare, 2: 1vN recognize") |
|
|
args.add_argument("--gallery_path", "-g", type=str, default=None, help="Path to the gallery image for image file") |
|
|
args.add_argument("--query_path", "-q", type=str, default=None, help="Path to the query image") |
|
|
args.add_argument("--draw", "-d", action='store_true', help="Whether to draw results on the image") |
|
|
args = args.parse_args() |
|
|
|
|
|
|
|
|
pipeline = FaceRecognitionPipeline(model_path=args.model_path) |
|
|
|
|
|
|
|
|
if args.type == 0: |
|
|
assert args.gallery_path is not None and args.query_path is not None, "请提供 gallery_path 和 query_path" |
|
|
gallery_img = cv2.imread(args.gallery_path) |
|
|
faces1 = pipeline.extract_features(gallery_img) |
|
|
if faces1: |
|
|
emb1 = faces1[0]['embedding'] |
|
|
else: |
|
|
print(f"警告: {args.gallery_path} 未检测到人脸") |
|
|
exit(0) |
|
|
|
|
|
query_img = cv2.imread(args.query_path) |
|
|
faces2 = pipeline.extract_features(query_img) |
|
|
if faces1 and faces2: |
|
|
sim, is_same = pipeline.compare(emb1, faces2[0]['embedding'], threshold=0.25) |
|
|
print(f"相似度: {sim:.4f}, 是否同一人: {is_same}") |
|
|
else: |
|
|
print(f"警告: {args.query_path} 未检测到人脸") |
|
|
exit(0) |
|
|
|
|
|
if args.draw: |
|
|
os.makedirs("./output", exist_ok=True) |
|
|
output_img = pipeline.app.draw_on(query_img, [faces2[0]]) |
|
|
cv2.imwrite(f"./output/{os.path.basename(args.query_path)}", output_img) |
|
|
print(f"结果已保存到 ./output/{os.path.basename(args.query_path)}") |
|
|
|
|
|
elif args.type == 1: |
|
|
|
|
|
assert args.gallery_path is not None, "请提供 gallery_path" |
|
|
|
|
|
|
|
|
gallery_names = [] |
|
|
gallery_embeddings = [] |
|
|
for fname in tqdm(os.listdir(args.gallery_path)): |
|
|
name = os.path.splitext(os.path.basename(fname))[0] |
|
|
gallery_img = cv2.imread(os.path.join(args.gallery_path, fname)) |
|
|
faces = pipeline.extract_features(gallery_img) |
|
|
if faces: |
|
|
gallery_names.append(name) |
|
|
gallery_embeddings.append(faces[0]['embedding']) |
|
|
else: |
|
|
print(f"警告: {fname} 未检测到人脸") |
|
|
|
|
|
|
|
|
print("特征库构建完成,包含以下人员:", gallery_names) |
|
|
if args.draw: |
|
|
os.makedirs("./output", exist_ok=True) |
|
|
while True: |
|
|
print("请输入查询图像路径 (输入 'exit' 退出): ") |
|
|
user_input = input() |
|
|
if user_input.lower() == 'exit': |
|
|
break |
|
|
if not os.path.isfile(user_input): |
|
|
print("输入的路径不是有效的文件,请重新输入。") |
|
|
continue |
|
|
query_img = cv2.imread(user_input) |
|
|
results = pipeline.recognize(query_img, gallery_embeddings, gallery_names, threshold=0.25) |
|
|
|
|
|
if results is None or len(results) == 0: |
|
|
print(f"{user_input} 未检测到人脸") |
|
|
continue |
|
|
for res in results: |
|
|
print(f"识别结果: {res['name']}, 相似度(0-1): {res['similarity']:.4f}") |
|
|
if args.draw: |
|
|
output_img = pipeline.draw_results(query_img, results) |
|
|
cv2.imwrite(f"./output/{os.path.basename(user_input)}", output_img) |
|
|
print(f"结果已保存到 ./output/{os.path.basename(user_input)}") |