MahatirTusher commited on
Commit
1d53a6e
·
verified ·
1 Parent(s): 76685ff

Update chroma_db_utils.py

Browse files
Files changed (1) hide show
  1. chroma_db_utils.py +25 -13
chroma_db_utils.py CHANGED
@@ -35,7 +35,7 @@ embedding_function = MistralEmbeddingFunction()
35
 
36
  def create_chroma_db(documents: List[str]):
37
  """
38
- Creates a persistent Chroma database using the provided documents.
39
  """
40
  # Create a persistent directory for ChromaDB
41
  persist_directory = "chroma_db"
@@ -46,22 +46,20 @@ def create_chroma_db(documents: List[str]):
46
  path=persist_directory,
47
  )
48
 
49
- # Get or create collection
50
  try:
51
- # Try to get existing collection
52
- db = chroma_client.get_collection(
53
  name="document_collection",
54
  embedding_function=embedding_function
55
  )
56
- # Clear existing documents
57
- db.delete(db.get()["ids"])
 
 
 
58
  except Exception as e:
59
- print(f"Error getting collection: {e}. Creating a new collection...")
60
- # Create new collection if it doesn't exist
61
- db = chroma_client.create_collection(
62
- name="document_collection",
63
- embedding_function=embedding_function
64
- )
65
 
66
  # Add documents in batches to avoid memory issues
67
  batch_size = 20
@@ -72,9 +70,11 @@ def create_chroma_db(documents: List[str]):
72
  documents=batch,
73
  ids=[f"doc_{j}" for j in range(i, i + len(batch))]
74
  )
 
75
  except Exception as e:
76
  print(f"Error adding batch {i} to ChromaDB: {e}")
77
 
 
78
  return db
79
 
80
  def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
@@ -108,4 +108,16 @@ def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
108
  return documents # Return only valid results
109
  except Exception as e:
110
  print(f"Error in get_relevant_passage: {str(e)}")
111
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def create_chroma_db(documents: List[str]):
37
  """
38
+ Creates or updates a persistent Chroma database using the provided documents.
39
  """
40
  # Create a persistent directory for ChromaDB
41
  persist_directory = "chroma_db"
 
46
  path=persist_directory,
47
  )
48
 
49
+ # Use get_or_create_collection to avoid UniqueConstraintError
50
  try:
51
+ db = chroma_client.get_or_create_collection(
 
52
  name="document_collection",
53
  embedding_function=embedding_function
54
  )
55
+ # Optionally clear existing documents if you want a fresh start
56
+ existing_ids = db.get()["ids"]
57
+ if existing_ids:
58
+ print(f"Clearing {len(existing_ids)} existing documents from collection...")
59
+ db.delete(ids=existing_ids)
60
  except Exception as e:
61
+ print(f"Error accessing or creating collection: {e}")
62
+ raise # Re-raise to halt execution if something goes wrong
 
 
 
 
63
 
64
  # Add documents in batches to avoid memory issues
65
  batch_size = 20
 
70
  documents=batch,
71
  ids=[f"doc_{j}" for j in range(i, i + len(batch))]
72
  )
73
+ print(f"Added batch {i} to {i + len(batch) - 1} successfully.")
74
  except Exception as e:
75
  print(f"Error adding batch {i} to ChromaDB: {e}")
76
 
77
+ print(f"ChromaDB collection 'document_collection' created/updated with {db.count()} documents.")
78
  return db
79
 
80
  def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
 
108
  return documents # Return only valid results
109
  except Exception as e:
110
  print(f"Error in get_relevant_passage: {str(e)}")
111
+ return []
112
+
113
+ # Example usage (uncomment to test)
114
+ if __name__ == "__main__":
115
+ sample_docs = [
116
+ "The quick brown fox jumps over the lazy dog.",
117
+ "Artificial intelligence is transforming the world.",
118
+ "ChromaDB is a vector database for embeddings."
119
+ ]
120
+ db = create_chroma_db(sample_docs)
121
+ query = "What is AI doing to the world?"
122
+ passages = get_relevant_passage(query, db)
123
+ print("Retrieved passages:", passages)