""" Example script demonstrating how to use the RAG system programmatically. """ import os from config import settings from dataset_loader import RAGBenchLoader from vector_store import ChromaDBManager from llm_client import GroqLLMClient, RAGPipeline from trace_evaluator import TRACEEvaluator def main(): """Example usage of RAG system.""" # Set your API key api_key = os.getenv("GROQ_API_KEY") or "your_api_key_here" if api_key == "your_api_key_here": print("Please set your GROQ_API_KEY in .env file or environment variable") return print("=" * 50) print("RAG System Example") print("=" * 50) # 1. Load dataset print("\n1. Loading dataset...") loader = RAGBenchLoader() dataset = loader.load_dataset("hotpotqa", split="train", max_samples=20) print(f"Loaded {len(dataset)} samples") # 2. Create vector store and collection print("\n2. Creating vector store...") vector_store = ChromaDBManager() collection_name = "example_collection" embedding_model = "emilyalsentzer/Bio_ClinicalBERT" chunking_strategy = "hybrid" print(f"Loading data into collection with {chunking_strategy} chunking...") vector_store.load_dataset_into_collection( collection_name=collection_name, embedding_model_name=embedding_model, chunking_strategy=chunking_strategy, dataset_data=dataset, chunk_size=512, overlap=50 ) # 3. Initialize LLM client print("\n3. Initializing LLM client...") llm_client = GroqLLMClient( api_key=api_key, model_name="llama-3.1-8b-instant", max_rpm=30, rate_limit_delay=2.0 ) # 4. Create RAG pipeline print("\n4. Creating RAG pipeline...") rag = RAGPipeline(llm_client, vector_store) # 5. Query the system print("\n5. Querying the system...") queries = [ "What is machine learning?", "How does neural network work?", "What is deep learning?" ] for i, query in enumerate(queries, 1): print(f"\n--- Query {i}: {query} ---") result = rag.query(query, n_results=3) print(f"Response: {result['response']}") print(f"\nRetrieved {len(result['retrieved_documents'])} documents:") for j, doc in enumerate(result['retrieved_documents'], 1): print(f"\nDocument {j} (Distance: {doc.get('distance', 'N/A')}):") print(f"{doc['document'][:200]}...") # 6. Run evaluation print("\n6. Running TRACE evaluation...") evaluator = TRACEEvaluator(llm_client) # Prepare test cases test_cases = [] test_samples = loader.get_test_data("hotpotqa", num_samples=5) for sample in test_samples: result = rag.query(sample["question"], n_results=5) test_cases.append({ "query": sample["question"], "response": result["response"], "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], "ground_truth": sample.get("answer", "") }) results = evaluator.evaluate_batch(test_cases) print("\nTRACE Evaluation Results:") print(f"Utilization: {results['utilization']:.3f}") print(f"Relevance: {results['relevance']:.3f}") print(f"Adherence: {results['adherence']:.3f}") print(f"Completeness: {results['completeness']:.3f}") print(f"Average: {results['average']:.3f}") # 7. View chat history print("\n7. Chat History:") history = rag.get_chat_history() print(f"Total conversations: {len(history)}") print("\n" + "=" * 50) print("Example completed successfully!") print("=" * 50) if __name__ == "__main__": main()