Ilia Tambovtsev commited on
Commit
e887b28
·
1 Parent(s): b105b1d

feat: add presentationidx metric

Browse files
Files changed (1) hide show
  1. src/eval/eval_mlflow.py +30 -10
src/eval/eval_mlflow.py CHANGED
@@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Protocol, Union
12
 
13
  import mlflow
14
  import mlflow.config
 
15
  import pandas as pd
16
  from dotenv import load_dotenv
17
  from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
@@ -96,6 +97,25 @@ class PresentationFound(BaseMetric):
96
  )
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  class PageMatch(BaseMetric):
100
  """Check if best page matches ground truth"""
101
 
@@ -169,8 +189,6 @@ class PresentationCount(BaseMetric):
169
 
170
 
171
  class BestChunkMatch(BaseMetric):
172
- """Count number of retrieved presentations"""
173
-
174
  async def acalculate(self, run_output: Dict, ground_truth: Dict) -> MetricResult:
175
  """Count presentations in retrieved results"""
176
  best_pres = run_output["contexts"][0]
@@ -293,6 +311,7 @@ class MetricsRegistry:
293
  _metrics = {
294
  "presentationmatch": PresentationMatch,
295
  "presentationfound": PresentationFound,
 
296
  "pagematch": PageMatch,
297
  "pagefound": PageFound,
298
  "presentationcount": PresentationCount,
@@ -303,7 +322,6 @@ class MetricsRegistry:
303
  @classmethod
304
  def create(cls, metric_name: str, **kwargs) -> BaseMetric:
305
  """Create metric instance by name"""
306
- # __import__('pdb').set_trace()
307
  metric_cls = cls._metrics.get(metric_name.lower())
308
  if metric_cls is None:
309
  raise ValueError(f"Unknown metric: {metric_name}")
@@ -316,6 +334,7 @@ class MetricPresets:
316
  BASIC = [
317
  "presentationmatch",
318
  "presentationfound",
 
319
  "pagematch",
320
  "pagefound",
321
  "presentationcount",
@@ -618,15 +637,16 @@ class RAGEvaluatorMlflow:
618
  # Initialize retriever
619
  retriever = self.config.get_retriever_with_scorer(scorer)
620
 
 
 
 
 
 
 
621
  with mlflow.start_run(
622
- run_name=f"scorer_{scorer.id}__retriever_{retriever.id}"
623
  ):
624
  # Log preprocessor
625
- preprocessor_id = (
626
- retriever.storage.query_preprocessor.id
627
- if retriever.storage.query_preprocessor
628
- else "None"
629
- )
630
  mlflow.log_params({"preprocessing": preprocessor_id})
631
  self._logger.info(f"Using preprocessor: {preprocessor_id}")
632
 
@@ -696,7 +716,7 @@ class RAGEvaluatorMlflow:
696
  # Log metrics
697
  for name, values in metric_values.items():
698
  if values:
699
- mean_value = sum(values) / len(values)
700
  mlflow.log_metric(f"mean_{name}", mean_value)
701
  mlflow.log_metric(f"n_questions", len(questions_df))
702
  mlflow.log_metric(f"error_rate", n_errors / len(questions_df))
 
12
 
13
  import mlflow
14
  import mlflow.config
15
+ import numpy as np
16
  import pandas as pd
17
  from dotenv import load_dotenv
18
  from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
 
97
  )
98
 
99
 
100
+ class PresentationIdx(BaseMetric):
101
+ async def acalculate(self, run_output: Dict, ground_truth: Dict) -> MetricResult:
102
+ found_pres_names = [c["pres_name"] for c in run_output["contexts"]]
103
+ score = float("nan")
104
+ for i, pres in enumerate(found_pres_names):
105
+ if pres == ground_truth["pres_name"]:
106
+ score = float(i + 1)
107
+
108
+ return MetricResult(
109
+ name=self.name,
110
+ score=score,
111
+ explanation=(
112
+ f"Presentation was found at position {score}"
113
+ if score != float("nan")
114
+ else "Presentation was not found"
115
+ ),
116
+ )
117
+
118
+
119
  class PageMatch(BaseMetric):
120
  """Check if best page matches ground truth"""
121
 
 
189
 
190
 
191
  class BestChunkMatch(BaseMetric):
 
 
192
  async def acalculate(self, run_output: Dict, ground_truth: Dict) -> MetricResult:
193
  """Count presentations in retrieved results"""
194
  best_pres = run_output["contexts"][0]
 
311
  _metrics = {
312
  "presentationmatch": PresentationMatch,
313
  "presentationfound": PresentationFound,
314
+ "presentationidx": PresentationIdx,
315
  "pagematch": PageMatch,
316
  "pagefound": PageFound,
317
  "presentationcount": PresentationCount,
 
322
  @classmethod
323
  def create(cls, metric_name: str, **kwargs) -> BaseMetric:
324
  """Create metric instance by name"""
 
325
  metric_cls = cls._metrics.get(metric_name.lower())
326
  if metric_cls is None:
327
  raise ValueError(f"Unknown metric: {metric_name}")
 
334
  BASIC = [
335
  "presentationmatch",
336
  "presentationfound",
337
+ "presentationidx",
338
  "pagematch",
339
  "pagefound",
340
  "presentationcount",
 
637
  # Initialize retriever
638
  retriever = self.config.get_retriever_with_scorer(scorer)
639
 
640
+ # Get preprocessor id
641
+ preprocessor_id = (
642
+ retriever.storage.query_preprocessor.id
643
+ if retriever.storage.query_preprocessor
644
+ else "None"
645
+ )
646
  with mlflow.start_run(
647
+ run_name=f"scorer_{scorer.id}__retriever_{retriever.id}__preprocessor_{preprocessor_id}"
648
  ):
649
  # Log preprocessor
 
 
 
 
 
650
  mlflow.log_params({"preprocessing": preprocessor_id})
651
  self._logger.info(f"Using preprocessor: {preprocessor_id}")
652
 
 
716
  # Log metrics
717
  for name, values in metric_values.items():
718
  if values:
719
+ mean_value = np.nanmean(values)
720
  mlflow.log_metric(f"mean_{name}", mean_value)
721
  mlflow.log_metric(f"n_questions", len(questions_df))
722
  mlflow.log_metric(f"error_rate", n_errors / len(questions_df))