Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from typing import Union, Optional | |
| import numpy as np | |
| import torch | |
| import tops | |
| import torchvision.transforms.functional as F | |
| from motpy import Detection, MultiObjectTracker | |
| from dp2.utils import load_config | |
| from dp2.infer import build_trained_generator | |
| from dp2.detection.structures import CSEPersonDetection, FaceDetection, PersonDetection, VehicleDetection | |
| def load_generator_from_cfg_path(cfg_path: Union[str, Path]): | |
| cfg = load_config(cfg_path) | |
| G = build_trained_generator(cfg) | |
| tops.logger.log(f"Loaded generator from: {cfg_path}") | |
| return G | |
| class Anonymizer: | |
| def __init__( | |
| self, | |
| detector, | |
| load_cache: bool = False, | |
| person_G_cfg: Optional[Union[str, Path]] = None, | |
| cse_person_G_cfg: Optional[Union[str, Path]] = None, | |
| face_G_cfg: Optional[Union[str, Path]] = None, | |
| car_G_cfg: Optional[Union[str, Path]] = None, | |
| ) -> None: | |
| self.detector = detector | |
| self.generators = {k: None for k in [CSEPersonDetection, PersonDetection, FaceDetection, VehicleDetection]} | |
| self.load_cache = load_cache | |
| if cse_person_G_cfg is not None: | |
| self.generators[CSEPersonDetection] = load_generator_from_cfg_path(cse_person_G_cfg) | |
| if person_G_cfg is not None: | |
| self.generators[PersonDetection] = load_generator_from_cfg_path(person_G_cfg) | |
| if face_G_cfg is not None: | |
| self.generators[FaceDetection] = load_generator_from_cfg_path(face_G_cfg) | |
| if car_G_cfg is not None: | |
| self.generators[VehicleDetection] = load_generator_from_cfg_path(car_G_cfg) | |
| def initialize_tracker(self, fps: float): | |
| self.tracker = MultiObjectTracker(dt=1/fps) | |
| self.track_to_z_idx = dict() | |
| def reset_tracker(self): | |
| self.track_to_z_idx = dict() | |
| def forward_G(self, | |
| G, | |
| batch, | |
| multi_modal_truncation: bool, | |
| amp: bool, | |
| z_idx: int, | |
| truncation_value: float, | |
| idx: int, | |
| all_styles=None): | |
| batch["img"] = F.normalize(batch["img"].float(), [0.5*255, 0.5*255, 0.5*255], [0.5*255, 0.5*255, 0.5*255]) | |
| batch["img"] = batch["img"].float() | |
| batch["condition"] = batch["mask"].float() * batch["img"] | |
| with torch.cuda.amp.autocast(amp): | |
| z = None | |
| if z_idx is not None: | |
| state = np.random.RandomState(seed=z_idx[idx]) | |
| z = state.normal(size=(1, G.z_channels)).astype(np.float32) | |
| z = tops.to_cuda(torch.from_numpy(z)) | |
| if all_styles is not None: | |
| anonymized_im = G(**batch, s=iter(all_styles[idx]))["img"] | |
| elif multi_modal_truncation: | |
| w_indices = None | |
| if z_idx is not None: | |
| w_indices = [z_idx[idx] % len(G.style_net.w_centers)] | |
| anonymized_im = G.multi_modal_truncate( | |
| **batch, truncation_value=truncation_value, | |
| w_indices=w_indices, | |
| z=z | |
| )["img"] | |
| else: | |
| anonymized_im = G.sample(**batch, truncation_value=truncation_value, z=z)["img"] | |
| anonymized_im = (anonymized_im+1).div(2).clamp(0, 1).mul(255) | |
| return anonymized_im | |
| def anonymize_detections(self, | |
| im, detection, | |
| update_identity=None, | |
| **synthesis_kwargs | |
| ): | |
| G = self.generators[type(detection)] | |
| if G is None: | |
| return im | |
| C, H, W = im.shape | |
| if update_identity is None: | |
| update_identity = [True for i in range(len(detection))] | |
| for idx in range(len(detection)): | |
| if not update_identity[idx]: | |
| continue | |
| batch = detection.get_crop(idx, im) | |
| x0, y0, x1, y1 = batch.pop("boxes")[0] | |
| batch = {k: tops.to_cuda(v) for k, v in batch.items()} | |
| anonymized_im = self.forward_G(G, batch, **synthesis_kwargs, idx=idx) | |
| gim = F.resize(anonymized_im[0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.BICUBIC, antialias=True) | |
| mask = F.resize(batch["mask"][0], (y1-y0, x1-x0), interpolation=F.InterpolationMode.NEAREST).squeeze(0) | |
| # Remove padding | |
| pad = [max(-x0, 0), max(-y0, 0)] | |
| pad = [*pad, max(x1-W, 0), max(y1-H, 0)] | |
| def remove_pad(x): return x[..., pad[1]:x.shape[-2]-pad[3], pad[0]:x.shape[-1]-pad[2]] | |
| gim = remove_pad(gim) | |
| mask = remove_pad(mask) > 0.5 | |
| x0, y0 = max(x0, 0), max(y0, 0) | |
| x1, y1 = min(x1, W), min(y1, H) | |
| mask = mask.logical_not()[None].repeat(3, 1, 1) | |
| im[:, y0:y1, x0:x1][mask] = gim[mask].round().clamp(0, 255).byte() | |
| return im | |
| def visualize_detection(self, im: torch.Tensor, cache_id: str = None) -> torch.Tensor: | |
| all_detections = self.detector.forward_and_cache(im, cache_id, load_cache=self.load_cache) | |
| im = im.cpu() | |
| for det in all_detections: | |
| im = det.visualize(im) | |
| return im | |
| def forward(self, im: torch.Tensor, cache_id: str = None, track=True, detections=None, **synthesis_kwargs) -> torch.Tensor: | |
| assert im.dtype == torch.uint8 | |
| im = tops.to_cuda(im) | |
| all_detections = detections | |
| if detections is None: | |
| if self.load_cache: | |
| all_detections = self.detector.forward_and_cache(im, cache_id) | |
| else: | |
| all_detections = self.detector(im) | |
| if hasattr(self, "tracker") and track: | |
| [_.pre_process() for _ in all_detections] | |
| boxes = np.concatenate([_.boxes for _ in all_detections]) | |
| boxes = [Detection(box) for box in boxes] | |
| self.tracker.step(boxes) | |
| track_ids = self.tracker.detections_matched_ids | |
| z_idx = [] | |
| for track_id in track_ids: | |
| if track_id not in self.track_to_z_idx: | |
| self.track_to_z_idx[track_id] = np.random.randint(0, 2**32-1) | |
| z_idx.append(self.track_to_z_idx[track_id]) | |
| z_idx = np.array(z_idx) | |
| idx_offset = 0 | |
| for detection in all_detections: | |
| zs = None | |
| if hasattr(self, "tracker") and track: | |
| zs = z_idx[idx_offset:idx_offset+len(detection)] | |
| idx_offset += len(detection) | |
| im = self.anonymize_detections(im, detection, z_idx=zs, **synthesis_kwargs) | |
| return im.cpu() | |
| def __call__(self, *args, **kwargs): | |
| return self.forward(*args, **kwargs) | |