Spaces:
Runtime error
Runtime error
| from transformers import pipeline, ViTModel, AutoImageProcessor | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| import os | |
| detector = pipeline(model="google/owlvit-base-patch32", task="zero-shot-object-detection") | |
| model = ViTModel.from_pretrained("google/vit-base-patch16-224") | |
| image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
| candidates = [] | |
| def extract_face(input_image): | |
| predictions = detector( | |
| input_image, | |
| candidate_labels=["human face"], | |
| ) | |
| scores = [prediction["score"] for prediction in predictions] | |
| max_score_box = tuple(predictions[scores == max(scores)]["box"].values()) | |
| face_image = input_image.crop(max_score_box) | |
| return face_image | |
| def load_candidates(candidate_dir): | |
| assert os.path.exists(candidate_dir), f"Path candidate_dir {candidate_dir} is not exist." | |
| candidates = [] | |
| candidate_labels = os.listdir(candidate_dir) | |
| for candidate_label in candidate_labels: | |
| image_paths = os.listdir(os.path.join(candidate_dir, candidate_label)) | |
| images = [Image.open(os.path.join(candidate_dir, candidate_label, image_path)).convert("RGB") for image_path in image_paths if image_path.endswith((".jpg", ".png", ".jpeg", ".bmp"))] | |
| candidates.append(dict(label=candidate_label, images=images)) | |
| return candidates | |
| def extract_faces(candidates): | |
| for candidate in candidates: | |
| faces = [] | |
| for image in candidate["images"]: | |
| faces.append(extract_face(image)) | |
| candidate["faces"] = faces | |
| return candidates | |
| def extract_featrue(candidates, target): | |
| for candidate in candidates: | |
| target_images = candidate[target] | |
| pixel_values = image_processor(target_images, return_tensors="pt")["pixel_values"] | |
| features = model(pixel_values)["pooler_output"] | |
| feature = features.mean(0) | |
| candidate["feature"] = feature | |
| return candidates | |
| def load_candidates_face_feature(candidates): | |
| candidates = extract_faces(candidates) | |
| candidates = extract_featrue(candidates, "faces") | |
| return candidates | |
| def compare_with_candidates(detectd_face, candidates): | |
| pixel_values = image_processor(detectd_face, return_tensors="pt")["pixel_values"] | |
| detectd_feature = model(pixel_values)["pooler_output"].squeeze(0) | |
| sims = [] | |
| labels = [candidate["label"] for candidate in candidates] | |
| for candidate in candidates: | |
| sim = torch.cosine_similarity(detectd_feature, candidate["feature"], dim=0).item() | |
| sims.append(sim) | |
| return labels[sims.index(max(sims))] | |
| def face_recognition(detected_image): | |
| predictions = detector( | |
| detected_image, | |
| candidate_labels=["human face"], | |
| ) | |
| labels = [] | |
| for p in predictions: | |
| box = tuple(p["box"].values()) | |
| label = compare_with_candidates(detected_image.crop(box), candidates) | |
| labels.append((box, label)) | |
| return detected_image, labels | |
| def load_candidates_in_cache(candidate_dir): | |
| global candidates | |
| candidates = load_candidates(candidate_dir) | |
| candidates = load_candidates_face_feature(candidates) | |
| def main(): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| detected_image = gr.Image(type="pil", label="detected_image") | |
| output_image = gr.AnnotatedImage(type="pil", label="output_image") | |
| with gr.Row(): | |
| candidate_dir = gr.Textbox(label="candidate_dir") | |
| load_candidates_btn = gr.Button("Load", variant="secondary", size="sm") | |
| btn = gr.Button("Face Recognition", variant="primary") | |
| load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir]) | |
| btn.click(fn=face_recognition, inputs=[detected_image], outputs=[output_image]) | |
| demo.launch(debug=True) | |
| if __name__ == "__main__": | |
| main() |