Ilia Tambovtsev commited on
Commit
e37d064
·
1 Parent(s): 2cb5a84

feat: implement retrieval with llm reranking

Browse files
Files changed (3) hide show
  1. src/eval/eval_mlflow.py +2 -2
  2. src/eval/evaluate.py +20 -15
  3. src/rag/storage.py +150 -3
src/eval/eval_mlflow.py CHANGED
@@ -295,7 +295,7 @@ class RAGEvaluator:
295
  retriever = PresentationRetriever(
296
  storage=self.storage,
297
  scorer=scorer,
298
- n_contexts=self.config.n_contexts,
299
  )
300
 
301
  # Run evaluation for each question
@@ -313,7 +313,7 @@ class RAGEvaluator:
313
  "pages": [int(x) if x else -1 for x in row["page"].split(",")],
314
  }
315
 
316
- output = retriever.retrieve(row["question"]) # pyright: ignore[reportArgumentType]
317
 
318
  self._logger.info(
319
  f"Retrieved {len(output['contexts'])} presentations"
 
295
  retriever = PresentationRetriever(
296
  storage=self.storage,
297
  scorer=scorer,
298
+ n_pages=self.config.n_contexts,
299
  )
300
 
301
  # Run evaluation for each question
 
313
  "pages": [int(x) if x else -1 for x in row["page"].split(",")],
314
  }
315
 
316
+ output = retriever(dict(question=row["question"]))
317
 
318
  self._logger.info(
319
  f"Retrieved {len(output['contexts'])} presentations"
src/eval/evaluate.py CHANGED
@@ -4,7 +4,7 @@ import time
4
  from collections import OrderedDict
5
  from functools import partial
6
  from textwrap import dedent
7
- from typing import Dict, List, Optional
8
 
9
  import pandas as pd
10
  from langchain_core import outputs
@@ -17,6 +17,7 @@ from langsmith.evaluation.evaluator import DynamicRunEvaluator, EvaluationResult
17
  from langsmith.schemas import Dataset
18
  from langsmith.utils import LangSmithError
19
  from pandas._libs.tslibs.np_datetime import py_td64_to_tdstruct
 
20
  from pydantic import BaseModel, ConfigDict, Field
21
  from ragas import SingleTurnSample
22
  from ragas.llms.base import LangchainLLMWrapper
@@ -29,6 +30,7 @@ from src.rag import (
29
  PresentationRetriever,
30
  ScorerTypes,
31
  )
 
32
 
33
 
34
  @run_evaluator
@@ -211,13 +213,15 @@ class EvaluationConfig(BaseModel):
211
 
212
  # Configure Retrieval
213
  scorers: List[ScorerTypes] = [MinScorer(), HyperbolicScorer()]
 
214
 
215
  # Setup Evaluators
216
  evaluators: List[DynamicRunEvaluator] = [presentation_match, page_match]
217
 
218
  # Configure RAGAS
219
  # ragas_metrics: List[type] = [Faithfulness] # List of metric classes
220
- n_contexts: int = 2
 
221
 
222
  # Configure evaluation
223
  max_concurrency: int = 2
@@ -232,7 +236,6 @@ class RAGEvaluatorLangsmith:
232
 
233
  def __init__(
234
  self,
235
- storage: ChromaSlideStore,
236
  config: EvaluationConfig,
237
  llm: ChatOpenAI = Config().model_config.load_vsegpt(model="openai/gpt-4o-mini"),
238
  ):
@@ -248,11 +251,10 @@ class RAGEvaluatorLangsmith:
248
  )
249
 
250
  # Setup class
251
- self.storage = storage
252
  self.client = Client()
253
  self.config = config
254
- llm_unwrapped = llm
255
- self.llm = LangchainLLMWrapper(llm_unwrapped)
256
 
257
  @classmethod
258
  def load_questions_from_sheet(cls, *args, **kwargs) -> pd.DataFrame:
@@ -332,15 +334,17 @@ class RAGEvaluatorLangsmith:
332
  else:
333
  experiment_prefix = f"{scorer.id}"
334
 
