| import os |
| import sqlite3 |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from typing import Dict, List, Optional, Sequence, Tuple |
|
|
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| IMAGES_DIR = os.path.join(ROOT_DIR, "images") |
| STASH_DIR = os.path.join(IMAGES_DIR, "Stash") |
| DB_PATH = os.path.join(ROOT_DIR, "db.sqlite") |
| MAX_WORKERS = min(16, os.cpu_count() or 8) |
| EXIF_METADATA_MAX_BYTES = 512 |
| EXIF_TYPE_ORDER = ("novelai", "sd", "comfy", "mj", "celsys", "photoshop", "stealth") |
| EXIF_TYPE_TO_CODE = {name: idx + 1 for idx, name in enumerate(EXIF_TYPE_ORDER)} |
| PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" |
|
|
| def open_db(path: str) -> sqlite3.Connection: |
| conn = sqlite3.connect(path) |
| conn.execute( |
| """ |
| CREATE TABLE IF NOT EXISTS pixif_cache ( |
| post_id TEXT PRIMARY KEY, |
| url TEXT, |
| exif_type INTEGER |
| ) |
| """ |
| ) |
| conn.commit() |
| ensure_db_schema(conn) |
| return conn |
|
|
| def ensure_db_schema(conn: sqlite3.Connection) -> None: |
| columns = [row[1] for row in conn.execute("PRAGMA table_info(pixif_cache)")] |
| if "exif_type" not in columns: |
| conn.execute("ALTER TABLE pixif_cache ADD COLUMN exif_type INTEGER") |
| conn.commit() |
|
|
| def determine_exif_type(metadata: Optional[bytes]) -> Optional[str]: |
| if metadata is None: |
| return None |
| if metadata == b"TitleAI generated image": |
| return "novelai" |
| if metadata.startswith(b"parameter"): |
| return "sd" |
| if b'{"' in metadata: |
| return "comfy" |
| if metadata.startswith(b"SoftwareCelsys"): |
| return "celsys" |
| return "photoshop" |
|
|
| def exif_type_to_code(exif_type: Optional[str]) -> Optional[int]: |
| if not exif_type: |
| return None |
| return EXIF_TYPE_TO_CODE.get(exif_type) |
|
|
| def parse_png_metadata(data: bytes) -> Optional[bytes]: |
| index = 8 |
| while index < len(data): |
| if index + 8 > len(data): |
| break |
| chunk_len = int.from_bytes(data[index:index + 4], "big") |
| chunk_type = data[index + 4:index + 8] |
| index += 8 |
| if chunk_type == b"tEXt": |
| content = data[index:index + chunk_len] |
| return content.replace(b"\0", b"") |
| if chunk_type == b"iTXt": |
| content = data[index:index + chunk_len] |
| return content.strip() |
| index += chunk_len + 4 |
| return None |
|
|
| def parse_png_metadata_file(path: str) -> Optional[bytes]: |
| try: |
| with open(path, "rb") as handle: |
| head = handle.read(EXIF_METADATA_MAX_BYTES) |
| if not head.startswith(PNG_SIGNATURE): |
| return None |
| return parse_png_metadata(head) |
| except Exception: |
| return None |
|
|
| def byteize(alpha: np.ndarray) -> np.ndarray: |
| alpha = alpha.T.reshape((-1,)) |
| alpha = alpha[:(alpha.shape[0] // 8) * 8] |
| alpha = np.bitwise_and(alpha, 1) |
| alpha = alpha.reshape((-1, 8)) |
| alpha = np.packbits(alpha, axis=1) |
| return alpha |
|
|
| class LSBExtractor: |
| def __init__(self, alpha: np.ndarray) -> None: |
| self.data = byteize(alpha) |
| self.pos = 0 |
|
|
| def get_next_n_bytes(self, n: int) -> bytearray: |
| n_bytes = self.data[self.pos:self.pos + n] |
| self.pos += n |
| return bytearray(n_bytes) |
|
|
| def read_32bit_integer(self) -> Optional[int]: |
| bytes_list = self.get_next_n_bytes(4) |
| if len(bytes_list) == 4: |
| return int.from_bytes(bytes_list, byteorder="big") |
| return None |
|
|
| def extract_stealth_metadata(image: Image.Image) -> bool: |
| if "A" not in image.getbands(): |
| raise AssertionError("image format") |
| alpha = np.array(image.getchannel("A")) |
| reader = LSBExtractor(alpha) |
| magic = "stealth_pngcomp" |
| read_magic = reader.get_next_n_bytes(len(magic)).decode("utf-8") |
| if magic != read_magic: |
| raise AssertionError("magic number") |
| read_len = reader.read_32bit_integer() |
| if read_len is None: |
| raise AssertionError("length missing") |
| return True |
|
|
| def has_stealth_png_path(path: str) -> bool: |
| try: |
| with Image.open(path) as image: |
| return extract_stealth_metadata(image) |
| except Exception: |
| return False |
|
|
| def detect_exif_code_from_path(path: str) -> Optional[int]: |
| metadata = parse_png_metadata_file(path) |
| exif_type = determine_exif_type(metadata) |
| code = exif_type_to_code(exif_type) |
| if code is not None: |
| return code |
| if has_stealth_png_path(path): |
| return EXIF_TYPE_TO_CODE.get("stealth") |
| return None |
|
|
| def fetch_pending_post_ids(conn: sqlite3.Connection) -> List[str]: |
| rows = conn.execute( |
| """ |
| SELECT post_id |
| FROM pixif_cache |
| WHERE exif_type IS NULL |
| AND COALESCE(url, '') != '' |
| """ |
| ).fetchall() |
| return [str(row[0]) for row in rows] |
|
|
| def update_exif_types(conn: sqlite3.Connection, rows: Sequence[Tuple[int, str]]) -> None: |
| if not rows: |
| return |
| conn.executemany( |
| """ |
| UPDATE pixif_cache SET exif_type = ? |
| WHERE post_id = ? |
| """, |
| rows, |
| ) |
|
|
| def detect_exif_codes_from_files( |
| post_ids: Sequence[str], |
| stash_dir: str, |
| max_workers: int = MAX_WORKERS, |
| ) -> Dict[str, Optional[int]]: |
| if not post_ids: |
| return {} |
| results: Dict[str, Optional[int]] = {} |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| futures = { |
| executor.submit( |
| detect_exif_code_from_path, |
| os.path.join(stash_dir, f"{post_id}.png"), |
| ): post_id |
| for post_id in post_ids |
| } |
| with tqdm(total=len(futures), unit="image", desc="Scanning exif") as pbar: |
| for future in as_completed(futures): |
| post_id = futures[future] |
| try: |
| code = future.result() |
| except Exception: |
| code = None |
| results[post_id] = code |
| pbar.update(1) |
| return results |
|
|
| def main() -> int: |
| os.makedirs(STASH_DIR, exist_ok=True) |
| conn = open_db(DB_PATH) |
| try: |
| post_ids = fetch_pending_post_ids(conn) |
| if not post_ids: |
| print("No pending rows.") |
| return 0 |
| existing = [post_id for post_id in post_ids if os.path.exists(os.path.join(STASH_DIR, f"{post_id}.png"))] |
| if not existing: |
| print("No matching images in stash.") |
| return 0 |
| results = detect_exif_codes_from_files(existing, STASH_DIR) |
| rows = [ |
| (exif_type, post_id) |
| for post_id, exif_type in results.items() |
| if exif_type is not None |
| ] |
| if rows: |
| with conn: |
| update_exif_types(conn, rows) |
| print(f"Updated {len(rows)} rows.") |
| return 0 |
| finally: |
| conn.close() |
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|