File size: 5,078 Bytes
6c5ce7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import argparse
from rag_pipelines.embeddings.dense import DenseEmbeddings
from rag_pipelines.embeddings.sparse import SparseEmbeddings
from rag_pipelines.llms.groq import ChatGroqGenerator
from rag_pipelines.pipelines.crag import CorrectiveRAGPipeline
from rag_pipelines.retrieval_evaluator.document_grader import DocumentGrader
from rag_pipelines.retrieval_evaluator.retrieval_evaluator import RetrievalEvaluator
from rag_pipelines.vectordb.pinecone_hybrid_index import PineconeHybridVectorDB
from rag_pipelines.vectordb.pinecone_hybrid_retriever import PineconeHybridRetriever
def main():
parser = argparse.ArgumentParser(description="Run the Corrective RAG pipeline.")
# Dense embeddings arguments
parser.add_argument(
"--dense_model_name",
type=str,
default="sentence-transformers/all-MiniLM-L6-v2",
help="Dense embedding model name.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to run the dense embedding model.",
)
# Sparse embeddings arguments
parser.add_argument(
"--sparse_max_seq_length",
type=int,
default=512,
help="Maximum sequence length for sparse embeddings.",
)
# Pinecone arguments
parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
parser.add_argument(
"--namespace",
type=str,
default="edgar-all",
help="Namespace for Pinecone retriever.",
)
# Retriever arguments
parser.add_argument("--alpha", type=float, default=0.5, help="Alpha parameter for hybrid retriever.")
parser.add_argument("--top_k", type=int, default=5, help="Number of top documents to retrieve.")
# LLM arguments
parser.add_argument(
"--llm_model",
type=str,
default="llama-3.2-90b-vision-preview",
help="Language model name.",
)
parser.add_argument(
"--temperature",
type=float,
default=0,
help="Temperature for the language model.",
)
parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
# Retrieval Evaluator and Document Grader arguments
parser.add_argument(
"--relevance_threshold",
type=float,
default=0.7,
help="Relevance threshold for document grading.",
)
# Query
parser.add_argument(
"--query",
type=str,
required=True,
help="Query to run through the Corrective RAG pipeline.",
)
args = parser.parse_args()
# Initialize embeddings
dense_embeddings = DenseEmbeddings(
model_name=args.dense_model_name,
model_kwargs={"device": args.device},
encode_kwargs={"normalize_embeddings": True},
show_progress=True,
)
sparse_embeddings = SparseEmbeddings(model_kwargs={"max_seq_length": args.sparse_max_seq_length})
dense_embedding_model = dense_embeddings.embedding_model
sparse_embedding_model = sparse_embeddings.sparse_embedding_model
# Initialize Pinecone vector DB
pinecone_vector_db = PineconeHybridVectorDB(
api_key=args.pinecone_api_key,
index_name=args.index_name,
dimension=args.dimension,
metric=args.metric,
region=args.region,
)
# Initialize Pinecone retriever
pinecone_retriever = PineconeHybridRetriever(
index=pinecone_vector_db.index,
dense_embedding_model=dense_embedding_model,
sparse_embedding_model=sparse_embedding_model,
alpha=args.alpha,
top_k=args.top_k,
namespace=args.namespace,
)
# Initialize RetrievalEvaluator and DocumentGrader
retrieval_evaluator = RetrievalEvaluator(
llm_model=args.llm_model,
llm_api_key=args.llm_api_key,
temperature=args.temperature,
)
document_grader = DocumentGrader(
evaluator=retrieval_evaluator,
threshold=args.relevance_threshold,
)
# Load the prompt and initialize the generator
generator = ChatGroqGenerator(
model=args.llm_model,
api_key=args.llm_api_key,
llm_params={"temperature": args.temperature},
)
llm = generator.llm
# Initialize the Corrective RAG pipeline
corrective_rag = CorrectiveRAGPipeline(
retriever=pinecone_retriever.hybrid_retriever,
prompt=retrieval_evaluator.prompt_template,
llm=llm,
document_grader=document_grader,
tracing_project_name="sec_corrective_rag",
)
# Run the pipeline
output = corrective_rag.run(args.query)
print(output)
if __name__ == "__main__":
main()
|