335
- retriever = PresentationRetriever(
336
- storage=self.storage, scorer=scorer, n_contexts=self.config.n_contexts
337
- )
338
  evaluate(
339
  retriever,
340
  experiment_prefix=experiment_prefix,
341
  data=self.config.dataset_name,
342
  evaluators=list(chains.values()),
343
- metadata=dict(scorer=scorer.id),
 
 
 
344
  max_concurrency=self.config.max_concurrency,
345
  )
346
 
@@ -369,12 +373,13 @@ def main():
369
  storage = ChromaSlideStore(collection_name="pres0", embedding_model=embeddings)
370
  eval_config = EvaluationConfig(
371
  dataset_name="PresRetrieve_5",
 
372
  evaluators=[
373
- # presentation_match,
374
- # presentation_found,
375
- # page_match,
376
- # page_found,
377
- create_llm_relevance_evaluator(llm),
378
  ],
379
  scorers=[MinScorer(), ExponentialScorer()],
380
  max_concurrency=1,
 
4
  from collections import OrderedDict
5
  from functools import partial
6
  from textwrap import dedent
7
+ from typing import ClassVar, Dict, List, Optional
8
 
9
  import pandas as pd
10
  from langchain_core import outputs
 
17
  from langsmith.schemas import Dataset
18
  from langsmith.utils import LangSmithError
19
  from pandas._libs.tslibs.np_datetime import py_td64_to_tdstruct
20
+ from pandas.core.dtypes.dtypes import re
21
  from pydantic import BaseModel, ConfigDict, Field
22
  from ragas import SingleTurnSample
23
  from ragas.llms.base import LangchainLLMWrapper
 
30
  PresentationRetriever,
31
  ScorerTypes,
32
  )
33
+ from src.rag.storage import LLMPresentationRetriever
34
 
35
 
36
  @run_evaluator
 
213
 
214
  # Configure Retrieval
215
  scorers: List[ScorerTypes] = [MinScorer(), HyperbolicScorer()]
216
+ retriever: PresentationRetriever
217
 
218
  # Setup Evaluators
219
  evaluators: List[DynamicRunEvaluator] = [presentation_match, page_match]
220
 
221
  # Configure RAGAS
222
  # ragas_metrics: List[type] = [Faithfulness] # List of metric classes
223
+ n_contexts: int = 10
224
+ n_pages: int = 3
225
 
226
  # Configure evaluation
227
  max_concurrency: int = 2
 
236
 
237
  def __init__(
238
  self,
 
239
  config: EvaluationConfig,
240
  llm: ChatOpenAI = Config().model_config.load_vsegpt(model="openai/gpt-4o-mini"),
241
  ):
 
251
  )
252
 
253
  # Setup class
 
254
  self.client = Client()
255
  self.config = config
256
+ self.llm = llm
257
+ self.llm_wrapped = LangchainLLMWrapper(self.llm)
258
 
259
  @classmethod
260
  def load_questions_from_sheet(cls, *args, **kwargs) -> pd.DataFrame:
 
334
  else:
335
  experiment_prefix = f"{scorer.id}"
336
 
337
+ retriever = self.config.retriever
338
+ retriever.set_scorer(scorer)
 
339
  evaluate(
340
  retriever,
341
  experiment_prefix=experiment_prefix,
342
  data=self.config.dataset_name,
343
  evaluators=list(chains.values()),
344
+ metadata=dict(
345
+ scorer=scorer.id,
346
+ retriever=self.config.retriever.__class__.__name__,
347
+ ),
348
  max_concurrency=self.config.max_concurrency,
349
  )
350
 
 
373
  storage = ChromaSlideStore(collection_name="pres0", embedding_model=embeddings)
374
  eval_config = EvaluationConfig(
375
  dataset_name="PresRetrieve_5",
376
+ retriever_cls=LLMPresentationRetriever,
377
  evaluators=[
378
+ presentation_match,
379
+ presentation_found,
380
+ page_match,
381
+ page_found,
382
+ # create_llm_relevance_evaluator(llm),
383
  ],
384
  scorers=[MinScorer(), ExponentialScorer()],
385
  max_concurrency=1,
src/rag/storage.py CHANGED
@@ -9,8 +9,14 @@ import numpy as np
9
  from chromadb.api.types import QueryResult
10
  from chromadb.config import Settings
11
  from datasets.utils import metadata
 
12
  from langchain.schema import Document
 
13
  from langchain_core.embeddings import Embeddings
 
 
 
 
14
  from langchain_openai.embeddings import OpenAIEmbeddings
15
  from pandas.core.algorithms import rank
16
  from pydantic import BaseModel, ConfigDict, Field, conbytes
@@ -769,6 +775,7 @@ class PresentationRetriever(BaseModel):
769
  storage: ChromaSlideStore
770
  scorer: BaseScorer = ExponentialScorer()
771
  n_contexts: int = -1
 
772
  retrieve_page_contexts: bool = True
773
 
774
  model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -850,8 +857,12 @@ class PresentationRetriever(BaseModel):
850
  metadata_filter=metadata_filter,
851
  )
