Spaces:
Running
Running
| import sqlite3 | |
| from functools import partial | |
| from itertools import islice | |
| from pathlib import Path | |
| import click | |
| BATCH_SIZE=1024 | |
| def main(input_dir: Path, output: Path, image_host: str, explicit:bool): | |
| connection = sqlite3.connect(output) | |
| try: | |
| _main_with_connection(input_dir, connection, image_host, explicit) | |
| finally: | |
| connection.close() | |
| def _main_with_connection(input_dir: Path, connection: sqlite3.Connection, image_host: str=None, explicit=True): | |
| connection.execute("CREATE TABLE IF NOT EXISTS " | |
| " captions(image_key text PRIMARY KEY, caption text NOT NULL);") | |
| if image_host: | |
| connection.execute(f""" | |
| CREATE VIEW IF NOT EXISTS images AS | |
| SELECT {sql_quote(connection, image_host)} || image_key || '.jpg' AS image, | |
| caption, | |
| rowid | |
| FROM captions | |
| """) | |
| text_files = input_dir.glob("*.txt") | |
| with click.progressbar(chunked(text_files, BATCH_SIZE)) as progress: | |
| for batch in progress: | |
| text_file: Path | |
| pairs = ((text_file.stem, text_file.read_text()) | |
| for text_file in batch) | |
| with connection: | |
| connection.executemany("INSERT INTO captions(image_key, caption) " | |
| "VALUES(?, ?) ", pairs) | |
| if not explicit: | |
| ratings = ["rating:unsafe", "rating:explicit", "rating:mature", "meta:nsfw", | |
| "subreddit:%nsfw"] | |
| for rating in ratings: | |
| with connection: | |
| c = connection.execute("DELETE FROM captions WHERE caption LIKE ?", | |
| (f"%{rating}%",)) | |
| print(f"Removed {c.rowcount} {rating} rows") | |
| with connection: | |
| # Add full-text search index | |
| connection.execute("""CREATE VIRTUAL TABLE | |
| captions_fts USING | |
| fts5(caption, image_key UNINDEXED, content=captions) | |
| """) | |
| connection.execute(""" | |
| INSERT INTO "captions_fts" (rowid, image_key, caption) | |
| SELECT rowid, image_key, caption | |
| FROM captions | |
| """) | |
| def chunked(iterable, n): | |
| return iter(partial(take, n, iter(iterable)), []) | |
| def take(n, iterable): | |
| return list(islice(iterable, n)) | |
| def sql_quote(connection, value: str) -> str: | |
| """ | |
| Apply SQLite string quoting to a value, including wrapping it in single quotes. | |
| :param value: String to quote | |
| """ | |
| # Normally we would use .execute(sql, [params]) for escaping, but | |
| # occasionally that isn't available - most notable when we need | |
| # to include a "... DEFAULT 'value'" in a column definition. | |
| return connection.execute( | |
| # Use SQLite itself to correctly escape this string: | |
| "SELECT quote(:value)", | |
| {"value": value}, | |
| ).fetchone()[0] | |
| if __name__ == "__main__": | |
| main() | |