""" OmniFile AI Processor — Pattern Database ========================================== Source: arabic-ocr-pro/ai/pattern_db.py Provides a SQLite-based storage for: - User corrections (original text → corrected text) - Pattern images (cropped word images + labels) - Usage statistics - Training status tracking The database enables the system to learn from user corrections and improve OCR accuracy over time through pattern matching. """ from __future__ import annotations import logging import sqlite3 from datetime import datetime from pathlib import Path from typing import Optional logger = logging.getLogger(__name__) class PatternDatabase: """SQLite database for storing OCR correction patterns. Manages persistent storage of user corrections and word pattern images, enabling the system to learn and improve over time. Attributes: db_path: Path to the SQLite database file. _connection: Active SQLite connection. """ def __init__(self, db_path: str | Path = "data/corrections.db") -> None: """Initialize the pattern database. Creates the database file and tables if they don't exist. Args: db_path: Path to the SQLite database file. """ self.db_path = Path(db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) self._connection: Optional[sqlite3.Connection] = None self._initialize_database() # ------------------------------------------------------------------ # Connection management # ------------------------------------------------------------------ def _get_connection(self) -> sqlite3.Connection: """Get or create a database connection. Returns: Active SQLite connection. """ if self._connection is None: self._connection = sqlite3.connect( str(self.db_path), check_same_thread=False, ) self._connection.row_factory = sqlite3.Row # Enable WAL mode for better concurrent access self._connection.execute("PRAGMA journal_mode=WAL") return self._connection def _initialize_database(self) -> None: """Create database tables if they don't exist. Creates: - corrections: Stores original → corrected text mappings - patterns: Stores word pattern images (BLOB) with labels - statistics: Tracks usage statistics - training_status: Tracks model training progress """ conn = self._get_connection() cursor = conn.cursor() cursor.executescript(""" CREATE TABLE IF NOT EXISTS corrections ( id INTEGER PRIMARY KEY AUTOINCREMENT, original_text TEXT NOT NULL, corrected_text TEXT NOT NULL, engine TEXT DEFAULT '', confidence REAL DEFAULT 0.0, created_at TEXT NOT NULL DEFAULT (datetime('now')), use_count INTEGER DEFAULT 0, last_used_at TEXT ); CREATE INDEX IF NOT EXISTS idx_corrections_original ON corrections(original_text); CREATE TABLE IF NOT EXISTS patterns ( id INTEGER PRIMARY KEY AUTOINCREMENT, label TEXT NOT NULL, image_data BLOB, image_width INTEGER, image_height INTEGER, ocr_text TEXT, confidence REAL DEFAULT 0.0, source_engine TEXT DEFAULT '', created_at TEXT NOT NULL DEFAULT (datetime('now')), use_count INTEGER DEFAULT 0, last_used_at TEXT ); CREATE INDEX IF NOT EXISTS idx_patterns_label ON patterns(label); CREATE TABLE IF NOT EXISTS statistics ( id INTEGER PRIMARY KEY AUTOINCREMENT, stat_key TEXT NOT NULL UNIQUE, stat_value TEXT NOT NULL, updated_at TEXT NOT NULL DEFAULT (datetime('now')) ); CREATE TABLE IF NOT EXISTS training_status ( id INTEGER PRIMARY KEY AUTOINCREMENT, model_name TEXT NOT NULL, status TEXT DEFAULT 'pending', total_samples INTEGER DEFAULT 0, trained_samples INTEGER DEFAULT 0, accuracy REAL DEFAULT 0.0, last_trained_at TEXT, created_at TEXT NOT NULL DEFAULT (datetime('now')) ); """) conn.commit() logger.debug(f"Database initialized: {self.db_path}") # ------------------------------------------------------------------ # Corrections CRUD # ------------------------------------------------------------------ def add_correction( self, original_text: str, corrected_text: str, engine: str = "", confidence: float = 0.0, ) -> int: """Add a new correction to the database. If the same correction already exists, increments its use count instead of creating a duplicate. Args: original_text: Original (incorrect) OCR text. corrected_text: User-provided corrected text. engine: OCR engine that produced the original text. confidence: Confidence score of the original OCR result. Returns: Row ID of the correction record. """ conn = self._get_connection() cursor = conn.cursor() # Check if correction already exists cursor.execute( "SELECT id, use_count FROM corrections " "WHERE original_text = ? AND corrected_text = ?", (original_text, corrected_text), ) existing = cursor.fetchone() if existing: cursor.execute( "UPDATE corrections SET use_count = ?, " "last_used_at = datetime('now') WHERE id = ?", (existing["use_count"] + 1, existing["id"]), ) conn.commit() return existing["id"] cursor.execute( """INSERT INTO corrections (original_text, corrected_text, engine, confidence) VALUES (?, ?, ?, ?)""", (original_text, corrected_text, engine, confidence), ) conn.commit() row_id = cursor.lastrowid # Update statistics self._increment_stat("total_corrections") logger.debug( f"Added correction: '{original_text}' -> '{corrected_text}'" ) return row_id def get_corrections( self, limit: int = 1000, min_use_count: int = 0, ) -> list[dict]: """Get all stored corrections. Args: limit: Maximum number of corrections to return. min_use_count: Minimum use count filter. Returns: List of correction dictionaries. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute( """SELECT id, original_text, corrected_text, engine, confidence, use_count, created_at FROM corrections WHERE use_count >= ? ORDER BY use_count DESC, created_at DESC LIMIT ?""", (min_use_count, limit), ) return [ { "id": row["id"], "original_text": row["original_text"], "corrected_text": row["corrected_text"], "engine": row["engine"], "confidence": row["confidence"], "use_count": row["use_count"], "created_at": row["created_at"], } for row in cursor.fetchall() ] def find_correction(self, original_text: str) -> Optional[dict]: """Look up a correction for specific original text. Args: original_text: The text to look up. Returns: Correction dictionary if found, None otherwise. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute( """SELECT id, original_text, corrected_text, engine, confidence, use_count, created_at FROM corrections WHERE original_text = ? ORDER BY use_count DESC LIMIT 1""", (original_text,), ) row = cursor.fetchone() if row: return { "id": row["id"], "original_text": row["original_text"], "corrected_text": row["corrected_text"], "engine": row["engine"], "confidence": row["confidence"], "use_count": row["use_count"], "created_at": row["created_at"], } return None def delete_correction(self, correction_id: int) -> bool: """Delete a correction record. Args: correction_id: ID of the correction to delete. Returns: True if deleted, False if not found. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute( "DELETE FROM corrections WHERE id = ?", (correction_id,) ) conn.commit() deleted = cursor.rowcount > 0 if deleted: logger.debug(f"Deleted correction id={correction_id}") return deleted # ------------------------------------------------------------------ # Patterns CRUD # ------------------------------------------------------------------ def add_pattern( self, label: str, image_data: bytes, image_width: int, image_height: int, ocr_text: str = "", confidence: float = 0.0, source_engine: str = "", ) -> int: """Add a new word pattern image to the database. Stores a cropped word image along with its label (correct text) for future pattern matching. Args: label: Correct text label for the pattern. image_data: Raw image bytes (PNG or JPEG encoded). image_width: Width of the pattern image. image_height: Height of the pattern image. ocr_text: OCR result that produced this pattern. confidence: Confidence score of the OCR result. source_engine: OCR engine that produced the result. Returns: Row ID of the pattern record. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute( """INSERT INTO patterns (label, image_data, image_width, image_height, ocr_text, confidence, source_engine) VALUES (?, ?, ?, ?, ?, ?, ?)""", (label, image_data, image_width, image_height, ocr_text, confidence, source_engine), ) conn.commit() row_id = cursor.lastrowid self._increment_stat("total_patterns") logger.debug( f"Added pattern: label='{label}', size={len(image_data)} bytes" ) return row_id def get_patterns( self, label: Optional[str] = None, limit: int = 500, ) -> list[dict]: """Get stored pattern images. Args: label: Optional label filter. limit: Maximum number of patterns to return. Returns: List of pattern dictionaries with image data. """ conn = self._get_connection() cursor = conn.cursor() if label: cursor.execute( """SELECT id, label, image_data, image_width, image_height, ocr_text, confidence, source_engine, use_count FROM patterns WHERE label = ? ORDER BY use_count DESC LIMIT ?""", (label, limit), ) else: cursor.execute( """SELECT id, label, image_data, image_width, image_height, ocr_text, confidence, source_engine, use_count FROM patterns ORDER BY use_count DESC LIMIT ?""", (limit,), ) return [ { "id": row["id"], "label": row["label"], "image_data": row["image_data"], "image_width": row["image_width"], "image_height": row["image_height"], "ocr_text": row["ocr_text"], "confidence": row["confidence"], "source_engine": row["source_engine"], "use_count": row["use_count"], } for row in cursor.fetchall() ] def get_unique_labels(self) -> list[str]: """Get all unique pattern labels. Returns: Sorted list of unique label strings. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute("SELECT DISTINCT label FROM patterns ORDER BY label") return [row["label"] for row in cursor.fetchall()] def increment_pattern_use(self, pattern_id: int) -> None: """Increment the use count for a pattern. Args: pattern_id: ID of the pattern to update. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute( "UPDATE patterns SET use_count = use_count + 1, " "last_used_at = datetime('now') WHERE id = ?", (pattern_id,), ) conn.commit() def delete_pattern(self, pattern_id: int) -> bool: """Delete a pattern record. Args: pattern_id: ID of the pattern to delete. Returns: True if deleted, False if not found. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute("DELETE FROM patterns WHERE id = ?", (pattern_id,)) conn.commit() deleted = cursor.rowcount > 0 if deleted: logger.debug(f"Deleted pattern id={pattern_id}") return deleted # ------------------------------------------------------------------ # Statistics # ------------------------------------------------------------------ def _increment_stat(self, key: str, increment: int = 1) -> None: """Increment a statistics counter. Args: key: Statistics key name. increment: Amount to increment by. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute( """INSERT INTO statistics (stat_key, stat_value) VALUES (?, ?) ON CONFLICT(stat_key) DO UPDATE SET stat_value = CAST(stat_value AS INTEGER) + ?, updated_at = datetime('now')""", (key, str(increment), increment), ) conn.commit() def get_stat(self, key: str) -> int: """Get a statistics value. Args: key: Statistics key name. Returns: Integer value, or 0 if not found. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute( "SELECT stat_value FROM statistics WHERE stat_key = ?", (key,) ) row = cursor.fetchone() if row: try: return int(row["stat_value"]) except (ValueError, TypeError): return 0 return 0 def get_all_stats(self) -> dict[str, int]: """Get all statistics. Returns: Dictionary of all statistic key-value pairs. """ conn = self._get_connection() cursor = conn.cursor() cursor.execute("SELECT stat_key, stat_value FROM statistics") return { row["stat_key"]: int(row["stat_value"]) for row in cursor.fetchall() } # ------------------------------------------------------------------ # Maintenance # ------------------------------------------------------------------ def cleanup(self, max_age_days: int = 90) -> int: """Remove old records that haven't been used recently. Args: max_age_days: Maximum age in days for unused records. Returns: Number of records deleted. """ conn = self._get_connection() cursor = conn.cursor() cutoff = f"datetime('now', '-{max_age_days} days')" cursor.execute( f"DELETE FROM corrections WHERE last_used_at IS NULL " f"AND created_at < {cutoff} AND use_count = 0" ) deleted_corrections = cursor.rowcount cursor.execute( f"DELETE FROM patterns WHERE last_used_at IS NULL " f"AND created_at < {cutoff} AND use_count = 0" ) deleted_patterns = cursor.rowcount conn.commit() total = deleted_corrections + deleted_patterns if total > 0: logger.info( f"Cleanup: deleted {deleted_corrections} corrections " f"and {deleted_patterns} patterns" ) return total # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------ def close(self) -> None: """Close the database connection.""" if self._connection is not None: self._connection.close() self._connection = None logger.debug("Database connection closed") def __enter__(self) -> "PatternDatabase": """Context manager entry.""" return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Context manager exit.""" self.close()