oluinioluwa814 commited on
Commit
c697c36
·
verified ·
1 Parent(s): b434dd2

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +31 -32
rag_pipeline.py CHANGED
@@ -1,57 +1,56 @@
1
  from sentence_transformers import SentenceTransformer
2
- import chromadb
3
- from chromadb.config import Settings
4
  import uuid
5
 
6
  class RAGPipeline:
7
  """
8
- Retrieval-Augmented Generation (RAG) pipeline:
9
- - Stores documents in a vector database (ChromaDB)
10
- - Generates embeddings using SentenceTransformer
11
- - Retrieves relevant documents for queries
12
  """
13
 
14
  def __init__(self, db_dir: str = "./chroma_store"):
15
- # Initialize the embedding model
16
  self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
17
 
18
- # Initialize ChromaDB client
19
- self.client = chromadb.Client(
20
- Settings(chroma_db_impl="duckdb+parquet", persist_directory=db_dir)
21
- )
22
-
23
- # Create or get a collection for storing documents
24
- self.collection = self.client.get_or_create_collection(name="rag_collection")
 
 
 
 
 
 
 
25
 
26
  def add_document(self, text: str, doc_id: str = None):
27
- """
28
- Adds a single document to the vector database.
29
- """
30
  if not doc_id:
31
  doc_id = str(uuid.uuid4())
32
-
33
- embedding = self.embedder.encode(text).tolist()
34
- self.collection.add(documents=[text], ids=[doc_id], embeddings=[embedding])
 
35
 
36
  def retrieve(self, query: str, top_k: int = 3):
37
- """
38
- Retrieves the top_k most relevant documents for a query.
39
- """
40
- q_embedding = self.embedder.encode(query).tolist()
41
  results = self.collection.query(
42
- query_embeddings=[q_embedding],
43
  n_results=top_k
44
  )
45
-
46
- # Return list of documents
47
  return results.get("documents", [[]])[0]
48
 
49
  def reset_vector_store(self):
50
- """
51
- Clears all documents from the collection.
52
- """
53
  try:
54
- self.client.delete_collection("rag_collection")
55
  except Exception:
56
  pass
57
- self.collection = self.client.get_or_create_collection(name="rag_collection")
 
 
 
 
 
 
1
  from sentence_transformers import SentenceTransformer
2
+ from chromadb import Client
3
+ from chromadb.utils import embedding_functions
4
  import uuid
5
 
6
  class RAGPipeline:
7
  """
8
+ RAG pipeline using ChromaDB new client API.
 
 
 
9
  """
10
 
11
  def __init__(self, db_dir: str = "./chroma_store"):
12
+ # Embedding model
13
  self.embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
14
 
15
+ # ChromaDB new client
16
+ self.client = Client() # No Settings() needed in v0.5+
17
+ self.collection_name = "rag_collection"
18
+
19
+ # Check if collection exists; create if not
20
+ if self.collection_name in [c.name for c in self.client.list_collections()]:
21
+ self.collection = self.client.get_collection(self.collection_name)
22
+ else:
23
+ self.collection = self.client.create_collection(
24
+ name=self.collection_name,
25
+ embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
26
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
27
+ )
28
+ )
29
 
30
  def add_document(self, text: str, doc_id: str = None):
 
 
 
31
  if not doc_id:
32
  doc_id = str(uuid.uuid4())
33
+ self.collection.add(
34
+ documents=[text],
35
+ ids=[doc_id],
36
+ )
37
 
38
  def retrieve(self, query: str, top_k: int = 3):
 
 
 
 
39
  results = self.collection.query(
40
+ query_texts=[query],
41
  n_results=top_k
42
  )
 
 
43
  return results.get("documents", [[]])[0]
44
 
45
  def reset_vector_store(self):
46
+ # Delete and recreate the collection
 
 
47
  try:
48
+ self.client.delete_collection(self.collection_name)
49
  except Exception:
50
  pass
51
+ self.collection = self.client.create_collection(
52
+ name=self.collection_name,
53
+ embedding_function=embedding_functions.SentenceTransformerEmbeddingFunction(
54
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
55
+ )
56
+ )