Spaces:
Sleeping
Sleeping
Update chroma_db_utils.py
Browse files- chroma_db_utils.py +25 -13
chroma_db_utils.py
CHANGED
|
@@ -35,7 +35,7 @@ embedding_function = MistralEmbeddingFunction()
|
|
| 35 |
|
| 36 |
def create_chroma_db(documents: List[str]):
|
| 37 |
"""
|
| 38 |
-
Creates a persistent Chroma database using the provided documents.
|
| 39 |
"""
|
| 40 |
# Create a persistent directory for ChromaDB
|
| 41 |
persist_directory = "chroma_db"
|
|
@@ -46,22 +46,20 @@ def create_chroma_db(documents: List[str]):
|
|
| 46 |
path=persist_directory,
|
| 47 |
)
|
| 48 |
|
| 49 |
-
#
|
| 50 |
try:
|
| 51 |
-
|
| 52 |
-
db = chroma_client.get_collection(
|
| 53 |
name="document_collection",
|
| 54 |
embedding_function=embedding_function
|
| 55 |
)
|
| 56 |
-
#
|
| 57 |
-
db.
|
|
|
|
|
|
|
|
|
|
| 58 |
except Exception as e:
|
| 59 |
-
print(f"Error
|
| 60 |
-
#
|
| 61 |
-
db = chroma_client.create_collection(
|
| 62 |
-
name="document_collection",
|
| 63 |
-
embedding_function=embedding_function
|
| 64 |
-
)
|
| 65 |
|
| 66 |
# Add documents in batches to avoid memory issues
|
| 67 |
batch_size = 20
|
|
@@ -72,9 +70,11 @@ def create_chroma_db(documents: List[str]):
|
|
| 72 |
documents=batch,
|
| 73 |
ids=[f"doc_{j}" for j in range(i, i + len(batch))]
|
| 74 |
)
|
|
|
|
| 75 |
except Exception as e:
|
| 76 |
print(f"Error adding batch {i} to ChromaDB: {e}")
|
| 77 |
|
|
|
|
| 78 |
return db
|
| 79 |
|
| 80 |
def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
|
|
@@ -108,4 +108,16 @@ def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
|
|
| 108 |
return documents # Return only valid results
|
| 109 |
except Exception as e:
|
| 110 |
print(f"Error in get_relevant_passage: {str(e)}")
|
| 111 |
-
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def create_chroma_db(documents: List[str]):
|
| 37 |
"""
|
| 38 |
+
Creates or updates a persistent Chroma database using the provided documents.
|
| 39 |
"""
|
| 40 |
# Create a persistent directory for ChromaDB
|
| 41 |
persist_directory = "chroma_db"
|
|
|
|
| 46 |
path=persist_directory,
|
| 47 |
)
|
| 48 |
|
| 49 |
+
# Use get_or_create_collection to avoid UniqueConstraintError
|
| 50 |
try:
|
| 51 |
+
db = chroma_client.get_or_create_collection(
|
|
|
|
| 52 |
name="document_collection",
|
| 53 |
embedding_function=embedding_function
|
| 54 |
)
|
| 55 |
+
# Optionally clear existing documents if you want a fresh start
|
| 56 |
+
existing_ids = db.get()["ids"]
|
| 57 |
+
if existing_ids:
|
| 58 |
+
print(f"Clearing {len(existing_ids)} existing documents from collection...")
|
| 59 |
+
db.delete(ids=existing_ids)
|
| 60 |
except Exception as e:
|
| 61 |
+
print(f"Error accessing or creating collection: {e}")
|
| 62 |
+
raise # Re-raise to halt execution if something goes wrong
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
# Add documents in batches to avoid memory issues
|
| 65 |
batch_size = 20
|
|
|
|
| 70 |
documents=batch,
|
| 71 |
ids=[f"doc_{j}" for j in range(i, i + len(batch))]
|
| 72 |
)
|
| 73 |
+
print(f"Added batch {i} to {i + len(batch) - 1} successfully.")
|
| 74 |
except Exception as e:
|
| 75 |
print(f"Error adding batch {i} to ChromaDB: {e}")
|
| 76 |
|
| 77 |
+
print(f"ChromaDB collection 'document_collection' created/updated with {db.count()} documents.")
|
| 78 |
return db
|
| 79 |
|
| 80 |
def get_relevant_passage(query: str, db, n_results: int = 5) -> List[str]:
|
|
|
|
| 108 |
return documents # Return only valid results
|
| 109 |
except Exception as e:
|
| 110 |
print(f"Error in get_relevant_passage: {str(e)}")
|
| 111 |
+
return []
|
| 112 |
+
|
| 113 |
+
# Example usage (uncomment to test)
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
sample_docs = [
|
| 116 |
+
"The quick brown fox jumps over the lazy dog.",
|
| 117 |
+
"Artificial intelligence is transforming the world.",
|
| 118 |
+
"ChromaDB is a vector database for embeddings."
|
| 119 |
+
]
|
| 120 |
+
db = create_chroma_db(sample_docs)
|
| 121 |
+
query = "What is AI doing to the world?"
|
| 122 |
+
passages = get_relevant_passage(query, db)
|
| 123 |
+
print("Retrieved passages:", passages)
|