File size: 4,990 Bytes
82e122c
 
 
c76014a
 
72e96d1
82e122c
c76014a
72e96d1
c76014a
f55959d
82e122c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72e96d1
 
 
82e122c
 
 
 
 
 
c76014a
 
 
82e122c
 
 
 
 
 
72e96d1
 
 
82e122c
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from __future__ import annotations

import logging
import os
import time
from pathlib import Path
from typing import Dict, Tuple, Optional

log = logging.getLogger(__name__)


class DbUploadStore:
    """
    In-memory registry for uploaded DB files with simple TTL-based cleanup.

    Responsibilities:
    - Track uploaded DBs by db_id -> filesystem path.
    - Enforce a TTL for uploaded DBs.
    - Remove stale entries and delete underlying files when expired.
    """

    def __init__(self, upload_dir: str, ttl_seconds: int) -> None:
        self.upload_dir = upload_dir
        self.ttl_seconds = ttl_seconds
        self._entries: Dict[str, Tuple[str, float]] = {}

        Path(self.upload_dir).mkdir(parents=True, exist_ok=True)
        log.debug(
            "Initialized DbUploadStore",
            extra={
                "upload_dir": self.upload_dir,
                "ttl_seconds": self.ttl_seconds,
            },
        )

    def _now(self) -> float:
        return time.time()

    def _is_expired(self, ts: float, now: Optional[float] = None) -> bool:
        if now is None:
            now = self._now()
        return (now - ts) > self.ttl_seconds

    def _gc_locked(self, now: Optional[float] = None) -> None:
        """
        Internal garbage collector.

        Removes stale entries and deletes the corresponding files on disk
        if they still exist.
        """
        if now is None:
            now = self._now()

        to_delete = []
        for db_id, (path, ts) in list(self._entries.items()):
            if self._is_expired(ts, now) or (not os.path.exists(path)):
                to_delete.append((db_id, path))

        for db_id, path in to_delete:
            self._entries.pop(db_id, None)
            try:
                if os.path.exists(path):
                    os.remove(path)
                    log.debug(
                        "Deleted expired uploaded DB file",
                        extra={"db_id": db_id, "path": path},
                    )
            except Exception as exc:
                # Best-effort cleanup; do not crash the app because of FS issues.
                log.debug(
                    "Failed to delete expired uploaded DB file",
                    extra={"db_id": db_id, "path": path},
                    exc_info=exc,
                )

    def cleanup_stale(self) -> None:
        """
        Public cleanup entry point.

        Can be called periodically (or on access) to remove expired DBs.
        """
        self._gc_locked()

    def register(self, db_id: str, path: str) -> None:
        """
        Register a new uploaded DB with the given db_id and filesystem path.
        """
        now = self._now()
        self._entries[db_id] = (path, now)
        log.debug(
            "Registered uploaded DB in DbUploadStore",
            extra={"db_id": db_id, "path": path},
        )
        # Optionally clean up old entries as we go.
        self._gc_locked(now=now)

    def resolve(self, db_id: str) -> Optional[str]:
        """
        Resolve db_id to a filesystem path if it exists and is not expired.

        Returns:
            str path if valid, or None if missing/expired.
        """
        self._gc_locked()
        entry = self._entries.get(db_id)
        if not entry:
            return None

        path, ts = entry
        if self._is_expired(ts):
            # Expired between last GC and now; treat as missing.
            self._entries.pop(db_id, None)
            try:
                if os.path.exists(path):
                    os.remove(path)
            except Exception as exc:
                log.debug(
                    "Failed to delete DB file on late-expiration",
                    extra={"db_id": db_id, "path": path},
                    exc_info=exc,
                )
            return None

        if not os.path.exists(path):
            # File disappeared; drop the entry.
            self._entries.pop(db_id, None)
            return None

        return path


# --------------------------------------------------------------------
# Module-level singleton and legacy helper functions
# --------------------------------------------------------------------

_DB_UPLOAD_DIR = os.getenv("DB_UPLOAD_DIR", "/tmp/nl2sql_dbs")
_DB_TTL_SECONDS = int(os.getenv("DB_TTL_SECONDS", "7200"))  # default: 2 hours

_STORE = DbUploadStore(upload_dir=_DB_UPLOAD_DIR, ttl_seconds=_DB_TTL_SECONDS)


def register_db(db_id: str, path: str) -> None:
    """
    Backwards-compatible helper:

    Register an uploaded DB in the process-wide DbUploadStore.
    """
    _STORE.register(db_id, path)


def cleanup_stale_dbs() -> None:
    """
    Backwards-compatible helper:

    Trigger TTL-based cleanup of stale DB entries.
    """
    _STORE.cleanup_stale()


def get_db_path(db_id: str) -> Optional[str]:
    """
    Backwards-compatible helper:

    Resolve db_id to a filesystem path if it is still valid.
    """
    return _STORE.resolve(db_id)