Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import torch | |
| import facer | |
| from PIL import Image | |
| from torchvision import transforms | |
| from eval.tools.face_utils.face import tight_warp_face | |
| from eval.tools.face_utils.face_recg import Backbone | |
| import os | |
| from torch.nn import functional as F | |
| from src.utils.data_utils import pad_to_square, pad_to_target, json_dump, json_load, split_grid | |
| def expand_bounding_box(x_min, y_min, x_max, y_max, factor=1.3): | |
| # Calculate the center of the bounding box | |
| x_center = (x_min + x_max) / 2 | |
| y_center = (y_min + y_max) / 2 | |
| # Calculate the width and height of the bounding box | |
| width = x_max - x_min | |
| height = y_max - y_min | |
| # Calculate the new width and height | |
| new_width = factor * width | |
| new_height = factor * height | |
| # Calculate the new bounding box coordinates | |
| x_min_new = x_center - new_width / 2 | |
| x_max_new = x_center + new_width / 2 | |
| y_min_new = y_center - new_height / 2 | |
| y_max_new = y_center + new_height / 2 | |
| return x_min_new, y_min_new, x_max_new, y_max_new | |
| class FaceID: | |
| def __init__(self, device): | |
| self.device = torch.device(device) | |
| self.T = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| self.detector = facer.face_detector("retinaface/resnet50", device=device) | |
| face_model_path = os.getenv("FACE_ID_MODEL_PATH", "./checkpoints/model_ir_se50.pth") | |
| self.model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se') | |
| self.model.load_state_dict(torch.load(face_model_path, map_location=device)) | |
| self.model.to(device) | |
| self.model.eval() | |
| def detect(self, image, expand_scale=1.3): | |
| with torch.no_grad(): | |
| faces = self.detector(image.to(self.device)) | |
| bboxes = faces['rects'].detach().cpu().tolist() | |
| bboxes = [expand_bounding_box(*x, expand_scale) for x in bboxes] | |
| return bboxes | |
| def __call__(self, image_x, image_y, normalize=False): | |
| # NOTE: Only Support One Face Per Image | |
| try: | |
| warp_x = tight_warp_face(image_x, self.detector)['cropped_face_masked'] | |
| warp_y = tight_warp_face(image_y, self.detector)['cropped_face_masked'] | |
| except: | |
| # print("[Warning] No face detected!!") | |
| return 0 | |
| if warp_x is None or warp_y is None: | |
| # print("[Warning] No face detected!!") | |
| return 0 | |
| feature_x = self.model(self.T(warp_x).unsqueeze(0).to(self.device))[0] # [512] | |
| feature_y = self.model(self.T(warp_y).unsqueeze(0).to(self.device))[0] # [512] | |
| if normalize: | |
| feature_x = feature_x / feature_x.norm(p=2, dim=-1, keepdim=True) | |
| feature_y = feature_y / feature_y.norm(p=2, dim=-1, keepdim=True) | |
| return F.cosine_similarity(feature_x, feature_y, dim=0).item() * 100 | |
| if __name__ == "__main__": | |
| # online demo: https://dun.163.com/trial/face/compare | |
| from src.train.data.data_utils import pil2tensor | |
| import numpy as np | |
| faceid = FaceID("cuda") | |
| real_image_path = "assets/bengio_bengio.png" | |
| # gen_image_path = "runs/0303-2034_flux100k_mod-t_oc-sblocks_multi-0.5_fs-lora8_cond192_res384_bs48_resume/eval/ckpt/40000/0304_cond192_tar512/1_A man is wearing green headphones standi.png" | |
| # gen_image_path = "runs/0303-2034_flux100k_mod-t_oc-sblocks_multi-0.5_fs-lora8_cond192_res384_bs48_resume/eval/ckpt/40000/0304_cond192_tar512/7_A man is wearing green headphones standi.png" | |
| # gen_image_path = "runs/0303-2034_flux100k_mod-t_oc-sblocks_multi-0.5_fs-lora8_cond192_res384_bs48_resume/eval/ckpt/40000/0304_cond192_tar512/198_A man wearing a black white suit and a w.png" | |
| gen_image_path = "data/tmp/GCG/florence-sam_phrase-grounding_S3L-two-20k_v41_wds/00000/qwen10M_00000-00010_00000_000000008_vis.png" | |
| if "eval/ckpt" in gen_image_path: | |
| gen_images = split_grid(Image.open(gen_image_path)) | |
| else: | |
| gen_images = [Image.open(gen_image_path)] | |
| for i, gen_img in enumerate(gen_images): | |
| # img_tensor = torch.from_numpy(np.array(gen_img)).unsqueeze(0).permute(0, 3, 1, 2) | |
| img_tensor = (pil2tensor(gen_img).unsqueeze(0) * 255).to(torch.uint8) | |
| bboxes = faceid.detect(img_tensor) | |
| for j, bbox in enumerate(bboxes): | |
| face_img = gen_img.crop(bbox) | |
| face_img.save(f"tmp_{j}.png") | |
| print(faceid(real_image_path, face_img)) | |
| # print(faceid.detect(img_tensor)) | |
| # break | |
| # print(faceid(real_image_path, gen_img)) | |