File size: 5,078 Bytes
6c5ce7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import argparse

from rag_pipelines.embeddings.dense import DenseEmbeddings
from rag_pipelines.embeddings.sparse import SparseEmbeddings
from rag_pipelines.llms.groq import ChatGroqGenerator
from rag_pipelines.pipelines.crag import CorrectiveRAGPipeline
from rag_pipelines.retrieval_evaluator.document_grader import DocumentGrader
from rag_pipelines.retrieval_evaluator.retrieval_evaluator import RetrievalEvaluator
from rag_pipelines.vectordb.pinecone_hybrid_index import PineconeHybridVectorDB
from rag_pipelines.vectordb.pinecone_hybrid_retriever import PineconeHybridRetriever


def main():
    parser = argparse.ArgumentParser(description="Run the Corrective RAG pipeline.")

    # Dense embeddings arguments
    parser.add_argument(
        "--dense_model_name",
        type=str,
        default="sentence-transformers/all-MiniLM-L6-v2",
        help="Dense embedding model name.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        help="Device to run the dense embedding model.",
    )

    # Sparse embeddings arguments
    parser.add_argument(
        "--sparse_max_seq_length",
        type=int,
        default=512,
        help="Maximum sequence length for sparse embeddings.",
    )

    # Pinecone arguments
    parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
    parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
    parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
    parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
    parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
    parser.add_argument(
        "--namespace",
        type=str,
        default="edgar-all",
        help="Namespace for Pinecone retriever.",
    )

    # Retriever arguments
    parser.add_argument("--alpha", type=float, default=0.5, help="Alpha parameter for hybrid retriever.")
    parser.add_argument("--top_k", type=int, default=5, help="Number of top documents to retrieve.")

    # LLM arguments
    parser.add_argument(
        "--llm_model",
        type=str,
        default="llama-3.2-90b-vision-preview",
        help="Language model name.",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0,
        help="Temperature for the language model.",
    )
    parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")

    # Retrieval Evaluator and Document Grader arguments
    parser.add_argument(
        "--relevance_threshold",
        type=float,
        default=0.7,
        help="Relevance threshold for document grading.",
    )

    # Query
    parser.add_argument(
        "--query",
        type=str,
        required=True,
        help="Query to run through the Corrective RAG pipeline.",
    )

    args = parser.parse_args()

    # Initialize embeddings
    dense_embeddings = DenseEmbeddings(
        model_name=args.dense_model_name,
        model_kwargs={"device": args.device},
        encode_kwargs={"normalize_embeddings": True},
        show_progress=True,
    )
    sparse_embeddings = SparseEmbeddings(model_kwargs={"max_seq_length": args.sparse_max_seq_length})

    dense_embedding_model = dense_embeddings.embedding_model
    sparse_embedding_model = sparse_embeddings.sparse_embedding_model

    # Initialize Pinecone vector DB
    pinecone_vector_db = PineconeHybridVectorDB(
        api_key=args.pinecone_api_key,
        index_name=args.index_name,
        dimension=args.dimension,
        metric=args.metric,
        region=args.region,
    )

    # Initialize Pinecone retriever
    pinecone_retriever = PineconeHybridRetriever(
        index=pinecone_vector_db.index,
        dense_embedding_model=dense_embedding_model,
        sparse_embedding_model=sparse_embedding_model,
        alpha=args.alpha,
        top_k=args.top_k,
        namespace=args.namespace,
    )

    # Initialize RetrievalEvaluator and DocumentGrader
    retrieval_evaluator = RetrievalEvaluator(
        llm_model=args.llm_model,
        llm_api_key=args.llm_api_key,
        temperature=args.temperature,
    )
    document_grader = DocumentGrader(
        evaluator=retrieval_evaluator,
        threshold=args.relevance_threshold,
    )

    # Load the prompt and initialize the generator
    generator = ChatGroqGenerator(
        model=args.llm_model,
        api_key=args.llm_api_key,
        llm_params={"temperature": args.temperature},
    )
    llm = generator.llm

    # Initialize the Corrective RAG pipeline
    corrective_rag = CorrectiveRAGPipeline(
        retriever=pinecone_retriever.hybrid_retriever,
        prompt=retrieval_evaluator.prompt_template,
        llm=llm,
        document_grader=document_grader,
        tracing_project_name="sec_corrective_rag",
    )

    # Run the pipeline
    output = corrective_rag.run(args.query)
    print(output)


if __name__ == "__main__":
    main()