Insightface / insightface_pipeline.py
wli1995's picture
Upload folder using huggingface_hub
d56c551 verified
import cv2
import numpy as np
from insightface.app import FaceAnalysis
import argparse
import os
from tqdm import tqdm
class FaceRecognitionPipeline:
def __init__(self, model_path='./models/axmodel'):
"""
初始化人脸识别管道
:param providers: ONNX 推理后端,支持 GPU/CPU
"""
# 1. 初始化 FaceAnalysis(包含 detection + landmark)
self.app = FaceAnalysis(root=model_path)
self.app.prepare(ctx_id=0, det_size=(640, 640)) # det_size 可调
# 2. (可选)显式加载 ArcFace 模型(其实 FaceAnalysis 已包含)
# self.rec_model = get_model('antelopev2', providers=providers) # 或 'buffalo_l'
# self.rec_model.prepare(ctx_id=0)
def extract_features(self, image):
"""
从图像中提取所有人脸的特征向量
:param image: BGR 图像 (H, W, 3)
:return: faces_info = [{'bbox', 'kps', 'embedding'}, ...]
"""
faces = self.app.get(image)
return faces # 每个 face 是 dict,包含 embedding (512-d)
def compare(self, emb1, emb2, threshold=0.25):
"""
计算两个 512 维特征向量的余弦相似度(InsightFace 使用余弦距离)
:param emb1, emb2: shape=(512,)
:param threshold: 相似度阈值(antelopev2 推荐 0.35,buffalo_l 推荐 0.25)
:return: (similarity, is_same)
"""
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
is_same = similarity > threshold
return similarity, is_same
def recognize(self, query_image, gallery_embeddings, gallery_names, threshold=0.25):
"""
在图库中识别查询图像中的人脸
:param query_image: BGR 图像
:param gallery_embeddings: list of embeddings (n x 512)
:param gallery_names: list of names (n,)
:param threshold: 识别阈值
:return: [{'name', 'similarity', 'bbox'}, ...]
"""
query_faces = self.extract_features(query_image)
results = []
for face in query_faces:
best_sim = -1
best_name = "Unknown"
for emb, name in zip(gallery_embeddings, gallery_names):
sim, _ = self.compare(face['embedding'], emb, threshold=0.25) # 先算相似度
if sim > best_sim:
best_sim = sim
best_name = name if sim > threshold else "Unknown"
results.append({
'name': best_name,
'similarity': best_sim,
'face': face
})
return results
def draw_results(self, image, results):
"""在图像上绘制识别结果"""
img_draw = image.copy()
for res in results:
img_draw = self.app.draw_on(img_draw, [res['face']])
x1, y1, x2, y2 = res['face']['bbox'].astype(int)
color = (0, 255, 0) if res['name'] != "Unknown" else (0, 0, 255)
cv2.putText(img_draw, f"{res['name']}: {res['similarity']:.2f}",
(x1, y2 + 15), cv2.FONT_HERSHEY_COMPLEX, 0.7, color, 1)
return img_draw
# ==================== 使用示例 ====================
if __name__ == "__main__":
args = argparse.ArgumentParser(description="Face Recognition Pipeline Example")
args.add_argument("--model_path", "-m", type=str, default="./models/buffalo_l", help="Path to the model directory")
args.add_argument("--type", "-t", type=int, default=0, help="Type of operation: 1: 1v1 compare, 2: 1vN recognize")
args.add_argument("--gallery_path", "-g", type=str, default=None, help="Path to the gallery image for image file")
args.add_argument("--query_path", "-q", type=str, default=None, help="Path to the query image")
args.add_argument("--draw", "-d", action='store_true', help="Whether to draw results on the image")
args = args.parse_args()
# 初始化 Pipeline
pipeline = FaceRecognitionPipeline(model_path=args.model_path)
# 示例1: 1:1 比对
if args.type == 0:
assert args.gallery_path is not None and args.query_path is not None, "请提供 gallery_path 和 query_path"
gallery_img = cv2.imread(args.gallery_path)
faces1 = pipeline.extract_features(gallery_img)
if faces1:
emb1 = faces1[0]['embedding']
else:
print(f"警告: {args.gallery_path} 未检测到人脸")
exit(0)
query_img = cv2.imread(args.query_path)
faces2 = pipeline.extract_features(query_img)
if faces1 and faces2:
sim, is_same = pipeline.compare(emb1, faces2[0]['embedding'], threshold=0.25)
print(f"相似度: {sim:.4f}, 是否同一人: {is_same}")
else:
print(f"警告: {args.query_path} 未检测到人脸")
exit(0)
if args.draw:
os.makedirs("./output", exist_ok=True)
output_img = pipeline.app.draw_on(query_img, [faces2[0]])
cv2.imwrite(f"./output/{os.path.basename(args.query_path)}", output_img)
print(f"结果已保存到 ./output/{os.path.basename(args.query_path)}")
elif args.type == 1:
assert args.gallery_path is not None, "请提供 gallery_path"
# 构建特征库
gallery_names = []
gallery_embeddings = []
for fname in tqdm(os.listdir(args.gallery_path)):
name = os.path.splitext(os.path.basename(fname))[0]
gallery_img = cv2.imread(os.path.join(args.gallery_path, fname))
faces = pipeline.extract_features(gallery_img)
if faces:
gallery_names.append(name)
gallery_embeddings.append(faces[0]['embedding'])
else:
print(f"警告: {fname} 未检测到人脸")
# 识别查询图像
print("特征库构建完成,包含以下人员:", gallery_names)
if args.draw:
os.makedirs("./output", exist_ok=True)
while True:
print("请输入查询图像路径 (输入 'exit' 退出): ")
user_input = input()
if user_input.lower() == 'exit':
break
if not os.path.isfile(user_input):
print("输入的路径不是有效的文件,请重新输入。")
continue
query_img = cv2.imread(user_input)
results = pipeline.recognize(query_img, gallery_embeddings, gallery_names, threshold=0.25)
if results is None or len(results) == 0:
print(f"{user_input} 未检测到人脸")
continue
for res in results:
print(f"识别结果: {res['name']}, 相似度(0-1): {res['similarity']:.4f}")
if args.draw:
output_img = pipeline.draw_results(query_img, results)
cv2.imwrite(f"./output/{os.path.basename(user_input)}", output_img)
print(f"结果已保存到 ./output/{os.path.basename(user_input)}")