Spaces:
Running
Running
| """ | |
| Example script demonstrating how to use the RAG system programmatically. | |
| """ | |
| import os | |
| from config import settings | |
| from dataset_loader import RAGBenchLoader | |
| from vector_store import ChromaDBManager | |
| from llm_client import GroqLLMClient, RAGPipeline | |
| from trace_evaluator import TRACEEvaluator | |
| def main(): | |
| """Example usage of RAG system.""" | |
| # Set your API key | |
| api_key = os.getenv("GROQ_API_KEY") or "your_api_key_here" | |
| if api_key == "your_api_key_here": | |
| print("Please set your GROQ_API_KEY in .env file or environment variable") | |
| return | |
| print("=" * 50) | |
| print("RAG System Example") | |
| print("=" * 50) | |
| # 1. Load dataset | |
| print("\n1. Loading dataset...") | |
| loader = RAGBenchLoader() | |
| dataset = loader.load_dataset("hotpotqa", split="train", max_samples=20) | |
| print(f"Loaded {len(dataset)} samples") | |
| # 2. Create vector store and collection | |
| print("\n2. Creating vector store...") | |
| vector_store = ChromaDBManager() | |
| collection_name = "example_collection" | |
| embedding_model = "emilyalsentzer/Bio_ClinicalBERT" | |
| chunking_strategy = "hybrid" | |
| print(f"Loading data into collection with {chunking_strategy} chunking...") | |
| vector_store.load_dataset_into_collection( | |
| collection_name=collection_name, | |
| embedding_model_name=embedding_model, | |
| chunking_strategy=chunking_strategy, | |
| dataset_data=dataset, | |
| chunk_size=512, | |
| overlap=50 | |
| ) | |
| # 3. Initialize LLM client | |
| print("\n3. Initializing LLM client...") | |
| llm_client = GroqLLMClient( | |
| api_key=api_key, | |
| model_name="llama-3.1-8b-instant", | |
| max_rpm=30, | |
| rate_limit_delay=2.0 | |
| ) | |
| # 4. Create RAG pipeline | |
| print("\n4. Creating RAG pipeline...") | |
| rag = RAGPipeline(llm_client, vector_store) | |
| # 5. Query the system | |
| print("\n5. Querying the system...") | |
| queries = [ | |
| "What is machine learning?", | |
| "How does neural network work?", | |
| "What is deep learning?" | |
| ] | |
| for i, query in enumerate(queries, 1): | |
| print(f"\n--- Query {i}: {query} ---") | |
| result = rag.query(query, n_results=3) | |
| print(f"Response: {result['response']}") | |
| print(f"\nRetrieved {len(result['retrieved_documents'])} documents:") | |
| for j, doc in enumerate(result['retrieved_documents'], 1): | |
| print(f"\nDocument {j} (Distance: {doc.get('distance', 'N/A')}):") | |
| print(f"{doc['document'][:200]}...") | |
| # 6. Run evaluation | |
| print("\n6. Running TRACE evaluation...") | |
| evaluator = TRACEEvaluator(llm_client) | |
| # Prepare test cases | |
| test_cases = [] | |
| test_samples = loader.get_test_data("hotpotqa", num_samples=5) | |
| for sample in test_samples: | |
| result = rag.query(sample["question"], n_results=5) | |
| test_cases.append({ | |
| "query": sample["question"], | |
| "response": result["response"], | |
| "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], | |
| "ground_truth": sample.get("answer", "") | |
| }) | |
| results = evaluator.evaluate_batch(test_cases) | |
| print("\nTRACE Evaluation Results:") | |
| print(f"Utilization: {results['utilization']:.3f}") | |
| print(f"Relevance: {results['relevance']:.3f}") | |
| print(f"Adherence: {results['adherence']:.3f}") | |
| print(f"Completeness: {results['completeness']:.3f}") | |
| print(f"Average: {results['average']:.3f}") | |
| # 7. View chat history | |
| print("\n7. Chat History:") | |
| history = rag.get_chat_history() | |
| print(f"Total conversations: {len(history)}") | |
| print("\n" + "=" * 50) | |
| print("Example completed successfully!") | |
| print("=" * 50) | |
| if __name__ == "__main__": | |
| main() | |