Spaces:
Build error
Build error
Ilia Tambovtsev commited on
Commit ·
e37d064
1
Parent(s): 2cb5a84
feat: implement retrieval with llm reranking
Browse files- src/eval/eval_mlflow.py +2 -2
- src/eval/evaluate.py +20 -15
- src/rag/storage.py +150 -3
src/eval/eval_mlflow.py
CHANGED
|
@@ -295,7 +295,7 @@ class RAGEvaluator:
|
|
| 295 |
retriever = PresentationRetriever(
|
| 296 |
storage=self.storage,
|
| 297 |
scorer=scorer,
|
| 298 |
-
|
| 299 |
)
|
| 300 |
|
| 301 |
# Run evaluation for each question
|
|
@@ -313,7 +313,7 @@ class RAGEvaluator:
|
|
| 313 |
"pages": [int(x) if x else -1 for x in row["page"].split(",")],
|
| 314 |
}
|
| 315 |
|
| 316 |
-
output = retriever
|
| 317 |
|
| 318 |
self._logger.info(
|
| 319 |
f"Retrieved {len(output['contexts'])} presentations"
|
|
|
|
| 295 |
retriever = PresentationRetriever(
|
| 296 |
storage=self.storage,
|
| 297 |
scorer=scorer,
|
| 298 |
+
n_pages=self.config.n_contexts,
|
| 299 |
)
|
| 300 |
|
| 301 |
# Run evaluation for each question
|
|
|
|
| 313 |
"pages": [int(x) if x else -1 for x in row["page"].split(",")],
|
| 314 |
}
|
| 315 |
|
| 316 |
+
output = retriever(dict(question=row["question"]))
|
| 317 |
|
| 318 |
self._logger.info(
|
| 319 |
f"Retrieved {len(output['contexts'])} presentations"
|
src/eval/evaluate.py
CHANGED
|
@@ -4,7 +4,7 @@ import time
|
|
| 4 |
from collections import OrderedDict
|
| 5 |
from functools import partial
|
| 6 |
from textwrap import dedent
|
| 7 |
-
from typing import Dict, List, Optional
|
| 8 |
|
| 9 |
import pandas as pd
|
| 10 |
from langchain_core import outputs
|
|
@@ -17,6 +17,7 @@ from langsmith.evaluation.evaluator import DynamicRunEvaluator, EvaluationResult
|
|
| 17 |
from langsmith.schemas import Dataset
|
| 18 |
from langsmith.utils import LangSmithError
|
| 19 |
from pandas._libs.tslibs.np_datetime import py_td64_to_tdstruct
|
|
|
|
| 20 |
from pydantic import BaseModel, ConfigDict, Field
|
| 21 |
from ragas import SingleTurnSample
|
| 22 |
from ragas.llms.base import LangchainLLMWrapper
|
|
@@ -29,6 +30,7 @@ from src.rag import (
|
|
| 29 |
PresentationRetriever,
|
| 30 |
ScorerTypes,
|
| 31 |
)
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
@run_evaluator
|
|
@@ -211,13 +213,15 @@ class EvaluationConfig(BaseModel):
|
|
| 211 |
|
| 212 |
# Configure Retrieval
|
| 213 |
scorers: List[ScorerTypes] = [MinScorer(), HyperbolicScorer()]
|
|
|
|
| 214 |
|
| 215 |
# Setup Evaluators
|
| 216 |
evaluators: List[DynamicRunEvaluator] = [presentation_match, page_match]
|
| 217 |
|
| 218 |
# Configure RAGAS
|
| 219 |
# ragas_metrics: List[type] = [Faithfulness] # List of metric classes
|
| 220 |
-
n_contexts: int =
|
|
|
|
| 221 |
|
| 222 |
# Configure evaluation
|
| 223 |
max_concurrency: int = 2
|
|
@@ -232,7 +236,6 @@ class RAGEvaluatorLangsmith:
|
|
| 232 |
|
| 233 |
def __init__(
|
| 234 |
self,
|
| 235 |
-
storage: ChromaSlideStore,
|
| 236 |
config: EvaluationConfig,
|
| 237 |
llm: ChatOpenAI = Config().model_config.load_vsegpt(model="openai/gpt-4o-mini"),
|
| 238 |
):
|
|
@@ -248,11 +251,10 @@ class RAGEvaluatorLangsmith:
|
|
| 248 |
)
|
| 249 |
|
| 250 |
# Setup class
|
| 251 |
-
self.storage = storage
|
| 252 |
self.client = Client()
|
| 253 |
self.config = config
|
| 254 |
-
|
| 255 |
-
self.
|
| 256 |
|
| 257 |
@classmethod
|
| 258 |
def load_questions_from_sheet(cls, *args, **kwargs) -> pd.DataFrame:
|
|
@@ -332,15 +334,17 @@ class RAGEvaluatorLangsmith:
|
|
| 332 |
else:
|
| 333 |
experiment_prefix = f"{scorer.id}"
|
| 334 |
|
| 335 |
-
retriever =
|
| 336 |
-
|
| 337 |
-
)
|
| 338 |
evaluate(
|
| 339 |
retriever,
|
| 340 |
experiment_prefix=experiment_prefix,
|
| 341 |
data=self.config.dataset_name,
|
| 342 |
evaluators=list(chains.values()),
|
| 343 |
-
metadata=dict(
|
|
|
|
|
|
|
|
|
|
| 344 |
max_concurrency=self.config.max_concurrency,
|
| 345 |
)
|
| 346 |
|
|
@@ -369,12 +373,13 @@ def main():
|
|
| 369 |
storage = ChromaSlideStore(collection_name="pres0", embedding_model=embeddings)
|
| 370 |
eval_config = EvaluationConfig(
|
| 371 |
dataset_name="PresRetrieve_5",
|
|
|
|
| 372 |
evaluators=[
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
create_llm_relevance_evaluator(llm),
|
| 378 |
],
|
| 379 |
scorers=[MinScorer(), ExponentialScorer()],
|
| 380 |
max_concurrency=1,
|
|
|
|
| 4 |
from collections import OrderedDict
|
| 5 |
from functools import partial
|
| 6 |
from textwrap import dedent
|
| 7 |
+
from typing import ClassVar, Dict, List, Optional
|
| 8 |
|
| 9 |
import pandas as pd
|
| 10 |
from langchain_core import outputs
|
|
|
|
| 17 |
from langsmith.schemas import Dataset
|
| 18 |
from langsmith.utils import LangSmithError
|
| 19 |
from pandas._libs.tslibs.np_datetime import py_td64_to_tdstruct
|
| 20 |
+
from pandas.core.dtypes.dtypes import re
|
| 21 |
from pydantic import BaseModel, ConfigDict, Field
|
| 22 |
from ragas import SingleTurnSample
|
| 23 |
from ragas.llms.base import LangchainLLMWrapper
|
|
|
|
| 30 |
PresentationRetriever,
|
| 31 |
ScorerTypes,
|
| 32 |
)
|
| 33 |
+
from src.rag.storage import LLMPresentationRetriever
|
| 34 |
|
| 35 |
|
| 36 |
@run_evaluator
|
|
|
|
| 213 |
|
| 214 |
# Configure Retrieval
|
| 215 |
scorers: List[ScorerTypes] = [MinScorer(), HyperbolicScorer()]
|
| 216 |
+
retriever: PresentationRetriever
|
| 217 |
|
| 218 |
# Setup Evaluators
|
| 219 |
evaluators: List[DynamicRunEvaluator] = [presentation_match, page_match]
|
| 220 |
|
| 221 |
# Configure RAGAS
|
| 222 |
# ragas_metrics: List[type] = [Faithfulness] # List of metric classes
|
| 223 |
+
n_contexts: int = 10
|
| 224 |
+
n_pages: int = 3
|
| 225 |
|
| 226 |
# Configure evaluation
|
| 227 |
max_concurrency: int = 2
|
|
|
|
| 236 |
|
| 237 |
def __init__(
|
| 238 |
self,
|
|
|
|
| 239 |
config: EvaluationConfig,
|
| 240 |
llm: ChatOpenAI = Config().model_config.load_vsegpt(model="openai/gpt-4o-mini"),
|
| 241 |
):
|
|
|
|
| 251 |
)
|
| 252 |
|
| 253 |
# Setup class
|
|
|
|
| 254 |
self.client = Client()
|
| 255 |
self.config = config
|
| 256 |
+
self.llm = llm
|
| 257 |
+
self.llm_wrapped = LangchainLLMWrapper(self.llm)
|
| 258 |
|
| 259 |
@classmethod
|
| 260 |
def load_questions_from_sheet(cls, *args, **kwargs) -> pd.DataFrame:
|
|
|
|
| 334 |
else:
|
| 335 |
experiment_prefix = f"{scorer.id}"
|
| 336 |
|
| 337 |
+
retriever = self.config.retriever
|
| 338 |
+
retriever.set_scorer(scorer)
|
|
|
|
| 339 |
evaluate(
|
| 340 |
retriever,
|
| 341 |
experiment_prefix=experiment_prefix,
|
| 342 |
data=self.config.dataset_name,
|
| 343 |
evaluators=list(chains.values()),
|
| 344 |
+
metadata=dict(
|
| 345 |
+
scorer=scorer.id,
|
| 346 |
+
retriever=self.config.retriever.__class__.__name__,
|
| 347 |
+
),
|
| 348 |
max_concurrency=self.config.max_concurrency,
|
| 349 |
)
|
| 350 |
|
|
|
|
| 373 |
storage = ChromaSlideStore(collection_name="pres0", embedding_model=embeddings)
|
| 374 |
eval_config = EvaluationConfig(
|
| 375 |
dataset_name="PresRetrieve_5",
|
| 376 |
+
retriever_cls=LLMPresentationRetriever,
|
| 377 |
evaluators=[
|
| 378 |
+
presentation_match,
|
| 379 |
+
presentation_found,
|
| 380 |
+
page_match,
|
| 381 |
+
page_found,
|
| 382 |
+
# create_llm_relevance_evaluator(llm),
|
| 383 |
],
|
| 384 |
scorers=[MinScorer(), ExponentialScorer()],
|
| 385 |
max_concurrency=1,
|
src/rag/storage.py
CHANGED
|
@@ -9,8 +9,14 @@ import numpy as np
|
|
| 9 |
from chromadb.api.types import QueryResult
|
| 10 |
from chromadb.config import Settings
|
| 11 |
from datasets.utils import metadata
|
|
|
|
| 12 |
from langchain.schema import Document
|
|
|
|
| 13 |
from langchain_core.embeddings import Embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 15 |
from pandas.core.algorithms import rank
|
| 16 |
from pydantic import BaseModel, ConfigDict, Field, conbytes
|
|
@@ -769,6 +775,7 @@ class PresentationRetriever(BaseModel):
|
|
| 769 |
storage: ChromaSlideStore
|
| 770 |
scorer: BaseScorer = ExponentialScorer()
|
| 771 |
n_contexts: int = -1
|
|
|
|
| 772 |
retrieve_page_contexts: bool = True
|
| 773 |
|
| 774 |
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
@@ -850,8 +857,12 @@ class PresentationRetriever(BaseModel):
|
|
| 850 |
metadata_filter=metadata_filter,
|
| 851 |
)
|
| 852 |
|
|
|
|
|
|
|
|
|
|
| 853 |
contexts = []
|
| 854 |
-
|
|
|
|
| 855 |
|
| 856 |
# Gather relevant info from presentation
|
| 857 |
pres_info = dict(
|
|
@@ -860,8 +871,10 @@ class PresentationRetriever(BaseModel):
|
|
| 860 |
)
|
| 861 |
|
| 862 |
if self.retrieve_page_contexts:
|
| 863 |
-
page_contexts = self.format_contexts(pres, self.
|
| 864 |
-
pres_info["contexts"] =
|
|
|
|
|
|
|
| 865 |
|
| 866 |
contexts.append(pres_info)
|
| 867 |
|
|
@@ -874,6 +887,140 @@ class PresentationRetriever(BaseModel):
|
|
| 874 |
def __call__(self, inputs: Dict[str, Any]):
|
| 875 |
return self.retrieve(inputs["question"])
|
| 876 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
|
| 878 |
# def create_slides_database(
|
| 879 |
# presentations: List[PresentationAnalysis], collection_name: str = "slides"
|
|
|
|
| 9 |
from chromadb.api.types import QueryResult
|
| 10 |
from chromadb.config import Settings
|
| 11 |
from datasets.utils import metadata
|
| 12 |
+
from langchain.chains.base import Chain
|
| 13 |
from langchain.schema import Document
|
| 14 |
+
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
| 15 |
from langchain_core.embeddings import Embeddings
|
| 16 |
+
from langchain_core.language_models import BaseLanguageModel
|
| 17 |
+
from langchain_core.output_parsers import JsonOutputParser
|
| 18 |
+
from langchain_core.prompts import PromptTemplate
|
| 19 |
+
from langchain_openai import ChatOpenAI
|
| 20 |
from langchain_openai.embeddings import OpenAIEmbeddings
|
| 21 |
from pandas.core.algorithms import rank
|
| 22 |
from pydantic import BaseModel, ConfigDict, Field, conbytes
|
|
|
|
| 775 |
storage: ChromaSlideStore
|
| 776 |
scorer: BaseScorer = ExponentialScorer()
|
| 777 |
n_contexts: int = -1
|
| 778 |
+
n_pages: int = -1
|
| 779 |
retrieve_page_contexts: bool = True
|
| 780 |
|
| 781 |
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
| 857 |
metadata_filter=metadata_filter,
|
| 858 |
)
|
| 859 |
|
| 860 |
+
return self.results2contexts(results)
|
| 861 |
+
|
| 862 |
+
def results2contexts(self, results: ScoredPresentations):
|
| 863 |
contexts = []
|
| 864 |
+
n_pres = self.n_contexts if self.n_contexts > 0 else len(results)
|
| 865 |
+
for i, pres in enumerate(results.presentations[:n_pres]):
|
| 866 |
|
| 867 |
# Gather relevant info from presentation
|
| 868 |
pres_info = dict(
|
|
|
|
| 871 |
)
|
| 872 |
|
| 873 |
if self.retrieve_page_contexts:
|
| 874 |
+
page_contexts = self.format_contexts(pres, self.n_pages)
|
| 875 |
+
pres_info["contexts"] = (
|
| 876 |
+
page_contexts # pyright: ignore[reportArgumentType]
|
| 877 |
+
)
|
| 878 |
|
| 879 |
contexts.append(pres_info)
|
| 880 |
|
|
|
|
| 887 |
def __call__(self, inputs: Dict[str, Any]):
|
| 888 |
return self.retrieve(inputs["question"])
|
| 889 |
|
| 890 |
+
def set_scorer(self, scorer: ScorerTypes):
|
| 891 |
+
self.scorer = scorer
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
class LLMPresentationRetriever(PresentationRetriever):
|
| 895 |
+
"""LLM-enhanced retriever that reranks results using structured relevance scoring"""
|
| 896 |
+
|
| 897 |
+
class RelevanceRanking(BaseModel):
|
| 898 |
+
class RelevanceEval(BaseModel):
|
| 899 |
+
document_id: int = Field(description="The id of the document")
|
| 900 |
+
relevance: int = Field(description="Relevance score from 1-10")
|
| 901 |
+
explanation: str = Field(
|
| 902 |
+
description="Short passage to clarify relevance score"
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
results: list[RelevanceEval]
|
| 906 |
+
|
| 907 |
+
llm: ChatOpenAI
|
| 908 |
+
top_k: int = 10
|
| 909 |
+
|
| 910 |
+
_parser: JsonOutputParser = JsonOutputParser(pydantic_object=RelevanceRanking)
|
| 911 |
+
|
| 912 |
+
rerank_prompt: PromptTemplate = PromptTemplate(
|
| 913 |
+
template="""You are evaluating search results for presentation slides.
|
| 914 |
+
Rate how relevant each document is to the given query.
|
| 915 |
+
The relevance score should be from 1-10 where:
|
| 916 |
+
- 1-3: Low relevance, mostly unrelated content
|
| 917 |
+
- 4-6: Moderate relevance, some related points
|
| 918 |
+
- 7-8: High relevance, clearly addresses the query
|
| 919 |
+
- 9-10: Perfect match, directly answers the query
|
| 920 |
+
|
| 921 |
+
Evaluate ALL documents and provide brief explanations.
|
| 922 |
+
|
| 923 |
+
Presentations to evaluate:
|
| 924 |
+
|
| 925 |
+
{context_str}
|
| 926 |
+
|
| 927 |
+
Question: {query_str}
|
| 928 |
+
|
| 929 |
+
Output Formatting:
|
| 930 |
+
{format_instructions}
|
| 931 |
+
""",
|
| 932 |
+
input_variables=["context_str", "query_str", "format_instructions"],
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
def _format_presentations(self, presentations: List[Dict[str, Any]]) -> str:
|
| 936 |
+
"""Format presentations for LLM evaluation"""
|
| 937 |
+
formatted = []
|
| 938 |
+
for i, pres in enumerate(presentations):
|
| 939 |
+
content = [f"Document {i+1}:"]
|
| 940 |
+
content.append(f"Title: {pres['pres_name']}")
|
| 941 |
+
|
| 942 |
+
if "contexts" in pres:
|
| 943 |
+
content.append("Content:")
|
| 944 |
+
content.extend(pres["contexts"])
|
| 945 |
+
|
| 946 |
+
formatted.append("\n".join(content))
|
| 947 |
+
|
| 948 |
+
return "\n\n".join(formatted)
|
| 949 |
+
|
| 950 |
+
def _rerank_results(
|
| 951 |
+
self,
|
| 952 |
+
results: List[Dict[str, Any]],
|
| 953 |
+
query: str,
|
| 954 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
| 955 |
+
) -> List[Dict[str, Any]]:
|
| 956 |
+
"""Rerank results using LLM relevance scoring"""
|
| 957 |
+
# Format input for LLM
|
| 958 |
+
context_str = self._format_presentations(results)
|
| 959 |
+
|
| 960 |
+
# Get LLM evaluation
|
| 961 |
+
chain = self.rerank_prompt | self.llm.with_structured_output(
|
| 962 |
+
self.RelevanceRanking
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
ranking = chain.invoke(
|
| 966 |
+
{
|
| 967 |
+
"context_str": context_str,
|
| 968 |
+
"query_str": query,
|
| 969 |
+
"format_instructions": self._parser.get_format_instructions(),
|
| 970 |
+
},
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
if len(ranking.results) != len(results):
|
| 974 |
+
print(f"Reranker returned {len(ranking.results)} results when should {len(results)}")
|
| 975 |
+
logger.warning(f"Reranker returned {len(ranking.results)} results when should {len(results)}")
|
| 976 |
+
|
| 977 |
+
# Sort results by relevance score
|
| 978 |
+
sorted_evals = sorted(
|
| 979 |
+
ranking.results, # pyright: ignore[reportAttributeAccessIssue]
|
| 980 |
+
key=lambda x: x.relevance,
|
| 981 |
+
reverse=True,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
# Reorder original results
|
| 985 |
+
reranked = [
|
| 986 |
+
results[eval.document_id - 1].copy()
|
| 987 |
+
for eval in sorted_evals[: self.top_k]
|
| 988 |
+
if eval.document_id-1 < len(results)
|
| 989 |
+
]
|
| 990 |
+
|
| 991 |
+
# Add LLM scoring info
|
| 992 |
+
for i in range(min(len(reranked), self.top_k)):
|
| 993 |
+
reranked[i]["llm_score"] = sorted_evals[i].relevance
|
| 994 |
+
reranked[i]["llm_explanation"] = sorted_evals[i].explanation
|
| 995 |
+
|
| 996 |
+
return reranked
|
| 997 |
+
|
| 998 |
+
def __call__(
|
| 999 |
+
self,
|
| 1000 |
+
inputs: Dict[str, Any],
|
| 1001 |
+
) -> Dict[str, Any]:
|
| 1002 |
+
"""Run the chain"""
|
| 1003 |
+
# Get base retrieval results
|
| 1004 |
+
base_results = super().retrieve(query=inputs["question"])
|
| 1005 |
+
|
| 1006 |
+
# Rerank using LLM
|
| 1007 |
+
if len(base_results["contexts"]) > 1:
|
| 1008 |
+
reranked = self._rerank_results(
|
| 1009 |
+
base_results["contexts"],
|
| 1010 |
+
inputs["question"],
|
| 1011 |
+
)
|
| 1012 |
+
else:
|
| 1013 |
+
reranked = base_results["contexts"]
|
| 1014 |
+
|
| 1015 |
+
# Combine contexts from reranked results
|
| 1016 |
+
all_contexts = []
|
| 1017 |
+
for result in reranked:
|
| 1018 |
+
all_contexts.extend(result["contexts"])
|
| 1019 |
+
|
| 1020 |
+
return dict(
|
| 1021 |
+
contexts=reranked,
|
| 1022 |
+
)
|
| 1023 |
+
|
| 1024 |
|
| 1025 |
# def create_slides_database(
|
| 1026 |
# presentations: List[PresentationAnalysis], collection_name: str = "slides"
|