"""Schema introspection and mutation helpers for synthetic database generation.""" from __future__ import annotations import sqlite3 from dataclasses import dataclass from pathlib import Path @dataclass class MutationResult: """Result of applying a single mutation to a database.""" mutation_name: str tables_affected: list[str] rows_added: int success: bool @dataclass class TableSchema: """Schema information for a single table.""" name: str columns: list[str] pk_columns: list[str] fk_columns: list[tuple[str, str, str]] def get_table_schemas(db_path: str) -> list[TableSchema]: """Extract table schema metadata (columns, PKs, and FKs) from a SQLite DB.""" path = Path(db_path) if not path.exists(): raise sqlite3.OperationalError(f"Database does not exist: {db_path}") try: with sqlite3.connect(path) as connection: cursor = connection.cursor() cursor.execute( """ SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name """ ) table_names = [row[0] for row in cursor.fetchall()] schemas: list[TableSchema] = [] for table_name in table_names: pragma_name = table_name.replace('"', '""') cursor.execute(f'PRAGMA table_info("{pragma_name}")') table_info = cursor.fetchall() columns = [row[1] for row in table_info] pk_ordered = sorted( ((int(row[5]), str(row[1])) for row in table_info if row[5]), key=lambda item: item[0], ) pk_columns = [column_name for _, column_name in pk_ordered] cursor.execute(f'PRAGMA foreign_key_list("{pragma_name}")') fk_info = cursor.fetchall() fk_columns = [ (str(row[3]), str(row[2]), str(row[4])) for row in fk_info if row[3] and row[2] and row[4] ] schemas.append( TableSchema( name=table_name, columns=columns, pk_columns=pk_columns, fk_columns=fk_columns, ) ) return schemas except sqlite3.DatabaseError as exc: raise sqlite3.OperationalError(str(exc)) from exc def detect_bridge_tables(schemas: list[TableSchema]) -> list[str]: """Return tables that look like bridge tables (2 or more foreign keys).""" return [schema.name for schema in schemas if len(schema.fk_columns) >= 2] def _quote_identifier(identifier: str) -> str: return f'"{identifier.replace(chr(34), chr(34) + chr(34))}"' def _column_affinity(declared_type: str) -> str: normalized = declared_type.upper() if "INT" in normalized: return "INTEGER" if any(token in normalized for token in ("CHAR", "CLOB", "TEXT")): return "TEXT" if any(token in normalized for token in ("REAL", "FLOA", "DOUB")): return "REAL" if "BLOB" in normalized: return "BLOB" return "NUMERIC" def inject_irrelevant_rows( db_path: str, schemas: list[TableSchema], n_rows: int = 5, ) -> MutationResult: """Inject synthetic rows into non-bridge tables with integer primary keys.""" if n_rows <= 0: return MutationResult( mutation_name="inject_irrelevant_rows", tables_affected=[], rows_added=0, success=True, ) bridge_tables = set(detect_bridge_tables(schemas)) rows_added = 0 tables_affected: list[str] = [] with sqlite3.connect(db_path) as connection: cursor = connection.cursor() for schema in schemas: if schema.name in bridge_tables or len(schema.pk_columns) != 1: continue pk_column = schema.pk_columns[0] pragma_table = schema.name.replace('"', '""') cursor.execute(f'PRAGMA table_info("{pragma_table}")') table_info = cursor.fetchall() if not table_info: continue column_by_name = {str(row[1]): row for row in table_info} pk_info = column_by_name.get(pk_column) if pk_info is None: continue pk_affinity = _column_affinity(str(pk_info[2])) if pk_affinity != "INTEGER": continue quoted_table = _quote_identifier(schema.name) quoted_pk = _quote_identifier(pk_column) cursor.execute(f"SELECT MAX({quoted_pk}) FROM {quoted_table}") max_pk = cursor.fetchone()[0] next_pk = int(max_pk) + 1 if max_pk is not None else 1 fk_targets: dict[str, object] = {} for fk_column, ref_table, ref_column in schema.fk_columns: quoted_ref_table = _quote_identifier(ref_table) quoted_ref_column = _quote_identifier(ref_column) cursor.execute( f"SELECT {quoted_ref_column} FROM {quoted_ref_table} LIMIT 1" ) result = cursor.fetchone() if result is None: fk_targets[fk_column] = None else: fk_targets[fk_column] = result[0] integer_column_max: dict[str, int] = {} for row in table_info: column_name = str(row[1]) if column_name == pk_column or column_name in fk_targets: continue affinity = _column_affinity(str(row[2])) if affinity != "INTEGER": continue quoted_column = _quote_identifier(column_name) cursor.execute(f"SELECT MAX({quoted_column}) FROM {quoted_table}") column_max = cursor.fetchone()[0] integer_column_max[column_name] = ( int(column_max) if column_max is not None else 0 ) inserted_for_table = 0 for row_index in range(n_rows): row_values: list[object] = [] skip_table = False for row in table_info: column_name = str(row[1]) declared_type = str(row[2]) not_null = bool(row[3]) default_value = row[4] if column_name == pk_column: value: object = next_pk next_pk += 1 elif column_name in fk_targets: value = fk_targets[column_name] else: affinity = _column_affinity(declared_type) if affinity == "INTEGER": value = ( integer_column_max.get(column_name, 0) + 1000 + row_index ) elif affinity == "REAL": value = float(row_index + 1) elif affinity in ("TEXT", "NUMERIC"): value = f"SYNTHETIC_{schema.name}_{column_name}_{row_index}" else: value = None if value is None and not_null: if default_value is not None: value = default_value else: skip_table = True break row_values.append(value) if skip_table: inserted_for_table = 0 break quoted_columns = ", ".join( _quote_identifier(str(row[1])) for row in table_info ) placeholders = ", ".join("?" for _ in table_info) insert_sql = ( f"INSERT INTO {quoted_table}" f" ({quoted_columns})" f" VALUES ({placeholders})" ) cursor.execute(insert_sql, row_values) inserted_for_table += 1 if inserted_for_table > 0: tables_affected.append(schema.name) rows_added += inserted_for_table connection.commit() return MutationResult( mutation_name="inject_irrelevant_rows", tables_affected=sorted(tables_affected), rows_added=rows_added, success=True, ) def remap_ids(db_path: str, schemas: list[TableSchema]) -> MutationResult: """Remap integer primary keys and matching foreign keys with a bijection.""" remap_plan: dict[str, tuple[str, dict[int, int]]] = {} tables_affected: set[str] = set() rows_updated = 0 with sqlite3.connect(db_path) as connection: cursor = connection.cursor() for schema in schemas: if len(schema.pk_columns) != 1: continue pk_column = schema.pk_columns[0] quoted_table = _quote_identifier(schema.name) quoted_pk = _quote_identifier(pk_column) cursor.execute(f"PRAGMA table_info({quoted_table})") table_info = cursor.fetchall() column_by_name = {str(row[1]): row for row in table_info} pk_info = column_by_name.get(pk_column) if pk_info is None: continue if _column_affinity(str(pk_info[2])) != "INTEGER": continue cursor.execute( f"SELECT {quoted_pk} FROM {quoted_table}" f" WHERE {quoted_pk} IS NOT NULL" f" ORDER BY {quoted_pk}" ) source_ids = [int(row[0]) for row in cursor.fetchall()] if not source_ids: continue start_id = max(source_ids) + 1000 mapping = { source_id: start_id + index for index, source_id in enumerate(source_ids) } remap_plan[schema.name] = (pk_column, mapping) if not remap_plan: return MutationResult( mutation_name="remap_ids", tables_affected=[], rows_added=0, success=True, ) try: cursor.execute("PRAGMA foreign_keys = OFF") for table_name, (pk_column, mapping) in remap_plan.items(): quoted_table = _quote_identifier(table_name) quoted_pk = _quote_identifier(pk_column) case_parts = " ".join( f"WHEN {old_id} THEN {new_id}" for old_id, new_id in mapping.items() ) where_values = ", ".join(str(old_id) for old_id in mapping) update_sql = ( f"UPDATE {quoted_table}" f" SET {quoted_pk} = CASE {quoted_pk}" f" {case_parts} ELSE {quoted_pk} END" f" WHERE {quoted_pk} IN ({where_values})" ) cursor.execute(update_sql) tables_affected.add(table_name) rows_updated += len(mapping) for child_schema in schemas: quoted_child_table = _quote_identifier(child_schema.name) for fk_column, ref_table, ref_column in child_schema.fk_columns: parent_plan = remap_plan.get(ref_table) if parent_plan is None: continue parent_pk_column, parent_mapping = parent_plan if ref_column != parent_pk_column: continue quoted_fk = _quote_identifier(fk_column) case_parts = " ".join( f"WHEN {old_id} THEN {new_id}" for old_id, new_id in parent_mapping.items() ) where_values = ", ".join(str(old_id) for old_id in parent_mapping) update_sql = ( f"UPDATE {quoted_child_table}" f" SET {quoted_fk} = CASE {quoted_fk}" f" {case_parts} ELSE {quoted_fk} END" f" WHERE {quoted_fk} IN ({where_values})" ) cursor.execute(update_sql) if cursor.rowcount > 0: tables_affected.add(child_schema.name) cursor.execute("PRAGMA foreign_keys = ON") cursor.execute("PRAGMA foreign_key_check") fk_violations = cursor.fetchall() if fk_violations: raise sqlite3.IntegrityError( "Foreign key integrity check failed" f" after ID remapping: {fk_violations[0]}" ) connection.commit() except Exception: connection.rollback() cursor.execute("PRAGMA foreign_keys = ON") raise return MutationResult( mutation_name="remap_ids", tables_affected=sorted(tables_affected), rows_added=rows_updated, success=True, ) def duplicate_bridge_rows( db_path: str, schemas: list[TableSchema], bridge_tables: list[str], ) -> MutationResult: """Duplicate bridge-table rows, skipping rows blocked by constraints.""" if not bridge_tables: return MutationResult( mutation_name="duplicate_bridge_rows", tables_affected=[], rows_added=0, success=True, ) schema_names = {schema.name for schema in schemas} rows_added = 0 tables_affected: list[str] = [] with sqlite3.connect(db_path) as connection: cursor = connection.cursor() for table_name in bridge_tables: if table_name not in schema_names: continue quoted_table = _quote_identifier(table_name) cursor.execute(f"PRAGMA table_info({quoted_table})") table_info = cursor.fetchall() if not table_info: continue column_names = [str(row[1]) for row in table_info] quoted_columns = ", ".join(_quote_identifier(name) for name in column_names) placeholders = ", ".join("?" for _ in column_names) cursor.execute(f"SELECT {quoted_columns} FROM {quoted_table}") existing_rows = cursor.fetchall() inserted_for_table = 0 for row in existing_rows: try: insert_sql = ( f"INSERT INTO {quoted_table}" f" ({quoted_columns})" f" VALUES ({placeholders})" ) cursor.execute(insert_sql, row) inserted_for_table += 1 except sqlite3.IntegrityError: continue if inserted_for_table > 0: tables_affected.append(table_name) rows_added += inserted_for_table connection.commit() return MutationResult( mutation_name="duplicate_bridge_rows", tables_affected=sorted(tables_affected), rows_added=rows_added, success=True, )