File size: 6,938 Bytes
52fadc8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | import os
import sqlite3
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
from PIL import Image
from tqdm import tqdm
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
IMAGES_DIR = os.path.join(ROOT_DIR, "images")
STASH_DIR = os.path.join(IMAGES_DIR, "Stash")
DB_PATH = os.path.join(ROOT_DIR, "db.sqlite")
MAX_WORKERS = min(16, os.cpu_count() or 8)
EXIF_METADATA_MAX_BYTES = 512
EXIF_TYPE_ORDER = ("novelai", "sd", "comfy", "mj", "celsys", "photoshop", "stealth")
EXIF_TYPE_TO_CODE = {name: idx + 1 for idx, name in enumerate(EXIF_TYPE_ORDER)}
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
def open_db(path: str) -> sqlite3.Connection:
conn = sqlite3.connect(path)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS pixif_cache (
post_id TEXT PRIMARY KEY,
url TEXT,
exif_type INTEGER
)
"""
)
conn.commit()
ensure_db_schema(conn)
return conn
def ensure_db_schema(conn: sqlite3.Connection) -> None:
columns = [row[1] for row in conn.execute("PRAGMA table_info(pixif_cache)")]
if "exif_type" not in columns:
conn.execute("ALTER TABLE pixif_cache ADD COLUMN exif_type INTEGER")
conn.commit()
def determine_exif_type(metadata: Optional[bytes]) -> Optional[str]:
if metadata is None:
return None
if metadata == b"TitleAI generated image":
return "novelai"
if metadata.startswith(b"parameter"):
return "sd"
if b'{"' in metadata:
return "comfy"
if metadata.startswith(b"SoftwareCelsys"):
return "celsys"
return "photoshop"
def exif_type_to_code(exif_type: Optional[str]) -> Optional[int]:
if not exif_type:
return None
return EXIF_TYPE_TO_CODE.get(exif_type)
def parse_png_metadata(data: bytes) -> Optional[bytes]:
index = 8
while index < len(data):
if index + 8 > len(data):
break
chunk_len = int.from_bytes(data[index:index + 4], "big")
chunk_type = data[index + 4:index + 8]
index += 8
if chunk_type == b"tEXt":
content = data[index:index + chunk_len]
return content.replace(b"\0", b"")
if chunk_type == b"iTXt":
content = data[index:index + chunk_len]
return content.strip()
index += chunk_len + 4
return None
def parse_png_metadata_file(path: str) -> Optional[bytes]:
try:
with open(path, "rb") as handle:
head = handle.read(EXIF_METADATA_MAX_BYTES)
if not head.startswith(PNG_SIGNATURE):
return None
return parse_png_metadata(head)
except Exception:
return None
def byteize(alpha: np.ndarray) -> np.ndarray:
alpha = alpha.T.reshape((-1,))
alpha = alpha[:(alpha.shape[0] // 8) * 8]
alpha = np.bitwise_and(alpha, 1)
alpha = alpha.reshape((-1, 8))
alpha = np.packbits(alpha, axis=1)
return alpha
class LSBExtractor:
def __init__(self, alpha: np.ndarray) -> None:
self.data = byteize(alpha)
self.pos = 0
def get_next_n_bytes(self, n: int) -> bytearray:
n_bytes = self.data[self.pos:self.pos + n]
self.pos += n
return bytearray(n_bytes)
def read_32bit_integer(self) -> Optional[int]:
bytes_list = self.get_next_n_bytes(4)
if len(bytes_list) == 4:
return int.from_bytes(bytes_list, byteorder="big")
return None
def extract_stealth_metadata(image: Image.Image) -> bool:
if "A" not in image.getbands():
raise AssertionError("image format")
alpha = np.array(image.getchannel("A"))
reader = LSBExtractor(alpha)
magic = "stealth_pngcomp"
read_magic = reader.get_next_n_bytes(len(magic)).decode("utf-8")
if magic != read_magic:
raise AssertionError("magic number")
read_len = reader.read_32bit_integer()
if read_len is None:
raise AssertionError("length missing")
return True
def has_stealth_png_path(path: str) -> bool:
try:
with Image.open(path) as image:
return extract_stealth_metadata(image)
except Exception:
return False
def detect_exif_code_from_path(path: str) -> Optional[int]:
metadata = parse_png_metadata_file(path)
exif_type = determine_exif_type(metadata)
code = exif_type_to_code(exif_type)
if code is not None:
return code
if has_stealth_png_path(path):
return EXIF_TYPE_TO_CODE.get("stealth")
return None
def fetch_pending_post_ids(conn: sqlite3.Connection) -> List[str]:
rows = conn.execute(
"""
SELECT post_id
FROM pixif_cache
WHERE exif_type IS NULL
AND COALESCE(url, '') != ''
"""
).fetchall()
return [str(row[0]) for row in rows]
def update_exif_types(conn: sqlite3.Connection, rows: Sequence[Tuple[int, str]]) -> None:
if not rows:
return
conn.executemany(
"""
UPDATE pixif_cache SET exif_type = ?
WHERE post_id = ?
""",
rows,
)
def detect_exif_codes_from_files(
post_ids: Sequence[str],
stash_dir: str,
max_workers: int = MAX_WORKERS,
) -> Dict[str, Optional[int]]:
if not post_ids:
return {}
results: Dict[str, Optional[int]] = {}
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {
executor.submit(
detect_exif_code_from_path,
os.path.join(stash_dir, f"{post_id}.png"),
): post_id
for post_id in post_ids
}
with tqdm(total=len(futures), unit="image", desc="Scanning exif") as pbar:
for future in as_completed(futures):
post_id = futures[future]
try:
code = future.result()
except Exception:
code = None
results[post_id] = code
pbar.update(1)
return results
def main() -> int:
os.makedirs(STASH_DIR, exist_ok=True)
conn = open_db(DB_PATH)
try:
post_ids = fetch_pending_post_ids(conn)
if not post_ids:
print("No pending rows.")
return 0
existing = [post_id for post_id in post_ids if os.path.exists(os.path.join(STASH_DIR, f"{post_id}.png"))]
if not existing:
print("No matching images in stash.")
return 0
results = detect_exif_codes_from_files(existing, STASH_DIR)
rows = [
(exif_type, post_id)
for post_id, exif_type in results.items()
if exif_type is not None
]
if rows:
with conn:
update_exif_types(conn, rows)
print(f"Updated {len(rows)} rows.")
return 0
finally:
conn.close()
if __name__ == "__main__":
raise SystemExit(main())
|