anonymize-faces / utils.py
KYM384's picture
bfloat16 to float16
e0b4e79 verified
import torch
import torchvision
import numpy as np
import argparse
import copy
import cv2
import os
from contextlib import nullcontext
from huggingface_hub import hf_hub_download
from facenet_pytorch import MTCNN
from models import MobileGenerator, MobileNetV3MultiTask
class Face:
def __init__(self, keypoint: list[tuple[int, int]]):
self.keypoint = keypoint
e0, e1, n, m0, m1 = keypoint
x_ = e1 - e0
y_ = 0.5 * (e0 + e1) - 0.5 * (m0 + m1)
c = 0.5 * (e0 + e1) - 0.1 * y_
cx, cy = int(c[0]), int(c[1])
theta = np.arctan2(x_[1], x_[0])
s = max(4.0 * np.linalg.norm(x_), 3.6 * np.linalg.norm(y_))
s = int(s)
# bbox: (x, y, w, h)
self.bbox = (cx-s//2, cy-s//2, s, s)
self.theta = theta
def get_center(self):
return self.bbox[0] + self.bbox[2] // 2, self.bbox[1] + self.bbox[3] // 2
def get_size(self):
return self.bbox[2]
def set_attributes(self, age: int, gender: str):
self.age = age
self.gender = gender
def update(self, keypoint: list[tuple[int, int]]):
self.__init__(keypoint)
def calc_iou(self, other) -> float:
x1 = max(self.bbox[0], other.bbox[0])
y1 = max(self.bbox[1], other.bbox[1])
x2 = min(self.bbox[0] + self.bbox[2], other.bbox[0] + other.bbox[2])
y2 = min(self.bbox[1] + self.bbox[3], other.bbox[1] + other.bbox[3])
inter_area = max(0, x2 - x1) * max(0, y2 - y1)
union_area = self.bbox[2] * self.bbox[3] + other.bbox[2] * other.bbox[3] - inter_area
if union_area == 0:
return 0.0
return inter_area / union_area
class FaceSet:
latent_ids = np.load(
hf_hub_download(
repo_id=os.getenv("HF_GEN_REPO_ID"),
filename="latent_ids.npz",
token=os.getenv("HF_HUB_TOKEN")
)
)
def __init__(self):
self.faces = []
self.nonused_counter = []
def append(self, face: Face):
self.faces.append(face)
self.nonused_counter.append(0)
def set_attributes(self, i: int, age: int, gender: str):
self.faces[i].set_attributes(age, gender)
if age[0] == 80 and gender[0] == "M":
age[0] = 70
self.faces[i].latent_id = self.latent_ids[f"{age[0]}_{gender[0]}_jp"]
def __len__(self) -> int:
# s = sum(c == 0 for c in self.nonused_counter)
# return s
return len(self.faces)
def __getitem__(self, idx: int) -> Face:
return self.faces[idx]
def __iter__(self):
# s = sum(c == 0 for c in self.nonused_counter)
# return iter(self.faces[:s])
return iter(self.faces)
def update(self, other, reset_nonused_threshold: int):
matched_self_indices = []
for i, other_face in enumerate(other):
max_iou = 0
max_j = -1
for j, self_face in enumerate(self.faces):
iou = other_face.calc_iou(self_face)
if iou > max_iou:
max_iou = iou
max_j = j
if max_iou > 0.3:
self.faces[max_j].update(other_face.keypoint)
self.nonused_counter[max_j] = 0
matched_self_indices.append(max_j)
else:
self.append(other_face)
matched_self_indices.append(len(self.faces)-1)
for j in range(len(self.faces)):
if j not in matched_self_indices:
self.nonused_counter[j] += 1
argsort = np.argsort(self.nonused_counter)[::-1]
self.faces = [self.faces[j] for j in argsort]
self.nonused_counter = [self.nonused_counter[j] for j in argsort]
self.faces = [face for j, face in enumerate(self.faces) if self.nonused_counter[j] < reset_nonused_threshold]
self.nonused_counter = [count for count in self.nonused_counter if count < reset_nonused_threshold]
class FaceCropper:
def __init__(self):
self.size = 256
self.crop_size = 224
self.detector = MTCNN(select_largest=False, keep_all=True, device="cuda" if torch.cuda.is_available() else "cpu")
mask = np.zeros((self.crop_size, self.crop_size), dtype=np.uint8)
mask[8:-8, 8:-8] = 255
mask = cv2.GaussianBlur(mask, (31, 31), 0)
self.mask = mask
def detect_keypoints(self, image: np.ndarray) -> FaceSet:
height, width = image.shape[:2]
_, _, points = self.detector.detect(image, landmarks=True)
faces_list = FaceSet()
if points is None:
return faces_list
for i in range(len(points)):
left_eye = points[i][0]
right_eye = points[i][1]
nose = points[i][2]
left_mouth = points[i][3]
right_mouth = points[i][4]
faces_list.append(Face(keypoint=[left_eye, right_eye, nose, left_mouth, right_mouth]))
return faces_list
def crop_and_resize(self, image: np.ndarray, face: Face) -> np.ndarray:
cx, cy = face.get_center()
theta = face.theta
s = face.get_size()
M = cv2.getRotationMatrix2D((cx, cy), np.degrees(theta), self.size / s * 1.14)
M[0, 2] += self.crop_size // 2 - cx
M[1, 2] += self.crop_size // 2 - cy
cropped = cv2.warpAffine(image, M, (self.crop_size, self.crop_size), flags=cv2.INTER_LINEAR)
return cropped
def invert_image(self, image: np.ndarray, cropped: np.ndarray, face: Face) -> np.ndarray:
cx, cy = face.get_center()
theta = face.theta
s = face.get_size()
x0 = max(0, int(np.floor(cx - s)))
y0 = max(0, int(np.floor(cy - s)))
x1 = min(image.shape[1], int(np.ceil(cx + s)))
y1 = min(image.shape[0], int(np.ceil(cy + s)))
if x0 >= x1 or y0 >= y1:
return image
cropped_image = image[y0:y1, x0:x1]
cx_local = cx - x0
cy_local = cy - y0
M = cv2.getRotationMatrix2D((cx_local, cy_local), np.degrees(theta), self.size / s * 1.14)
M[0, 2] += self.crop_size // 2 - cx_local
M[1, 2] += self.crop_size // 2 - cy_local
M_inv = cv2.invertAffineTransform(M)
inverted = cv2.warpAffine(cropped, M_inv, (x1-x0, y1-y0), flags=cv2.INTER_LINEAR)
mask = cv2.warpAffine(self.mask, M_inv, (x1-x0, y1-y0))
mask = mask.astype(np.float32)[:, :, None] / 255.0
blended = cropped_image.astype(np.float32) * (1 - mask) + inverted.astype(np.float32) * mask
result = image.copy()
result[y0:y1, x0:x1] = blended.astype(np.uint8)
return result
class FaceSwapper:
def __init__(self, model_path: str, classifier_checkpoint: str):
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.generator = MobileGenerator(input_nc=3, output_nc=3, latent_dim=512, n_blocks=6)
self.generator.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"), weights_only=False))
self.generator.to(self.device).eval()
self.classifier = MobileNetV3MultiTask(model_name="mobilenetv3_small_100", num_age_classes=10, num_gender_classes=2)
self.classifier.to(self.device).eval()
self.classifier.load_state_dict(torch.load(classifier_checkpoint, map_location=torch.device("cpu"), weights_only=False)["model_state_dict"])
self.mean = torch.tensor([0.485, 0.456, 0.406]).reshape(1,3,1,1)
self.std = torch.tensor([0.229, 0.224, 0.225]).reshape(1,3,1,1)
def np2tensor(self, imgs: np.ndarray) -> torch.Tensor:
if not isinstance(imgs, list):
imgs = [imgs]
imgs = np.stack(imgs, axis=0)
imgs = torch.from_numpy(imgs.astype(np.float32) / 255).permute(0, 3, 1, 2)
return (imgs - self.mean) / self.std
def tensor2np(self, imgs: torch.Tensor) -> np.ndarray:
imgs = imgs * self.std + self.mean
imgs = imgs.permute(0, 2, 3, 1).detach().numpy()
imgs = np.clip(imgs, 0, 1)
return (imgs * 255).astype(np.uint8)
def classify(self, img: np.ndarray) -> list[tuple[int, str]]:
autocast_context = torch.autocast("cuda", torch.float16) if self.device.type == "cuda" else nullcontext()
with torch.no_grad(), autocast_context:
img_tensor = self.np2tensor(img).to(self.device)
ages, genders = self.classifier(img_tensor)
ages = torch.softmax(ages, dim=1)
genders = torch.softmax(genders, dim=1)
attributes = []
for i in range(len(img_tensor)):
age = ages[i].argmax().item() * 10
age_logit = ages[i].max().item()
gender = "F" if genders[i].argmax().item() == 0 else "M"
gender_logit = genders[i].max().item()
attributes.append(([age, age_logit], [gender, gender_logit]))
return attributes
def swap(self, img_att: np.ndarray, latent_ids: list[np.ndarray]) -> np.ndarray:
autocast_context = torch.autocast("cuda", torch.float16) if self.device.type == "cuda" else nullcontext()
with torch.no_grad(), autocast_context:
img_att = self.np2tensor(img_att).to(self.device)
latent_ids = torch.from_numpy(np.vstack(latent_ids)).to(self.device)
output = self.generator(img_att, latent_ids)
return self.tensor2np(output.to("cpu"))