File size: 3,474 Bytes
6c5ce7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os

import weave
from dataloaders.langchain import FinanceBenchDataloader
from dspy import LM
from langchain_huggingface import HuggingFaceEmbeddings

from rag_pipelines.embeddings import SparseEmbeddingsMilvus as SparseEmbeddings
from rag_pipelines.evaluation import AnswerRelevancyScorer, ContextualPrecisionScorer, Evaluator
from rag_pipelines.pipelines import RAG, DSPyRAGPipeline
from rag_pipelines.vectordb import DSPyMilvusRetriever as MilvusRetriever
from rag_pipelines.vectordb import MilvusVectorDB, milvus_retriever

os.environ["WEAVE_PARALLELISM"] = "1"
os.environ["WEAVE_TRACE_LANGCHAIN"] = "false"

dense_model = "intfloat/multilingual-e5-large"
encode_kwargs = {"prompt": "query: "}
model_kwargs = {
    "device": "cpu",
    "trust_remote_code": True,
    "backend": "onnx",
    "model_kwargs": {"file_name": "onnx/model.onnx"},
}

##### Use the e5-large-instruct model for everything now

dense_field = "dense_vector"
sparse_field = "sparse_vector"
text_field = "text"
metadata_field = "metadata"

dense_search_params = {
    "metric_type": "COSINE",
}

sparse_search_params = {
    "metric_type": "IP",
}

milvus_uri = "https://in03-8aaa331b36bf39c.serverless.gcp-us-west1.cloud.zilliz.com"
milvus_token = (
    "cd567c8418a6b8fe4b438300cfc56212f22ef1347bc12031b0114bd72ba0aec3978ce8c107c11a4ae01239b010c15765358cdf37"
)
milvus_collection_name = "financebenchsub"


tracing_project_name = "dspy_rag"
weave_params = {}

client = weave.init(tracing_project_name, **weave_params)

dense_embeddings = HuggingFaceEmbeddings(model_name=dense_model, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)
sparse_embeddings = SparseEmbeddings(model_name="Splade_PP_en_v1")

milvus_vector_db = MilvusVectorDB(
    uri=milvus_uri,
    token=milvus_token,
    collection_name=milvus_collection_name,
)

milvus_retriever = MilvusRetriever(
    collection=milvus_vector_db.collection,
    dense_embedding_model=dense_embeddings,
    sparse_embedding_model=sparse_embeddings,
    anns_fields=[dense_field, sparse_field],
    field_search_params=[dense_search_params, sparse_search_params],
    text_field=text_field,
    top_k=3,
)

llm = LM(
    "groq/llama-3.3-70b-versatile",
    api_key="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",
    num_retries=120,
)

dspy_rag_module = RAG(milvus_retriever)

pipeline = DSPyRAGPipeline(rag_module=dspy_rag_module, llm=llm)

dataloader = FinanceBenchDataloader(
    dataset_name="PatronusAI/financebench",
    split="train[:2]",
)

data = dataloader.load_data()
eval_data = dataloader.get_evaluation_data()

evaluation_dataset = weave.Dataset(name="financebench_test_evaluation_dataset", rows=eval_data)

questions = dataloader.get_questions()

# dataloader.publish_to_weave(
#     weave_project_name="financebench_test",
#     dataset_name="financebench_test_dataset",
#     evaluation_dataset_name="financebench_test_evaluation_dataset",
# )

answer_relevancy_scorer = AnswerRelevancyScorer(
    threshold=0.5,
    model="gpt-4",
    include_reason=True,
    verbose=True,
)
contextual_precision_scorer = ContextualPrecisionScorer(
    threshold=0.5,
    model="gpt-4",
    include_reason=True,
    verbose=True,
)

evaluator = Evaluator(
    evaluation_name="financebench_test_subset_2",
    evaluation_dataset=evaluation_dataset,
    evaluation_scorers=[answer_relevancy_scorer, contextual_precision_scorer],
    pipeline=pipeline,
)

evaluation_results = evaluator.evaluate()
print(evaluation_results)