Spaces:
Sleeping
Sleeping
| import os | |
| import chromadb | |
| from datasets import load_dataset | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| from llama_index.core import VectorStoreIndex | |
| from llama_index.vector_stores.chroma import ChromaVectorStore | |
| from tqdm import tqdm | |
| try: | |
| dataset_id = "gaia-benchmark/GAIA" | |
| gaia_dataset = load_dataset('gaia-benchmark/GAIA', '2023_all') | |
| print("Dataset loaded successfully!") | |
| print(gaia_dataset) | |
| except Exception as e: | |
| print(f"Error loading the dataset: {e}") | |
| print("Make sure to:") | |
| print("1. Have logged in with 'huggingface-cli login'.") | |
| print("2. Have been granted access to the dataset on Hugging Face.") | |
| print(f"3. That the dataset ID '{dataset_id}' is correct.") | |
| print("4. That you've added 'trust_remote_code=True' to the load_dataset call.") | |
| def setup_chroma_db(): | |
| """Configure ChromaDB as the vector database.""" | |
| # Set up ChromaDB client | |
| db_path = os.path.join(os.getcwd(), "chroma_db") | |
| # Ensure directory exists | |
| os.makedirs(db_path, exist_ok=True) | |
| # First step: Initialize ChromaDB client | |
| db = chromadb.PersistentClient(path=db_path) | |
| # Second step: Set up embedding model | |
| embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-mpnet-base-v2") | |
| # Third step: Create or get the collection | |
| try: | |
| chroma_collection = db.get_collection("gaia_examples") | |
| print("Existing collection found...") | |
| except: | |
| chroma_collection = db.create_collection( | |
| "gaia_examples", | |
| metadata={"description": "GAIA benchmark examples for agent training"} | |
| ) | |
| print("New collection created...") | |
| # Fourth step: Set up the vector store | |
| vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
| # Fifth step: Create the index | |
| index = VectorStoreIndex.from_vector_store( | |
| vector_store=vector_store, | |
| embed_model=embed_model | |
| ) | |
| return db, chroma_collection, vector_store, index | |
| def load_gaia_to_chroma(): | |
| """Load the GAIA dataset into ChromaDB.""" | |
| # Configure ChromaDB | |
| _, collection, _, index = setup_chroma_db() | |
| # Load dataset | |
| print("Loading GAIA dataset...") | |
| gaia_dataset = load_dataset('gaia-benchmark/GAIA', '2023_all') | |
| # Total number of examples to process | |
| total_examples = len(gaia_dataset['validation']) + len(gaia_dataset['test']) | |
| print(f"Total examples to process: {total_examples}") | |
| # Global counter for unique ID | |
| example_counter = 0 | |
| # Process validation and test sets | |
| for split in ['validation', 'test']: | |
| print(f"\nProcessing {split} set...") | |
| # Lists for batch loading | |
| ids = [] | |
| texts = [] | |
| metadatas = [] | |
| # Use tqdm to show progress | |
| for idx, example in enumerate(tqdm(gaia_dataset[split], desc=f"Processing {split}")): | |
| # Unique ID | |
| doc_id = f"{split}_{example['task_id']}" | |
| # Text for embedding | |
| text_content = f""" | |
| Question: {example['Question']} | |
| Level: {example['Level']} | |
| Steps to solve: | |
| {example['Annotator Metadata']['Steps']} | |
| Tools used: | |
| {example['Annotator Metadata']['Tools']} | |
| Final Answer: {example['Final answer']} | |
| """ | |
| # Metadata | |
| metadata = { | |
| "task_id": example['task_id'], | |
| "level": example['Level'], | |
| "type": "gaia_example", | |
| "split": split | |
| } | |
| # Add to lists | |
| ids.append(doc_id) | |
| texts.append(text_content) | |
| metadatas.append(metadata) | |
| # Load in batches every 50 examples or in the last batch | |
| if len(ids) >= 50 or idx == len(gaia_dataset[split]) - 1: | |
| # Add documents in batches - ChromaDB will calculate embeddings automatically | |
| collection.add( | |
| ids=ids, | |
| documents=texts, | |
| metadatas=metadatas | |
| ) | |
| print(f"Batch of {len(ids)} examples loaded...") | |
| example_counter += len(ids) | |
| # Reset lists | |
| ids = [] | |
| texts = [] | |
| metadatas = [] | |
| print(f"\nLoad complete. {example_counter} examples stored in ChromaDB.") | |
| print(f"Data is saved at: {os.path.join(os.getcwd(), 'chroma_db')}") | |
| def test_chroma_search(): | |
| """Test the search in ChromaDB.""" | |
| _, collection, _, index = setup_chroma_db() | |
| # Example query | |
| test_query = "What is the last word before the second chorus of a famous song?" | |
| # Perform search | |
| results = collection.query( | |
| query_texts=[test_query], | |
| n_results=2, | |
| where={"type": "gaia_example"} | |
| ) | |
| # Show results | |
| print("\n=== Example search results ===") | |
| for i in range(len(results["ids"][0])): | |
| print(f"\nResult #{i+1}:") | |
| print(f"ID: {results['ids'][0][i]}") | |
| print(f"Metadata: {results['metadatas'][0][i]}") | |
| print(f"Content: {results['documents'][0][i][:200]}...") # Show first 200 characters | |
| print("\n=== End of results ===") | |
| # Run the process | |
| if __name__ == "__main__": | |
| print("Starting GAIA data load to ChromaDB...") | |
| load_gaia_to_chroma() | |
| print("\nTesting search...") | |
| test_chroma_search() |