File size: 5,897 Bytes
b75c637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
"""
Storage — SQL persistence for pipeline runs.

Uses SQLite by default (``runs/<run_id>/db.sqlite``).
Set ``DATABASE_URL`` env var for Postgres (e.g. ``postgresql://user:pass@host/db``).

Tables:
    normalized_records – all NormalizedRecord rows
    stage_results      – per-stage summary + status
    artifacts          – artifact metadata
"""

from __future__ import annotations

import json
import logging
import os
import sqlite3
from pathlib import Path
from typing import List, Optional

from engine.io_contract import Artifact, EngineOutput, NormalizedRecord

logger = logging.getLogger("engine.storage")

# ---------------------------------------------------------------------------
# Schema DDL (SQLite-compatible, works with Postgres too)
# ---------------------------------------------------------------------------

_DDL = """
CREATE TABLE IF NOT EXISTS normalized_records (
    row_id        TEXT PRIMARY KEY,
    source_file   TEXT,
    source_type   TEXT,
    timestamp     TEXT,
    entity_name   TEXT,
    entity_phone  TEXT,
    entity_email  TEXT,
    entity_ip     TEXT,
    entity_domain TEXT,
    entity_hash   TEXT,
    raw_text      TEXT,
    extra         TEXT
);

CREATE TABLE IF NOT EXISTS stage_results (
    stage   TEXT PRIMARY KEY,
    status  TEXT,
    summary TEXT,
    error   TEXT,
    metadata TEXT
);

CREATE TABLE IF NOT EXISTS artifacts (
    name        TEXT PRIMARY KEY,
    path        TEXT,
    mime_type   TEXT,
    description TEXT
);
"""


# ---------------------------------------------------------------------------
# Storage backend
# ---------------------------------------------------------------------------

class StorageBackend:
    """
    Thin wrapper around a SQLite (or Postgres) connection.

    For this iteration we use raw ``sqlite3``.  A future iteration can
    swap in SQLAlchemy / SQLModel for Postgres parity.
    """

    def __init__(self, db_path: Path):
        self.db_path = db_path
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        self._conn: Optional[sqlite3.Connection] = None

    # -- lifecycle -----------------------------------------------------------

    def connect(self) -> None:
        logger.info("Connecting to SQLite: %s", self.db_path)
        self._conn = sqlite3.connect(str(self.db_path))
        self._conn.executescript(_DDL)
        self._conn.commit()

    def close(self) -> None:
        if self._conn:
            self._conn.close()
            self._conn = None

    @property
    def conn(self) -> sqlite3.Connection:
        if self._conn is None:
            self.connect()
        assert self._conn is not None
        return self._conn

    # -- writes --------------------------------------------------------------

    def insert_records(self, records: List[NormalizedRecord]) -> int:
        """Insert normalized records.  Returns count inserted."""
        if not records:
            return 0
        sql = """
            INSERT OR REPLACE INTO normalized_records
            (row_id, source_file, source_type, timestamp,
             entity_name, entity_phone, entity_email,
             entity_ip, entity_domain, entity_hash,
             raw_text, extra)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """
        rows = [
            (
                r.row_id,
                r.source_file,
                r.source_type.value if r.source_type else "",
                r.timestamp.isoformat() if r.timestamp else None,
                r.entity_name,
                r.entity_phone,
                r.entity_email,
                r.entity_ip,
                r.entity_domain,
                r.entity_hash,
                r.raw_text,
                json.dumps(r.extra, ensure_ascii=False, default=str),
            )
            for r in records
        ]
        self.conn.executemany(sql, rows)
        self.conn.commit()
        logger.info("Inserted %d normalized records", len(rows))
        return len(rows)

    def insert_stage_result(self, output: EngineOutput) -> None:
        """Upsert a stage result row."""
        sql = """
            INSERT OR REPLACE INTO stage_results
            (stage, status, summary, error, metadata)
            VALUES (?, ?, ?, ?, ?)
        """
        self.conn.execute(sql, (
            output.stage,
            output.status.value,
            output.summary,
            output.error,
            json.dumps(output.metadata, ensure_ascii=False, default=str),
        ))
        self.conn.commit()

    def insert_artifact(self, artifact: Artifact) -> None:
        """Upsert an artifact metadata row."""
        sql = """
            INSERT OR REPLACE INTO artifacts
            (name, path, mime_type, description)
            VALUES (?, ?, ?, ?)
        """
        self.conn.execute(sql, (
            artifact.name,
            str(artifact.path),
            artifact.mime_type,
            artifact.description,
        ))
        self.conn.commit()

    # -- reads ---------------------------------------------------------------

    def count_records(self) -> int:
        cur = self.conn.execute("SELECT COUNT(*) FROM normalized_records")
        return cur.fetchone()[0]

    def fetch_all_records(self) -> List[dict]:
        """Return all normalized records as dicts."""
        cur = self.conn.execute("SELECT * FROM normalized_records")
        cols = [d[0] for d in cur.description]
        return [dict(zip(cols, row)) for row in cur.fetchall()]

    def fetch_stage_results(self) -> List[dict]:
        cur = self.conn.execute("SELECT * FROM stage_results")
        cols = [d[0] for d in cur.description]
        return [dict(zip(cols, row)) for row in cur.fetchall()]


def create_storage(db_path: Path) -> StorageBackend:
    """Factory: create and connect a StorageBackend."""
    backend = StorageBackend(db_path)
    backend.connect()
    return backend