Aryan Jain commited on
Commit
e9024b0
·
1 Parent(s): 6aaeb0e

feat: add ChromaClient for managing embeddings with ChromaDB

Browse files

- Updated pyproject.toml to include chromadb as a dependency.
- Implemented ChromaClient class for handling text embeddings and interactions with ChromaDB.
- Added methods for upserting texts and querying embeddings.
- Integrated OpenAI's embedding model for generating text embeddings.

Files changed (3) hide show
  1. poetry.lock +0 -0
  2. pyproject.toml +2 -1
  3. src/utils/_chroma_client.py +102 -0
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -16,7 +16,8 @@ dependencies = [
16
  "alembic (>=1.18.4,<2.0.0)",
17
  "aiosqlite (>=0.22.1,<0.23.0)",
18
  "pydantic-ai-slim (>=1.67.0,<2.0.0)",
19
- "openai (>=2.26.0,<3.0.0)"
 
20
  ]
21
 
22
 
 
16
  "alembic (>=1.18.4,<2.0.0)",
17
  "aiosqlite (>=0.22.1,<0.23.0)",
18
  "pydantic-ai-slim (>=1.67.0,<2.0.0)",
19
+ "openai (>=2.26.0,<3.0.0)",
20
+ "chromadb (>=1.5.5,<2.0.0)"
21
  ]
22
 
23
 
src/utils/_chroma_client.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import uuid
3
+ import os
4
+ import chromadb
5
+ from chromadb.config import Settings
6
+ from openai import AsyncOpenAI
7
+
8
+
9
+ class ChromaClient:
10
+ def __init__(self):
11
+ self.collection_name = os.getenv("CHROMA_COLLECTION_NAME", "default_collection")
12
+ self.persist_directory = os.getenv("CHROMA_PERSIST_DIRECTORY", "./chroma_db")
13
+ self.use_persistent = (
14
+ os.getenv("CHROMA_USE_PERSISTENT", "true").lower() == "true"
15
+ )
16
+ self.chroma_host = os.getenv("CHROMA_HOST", "localhost")
17
+ self.chroma_port = int(os.getenv("CHROMA_PORT", "8000"))
18
+ self.embedding_model = os.getenv(
19
+ "OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"
20
+ )
21
+
22
+ self.openai = AsyncOpenAI()
23
+ self.client = None
24
+ self.collection = None
25
+
26
+ async def __aenter__(self):
27
+ if self.use_persistent:
28
+ self.client = chromadb.PersistentClient(path=self.persist_directory)
29
+ else:
30
+ self.client = chromadb.HttpClient(
31
+ host=self.chroma_host,
32
+ port=self.chroma_port,
33
+ settings=Settings(anonymized_telemetry=False),
34
+ )
35
+
36
+ self.collection = self.client.get_or_create_collection(
37
+ name=self.collection_name,
38
+ metadata={"hnsw:space": "cosine"},
39
+ )
40
+ return self
41
+
42
+ async def __aexit__(self, exc_type, exc_value, traceback):
43
+ self.client = None
44
+ self.collection = None
45
+
46
+ async def _get_text_embedding(self, text: str) -> list[float]:
47
+ response = await self.openai.embeddings.create(
48
+ input=text,
49
+ model=self.embedding_model,
50
+ )
51
+ return response.data[0].embedding
52
+
53
+ async def _get_batch_embeddings(self, texts: list[str]) -> list[list[float]]:
54
+ response = await self.openai.embeddings.create(
55
+ input=texts,
56
+ model=self.embedding_model,
57
+ )
58
+ return [item.embedding for item in response.data]
59
+
60
+ async def upsert(self, texts: list[str], metadatas: list[dict] = None):
61
+ if not texts:
62
+ return
63
+
64
+ if metadatas is None:
65
+ metadatas = [{} for _ in texts]
66
+
67
+ if len(texts) != len(metadatas):
68
+ raise ValueError("texts and metadatas must have the same length")
69
+
70
+ ids = [meta.pop("id", str(uuid.uuid4())) for meta in metadatas]
71
+ embeddings = await self._get_batch_embeddings(texts)
72
+
73
+ loop = asyncio.get_event_loop()
74
+ await loop.run_in_executor(
75
+ None,
76
+ lambda: self.collection.upsert(
77
+ ids=ids,
78
+ embeddings=embeddings,
79
+ documents=texts,
80
+ metadatas=metadatas,
81
+ ),
82
+ )
83
+
84
+ async def query(self, query: str, n_results: int = 5) -> dict:
85
+ query_embedding = await self._get_text_embedding(query)
86
+
87
+ loop = asyncio.get_event_loop()
88
+ results = await loop.run_in_executor(
89
+ None,
90
+ lambda: self.collection.query(
91
+ query_embeddings=[query_embedding],
92
+ n_results=n_results,
93
+ include=["documents", "metadatas", "distances"],
94
+ ),
95
+ )
96
+
97
+ return {
98
+ "ids": results["ids"][0],
99
+ "documents": results["documents"][0],
100
+ "metadatas": results["metadatas"][0],
101
+ "distances": results["distances"][0],
102
+ }