File size: 3,793 Bytes
1d10b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Example script demonstrating how to use the RAG system programmatically.
"""
import os
from config import settings
from dataset_loader import RAGBenchLoader
from vector_store import ChromaDBManager
from llm_client import GroqLLMClient, RAGPipeline
from trace_evaluator import TRACEEvaluator


def main():
    """Example usage of RAG system."""
    
    # Set your API key
    api_key = os.getenv("GROQ_API_KEY") or "your_api_key_here"
    
    if api_key == "your_api_key_here":
        print("Please set your GROQ_API_KEY in .env file or environment variable")
        return
    
    print("=" * 50)
    print("RAG System Example")
    print("=" * 50)
    
    # 1. Load dataset
    print("\n1. Loading dataset...")
    loader = RAGBenchLoader()
    dataset = loader.load_dataset("hotpotqa", split="train", max_samples=20)
    print(f"Loaded {len(dataset)} samples")
    
    # 2. Create vector store and collection
    print("\n2. Creating vector store...")
    vector_store = ChromaDBManager()
    
    collection_name = "example_collection"
    embedding_model = "emilyalsentzer/Bio_ClinicalBERT"
    chunking_strategy = "hybrid"
    
    print(f"Loading data into collection with {chunking_strategy} chunking...")
    vector_store.load_dataset_into_collection(
        collection_name=collection_name,
        embedding_model_name=embedding_model,
        chunking_strategy=chunking_strategy,
        dataset_data=dataset,
        chunk_size=512,
        overlap=50
    )
    
    # 3. Initialize LLM client
    print("\n3. Initializing LLM client...")
    llm_client = GroqLLMClient(
        api_key=api_key,
        model_name="llama-3.1-8b-instant",
        max_rpm=30,
        rate_limit_delay=2.0
    )
    
    # 4. Create RAG pipeline
    print("\n4. Creating RAG pipeline...")
    rag = RAGPipeline(llm_client, vector_store)
    
    # 5. Query the system
    print("\n5. Querying the system...")
    queries = [
        "What is machine learning?",
        "How does neural network work?",
        "What is deep learning?"
    ]
    
    for i, query in enumerate(queries, 1):
        print(f"\n--- Query {i}: {query} ---")
        result = rag.query(query, n_results=3)
        
        print(f"Response: {result['response']}")
        print(f"\nRetrieved {len(result['retrieved_documents'])} documents:")
        for j, doc in enumerate(result['retrieved_documents'], 1):
            print(f"\nDocument {j} (Distance: {doc.get('distance', 'N/A')}):")
            print(f"{doc['document'][:200]}...")
    
    # 6. Run evaluation
    print("\n6. Running TRACE evaluation...")
    evaluator = TRACEEvaluator(llm_client)
    
    # Prepare test cases
    test_cases = []
    test_samples = loader.get_test_data("hotpotqa", num_samples=5)
    
    for sample in test_samples:
        result = rag.query(sample["question"], n_results=5)
        test_cases.append({
            "query": sample["question"],
            "response": result["response"],
            "retrieved_documents": [doc["document"] for doc in result["retrieved_documents"]],
            "ground_truth": sample.get("answer", "")
        })
    
    results = evaluator.evaluate_batch(test_cases)
    
    print("\nTRACE Evaluation Results:")
    print(f"Utilization:  {results['utilization']:.3f}")
    print(f"Relevance:    {results['relevance']:.3f}")
    print(f"Adherence:    {results['adherence']:.3f}")
    print(f"Completeness: {results['completeness']:.3f}")
    print(f"Average:      {results['average']:.3f}")
    
    # 7. View chat history
    print("\n7. Chat History:")
    history = rag.get_chat_history()
    print(f"Total conversations: {len(history)}")
    
    print("\n" + "=" * 50)
    print("Example completed successfully!")
    print("=" * 50)


if __name__ == "__main__":
    main()