852
 
 
 
 
853
  contexts = []
854
- for pres in results.presentations:
 
855
 
856
  # Gather relevant info from presentation
857
  pres_info = dict(
@@ -860,8 +871,10 @@ class PresentationRetriever(BaseModel):
860
  )
861
 
862
  if self.retrieve_page_contexts:
863
- page_contexts = self.format_contexts(pres, self.n_contexts)
864
- pres_info["contexts"] = page_contexts
 
 
865
 
866
  contexts.append(pres_info)
867
 
@@ -874,6 +887,140 @@ class PresentationRetriever(BaseModel):
874
  def __call__(self, inputs: Dict[str, Any]):
875
  return self.retrieve(inputs["question"])
876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
 
878
  # def create_slides_database(
879
  # presentations: List[PresentationAnalysis], collection_name: str = "slides"
 
9
  from chromadb.api.types import QueryResult
10
  from chromadb.config import Settings
11
  from datasets.utils import metadata
12
+ from langchain.chains.base import Chain
13
  from langchain.schema import Document
14
+ from langchain_core.callbacks.manager import CallbackManagerForChainRun
15
  from langchain_core.embeddings import Embeddings
16
+ from langchain_core.language_models import BaseLanguageModel
17
+ from langchain_core.output_parsers import JsonOutputParser
18
+ from langchain_core.prompts import PromptTemplate
19
+ from langchain_openai import ChatOpenAI
20
  from langchain_openai.embeddings import OpenAIEmbeddings
21
  from pandas.core.algorithms import rank
22
  from pydantic import BaseModel, ConfigDict, Field, conbytes
 
775
  storage: ChromaSlideStore
776
  scorer: BaseScorer = ExponentialScorer()
777
  n_contexts: int = -1
778
+ n_pages: int = -1
779
  retrieve_page_contexts: bool = True
780
 
781
  model_config = ConfigDict(arbitrary_types_allowed=True)
 
857
  metadata_filter=metadata_filter,
858
  )
859
 
860
+ return self.results2contexts(results)
861
+
862
+ def results2contexts(self, results: ScoredPresentations):
863
  contexts = []
864
+ n_pres = self.n_contexts if self.n_contexts > 0 else len(results)
865
+ for i, pres in enumerate(results.presentations[:n_pres]):
866
 
867
  # Gather relevant info from presentation
868
  pres_info = dict(
 
871
  )
872
 
873
  if self.retrieve_page_contexts:
874
+ page_contexts = self.format_contexts(pres, self.n_pages)
875
+ pres_info["contexts"] = (
876
+ page_contexts # pyright: ignore[reportArgumentType]
877
+ )
878
 
879
  contexts.append(pres_info)
880
 
 
887
  def __call__(self, inputs: Dict[str, Any]):
888
  return self.retrieve(inputs["question"])
889
 
