import os import sqlite3 from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Optional, Sequence, Tuple import numpy as np from PIL import Image from tqdm import tqdm ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) IMAGES_DIR = os.path.join(ROOT_DIR, "images") STASH_DIR = os.path.join(IMAGES_DIR, "Stash") DB_PATH = os.path.join(ROOT_DIR, "db.sqlite") MAX_WORKERS = min(16, os.cpu_count() or 8) EXIF_METADATA_MAX_BYTES = 512 EXIF_TYPE_ORDER = ("novelai", "sd", "comfy", "mj", "celsys", "photoshop", "stealth") EXIF_TYPE_TO_CODE = {name: idx + 1 for idx, name in enumerate(EXIF_TYPE_ORDER)} PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" def open_db(path: str) -> sqlite3.Connection: conn = sqlite3.connect(path) conn.execute( """ CREATE TABLE IF NOT EXISTS pixif_cache ( post_id TEXT PRIMARY KEY, url TEXT, exif_type INTEGER ) """ ) conn.commit() ensure_db_schema(conn) return conn def ensure_db_schema(conn: sqlite3.Connection) -> None: columns = [row[1] for row in conn.execute("PRAGMA table_info(pixif_cache)")] if "exif_type" not in columns: conn.execute("ALTER TABLE pixif_cache ADD COLUMN exif_type INTEGER") conn.commit() def determine_exif_type(metadata: Optional[bytes]) -> Optional[str]: if metadata is None: return None if metadata == b"TitleAI generated image": return "novelai" if metadata.startswith(b"parameter"): return "sd" if b'{"' in metadata: return "comfy" if metadata.startswith(b"SoftwareCelsys"): return "celsys" return "photoshop" def exif_type_to_code(exif_type: Optional[str]) -> Optional[int]: if not exif_type: return None return EXIF_TYPE_TO_CODE.get(exif_type) def parse_png_metadata(data: bytes) -> Optional[bytes]: index = 8 while index < len(data): if index + 8 > len(data): break chunk_len = int.from_bytes(data[index:index + 4], "big") chunk_type = data[index + 4:index + 8] index += 8 if chunk_type == b"tEXt": content = data[index:index + chunk_len] return content.replace(b"\0", b"") if chunk_type == b"iTXt": content = data[index:index + chunk_len] return content.strip() index += chunk_len + 4 return None def parse_png_metadata_file(path: str) -> Optional[bytes]: try: with open(path, "rb") as handle: head = handle.read(EXIF_METADATA_MAX_BYTES) if not head.startswith(PNG_SIGNATURE): return None return parse_png_metadata(head) except Exception: return None def byteize(alpha: np.ndarray) -> np.ndarray: alpha = alpha.T.reshape((-1,)) alpha = alpha[:(alpha.shape[0] // 8) * 8] alpha = np.bitwise_and(alpha, 1) alpha = alpha.reshape((-1, 8)) alpha = np.packbits(alpha, axis=1) return alpha class LSBExtractor: def __init__(self, alpha: np.ndarray) -> None: self.data = byteize(alpha) self.pos = 0 def get_next_n_bytes(self, n: int) -> bytearray: n_bytes = self.data[self.pos:self.pos + n] self.pos += n return bytearray(n_bytes) def read_32bit_integer(self) -> Optional[int]: bytes_list = self.get_next_n_bytes(4) if len(bytes_list) == 4: return int.from_bytes(bytes_list, byteorder="big") return None def extract_stealth_metadata(image: Image.Image) -> bool: if "A" not in image.getbands(): raise AssertionError("image format") alpha = np.array(image.getchannel("A")) reader = LSBExtractor(alpha) magic = "stealth_pngcomp" read_magic = reader.get_next_n_bytes(len(magic)).decode("utf-8") if magic != read_magic: raise AssertionError("magic number") read_len = reader.read_32bit_integer() if read_len is None: raise AssertionError("length missing") return True def has_stealth_png_path(path: str) -> bool: try: with Image.open(path) as image: return extract_stealth_metadata(image) except Exception: return False def detect_exif_code_from_path(path: str) -> Optional[int]: metadata = parse_png_metadata_file(path) exif_type = determine_exif_type(metadata) code = exif_type_to_code(exif_type) if code is not None: return code if has_stealth_png_path(path): return EXIF_TYPE_TO_CODE.get("stealth") return None def fetch_pending_post_ids(conn: sqlite3.Connection) -> List[str]: rows = conn.execute( """ SELECT post_id FROM pixif_cache WHERE exif_type IS NULL AND COALESCE(url, '') != '' """ ).fetchall() return [str(row[0]) for row in rows] def update_exif_types(conn: sqlite3.Connection, rows: Sequence[Tuple[int, str]]) -> None: if not rows: return conn.executemany( """ UPDATE pixif_cache SET exif_type = ? WHERE post_id = ? """, rows, ) def detect_exif_codes_from_files( post_ids: Sequence[str], stash_dir: str, max_workers: int = MAX_WORKERS, ) -> Dict[str, Optional[int]]: if not post_ids: return {} results: Dict[str, Optional[int]] = {} with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = { executor.submit( detect_exif_code_from_path, os.path.join(stash_dir, f"{post_id}.png"), ): post_id for post_id in post_ids } with tqdm(total=len(futures), unit="image", desc="Scanning exif") as pbar: for future in as_completed(futures): post_id = futures[future] try: code = future.result() except Exception: code = None results[post_id] = code pbar.update(1) return results def main() -> int: os.makedirs(STASH_DIR, exist_ok=True) conn = open_db(DB_PATH) try: post_ids = fetch_pending_post_ids(conn) if not post_ids: print("No pending rows.") return 0 existing = [post_id for post_id in post_ids if os.path.exists(os.path.join(STASH_DIR, f"{post_id}.png"))] if not existing: print("No matching images in stash.") return 0 results = detect_exif_codes_from_files(existing, STASH_DIR) rows = [ (exif_type, post_id) for post_id, exif_type in results.items() if exif_type is not None ] if rows: with conn: update_exif_types(conn, rows) print(f"Updated {len(rows)} rows.") return 0 finally: conn.close() if __name__ == "__main__": raise SystemExit(main())