Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Image Quality Filter (GPU-Accelerated) | |
| Filters raw scraped images based on resolution, sharpness, aspect ratio, | |
| file size, and deduplication. Uses GPU for batch sharpness and color analysis. | |
| Outputs high-quality images to data/processed/. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import shutil | |
| import logging | |
| import argparse | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import yaml | |
| import cv2 | |
| import numpy as np | |
| import imagehash | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from tqdm import tqdm | |
| # ─── SM120 (Blackwell) CUDA optimizations ─────────────────────────────────── | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Logging | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Config | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def load_config(config_path: str = "configs/config.yaml") -> dict: | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| return yaml.safe_load(f) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # GPU-Accelerated Quality Checker | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class ImageQualityChecker: | |
| """ | |
| Evaluate image quality using GPU-accelerated sharpness and color analysis. | |
| Falls back to CPU if no CUDA device is available. | |
| """ | |
| # Laplacian kernel for GPU sharpness detection | |
| LAPLACIAN_KERNEL = torch.tensor( | |
| [[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32 | |
| ).unsqueeze(0).unsqueeze(0) | |
| def __init__( | |
| self, | |
| min_resolution: int = 512, | |
| min_sharpness: float = 50.0, | |
| min_aspect_ratio: float = 0.4, | |
| max_aspect_ratio: float = 2.5, | |
| min_file_size_kb: int = 20, | |
| max_file_size_mb: int = 50, | |
| device: str = "auto", | |
| ): | |
| self.min_resolution = min_resolution | |
| self.min_sharpness = min_sharpness | |
| self.min_aspect_ratio = min_aspect_ratio | |
| self.max_aspect_ratio = max_aspect_ratio | |
| self.min_file_size_bytes = min_file_size_kb * 1024 | |
| self.max_file_size_bytes = max_file_size_mb * 1024 * 1024 | |
| # GPU setup | |
| if device == "auto": | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| self._kernel = self.LAPLACIAN_KERNEL.to(self.device) | |
| logger.info(f"Quality checker using device: {self.device}") | |
| def _gpu_sharpness(self, img_array: np.ndarray) -> float: | |
| """Compute sharpness using Laplacian on GPU.""" | |
| # Convert to grayscale | |
| gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| # Move to GPU as torch tensor | |
| tensor = torch.from_numpy(gray.astype(np.float32)).unsqueeze(0).unsqueeze(0) | |
| tensor = tensor.to(self.device) | |
| # Apply Laplacian convolution on GPU | |
| laplacian = F.conv2d(tensor, self._kernel, padding=1) | |
| sharpness = laplacian.var().item() | |
| return sharpness | |
| def _gpu_color_std(self, img_array: np.ndarray) -> float: | |
| """Compute color standard deviation on GPU.""" | |
| tensor = torch.from_numpy(img_array.astype(np.float32)).to(self.device) | |
| return tensor.std().item() | |
| def check(self, image_path: Path) -> tuple[bool, dict]: | |
| """ | |
| Check image quality. Returns (passed, metrics_dict). | |
| Sharpness and color checks run on GPU. | |
| """ | |
| metrics = { | |
| "path": str(image_path), | |
| "passed": False, | |
| "reason": None, | |
| } | |
| # File size check (CPU — trivial) | |
| file_size = image_path.stat().st_size | |
| metrics["file_size_bytes"] = file_size | |
| if file_size < self.min_file_size_bytes: | |
| metrics["reason"] = "file_too_small" | |
| return False, metrics | |
| if file_size > self.max_file_size_bytes: | |
| metrics["reason"] = "file_too_large" | |
| return False, metrics | |
| # Load image | |
| try: | |
| img = Image.open(image_path).convert("RGB") | |
| except Exception: | |
| metrics["reason"] = "unreadable" | |
| return False, metrics | |
| w, h = img.size | |
| metrics["width"] = w | |
| metrics["height"] = h | |
| # Resolution check (CPU — trivial) | |
| if min(w, h) < self.min_resolution: | |
| metrics["reason"] = "low_resolution" | |
| return False, metrics | |
| # Aspect ratio check (CPU — trivial) | |
| aspect = w / h | |
| metrics["aspect_ratio"] = round(aspect, 3) | |
| if aspect < self.min_aspect_ratio or aspect > self.max_aspect_ratio: | |
| metrics["reason"] = "bad_aspect_ratio" | |
| return False, metrics | |
| img_array = np.array(img) | |
| # Sharpness check (GPU-accelerated Laplacian) | |
| try: | |
| sharpness = self._gpu_sharpness(img_array) | |
| metrics["sharpness"] = round(sharpness, 2) | |
| if sharpness < self.min_sharpness: | |
| metrics["reason"] = "too_blurry" | |
| return False, metrics | |
| except Exception: | |
| metrics["reason"] = "sharpness_check_failed" | |
| return False, metrics | |
| # Color variance check (GPU-accelerated) | |
| std = self._gpu_color_std(img_array) | |
| metrics["color_std"] = round(float(std), 2) | |
| if std < 15.0: | |
| metrics["reason"] = "too_uniform" | |
| return False, metrics | |
| metrics["passed"] = True | |
| return True, metrics | |
| def check_batch(self, image_paths: list[Path]) -> list[tuple[bool, dict]]: | |
| """ | |
| Batch quality check — processes multiple images with GPU acceleration. | |
| Pre-filters by file size and resolution on CPU, then batches | |
| GPU operations for remaining images. | |
| """ | |
| results = [] | |
| for path in image_paths: | |
| results.append(self.check(path)) | |
| return results | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Deduplicator | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class Deduplicator: | |
| """Remove near-duplicate images using perceptual hashing.""" | |
| def __init__(self, hash_size: int = 8, threshold: int = 5): | |
| self.hash_size = hash_size | |
| self.threshold = threshold | |
| self.hashes: dict[str, "imagehash.ImageHash"] = {} | |
| def is_duplicate(self, image_path: Path) -> bool: | |
| try: | |
| img = Image.open(image_path).convert("RGB") | |
| h = imagehash.phash(img, hash_size=self.hash_size) | |
| for existing_path, existing_hash in self.hashes.items(): | |
| if abs(h - existing_hash) <= self.threshold: | |
| return True | |
| self.hashes[str(image_path)] = h | |
| return False | |
| except Exception: | |
| return True # Can't hash → treat as duplicate | |
| class GPUHasher: | |
| """ | |
| GPU-accelerated Perceptual Hashing (pHash). | |
| Strictly forces GPU usage. | |
| """ | |
| def __init__(self, device="cuda"): | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("❌ CUDA is not available! GPUHasher requires a GPU.") | |
| self.device = device | |
| logger.info(f"⚡ GPUHasher initialized on: {str(self.device).upper()}") | |
| self.dct_matrix = self._get_dct_matrix(32).to(self.device) | |
| def _get_dct_matrix(self, N): | |
| """Standard DCT-II matrix.""" | |
| dct_m = np.zeros((N, N)) | |
| for k in range(N): | |
| for n in range(N): | |
| dct_m[k, n] = np.cos(np.pi / N * (n + 0.5) * k) | |
| return torch.from_numpy(dct_m).float() | |
| def compute_hashes(self, image_paths: list[Path], batch_size=64) -> dict[str, int]: | |
| """ | |
| Compute pHash for a list of image paths using GPU acceleration. | |
| Returns dictionary {path_str: hash_int} | |
| """ | |
| results = {} | |
| # Use tqdm for progress bar | |
| with tqdm(total=len(image_paths), desc=" Computing hashes (GPU)", unit="img") as pbar: | |
| for i in range(0, len(image_paths), batch_size): | |
| batch_paths = image_paths[i : i + batch_size] | |
| batch_tensors = [] | |
| valid_paths = [] | |
| for p in batch_paths: | |
| try: | |
| # Open (L = grayscale) | |
| # We avoid PIL.resize here to save CPU | |
| img = Image.open(p).convert("L") | |
| # Convert to tensor [1, H, W] directly | |
| t = torch.from_numpy(np.array(img)).float().unsqueeze(0) / 255.0 | |
| batch_tensors.append(t) | |
| valid_paths.append(str(p)) | |
| except Exception: | |
| pass | |
| # Update pbar for the batch processed | |
| pbar.update(len(batch_paths)) | |
| if not batch_tensors: | |
| continue | |
| # GPU Processing | |
| try: | |
| gpu_tensors = [] | |
| for t in batch_tensors: | |
| # Move to GPU | |
| t_gpu = t.to(self.device, non_blocking=True).unsqueeze(0) # [1, 1, H, W] | |
| # Resize on GPU | |
| t_resized = F.interpolate(t_gpu, size=(32, 32), mode='bilinear', align_corners=False) | |
| gpu_tensors.append(t_resized.squeeze(0)) # [1, 32, 32] | |
| # Stack: [B, 32, 32] | |
| pixel_batch = torch.stack(gpu_tensors).squeeze(1) | |
| # Compute DCT: D * I * D^T | |
| # [32, 32] @ [B, 32, 32] @ [32, 32] -> [B, 32, 32] | |
| dct = torch.matmul(self.dct_matrix, pixel_batch) | |
| dct = torch.matmul(dct, self.dct_matrix.T) | |
| # Extract top-left 8x8 (excluding DC term at 0,0) | |
| # Flatten to [B, 64] | |
| dct_low = dct[:, :8, :8].reshape(-1, 64) | |
| # Compute median per image | |
| medians = dct_low.median(dim=1, keepdim=True).values | |
| # Generate hash: 1 if > median, 0 otherwise | |
| bits = (dct_low > medians).long() | |
| # Convert 64 bits to integer | |
| # Powers of 2 vector: [2^0, 2^1, ... 2^63] | |
| powers = (2 ** torch.arange(64, device=self.device)).long() | |
| hashes = (bits * powers).sum(dim=1).cpu().numpy() | |
| for p, h in zip(valid_paths, hashes): | |
| results[p] = int(h) | |
| except Exception as e: | |
| logger.debug(f"GPU Hash batch failed: {e}") | |
| continue | |
| return results | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Main Pipeline | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Main Pipeline | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def run_quality_filter(config: dict) -> dict: | |
| """Main quality filter pipeline (GPU-accelerated) with Auto-Scrape Top-Up.""" | |
| from pinterest_scraper import PinterestScraper, DEFAULT_QUERIES # Lazy import to avoid circular deps | |
| raw_dir = Path(config["paths"]["data"]["raw"]) | |
| processed_dir = Path(config["paths"]["data"]["processed"]) | |
| TARGET_COUNT = 1300 | |
| if not raw_dir.exists(): | |
| logger.error(f"Raw data directory does not exist: {raw_dir}") | |
| sys.exit(1) | |
| # Quality settings from config | |
| quality_cfg = config.get("dataset", {}).get("quality", {}) | |
| checker = ImageQualityChecker( | |
| min_resolution=quality_cfg.get("min_resolution", 512), | |
| min_sharpness=quality_cfg.get("min_sharpness", 50.0), | |
| min_aspect_ratio=quality_cfg.get("min_aspect_ratio", 0.4), | |
| max_aspect_ratio=quality_cfg.get("max_aspect_ratio", 2.5), | |
| ) | |
| dedup = Deduplicator() | |
| # Initialize scraper (but don't start driver yet) | |
| scraper = PinterestScraper(config, str(raw_dir)) | |
| # Log GPU status | |
| if torch.cuda.is_available(): | |
| gpu_name = torch.cuda.get_device_name(0) | |
| gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
| logger.info(f"🎮 GPU detected: {gpu_name}. Total memory: {gpu_mem:.2f} GB") | |
| else: | |
| logger.info("🖥️ No GPU detected — running on CPU (slower)") | |
| # Stats | |
| stats = defaultdict(lambda: {"total": 0, "passed": 0, "failed": 0, "duplicates": 0}) | |
| # 1. LOAD ALL EXISTING PROCESSED IMAGES (Global Deduplication) | |
| logger.info("🧠 Learning ALL existing images to prevent duplicates...") | |
| all_processed_files = [] | |
| for root, _, files in os.walk(processed_dir): | |
| for file in files: | |
| if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')): | |
| all_processed_files.append(Path(root) / file) | |
| existing_hashes = 0 | |
| if all_processed_files: | |
| hasher = GPUHasher() | |
| # Compute hashes for everything currently in processed | |
| batch_hashes = hasher.compute_hashes(all_processed_files, batch_size=128) | |
| dedup.hashes.update(batch_hashes) | |
| existing_hashes = len(batch_hashes) | |
| logger.info(f"✅ Memorized {existing_hashes} unique images in processed dataset.") | |
| # Collect all leaf directories (directories that contain images, not just parents) | |
| leaf_dirs = [] | |
| for root, dirs, files in os.walk(raw_dir): | |
| root_path = Path(root) | |
| # Check if this is a leaf node we want to process | |
| # (It might be empty now but was scraped before, or we want to scrape it) | |
| # For now, rely on existing folders in raw. | |
| rel_path = root_path.relative_to(raw_dir) | |
| # Skip the root directory itself (files directly in data/raw) | |
| if str(rel_path) == ".": | |
| continue | |
| leaf_dirs.append((rel_path, root_path)) | |
| if not leaf_dirs: | |
| logger.warning("No directories found in raw data.") | |
| return {} | |
| logger.info(f"Found {len(leaf_dirs)} theme directories to process") | |
| for rel_path, dir_path in sorted(leaf_dirs): | |
| category = str(rel_path).replace("\\", "/") | |
| out_dir = processed_dir / rel_path | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # We assume leaf dir if it has no subdirs with images? | |
| # Simpler: just process if we found it. | |
| while True: | |
| # Check current status in processed folder | |
| processed_images = [f for f in os.listdir(out_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] | |
| current_count = len(processed_images) | |
| # If we met the target, break loop and move to next category | |
| if current_count >= TARGET_COUNT: | |
| logger.info(f"✅ {category}: Target met ({current_count} images).") | |
| break | |
| needed = TARGET_COUNT - current_count | |
| logger.info(f"\nCategory: {category}") | |
| logger.info(f" Current: {current_count} | Needed: {needed}") | |
| # Get raw images | |
| raw_images = sorted([ | |
| dir_path / f for f in os.listdir(dir_path) | |
| if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.bmp')) | |
| ]) | |
| logger.info(f" Raw images available: {len(raw_images)}") | |
| # Identify candidates (raw images NOT yet in processed folder by filename) | |
| existing_filenames = set(processed_images) | |
| candidates = [p for p in raw_images if p.name not in existing_filenames] | |
| added_this_round = 0 | |
| if candidates: | |
| logger.info(f" Processing {len(candidates)} new candidates...") | |
| pbar = tqdm(candidates, desc=f" {category} (Filter)", unit="img") | |
| for img_path in pbar: | |
| if added_this_round >= needed: | |
| break | |
| stats[category]["total"] += 1 | |
| # Quality check (GPU-accelerated sharpness + color) | |
| passed, metrics = checker.check(img_path) | |
| if not passed: | |
| stats[category]["failed"] += 1 | |
| # logger.debug(f" REJECTED {img_path.name}: {metrics['reason']}") | |
| continue | |
| # Dedup check (Hash-based) | |
| if dedup.is_duplicate(img_path): | |
| stats[category]["duplicates"] += 1 | |
| # logger.debug(f" DUPLICATE {img_path.name}") | |
| continue | |
| # Copy to processed | |
| dest = out_dir / img_path.name | |
| shutil.copy2(img_path, dest) | |
| stats[category]["passed"] += 1 | |
| added_this_round += 1 | |
| pbar.close() | |
| current_count += added_this_round | |
| if current_count >= TARGET_COUNT: | |
| continue # Re-evaluate loop condition (which will break) | |
| # If still short, trigger scraper | |
| needed = TARGET_COUNT - current_count | |
| if needed > 0: | |
| logger.warning(f" ⚠️ Short by {needed} images! Launching Scraper to fetch more...") | |
| # Fetch query list | |
| queries = DEFAULT_QUERIES.get(category) | |
| if not queries: | |
| # Fallback queries | |
| theme = category.split("/")[-1] | |
| queries = [f"{theme} poster", f"{theme} design", f"{theme} advertisement"] | |
| # Scrape 2x what we need | |
| scrape_target = len(raw_images) + (needed * 2) | |
| # Ensure we at least target 2800 if we are really low | |
| scrape_target = max(scrape_target, 2800) | |
| scraper.TARGET_PER_THEME = scrape_target | |
| logger.info(f" 🕷️ Scraping target set to {scrape_target} for {category}...") | |
| try: | |
| # scraper.scrape_category downloads to raw_dir/{category} | |
| # It returns total downloaded count | |
| new_total = scraper.scrape_category(category, queries) | |
| logger.info(f" ✅ Scraping finished. Raw total is now {new_total}. Rescanning...") | |
| except Exception as e: | |
| logger.error(f" ❌ Scraper failed: {e}") | |
| break # Stop trying for this category if scraper fails | |
| else: | |
| break # Should be caught by top check, but safe fallback | |
| # Clear GPU memory | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return dict(stats) | |
| def print_summary(stats: dict): | |
| """Print a summary table.""" | |
| # ... existing print_summary code ... | |
| print("\n" + "=" * 60) | |
| print(f"{'Category':<35} | {'Total':<8} | {'Pass':<6} | {'Fail':<6} | {'Dupes':<6}") | |
| print("-" * 60) | |
| total_passed = 0 | |
| for cat, data in sorted(stats.items()): | |
| print(f"{cat:<35} | {data['total']:<8} | {data['passed']:<6} | {data['failed']:<6} | {data['duplicates']:<6}") | |
| total_passed += data['passed'] | |
| print("-" * 60) | |
| print(f"Total High-Quality Images: {total_passed}") | |
| print("=" * 60 + "\n") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run Quality Filter with Auto-Scrape") | |
| parser.add_argument("--config", default="configs/config.yaml", help="Path to config.yaml") | |
| args = parser.parse_args() | |
| config = load_config(args.config) | |
| # Run pipeline | |
| stats = run_quality_filter(config) | |
| print_summary(stats) | |
| logger.info("\n" + "=" * 80) | |
| logger.info("QUALITY FILTER SUMMARY") | |
| logger.info("=" * 80) | |
| logger.info(f" {'Category':35s} {'Total':>7s} {'Passed':>7s} {'Failed':>7s} {'Dupes':>7s} {'Rate':>7s}") | |
| logger.info(f" {'-'*35} {'-'*7} {'-'*7} {'-'*7} {'-'*7} {'-'*7}") | |
| grand_total = grand_passed = 0 | |
| for cat, s in sorted(stats.items()): | |
| rate = f"{s['passed']/max(s['total'],1)*100:.1f}%" | |
| logger.info( | |
| f" {cat:35s} {s['total']:7d} {s['passed']:7d} " | |
| f"{s['failed']:7d} {s['duplicates']:7d} {rate:>7s}" | |
| ) | |
| grand_total += s["total"] | |
| grand_passed += s["passed"] | |
| rate = f"{grand_passed/max(grand_total,1)*100:.1f}%" | |
| logger.info(f" {'-'*35} {'-'*7} {'-'*7} {'-'*7} {'-'*7} {'-'*7}") | |
| logger.info(f" {'TOTAL':35s} {grand_total:7d} {grand_passed:7d}{'':>17s} {rate:>7s}") | |
| logger.info("=" * 80) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Image Quality Filter (GPU-Accelerated)") | |
| parser.add_argument("--config", default="configs/config.yaml", help="Path to config.yaml") | |
| args = parser.parse_args() | |
| config = load_config(args.config) | |
| stats = run_quality_filter(config) | |
| print_summary(stats) | |
| if __name__ == "__main__": | |
| main() | |