Spaces:
Build error
Build error
File size: 8,095 Bytes
752c636 9efc67a f5562ad 752c636 8460a2a f5562ad 752c636 8460a2a 752c636 8460a2a 752c636 8460a2a 752c636 8460a2a 752c636 8460a2a 752c636 8460a2a 752c636 8460a2a 752c636 8460a2a 752c636 8460a2a 752c636 8fbfaf4 752c636 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import PIL
from PIL import ImageOps
import xml.etree.ElementTree as ET
import shutil
import os
import numpy as np
from glob import glob
import tensorflow as tf
from tensorflow.keras import backend as K
import cv2
import matplotlib.pyplot as plt
import importlib
from face_recognition import config
from face_recognition.aligner import Aligner
class face_recognition:
def __init__(self,model_path,thres=None,min_aligner_confidence=None):
config_file_path='.'.join(model_path.split("/"))+".config"
# print(config_file_path)
self.model_config= importlib.import_module(config_file_path)
# print(self.model_config)
self.thres=thres if thres is not None else self.model_config.d_thres
self.aligner=Aligner(min_aligner_confidence) if min_aligner_confidence is not None else Aligner(config.min_aligner_confidence)
self.feature_extractor=tf.keras.models.load_model(model_path+"/model.h5",compile=False)
# def euclidean_distance(self,vectors):
# squared_sum=np.sum(np.square(vectors[0]-vectors[1]),axis=-1,keepdims=True)
# return np.sqrt(np.maximum(squared_sum,1e-7))
def new_distance(self,vectors):
''' this distance metric is -1 to 1
and it gives values close to 1 when matching
and values close to -1 when not matching
'''
return (vectors[0]*vectors[1]).sum(-1)
def calculate_distance(self,crop_img,db_faces_features,mode='avg'):
"""
mode= 'avg' or 'best'
"""
if mode not in ['avg','best']: raise ValueError(f"Unknown mode:{mode} \nMode should be one of these:{['avg','best']}")
crop_img_features=self.feature_extractor.predict(crop_img[None,:,:,:],verbose=0)
all_distances=[] # distance of this particular crop with all faces in database
for face_idx in range(len(self.faces)):
if mode=='avg':
db_face_features=db_faces_features[face_idx].mean(axis=0,keepdims=True) # avg method
new_crop_img_features=crop_img_features.copy()
else:
db_face_features=db_faces_features[face_idx] # best method
new_crop_img_features=np.tile(crop_img_features,[db_face_features.shape[0],1])
try:
assert(db_face_features.shape==new_crop_img_features.shape)
except:
raise AssertionError(f"db_face_features shape{db_face_features.shape} does not match crop_img_features shape{new_crop_img_features.shape}")
distance=np.max(self.new_distance([db_face_features,new_crop_img_features]),axis=0)
if distance>self.thres:
all_distances.append(distance) # obj distance wrt to all faces in database
else:
all_distances.append(self.model_config.large_distance) # not the person guaranteed
return all_distances
def repeat_allowed_face_recognition(self,distance_dict):
faceidx_to_obj_dict=dict()
for obj in distance_dict.keys():
distances=np.array(distance_dict[obj])
min_distance,min_distance_idx = distances.max(),distances.argmax()
if min_distance>self.thres:
obj.find('name').text = self.faces[min_distance_idx]
distance_tag=ET.Element("distance")
distance_tag.text="{:.2f}".format(min_distance)
obj.append(distance_tag)
return faceidx_to_obj_dict
def no_repeat_allowed_face_recognition(self,distance_dict):
def assign_face_label(obj):
# find min and argmin
min_distance,min_distance_idx = distance_dict[obj].max(),distance_dict[obj].argmax()
# base condition
if min_distance<self.thres:
# print("end");
return;
elif min_distance_idx not in faceidx_to_obj_dict:
faceidx_to_obj_dict[min_distance_idx]=(obj,min_distance) # stores obj and distance
else:
if(min_distance<faceidx_to_obj_dict[min_distance_idx][1]): # current is less matching
distance_dict[obj][min_distance_idx]=self.model_config.large_distance
assign_face_label(obj)
else: # current is more matching
temp_obj,temp_min_distance=faceidx_to_obj_dict[min_distance_idx]
faceidx_to_obj_dict[min_distance_idx]=(obj,min_distance) # stores obj and distance
distance_dict[temp_obj][min_distance_idx]=self.model_config.large_distance
assign_face_label(temp_obj)
faceidx_to_obj_dict=dict()
for obj in distance_dict.keys():
assign_face_label(obj)
for idx,(obj,distance) in faceidx_to_obj_dict.items():
obj.find('name').text = self.faces[idx]
distance_tag=ET.Element("distance")
distance_tag.text="{:.2f}".format(distance)
obj.append(distance_tag)
# print(obj.find("distance").text)
return faceidx_to_obj_dict
def forward_pass(self,img,tree,mode="repeat"):
'''mode : "repeat" or "no-repeat" '''
root=tree.getroot()
self.distance_dict=dict()
size=root.find('size')
w,h=int(size.find("width").text),int(size.find("height").text)
for i,obj in enumerate(root.findall("object")):
bndbox=obj.find("bndbox")
xmin,ymin , xmax,ymax=int(bndbox.find('xmin').text),int(bndbox.find('ymin').text),int(bndbox.find('xmax').text),int(bndbox.find('ymax').text)
crop_img=img[ymin:ymax,xmin:xmax]
crop_img=cv2.resize(crop_img,[self.model_config.input_size,self.model_config.input_size])
crop_img=self.aligner.align((crop_img,))[0]
if crop_img is not None:
self.distance_dict[obj]=np.array(self.calculate_distance(crop_img,self.db_faces_features,mode=self.distance_mode))
# print(distance_dict)
if mode=="repeat":
faceidx_to_obj_dict=self.repeat_allowed_face_recognition(self.distance_dict)
elif mode=="no-repeat":
faceidx_to_obj_dict=self.no_repeat_allowed_face_recognition(self.distance_dict)
return tree
def predict(self,img,tree):
if (not hasattr(self,"distance_mode")) or (not hasattr(self,"recognition_mode")): raise ValueError(f"Call set_face_db_and_mode method first!")
tree=self.forward_pass(img,tree,mode=self.recognition_mode)
return tree
def set_face_db_and_mode(self,faces,db_faces_features,distance_mode="avg",recognition_mode="repeat"):
if distance_mode not in ['avg','best']: raise ValueError(f"Unknown mode:{distance_mode} \nMode should be one of these:{['avg','best']}")
if recognition_mode not in ['repeat','no-repeat']: raise ValueError(f"Unknown mode:{recognition_mode} \nMode should be one of these:{['repeat','no-repeat']}")
self.distance_mode=distance_mode
self.recognition_mode=recognition_mode
self.faces=faces
self.db_faces_features=db_faces_features
# print(face_features)
# for xml_file in ["/content/images - Copy/IMG20221124131734.xml"]:
if __name__=="__main__":
from helper import *
img_dir=config.img_dir
save_dir=config.save_dir
_,faces,_=next(os.walk(config.db_dir))
db_faces_features=[np.loadtxt(f"{config.db_dir}/{face_dir}/features.npy",ndmin=2) for face_dir in faces]
for i in range(len(faces)):
print(faces[i],":",db_faces_features[i].shape)
if os.path.exists(save_dir):shutil.rmtree(save_dir)
os.mkdir(save_dir)
fr=face_recognition("face_recognition/Models/v1")
# fr=face_recognition(thres=0.3)
# fr.set_face_db_and_mode(faces,db_faces_features,distance_mode="best",recognition_mode="repeat")
fr.set_face_db_and_mode(faces,db_faces_features,distance_mode="best",recognition_mode="no-repeat")
for xml_file in glob(f"{img_dir}/*.xml"):
tree=ET.parse(xml_file)
root=tree.getroot()
img_name=img_dir+'/'+root.find("filename").text
img=PIL.Image.open(img_name).convert("RGB")
img = ImageOps.exif_transpose(img)
img=np.array(img)
tree=fr.predict(img,tree)
img=show_pred_image(tree,img)
# plot examples
# plt.figure(figsize=(10,10))
# plt.axis("off")
# plt.title("Labeled images")
# plt.imshow(img)
# plt.show()
print(xml_to_objs_found(tree))
cv2.imwrite(save_dir+"/"+root.find("filename").text,cv2.cvtColor(img,cv2.COLOR_RGB2BGR))
|