Developer
Initial commit for HuggingFace Spaces - RAG Capstone Project with Qdrant Cloud
1d10b0a
"""
Example script demonstrating how to use the RAG system programmatically.
"""
import os
from config import settings
from dataset_loader import RAGBenchLoader
from vector_store import ChromaDBManager
from llm_client import GroqLLMClient, RAGPipeline
from trace_evaluator import TRACEEvaluator
def main():
"""Example usage of RAG system."""
# Set your API key
api_key = os.getenv("GROQ_API_KEY") or "your_api_key_here"
if api_key == "your_api_key_here":
print("Please set your GROQ_API_KEY in .env file or environment variable")
return
print("=" * 50)
print("RAG System Example")
print("=" * 50)
# 1. Load dataset
print("\n1. Loading dataset...")
loader = RAGBenchLoader()
dataset = loader.load_dataset("hotpotqa", split="train", max_samples=20)
print(f"Loaded {len(dataset)} samples")
# 2. Create vector store and collection
print("\n2. Creating vector store...")
vector_store = ChromaDBManager()
collection_name = "example_collection"
embedding_model = "emilyalsentzer/Bio_ClinicalBERT"
chunking_strategy = "hybrid"
print(f"Loading data into collection with {chunking_strategy} chunking...")
vector_store.load_dataset_into_collection(
collection_name=collection_name,
embedding_model_name=embedding_model,
chunking_strategy=chunking_strategy,
dataset_data=dataset,
chunk_size=512,
overlap=50
)
# 3. Initialize LLM client
print("\n3. Initializing LLM client...")
llm_client = GroqLLMClient(
api_key=api_key,
model_name="llama-3.1-8b-instant",
max_rpm=30,
rate_limit_delay=2.0
)
# 4. Create RAG pipeline
print("\n4. Creating RAG pipeline...")
rag = RAGPipeline(llm_client, vector_store)
# 5. Query the system
print("\n5. Querying the system...")
queries = [
"What is machine learning?",
"How does neural network work?",
"What is deep learning?"
]
for i, query in enumerate(queries, 1):
print(f"\n--- Query {i}: {query} ---")
result = rag.query(query, n_results=3)
print(f"Response: {result['response']}")
print(f"\nRetrieved {len(result['retrieved_documents'])} documents:")
for j, doc in enumerate(result['retrieved_documents'], 1):
print(f"\nDocument {j} (Distance: {doc.get('distance', 'N/A')}):")
print(f"{doc['document'][:200]}...")
# 6. Run evaluation
print("\n6. Running TRACE evaluation...")
evaluator = TRACEEvaluator(llm_client)
# Prepare test cases
test_cases = []
test_samples = loader.get_test_data("hotpotqa", num_samples=5)
for sample in test_samples:
result = rag.query(sample["question"], n_results=5)
test_cases.append({
"query": sample["question"],
"response": result["response"],
"retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
"ground_truth": sample.get("answer", "")
})
results = evaluator.evaluate_batch(test_cases)
print("\nTRACE Evaluation Results:")
print(f"Utilization: {results['utilization']:.3f}")
print(f"Relevance: {results['relevance']:.3f}")
print(f"Adherence: {results['adherence']:.3f}")
print(f"Completeness: {results['completeness']:.3f}")
print(f"Average: {results['average']:.3f}")
# 7. View chat history
print("\n7. Chat History:")
history = rag.get_chat_history()
print(f"Total conversations: {len(history)}")
print("\n" + "=" * 50)
print("Example completed successfully!")
print("=" * 50)
if __name__ == "__main__":
main()