p
File size: 6,938 Bytes
52fadc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import sqlite3
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np
from PIL import Image
from tqdm import tqdm

ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
IMAGES_DIR = os.path.join(ROOT_DIR, "images")
STASH_DIR = os.path.join(IMAGES_DIR, "Stash")
DB_PATH = os.path.join(ROOT_DIR, "db.sqlite")
MAX_WORKERS = min(16, os.cpu_count() or 8)
EXIF_METADATA_MAX_BYTES = 512
EXIF_TYPE_ORDER = ("novelai", "sd", "comfy", "mj", "celsys", "photoshop", "stealth")
EXIF_TYPE_TO_CODE = {name: idx + 1 for idx, name in enumerate(EXIF_TYPE_ORDER)}
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"

def open_db(path: str) -> sqlite3.Connection:
    conn = sqlite3.connect(path)
    conn.execute(
        """
        CREATE TABLE IF NOT EXISTS pixif_cache (
            post_id TEXT PRIMARY KEY,
            url TEXT,
            exif_type INTEGER
        )
        """
    )
    conn.commit()
    ensure_db_schema(conn)
    return conn

def ensure_db_schema(conn: sqlite3.Connection) -> None:
    columns = [row[1] for row in conn.execute("PRAGMA table_info(pixif_cache)")]
    if "exif_type" not in columns:
        conn.execute("ALTER TABLE pixif_cache ADD COLUMN exif_type INTEGER")
        conn.commit()

def determine_exif_type(metadata: Optional[bytes]) -> Optional[str]:
    if metadata is None:
        return None
    if metadata == b"TitleAI generated image":
        return "novelai"
    if metadata.startswith(b"parameter"):
        return "sd"
    if b'{"' in metadata:
        return "comfy"
    if metadata.startswith(b"SoftwareCelsys"):
        return "celsys"
    return "photoshop"

def exif_type_to_code(exif_type: Optional[str]) -> Optional[int]:
    if not exif_type:
        return None
    return EXIF_TYPE_TO_CODE.get(exif_type)

def parse_png_metadata(data: bytes) -> Optional[bytes]:
    index = 8
    while index < len(data):
        if index + 8 > len(data):
            break
        chunk_len = int.from_bytes(data[index:index + 4], "big")
        chunk_type = data[index + 4:index + 8]
        index += 8
        if chunk_type == b"tEXt":
            content = data[index:index + chunk_len]
            return content.replace(b"\0", b"")
        if chunk_type == b"iTXt":
            content = data[index:index + chunk_len]
            return content.strip()
        index += chunk_len + 4
    return None

def parse_png_metadata_file(path: str) -> Optional[bytes]:
    try:
        with open(path, "rb") as handle:
            head = handle.read(EXIF_METADATA_MAX_BYTES)
            if not head.startswith(PNG_SIGNATURE):
                return None
            return parse_png_metadata(head)
    except Exception:
        return None

def byteize(alpha: np.ndarray) -> np.ndarray:
    alpha = alpha.T.reshape((-1,))
    alpha = alpha[:(alpha.shape[0] // 8) * 8]
    alpha = np.bitwise_and(alpha, 1)
    alpha = alpha.reshape((-1, 8))
    alpha = np.packbits(alpha, axis=1)
    return alpha

class LSBExtractor:
    def __init__(self, alpha: np.ndarray) -> None:
        self.data = byteize(alpha)
        self.pos = 0

    def get_next_n_bytes(self, n: int) -> bytearray:
        n_bytes = self.data[self.pos:self.pos + n]
        self.pos += n
        return bytearray(n_bytes)

    def read_32bit_integer(self) -> Optional[int]:
        bytes_list = self.get_next_n_bytes(4)
        if len(bytes_list) == 4:
            return int.from_bytes(bytes_list, byteorder="big")
        return None

def extract_stealth_metadata(image: Image.Image) -> bool:
    if "A" not in image.getbands():
        raise AssertionError("image format")
    alpha = np.array(image.getchannel("A"))
    reader = LSBExtractor(alpha)
    magic = "stealth_pngcomp"
    read_magic = reader.get_next_n_bytes(len(magic)).decode("utf-8")
    if magic != read_magic:
        raise AssertionError("magic number")
    read_len = reader.read_32bit_integer()
    if read_len is None:
        raise AssertionError("length missing")
    return True

def has_stealth_png_path(path: str) -> bool:
    try:
        with Image.open(path) as image:
            return extract_stealth_metadata(image)
    except Exception:
        return False

def detect_exif_code_from_path(path: str) -> Optional[int]:
    metadata = parse_png_metadata_file(path)
    exif_type = determine_exif_type(metadata)
    code = exif_type_to_code(exif_type)
    if code is not None:
        return code
    if has_stealth_png_path(path):
        return EXIF_TYPE_TO_CODE.get("stealth")
    return None

def fetch_pending_post_ids(conn: sqlite3.Connection) -> List[str]:
    rows = conn.execute(
        """
        SELECT post_id
        FROM pixif_cache
        WHERE exif_type IS NULL
          AND COALESCE(url, '') != ''
        """
    ).fetchall()
    return [str(row[0]) for row in rows]

def update_exif_types(conn: sqlite3.Connection, rows: Sequence[Tuple[int, str]]) -> None:
    if not rows:
        return
    conn.executemany(
        """
        UPDATE pixif_cache SET exif_type = ?
        WHERE post_id = ?
        """,
        rows,
    )

def detect_exif_codes_from_files(
    post_ids: Sequence[str],
    stash_dir: str,
    max_workers: int = MAX_WORKERS,
) -> Dict[str, Optional[int]]:
    if not post_ids:
        return {}
    results: Dict[str, Optional[int]] = {}
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {
            executor.submit(
                detect_exif_code_from_path,
                os.path.join(stash_dir, f"{post_id}.png"),
            ): post_id
            for post_id in post_ids
        }
        with tqdm(total=len(futures), unit="image", desc="Scanning exif") as pbar:
            for future in as_completed(futures):
                post_id = futures[future]
                try:
                    code = future.result()
                except Exception:
                    code = None
                results[post_id] = code
                pbar.update(1)
    return results

def main() -> int:
    os.makedirs(STASH_DIR, exist_ok=True)
    conn = open_db(DB_PATH)
    try:
        post_ids = fetch_pending_post_ids(conn)
        if not post_ids:
            print("No pending rows.")
            return 0
        existing = [post_id for post_id in post_ids if os.path.exists(os.path.join(STASH_DIR, f"{post_id}.png"))]
        if not existing:
            print("No matching images in stash.")
            return 0
        results = detect_exif_codes_from_files(existing, STASH_DIR)
        rows = [
            (exif_type, post_id)
            for post_id, exif_type in results.items()
            if exif_type is not None
        ]
        if rows:
            with conn:
                update_exif_types(conn, rows)
        print(f"Updated {len(rows)} rows.")
        return 0
    finally:
        conn.close()

if __name__ == "__main__":
    raise SystemExit(main())