Spaces:
Runtime error
Runtime error
| from typing import Callable, Optional, Sequence, Union | |
| import logging | |
| from collections import defaultdict | |
| from inspect import signature | |
| from ..llm.client import LLMClient, get_default_client | |
| from ..utils.analytics_collector import analytics | |
| from .knowledge_base import KnowledgeBase | |
| from .metrics import CorrectnessMetric, Metric | |
| from .question_generators.utils import maybe_tqdm | |
| from .recommendation import get_rag_recommendation | |
| from .report import RAGReport | |
| from .testset import QATestset | |
| from .testset_generation import generate_testset | |
| logger = logging.getLogger(__name__) | |
| ANSWER_FN_HISTORY_PARAM = "history" | |
| def evaluate( | |
| answer_fn: Union[Callable, Sequence[str]], | |
| testset: Optional[QATestset] = None, | |
| knowledge_base: Optional[KnowledgeBase] = None, | |
| llm_client: Optional[LLMClient] = None, | |
| agent_description: str = "This agent is a chatbot that answers question from users.", | |
| metrics: Optional[Sequence[Callable]] = None, | |
| ) -> RAGReport: | |
| """Evaluate an agent by comparing its answers on a QATestset. | |
| Parameters | |
| ---------- | |
| answers_fn : Union[Callable, Sequence[str]] | |
| The prediction function of the agent to evaluate or a list of precalculated answers on the testset. | |
| testset : QATestset, optional | |
| The test set to evaluate the agent on. If not provided, a knowledge base must be provided and a default testset will be created from the knowledge base. | |
| Note that if the answers_fn is a list of answers, the testset is required. | |
| knowledge_base : KnowledgeBase, optional | |
| The knowledge base of the agent to evaluate. If not provided, a testset must be provided. | |
| llm_client : LLMClient, optional | |
| The LLM client to use for the evaluation. If not provided, a default openai client will be used. | |
| agent_description : str, optional | |
| Description of the agent to be tested. | |
| metrics : Optional[Sequence[Callable]], optional | |
| Metrics to compute on the test set. | |
| Returns | |
| ------- | |
| RAGReport | |
| The report of the evaluation. | |
| """ | |
| validate_inputs(answer_fn, knowledge_base, testset) | |
| testset = testset or generate_testset(knowledge_base) | |
| answers = retrieve_answers(answer_fn, testset) | |
| llm_client = llm_client or get_default_client() | |
| metrics = get_metrics(metrics, llm_client, agent_description) | |
| metrics_results = compute_metrics(metrics, testset, answers) | |
| report = get_report(testset, answers, metrics_results, knowledge_base) | |
| add_recommendation(report, llm_client, metrics) | |
| track_analytics(report, testset, knowledge_base, agent_description, metrics) | |
| return report | |
| def validate_inputs(answer_fn, knowledge_base, testset): | |
| if testset is None: | |
| if knowledge_base is None: | |
| raise ValueError("At least one of testset or knowledge base must be provided to the evaluate function.") | |
| if not isinstance(answer_fn, Sequence): | |
| raise ValueError( | |
| "If the testset is not provided, the answer_fn must be a list of answers to ensure the matching between questions and answers." | |
| ) | |
| testset = generate_testset(knowledge_base) | |
| # Check basic types, in case the user passed the params in the wrong order | |
| if knowledge_base is not None and not isinstance(knowledge_base, KnowledgeBase): | |
| raise ValueError( | |
| f"knowledge_base must be a KnowledgeBase object (got {type(knowledge_base)} instead). Are you sure you passed the parameters in the right order?" | |
| ) | |
| if testset is not None and not isinstance(testset, QATestset): | |
| raise ValueError( | |
| f"testset must be a QATestset object (got {type(testset)} instead). Are you sure you passed the parameters in the right order?" | |
| ) | |
| def retrieve_answers(answer_fn, testset): | |
| return answer_fn if isinstance(answer_fn, Sequence) else _compute_answers(answer_fn, testset) | |
| def get_metrics(metrics, llm_client, agent_description): | |
| metrics = list(metrics) if metrics is not None else [] | |
| if not any(isinstance(metric, CorrectnessMetric) for metric in metrics): | |
| # By default only correctness is computed as it is required to build the report | |
| metrics.insert( | |
| 0, CorrectnessMetric(name="correctness", llm_client=llm_client, agent_description=agent_description) | |
| ) | |
| return metrics | |
| def compute_metrics(metrics, testset, answers): | |
| metrics_results = defaultdict(dict) | |
| for metric in metrics: | |
| metric_name = getattr( | |
| metric, "name", metric.__class__.__name__ if isinstance(metric, Metric) else metric.__name__ | |
| ) | |
| for sample, answer in maybe_tqdm( | |
| zip(testset.to_pandas().to_records(index=True), answers), | |
| desc=f"{metric_name} evaluation", | |
| total=len(answers), | |
| ): | |
| metrics_results[sample["id"]].update(metric(sample, answer)) | |
| return metrics_results | |
| def get_report(testset, answers, metrics_results, knowledge_base): | |
| return RAGReport(testset, answers, metrics_results, knowledge_base) | |
| def add_recommendation(report, llm_client, metrics): | |
| recommendation = get_rag_recommendation( | |
| report.topics, | |
| report.correctness_by_question_type().to_dict()[metrics[0].name], | |
| report.correctness_by_topic().to_dict()[metrics[0].name], | |
| llm_client, | |
| ) | |
| report._recommendation = recommendation | |
| def track_analytics(report, testset, knowledge_base, agent_description, metrics): | |
| analytics.track( | |
| "raget:evaluation", | |
| { | |
| "testset_size": len(testset), | |
| "knowledge_base_size": len(knowledge_base) if knowledge_base else -1, | |
| "agent_description": agent_description, | |
| "num_metrics": len(metrics), | |
| "correctness": report.correctness, | |
| }, | |
| ) | |
| def _compute_answers(answer_fn, testset): | |
| answers = [] | |
| needs_history = ( | |
| len(signature(answer_fn).parameters) > 1 and ANSWER_FN_HISTORY_PARAM in signature(answer_fn).parameters | |
| ) | |
| for sample in maybe_tqdm(testset.samples, desc="Asking questions to the agent", total=len(testset)): | |
| kwargs = {} | |
| if needs_history: | |
| kwargs[ANSWER_FN_HISTORY_PARAM] = sample.conversation_history | |
| answers.append(answer_fn(sample.question, **kwargs)) | |
| return answers |