Spaces:
Sleeping
Sleeping
| import os | |
| from contextlib import nullcontext | |
| import numpy as np | |
| import cv2 | |
| from scipy.stats import skew, kurtosis | |
| try: | |
| import torch | |
| except ImportError: | |
| torch = None | |
| def _get_default_drunet(): | |
| """Load DRUNet grayscale model from default weights path.""" | |
| if torch is None: | |
| return None | |
| try: | |
| from .drunet import load_drunet_gray | |
| weights_path = os.path.join( | |
| os.path.dirname(__file__), | |
| 'drunet', 'weights', 'drunet_gray.pth' | |
| ) | |
| if os.path.exists(weights_path): | |
| return load_drunet_gray(weights_path, noise_level=15) | |
| else: | |
| print(f"Warning: DRUNet weights not found at {weights_path}") | |
| return None | |
| except Exception as e: | |
| print(f"Warning: Failed to load DRUNet: {e}") | |
| return None | |
| class ResidualExtractor: | |
| def __init__(self, denoiser='auto', use_gaussian_fallback=True, | |
| max_tile_size=1024, tile_overlap=64, max_image_size=4096, | |
| auto_downscale=True): | |
| """ | |
| Args: | |
| denoiser: Denoiser model. Options: | |
| - 'auto': Load DRUNet automatically (recommended) | |
| - None: No denoiser, use fallback | |
| - torch.nn.Module: Custom denoiser that takes (B,1,H,W) in [0,1] | |
| use_gaussian_fallback (bool): If True and denoiser unavailable, | |
| use Gaussian blur as fallback. | |
| max_tile_size (int): Maximum tile size for tiled processing (default: 1024) | |
| tile_overlap (int): Overlap between tiles to avoid boundary artifacts (default: 64) | |
| max_image_size (int): Maximum image dimension before auto-downscaling (default: 4096) | |
| auto_downscale (bool): Automatically downscale very large images (default: True) | |
| """ | |
| self.use_gaussian_fallback = use_gaussian_fallback | |
| self.max_tile_size = max_tile_size | |
| self.tile_overlap = tile_overlap | |
| self.max_image_size = max_image_size | |
| self.auto_downscale = auto_downscale | |
| if denoiser == 'auto': | |
| self.denoiser = _get_default_drunet() | |
| if self.denoiser is not None: | |
| print("ResidualExtractor: Using DRUNet denoiser") | |
| elif use_gaussian_fallback: | |
| print("ResidualExtractor: DRUNet unavailable, using Gaussian fallback") | |
| else: | |
| self.denoiser = denoiser | |
| if self.denoiser is not None and torch is None: | |
| raise RuntimeError("PyTorch not available but denoiser was provided.") | |
| if self.denoiser is not None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.denoiser.to(self.device).eval() | |
| else: | |
| self.device = None | |
| def _extract_features_torch_fast(self, gray, gray_uint8): | |
| """ | |
| Fast path: keep denoising + statistics on-device (no large CPU copies). | |
| Only used when we can run the full frame without tiling. | |
| """ | |
| if self.denoiser is None or torch is None: | |
| raise RuntimeError("Torch fast path requires a denoiser and torch.") | |
| device = self.device | |
| if device.type == 'cuda': | |
| amp_ctx = lambda: torch.amp.autocast('cuda') | |
| else: | |
| amp_ctx = nullcontext | |
| gray_clamped = np.clip(gray, 0, 255) | |
| gray_t = torch.from_numpy(gray_clamped).to(device=device, dtype=torch.float32) | |
| gray_4d = gray_t.unsqueeze(0).unsqueeze(0) | |
| denoise_input = (gray_4d / 255.0).clamp(0.0, 1.0) | |
| with torch.inference_mode(): | |
| with amp_ctx(): | |
| denoised = self.denoiser(denoise_input) | |
| residual = (gray_4d - denoised * 255.0).float() | |
| residual_flat = residual.view(-1) | |
| abs_res = residual_flat.abs() | |
| mean = residual_flat.mean() | |
| var = residual_flat.var(unbiased=False) | |
| std = torch.sqrt(var + 1e-12) | |
| # Match scipy defaults: bias=True, fisher=True | |
| centered = residual_flat - mean | |
| skew_val = (centered.pow(3).mean()) / (std.pow(3) + 1e-12) | |
| kurt_val = (centered.pow(4).mean()) / (std.pow(4) + 1e-12) - 3.0 | |
| features = { | |
| 'residual_mean': float(mean.item()), | |
| 'residual_std': float(std.item()), | |
| 'residual_skew': float(skew_val.item()), | |
| 'residual_kurtosis': float(kurt_val.item()), | |
| 'residual_energy': float(residual_flat.pow(2).mean().item()), | |
| 'residual_energy_mean': float(abs_res.mean().item()), | |
| 'residual_energy_std': float(abs_res.std(unbiased=False).item()), | |
| 'residual_energy_p95': float(torch.quantile(abs_res, 0.95).item()), | |
| } | |
| del residual, residual_flat, abs_res, denoised, gray_t, gray_4d | |
| return features | |
| def _run_denoiser_tiled(self, gray_uint8): | |
| """ | |
| Process large images in tiles to avoid VRAM exhaustion. | |
| Args: | |
| gray_uint8: HxW uint8 grayscale image. | |
| Returns: | |
| denoised image as float32 in [0,255]. | |
| """ | |
| h, w = gray_uint8.shape | |
| # Check if image is small enough to process directly | |
| if h <= self.max_tile_size and w <= self.max_tile_size: | |
| return self._run_denoiser_single(gray_uint8) | |
| # Process in tiles with overlap | |
| denoised = np.zeros((h, w), dtype=np.float32) | |
| counts = np.zeros((h, w), dtype=np.float32) # For averaging overlapping regions | |
| slide = self.max_tile_size - 2 * self.tile_overlap # Effective slide size | |
| for y0 in range(0, h, slide): | |
| y_start = max(0, y0 - self.tile_overlap) | |
| y_end = min(h, y0 + self.max_tile_size - self.tile_overlap) | |
| for x0 in range(0, w, slide): | |
| x_start = max(0, x0 - self.tile_overlap) | |
| x_end = min(w, x0 + self.max_tile_size - self.tile_overlap) | |
| # Extract tile | |
| tile = gray_uint8[y_start:y_end, x_start:x_end] | |
| # Process tile | |
| tile_denoised = self._run_denoiser_single(tile) | |
| # Determine output region (excluding overlap on first tile) | |
| out_y_start = y_start if y0 == 0 else y_start + self.tile_overlap | |
| out_y_end = y_end if y0 + slide >= h else y_end - self.tile_overlap | |
| out_x_start = x_start if x0 == 0 else x_start + self.tile_overlap | |
| out_x_end = x_end if x0 + slide >= w else x_end - self.tile_overlap | |
| # Extract corresponding region from denoised tile | |
| tile_y_start = out_y_start - y_start | |
| tile_y_end = out_y_end - y_start | |
| tile_x_start = out_x_start - x_start | |
| tile_x_end = out_x_end - x_start | |
| tile_output = tile_denoised[tile_y_start:tile_y_end, tile_x_start:tile_x_end] | |
| # Accumulate (for averaging overlapping regions) | |
| denoised[out_y_start:out_y_end, out_x_start:out_x_end] += tile_output | |
| counts[out_y_start:out_y_end, out_x_start:out_x_end] += 1.0 | |
| # Average overlapping regions | |
| mask = counts > 0 | |
| denoised[mask] /= counts[mask] | |
| return denoised | |
| def _run_denoiser_single(self, gray_uint8): | |
| """ | |
| Process a single image/tile through the denoiser. | |
| Args: | |
| gray_uint8: HxW uint8 grayscale. | |
| Returns: | |
| denoised image as float32 in [0,255]. | |
| """ | |
| if self.denoiser is None or torch is None: | |
| raise RuntimeError("Denoiser not available.") | |
| img = gray_uint8.astype(np.float32) / 255.0 | |
| tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(self.device) | |
| if self.device.type == 'cuda': | |
| amp_ctx = lambda: torch.amp.autocast('cuda') | |
| else: | |
| amp_ctx = nullcontext | |
| with torch.inference_mode(): | |
| with amp_ctx(): | |
| denoised = self.denoiser(tensor).cpu().numpy()[0, 0] | |
| # Clear tensor from GPU immediately | |
| del tensor | |
| denoised = np.clip(denoised * 255.0, 0, 255).astype(np.float32) | |
| return denoised | |
| def _run_denoiser(self, gray_uint8): | |
| """ | |
| Main denoiser entry point with automatic memory management. | |
| Args: | |
| gray_uint8: HxW uint8 grayscale. | |
| Returns: | |
| denoised image as float32 in [0,255]. | |
| """ | |
| if self.denoiser is None or torch is None: | |
| raise RuntimeError("Denoiser not available.") | |
| h, w = gray_uint8.shape | |
| # Auto-downscale extremely large images | |
| if self.auto_downscale and (h > self.max_image_size or w > self.max_image_size): | |
| scale = min(self.max_image_size / h, self.max_image_size / w) | |
| new_h, new_w = int(h * scale), int(w * scale) | |
| print(f"ResidualExtractor: Downscaling {h}x{w} -> {new_h}x{new_w} to fit in VRAM") | |
| gray_downscaled = cv2.resize(gray_uint8, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
| denoised_downscaled = self._run_denoiser_tiled(gray_downscaled) | |
| # Upscale back to original size | |
| denoised = cv2.resize(denoised_downscaled, (w, h), interpolation=cv2.INTER_LINEAR) | |
| return denoised | |
| # Use tiled processing for large images | |
| return self._run_denoiser_tiled(gray_uint8) | |
| def extract_features(self, image): | |
| """ | |
| Extracts residual-based features. | |
| Uses a proper denoiser if provided, otherwise falls back to Gaussian blur. | |
| Args: | |
| image (PIL.Image or np.ndarray): Input image. | |
| Returns: | |
| dict: Dictionary of features. | |
| """ | |
| if not isinstance(image, np.ndarray): | |
| image = np.array(image) | |
| if len(image.shape) == 3: | |
| gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
| else: | |
| gray = image | |
| gray = gray.astype(np.float32) | |
| gray_uint8 = np.clip(gray, 0, 255).astype(np.uint8) | |
| if self.denoiser is not None: | |
| h, w = gray.shape | |
| use_fast_path = ( | |
| torch is not None and | |
| h <= self.max_tile_size and w <= self.max_tile_size and | |
| not (self.auto_downscale and (h > self.max_image_size or w > self.max_image_size)) | |
| ) | |
| if use_fast_path: | |
| return self._extract_features_torch_fast(gray, gray_uint8) | |
| denoised = self._run_denoiser(gray_uint8) | |
| else: | |
| # Fallback: simple Gaussian denoiser (weaker forensic signal) | |
| denoised = cv2.GaussianBlur(gray_uint8, (5, 5), 0).astype(np.float32) | |
| residual = gray - denoised | |
| abs_res = np.abs(residual) | |
| features = { | |
| # legacy stats | |
| 'residual_mean': float(np.mean(residual)), | |
| 'residual_std': float(np.std(residual)), | |
| 'residual_skew': float(skew(residual.flatten())), | |
| 'residual_kurtosis': float(kurtosis(residual.flatten())), | |
| 'residual_energy': float(np.sum(residual ** 2)) / residual.size, | |
| 'residual_energy_mean': float(abs_res.mean()), | |
| 'residual_energy_std': float(abs_res.std()), | |
| 'residual_energy_p95': float(np.percentile(abs_res, 95)), | |
| } | |
| return features | |