| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import os |
|
|
| from .log import log |
| import numpy as np |
| import torch |
| from pytorch_retinaface.data import cfg_re50 |
| from pytorch_retinaface.layers.functions.prior_box import PriorBox |
| from pytorch_retinaface.models.retinaface import RetinaFace |
| from torch.utils.data import DataLoader, TensorDataset |
| from tqdm import tqdm |
|
|
| from .guardrail_core import GuardrailRunner, PostprocessingGuardrail |
| from .guardrail_io_utils import get_video_filepaths, read_video, save_video |
| from .blur_utils import pixelate_face |
| from .retinaface_utils import decode_batch, filter_detected_boxes, load_model |
| from .misc import misc, Color, timer |
|
|
| DEFAULT_RETINAFACE_CHECKPOINT = "checkpoints/Cosmos-1.0-Guardrail/face_blur_filter/Resnet50_Final.pth" |
|
|
| |
| TOP_K = 5_000 |
| KEEP_TOP_K = 750 |
| NMS_THRESHOLD = 0.4 |
|
|
|
|
| class RetinaFaceFilter(PostprocessingGuardrail): |
| def __init__( |
| self, |
| checkpoint: str = DEFAULT_RETINAFACE_CHECKPOINT, |
| batch_size: int = 1, |
| confidence_threshold: float = 0.7, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| ) -> None: |
| """ |
| Initialize the RetinaFace model for face detection and blurring. |
| |
| Args: |
| checkpoint: Path to the RetinaFace checkpoint file |
| batch_size: Batch size for RetinaFace inference and processing |
| confidence_threshold: Minimum confidence score to consider a face detection |
| """ |
| self.cfg = cfg_re50 |
| self.batch_size = batch_size |
| self.confidence_threshold = confidence_threshold |
| self.device = device |
| self.dtype = torch.float32 |
|
|
| |
| self.cfg["pretrain"] = False |
| self.net = RetinaFace(cfg=self.cfg, phase="test") |
| cpu = self.device == "cpu" |
|
|
| |
| self.net = load_model(self.net, checkpoint, cpu) |
| self.net.to(self.device, dtype=self.dtype).eval() |
|
|
| def preprocess_frames(self, frames: np.ndarray) -> torch.Tensor: |
| """Preprocess a sequence of frames for face detection. |
| |
| Args: |
| frames: Input frames |
| |
| Returns: |
| Preprocessed frames tensor |
| """ |
| with torch.no_grad(): |
| frames_tensor = torch.from_numpy(frames).to(self.device, dtype=self.dtype) |
| frames_tensor = frames_tensor.permute(0, 3, 1, 2) |
| frames_tensor = frames_tensor[:, [2, 1, 0], :, :] |
| means = torch.tensor([104.0, 117.0, 123.0], device=self.device, dtype=self.dtype).view(1, 3, 1, 1) |
| frames_tensor = frames_tensor - means |
| return frames_tensor |
|
|
| def blur_detected_faces( |
| self, |
| frames: np.ndarray, |
| batch_loc: torch.Tensor, |
| batch_conf: torch.Tensor, |
| prior_data: torch.Tensor, |
| scale: torch.Tensor, |
| min_size: tuple[int] = (20, 20), |
| ) -> list[np.ndarray]: |
| """Blur detected faces in a batch of frames using RetinaFace predictions. |
| |
| Args: |
| frames: Input frames |
| batch_loc: Batched location predictions |
| batch_conf: Batched confidence scores |
| prior_data: Prior boxes for the video |
| scale: Scale factor for resizing detections |
| min_size: Minimum size of a detected face region in pixels |
| |
| Returns: |
| Processed frames with pixelated faces |
| """ |
| with torch.no_grad(): |
| batch_boxes = decode_batch(batch_loc, prior_data, self.cfg["variance"]) |
| batch_boxes = batch_boxes * scale |
|
|
| blurred_frames = [] |
| for i, boxes in enumerate(batch_boxes): |
| boxes = boxes.detach().cpu().numpy() |
| scores = batch_conf[i, :, 1].detach().cpu().numpy() |
|
|
| filtered_boxes = filter_detected_boxes( |
| boxes, |
| scores, |
| confidence_threshold=self.confidence_threshold, |
| nms_threshold=NMS_THRESHOLD, |
| top_k=TOP_K, |
| keep_top_k=KEEP_TOP_K, |
| ) |
|
|
| frame = frames[i] |
| for box in filtered_boxes: |
| x1, y1, x2, y2 = map(int, box) |
| |
| if x2 - x1 < min_size[0] or y2 - y1 < min_size[1]: |
| continue |
| max_h, max_w = frame.shape[:2] |
| face_roi = frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] |
| blurred_face = pixelate_face(face_roi) |
| frame[max(y1, 0) : min(y2, max_h), max(x1, 0) : min(x2, max_w)] = blurred_face |
| blurred_frames.append(frame) |
|
|
| return blurred_frames |
|
|
| def postprocess(self, frames: np.ndarray) -> np.ndarray: |
| """Blur faces in a sequence of frames. |
| |
| Args: |
| frames: Input frames |
| |
| Returns: |
| Processed frames with pixelated faces |
| """ |
| |
| frames_tensor = self.preprocess_frames(frames) |
| dataset = TensorDataset(frames_tensor) |
| dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) |
| processed_frames, processed_batches = [], [] |
|
|
| prior_data, scale = None, None |
| for i, batch in enumerate(dataloader): |
| batch = batch[0] |
| h, w = batch.shape[-2:] |
|
|
| with torch.no_grad(): |
| |
| if prior_data is None: |
| priorbox = PriorBox(self.cfg, image_size=(h, w)) |
| priors = priorbox.forward() |
| priors = priors.to(self.device, dtype=self.dtype) |
| prior_data = priors.data |
|
|
| |
| if scale is None: |
| scale = torch.Tensor([w, h, w, h]) |
| scale = scale.to(self.device, dtype=self.dtype) |
|
|
| batch_loc, batch_conf, _ = self.net(batch) |
|
|
| |
| start_idx = i * self.batch_size |
| end_idx = min(start_idx + self.batch_size, len(frames)) |
| processed_batches.append( |
| self.blur_detected_faces(frames[start_idx:end_idx], batch_loc, batch_conf, prior_data, scale) |
| ) |
|
|
| processed_frames = [frame for batch in processed_batches for frame in batch] |
| return np.array(processed_frames) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--input_dir", type=str, required=True, help="Path containing input videos") |
| parser.add_argument("--output_dir", type=str, required=True, help="Path for saving processed videos") |
| parser.add_argument( |
| "--checkpoint", |
| type=str, |
| help="Path to the RetinaFace checkpoint file", |
| default=DEFAULT_RETINAFACE_CHECKPOINT, |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main(args): |
| filepaths = get_video_filepaths(args.input_dir) |
| if not filepaths: |
| log.error(f"No video files found in directory: {args.input_dir}") |
| return |
|
|
| face_blur = RetinaFaceFilter(checkpoint=args.checkpoint) |
| postprocessing_runner = GuardrailRunner(postprocessors=[face_blur]) |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| for filepath in tqdm(filepaths): |
| video_data = read_video(filepath) |
| with timer("face blur filter"): |
| frames = postprocessing_runner.postprocess(video_data.frames) |
|
|
| output_path = os.path.join(args.output_dir, os.path.basename(filepath)) |
| save_video(output_path, frames, video_data.fps) |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |
|
|