Spaces:
Running
Running
| import os | |
| import sqlite3 | |
| import imagehash | |
| from PIL import Image | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import logging | |
| from concurrent.futures import ThreadPoolExecutor | |
| import time | |
| logger = logging.getLogger(__name__) | |
| class GlobalImageDeduplicator: | |
| """ | |
| Globally tracks perceptual hashes of all images in the data directory | |
| to prevent downloading duplicates across all subfolders and phases. | |
| Uses an SQLite database for persistent caching to speed up initialization. | |
| """ | |
| def __init__(self, data_dir: str, db_path: str = None, hash_size: int = 8, threshold: int = 5): | |
| self.data_dir = Path(data_dir) | |
| if db_path is None: | |
| # Store at root/data/phash_cache.db | |
| self.db_path = self.data_dir / "phash_cache.db" | |
| else: | |
| self.db_path = Path(db_path) | |
| self.hash_size = hash_size | |
| self.threshold = threshold | |
| self.hashes = [] # List of (filepath, imagehash.ImageHash) | |
| logger.info(f"Initializing Global Image Deduplicator using DB: {self.db_path}") | |
| self.conn = sqlite3.connect(self.db_path, check_same_thread=False) | |
| self._init_db() | |
| self._load_and_sync() | |
| def _init_db(self): | |
| with self.conn: | |
| self.conn.execute(''' | |
| CREATE TABLE IF NOT EXISTS phashes ( | |
| filepath TEXT PRIMARY KEY, | |
| mtime REAL, | |
| hash_str TEXT | |
| ) | |
| ''') | |
| def _load_and_sync(self): | |
| logger.info(f"Scanning {self.data_dir} for images...") | |
| all_files = [] | |
| for ext in ('*.jpg', '*.jpeg', '*.png', '*.webp'): | |
| all_files.extend(self.data_dir.rglob(ext)) | |
| # Get existing from DB | |
| cursor = self.conn.cursor() | |
| cursor.execute("SELECT filepath, mtime, hash_str FROM phashes") | |
| db_records = {row[0]: (row[1], row[2]) for row in cursor.fetchall()} | |
| to_hash = [] | |
| to_delete = [] | |
| # Determine what needs hashing | |
| current_files = set(str(f) for f in all_files) | |
| for f in all_files: | |
| f_str = str(f) | |
| mtime = os.path.getmtime(f) | |
| if f_str in db_records: | |
| # If modified time changed, rehash | |
| if db_records[f_str][0] < mtime: | |
| to_hash.append((f_str, f, mtime)) | |
| else: | |
| to_hash.append((f_str, f, mtime)) | |
| for db_file in db_records: | |
| if db_file not in current_files: | |
| to_delete.append(db_file) | |
| # Delete missing files from DB | |
| if to_delete: | |
| logger.info(f"Removing {len(to_delete)} deleted files from cache.") | |
| with self.conn: | |
| self.conn.executemany("DELETE FROM phashes WHERE filepath = ?", [(f,) for f in to_delete]) | |
| # Hash new or modified files | |
| if to_hash: | |
| logger.info(f"Hashing {len(to_hash)} new/modified images. This might take a while...") | |
| def compute_hash(args): | |
| f_str, f, mtime = args | |
| try: | |
| with Image.open(f) as img: | |
| # Convert to RGB to be safe and avoid issues with alpha channels | |
| conv_img = img.convert("RGB") | |
| h = imagehash.phash(conv_img, hash_size=self.hash_size) | |
| return f_str, mtime, str(h) | |
| except Exception as e: | |
| logger.debug(f"Error hashing {f}: {e}") | |
| return None | |
| results = [] | |
| with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: | |
| for res in tqdm(executor.map(compute_hash, to_hash), total=len(to_hash), desc="Hashing"): | |
| if res is not None: | |
| results.append(res) | |
| # Save new hashes to DB | |
| with self.conn: | |
| self.conn.executemany("INSERT OR REPLACE INTO phashes (filepath, mtime, hash_str) VALUES (?, ?, ?)", results) | |
| # Load all hashes into memory for fast comparison | |
| cursor.execute("SELECT filepath, hash_str FROM phashes") | |
| for filepath, hash_str in cursor.fetchall(): | |
| self.hashes.append((filepath, imagehash.hex_to_hash(hash_str))) | |
| logger.info(f"Loaded {len(self.hashes)} image hashes for deduplication.") | |
| def is_duplicate(self, img: Image.Image, save_path: str = None) -> bool: | |
| """ | |
| Check if an image is a duplicate of any globally known image. | |
| If save_path is provided, and it's NOT a duplicate, it adds the hash to the in-memory | |
| cache immediately so we don't download the same duplicate in the same session. | |
| """ | |
| # Ensure RGB | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| h = imagehash.phash(img, hash_size=self.hash_size) | |
| for existing_path, existing_hash in self.hashes: | |
| if abs(h - existing_hash) <= self.threshold: | |
| # logger.debug(f"Duplicate found! Matches {existing_path}") | |
| return True | |
| if save_path: | |
| self.hashes.append((str(save_path), h)) | |
| return False | |
| def add_to_disk_cache(self, filepath: str, img: Image.Image): | |
| """ | |
| Manually add an image to the DB cache. Use this after saving an image to disk | |
| so next time we run, it's already in the DB. | |
| """ | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| h = imagehash.phash(img, hash_size=self.hash_size) | |
| # Wait slightly to ensure mtime is written | |
| time.sleep(0.01) | |
| mtime = os.path.getmtime(filepath) | |
| with self.conn: | |
| self.conn.execute("INSERT OR REPLACE INTO phashes (filepath, mtime, hash_str) VALUES (?, ?, ?)", | |
| (str(filepath), mtime, str(h))) | |