Spaces:
Sleeping
Sleeping
File size: 24,268 Bytes
a8aea21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 | #!/usr/bin/env python3
"""
Image Quality Filter (GPU-Accelerated)
Filters raw scraped images based on resolution, sharpness, aspect ratio,
file size, and deduplication. Uses GPU for batch sharpness and color analysis.
Outputs high-quality images to data/processed/.
"""
import os
import sys
import json
import shutil
import logging
import argparse
from pathlib import Path
from collections import defaultdict
import yaml
import cv2
import numpy as np
import imagehash
import torch
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
# โโโ SM120 (Blackwell) CUDA optimizations โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Logging
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Config
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def load_config(config_path: str = "configs/config.yaml") -> dict:
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# GPU-Accelerated Quality Checker
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
class ImageQualityChecker:
"""
Evaluate image quality using GPU-accelerated sharpness and color analysis.
Falls back to CPU if no CUDA device is available.
"""
# Laplacian kernel for GPU sharpness detection
LAPLACIAN_KERNEL = torch.tensor(
[[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32
).unsqueeze(0).unsqueeze(0)
def __init__(
self,
min_resolution: int = 512,
min_sharpness: float = 50.0,
min_aspect_ratio: float = 0.4,
max_aspect_ratio: float = 2.5,
min_file_size_kb: int = 20,
max_file_size_mb: int = 50,
device: str = "auto",
):
self.min_resolution = min_resolution
self.min_sharpness = min_sharpness
self.min_aspect_ratio = min_aspect_ratio
self.max_aspect_ratio = max_aspect_ratio
self.min_file_size_bytes = min_file_size_kb * 1024
self.max_file_size_bytes = max_file_size_mb * 1024 * 1024
# GPU setup
if device == "auto":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
self._kernel = self.LAPLACIAN_KERNEL.to(self.device)
logger.info(f"Quality checker using device: {self.device}")
def _gpu_sharpness(self, img_array: np.ndarray) -> float:
"""Compute sharpness using Laplacian on GPU."""
# Convert to grayscale
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
# Move to GPU as torch tensor
tensor = torch.from_numpy(gray.astype(np.float32)).unsqueeze(0).unsqueeze(0)
tensor = tensor.to(self.device)
# Apply Laplacian convolution on GPU
laplacian = F.conv2d(tensor, self._kernel, padding=1)
sharpness = laplacian.var().item()
return sharpness
def _gpu_color_std(self, img_array: np.ndarray) -> float:
"""Compute color standard deviation on GPU."""
tensor = torch.from_numpy(img_array.astype(np.float32)).to(self.device)
return tensor.std().item()
def check(self, image_path: Path) -> tuple[bool, dict]:
"""
Check image quality. Returns (passed, metrics_dict).
Sharpness and color checks run on GPU.
"""
metrics = {
"path": str(image_path),
"passed": False,
"reason": None,
}
# File size check (CPU โ trivial)
file_size = image_path.stat().st_size
metrics["file_size_bytes"] = file_size
if file_size < self.min_file_size_bytes:
metrics["reason"] = "file_too_small"
return False, metrics
if file_size > self.max_file_size_bytes:
metrics["reason"] = "file_too_large"
return False, metrics
# Load image
try:
img = Image.open(image_path).convert("RGB")
except Exception:
metrics["reason"] = "unreadable"
return False, metrics
w, h = img.size
metrics["width"] = w
metrics["height"] = h
# Resolution check (CPU โ trivial)
if min(w, h) < self.min_resolution:
metrics["reason"] = "low_resolution"
return False, metrics
# Aspect ratio check (CPU โ trivial)
aspect = w / h
metrics["aspect_ratio"] = round(aspect, 3)
if aspect < self.min_aspect_ratio or aspect > self.max_aspect_ratio:
metrics["reason"] = "bad_aspect_ratio"
return False, metrics
img_array = np.array(img)
# Sharpness check (GPU-accelerated Laplacian)
try:
sharpness = self._gpu_sharpness(img_array)
metrics["sharpness"] = round(sharpness, 2)
if sharpness < self.min_sharpness:
metrics["reason"] = "too_blurry"
return False, metrics
except Exception:
metrics["reason"] = "sharpness_check_failed"
return False, metrics
# Color variance check (GPU-accelerated)
std = self._gpu_color_std(img_array)
metrics["color_std"] = round(float(std), 2)
if std < 15.0:
metrics["reason"] = "too_uniform"
return False, metrics
metrics["passed"] = True
return True, metrics
def check_batch(self, image_paths: list[Path]) -> list[tuple[bool, dict]]:
"""
Batch quality check โ processes multiple images with GPU acceleration.
Pre-filters by file size and resolution on CPU, then batches
GPU operations for remaining images.
"""
results = []
for path in image_paths:
results.append(self.check(path))
return results
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Deduplicator
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
class Deduplicator:
"""Remove near-duplicate images using perceptual hashing."""
def __init__(self, hash_size: int = 8, threshold: int = 5):
self.hash_size = hash_size
self.threshold = threshold
self.hashes: dict[str, "imagehash.ImageHash"] = {}
def is_duplicate(self, image_path: Path) -> bool:
try:
img = Image.open(image_path).convert("RGB")
h = imagehash.phash(img, hash_size=self.hash_size)
for existing_path, existing_hash in self.hashes.items():
if abs(h - existing_hash) <= self.threshold:
return True
self.hashes[str(image_path)] = h
return False
except Exception:
return True # Can't hash โ treat as duplicate
class GPUHasher:
"""
GPU-accelerated Perceptual Hashing (pHash).
Strictly forces GPU usage.
"""
def __init__(self, device="cuda"):
if not torch.cuda.is_available():
raise RuntimeError("โ CUDA is not available! GPUHasher requires a GPU.")
self.device = device
logger.info(f"โก GPUHasher initialized on: {str(self.device).upper()}")
self.dct_matrix = self._get_dct_matrix(32).to(self.device)
def _get_dct_matrix(self, N):
"""Standard DCT-II matrix."""
dct_m = np.zeros((N, N))
for k in range(N):
for n in range(N):
dct_m[k, n] = np.cos(np.pi / N * (n + 0.5) * k)
return torch.from_numpy(dct_m).float()
def compute_hashes(self, image_paths: list[Path], batch_size=64) -> dict[str, int]:
"""
Compute pHash for a list of image paths using GPU acceleration.
Returns dictionary {path_str: hash_int}
"""
results = {}
# Use tqdm for progress bar
with tqdm(total=len(image_paths), desc=" Computing hashes (GPU)", unit="img") as pbar:
for i in range(0, len(image_paths), batch_size):
batch_paths = image_paths[i : i + batch_size]
batch_tensors = []
valid_paths = []
for p in batch_paths:
try:
# Open (L = grayscale)
# We avoid PIL.resize here to save CPU
img = Image.open(p).convert("L")
# Convert to tensor [1, H, W] directly
t = torch.from_numpy(np.array(img)).float().unsqueeze(0) / 255.0
batch_tensors.append(t)
valid_paths.append(str(p))
except Exception:
pass
# Update pbar for the batch processed
pbar.update(len(batch_paths))
if not batch_tensors:
continue
# GPU Processing
try:
gpu_tensors = []
for t in batch_tensors:
# Move to GPU
t_gpu = t.to(self.device, non_blocking=True).unsqueeze(0) # [1, 1, H, W]
# Resize on GPU
t_resized = F.interpolate(t_gpu, size=(32, 32), mode='bilinear', align_corners=False)
gpu_tensors.append(t_resized.squeeze(0)) # [1, 32, 32]
# Stack: [B, 32, 32]
pixel_batch = torch.stack(gpu_tensors).squeeze(1)
# Compute DCT: D * I * D^T
# [32, 32] @ [B, 32, 32] @ [32, 32] -> [B, 32, 32]
dct = torch.matmul(self.dct_matrix, pixel_batch)
dct = torch.matmul(dct, self.dct_matrix.T)
# Extract top-left 8x8 (excluding DC term at 0,0)
# Flatten to [B, 64]
dct_low = dct[:, :8, :8].reshape(-1, 64)
# Compute median per image
medians = dct_low.median(dim=1, keepdim=True).values
# Generate hash: 1 if > median, 0 otherwise
bits = (dct_low > medians).long()
# Convert 64 bits to integer
# Powers of 2 vector: [2^0, 2^1, ... 2^63]
powers = (2 ** torch.arange(64, device=self.device)).long()
hashes = (bits * powers).sum(dim=1).cpu().numpy()
for p, h in zip(valid_paths, hashes):
results[p] = int(h)
except Exception as e:
logger.debug(f"GPU Hash batch failed: {e}")
continue
return results
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Main Pipeline
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Main Pipeline
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
def run_quality_filter(config: dict) -> dict:
"""Main quality filter pipeline (GPU-accelerated) with Auto-Scrape Top-Up."""
from pinterest_scraper import PinterestScraper, DEFAULT_QUERIES # Lazy import to avoid circular deps
raw_dir = Path(config["paths"]["data"]["raw"])
processed_dir = Path(config["paths"]["data"]["processed"])
TARGET_COUNT = 1300
if not raw_dir.exists():
logger.error(f"Raw data directory does not exist: {raw_dir}")
sys.exit(1)
# Quality settings from config
quality_cfg = config.get("dataset", {}).get("quality", {})
checker = ImageQualityChecker(
min_resolution=quality_cfg.get("min_resolution", 512),
min_sharpness=quality_cfg.get("min_sharpness", 50.0),
min_aspect_ratio=quality_cfg.get("min_aspect_ratio", 0.4),
max_aspect_ratio=quality_cfg.get("max_aspect_ratio", 2.5),
)
dedup = Deduplicator()
# Initialize scraper (but don't start driver yet)
scraper = PinterestScraper(config, str(raw_dir))
# Log GPU status
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)
logger.info(f"๐ฎ GPU detected: {gpu_name}. Total memory: {gpu_mem:.2f} GB")
else:
logger.info("๐ฅ๏ธ No GPU detected โ running on CPU (slower)")
# Stats
stats = defaultdict(lambda: {"total": 0, "passed": 0, "failed": 0, "duplicates": 0})
# 1. LOAD ALL EXISTING PROCESSED IMAGES (Global Deduplication)
logger.info("๐ง Learning ALL existing images to prevent duplicates...")
all_processed_files = []
for root, _, files in os.walk(processed_dir):
for file in files:
if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
all_processed_files.append(Path(root) / file)
existing_hashes = 0
if all_processed_files:
hasher = GPUHasher()
# Compute hashes for everything currently in processed
batch_hashes = hasher.compute_hashes(all_processed_files, batch_size=128)
dedup.hashes.update(batch_hashes)
existing_hashes = len(batch_hashes)
logger.info(f"โ
Memorized {existing_hashes} unique images in processed dataset.")
# Collect all leaf directories (directories that contain images, not just parents)
leaf_dirs = []
for root, dirs, files in os.walk(raw_dir):
root_path = Path(root)
# Check if this is a leaf node we want to process
# (It might be empty now but was scraped before, or we want to scrape it)
# For now, rely on existing folders in raw.
rel_path = root_path.relative_to(raw_dir)
# Skip the root directory itself (files directly in data/raw)
if str(rel_path) == ".":
continue
leaf_dirs.append((rel_path, root_path))
if not leaf_dirs:
logger.warning("No directories found in raw data.")
return {}
logger.info(f"Found {len(leaf_dirs)} theme directories to process")
for rel_path, dir_path in sorted(leaf_dirs):
category = str(rel_path).replace("\\", "/")
out_dir = processed_dir / rel_path
out_dir.mkdir(parents=True, exist_ok=True)
# We assume leaf dir if it has no subdirs with images?
# Simpler: just process if we found it.
while True:
# Check current status in processed folder
processed_images = [f for f in os.listdir(out_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
current_count = len(processed_images)
# If we met the target, break loop and move to next category
if current_count >= TARGET_COUNT:
logger.info(f"โ
{category}: Target met ({current_count} images).")
break
needed = TARGET_COUNT - current_count
logger.info(f"\nCategory: {category}")
logger.info(f" Current: {current_count} | Needed: {needed}")
# Get raw images
raw_images = sorted([
dir_path / f for f in os.listdir(dir_path)
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp', '.bmp'))
])
logger.info(f" Raw images available: {len(raw_images)}")
# Identify candidates (raw images NOT yet in processed folder by filename)
existing_filenames = set(processed_images)
candidates = [p for p in raw_images if p.name not in existing_filenames]
added_this_round = 0
if candidates:
logger.info(f" Processing {len(candidates)} new candidates...")
pbar = tqdm(candidates, desc=f" {category} (Filter)", unit="img")
for img_path in pbar:
if added_this_round >= needed:
break
stats[category]["total"] += 1
# Quality check (GPU-accelerated sharpness + color)
passed, metrics = checker.check(img_path)
if not passed:
stats[category]["failed"] += 1
# logger.debug(f" REJECTED {img_path.name}: {metrics['reason']}")
continue
# Dedup check (Hash-based)
if dedup.is_duplicate(img_path):
stats[category]["duplicates"] += 1
# logger.debug(f" DUPLICATE {img_path.name}")
continue
# Copy to processed
dest = out_dir / img_path.name
shutil.copy2(img_path, dest)
stats[category]["passed"] += 1
added_this_round += 1
pbar.close()
current_count += added_this_round
if current_count >= TARGET_COUNT:
continue # Re-evaluate loop condition (which will break)
# If still short, trigger scraper
needed = TARGET_COUNT - current_count
if needed > 0:
logger.warning(f" โ ๏ธ Short by {needed} images! Launching Scraper to fetch more...")
# Fetch query list
queries = DEFAULT_QUERIES.get(category)
if not queries:
# Fallback queries
theme = category.split("/")[-1]
queries = [f"{theme} poster", f"{theme} design", f"{theme} advertisement"]
# Scrape 2x what we need
scrape_target = len(raw_images) + (needed * 2)
# Ensure we at least target 2800 if we are really low
scrape_target = max(scrape_target, 2800)
scraper.TARGET_PER_THEME = scrape_target
logger.info(f" ๐ท๏ธ Scraping target set to {scrape_target} for {category}...")
try:
# scraper.scrape_category downloads to raw_dir/{category}
# It returns total downloaded count
new_total = scraper.scrape_category(category, queries)
logger.info(f" โ
Scraping finished. Raw total is now {new_total}. Rescanning...")
except Exception as e:
logger.error(f" โ Scraper failed: {e}")
break # Stop trying for this category if scraper fails
else:
break # Should be caught by top check, but safe fallback
# Clear GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
return dict(stats)
def print_summary(stats: dict):
"""Print a summary table."""
# ... existing print_summary code ...
print("\n" + "=" * 60)
print(f"{'Category':<35} | {'Total':<8} | {'Pass':<6} | {'Fail':<6} | {'Dupes':<6}")
print("-" * 60)
total_passed = 0
for cat, data in sorted(stats.items()):
print(f"{cat:<35} | {data['total']:<8} | {data['passed']:<6} | {data['failed']:<6} | {data['duplicates']:<6}")
total_passed += data['passed']
print("-" * 60)
print(f"Total High-Quality Images: {total_passed}")
print("=" * 60 + "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run Quality Filter with Auto-Scrape")
parser.add_argument("--config", default="configs/config.yaml", help="Path to config.yaml")
args = parser.parse_args()
config = load_config(args.config)
# Run pipeline
stats = run_quality_filter(config)
print_summary(stats)
logger.info("\n" + "=" * 80)
logger.info("QUALITY FILTER SUMMARY")
logger.info("=" * 80)
logger.info(f" {'Category':35s} {'Total':>7s} {'Passed':>7s} {'Failed':>7s} {'Dupes':>7s} {'Rate':>7s}")
logger.info(f" {'-'*35} {'-'*7} {'-'*7} {'-'*7} {'-'*7} {'-'*7}")
grand_total = grand_passed = 0
for cat, s in sorted(stats.items()):
rate = f"{s['passed']/max(s['total'],1)*100:.1f}%"
logger.info(
f" {cat:35s} {s['total']:7d} {s['passed']:7d} "
f"{s['failed']:7d} {s['duplicates']:7d} {rate:>7s}"
)
grand_total += s["total"]
grand_passed += s["passed"]
rate = f"{grand_passed/max(grand_total,1)*100:.1f}%"
logger.info(f" {'-'*35} {'-'*7} {'-'*7} {'-'*7} {'-'*7} {'-'*7}")
logger.info(f" {'TOTAL':35s} {grand_total:7d} {grand_passed:7d}{'':>17s} {rate:>7s}")
logger.info("=" * 80)
def main():
parser = argparse.ArgumentParser(description="Image Quality Filter (GPU-Accelerated)")
parser.add_argument("--config", default="configs/config.yaml", help="Path to config.yaml")
args = parser.parse_args()
config = load_config(args.config)
stats = run_quality_filter(config)
print_summary(stats)
if __name__ == "__main__":
main()
|