test-ragp / scripts /dspy_rag.py
awinml's picture
Upload 11 files
6c5ce7a verified
import os
import weave
from dataloaders.langchain import FinanceBenchDataloader
from dspy import LM
from langchain_huggingface import HuggingFaceEmbeddings
from rag_pipelines.embeddings import SparseEmbeddingsMilvus as SparseEmbeddings
from rag_pipelines.evaluation import AnswerRelevancyScorer, ContextualPrecisionScorer, Evaluator
from rag_pipelines.pipelines import RAG, DSPyRAGPipeline
from rag_pipelines.vectordb import DSPyMilvusRetriever as MilvusRetriever
from rag_pipelines.vectordb import MilvusVectorDB, milvus_retriever
os.environ["WEAVE_PARALLELISM"] = "1"
os.environ["WEAVE_TRACE_LANGCHAIN"] = "false"
dense_model = "intfloat/multilingual-e5-large"
encode_kwargs = {"prompt": "query: "}
model_kwargs = {
"device": "cpu",
"trust_remote_code": True,
"backend": "onnx",
"model_kwargs": {"file_name": "onnx/model.onnx"},
}
##### Use the e5-large-instruct model for everything now
dense_field = "dense_vector"
sparse_field = "sparse_vector"
text_field = "text"
metadata_field = "metadata"
dense_search_params = {
"metric_type": "COSINE",
}
sparse_search_params = {
"metric_type": "IP",
}
milvus_uri = "https://in03-8aaa331b36bf39c.serverless.gcp-us-west1.cloud.zilliz.com"
milvus_token = (
"cd567c8418a6b8fe4b438300cfc56212f22ef1347bc12031b0114bd72ba0aec3978ce8c107c11a4ae01239b010c15765358cdf37"
)
milvus_collection_name = "financebenchsub"
tracing_project_name = "dspy_rag"
weave_params = {}
client = weave.init(tracing_project_name, **weave_params)
dense_embeddings = HuggingFaceEmbeddings(model_name=dense_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)
sparse_embeddings = SparseEmbeddings(model_name="Splade_PP_en_v1")
milvus_vector_db = MilvusVectorDB(
uri=milvus_uri,
token=milvus_token,
collection_name=milvus_collection_name,
)
milvus_retriever = MilvusRetriever(
collection=milvus_vector_db.collection,
dense_embedding_model=dense_embeddings,
sparse_embedding_model=sparse_embeddings,
anns_fields=[dense_field, sparse_field],
field_search_params=[dense_search_params, sparse_search_params],
text_field=text_field,
top_k=3,
)
llm = LM(
"groq/llama-3.3-70b-versatile",
api_key="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",
num_retries=120,
)
dspy_rag_module = RAG(milvus_retriever)
pipeline = DSPyRAGPipeline(rag_module=dspy_rag_module, llm=llm)
dataloader = FinanceBenchDataloader(
dataset_name="PatronusAI/financebench",
split="train[:2]",
)
data = dataloader.load_data()
eval_data = dataloader.get_evaluation_data()
evaluation_dataset = weave.Dataset(name="financebench_test_evaluation_dataset", rows=eval_data)
questions = dataloader.get_questions()
# dataloader.publish_to_weave(
# weave_project_name="financebench_test",
# dataset_name="financebench_test_dataset",
# evaluation_dataset_name="financebench_test_evaluation_dataset",
# )
answer_relevancy_scorer = AnswerRelevancyScorer(
threshold=0.5,
model="gpt-4",
include_reason=True,
verbose=True,
)
contextual_precision_scorer = ContextualPrecisionScorer(
threshold=0.5,
model="gpt-4",
include_reason=True,
verbose=True,
)
evaluator = Evaluator(
evaluation_name="financebench_test_subset_2",
evaluation_dataset=evaluation_dataset,
evaluation_scorers=[answer_relevancy_scorer, contextual_precision_scorer],
pipeline=pipeline,
)
evaluation_results = evaluator.evaluate()
print(evaluation_results)