File size: 7,059 Bytes
d56c551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import cv2
import numpy as np
from insightface.app import FaceAnalysis
import argparse
import os
from tqdm import tqdm

class FaceRecognitionPipeline:
    def __init__(self, model_path='./models/axmodel'):
        """
        初始化人脸识别管道  
        :param providers: ONNX 推理后端,支持 GPU/CPU
        """
        # 1. 初始化 FaceAnalysis(包含 detection + landmark)
        self.app = FaceAnalysis(root=model_path)
        self.app.prepare(ctx_id=0, det_size=(640, 640))  # det_size 可调

        # 2. (可选)显式加载 ArcFace 模型(其实 FaceAnalysis 已包含)
        # self.rec_model = get_model('antelopev2', providers=providers)  # 或 'buffalo_l'
        # self.rec_model.prepare(ctx_id=0)

    def extract_features(self, image):
        """
        从图像中提取所有人脸的特征向量
        :param image: BGR 图像 (H, W, 3)
        :return: faces_info = [{'bbox', 'kps', 'embedding'}, ...]
        """
        faces = self.app.get(image)
        return faces  # 每个 face 是 dict,包含 embedding (512-d)

    def compare(self, emb1, emb2, threshold=0.25):
        """
        计算两个 512 维特征向量的余弦相似度(InsightFace 使用余弦距离)
        :param emb1, emb2: shape=(512,)
        :param threshold: 相似度阈值(antelopev2 推荐 0.35,buffalo_l 推荐 0.25)
        :return: (similarity, is_same)
        """
        similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
        is_same = similarity > threshold
        return similarity, is_same

    def recognize(self, query_image, gallery_embeddings, gallery_names, threshold=0.25):
        """
        在图库中识别查询图像中的人脸
        :param query_image: BGR 图像
        :param gallery_embeddings: list of embeddings (n x 512)
        :param gallery_names: list of names (n,)
        :param threshold: 识别阈值
        :return: [{'name', 'similarity', 'bbox'}, ...]
        """
        query_faces = self.extract_features(query_image)
        results = []
        for face in query_faces:
            best_sim = -1
            best_name = "Unknown"
            for emb, name in zip(gallery_embeddings, gallery_names):
                sim, _ = self.compare(face['embedding'], emb, threshold=0.25)  # 先算相似度
                if sim > best_sim:
                    best_sim = sim
                    best_name = name if sim > threshold else "Unknown"
            results.append({
                'name': best_name,
                'similarity': best_sim,
                'face': face
            })
        return results

    def draw_results(self, image, results):
        """在图像上绘制识别结果"""
        img_draw = image.copy()
        for res in results:
            img_draw = self.app.draw_on(img_draw, [res['face']])

            x1, y1, x2, y2 = res['face']['bbox'].astype(int)
            color = (0, 255, 0) if res['name'] != "Unknown" else (0, 0, 255)
            cv2.putText(img_draw, f"{res['name']}: {res['similarity']:.2f}",
                        (x1, y2 + 15), cv2.FONT_HERSHEY_COMPLEX, 0.7, color, 1)
        return img_draw

# ==================== 使用示例 ====================

if __name__ == "__main__":
    args = argparse.ArgumentParser(description="Face Recognition Pipeline Example")
    args.add_argument("--model_path", "-m", type=str, default="./models/buffalo_l", help="Path to the model directory")
    args.add_argument("--type", "-t", type=int, default=0, help="Type of operation: 1: 1v1 compare, 2: 1vN recognize")
    args.add_argument("--gallery_path", "-g", type=str, default=None, help="Path to the gallery image for image file")
    args.add_argument("--query_path", "-q", type=str, default=None, help="Path to the query image")
    args.add_argument("--draw", "-d", action='store_true', help="Whether to draw results on the image")
    args = args.parse_args()

    # 初始化 Pipeline
    pipeline = FaceRecognitionPipeline(model_path=args.model_path)

    # 示例1: 1:1 比对
    if args.type == 0:
        assert args.gallery_path is not None and args.query_path is not None, "请提供 gallery_path 和 query_path"
        gallery_img = cv2.imread(args.gallery_path)
        faces1 = pipeline.extract_features(gallery_img)
        if faces1:
            emb1 = faces1[0]['embedding']
        else:
            print(f"警告: {args.gallery_path} 未检测到人脸")
            exit(0)

        query_img = cv2.imread(args.query_path)
        faces2 = pipeline.extract_features(query_img)
        if faces1 and faces2:
            sim, is_same = pipeline.compare(emb1, faces2[0]['embedding'], threshold=0.25)
            print(f"相似度: {sim:.4f}, 是否同一人: {is_same}")
        else:
            print(f"警告: {args.query_path} 未检测到人脸")
            exit(0)

        if args.draw:
            os.makedirs("./output", exist_ok=True)
            output_img = pipeline.app.draw_on(query_img, [faces2[0]])
            cv2.imwrite(f"./output/{os.path.basename(args.query_path)}", output_img)
            print(f"结果已保存到 ./output/{os.path.basename(args.query_path)}")

    elif args.type == 1:

        assert args.gallery_path is not None, "请提供 gallery_path"

        # 构建特征库
        gallery_names = []
        gallery_embeddings = []
        for fname in tqdm(os.listdir(args.gallery_path)):
            name = os.path.splitext(os.path.basename(fname))[0]
            gallery_img = cv2.imread(os.path.join(args.gallery_path, fname))
            faces = pipeline.extract_features(gallery_img)
            if faces:
                gallery_names.append(name)
                gallery_embeddings.append(faces[0]['embedding'])
            else:
                print(f"警告: {fname} 未检测到人脸")

        # 识别查询图像
        print("特征库构建完成,包含以下人员:", gallery_names)
        if args.draw:
            os.makedirs("./output", exist_ok=True)
        while True:
            print("请输入查询图像路径 (输入 'exit' 退出): ")
            user_input = input()
            if user_input.lower() == 'exit':
                break
            if not os.path.isfile(user_input):
                print("输入的路径不是有效的文件,请重新输入。")
                continue
            query_img = cv2.imread(user_input)
            results = pipeline.recognize(query_img, gallery_embeddings, gallery_names, threshold=0.25)
            
            if results is None or len(results) == 0:
                print(f"{user_input} 未检测到人脸")
                continue
            for res in results:
                print(f"识别结果: {res['name']}, 相似度(0-1): {res['similarity']:.4f}")
            if args.draw:
                output_img = pipeline.draw_results(query_img, results)
                cv2.imwrite(f"./output/{os.path.basename(user_input)}", output_img)
                print(f"结果已保存到 ./output/{os.path.basename(user_input)}")