|
|
""" |
|
|
Example script demonstrating how to use the RAG system programmatically. |
|
|
""" |
|
|
import os |
|
|
from config import settings |
|
|
from dataset_loader import RAGBenchLoader |
|
|
from vector_store import ChromaDBManager |
|
|
from llm_client import GroqLLMClient, RAGPipeline |
|
|
from trace_evaluator import TRACEEvaluator |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Example usage of RAG system.""" |
|
|
|
|
|
|
|
|
api_key = os.getenv("GROQ_API_KEY") or "your_api_key_here" |
|
|
|
|
|
if api_key == "your_api_key_here": |
|
|
print("Please set your GROQ_API_KEY in .env file or environment variable") |
|
|
return |
|
|
|
|
|
print("=" * 50) |
|
|
print("RAG System Example") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
print("\n1. Loading dataset...") |
|
|
loader = RAGBenchLoader() |
|
|
dataset = loader.load_dataset("hotpotqa", split="train", max_samples=20) |
|
|
print(f"Loaded {len(dataset)} samples") |
|
|
|
|
|
|
|
|
print("\n2. Creating vector store...") |
|
|
vector_store = ChromaDBManager() |
|
|
|
|
|
collection_name = "example_collection" |
|
|
embedding_model = "emilyalsentzer/Bio_ClinicalBERT" |
|
|
chunking_strategy = "hybrid" |
|
|
|
|
|
print(f"Loading data into collection with {chunking_strategy} chunking...") |
|
|
vector_store.load_dataset_into_collection( |
|
|
collection_name=collection_name, |
|
|
embedding_model_name=embedding_model, |
|
|
chunking_strategy=chunking_strategy, |
|
|
dataset_data=dataset, |
|
|
chunk_size=512, |
|
|
overlap=50 |
|
|
) |
|
|
|
|
|
|
|
|
print("\n3. Initializing LLM client...") |
|
|
llm_client = GroqLLMClient( |
|
|
api_key=api_key, |
|
|
model_name="llama-3.1-8b-instant", |
|
|
max_rpm=30, |
|
|
rate_limit_delay=2.0 |
|
|
) |
|
|
|
|
|
|
|
|
print("\n4. Creating RAG pipeline...") |
|
|
rag = RAGPipeline(llm_client, vector_store) |
|
|
|
|
|
|
|
|
print("\n5. Querying the system...") |
|
|
queries = [ |
|
|
"What is machine learning?", |
|
|
"How does neural network work?", |
|
|
"What is deep learning?" |
|
|
] |
|
|
|
|
|
for i, query in enumerate(queries, 1): |
|
|
print(f"\n--- Query {i}: {query} ---") |
|
|
result = rag.query(query, n_results=3) |
|
|
|
|
|
print(f"Response: {result['response']}") |
|
|
print(f"\nRetrieved {len(result['retrieved_documents'])} documents:") |
|
|
for j, doc in enumerate(result['retrieved_documents'], 1): |
|
|
print(f"\nDocument {j} (Distance: {doc.get('distance', 'N/A')}):") |
|
|
print(f"{doc['document'][:200]}...") |
|
|
|
|
|
|
|
|
print("\n6. Running TRACE evaluation...") |
|
|
evaluator = TRACEEvaluator(llm_client) |
|
|
|
|
|
|
|
|
test_cases = [] |
|
|
test_samples = loader.get_test_data("hotpotqa", num_samples=5) |
|
|
|
|
|
for sample in test_samples: |
|
|
result = rag.query(sample["question"], n_results=5) |
|
|
test_cases.append({ |
|
|
"query": sample["question"], |
|
|
"response": result["response"], |
|
|
"retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]], |
|
|
"ground_truth": sample.get("answer", "") |
|
|
}) |
|
|
|
|
|
results = evaluator.evaluate_batch(test_cases) |
|
|
|
|
|
print("\nTRACE Evaluation Results:") |
|
|
print(f"Utilization: {results['utilization']:.3f}") |
|
|
print(f"Relevance: {results['relevance']:.3f}") |
|
|
print(f"Adherence: {results['adherence']:.3f}") |
|
|
print(f"Completeness: {results['completeness']:.3f}") |
|
|
print(f"Average: {results['average']:.3f}") |
|
|
|
|
|
|
|
|
print("\n7. Chat History:") |
|
|
history = rag.get_chat_history() |
|
|
print(f"Total conversations: {len(history)}") |
|
|
|
|
|
print("\n" + "=" * 50) |
|
|
print("Example completed successfully!") |
|
|
print("=" * 50) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|