Spaces:
Paused
Paused
| import sqlite3 | |
| from abc import ABC, abstractmethod | |
| from typing import Any, Set | |
| import threading | |
| from overrides import override | |
| class Connection: | |
| """A threadpool connection that returns itself to the pool on close()""" | |
| _pool: "Pool" | |
| _db_file: str | |
| _conn: sqlite3.Connection | |
| def __init__( | |
| self, pool: "Pool", db_file: str, is_uri: bool, *args: Any, **kwargs: Any | |
| ): | |
| self._pool = pool | |
| self._db_file = db_file | |
| self._conn = sqlite3.connect( | |
| db_file, timeout=1000, check_same_thread=False, uri=is_uri, *args, **kwargs | |
| ) # type: ignore | |
| self._conn.isolation_level = None # Handle commits explicitly | |
| def execute(self, sql: str, parameters=...) -> sqlite3.Cursor: # type: ignore | |
| if parameters is ...: | |
| return self._conn.execute(sql) | |
| return self._conn.execute(sql, parameters) | |
| def commit(self) -> None: | |
| self._conn.commit() | |
| def rollback(self) -> None: | |
| self._conn.rollback() | |
| def cursor(self) -> sqlite3.Cursor: | |
| return self._conn.cursor() | |
| def close_actual(self) -> None: | |
| """Actually closes the connection to the db""" | |
| self._conn.close() | |
| class Pool(ABC): | |
| """Abstract base class for a pool of connections to a sqlite database.""" | |
| def __init__(self, db_file: str, is_uri: bool) -> None: | |
| pass | |
| def connect(self, *args: Any, **kwargs: Any) -> Connection: | |
| """Return a connection from the pool.""" | |
| pass | |
| def close(self) -> None: | |
| """Close all connections in the pool.""" | |
| pass | |
| def return_to_pool(self, conn: Connection) -> None: | |
| """Return a connection to the pool.""" | |
| pass | |
| class LockPool(Pool): | |
| """A pool that has a single connection per thread but uses a lock to ensure that only one thread can use it at a time. | |
| This is used because sqlite does not support multithreaded access with connection timeouts when using the | |
| shared cache mode. We use the shared cache mode to allow multiple threads to share a database. | |
| """ | |
| _connections: Set[Connection] | |
| _lock: threading.RLock | |
| _connection: threading.local | |
| _db_file: str | |
| _is_uri: bool | |
| def __init__(self, db_file: str, is_uri: bool = False): | |
| self._connections = set() | |
| self._connection = threading.local() | |
| self._lock = threading.RLock() | |
| self._db_file = db_file | |
| self._is_uri = is_uri | |
| def connect(self, *args: Any, **kwargs: Any) -> Connection: | |
| self._lock.acquire() | |
| if hasattr(self._connection, "conn") and self._connection.conn is not None: | |
| return self._connection.conn # type: ignore # cast doesn't work here for some reason | |
| else: | |
| new_connection = Connection( | |
| self, self._db_file, self._is_uri, *args, **kwargs | |
| ) | |
| self._connection.conn = new_connection | |
| self._connections.add(new_connection) | |
| return new_connection | |
| def return_to_pool(self, conn: Connection) -> None: | |
| try: | |
| self._lock.release() | |
| except RuntimeError: | |
| pass | |
| def close(self) -> None: | |
| for conn in self._connections: | |
| conn.close_actual() | |
| self._connections.clear() | |
| self._connection = threading.local() | |
| try: | |
| self._lock.release() | |
| except RuntimeError: | |
| pass | |
| class PerThreadPool(Pool): | |
| """Maintains a connection per thread. For now this does not maintain a cap on the number of connections, but it could be | |
| extended to do so and block on connect() if the cap is reached. | |
| """ | |
| _connections: Set[Connection] | |
| _lock: threading.Lock | |
| _connection: threading.local | |
| _db_file: str | |
| _is_uri_: bool | |
| def __init__(self, db_file: str, is_uri: bool = False): | |
| self._connections = set() | |
| self._connection = threading.local() | |
| self._lock = threading.Lock() | |
| self._db_file = db_file | |
| self._is_uri = is_uri | |
| def connect(self, *args: Any, **kwargs: Any) -> Connection: | |
| if hasattr(self._connection, "conn") and self._connection.conn is not None: | |
| return self._connection.conn # type: ignore # cast doesn't work here for some reason | |
| else: | |
| new_connection = Connection( | |
| self, self._db_file, self._is_uri, *args, **kwargs | |
| ) | |
| self._connection.conn = new_connection | |
| with self._lock: | |
| self._connections.add(new_connection) | |
| return new_connection | |
| def close(self) -> None: | |
| with self._lock: | |
| for conn in self._connections: | |
| conn.close_actual() | |
| self._connections.clear() | |
| self._connection = threading.local() | |
| def return_to_pool(self, conn: Connection) -> None: | |
| pass # Each thread gets its own connection, so we don't need to return it to the pool | |