Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from mistralai import Mistral | |
| import numpy as np | |
| import time | |
| import chromadb | |
| from chromadb.config import Settings | |
| import json | |
| import hashlib | |
| load_dotenv() | |
| MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") | |
| COLLECTION_NAME = "webpages_collection" | |
| PERSIST_DIRECTORY = "./chroma_db" | |
| def vectorize(input_texts, batch_size=5): | |
| """ | |
| Get the text embeddings for the given inputs using Mistral API. | |
| """ | |
| try: | |
| client = Mistral(api_key=MISTRAL_API_KEY) | |
| except Exception as e: | |
| print(f"Error initializing Mistral client: {e}") | |
| return [] | |
| embeddings = [] | |
| for i in range(0, len(input_texts), batch_size): | |
| batch = input_texts[i:i + batch_size] | |
| while True: | |
| try: | |
| embeddings_batch_response = client.embeddings.create( | |
| model="mistral-embed", | |
| inputs=batch | |
| ) | |
| time.sleep(1) | |
| embeddings.extend([data.embedding for data in embeddings_batch_response.data]) | |
| break | |
| except Exception as e: | |
| if "rate limit exceeded" in str(e).lower(): | |
| print("Rate limit exceeded. Retrying after 10 seconds...") | |
| time.sleep(10) | |
| else: | |
| print(f"Error in embedding batch: {e}") | |
| raise | |
| return embeddings | |
| def chunk_content(markdown_content, chunk_size=2048): | |
| """ | |
| Vectorizes the given markdown content into chunks of specified size without cutting sentences. | |
| """ | |
| def find_sentence_end(text, start): | |
| """Find the nearest sentence end from the start index.""" | |
| punctuations = {'.', '!', '?'} | |
| end = start | |
| while end < len(text) and text[end] not in punctuations: | |
| end += 1 | |
| while end < len(text) and text[end] in punctuations: | |
| end += 1 | |
| while end > start and text[end - 1] not in punctuations: | |
| end -= 1 | |
| return end | |
| chunks = [] | |
| start = 0 | |
| while start < len(markdown_content): | |
| end = min(start + chunk_size, len(markdown_content)) | |
| end = find_sentence_end(markdown_content, end) | |
| chunks.append(markdown_content[start:end].strip()) | |
| start = end | |
| return chunks | |
| def generate_chunk_id(chunk): | |
| """Generate a unique ID for a chunk using SHA-256 hash.""" | |
| return hashlib.sha256(chunk.encode('utf-8')).hexdigest() | |
| def load_in_vector_db(markdown_content, metadatas=None, collection_name=COLLECTION_NAME): | |
| """ | |
| Load the text embeddings into a ChromaDB collection for efficient similarity search. | |
| """ | |
| try: | |
| client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
| except Exception as e: | |
| print(f"Error initializing ChromaDB client: {e}") | |
| return | |
| try: | |
| if collection_name not in [col.name for col in client.list_collections()]: | |
| collection = client.create_collection(collection_name) | |
| else: | |
| collection = client.get_collection(collection_name) | |
| except Exception as e: | |
| print(f"Error accessing collection: {e}") | |
| return | |
| try: | |
| existing_items = collection.get() | |
| except Exception as e: | |
| print(f"Error retrieving existing items: {e}") | |
| return | |
| existing_ids = set() | |
| if 'ids' in existing_items: | |
| existing_ids.update(existing_items['ids']) | |
| chunks = chunk_content(markdown_content) | |
| text_to_vectorize = [] | |
| for chunk in chunks: | |
| chunk_id = generate_chunk_id(chunk) | |
| if chunk_id not in existing_ids: | |
| text_to_vectorize.append(chunk) | |
| print(f"New chunks to vectorize: {len(text_to_vectorize)}") | |
| if text_to_vectorize: | |
| embeddings = vectorize(text_to_vectorize) | |
| for embedding, chunk in zip(embeddings, text_to_vectorize): | |
| chunk_id = generate_chunk_id(chunk) | |
| if chunk_id not in existing_ids: | |
| try: | |
| collection.add( | |
| embeddings=[embedding], | |
| documents=[chunk], | |
| metadatas=[metadatas], | |
| ids=[chunk_id] | |
| ) | |
| existing_ids.add(chunk_id) | |
| except Exception as e: | |
| print(f"Error adding embedding to collection: {e}") | |
| def retrieve_from_database(query, collection_name=COLLECTION_NAME, n_results=5, distance_threshold=None): | |
| """ | |
| Retrieve the most similar documents from the vector store based on the query. | |
| """ | |
| try: | |
| client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
| collection = client.get_collection(collection_name) | |
| except Exception as e: | |
| print(f"Error accessing collection: {e}") | |
| return | |
| try: | |
| query_embeddings = vectorize([query]) | |
| except Exception as e: | |
| print(f"Error vectorizing query: {e}") | |
| return | |
| try: | |
| raw_results = collection.query( | |
| query_embeddings=query_embeddings, | |
| n_results=n_results, | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| except Exception as e: | |
| print(f"Error querying collection: {e}") | |
| return | |
| if distance_threshold is not None: | |
| filtered_results = { | |
| "ids": [], | |
| "distances": [], | |
| "metadatas": [], | |
| "documents": [] | |
| } | |
| for i, distance in enumerate(raw_results['distances'][0]): | |
| if distance <= distance_threshold: | |
| filtered_results['ids'].append(raw_results['ids'][0][i]) | |
| filtered_results['distances'].append(distance) | |
| filtered_results['metadatas'].append(raw_results['metadatas'][0][i]) | |
| filtered_results['documents'].append(raw_results['documents'][0][i]) | |
| results = filtered_results | |
| if len(results['documents']) == 0: | |
| return "No relevant data found in the knowledge database. Have you checked any webpages? If so, please try to find more relevant data." | |
| else: | |
| return results | |
| else: | |
| return raw_results | |
| def search_documents(collection_name=COLLECTION_NAME, query=None, query_embedding=None, metadata_filter=None, n_results=10): | |
| """ | |
| Search for documents in a ChromaDB collection. | |
| :param collection_name: The name of the collection to search within. | |
| :param query: The text query to search for (optional). | |
| :param query_embedding: The embedding query to search for (optional). | |
| :param metadata_filter: A filter to apply to the metadata (optional). | |
| :param n_results: The number of results to return (default is 10). | |
| :return: The search results. | |
| """ | |
| client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
| collection = client.get_collection(collection_name) | |
| if query: | |
| query_embedding = vectorize([query])[0] | |
| if query_embedding: | |
| results = collection.query(query_embeddings=[query_embedding], n_results=n_results, where=metadata_filter) | |
| else: | |
| results = collection.get(where=metadata_filter, limit=n_results) | |
| return results | |
| def delete_documents(collection_name=COLLECTION_NAME, ids=None): | |
| """ | |
| Delete documents from a ChromaDB collection based on their IDs. | |
| :param collection_name: The name of the collection. | |
| :param ids: A list of IDs of the documents to delete. | |
| """ | |
| client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
| collection = client.get_collection(collection_name) | |
| collection.delete(ids=ids) | |
| print(f"Documents with IDs {ids} have been deleted from the collection {collection_name}.") | |
| def delete_collection(collection_name=COLLECTION_NAME): | |
| """ | |
| Delete a ChromaDB collection. | |
| :param collection_name: The name of the collection to delete. | |
| """ | |
| client = chromadb.PersistentClient(path=PERSIST_DIRECTORY) | |
| client.delete_collection(collection_name) | |
| print(f"Collection {collection_name} has been deleted.") |