890
+ def set_scorer(self, scorer: ScorerTypes):
891
+ self.scorer = scorer
892
+
893
+
894
+ class LLMPresentationRetriever(PresentationRetriever):
895
+ """LLM-enhanced retriever that reranks results using structured relevance scoring"""
896
+
897
+ class RelevanceRanking(BaseModel):
898
+ class RelevanceEval(BaseModel):
899
+ document_id: int = Field(description="The id of the document")
900
+ relevance: int = Field(description="Relevance score from 1-10")
901
+ explanation: str = Field(
902
+ description="Short passage to clarify relevance score"
903
+ )
904
+
905
+ results: list[RelevanceEval]
906
+
907
+ llm: ChatOpenAI
908
+ top_k: int = 10
909
+
910
+ _parser: JsonOutputParser = JsonOutputParser(pydantic_object=RelevanceRanking)
911
+
912
+ rerank_prompt: PromptTemplate = PromptTemplate(
913
+ template="""You are evaluating search results for presentation slides.
914
+ Rate how relevant each document is to the given query.
915
+ The relevance score should be from 1-10 where:
916
+ - 1-3: Low relevance, mostly unrelated content
917
+ - 4-6: Moderate relevance, some related points
918
+ - 7-8: High relevance, clearly addresses the query
919
+ - 9-10: Perfect match, directly answers the query
920
+
921
+ Evaluate ALL documents and provide brief explanations.
922
+
923
+ Presentations to evaluate:
924
+
925
+ {context_str}
926
+
927
+ Question: {query_str}
928
+
929
+ Output Formatting:
930
+ {format_instructions}
931
+ """,
932
+ input_variables=["context_str", "query_str", "format_instructions"],
933
+ )
934
+
935
+ def _format_presentations(self, presentations: List[Dict[str, Any]]) -> str:
936
+ """Format presentations for LLM evaluation"""
937
+ formatted = []
938
+ for i, pres in enumerate(presentations):
939
+ content = [f"Document {i+1}:"]
940
+ content.append(f"Title: {pres['pres_name']}")
941
+
942
+ if "contexts" in pres:
943
+ content.append("Content:")
944
+ content.extend(pres["contexts"])
945
+
946
+ formatted.append("\n".join(content))
947
+
948
+ return "\n\n".join(formatted)
949
+
950
+ def _rerank_results(
951
+ self,
952
+ results: List[Dict[str, Any]],
953
+ query: str,
954
+ run_manager: Optional[CallbackManagerForChainRun] = None,
955
+ ) -> List[Dict[str, Any]]:
956
+ """Rerank results using LLM relevance scoring"""
957
+ # Format input for LLM
958
+ context_str = self._format_presentations(results)
959
+
960
+ # Get LLM evaluation
961
+ chain = self.rerank_prompt | self.llm.with_structured_output(
962
+ self.RelevanceRanking
963
+ )
964
+
965
+ ranking = chain.invoke(
966
+ {
967
+ "context_str": context_str,
968
+ "query_str": query,
969
+ "format_instructions": self._parser.get_format_instructions(),
970
+ },
971
+ )
972
+
973
+ if len(ranking.results) != len(results):
974
+ print(f"Reranker returned {len(ranking.results)} results when should {len(results)}")
975
+ logger.warning(f"Reranker returned {len(ranking.results)} results when should {len(results)}")
976
+
977
+ # Sort results by relevance score
978
+ sorted_evals = sorted(
979
+ ranking.results, # pyright: ignore[reportAttributeAccessIssue]
980
+ key=lambda x: x.relevance,
981
+ reverse=True,
982
+ )
983
+
984
+ # Reorder original results
985
+ reranked = [
986
+ results[eval.document_id - 1].copy()
987
+ for eval in sorted_evals[: self.top_k]
988
+ if eval.document_id-1 < len(results)
989
+ ]
990
+
991
+ # Add LLM scoring info
992
+ for i in range(min(len(reranked), self.top_k)):
993
+ reranked[i]["llm_score"] = sorted_evals[i].relevance
994
+ reranked[i]["llm_explanation"] = sorted_evals[i].explanation
995
+
996
+ return reranked
997
+
998
+ def __call__(
999
+ self,
1000
+ inputs: Dict[str, Any],
1001
+ ) -> Dict[str, Any]:
1002
+ """Run the chain"""
1003
+ # Get base retrieval results
1004
+ base_results = super().retrieve(query=inputs["question"])
1005
+
1006
+ # Rerank using LLM
1007
+ if len(base_results["contexts"]) > 1:
1008
+ reranked = self._rerank_results(
1009
+ base_results["contexts"],
1010
+ inputs["question"],
1011
+ )
1012
+ else:
1013
+ reranked = base_results["contexts"]
1014
+
1015
+ # Combine contexts from reranked results
1016
+ all_contexts = []
1017
+ for result in reranked:
1018
+ all_contexts.extend(result["contexts"])
1019
+
1020
+ return dict(
1021
+ contexts=reranked,
1022
+ )
1023
+
1024
 
1025
  # def create_slides_database(
1026
  # presentations: List[PresentationAnalysis], collection_name: str = "slides"