Spaces:
Sleeping
Sleeping
File size: 18,069 Bytes
900df0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 | """
OmniFile AI Processor — Pattern Database
==========================================
Source: arabic-ocr-pro/ai/pattern_db.py
Provides a SQLite-based storage for:
- User corrections (original text → corrected text)
- Pattern images (cropped word images + labels)
- Usage statistics
- Training status tracking
The database enables the system to learn from user corrections
and improve OCR accuracy over time through pattern matching.
"""
from __future__ import annotations
import logging
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
class PatternDatabase:
"""SQLite database for storing OCR correction patterns.
Manages persistent storage of user corrections and word pattern
images, enabling the system to learn and improve over time.
Attributes:
db_path: Path to the SQLite database file.
_connection: Active SQLite connection.
"""
def __init__(self, db_path: str | Path = "data/corrections.db") -> None:
"""Initialize the pattern database.
Creates the database file and tables if they don't exist.
Args:
db_path: Path to the SQLite database file.
"""
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._connection: Optional[sqlite3.Connection] = None
self._initialize_database()
# ------------------------------------------------------------------
# Connection management
# ------------------------------------------------------------------
def _get_connection(self) -> sqlite3.Connection:
"""Get or create a database connection.
Returns:
Active SQLite connection.
"""
if self._connection is None:
self._connection = sqlite3.connect(
str(self.db_path),
check_same_thread=False,
)
self._connection.row_factory = sqlite3.Row
# Enable WAL mode for better concurrent access
self._connection.execute("PRAGMA journal_mode=WAL")
return self._connection
def _initialize_database(self) -> None:
"""Create database tables if they don't exist.
Creates:
- corrections: Stores original → corrected text mappings
- patterns: Stores word pattern images (BLOB) with labels
- statistics: Tracks usage statistics
- training_status: Tracks model training progress
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.executescript("""
CREATE TABLE IF NOT EXISTS corrections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
original_text TEXT NOT NULL,
corrected_text TEXT NOT NULL,
engine TEXT DEFAULT '',
confidence REAL DEFAULT 0.0,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
use_count INTEGER DEFAULT 0,
last_used_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_corrections_original
ON corrections(original_text);
CREATE TABLE IF NOT EXISTS patterns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
label TEXT NOT NULL,
image_data BLOB,
image_width INTEGER,
image_height INTEGER,
ocr_text TEXT,
confidence REAL DEFAULT 0.0,
source_engine TEXT DEFAULT '',
created_at TEXT NOT NULL DEFAULT (datetime('now')),
use_count INTEGER DEFAULT 0,
last_used_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_patterns_label
ON patterns(label);
CREATE TABLE IF NOT EXISTS statistics (
id INTEGER PRIMARY KEY AUTOINCREMENT,
stat_key TEXT NOT NULL UNIQUE,
stat_value TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE TABLE IF NOT EXISTS training_status (
id INTEGER PRIMARY KEY AUTOINCREMENT,
model_name TEXT NOT NULL,
status TEXT DEFAULT 'pending',
total_samples INTEGER DEFAULT 0,
trained_samples INTEGER DEFAULT 0,
accuracy REAL DEFAULT 0.0,
last_trained_at TEXT,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
""")
conn.commit()
logger.debug(f"Database initialized: {self.db_path}")
# ------------------------------------------------------------------
# Corrections CRUD
# ------------------------------------------------------------------
def add_correction(
self,
original_text: str,
corrected_text: str,
engine: str = "",
confidence: float = 0.0,
) -> int:
"""Add a new correction to the database.
If the same correction already exists, increments its use count
instead of creating a duplicate.
Args:
original_text: Original (incorrect) OCR text.
corrected_text: User-provided corrected text.
engine: OCR engine that produced the original text.
confidence: Confidence score of the original OCR result.
Returns:
Row ID of the correction record.
"""
conn = self._get_connection()
cursor = conn.cursor()
# Check if correction already exists
cursor.execute(
"SELECT id, use_count FROM corrections "
"WHERE original_text = ? AND corrected_text = ?",
(original_text, corrected_text),
)
existing = cursor.fetchone()
if existing:
cursor.execute(
"UPDATE corrections SET use_count = ?, "
"last_used_at = datetime('now') WHERE id = ?",
(existing["use_count"] + 1, existing["id"]),
)
conn.commit()
return existing["id"]
cursor.execute(
"""INSERT INTO corrections (original_text, corrected_text, engine, confidence)
VALUES (?, ?, ?, ?)""",
(original_text, corrected_text, engine, confidence),
)
conn.commit()
row_id = cursor.lastrowid
# Update statistics
self._increment_stat("total_corrections")
logger.debug(
f"Added correction: '{original_text}' -> '{corrected_text}'"
)
return row_id
def get_corrections(
self,
limit: int = 1000,
min_use_count: int = 0,
) -> list[dict]:
"""Get all stored corrections.
Args:
limit: Maximum number of corrections to return.
min_use_count: Minimum use count filter.
Returns:
List of correction dictionaries.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""SELECT id, original_text, corrected_text, engine, confidence,
use_count, created_at
FROM corrections
WHERE use_count >= ?
ORDER BY use_count DESC, created_at DESC
LIMIT ?""",
(min_use_count, limit),
)
return [
{
"id": row["id"],
"original_text": row["original_text"],
"corrected_text": row["corrected_text"],
"engine": row["engine"],
"confidence": row["confidence"],
"use_count": row["use_count"],
"created_at": row["created_at"],
}
for row in cursor.fetchall()
]
def find_correction(self, original_text: str) -> Optional[dict]:
"""Look up a correction for specific original text.
Args:
original_text: The text to look up.
Returns:
Correction dictionary if found, None otherwise.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""SELECT id, original_text, corrected_text, engine, confidence,
use_count, created_at
FROM corrections
WHERE original_text = ?
ORDER BY use_count DESC
LIMIT 1""",
(original_text,),
)
row = cursor.fetchone()
if row:
return {
"id": row["id"],
"original_text": row["original_text"],
"corrected_text": row["corrected_text"],
"engine": row["engine"],
"confidence": row["confidence"],
"use_count": row["use_count"],
"created_at": row["created_at"],
}
return None
def delete_correction(self, correction_id: int) -> bool:
"""Delete a correction record.
Args:
correction_id: ID of the correction to delete.
Returns:
True if deleted, False if not found.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"DELETE FROM corrections WHERE id = ?", (correction_id,)
)
conn.commit()
deleted = cursor.rowcount > 0
if deleted:
logger.debug(f"Deleted correction id={correction_id}")
return deleted
# ------------------------------------------------------------------
# Patterns CRUD
# ------------------------------------------------------------------
def add_pattern(
self,
label: str,
image_data: bytes,
image_width: int,
image_height: int,
ocr_text: str = "",
confidence: float = 0.0,
source_engine: str = "",
) -> int:
"""Add a new word pattern image to the database.
Stores a cropped word image along with its label (correct text)
for future pattern matching.
Args:
label: Correct text label for the pattern.
image_data: Raw image bytes (PNG or JPEG encoded).
image_width: Width of the pattern image.
image_height: Height of the pattern image.
ocr_text: OCR result that produced this pattern.
confidence: Confidence score of the OCR result.
source_engine: OCR engine that produced the result.
Returns:
Row ID of the pattern record.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""INSERT INTO patterns (label, image_data, image_width, image_height,
ocr_text, confidence, source_engine)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(label, image_data, image_width, image_height,
ocr_text, confidence, source_engine),
)
conn.commit()
row_id = cursor.lastrowid
self._increment_stat("total_patterns")
logger.debug(
f"Added pattern: label='{label}', size={len(image_data)} bytes"
)
return row_id
def get_patterns(
self,
label: Optional[str] = None,
limit: int = 500,
) -> list[dict]:
"""Get stored pattern images.
Args:
label: Optional label filter.
limit: Maximum number of patterns to return.
Returns:
List of pattern dictionaries with image data.
"""
conn = self._get_connection()
cursor = conn.cursor()
if label:
cursor.execute(
"""SELECT id, label, image_data, image_width, image_height,
ocr_text, confidence, source_engine, use_count
FROM patterns
WHERE label = ?
ORDER BY use_count DESC
LIMIT ?""",
(label, limit),
)
else:
cursor.execute(
"""SELECT id, label, image_data, image_width, image_height,
ocr_text, confidence, source_engine, use_count
FROM patterns
ORDER BY use_count DESC
LIMIT ?""",
(limit,),
)
return [
{
"id": row["id"],
"label": row["label"],
"image_data": row["image_data"],
"image_width": row["image_width"],
"image_height": row["image_height"],
"ocr_text": row["ocr_text"],
"confidence": row["confidence"],
"source_engine": row["source_engine"],
"use_count": row["use_count"],
}
for row in cursor.fetchall()
]
def get_unique_labels(self) -> list[str]:
"""Get all unique pattern labels.
Returns:
Sorted list of unique label strings.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT DISTINCT label FROM patterns ORDER BY label")
return [row["label"] for row in cursor.fetchall()]
def increment_pattern_use(self, pattern_id: int) -> None:
"""Increment the use count for a pattern.
Args:
pattern_id: ID of the pattern to update.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"UPDATE patterns SET use_count = use_count + 1, "
"last_used_at = datetime('now') WHERE id = ?",
(pattern_id,),
)
conn.commit()
def delete_pattern(self, pattern_id: int) -> bool:
"""Delete a pattern record.
Args:
pattern_id: ID of the pattern to delete.
Returns:
True if deleted, False if not found.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("DELETE FROM patterns WHERE id = ?", (pattern_id,))
conn.commit()
deleted = cursor.rowcount > 0
if deleted:
logger.debug(f"Deleted pattern id={pattern_id}")
return deleted
# ------------------------------------------------------------------
# Statistics
# ------------------------------------------------------------------
def _increment_stat(self, key: str, increment: int = 1) -> None:
"""Increment a statistics counter.
Args:
key: Statistics key name.
increment: Amount to increment by.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""INSERT INTO statistics (stat_key, stat_value)
VALUES (?, ?)
ON CONFLICT(stat_key) DO UPDATE SET
stat_value = CAST(stat_value AS INTEGER) + ?,
updated_at = datetime('now')""",
(key, str(increment), increment),
)
conn.commit()
def get_stat(self, key: str) -> int:
"""Get a statistics value.
Args:
key: Statistics key name.
Returns:
Integer value, or 0 if not found.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"SELECT stat_value FROM statistics WHERE stat_key = ?", (key,)
)
row = cursor.fetchone()
if row:
try:
return int(row["stat_value"])
except (ValueError, TypeError):
return 0
return 0
def get_all_stats(self) -> dict[str, int]:
"""Get all statistics.
Returns:
Dictionary of all statistic key-value pairs.
"""
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute("SELECT stat_key, stat_value FROM statistics")
return {
row["stat_key"]: int(row["stat_value"])
for row in cursor.fetchall()
}
# ------------------------------------------------------------------
# Maintenance
# ------------------------------------------------------------------
def cleanup(self, max_age_days: int = 90) -> int:
"""Remove old records that haven't been used recently.
Args:
max_age_days: Maximum age in days for unused records.
Returns:
Number of records deleted.
"""
conn = self._get_connection()
cursor = conn.cursor()
cutoff = f"datetime('now', '-{max_age_days} days')"
cursor.execute(
f"DELETE FROM corrections WHERE last_used_at IS NULL "
f"AND created_at < {cutoff} AND use_count = 0"
)
deleted_corrections = cursor.rowcount
cursor.execute(
f"DELETE FROM patterns WHERE last_used_at IS NULL "
f"AND created_at < {cutoff} AND use_count = 0"
)
deleted_patterns = cursor.rowcount
conn.commit()
total = deleted_corrections + deleted_patterns
if total > 0:
logger.info(
f"Cleanup: deleted {deleted_corrections} corrections "
f"and {deleted_patterns} patterns"
)
return total
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
def close(self) -> None:
"""Close the database connection."""
if self._connection is not None:
self._connection.close()
self._connection = None
logger.debug("Database connection closed")
def __enter__(self) -> "PatternDatabase":
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
|