awinml commited on
Commit
6c5ce7a
·
verified ·
1 Parent(s): 8a1f4a8

Upload 11 files

Browse files
scripts/crag.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from rag_pipelines.embeddings.dense import DenseEmbeddings
4
+ from rag_pipelines.embeddings.sparse import SparseEmbeddings
5
+ from rag_pipelines.llms.groq import ChatGroqGenerator
6
+ from rag_pipelines.pipelines.crag import CorrectiveRAGPipeline
7
+ from rag_pipelines.retrieval_evaluator.document_grader import DocumentGrader
8
+ from rag_pipelines.retrieval_evaluator.retrieval_evaluator import RetrievalEvaluator
9
+ from rag_pipelines.vectordb.pinecone_hybrid_index import PineconeHybridVectorDB
10
+ from rag_pipelines.vectordb.pinecone_hybrid_retriever import PineconeHybridRetriever
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser(description="Run the Corrective RAG pipeline.")
15
+
16
+ # Dense embeddings arguments
17
+ parser.add_argument(
18
+ "--dense_model_name",
19
+ type=str,
20
+ default="sentence-transformers/all-MiniLM-L6-v2",
21
+ help="Dense embedding model name.",
22
+ )
23
+ parser.add_argument(
24
+ "--device",
25
+ type=str,
26
+ default="cpu",
27
+ help="Device to run the dense embedding model.",
28
+ )
29
+
30
+ # Sparse embeddings arguments
31
+ parser.add_argument(
32
+ "--sparse_max_seq_length",
33
+ type=int,
34
+ default=512,
35
+ help="Maximum sequence length for sparse embeddings.",
36
+ )
37
+
38
+ # Pinecone arguments
39
+ parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
40
+ parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
41
+ parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
42
+ parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
43
+ parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
44
+ parser.add_argument(
45
+ "--namespace",
46
+ type=str,
47
+ default="edgar-all",
48
+ help="Namespace for Pinecone retriever.",
49
+ )
50
+
51
+ # Retriever arguments
52
+ parser.add_argument("--alpha", type=float, default=0.5, help="Alpha parameter for hybrid retriever.")
53
+ parser.add_argument("--top_k", type=int, default=5, help="Number of top documents to retrieve.")
54
+
55
+ # LLM arguments
56
+ parser.add_argument(
57
+ "--llm_model",
58
+ type=str,
59
+ default="llama-3.2-90b-vision-preview",
60
+ help="Language model name.",
61
+ )
62
+ parser.add_argument(
63
+ "--temperature",
64
+ type=float,
65
+ default=0,
66
+ help="Temperature for the language model.",
67
+ )
68
+ parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
69
+
70
+ # Retrieval Evaluator and Document Grader arguments
71
+ parser.add_argument(
72
+ "--relevance_threshold",
73
+ type=float,
74
+ default=0.7,
75
+ help="Relevance threshold for document grading.",
76
+ )
77
+
78
+ # Query
79
+ parser.add_argument(
80
+ "--query",
81
+ type=str,
82
+ required=True,
83
+ help="Query to run through the Corrective RAG pipeline.",
84
+ )
85
+
86
+ args = parser.parse_args()
87
+
88
+ # Initialize embeddings
89
+ dense_embeddings = DenseEmbeddings(
90
+ model_name=args.dense_model_name,
91
+ model_kwargs={"device": args.device},
92
+ encode_kwargs={"normalize_embeddings": True},
93
+ show_progress=True,
94
+ )
95
+ sparse_embeddings = SparseEmbeddings(model_kwargs={"max_seq_length": args.sparse_max_seq_length})
96
+
97
+ dense_embedding_model = dense_embeddings.embedding_model
98
+ sparse_embedding_model = sparse_embeddings.sparse_embedding_model
99
+
100
+ # Initialize Pinecone vector DB
101
+ pinecone_vector_db = PineconeHybridVectorDB(
102
+ api_key=args.pinecone_api_key,
103
+ index_name=args.index_name,
104
+ dimension=args.dimension,
105
+ metric=args.metric,
106
+ region=args.region,
107
+ )
108
+
109
+ # Initialize Pinecone retriever
110
+ pinecone_retriever = PineconeHybridRetriever(
111
+ index=pinecone_vector_db.index,
112
+ dense_embedding_model=dense_embedding_model,
113
+ sparse_embedding_model=sparse_embedding_model,
114
+ alpha=args.alpha,
115
+ top_k=args.top_k,
116
+ namespace=args.namespace,
117
+ )
118
+
119
+ # Initialize RetrievalEvaluator and DocumentGrader
120
+ retrieval_evaluator = RetrievalEvaluator(
121
+ llm_model=args.llm_model,
122
+ llm_api_key=args.llm_api_key,
123
+ temperature=args.temperature,
124
+ )
125
+ document_grader = DocumentGrader(
126
+ evaluator=retrieval_evaluator,
127
+ threshold=args.relevance_threshold,
128
+ )
129
+
130
+ # Load the prompt and initialize the generator
131
+ generator = ChatGroqGenerator(
132
+ model=args.llm_model,
133
+ api_key=args.llm_api_key,
134
+ llm_params={"temperature": args.temperature},
135
+ )
136
+ llm = generator.llm
137
+
138
+ # Initialize the Corrective RAG pipeline
139
+ corrective_rag = CorrectiveRAGPipeline(
140
+ retriever=pinecone_retriever.hybrid_retriever,
141
+ prompt=retrieval_evaluator.prompt_template,
142
+ llm=llm,
143
+ document_grader=document_grader,
144
+ tracing_project_name="sec_corrective_rag",
145
+ )
146
+
147
+ # Run the pipeline
148
+ output = corrective_rag.run(args.query)
149
+ print(output)
150
+
151
+
152
+ if __name__ == "__main__":
153
+ main()
scripts/crag_evaluation.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from dataloaders import (
4
+ ARCDataloader,
5
+ EdgarDataLoader,
6
+ FactScoreDataloader,
7
+ PopQADataloader,
8
+ TriviaQADataloader,
9
+ )
10
+
11
+ from rag_pipelines.embeddings.dense import DenseEmbeddings
12
+ from rag_pipelines.embeddings.sparse import SparseEmbeddings
13
+ from rag_pipelines.evaluation import (
14
+ AnswerRelevancyScorer,
15
+ ContextualPrecisionScorer,
16
+ ContextualRecallScorer,
17
+ ContextualRelevancyScorer,
18
+ Evaluator,
19
+ FaithfulnessScorer,
20
+ HallucinationScorer,
21
+ SummarizationScorer,
22
+ )
23
+ from rag_pipelines.evaluation.evaluator import Evaluator
24
+ from rag_pipelines.llms.groq import ChatGroqGenerator
25
+ from rag_pipelines.pipelines.crag import CorrectiveRAGPipeline
26
+ from rag_pipelines.retrieval_evaluator.document_grader import DocumentGrader
27
+ from rag_pipelines.retrieval_evaluator.retrieval_evaluator import RetrievalEvaluator
28
+ from rag_pipelines.vectordb.pinecone_hybrid_index import PineconeHybridVectorDB
29
+ from rag_pipelines.vectordb.pinecone_hybrid_retriever import PineconeHybridRetriever
30
+
31
+ SUPPORTED_DATASETS = {
32
+ "arc": ARCDataloader,
33
+ "edgar": EdgarDataLoader,
34
+ "popqa": PopQADataloader,
35
+ "factscore": FactScoreDataloader,
36
+ "triviaqa": TriviaQADataloader,
37
+ }
38
+
39
+ SCORER_CLASSES = {
40
+ "contextual_precision": ContextualPrecisionScorer,
41
+ "contextual_recall": ContextualRecallScorer,
42
+ "contextual_relevancy": ContextualRelevancyScorer,
43
+ "answer_relevancy": AnswerRelevancyScorer,
44
+ "faithfulness": FaithfulnessScorer,
45
+ "summarization": SummarizationScorer,
46
+ "hallucination": HallucinationScorer,
47
+ }
48
+
49
+
50
+ def main():
51
+ parser = argparse.ArgumentParser(description="Run the Corrective RAG pipeline.")
52
+
53
+ # Dense embeddings arguments
54
+ parser.add_argument(
55
+ "--dense_model_name",
56
+ type=str,
57
+ default="sentence-transformers/all-MiniLM-L6-v2",
58
+ help="Dense embedding model name.",
59
+ )
60
+ parser.add_argument(
61
+ "--device",
62
+ type=str,
63
+ default="cpu",
64
+ help="Device to run the dense embedding model.",
65
+ )
66
+
67
+ # Sparse embeddings arguments
68
+ parser.add_argument(
69
+ "--sparse_max_seq_length",
70
+ type=int,
71
+ default=512,
72
+ help="Maximum sequence length for sparse embeddings.",
73
+ )
74
+
75
+ # Pinecone arguments
76
+ parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
77
+ parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
78
+ parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
79
+ parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
80
+ parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
81
+ parser.add_argument(
82
+ "--namespace",
83
+ type=str,
84
+ default="edgar-all",
85
+ help="Namespace for Pinecone retriever.",
86
+ )
87
+
88
+ # Retriever arguments
89
+ parser.add_argument("--alpha", type=float, default=0.5, help="Alpha parameter for hybrid retriever.")
90
+ parser.add_argument("--top_k", type=int, default=5, help="Number of top documents to retrieve.")
91
+
92
+ # LLM arguments
93
+ parser.add_argument(
94
+ "--llm_model",
95
+ type=str,
96
+ default="llama-3.2-90b-vision-preview",
97
+ help="Language model name.",
98
+ )
99
+ parser.add_argument(
100
+ "--temperature",
101
+ type=float,
102
+ default=0,
103
+ help="Temperature for the language model.",
104
+ )
105
+ parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
106
+
107
+ # Retrieval Evaluator and Document Grader arguments
108
+ parser.add_argument(
109
+ "--relevance_threshold",
110
+ type=float,
111
+ default=0.7,
112
+ help="Relevance threshold for document grading.",
113
+ )
114
+
115
+ # Load evaluation data
116
+ parser = argparse.ArgumentParser(description="Load evaluation dataset and initialize the dataloader.")
117
+ parser.add_argument(
118
+ "--dataset_type",
119
+ type=str,
120
+ default="edgar",
121
+ choices=SUPPORTED_DATASETS.keys(),
122
+ help="Dataset to load from. Options: arc, edgar, popqa, factscore, triviaqa.",
123
+ )
124
+ parser.add_argument(
125
+ "--hf_dataset_path",
126
+ type=str,
127
+ default="lamini/earnings-calls-qa",
128
+ help="Path to the HuggingFace dataset.",
129
+ )
130
+ parser.add_argument(
131
+ "--dataset_split",
132
+ type=str,
133
+ default="test",
134
+ help="Split of the dataset to use (e.g., train, validation, test).",
135
+ )
136
+
137
+ # Scorer arguments
138
+ parser.add_argument(
139
+ "--scorer",
140
+ type=str,
141
+ default="contextual_precision",
142
+ choices=[
143
+ "contextual_precision",
144
+ "contextual_recall",
145
+ "contextual_relevancy",
146
+ "answer_relevancy",
147
+ "faithfulness",
148
+ "summarization",
149
+ "hallucination",
150
+ ],
151
+ help="Scorer to use.",
152
+ )
153
+
154
+ # Evaluation arguments
155
+ parser.add_argument(
156
+ "--evaluation_name",
157
+ type=str,
158
+ default="hybrid_rag",
159
+ help="Name of the evaluation.",
160
+ )
161
+
162
+ # Add argument for selecting scorers
163
+ parser.add_argument(
164
+ "--scorers",
165
+ type=str,
166
+ nargs="+",
167
+ choices=SCORER_CLASSES.keys(),
168
+ required=True,
169
+ help="List of scorers to use. Options: contextual_precision, contextual_recall, contextual_relevancy, "
170
+ "answer_relevancy, faithfulness, summarization, hallucination.",
171
+ )
172
+
173
+ # Add shared arguments for scorer parameters
174
+ parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for evaluation.")
175
+ parser.add_argument("--model", type=str, default="gpt-4", help="Model to use for scoring.")
176
+ parser.add_argument("--include_reason", action="store_true", help="Include reasons in scoring.")
177
+ parser.add_argument(
178
+ "--assessment_questions",
179
+ type=str,
180
+ nargs="*",
181
+ help="List of assessment questions for scoring.",
182
+ )
183
+ parser.add_argument("--strict_mode", action="store_true", help="Enable strict mode for scoring.")
184
+ parser.add_argument("--async_mode", action="store_true", help="Enable asynchronous processing.")
185
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output.")
186
+ parser.add_argument(
187
+ "--truths_extraction_limit",
188
+ type=int,
189
+ default=None,
190
+ help="Limit for truth extraction in scoring.",
191
+ )
192
+
193
+ args = parser.parse_args()
194
+
195
+ # Initialize dataloader based on the dataset type
196
+ try:
197
+ DataLoaderClass = SUPPORTED_DATASETS[args.dataset_type]
198
+ dataloader = DataLoaderClass(dataset_name=args.hf_dataset_path, split=args.dataset_split)
199
+ except KeyError:
200
+ msg = (
201
+ f"Dataset '{args.dataset_type}' is not supported. "
202
+ f"Supported options are: {', '.join(SUPPORTED_DATASETS.keys())}."
203
+ )
204
+ raise ValueError(msg)
205
+
206
+ eval_dataset = dataloader.get_eval_data()
207
+
208
+ # Initialize embeddings
209
+ dense_embeddings = DenseEmbeddings(
210
+ model_name=args.dense_model_name,
211
+ model_kwargs={"device": args.device},
212
+ encode_kwargs={"normalize_embeddings": True},
213
+ show_progress=True,
214
+ )
215
+ sparse_embeddings = SparseEmbeddings(model_kwargs={"max_seq_length": args.sparse_max_seq_length})
216
+
217
+ dense_embedding_model = dense_embeddings.embedding_model
218
+ sparse_embedding_model = sparse_embeddings.sparse_embedding_model
219
+
220
+ # Initialize Pinecone vector DB
221
+ pinecone_vector_db = PineconeHybridVectorDB(
222
+ api_key=args.pinecone_api_key,
223
+ index_name=args.index_name,
224
+ dimension=args.dimension,
225
+ metric=args.metric,
226
+ region=args.region,
227
+ )
228
+
229
+ # Initialize Pinecone retriever
230
+ pinecone_retriever = PineconeHybridRetriever(
231
+ index=pinecone_vector_db.index,
232
+ dense_embedding_model=dense_embedding_model,
233
+ sparse_embedding_model=sparse_embedding_model,
234
+ alpha=args.alpha,
235
+ top_k=args.top_k,
236
+ namespace=args.namespace,
237
+ )
238
+
239
+ # Initialize RetrievalEvaluator and DocumentGrader
240
+ retrieval_evaluator = RetrievalEvaluator(
241
+ llm_model=args.llm_model,
242
+ llm_api_key=args.llm_api_key,
243
+ temperature=args.temperature,
244
+ )
245
+ document_grader = DocumentGrader(
246
+ evaluator=retrieval_evaluator,
247
+ threshold=args.relevance_threshold,
248
+ )
249
+
250
+ # Load the prompt and initialize the generator
251
+ generator = ChatGroqGenerator(
252
+ model=args.llm_model,
253
+ api_key=args.llm_api_key,
254
+ llm_params={"temperature": args.temperature},
255
+ )
256
+ llm = generator.llm
257
+
258
+ # Initialize the Corrective RAG pipeline
259
+ corrective_rag = CorrectiveRAGPipeline(
260
+ retriever=pinecone_retriever.hybrid_retriever,
261
+ prompt=retrieval_evaluator.prompt_template,
262
+ llm=llm,
263
+ document_grader=document_grader,
264
+ tracing_project_name="sec_corrective_rag",
265
+ )
266
+
267
+ # Initialize the scorers with the provided arguments
268
+ scorers = []
269
+ for scorer_name in args.scorers:
270
+ if scorer_name in SCORER_CLASSES:
271
+ ScorerClass = SCORER_CLASSES[scorer_name]
272
+ scorer = ScorerClass(
273
+ threshold=args.threshold,
274
+ model=args.model,
275
+ include_reason=args.include_reason,
276
+ assessment_questions=args.assessment_questions,
277
+ strict_mode=args.strict_mode,
278
+ async_mode=args.async_mode,
279
+ verbose=args.verbose,
280
+ truths_extraction_limit=args.truths_extraction_limit,
281
+ )
282
+ scorers.append(scorer)
283
+ else:
284
+ msg = f"Scorer '{scorer_name}' is not supported."
285
+ raise ValueError(msg)
286
+
287
+ # Run the pipeline
288
+ evaluator = Evaluator(
289
+ evaluation_name=args.evaluation_name,
290
+ pipeline=corrective_rag,
291
+ dataset=eval_dataset,
292
+ scorers=[scorers],
293
+ )
294
+
295
+ evaluation_results = evaluator.evaluate()
296
+ print(evaluation_results)
297
+
298
+
299
+ if __name__ == "__main__":
300
+ main()
scripts/dspy_rag.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import weave
4
+ from dataloaders.langchain import FinanceBenchDataloader
5
+ from dspy import LM
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+
8
+ from rag_pipelines.embeddings import SparseEmbeddingsMilvus as SparseEmbeddings
9
+ from rag_pipelines.evaluation import AnswerRelevancyScorer, ContextualPrecisionScorer, Evaluator
10
+ from rag_pipelines.pipelines import RAG, DSPyRAGPipeline
11
+ from rag_pipelines.vectordb import DSPyMilvusRetriever as MilvusRetriever
12
+ from rag_pipelines.vectordb import MilvusVectorDB, milvus_retriever
13
+
14
+ os.environ["WEAVE_PARALLELISM"] = "1"
15
+ os.environ["WEAVE_TRACE_LANGCHAIN"] = "false"
16
+
17
+ dense_model = "intfloat/multilingual-e5-large"
18
+ encode_kwargs = {"prompt": "query: "}
19
+ model_kwargs = {
20
+ "device": "cpu",
21
+ "trust_remote_code": True,
22
+ "backend": "onnx",
23
+ "model_kwargs": {"file_name": "onnx/model.onnx"},
24
+ }
25
+
26
+ ##### Use the e5-large-instruct model for everything now
27
+
28
+ dense_field = "dense_vector"
29
+ sparse_field = "sparse_vector"
30
+ text_field = "text"
31
+ metadata_field = "metadata"
32
+
33
+ dense_search_params = {
34
+ "metric_type": "COSINE",
35
+ }
36
+
37
+ sparse_search_params = {
38
+ "metric_type": "IP",
39
+ }
40
+
41
+ milvus_uri = "https://in03-8aaa331b36bf39c.serverless.gcp-us-west1.cloud.zilliz.com"
42
+ milvus_token = (
43
+ "cd567c8418a6b8fe4b438300cfc56212f22ef1347bc12031b0114bd72ba0aec3978ce8c107c11a4ae01239b010c15765358cdf37"
44
+ )
45
+ milvus_collection_name = "financebenchsub"
46
+
47
+
48
+ tracing_project_name = "dspy_rag"
49
+ weave_params = {}
50
+
51
+ client = weave.init(tracing_project_name, **weave_params)
52
+
53
+ dense_embeddings = HuggingFaceEmbeddings(model_name=dense_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)
54
+ sparse_embeddings = SparseEmbeddings(model_name="Splade_PP_en_v1")
55
+
56
+ milvus_vector_db = MilvusVectorDB(
57
+ uri=milvus_uri,
58
+ token=milvus_token,
59
+ collection_name=milvus_collection_name,
60
+ )
61
+
62
+ milvus_retriever = MilvusRetriever(
63
+ collection=milvus_vector_db.collection,
64
+ dense_embedding_model=dense_embeddings,
65
+ sparse_embedding_model=sparse_embeddings,
66
+ anns_fields=[dense_field, sparse_field],
67
+ field_search_params=[dense_search_params, sparse_search_params],
68
+ text_field=text_field,
69
+ top_k=3,
70
+ )
71
+
72
+ llm = LM(
73
+ "groq/llama-3.3-70b-versatile",
74
+ api_key="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",
75
+ num_retries=120,
76
+ )
77
+
78
+ dspy_rag_module = RAG(milvus_retriever)
79
+
80
+ pipeline = DSPyRAGPipeline(rag_module=dspy_rag_module, llm=llm)
81
+
82
+ dataloader = FinanceBenchDataloader(
83
+ dataset_name="PatronusAI/financebench",
84
+ split="train[:2]",
85
+ )
86
+
87
+ data = dataloader.load_data()
88
+ eval_data = dataloader.get_evaluation_data()
89
+
90
+ evaluation_dataset = weave.Dataset(name="financebench_test_evaluation_dataset", rows=eval_data)
91
+
92
+ questions = dataloader.get_questions()
93
+
94
+ # dataloader.publish_to_weave(
95
+ # weave_project_name="financebench_test",
96
+ # dataset_name="financebench_test_dataset",
97
+ # evaluation_dataset_name="financebench_test_evaluation_dataset",
98
+ # )
99
+
100
+ answer_relevancy_scorer = AnswerRelevancyScorer(
101
+ threshold=0.5,
102
+ model="gpt-4",
103
+ include_reason=True,
104
+ verbose=True,
105
+ )
106
+ contextual_precision_scorer = ContextualPrecisionScorer(
107
+ threshold=0.5,
108
+ model="gpt-4",
109
+ include_reason=True,
110
+ verbose=True,
111
+ )
112
+
113
+ evaluator = Evaluator(
114
+ evaluation_name="financebench_test_subset_2",
115
+ evaluation_dataset=evaluation_dataset,
116
+ evaluation_scorers=[answer_relevancy_scorer, contextual_precision_scorer],
117
+ pipeline=pipeline,
118
+ )
119
+
120
+ evaluation_results = evaluator.evaluate()
121
+ print(evaluation_results)
scripts/hybrid_rag.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from langchain_core.prompts import ChatPromptTemplate
4
+
5
+ from rag_pipelines.embeddings.dense import DenseEmbeddings
6
+ from rag_pipelines.embeddings.sparse import SparseEmbeddings
7
+ from rag_pipelines.llms.groq import ChatGroqGenerator
8
+ from rag_pipelines.pipelines.rag import RAGPipeline
9
+ from rag_pipelines.prompts.rag_prompt import RAG_PROMPT
10
+ from rag_pipelines.vectordb.pinecone_hybrid_index import PineconeHybridVectorDB
11
+ from rag_pipelines.vectordb.pinecone_hybrid_retriever import PineconeHybridRetriever
12
+
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser(description="Run the Hybrid RAG pipeline.")
16
+
17
+ # Dense embeddings arguments
18
+ parser.add_argument(
19
+ "--dense_model_name",
20
+ type=str,
21
+ default="sentence-transformers/all-MiniLM-L6-v2",
22
+ help="Dense embedding model name.",
23
+ )
24
+ parser.add_argument(
25
+ "--device",
26
+ type=str,
27
+ default="cpu",
28
+ help="Device to run the dense embedding model.",
29
+ )
30
+
31
+ # Sparse embeddings arguments
32
+ parser.add_argument(
33
+ "--sparse_max_seq_length",
34
+ type=int,
35
+ default=512,
36
+ help="Maximum sequence length for sparse embeddings.",
37
+ )
38
+
39
+ # Pinecone arguments
40
+ parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
41
+ parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
42
+ parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
43
+ parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
44
+ parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
45
+ parser.add_argument("--cloud", type=str, default="aws", help="Pinecone cloud provider.")
46
+ parser.add_argument(
47
+ "--namespace",
48
+ type=str,
49
+ default="edgar-all",
50
+ help="Namespace for Pinecone retriever.",
51
+ )
52
+
53
+ # Retriever arguments
54
+ parser.add_argument("--alpha", type=float, default=0.5, help="Alpha parameter for hybrid retriever.")
55
+ parser.add_argument("--top_k", type=int, default=5, help="Number of top documents to retrieve.")
56
+
57
+ # LLM arguments
58
+ parser.add_argument(
59
+ "--llm_model",
60
+ type=str,
61
+ default="llama-3.2-90b-vision-preview",
62
+ help="Language model name.",
63
+ )
64
+ parser.add_argument(
65
+ "--temperature",
66
+ type=float,
67
+ default=0,
68
+ help="Temperature for the language model.",
69
+ )
70
+ parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
71
+
72
+ # Query
73
+ parser.add_argument(
74
+ "--query",
75
+ type=str,
76
+ required=True,
77
+ help="Query to run through the Hybrid RAG pipeline.",
78
+ )
79
+
80
+ args = parser.parse_args()
81
+
82
+ # Initialize embeddings
83
+ dense_embeddings = DenseEmbeddings(
84
+ model_name=args.dense_model_name,
85
+ model_kwargs={"device": args.device},
86
+ encode_kwargs={"normalize_embeddings": True},
87
+ show_progress=True,
88
+ )
89
+ sparse_embeddings = SparseEmbeddings(model_kwargs={"max_seq_length": args.sparse_max_seq_length})
90
+
91
+ dense_embedding_model = dense_embeddings.embedding_model
92
+ sparse_embedding_model = sparse_embeddings.sparse_embedding_model
93
+
94
+ # Initialize Pinecone vector DB
95
+ pinecone_vector_db = PineconeHybridVectorDB(
96
+ api_key=args.pinecone_api_key,
97
+ index_name=args.index_name,
98
+ dimension=args.dimension,
99
+ metric=args.metric,
100
+ region=args.region,
101
+ cloud=args.cloud,
102
+ )
103
+
104
+ # Initialize Pinecone retriever
105
+ pinecone_retriever = PineconeHybridRetriever(
106
+ index=pinecone_vector_db.index,
107
+ dense_embedding_model=dense_embedding_model,
108
+ sparse_embedding_model=sparse_embedding_model,
109
+ alpha=args.alpha,
110
+ top_k=args.top_k,
111
+ namespace=args.namespace,
112
+ )
113
+
114
+ # Load the prompt
115
+
116
+ prompt = ChatPromptTemplate.from_messages(
117
+ [
118
+ ("human", RAG_PROMPT),
119
+ ]
120
+ )
121
+
122
+ # Initialize the LLM
123
+ generator = ChatGroqGenerator(
124
+ model=args.llm_model,
125
+ api_key=args.llm_api_key,
126
+ llm_params={"temperature": args.temperature},
127
+ )
128
+ llm = generator.llm
129
+
130
+ # Initialize the Hybrid RAG pipeline
131
+ hybrid_rag = RAGPipeline(
132
+ retriever=pinecone_retriever.hybrid_retriever,
133
+ prompt=prompt,
134
+ llm=llm,
135
+ tracing_project_name="sec_hybrid_rag",
136
+ )
137
+
138
+ # Run the pipeline
139
+ output = hybrid_rag.predict(args.query)
140
+ print(output)
141
+
142
+
143
+ if __name__ == "__main__":
144
+ main()
scripts/indexing_financebench_milvus.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from dataloaders.langchain import FinanceBenchDataloader
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from pymilvus import CollectionSchema, DataType, FieldSchema
6
+
7
+ from rag_pipelines.embeddings import SparseEmbeddingsMilvus as SparseEmbeddings
8
+ from rag_pipelines.unstructured import UnstructuredChunker, UnstructuredDocumentLoader
9
+ from rag_pipelines.utils import dict_type
10
+ from rag_pipelines.vectordb import MilvusVectorDB
11
+
12
+
13
+ def parse_arguments() -> argparse.Namespace:
14
+ """Parse command-line arguments.
15
+
16
+ Returns:
17
+ argparse.Namespace: Parsed command-line arguments.
18
+ """
19
+ parser = argparse.ArgumentParser(
20
+ description="Run the FinanceBench pipeline to load, process, chunk, embed, and index documents."
21
+ )
22
+
23
+ # FinanceBench dataset parameters
24
+ parser.add_argument(
25
+ "--dataset_name",
26
+ type=str,
27
+ default="PatronusAI/financebench",
28
+ help="HuggingFace dataset name.",
29
+ )
30
+ parser.add_argument(
31
+ "--split",
32
+ type=str,
33
+ default="train",
34
+ help="Dataset split to use (e.g., 'train').",
35
+ )
36
+
37
+ # PDF directory for unstructured document loader
38
+ parser.add_argument(
39
+ "--pdf_dir",
40
+ type=str,
41
+ default="pdfs/",
42
+ help="Directory path containing PDF files.",
43
+ )
44
+
45
+ # UnstructuredDocumentLoader parameters
46
+ parser.add_argument(
47
+ "--strategy",
48
+ type=str,
49
+ default="fast",
50
+ help="Processing strategy for the unstructured document loader.",
51
+ )
52
+ parser.add_argument(
53
+ "--mode",
54
+ type=str,
55
+ default="elements",
56
+ help="Extraction mode for the unstructured document loader.",
57
+ )
58
+
59
+ # Milvus connection parameters
60
+ parser.add_argument(
61
+ "--milvus_uri",
62
+ type=str,
63
+ help="URI for the Milvus server.",
64
+ )
65
+ parser.add_argument(
66
+ "--milvus_token",
67
+ type=str,
68
+ help="Authentication token for Milvus.",
69
+ )
70
+ parser.add_argument(
71
+ "--collection_name",
72
+ type=str,
73
+ default="financebench",
74
+ help="Name of the Milvus collection to create/use.",
75
+ )
76
+
77
+ # Dense embedding model parameters
78
+ parser.add_argument(
79
+ "--dense_embedding_model",
80
+ type=str,
81
+ default="sentence-transformers/all-mpnet-base-v2",
82
+ help="Model name for dense embeddings.",
83
+ )
84
+ parser.add_argument(
85
+ "--dense_model_kwargs",
86
+ type=dict_type,
87
+ default='{"device": "cpu", "trust_remote_code": true}',
88
+ help="Keyword arguments for dense embeddings model initialization.",
89
+ )
90
+ parser.add_argument(
91
+ "--dense_encode_kwargs",
92
+ type=dict_type,
93
+ default='{"normalize_embeddings": true}',
94
+ help="Keyword arguments for dense embeddings encoding.",
95
+ )
96
+
97
+ # Sparse embedding model parameters
98
+ parser.add_argument(
99
+ "--sparse_embedding_model",
100
+ type=str,
101
+ default="Splade_PP_en_v1",
102
+ help="Model name for sparse embeddings.",
103
+ )
104
+
105
+ # Schema configuration parameters
106
+
107
+ # Field names
108
+ parser.add_argument(
109
+ "--pk_field",
110
+ type=str,
111
+ default="doc_id",
112
+ help="Name of the primary key field.",
113
+ )
114
+ parser.add_argument(
115
+ "--dense_field",
116
+ type=str,
117
+ default="dense_vector",
118
+ help="Name of the dense vector field.",
119
+ )
120
+ parser.add_argument(
121
+ "--sparse_field",
122
+ type=str,
123
+ default="sparse_vector",
124
+ help="Name of the sparse vector field.",
125
+ )
126
+ parser.add_argument(
127
+ "--text_field",
128
+ type=str,
129
+ default="text",
130
+ help="Name of the text field.",
131
+ )
132
+ parser.add_argument(
133
+ "--metadata_field",
134
+ type=str,
135
+ default="metadata",
136
+ help="Name of the metadata field.",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--dense_dim",
141
+ type=int,
142
+ default=768,
143
+ help="Dimension of dense embeddings.",
144
+ )
145
+ parser.add_argument(
146
+ "--pk_max_length",
147
+ type=int,
148
+ default=100,
149
+ help="Max length for the primary key field.",
150
+ )
151
+ parser.add_argument(
152
+ "--text_max_length",
153
+ type=int,
154
+ default=65535,
155
+ help="Max length for the text field.",
156
+ )
157
+
158
+ # Index parameters
159
+ parser.add_argument(
160
+ "--dense_index_params",
161
+ type=dict_type,
162
+ default='{"index_type": "FLAT", "metric_type": "IP"}',
163
+ help="JSON string specifying dense index parameters.",
164
+ )
165
+ parser.add_argument(
166
+ "--sparse_index_params",
167
+ type=dict_type,
168
+ default='{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}',
169
+ help="JSON string specifying sparse index parameters.",
170
+ )
171
+
172
+ # Collection creation flag
173
+ parser.add_argument(
174
+ "--create_new_collection",
175
+ action="store_true",
176
+ help="Create a new collection or use existing. Defaults to False.",
177
+ )
178
+
179
+ return parser.parse_args()
180
+
181
+
182
+ def main() -> None:
183
+ """Run the FinanceBench document processing pipeline.
184
+
185
+ This function performs the following steps:
186
+ 1. Loads the FinanceBench dataset.
187
+ 2. Retrieves PDF documents from the specified directory.
188
+ 3. Processes PDFs using the UnstructuredDocumentLoader.
189
+ 4. Chunks documents using the UnstructuredChunker.
190
+ 5. Generates dense and sparse embeddings with specified parameters.
191
+ 6. Sets up a Milvus vector database and indexes the documents.
192
+ """
193
+ args = parse_arguments()
194
+
195
+ # Initialize FinanceBench dataloader and load the corpus PDFs
196
+ dataloader = FinanceBenchDataloader(
197
+ dataset_name=args.dataset_name,
198
+ split=args.split,
199
+ )
200
+
201
+ # Load and transform PDF documents from the provided directory
202
+ unstructured_document_loader = UnstructuredDocumentLoader(
203
+ strategy=args.strategy,
204
+ mode=args.mode,
205
+ )
206
+
207
+ # Chunk the documents using the UnstructuredChunker
208
+ chunker = UnstructuredChunker()
209
+
210
+ # Initialize dense and sparse embedding models with additional parameters
211
+ dense_embeddings = HuggingFaceEmbeddings(
212
+ model_name=args.dense_embedding_model,
213
+ model_kwargs=args.dense_model_kwargs,
214
+ encode_kwargs=args.dense_encode_kwargs,
215
+ )
216
+ sparse_embeddings = SparseEmbeddings(
217
+ model_name=args.sparse_embedding_model,
218
+ )
219
+
220
+ # Define Milvus collection fields and schema
221
+ pk_field = args.pk_field
222
+ dense_field = args.dense_field
223
+ sparse_field = args.sparse_field
224
+ text_field = args.text_field
225
+ metadata_field = args.metadata_field
226
+
227
+ fields = [
228
+ FieldSchema(
229
+ name=pk_field,
230
+ dtype=DataType.VARCHAR,
231
+ is_primary=True,
232
+ auto_id=True,
233
+ max_length=args.pk_max_length,
234
+ ),
235
+ FieldSchema(name=dense_field, dtype=DataType.FLOAT_VECTOR, dim=args.dense_dim),
236
+ FieldSchema(name=sparse_field, dtype=DataType.SPARSE_FLOAT_VECTOR),
237
+ FieldSchema(name=text_field, dtype=DataType.VARCHAR, max_length=args.text_max_length),
238
+ FieldSchema(name=metadata_field, dtype=DataType.JSON),
239
+ ]
240
+ schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
241
+
242
+ # Initialize the Milvus vector database client
243
+ milvus_vector_db = MilvusVectorDB(
244
+ uri=args.milvus_uri,
245
+ token=args.milvus_token,
246
+ collection_name=args.collection_name,
247
+ collection_schema=schema,
248
+ dense_field=dense_field,
249
+ sparse_field=sparse_field,
250
+ text_field=text_field,
251
+ metadata_field=metadata_field,
252
+ dense_index_params=args.dense_index_params,
253
+ sparse_index_params=args.sparse_index_params,
254
+ create_new_collection=args.create_new_collection,
255
+ )
256
+
257
+ # Add documents to the Milvus vector database
258
+ dataloader.get_corpus_pdfs()
259
+ documents = unstructured_document_loader.transform_documents(args.pdf_dir)
260
+ chunked_documents = chunker.transform_documents(documents)
261
+ milvus_vector_db.add_documents(
262
+ documents=chunked_documents,
263
+ dense_embedding_model=dense_embeddings,
264
+ sparse_embedding_model=sparse_embeddings,
265
+ )
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
scripts/indexing_financebench_milvus_voyage.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from dataloaders.langchain import FinanceBenchDataloader
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from pymilvus import CollectionSchema, DataType, FieldSchema
6
+
7
+ from rag_pipelines.embeddings import SparseEmbeddingsMilvus as SparseEmbeddings
8
+ from rag_pipelines.unstructured import UnstructuredChunker, UnstructuredDocumentLoader
9
+ from rag_pipelines.utils import dict_type
10
+ from rag_pipelines.vectordb import MilvusVectorDB
11
+
12
+
13
+ def parse_arguments() -> argparse.Namespace:
14
+ """Parse command-line arguments.
15
+
16
+ Returns:
17
+ argparse.Namespace: Parsed command-line arguments.
18
+ """
19
+ parser = argparse.ArgumentParser(
20
+ description="Run the FinanceBench pipeline to load, process, chunk, embed, and index documents."
21
+ )
22
+
23
+ # FinanceBench dataset parameters
24
+ parser.add_argument(
25
+ "--dataset_name",
26
+ type=str,
27
+ default="PatronusAI/financebench",
28
+ help="HuggingFace dataset name.",
29
+ )
30
+ parser.add_argument(
31
+ "--split",
32
+ type=str,
33
+ default="train",
34
+ help="Dataset split to use (e.g., 'train').",
35
+ )
36
+
37
+ # PDF directory for unstructured document loader
38
+ parser.add_argument(
39
+ "--pdf_dir",
40
+ type=str,
41
+ default="pdfs/",
42
+ help="Directory path containing PDF files.",
43
+ )
44
+
45
+ # UnstructuredDocumentLoader parameters
46
+ parser.add_argument(
47
+ "--strategy",
48
+ type=str,
49
+ default="fast",
50
+ help="Processing strategy for the unstructured document loader.",
51
+ )
52
+ parser.add_argument(
53
+ "--mode",
54
+ type=str,
55
+ default="elements",
56
+ help="Extraction mode for the unstructured document loader.",
57
+ )
58
+
59
+ # Milvus connection parameters
60
+ parser.add_argument(
61
+ "--milvus_uri",
62
+ type=str,
63
+ help="URI for the Milvus server.",
64
+ )
65
+ parser.add_argument(
66
+ "--milvus_token",
67
+ type=str,
68
+ help="Authentication token for Milvus.",
69
+ )
70
+ parser.add_argument(
71
+ "--collection_name",
72
+ type=str,
73
+ default="financebench",
74
+ help="Name of the Milvus collection to create/use.",
75
+ )
76
+
77
+ # Dense embedding model parameters
78
+ parser.add_argument(
79
+ "--dense_embedding_model",
80
+ type=str,
81
+ default="sentence-transformers/all-mpnet-base-v2",
82
+ help="Model name for dense embeddings.",
83
+ )
84
+ parser.add_argument(
85
+ "--dense_model_kwargs",
86
+ type=dict_type,
87
+ default='{"device": "cpu", "trust_remote_code": true}',
88
+ help="Keyword arguments for dense embeddings model initialization.",
89
+ )
90
+ parser.add_argument(
91
+ "--dense_encode_kwargs",
92
+ type=dict_type,
93
+ default='{"normalize_embeddings": true}',
94
+ help="Keyword arguments for dense embeddings encoding.",
95
+ )
96
+
97
+ # Sparse embedding model parameters
98
+ parser.add_argument(
99
+ "--sparse_embedding_model",
100
+ type=str,
101
+ default="Splade_PP_en_v1",
102
+ help="Model name for sparse embeddings.",
103
+ )
104
+
105
+ # Schema configuration parameters
106
+
107
+ # Field names
108
+ parser.add_argument(
109
+ "--pk_field",
110
+ type=str,
111
+ default="doc_id",
112
+ help="Name of the primary key field.",
113
+ )
114
+ parser.add_argument(
115
+ "--dense_field",
116
+ type=str,
117
+ default="dense_vector",
118
+ help="Name of the dense vector field.",
119
+ )
120
+ parser.add_argument(
121
+ "--sparse_field",
122
+ type=str,
123
+ default="sparse_vector",
124
+ help="Name of the sparse vector field.",
125
+ )
126
+ parser.add_argument(
127
+ "--text_field",
128
+ type=str,
129
+ default="text",
130
+ help="Name of the text field.",
131
+ )
132
+ parser.add_argument(
133
+ "--metadata_field",
134
+ type=str,
135
+ default="metadata",
136
+ help="Name of the metadata field.",
137
+ )
138
+
139
+ parser.add_argument(
140
+ "--dense_dim",
141
+ type=int,
142
+ default=768,
143
+ help="Dimension of dense embeddings.",
144
+ )
145
+ parser.add_argument(
146
+ "--pk_max_length",
147
+ type=int,
148
+ default=100,
149
+ help="Max length for the primary key field.",
150
+ )
151
+ parser.add_argument(
152
+ "--text_max_length",
153
+ type=int,
154
+ default=65535,
155
+ help="Max length for the text field.",
156
+ )
157
+
158
+ # Index parameters
159
+ parser.add_argument(
160
+ "--dense_index_params",
161
+ type=dict_type,
162
+ default='{"index_type": "FLAT", "metric_type": "IP"}',
163
+ help="JSON string specifying dense index parameters.",
164
+ )
165
+ parser.add_argument(
166
+ "--sparse_index_params",
167
+ type=dict_type,
168
+ default='{"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP"}',
169
+ help="JSON string specifying sparse index parameters.",
170
+ )
171
+
172
+ # Collection creation flag
173
+ parser.add_argument(
174
+ "--create_new_collection",
175
+ action="store_true",
176
+ help="Create a new collection or use existing. Defaults to False.",
177
+ )
178
+
179
+ return parser.parse_args()
180
+
181
+
182
+ def main() -> None:
183
+ """Run the FinanceBench document processing pipeline.
184
+
185
+ This function performs the following steps:
186
+ 1. Loads the FinanceBench dataset.
187
+ 2. Retrieves PDF documents from the specified directory.
188
+ 3. Processes PDFs using the UnstructuredDocumentLoader.
189
+ 4. Chunks documents using the UnstructuredChunker.
190
+ 5. Generates dense and sparse embeddings with specified parameters.
191
+ 6. Sets up a Milvus vector database and indexes the documents.
192
+ """
193
+ args = parse_arguments()
194
+
195
+ # Initialize FinanceBench dataloader and load the corpus PDFs
196
+ dataloader = FinanceBenchDataloader(
197
+ dataset_name=args.dataset_name,
198
+ split=args.split,
199
+ )
200
+
201
+ # Load and transform PDF documents from the provided directory
202
+ unstructured_document_loader = UnstructuredDocumentLoader(
203
+ strategy=args.strategy,
204
+ mode=args.mode,
205
+ )
206
+
207
+ # Chunk the documents using the UnstructuredChunker
208
+ chunker = UnstructuredChunker()
209
+
210
+ # Initialize dense and sparse embedding models with additional parameters
211
+ dense_embeddings = HuggingFaceEmbeddings(
212
+ model_name=args.dense_embedding_model,
213
+ model_kwargs=args.dense_model_kwargs,
214
+ encode_kwargs=args.dense_encode_kwargs,
215
+ )
216
+ sparse_embeddings = SparseEmbeddings(
217
+ model_name=args.sparse_embedding_model,
218
+ )
219
+
220
+ # Define Milvus collection fields and schema
221
+ pk_field = args.pk_field
222
+ dense_field = args.dense_field
223
+ sparse_field = args.sparse_field
224
+ text_field = args.text_field
225
+ metadata_field = args.metadata_field
226
+
227
+ fields = [
228
+ FieldSchema(
229
+ name=pk_field,
230
+ dtype=DataType.VARCHAR,
231
+ is_primary=True,
232
+ auto_id=True,
233
+ max_length=args.pk_max_length,
234
+ ),
235
+ FieldSchema(name=dense_field, dtype=DataType.FLOAT_VECTOR, dim=args.dense_dim),
236
+ FieldSchema(name=sparse_field, dtype=DataType.SPARSE_FLOAT_VECTOR),
237
+ FieldSchema(name=text_field, dtype=DataType.VARCHAR, max_length=args.text_max_length),
238
+ FieldSchema(name=metadata_field, dtype=DataType.JSON),
239
+ ]
240
+ schema = CollectionSchema(fields=fields, enable_dynamic_field=False)
241
+
242
+ # Initialize the Milvus vector database client
243
+ milvus_vector_db = MilvusVectorDB(
244
+ uri=args.milvus_uri,
245
+ token=args.milvus_token,
246
+ collection_name=args.collection_name,
247
+ collection_schema=schema,
248
+ dense_field=dense_field,
249
+ sparse_field=sparse_field,
250
+ text_field=text_field,
251
+ metadata_field=metadata_field,
252
+ dense_index_params=args.dense_index_params,
253
+ sparse_index_params=args.sparse_index_params,
254
+ create_new_collection=args.create_new_collection,
255
+ )
256
+
257
+ # Add documents to the Milvus vector database
258
+ dataloader.get_corpus_pdfs()
259
+ documents = unstructured_document_loader.transform_documents(args.pdf_dir)
260
+ chunked_documents = chunker.transform_documents(documents)
261
+ milvus_vector_db.add_documents(
262
+ documents=chunked_documents,
263
+ dense_embedding_model=dense_embeddings,
264
+ sparse_embedding_model=sparse_embeddings,
265
+ )
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
scripts/indexing_pinecone.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import weave
4
+ from dataloaders.langchain import FinanceBenchDataloader
5
+
6
+ from rag_pipelines.embeddings.dense import DenseEmbeddings
7
+ from rag_pipelines.embeddings.sparse_pinecone_text import SparseEmbeddings
8
+ from rag_pipelines.vectordb.pinecone_hybrid_index import PineconeHybridVectorDB
9
+
10
+
11
+ def parse_arguments() -> argparse.Namespace:
12
+ """Parse command-line arguments for the FinanceBench pipeline.
13
+
14
+ Returns:
15
+ argparse.Namespace: Parsed command-line arguments.
16
+ """
17
+ parser = argparse.ArgumentParser(
18
+ description="Process FinanceBench data, generate embeddings, and add processed documents to a Pinecone hybrid index."
19
+ )
20
+
21
+ # Weave tracing project name
22
+ parser.add_argument(
23
+ "--project_name",
24
+ required=True,
25
+ help="Weave project name to initialize tracing.",
26
+ )
27
+
28
+ # FinanceBench dataloader arguments
29
+ parser.add_argument(
30
+ "--dataset_name",
31
+ type=str,
32
+ required=True,
33
+ help="Name of the FinanceBench dataset (e.g., 'PatronusAI/financebench').",
34
+ )
35
+ parser.add_argument(
36
+ "--split",
37
+ type=str,
38
+ default="train[:1]",
39
+ help="Dataset split to use (e.g., 'train[:1]').",
40
+ )
41
+
42
+ # Dense Embeddings arguments
43
+ parser.add_argument(
44
+ "--dense_model_name",
45
+ type=str,
46
+ required=True,
47
+ help="Dense embedding model name (e.g., 'sentence-transformers/all-MiniLM-L6-v2').",
48
+ )
49
+ parser.add_argument(
50
+ "--dense_device",
51
+ type=str,
52
+ default="cpu",
53
+ help="Device to run the dense embedding model (e.g., 'cpu' or 'cuda').",
54
+ )
55
+ parser.add_argument(
56
+ "--normalize_embeddings",
57
+ action="store_true",
58
+ help="Flag to normalize embeddings during encoding.",
59
+ )
60
+ parser.add_argument(
61
+ "--show_progress",
62
+ action="store_true",
63
+ help="Flag to show progress during embedding generation.",
64
+ )
65
+
66
+ # Sparse Embeddings arguments
67
+ parser.add_argument(
68
+ "--sparse_max_seq_length",
69
+ type=int,
70
+ required=True,
71
+ help="Maximum sequence length for sparse embeddings.",
72
+ )
73
+
74
+ # Semantic Chunking arguments (if applicable in your downstream pipeline)
75
+ parser.add_argument(
76
+ "--chunking_threshold_type",
77
+ type=str,
78
+ default="percentile",
79
+ help="Threshold type for semantic chunking (e.g., 'percentile' or 'absolute').",
80
+ )
81
+
82
+ # Pinecone configuration arguments
83
+ parser.add_argument(
84
+ "--pinecone_api_key",
85
+ type=str,
86
+ required=True,
87
+ help="API key for the Pinecone vector database.",
88
+ )
89
+ parser.add_argument(
90
+ "--pinecone_index_name",
91
+ type=str,
92
+ required=True,
93
+ help="Name of the Pinecone index.",
94
+ )
95
+ parser.add_argument(
96
+ "--pinecone_dimension",
97
+ type=int,
98
+ required=True,
99
+ help="Vector dimension in the Pinecone index.",
100
+ )
101
+ parser.add_argument(
102
+ "--pinecone_metric",
103
+ type=str,
104
+ required=True,
105
+ help="Similarity metric for the Pinecone index (e.g., 'dotproduct' or 'cosine').",
106
+ )
107
+ parser.add_argument(
108
+ "--pinecone_region",
109
+ type=str,
110
+ required=True,
111
+ help="Pinecone region (e.g., 'us-east-1').",
112
+ )
113
+ parser.add_argument(
114
+ "--pinecone_cloud",
115
+ type=str,
116
+ required=True,
117
+ help="Pinecone cloud provider (e.g., 'aws').",
118
+ )
119
+ parser.add_argument(
120
+ "--namespace",
121
+ type=str,
122
+ required=True,
123
+ help="Namespace for document storage in Pinecone.",
124
+ )
125
+
126
+ return parser.parse_args()
127
+
128
+
129
+ def main() -> None:
130
+ """Load FinanceBench data, generate dense and sparse embeddings, add processed documents to a Pinecone index.
131
+
132
+ The pipeline performs the following steps:
133
+ 1. Initialize Weave tracing.
134
+ 2. Load FinanceBench documents.
135
+ 3. Generate dense and sparse embeddings for the documents.
136
+ 4. Initialize and configure the Pinecone hybrid vector database.
137
+ 5. Index the processed documents in Pinecone.
138
+ """
139
+ args = parse_arguments()
140
+
141
+ # Initialize Weave tracing
142
+ weave.init(args.project_name)
143
+
144
+ # Load FinanceBench dataset using FinanceBenchDataloader
145
+ data_loader = FinanceBenchDataloader(
146
+ dataset_name=args.dataset_name,
147
+ split=args.split,
148
+ )
149
+ # Download and prepare PDF documents from the dataset (if not already cached)
150
+ data_loader.get_corpus_pdfs()
151
+ # Create structured documents from the downloaded PDFs
152
+ documents = data_loader.create_documents()
153
+ print("Loaded Documents:")
154
+ print(documents)
155
+
156
+ # Initialize dense embedding model
157
+ dense_embeddings = DenseEmbeddings(
158
+ model_name=args.dense_model_name,
159
+ model_kwargs={"device": args.dense_device},
160
+ encode_kwargs={"normalize_embeddings": args.normalize_embeddings},
161
+ show_progress=args.show_progress,
162
+ )
163
+
164
+ # Initialize sparse embedding model
165
+ sparse_embeddings = SparseEmbeddings(model_kwargs={"max_seq_length": args.sparse_max_seq_length})
166
+
167
+ # Extract embedding models for use in the Pinecone vector database
168
+ dense_embedding_model = dense_embeddings.embedding_model
169
+ sparse_embedding_model = sparse_embeddings.sparse_embedding_model
170
+
171
+ # Initialize PineconeHybridVectorDB with specified configuration
172
+ pinecone_vector_db = PineconeHybridVectorDB(
173
+ api_key=args.pinecone_api_key,
174
+ index_name=args.pinecone_index_name,
175
+ dimension=args.pinecone_dimension,
176
+ metric=args.pinecone_metric,
177
+ region=args.pinecone_region,
178
+ cloud=args.pinecone_cloud,
179
+ )
180
+
181
+ # Add the processed documents to the Pinecone hybrid index using both dense and sparse embeddings
182
+ pinecone_vector_db.add_documents(
183
+ documents=documents,
184
+ dense_embedding_model=dense_embedding_model,
185
+ sparse_embedding_model=sparse_embedding_model,
186
+ namespace=args.namespace,
187
+ )
188
+
189
+ print("Documents have been indexed successfully in Pinecone.")
190
+
191
+
192
+ if __name__ == "__main__":
193
+ main()
scripts/indexing_weaviate.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+
4
+ import weave
5
+ from dataloaders.langchain import FinanceBenchDataloader
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+
8
+ from rag_pipelines.unstructured.unstructured_chunker import UnstructuredChunker
9
+ from rag_pipelines.unstructured.unstructured_pdf_loader import UnstructuredDocumentLoader
10
+ from rag_pipelines.utils.logging import LoggerFactory
11
+ from rag_pipelines.vectordb.weaviate import (
12
+ WeaviateVectorDB,
13
+ ) # Assumes the WeaviateVectorDB class is defined as shown above
14
+
15
+ logger_factory = LoggerFactory(logger_name=__name__, log_level=logging.INFO)
16
+ logger = logger_factory.get_logger()
17
+
18
+
19
+ def parse_arguments() -> argparse.Namespace:
20
+ """Parse command-line arguments.
21
+
22
+ Returns:
23
+ argparse.Namespace: Parsed command-line arguments.
24
+ """
25
+ parser = argparse.ArgumentParser(
26
+ description="Run the FinanceBench pipeline to load, process, chunk, embed, and index documents in Weaviate."
27
+ )
28
+
29
+ # FinanceBench dataset parameters
30
+ parser.add_argument(
31
+ "--dataset_name",
32
+ type=str,
33
+ default="PatronusAI/financebench",
34
+ help="Name of the FinanceBench dataset to use.",
35
+ )
36
+ parser.add_argument(
37
+ "--split",
38
+ type=str,
39
+ default="train[:1]",
40
+ help="Dataset split to use (e.g., 'train[:1]').",
41
+ )
42
+
43
+ # PDF directory for unstructured document loader
44
+ parser.add_argument(
45
+ "--pdf_dir",
46
+ type=str,
47
+ default="pdfs/",
48
+ help="Directory path containing PDF files.",
49
+ )
50
+
51
+ # UnstructuredDocumentLoader parameters
52
+ parser.add_argument(
53
+ "--strategy",
54
+ type=str,
55
+ default="fast",
56
+ help="Processing strategy for the unstructured document loader.",
57
+ )
58
+ parser.add_argument(
59
+ "--mode",
60
+ type=str,
61
+ default="elements",
62
+ help="Extraction mode for the unstructured document loader.",
63
+ )
64
+
65
+ # Weaviate connection parameters
66
+ parser.add_argument(
67
+ "--cluster_url",
68
+ type=str,
69
+ required=True,
70
+ help="URL of the Weaviate cluster.",
71
+ )
72
+ parser.add_argument(
73
+ "--api_key",
74
+ type=str,
75
+ required=True,
76
+ help="API key for Weaviate authentication.",
77
+ )
78
+ parser.add_argument(
79
+ "--collection_name",
80
+ type=str,
81
+ default="financebench",
82
+ help="Name of the Weaviate collection to create/use.",
83
+ )
84
+ parser.add_argument(
85
+ "--text_field",
86
+ type=str,
87
+ default="text",
88
+ help="Field name that contains document text in Weaviate.",
89
+ )
90
+
91
+ # Dense embedding model parameters
92
+ parser.add_argument(
93
+ "--dense_model_name",
94
+ type=str,
95
+ default="sentence-transformers/all-mpnet-base-v2",
96
+ help="Dense embedding model name.",
97
+ )
98
+
99
+ return parser.parse_args()
100
+
101
+
102
+ def main() -> None:
103
+ """Run the FinanceBench document processing pipeline using Weaviate.
104
+
105
+ The pipeline performs the following steps:
106
+ 1. Initializes Weave tracing.
107
+ 2. Loads a subset of the FinanceBench dataset.
108
+ 3. Retrieves PDF documents from the specified directory.
109
+ 4. Processes PDFs using the UnstructuredDocumentLoader.
110
+ 5. Chunks documents using the UnstructuredChunker.
111
+ 6. Generates dense embeddings.
112
+ 7. Sets up a Weaviate vector database and indexes the documents.
113
+ """
114
+ args = parse_arguments()
115
+
116
+ # Initialize Weave tracing
117
+ weave.init("financebench_test")
118
+
119
+ # Load FinanceBench dataset and retrieve corpus PDFs
120
+ dataloader = FinanceBenchDataloader(
121
+ dataset_name=args.dataset_name,
122
+ split=args.split,
123
+ )
124
+ dataloader.get_corpus_pdfs()
125
+
126
+ # Load and transform PDF documents from the specified directory
127
+ unstructured_document_loader = UnstructuredDocumentLoader(
128
+ strategy=args.strategy,
129
+ mode=args.mode,
130
+ )
131
+ documents = unstructured_document_loader.transform_documents(args.pdf_dir)
132
+ logger.info("Loaded Documents:")
133
+ logger.info(documents)
134
+
135
+ # Chunk the documents using the UnstructuredChunker
136
+ chunker = UnstructuredChunker()
137
+ chunked_documents = chunker.transform_documents(documents)
138
+ logger.info("Chunked Documents:")
139
+ logger.info(chunked_documents)
140
+
141
+ # Initialize the dense embedding model
142
+ embeddings = HuggingFaceEmbeddings(model_name=args.dense_model_name)
143
+
144
+ # Initialize the Weaviate vector database client
145
+ weaviate_vector_db = WeaviateVectorDB(
146
+ cluster_url=args.cluster_url,
147
+ api_key=args.api_key,
148
+ collection_name=args.collection_name,
149
+ text_field=args.text_field,
150
+ dense_embedding_model=embeddings,
151
+ )
152
+
153
+ # Index the chunked documents in Weaviate using the dense embeddings
154
+ weaviate_vector_db.add_documents(documents=chunked_documents)
155
+ logger.info("Documents have been indexed successfully in Weaviate.")
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()
scripts/rag_evaluation.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from dataloaders import (
4
+ ARCDataloader,
5
+ EdgarDataLoader,
6
+ FactScoreDataloader,
7
+ PopQADataloader,
8
+ TriviaQADataloader,
9
+ )
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+
12
+ from rag_pipelines.embeddings.dense import DenseEmbeddings
13
+ from rag_pipelines.embeddings.sparse import SparseEmbeddings
14
+ from rag_pipelines.evaluation import (
15
+ AnswerRelevancyScorer,
16
+ ContextualPrecisionScorer,
17
+ ContextualRecallScorer,
18
+ ContextualRelevancyScorer,
19
+ Evaluator,
20
+ FaithfulnessScorer,
21
+ HallucinationScorer,
22
+ SummarizationScorer,
23
+ )
24
+ from rag_pipelines.llms.groq import ChatGroqGenerator
25
+ from rag_pipelines.pipelines.rag import RAGPipeline
26
+ from rag_pipelines.prompts.rag_prompt import RAG_PROMPT
27
+ from rag_pipelines.vectordb.pinecone_hybrid_index import PineconeHybridVectorDB
28
+ from rag_pipelines.vectordb.pinecone_hybrid_retriever import PineconeHybridRetriever
29
+
30
+ SUPPORTED_DATASETS = {
31
+ "arc": ARCDataloader,
32
+ "edgar": EdgarDataLoader,
33
+ "popqa": PopQADataloader,
34
+ "factscore": FactScoreDataloader,
35
+ "triviaqa": TriviaQADataloader,
36
+ }
37
+
38
+ SCORER_CLASSES = {
39
+ "contextual_precision": ContextualPrecisionScorer,
40
+ "contextual_recall": ContextualRecallScorer,
41
+ "contextual_relevancy": ContextualRelevancyScorer,
42
+ "answer_relevancy": AnswerRelevancyScorer,
43
+ "faithfulness": FaithfulnessScorer,
44
+ "summarization": SummarizationScorer,
45
+ "hallucination": HallucinationScorer,
46
+ }
47
+
48
+
49
+ def main():
50
+ parser = argparse.ArgumentParser(description="Run the Hybrid RAG pipeline.")
51
+
52
+ # Dense embeddings arguments
53
+ parser.add_argument(
54
+ "--dense_model_name",
55
+ type=str,
56
+ default="sentence-transformers/all-MiniLM-L6-v2",
57
+ help="Dense embedding model name.",
58
+ )
59
+ parser.add_argument(
60
+ "--device",
61
+ type=str,
62
+ default="cpu",
63
+ help="Device to run the dense embedding model.",
64
+ )
65
+
66
+ # Sparse embeddings arguments
67
+ parser.add_argument(
68
+ "--sparse_max_seq_length",
69
+ type=int,
70
+ default=512,
71
+ help="Maximum sequence length for sparse embeddings.",
72
+ )
73
+
74
+ # Pinecone arguments
75
+ parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
76
+ parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
77
+ parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
78
+ parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
79
+ parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
80
+ parser.add_argument("--cloud", type=str, default="aws", help="Pinecone cloud provider.")
81
+ parser.add_argument(
82
+ "--namespace",
83
+ type=str,
84
+ default="edgar-all",
85
+ help="Namespace for Pinecone retriever.",
86
+ )
87
+
88
+ # Retriever arguments
89
+ parser.add_argument("--alpha", type=float, default=0.5, help="Alpha parameter for hybrid retriever.")
90
+ parser.add_argument("--top_k", type=int, default=5, help="Number of top documents to retrieve.")
91
+
92
+ # LLM arguments
93
+ parser.add_argument(
94
+ "--llm_model",
95
+ type=str,
96
+ default="llama-3.2-90b-vision-preview",
97
+ help="Language model name.",
98
+ )
99
+ parser.add_argument(
100
+ "--temperature",
101
+ type=float,
102
+ default=0,
103
+ help="Temperature for the language model.",
104
+ )
105
+ parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
106
+
107
+ # Load evaluation data
108
+ parser = argparse.ArgumentParser(description="Load evaluation dataset and initialize the dataloader.")
109
+ parser.add_argument(
110
+ "--dataset_type",
111
+ type=str,
112
+ default="edgar",
113
+ choices=SUPPORTED_DATASETS.keys(),
114
+ help="Dataset to load from. Options: arc, edgar, popqa, factscore, triviaqa.",
115
+ )
116
+ parser.add_argument(
117
+ "--hf_dataset_path",
118
+ type=str,
119
+ default="lamini/earnings-calls-qa",
120
+ help="Path to the HuggingFace dataset.",
121
+ )
122
+ parser.add_argument(
123
+ "--dataset_split",
124
+ type=str,
125
+ default="test",
126
+ help="Split of the dataset to use (e.g., train, validation, test).",
127
+ )
128
+
129
+ # Scorer arguments
130
+ parser.add_argument(
131
+ "--scorer",
132
+ type=str,
133
+ default="contextual_precision",
134
+ choices=[
135
+ "contextual_precision",
136
+ "contextual_recall",
137
+ "contextual_relevancy",
138
+ "answer_relevancy",
139
+ "faithfulness",
140
+ "summarization",
141
+ "hallucination",
142
+ ],
143
+ help="Scorer to use.",
144
+ )
145
+
146
+ # Evaluation arguments
147
+ parser.add_argument(
148
+ "--evaluation_name",
149
+ type=str,
150
+ default="hybrid_rag",
151
+ help="Name of the evaluation.",
152
+ )
153
+
154
+ # Add argument for selecting scorers
155
+ parser.add_argument(
156
+ "--scorers",
157
+ type=str,
158
+ nargs="+",
159
+ choices=SCORER_CLASSES.keys(),
160
+ required=True,
161
+ help="List of scorers to use. Options: contextual_precision, contextual_recall, contextual_relevancy, "
162
+ "answer_relevancy, faithfulness, summarization, hallucination.",
163
+ )
164
+
165
+ # Add shared arguments for scorer parameters
166
+ parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for evaluation.")
167
+ parser.add_argument("--model", type=str, default="gpt-4", help="Model to use for scoring.")
168
+ parser.add_argument("--include_reason", action="store_true", help="Include reasons in scoring.")
169
+ parser.add_argument(
170
+ "--assessment_questions",
171
+ type=str,
172
+ nargs="*",
173
+ help="List of assessment questions for scoring.",
174
+ )
175
+ parser.add_argument("--strict_mode", action="store_true", help="Enable strict mode for scoring.")
176
+ parser.add_argument("--async_mode", action="store_true", help="Enable asynchronous processing.")
177
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output.")
178
+ parser.add_argument(
179
+ "--truths_extraction_limit",
180
+ type=int,
181
+ default=None,
182
+ help="Limit for truth extraction in scoring.",
183
+ )
184
+
185
+ args = parser.parse_args()
186
+
187
+ # Initialize dataloader based on the dataset type
188
+ try:
189
+ DataLoaderClass = SUPPORTED_DATASETS[args.dataset_type]
190
+ dataloader = DataLoaderClass(dataset_name=args.hf_dataset_path, split=args.dataset_split)
191
+ except KeyError:
192
+ msg = (
193
+ f"Dataset '{args.dataset_type}' is not supported. "
194
+ f"Supported options are: {', '.join(SUPPORTED_DATASETS.keys())}."
195
+ )
196
+ raise ValueError(msg)
197
+
198
+ eval_dataset = dataloader.get_eval_data()
199
+
200
+ # Initialize embeddings
201
+ dense_embeddings = DenseEmbeddings(
202
+ model_name=args.dense_model_name,
203
+ model_kwargs={"device": args.device},
204
+ encode_kwargs={"normalize_embeddings": True},
205
+ show_progress=True,
206
+ )
207
+ sparse_embeddings = SparseEmbeddings(model_kwargs={"max_seq_length": args.sparse_max_seq_length})
208
+
209
+ dense_embedding_model = dense_embeddings.embedding_model
210
+ sparse_embedding_model = sparse_embeddings.sparse_embedding_model
211
+
212
+ # Initialize Pinecone vector DB
213
+ pinecone_vector_db = PineconeHybridVectorDB(
214
+ api_key=args.pinecone_api_key,
215
+ index_name=args.index_name,
216
+ dimension=args.dimension,
217
+ metric=args.metric,
218
+ region=args.region,
219
+ cloud=args.cloud,
220
+ )
221
+
222
+ # Initialize Pinecone retriever
223
+ pinecone_retriever = PineconeHybridRetriever(
224
+ index=pinecone_vector_db.index,
225
+ dense_embedding_model=dense_embedding_model,
226
+ sparse_embedding_model=sparse_embedding_model,
227
+ alpha=args.alpha,
228
+ top_k=args.top_k,
229
+ namespace=args.namespace,
230
+ )
231
+
232
+ # Load the prompt
233
+
234
+ prompt = ChatPromptTemplate.from_messages(
235
+ [
236
+ ("human", RAG_PROMPT),
237
+ ]
238
+ )
239
+
240
+ # Initialize the LLM
241
+ generator = ChatGroqGenerator(
242
+ model=args.llm_model,
243
+ api_key=args.llm_api_key,
244
+ llm_params={"temperature": args.temperature},
245
+ )
246
+ llm = generator.llm
247
+
248
+ # Initialize the Hybrid RAG pipeline
249
+ hybrid_rag = RAGPipeline(
250
+ retriever=pinecone_retriever.hybrid_retriever,
251
+ prompt=prompt,
252
+ llm=llm,
253
+ tracing_project_name="sec_hybrid_rag",
254
+ )
255
+
256
+ # Initialize the scorers with the provided arguments
257
+ scorers = []
258
+ for scorer_name in args.scorers:
259
+ if scorer_name in SCORER_CLASSES:
260
+ ScorerClass = SCORER_CLASSES[scorer_name]
261
+ scorer = ScorerClass(
262
+ threshold=args.threshold,
263
+ model=args.model,
264
+ include_reason=args.include_reason,
265
+ assessment_questions=args.assessment_questions,
266
+ strict_mode=args.strict_mode,
267
+ async_mode=args.async_mode,
268
+ verbose=args.verbose,
269
+ truths_extraction_limit=args.truths_extraction_limit,
270
+ )
271
+ scorers.append(scorer)
272
+ else:
273
+ msg = f"Scorer '{scorer_name}' is not supported."
274
+ raise ValueError(msg)
275
+
276
+ # Run the pipeline
277
+ evaluator = Evaluator(
278
+ evaluation_name=args.evaluation_name,
279
+ pipeline=hybrid_rag,
280
+ dataset=eval_dataset,
281
+ scorers=[scorers],
282
+ )
283
+
284
+ evaluation_results = evaluator.evaluate()
285
+ print(evaluation_results)
286
+
287
+
288
+ if __name__ == "__main__":
289
+ main()
scripts/self_rag.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from langchain_community.retrievers import PineconeHybridSearchRetriever
4
+ from langchain_core.prompts.chat import ChatPromptTemplate
5
+ from langchain_groq import ChatGroq
6
+
7
+ from rag_pipelines.pipelines.self_rag import SelfRAGPipeline
8
+ from rag_pipelines.query_transformer.query_transformer import QueryTransformer
9
+ from rag_pipelines.retrieval_evaluator.document_grader import DocumentGrader
10
+ from rag_pipelines.retrieval_evaluator.retrieval_evaluator import RetrievalEvaluator
11
+ from rag_pipelines.websearch.web_search import WebSearch
12
+
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser(description="Run the Self-RAG pipeline.")
16
+
17
+ # Pinecone retriever arguments
18
+ parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
19
+ parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
20
+ parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
21
+ parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
22
+ parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
23
+ parser.add_argument(
24
+ "--namespace",
25
+ type=str,
26
+ default="edgar-all",
27
+ help="Namespace for Pinecone retriever.",
28
+ )
29
+
30
+ # Query Transformer arguments
31
+ parser.add_argument(
32
+ "--query_transformer_model",
33
+ type=str,
34
+ default="t5-small",
35
+ help="Model used for query transformation.",
36
+ )
37
+
38
+ # Retrieval Evaluator arguments
39
+ parser.add_argument(
40
+ "--llm_model",
41
+ type=str,
42
+ default="llama-3.2-90b-vision-preview",
43
+ help="Language model name for retrieval evaluator.",
44
+ )
45
+ parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
46
+ parser.add_argument(
47
+ "--temperature",
48
+ type=float,
49
+ default=0.7,
50
+ help="Temperature for the language model.",
51
+ )
52
+ parser.add_argument(
53
+ "--relevance_threshold",
54
+ type=float,
55
+ default=0.7,
56
+ help="Relevance threshold for document grading.",
57
+ )
58
+
59
+ # Web Search arguments
60
+ parser.add_argument("--web_search_api_key", type=str, required=True, help="API key for web search.")
61
+
62
+ # Prompt arguments
63
+ parser.add_argument(
64
+ "--prompt_template_path",
65
+ type=str,
66
+ required=True,
67
+ help="Path to the prompt template for LLM.",
68
+ )
69
+
70
+ # Query
71
+ parser.add_argument(
72
+ "--query",
73
+ type=str,
74
+ required=True,
75
+ help="Query to run through the Self-RAG pipeline.",
76
+ )
77
+
78
+ args = parser.parse_args()
79
+
80
+ # Initialize Pinecone retriever
81
+ retriever = PineconeHybridSearchRetriever(
82
+ api_key=args.pinecone_api_key,
83
+ index_name=args.index_name,
84
+ dimension=args.dimension,
85
+ metric=args.metric,
86
+ region=args.region,
87
+ namespace=args.namespace,
88
+ )
89
+
90
+ # Initialize Query Transformer
91
+ query_transformer = QueryTransformer(model_name=args.query_transformer_model)
92
+
93
+ # Initialize Retrieval Evaluator and Document Grader
94
+ retrieval_evaluator = RetrievalEvaluator(
95
+ llm_model=args.llm_model,
96
+ llm_api_key=args.llm_api_key,
97
+ temperature=args.temperature,
98
+ )
99
+ document_grader = DocumentGrader(
100
+ evaluator=retrieval_evaluator,
101
+ threshold=args.relevance_threshold,
102
+ )
103
+
104
+ # Initialize Web Search
105
+ web_search = WebSearch(api_key=args.web_search_api_key)
106
+
107
+ # Load the prompt template
108
+ with open(args.prompt_template_path) as file:
109
+ prompt_template_str = file.read()
110
+ prompt = ChatPromptTemplate.from_template(prompt_template_str)
111
+
112
+ # Initialize the LLM
113
+ llm = ChatGroq(
114
+ model=args.llm_model,
115
+ api_key=args.llm_api_key,
116
+ llm_params={"temperature": args.temperature},
117
+ )
118
+
119
+ # Initialize Self-RAG Pipeline
120
+ self_rag_pipeline = SelfRAGPipeline(
121
+ retriever=retriever,
122
+ query_transformer=query_transformer,
123
+ retrieval_evaluator=retrieval_evaluator,
124
+ document_grader=document_grader,
125
+ web_search=web_search,
126
+ prompt=prompt,
127
+ llm=llm,
128
+ )
129
+
130
+ # Run the pipeline
131
+ output = self_rag_pipeline.run(args.query)
132
+ print(output)
133
+
134
+
135
+ if __name__ == "__main__":
136
+ main()
scripts/self_rag_evaluation.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from dataloaders.langchain import (
4
+ ARCDataloader,
5
+ EdgarDataLoader,
6
+ FactScoreDataloader,
7
+ PopQADataloader,
8
+ TriviaQADataloader,
9
+ )
10
+ from langchain_community.retrievers import PineconeHybridSearchRetriever
11
+ from langchain_core.prompts.chat import ChatPromptTemplate
12
+ from langchain_groq import ChatGroq
13
+
14
+ from rag_pipelines.evaluation import (
15
+ AnswerRelevancyScorer,
16
+ ContextualPrecisionScorer,
17
+ ContextualRecallScorer,
18
+ ContextualRelevancyScorer,
19
+ Evaluator,
20
+ FaithfulnessScorer,
21
+ HallucinationScorer,
22
+ SummarizationScorer,
23
+ )
24
+ from rag_pipelines.pipelines.self_rag import SelfRAGPipeline
25
+ from rag_pipelines.query_transformer.query_transformer import QueryTransformer
26
+ from rag_pipelines.retrieval_evaluator.document_grader import DocumentGrader
27
+ from rag_pipelines.retrieval_evaluator.retrieval_evaluator import RetrievalEvaluator
28
+ from rag_pipelines.websearch.web_search import WebSearch
29
+
30
+ SUPPORTED_DATASETS = {
31
+ "arc": ARCDataloader,
32
+ "edgar": EdgarDataLoader,
33
+ "popqa": PopQADataloader,
34
+ "factscore": FactScoreDataloader,
35
+ "triviaqa": TriviaQADataloader,
36
+ }
37
+
38
+ SCORER_CLASSES = {
39
+ "contextual_precision": ContextualPrecisionScorer,
40
+ "contextual_recall": ContextualRecallScorer,
41
+ "contextual_relevancy": ContextualRelevancyScorer,
42
+ "answer_relevancy": AnswerRelevancyScorer,
43
+ "faithfulness": FaithfulnessScorer,
44
+ "summarization": SummarizationScorer,
45
+ "hallucination": HallucinationScorer,
46
+ }
47
+
48
+
49
+ def main():
50
+ parser = argparse.ArgumentParser(description="Run the Self-RAG pipeline.")
51
+
52
+ # Pinecone retriever arguments
53
+ parser.add_argument("--pinecone_api_key", type=str, required=True, help="Pinecone API key.")
54
+ parser.add_argument("--index_name", type=str, default="edgar", help="Pinecone index name.")
55
+ parser.add_argument("--dimension", type=int, default=384, help="Dimension of embeddings.")
56
+ parser.add_argument("--metric", type=str, default="dotproduct", help="Metric for similarity search.")
57
+ parser.add_argument("--region", type=str, default="us-east-1", help="Pinecone region.")
58
+ parser.add_argument(
59
+ "--namespace",
60
+ type=str,
61
+ default="edgar-all",
62
+ help="Namespace for Pinecone retriever.",
63
+ )
64
+
65
+ # Query Transformer arguments
66
+ parser.add_argument(
67
+ "--query_transformer_model",
68
+ type=str,
69
+ default="t5-small",
70
+ help="Model used for query transformation.",
71
+ )
72
+
73
+ # Retrieval Evaluator arguments
74
+ parser.add_argument(
75
+ "--llm_model",
76
+ type=str,
77
+ default="llama-3.2-90b-vision-preview",
78
+ help="Language model name for retrieval evaluator.",
79
+ )
80
+ parser.add_argument("--llm_api_key", type=str, required=True, help="API key for the language model.")
81
+ parser.add_argument(
82
+ "--temperature",
83
+ type=float,
84
+ default=0.7,
85
+ help="Temperature for the language model.",
86
+ )
87
+ parser.add_argument(
88
+ "--relevance_threshold",
89
+ type=float,
90
+ default=0.7,
91
+ help="Relevance threshold for document grading.",
92
+ )
93
+
94
+ # Web Search arguments
95
+ parser.add_argument("--web_search_api_key", type=str, required=True, help="API key for web search.")
96
+
97
+ # Prompt arguments
98
+ parser.add_argument(
99
+ "--prompt_template_path",
100
+ type=str,
101
+ required=True,
102
+ help="Path to the prompt template for LLM.",
103
+ )
104
+
105
+ # Load evaluation data
106
+ parser = argparse.ArgumentParser(description="Load evaluation dataset and initialize the dataloader.")
107
+ parser.add_argument(
108
+ "--dataset_type",
109
+ type=str,
110
+ default="edgar",
111
+ choices=SUPPORTED_DATASETS.keys(),
112
+ help="Dataset to load from. Options: arc, edgar, popqa, factscore, triviaqa.",
113
+ )
114
+ parser.add_argument(
115
+ "--hf_dataset_path",
116
+ type=str,
117
+ default="lamini/earnings-calls-qa",
118
+ help="Path to the HuggingFace dataset.",
119
+ )
120
+ parser.add_argument(
121
+ "--dataset_split",
122
+ type=str,
123
+ default="test",
124
+ help="Split of the dataset to use (e.g., train, validation, test).",
125
+ )
126
+
127
+ # Scorer arguments
128
+ parser.add_argument(
129
+ "--scorer",
130
+ type=str,
131
+ default="contextual_precision",
132
+ choices=[
133
+ "contextual_precision",
134
+ "contextual_recall",
135
+ "contextual_relevancy",
136
+ "answer_relevancy",
137
+ "faithfullness",
138
+ "summarization",
139
+ "hallucination",
140
+ ],
141
+ help="Scorer to use.",
142
+ )
143
+
144
+ # Evaluation arguments
145
+ parser.add_argument(
146
+ "--evaluation_name",
147
+ type=str,
148
+ default="hybrid_rag",
149
+ help="Name of the evaluation.",
150
+ )
151
+
152
+ # Add argument for selecting scorers
153
+ parser.add_argument(
154
+ "--scorers",
155
+ type=str,
156
+ nargs="+",
157
+ choices=SCORER_CLASSES.keys(),
158
+ required=True,
159
+ help="List of scorers to use. Options: contextual_precision, contextual_recall, contextual_relevancy, "
160
+ "answer_relevancy, faithfulness, summarization, hallucination.",
161
+ )
162
+
163
+ # Add shared arguments for scorer parameters
164
+ parser.add_argument("--threshold", type=float, default=0.5, help="Threshold for evaluation.")
165
+ parser.add_argument("--model", type=str, default="gpt-4", help="Model to use for scoring.")
166
+ parser.add_argument("--include_reason", action="store_true", help="Include reasons in scoring.")
167
+ parser.add_argument(
168
+ "--assessment_questions",
169
+ type=str,
170
+ nargs="*",
171
+ help="List of assessment questions for scoring.",
172
+ )
173
+ parser.add_argument("--strict_mode", action="store_true", help="Enable strict mode for scoring.")
174
+ parser.add_argument("--async_mode", action="store_true", help="Enable asynchronous processing.")
175
+ parser.add_argument("--verbose", action="store_true", help="Enable verbose output.")
176
+ parser.add_argument(
177
+ "--truths_extraction_limit",
178
+ type=int,
179
+ default=None,
180
+ help="Limit for truth extraction in scoring.",
181
+ )
182
+
183
+ args = parser.parse_args()
184
+
185
+ # Initialize dataloader based on the dataset type
186
+ try:
187
+ DataLoaderClass = SUPPORTED_DATASETS[args.dataset_type]
188
+ dataloader = DataLoaderClass(dataset_name=args.hf_dataset_path, split=args.dataset_split)
189
+ except KeyError:
190
+ msg = (
191
+ f"Dataset '{args.dataset_type}' is not supported. "
192
+ f"Supported options are: {', '.join(SUPPORTED_DATASETS.keys())}."
193
+ )
194
+ raise ValueError(msg)
195
+
196
+ eval_dataset = dataloader.get_eval_data()
197
+
198
+ # Initialize Pinecone retriever
199
+ retriever = PineconeHybridSearchRetriever(
200
+ api_key=args.pinecone_api_key,
201
+ index_name=args.index_name,
202
+ dimension=args.dimension,
203
+ metric=args.metric,
204
+ region=args.region,
205
+ namespace=args.namespace,
206
+ )
207
+
208
+ # Initialize Query Transformer
209
+ query_transformer = QueryTransformer(model_name=args.query_transformer_model)
210
+
211
+ # Initialize Retrieval Evaluator and Document Grader
212
+ retrieval_evaluator = RetrievalEvaluator(
213
+ llm_model=args.llm_model,
214
+ llm_api_key=args.llm_api_key,
215
+ temperature=args.temperature,
216
+ )
217
+ document_grader = DocumentGrader(
218
+ evaluator=retrieval_evaluator,
219
+ threshold=args.relevance_threshold,
220
+ )
221
+
222
+ # Initialize Web Search
223
+ web_search = WebSearch(api_key=args.web_search_api_key)
224
+
225
+ # Load the prompt template
226
+ with open(args.prompt_template_path) as file:
227
+ prompt_template_str = file.read()
228
+ prompt = ChatPromptTemplate.from_template(prompt_template_str)
229
+
230
+ # Initialize the LLM
231
+ llm = ChatGroq(
232
+ model=args.llm_model,
233
+ api_key=args.llm_api_key,
234
+ llm_params={"temperature": args.temperature},
235
+ )
236
+
237
+ # Initialize Self-RAG Pipeline
238
+ self_rag_pipeline = SelfRAGPipeline(
239
+ retriever=retriever,
240
+ query_transformer=query_transformer,
241
+ retrieval_evaluator=retrieval_evaluator,
242
+ document_grader=document_grader,
243
+ web_search=web_search,
244
+ prompt=prompt,
245
+ llm=llm,
246
+ )
247
+
248
+ # Initialize the scorers with the provided arguments
249
+ scorers = []
250
+ for scorer_name in args.scorers:
251
+ if scorer_name in SCORER_CLASSES:
252
+ scorer_class = SCORER_CLASSES[scorer_name]
253
+ scorer = scorer_class(
254
+ threshold=args.threshold,
255
+ model=args.model,
256
+ include_reason=args.include_reason,
257
+ assessment_questions=args.assessment_questions,
258
+ strict_mode=args.strict_mode,
259
+ async_mode=args.async_mode,
260
+ verbose=args.verbose,
261
+ truths_extraction_limit=args.truths_extraction_limit,
262
+ )
263
+ scorers.append(scorer)
264
+ else:
265
+ msg = f"Scorer '{scorer_name}' is not supported."
266
+ raise ValueError(msg)
267
+
268
+ # Run the pipeline
269
+ evaluator = Evaluator(
270
+ evaluation_name=args.evaluation_name,
271
+ pipeline=self_rag_pipeline,
272
+ dataset=eval_dataset,
273
+ scorers=[scorers],
274
+ )
275
+
276
+ evaluation_results = evaluator.evaluate()
277
+ print(evaluation_results)
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()