File size: 18,069 Bytes
900df0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
"""
OmniFile AI Processor — Pattern Database
==========================================
Source: arabic-ocr-pro/ai/pattern_db.py

Provides a SQLite-based storage for:
- User corrections (original text → corrected text)
- Pattern images (cropped word images + labels)
- Usage statistics
- Training status tracking

The database enables the system to learn from user corrections
and improve OCR accuracy over time through pattern matching.
"""

from __future__ import annotations

import logging
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Optional

logger = logging.getLogger(__name__)


class PatternDatabase:
    """SQLite database for storing OCR correction patterns.

    Manages persistent storage of user corrections and word pattern
    images, enabling the system to learn and improve over time.

    Attributes:
        db_path: Path to the SQLite database file.
        _connection: Active SQLite connection.
    """

    def __init__(self, db_path: str | Path = "data/corrections.db") -> None:
        """Initialize the pattern database.

        Creates the database file and tables if they don't exist.

        Args:
            db_path: Path to the SQLite database file.
        """
        self.db_path = Path(db_path)
        self.db_path.parent.mkdir(parents=True, exist_ok=True)

        self._connection: Optional[sqlite3.Connection] = None
        self._initialize_database()

    # ------------------------------------------------------------------
    # Connection management
    # ------------------------------------------------------------------

    def _get_connection(self) -> sqlite3.Connection:
        """Get or create a database connection.

        Returns:
            Active SQLite connection.
        """
        if self._connection is None:
            self._connection = sqlite3.connect(
                str(self.db_path),
                check_same_thread=False,
            )
            self._connection.row_factory = sqlite3.Row
            # Enable WAL mode for better concurrent access
            self._connection.execute("PRAGMA journal_mode=WAL")
        return self._connection

    def _initialize_database(self) -> None:
        """Create database tables if they don't exist.

        Creates:
        - corrections: Stores original → corrected text mappings
        - patterns: Stores word pattern images (BLOB) with labels
        - statistics: Tracks usage statistics
        - training_status: Tracks model training progress
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.executescript("""
            CREATE TABLE IF NOT EXISTS corrections (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                original_text TEXT NOT NULL,
                corrected_text TEXT NOT NULL,
                engine TEXT DEFAULT '',
                confidence REAL DEFAULT 0.0,
                created_at TEXT NOT NULL DEFAULT (datetime('now')),
                use_count INTEGER DEFAULT 0,
                last_used_at TEXT
            );

            CREATE INDEX IF NOT EXISTS idx_corrections_original
                ON corrections(original_text);

            CREATE TABLE IF NOT EXISTS patterns (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                label TEXT NOT NULL,
                image_data BLOB,
                image_width INTEGER,
                image_height INTEGER,
                ocr_text TEXT,
                confidence REAL DEFAULT 0.0,
                source_engine TEXT DEFAULT '',
                created_at TEXT NOT NULL DEFAULT (datetime('now')),
                use_count INTEGER DEFAULT 0,
                last_used_at TEXT
            );

            CREATE INDEX IF NOT EXISTS idx_patterns_label
                ON patterns(label);

            CREATE TABLE IF NOT EXISTS statistics (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                stat_key TEXT NOT NULL UNIQUE,
                stat_value TEXT NOT NULL,
                updated_at TEXT NOT NULL DEFAULT (datetime('now'))
            );

            CREATE TABLE IF NOT EXISTS training_status (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                model_name TEXT NOT NULL,
                status TEXT DEFAULT 'pending',
                total_samples INTEGER DEFAULT 0,
                trained_samples INTEGER DEFAULT 0,
                accuracy REAL DEFAULT 0.0,
                last_trained_at TEXT,
                created_at TEXT NOT NULL DEFAULT (datetime('now'))
            );
        """)

        conn.commit()
        logger.debug(f"Database initialized: {self.db_path}")

    # ------------------------------------------------------------------
    # Corrections CRUD
    # ------------------------------------------------------------------

    def add_correction(
        self,
        original_text: str,
        corrected_text: str,
        engine: str = "",
        confidence: float = 0.0,
    ) -> int:
        """Add a new correction to the database.

        If the same correction already exists, increments its use count
        instead of creating a duplicate.

        Args:
            original_text: Original (incorrect) OCR text.
            corrected_text: User-provided corrected text.
            engine: OCR engine that produced the original text.
            confidence: Confidence score of the original OCR result.

        Returns:
            Row ID of the correction record.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        # Check if correction already exists
        cursor.execute(
            "SELECT id, use_count FROM corrections "
            "WHERE original_text = ? AND corrected_text = ?",
            (original_text, corrected_text),
        )
        existing = cursor.fetchone()

        if existing:
            cursor.execute(
                "UPDATE corrections SET use_count = ?, "
                "last_used_at = datetime('now') WHERE id = ?",
                (existing["use_count"] + 1, existing["id"]),
            )
            conn.commit()
            return existing["id"]

        cursor.execute(
            """INSERT INTO corrections (original_text, corrected_text, engine, confidence)
               VALUES (?, ?, ?, ?)""",
            (original_text, corrected_text, engine, confidence),
        )
        conn.commit()
        row_id = cursor.lastrowid

        # Update statistics
        self._increment_stat("total_corrections")

        logger.debug(
            f"Added correction: '{original_text}' -> '{corrected_text}'"
        )
        return row_id

    def get_corrections(
        self,
        limit: int = 1000,
        min_use_count: int = 0,
    ) -> list[dict]:
        """Get all stored corrections.

        Args:
            limit: Maximum number of corrections to return.
            min_use_count: Minimum use count filter.

        Returns:
            List of correction dictionaries.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute(
            """SELECT id, original_text, corrected_text, engine, confidence,
                      use_count, created_at
               FROM corrections
               WHERE use_count >= ?
               ORDER BY use_count DESC, created_at DESC
               LIMIT ?""",
            (min_use_count, limit),
        )

        return [
            {
                "id": row["id"],
                "original_text": row["original_text"],
                "corrected_text": row["corrected_text"],
                "engine": row["engine"],
                "confidence": row["confidence"],
                "use_count": row["use_count"],
                "created_at": row["created_at"],
            }
            for row in cursor.fetchall()
        ]

    def find_correction(self, original_text: str) -> Optional[dict]:
        """Look up a correction for specific original text.

        Args:
            original_text: The text to look up.

        Returns:
            Correction dictionary if found, None otherwise.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute(
            """SELECT id, original_text, corrected_text, engine, confidence,
                      use_count, created_at
               FROM corrections
               WHERE original_text = ?
               ORDER BY use_count DESC
               LIMIT 1""",
            (original_text,),
        )

        row = cursor.fetchone()
        if row:
            return {
                "id": row["id"],
                "original_text": row["original_text"],
                "corrected_text": row["corrected_text"],
                "engine": row["engine"],
                "confidence": row["confidence"],
                "use_count": row["use_count"],
                "created_at": row["created_at"],
            }
        return None

    def delete_correction(self, correction_id: int) -> bool:
        """Delete a correction record.

        Args:
            correction_id: ID of the correction to delete.

        Returns:
            True if deleted, False if not found.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute(
            "DELETE FROM corrections WHERE id = ?", (correction_id,)
        )
        conn.commit()
        deleted = cursor.rowcount > 0

        if deleted:
            logger.debug(f"Deleted correction id={correction_id}")

        return deleted

    # ------------------------------------------------------------------
    # Patterns CRUD
    # ------------------------------------------------------------------

    def add_pattern(
        self,
        label: str,
        image_data: bytes,
        image_width: int,
        image_height: int,
        ocr_text: str = "",
        confidence: float = 0.0,
        source_engine: str = "",
    ) -> int:
        """Add a new word pattern image to the database.

        Stores a cropped word image along with its label (correct text)
        for future pattern matching.

        Args:
            label: Correct text label for the pattern.
            image_data: Raw image bytes (PNG or JPEG encoded).
            image_width: Width of the pattern image.
            image_height: Height of the pattern image.
            ocr_text: OCR result that produced this pattern.
            confidence: Confidence score of the OCR result.
            source_engine: OCR engine that produced the result.

        Returns:
            Row ID of the pattern record.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute(
            """INSERT INTO patterns (label, image_data, image_width, image_height,
                                      ocr_text, confidence, source_engine)
               VALUES (?, ?, ?, ?, ?, ?, ?)""",
            (label, image_data, image_width, image_height,
             ocr_text, confidence, source_engine),
        )
        conn.commit()
        row_id = cursor.lastrowid

        self._increment_stat("total_patterns")
        logger.debug(
            f"Added pattern: label='{label}', size={len(image_data)} bytes"
        )
        return row_id

    def get_patterns(
        self,
        label: Optional[str] = None,
        limit: int = 500,
    ) -> list[dict]:
        """Get stored pattern images.

        Args:
            label: Optional label filter.
            limit: Maximum number of patterns to return.

        Returns:
            List of pattern dictionaries with image data.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        if label:
            cursor.execute(
                """SELECT id, label, image_data, image_width, image_height,
                          ocr_text, confidence, source_engine, use_count
                   FROM patterns
                   WHERE label = ?
                   ORDER BY use_count DESC
                   LIMIT ?""",
                (label, limit),
            )
        else:
            cursor.execute(
                """SELECT id, label, image_data, image_width, image_height,
                          ocr_text, confidence, source_engine, use_count
                   FROM patterns
                   ORDER BY use_count DESC
                   LIMIT ?""",
                (limit,),
            )

        return [
            {
                "id": row["id"],
                "label": row["label"],
                "image_data": row["image_data"],
                "image_width": row["image_width"],
                "image_height": row["image_height"],
                "ocr_text": row["ocr_text"],
                "confidence": row["confidence"],
                "source_engine": row["source_engine"],
                "use_count": row["use_count"],
            }
            for row in cursor.fetchall()
        ]

    def get_unique_labels(self) -> list[str]:
        """Get all unique pattern labels.

        Returns:
            Sorted list of unique label strings.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute("SELECT DISTINCT label FROM patterns ORDER BY label")
        return [row["label"] for row in cursor.fetchall()]

    def increment_pattern_use(self, pattern_id: int) -> None:
        """Increment the use count for a pattern.

        Args:
            pattern_id: ID of the pattern to update.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute(
            "UPDATE patterns SET use_count = use_count + 1, "
            "last_used_at = datetime('now') WHERE id = ?",
            (pattern_id,),
        )
        conn.commit()

    def delete_pattern(self, pattern_id: int) -> bool:
        """Delete a pattern record.

        Args:
            pattern_id: ID of the pattern to delete.

        Returns:
            True if deleted, False if not found.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute("DELETE FROM patterns WHERE id = ?", (pattern_id,))
        conn.commit()
        deleted = cursor.rowcount > 0

        if deleted:
            logger.debug(f"Deleted pattern id={pattern_id}")

        return deleted

    # ------------------------------------------------------------------
    # Statistics
    # ------------------------------------------------------------------

    def _increment_stat(self, key: str, increment: int = 1) -> None:
        """Increment a statistics counter.

        Args:
            key: Statistics key name.
            increment: Amount to increment by.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute(
            """INSERT INTO statistics (stat_key, stat_value)
               VALUES (?, ?)
               ON CONFLICT(stat_key) DO UPDATE SET
                   stat_value = CAST(stat_value AS INTEGER) + ?,
                   updated_at = datetime('now')""",
            (key, str(increment), increment),
        )
        conn.commit()

    def get_stat(self, key: str) -> int:
        """Get a statistics value.

        Args:
            key: Statistics key name.

        Returns:
            Integer value, or 0 if not found.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute(
            "SELECT stat_value FROM statistics WHERE stat_key = ?", (key,)
        )
        row = cursor.fetchone()

        if row:
            try:
                return int(row["stat_value"])
            except (ValueError, TypeError):
                return 0
        return 0

    def get_all_stats(self) -> dict[str, int]:
        """Get all statistics.

        Returns:
            Dictionary of all statistic key-value pairs.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cursor.execute("SELECT stat_key, stat_value FROM statistics")
        return {
            row["stat_key"]: int(row["stat_value"])
            for row in cursor.fetchall()
        }

    # ------------------------------------------------------------------
    # Maintenance
    # ------------------------------------------------------------------

    def cleanup(self, max_age_days: int = 90) -> int:
        """Remove old records that haven't been used recently.

        Args:
            max_age_days: Maximum age in days for unused records.

        Returns:
            Number of records deleted.
        """
        conn = self._get_connection()
        cursor = conn.cursor()

        cutoff = f"datetime('now', '-{max_age_days} days')"

        cursor.execute(
            f"DELETE FROM corrections WHERE last_used_at IS NULL "
            f"AND created_at < {cutoff} AND use_count = 0"
        )
        deleted_corrections = cursor.rowcount

        cursor.execute(
            f"DELETE FROM patterns WHERE last_used_at IS NULL "
            f"AND created_at < {cutoff} AND use_count = 0"
        )
        deleted_patterns = cursor.rowcount

        conn.commit()
        total = deleted_corrections + deleted_patterns

        if total > 0:
            logger.info(
                f"Cleanup: deleted {deleted_corrections} corrections "
                f"and {deleted_patterns} patterns"
            )

        return total

    # ------------------------------------------------------------------
    # Lifecycle
    # ------------------------------------------------------------------

    def close(self) -> None:
        """Close the database connection."""
        if self._connection is not None:
            self._connection.close()
            self._connection = None
            logger.debug("Database connection closed")

    def __enter__(self) -> "PatternDatabase":
        """Context manager entry."""
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        """Context manager exit."""
        self.close()