Campus-AI / scripts /image_deduplicator.py
realruneett's picture
Final Release: CampusGen AI Pipeline & Compositor
a8aea21
import os
import sqlite3
import imagehash
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import logging
from concurrent.futures import ThreadPoolExecutor
import time
logger = logging.getLogger(__name__)
class GlobalImageDeduplicator:
"""
Globally tracks perceptual hashes of all images in the data directory
to prevent downloading duplicates across all subfolders and phases.
Uses an SQLite database for persistent caching to speed up initialization.
"""
def __init__(self, data_dir: str, db_path: str = None, hash_size: int = 8, threshold: int = 5):
self.data_dir = Path(data_dir)
if db_path is None:
# Store at root/data/phash_cache.db
self.db_path = self.data_dir / "phash_cache.db"
else:
self.db_path = Path(db_path)
self.hash_size = hash_size
self.threshold = threshold
self.hashes = [] # List of (filepath, imagehash.ImageHash)
logger.info(f"Initializing Global Image Deduplicator using DB: {self.db_path}")
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
self._init_db()
self._load_and_sync()
def _init_db(self):
with self.conn:
self.conn.execute('''
CREATE TABLE IF NOT EXISTS phashes (
filepath TEXT PRIMARY KEY,
mtime REAL,
hash_str TEXT
)
''')
def _load_and_sync(self):
logger.info(f"Scanning {self.data_dir} for images...")
all_files = []
for ext in ('*.jpg', '*.jpeg', '*.png', '*.webp'):
all_files.extend(self.data_dir.rglob(ext))
# Get existing from DB
cursor = self.conn.cursor()
cursor.execute("SELECT filepath, mtime, hash_str FROM phashes")
db_records = {row[0]: (row[1], row[2]) for row in cursor.fetchall()}
to_hash = []
to_delete = []
# Determine what needs hashing
current_files = set(str(f) for f in all_files)
for f in all_files:
f_str = str(f)
mtime = os.path.getmtime(f)
if f_str in db_records:
# If modified time changed, rehash
if db_records[f_str][0] < mtime:
to_hash.append((f_str, f, mtime))
else:
to_hash.append((f_str, f, mtime))
for db_file in db_records:
if db_file not in current_files:
to_delete.append(db_file)
# Delete missing files from DB
if to_delete:
logger.info(f"Removing {len(to_delete)} deleted files from cache.")
with self.conn:
self.conn.executemany("DELETE FROM phashes WHERE filepath = ?", [(f,) for f in to_delete])
# Hash new or modified files
if to_hash:
logger.info(f"Hashing {len(to_hash)} new/modified images. This might take a while...")
def compute_hash(args):
f_str, f, mtime = args
try:
with Image.open(f) as img:
# Convert to RGB to be safe and avoid issues with alpha channels
conv_img = img.convert("RGB")
h = imagehash.phash(conv_img, hash_size=self.hash_size)
return f_str, mtime, str(h)
except Exception as e:
logger.debug(f"Error hashing {f}: {e}")
return None
results = []
with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
for res in tqdm(executor.map(compute_hash, to_hash), total=len(to_hash), desc="Hashing"):
if res is not None:
results.append(res)
# Save new hashes to DB
with self.conn:
self.conn.executemany("INSERT OR REPLACE INTO phashes (filepath, mtime, hash_str) VALUES (?, ?, ?)", results)
# Load all hashes into memory for fast comparison
cursor.execute("SELECT filepath, hash_str FROM phashes")
for filepath, hash_str in cursor.fetchall():
self.hashes.append((filepath, imagehash.hex_to_hash(hash_str)))
logger.info(f"Loaded {len(self.hashes)} image hashes for deduplication.")
def is_duplicate(self, img: Image.Image, save_path: str = None) -> bool:
"""
Check if an image is a duplicate of any globally known image.
If save_path is provided, and it's NOT a duplicate, it adds the hash to the in-memory
cache immediately so we don't download the same duplicate in the same session.
"""
# Ensure RGB
if img.mode != 'RGB':
img = img.convert('RGB')
h = imagehash.phash(img, hash_size=self.hash_size)
for existing_path, existing_hash in self.hashes:
if abs(h - existing_hash) <= self.threshold:
# logger.debug(f"Duplicate found! Matches {existing_path}")
return True
if save_path:
self.hashes.append((str(save_path), h))
return False
def add_to_disk_cache(self, filepath: str, img: Image.Image):
"""
Manually add an image to the DB cache. Use this after saving an image to disk
so next time we run, it's already in the DB.
"""
if img.mode != 'RGB':
img = img.convert('RGB')
h = imagehash.phash(img, hash_size=self.hash_size)
# Wait slightly to ensure mtime is written
time.sleep(0.01)
mtime = os.path.getmtime(filepath)
with self.conn:
self.conn.execute("INSERT OR REPLACE INTO phashes (filepath, mtime, hash_str) VALUES (?, ?, ?)",
(str(filepath), mtime, str(h)))