Spaces:
Paused
Paused
| from chromadb.db.impl.sqlite_pool import Connection, LockPool, PerThreadPool, Pool | |
| from chromadb.db.migrations import MigratableDB, Migration | |
| from chromadb.config import System, Settings | |
| import chromadb.db.base as base | |
| from chromadb.db.mixins.embeddings_queue import SqlEmbeddingsQueue | |
| from chromadb.db.mixins.sysdb import SqlSysDB | |
| from chromadb.telemetry.opentelemetry import ( | |
| OpenTelemetryClient, | |
| OpenTelemetryGranularity, | |
| trace_method, | |
| ) | |
| import sqlite3 | |
| from overrides import override | |
| import pypika | |
| from typing import Sequence, cast, Optional, Type, Any | |
| from typing_extensions import Literal | |
| from types import TracebackType | |
| import os | |
| from uuid import UUID | |
| from threading import local | |
| from importlib_resources import files | |
| from importlib_resources.abc import Traversable | |
| class TxWrapper(base.TxWrapper): | |
| _conn: Connection | |
| _pool: Pool | |
| def __init__(self, conn_pool: Pool, stack: local): | |
| self._tx_stack = stack | |
| self._conn = conn_pool.connect() | |
| self._pool = conn_pool | |
| def __enter__(self) -> base.Cursor: | |
| if len(self._tx_stack.stack) == 0: | |
| self._conn.execute("PRAGMA case_sensitive_like = ON") | |
| self._conn.execute("BEGIN;") | |
| self._tx_stack.stack.append(self) | |
| return self._conn.cursor() # type: ignore | |
| def __exit__( | |
| self, | |
| exc_type: Optional[Type[BaseException]], | |
| exc_value: Optional[BaseException], | |
| traceback: Optional[TracebackType], | |
| ) -> Literal[False]: | |
| self._tx_stack.stack.pop() | |
| if len(self._tx_stack.stack) == 0: | |
| if exc_type is None: | |
| self._conn.commit() | |
| else: | |
| self._conn.rollback() | |
| self._conn.cursor().close() | |
| self._pool.return_to_pool(self._conn) | |
| return False | |
| class SqliteDB(MigratableDB, SqlEmbeddingsQueue, SqlSysDB): | |
| _conn_pool: Pool | |
| _settings: Settings | |
| _migration_imports: Sequence[Traversable] | |
| _db_file: str | |
| _tx_stack: local | |
| _is_persistent: bool | |
| def __init__(self, system: System): | |
| self._settings = system.settings | |
| self._migration_imports = [ | |
| files("chromadb.migrations.embeddings_queue"), | |
| files("chromadb.migrations.sysdb"), | |
| files("chromadb.migrations.metadb"), | |
| ] | |
| self._is_persistent = self._settings.require("is_persistent") | |
| self._opentelemetry_client = system.require(OpenTelemetryClient) | |
| if not self._is_persistent: | |
| # In order to allow sqlite to be shared between multiple threads, we need to use a | |
| # URI connection string with shared cache. | |
| # See https://www.sqlite.org/sharedcache.html | |
| # https://stackoverflow.com/questions/3315046/sharing-a-memory-database-between-different-threads-in-python-using-sqlite3-pa | |
| self._db_file = "file::memory:?cache=shared" | |
| self._conn_pool = LockPool(self._db_file, is_uri=True) | |
| else: | |
| self._db_file = ( | |
| self._settings.require("persist_directory") + "/chroma.sqlite3" | |
| ) | |
| if not os.path.exists(self._db_file): | |
| os.makedirs(os.path.dirname(self._db_file), exist_ok=True) | |
| self._conn_pool = PerThreadPool(self._db_file) | |
| self._tx_stack = local() | |
| super().__init__(system) | |
| def start(self) -> None: | |
| super().start() | |
| with self.tx() as cur: | |
| cur.execute("PRAGMA foreign_keys = ON") | |
| cur.execute("PRAGMA case_sensitive_like = ON") | |
| self.initialize_migrations() | |
| def stop(self) -> None: | |
| super().stop() | |
| self._conn_pool.close() | |
| def querybuilder() -> Type[pypika.Query]: | |
| return pypika.Query # type: ignore | |
| def parameter_format() -> str: | |
| return "?" | |
| def migration_scope() -> str: | |
| return "sqlite" | |
| def migration_dirs(self) -> Sequence[Traversable]: | |
| return self._migration_imports | |
| def tx(self) -> TxWrapper: | |
| if not hasattr(self._tx_stack, "stack"): | |
| self._tx_stack.stack = [] | |
| return TxWrapper(self._conn_pool, stack=self._tx_stack) | |
| def reset_state(self) -> None: | |
| if not self._settings.require("allow_reset"): | |
| raise ValueError( | |
| "Resetting the database is not allowed. Set `allow_reset` to true in the config in tests or other non-production environments where reset should be permitted." | |
| ) | |
| with self.tx() as cur: | |
| # Drop all tables | |
| cur.execute( | |
| """ | |
| SELECT name FROM sqlite_master | |
| WHERE type='table' | |
| """ | |
| ) | |
| for row in cur.fetchall(): | |
| cur.execute(f"DROP TABLE IF EXISTS {row[0]}") | |
| self._conn_pool.close() | |
| self.start() | |
| super().reset_state() | |
| def setup_migrations(self) -> None: | |
| with self.tx() as cur: | |
| cur.execute( | |
| """ | |
| CREATE TABLE IF NOT EXISTS migrations ( | |
| dir TEXT NOT NULL, | |
| version INTEGER NOT NULL, | |
| filename TEXT NOT NULL, | |
| sql TEXT NOT NULL, | |
| hash TEXT NOT NULL, | |
| PRIMARY KEY (dir, version) | |
| ) | |
| """ | |
| ) | |
| def migrations_initialized(self) -> bool: | |
| with self.tx() as cur: | |
| cur.execute( | |
| """SELECT count(*) FROM sqlite_master | |
| WHERE type='table' AND name='migrations'""" | |
| ) | |
| if cur.fetchone()[0] == 0: | |
| return False | |
| else: | |
| return True | |
| def db_migrations(self, dir: Traversable) -> Sequence[Migration]: | |
| with self.tx() as cur: | |
| cur.execute( | |
| """ | |
| SELECT dir, version, filename, sql, hash | |
| FROM migrations | |
| WHERE dir = ? | |
| ORDER BY version ASC | |
| """, | |
| (dir.name,), | |
| ) | |
| migrations = [] | |
| for row in cur.fetchall(): | |
| found_dir = cast(str, row[0]) | |
| found_version = cast(int, row[1]) | |
| found_filename = cast(str, row[2]) | |
| found_sql = cast(str, row[3]) | |
| found_hash = cast(str, row[4]) | |
| migrations.append( | |
| Migration( | |
| dir=found_dir, | |
| version=found_version, | |
| filename=found_filename, | |
| sql=found_sql, | |
| hash=found_hash, | |
| scope=self.migration_scope(), | |
| ) | |
| ) | |
| return migrations | |
| def apply_migration(self, cur: base.Cursor, migration: Migration) -> None: | |
| cur.executescript(migration["sql"]) | |
| cur.execute( | |
| """ | |
| INSERT INTO migrations (dir, version, filename, sql, hash) | |
| VALUES (?, ?, ?, ?, ?) | |
| """, | |
| ( | |
| migration["dir"], | |
| migration["version"], | |
| migration["filename"], | |
| migration["sql"], | |
| migration["hash"], | |
| ), | |
| ) | |
| def uuid_from_db(value: Optional[Any]) -> Optional[UUID]: | |
| return UUID(value) if value is not None else None | |
| def uuid_to_db(uuid: Optional[UUID]) -> Optional[Any]: | |
| return str(uuid) if uuid is not None else None | |
| def unique_constraint_error() -> Type[BaseException]: | |
| return sqlite3.IntegrityError | |