File size: 6,222 Bytes
a8aea21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import sqlite3
import imagehash
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import logging
from concurrent.futures import ThreadPoolExecutor
import time

logger = logging.getLogger(__name__)

class GlobalImageDeduplicator:
    """

    Globally tracks perceptual hashes of all images in the data directory

    to prevent downloading duplicates across all subfolders and phases.

    Uses an SQLite database for persistent caching to speed up initialization.

    """
    def __init__(self, data_dir: str, db_path: str = None, hash_size: int = 8, threshold: int = 5):
        self.data_dir = Path(data_dir)
        if db_path is None:
            # Store at root/data/phash_cache.db
            self.db_path = self.data_dir / "phash_cache.db"
        else:
            self.db_path = Path(db_path)
            
        self.hash_size = hash_size
        self.threshold = threshold
        self.hashes = [] # List of (filepath, imagehash.ImageHash)
        
        logger.info(f"Initializing Global Image Deduplicator using DB: {self.db_path}")
        self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
        self._init_db()
        self._load_and_sync()

    def _init_db(self):
        with self.conn:
            self.conn.execute('''

                CREATE TABLE IF NOT EXISTS phashes (

                    filepath TEXT PRIMARY KEY,

                    mtime REAL,

                    hash_str TEXT

                )

            ''')
            
    def _load_and_sync(self):
        logger.info(f"Scanning {self.data_dir} for images...")
        all_files = []
        for ext in ('*.jpg', '*.jpeg', '*.png', '*.webp'):
            all_files.extend(self.data_dir.rglob(ext))
            
        # Get existing from DB
        cursor = self.conn.cursor()
        cursor.execute("SELECT filepath, mtime, hash_str FROM phashes")
        db_records = {row[0]: (row[1], row[2]) for row in cursor.fetchall()}
        
        to_hash = []
        to_delete = []
        
        # Determine what needs hashing
        current_files = set(str(f) for f in all_files)
        
        for f in all_files:
            f_str = str(f)
            mtime = os.path.getmtime(f)
            if f_str in db_records:
                # If modified time changed, rehash
                if db_records[f_str][0] < mtime:
                    to_hash.append((f_str, f, mtime))
            else:
                to_hash.append((f_str, f, mtime))
                
        for db_file in db_records:
            if db_file not in current_files:
                to_delete.append(db_file)
                
        # Delete missing files from DB
        if to_delete:
            logger.info(f"Removing {len(to_delete)} deleted files from cache.")
            with self.conn:
                self.conn.executemany("DELETE FROM phashes WHERE filepath = ?", [(f,) for f in to_delete])
                
        # Hash new or modified files
        if to_hash:
            logger.info(f"Hashing {len(to_hash)} new/modified images. This might take a while...")
            
            def compute_hash(args):
                f_str, f, mtime = args
                try:
                    with Image.open(f) as img:
                        # Convert to RGB to be safe and avoid issues with alpha channels
                        conv_img = img.convert("RGB")
                        h = imagehash.phash(conv_img, hash_size=self.hash_size)
                        return f_str, mtime, str(h)
                except Exception as e:
                    logger.debug(f"Error hashing {f}: {e}")
                    return None
            
            results = []
            with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
                for res in tqdm(executor.map(compute_hash, to_hash), total=len(to_hash), desc="Hashing"):
                    if res is not None:
                        results.append(res)
            
            # Save new hashes to DB
            with self.conn:
                self.conn.executemany("INSERT OR REPLACE INTO phashes (filepath, mtime, hash_str) VALUES (?, ?, ?)", results)
                
        # Load all hashes into memory for fast comparison
        cursor.execute("SELECT filepath, hash_str FROM phashes")
        
        for filepath, hash_str in cursor.fetchall():
            self.hashes.append((filepath, imagehash.hex_to_hash(hash_str)))
            
        logger.info(f"Loaded {len(self.hashes)} image hashes for deduplication.")

    def is_duplicate(self, img: Image.Image, save_path: str = None) -> bool:
        """

        Check if an image is a duplicate of any globally known image.

        If save_path is provided, and it's NOT a duplicate, it adds the hash to the in-memory 

        cache immediately so we don't download the same duplicate in the same session.

        """
        # Ensure RGB
        if img.mode != 'RGB':
            img = img.convert('RGB')
            
        h = imagehash.phash(img, hash_size=self.hash_size)
        
        for existing_path, existing_hash in self.hashes:
            if abs(h - existing_hash) <= self.threshold:
                # logger.debug(f"Duplicate found! Matches {existing_path}")
                return True
                
        if save_path:
            self.hashes.append((str(save_path), h))
            
        return False
        
    def add_to_disk_cache(self, filepath: str, img: Image.Image):
        """

        Manually add an image to the DB cache. Use this after saving an image to disk

        so next time we run, it's already in the DB.

        """
        if img.mode != 'RGB':
            img = img.convert('RGB')
        h = imagehash.phash(img, hash_size=self.hash_size)
        # Wait slightly to ensure mtime is written
        time.sleep(0.01)
        mtime = os.path.getmtime(filepath)
        with self.conn:
            self.conn.execute("INSERT OR REPLACE INTO phashes (filepath, mtime, hash_str) VALUES (?, ?, ?)", 
                            (str(filepath), mtime, str(h)))