File size: 4,576 Bytes
6c5ce7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import argparse
from langchain_core.prompts import ChatPromptTemplate
from rag_pipelines.embeddings.dense import DenseEmbeddings
from rag_pipelines.embeddings.sparse import SparseEmbeddings
from rag_pipelines.llms.groq import ChatGroqGenerator
from rag_pipelines.pipelines.rag import RAGPipeline
from rag_pipelines.prompts.rag_prompt import RAG_PROMPT
from rag_pipelines.vectordb.pinecone_hybrid_index import PineconeHybridVectorDB
from rag_pipelines.vectordb.pinecone_hybrid_retriever import PineconeHybridRetriever
def main():
parser = argparse.ArgumentParser(description="Run the Hybrid RAG pipeline.")
# Dense embeddings arguments
parser.add_argument(
"--dense_model_name",
type=str,
default="sentence-transformers/all-MiniLM-L6-v2",
help="Dense embedding model name.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to run the dense embedding model.",
)
# Sparse embeddings arguments
parser.add_argument(
"--sparse_max_seq_length",
type=int,
default=512,
help="Maximum sequence length for sparse embeddings.",
)
# Pinecone arguments
parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
parser.add_argument("--cloud", type=str, default="aws", help="Pinecone cloud provider.")
parser.add_argument(
"--namespace",
type=str,
default="edgar-all",
help="Namespace for Pinecone retriever.",
)
# Retriever arguments
parser.add_argument("--alpha", type=float, default=0.5, help="Alpha parameter for hybrid retriever.")
parser.add_argument("--top_k", type=int, default=5, help="Number of top documents to retrieve.")
# LLM arguments
parser.add_argument(
"--llm_model",
type=str,
default="llama-3.2-90b-vision-preview",
help="Language model name.",
)
parser.add_argument(
"--temperature",
type=float,
default=0,
help="Temperature for the language model.",
)
parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
# Query
parser.add_argument(
"--query",
type=str,
required=True,
help="Query to run through the Hybrid RAG pipeline.",
)
args = parser.parse_args()
# Initialize embeddings
dense_embeddings = DenseEmbeddings(
model_name=args.dense_model_name,
model_kwargs={"device": args.device},
encode_kwargs={"normalize_embeddings": True},
show_progress=True,
)
sparse_embeddings = SparseEmbeddings(model_kwargs={"max_seq_length": args.sparse_max_seq_length})
dense_embedding_model = dense_embeddings.embedding_model
sparse_embedding_model = sparse_embeddings.sparse_embedding_model
# Initialize Pinecone vector DB
pinecone_vector_db = PineconeHybridVectorDB(
api_key=args.pinecone_api_key,
index_name=args.index_name,
dimension=args.dimension,
metric=args.metric,
region=args.region,
cloud=args.cloud,
)
# Initialize Pinecone retriever
pinecone_retriever = PineconeHybridRetriever(
index=pinecone_vector_db.index,
dense_embedding_model=dense_embedding_model,
sparse_embedding_model=sparse_embedding_model,
alpha=args.alpha,
top_k=args.top_k,
namespace=args.namespace,
)
# Load the prompt
prompt = ChatPromptTemplate.from_messages(
[
("human", RAG_PROMPT),
]
)
# Initialize the LLM
generator = ChatGroqGenerator(
model=args.llm_model,
api_key=args.llm_api_key,
llm_params={"temperature": args.temperature},
)
llm = generator.llm
# Initialize the Hybrid RAG pipeline
hybrid_rag = RAGPipeline(
retriever=pinecone_retriever.hybrid_retriever,
prompt=prompt,
llm=llm,
tracing_project_name="sec_hybrid_rag",
)
# Run the pipeline
output = hybrid_rag.predict(args.query)
print(output)
if __name__ == "__main__":
main()
|