| import time |
| from abc import ABC, abstractmethod |
| from enum import StrEnum |
| from typing import Any, Callable, ClassVar, NamedTuple, Self, Sequence, cast |
| from uuid import UUID |
|
|
| from deepeval.metrics import GEval |
| from deepeval.test_case import LLMTestCase, LLMTestCaseParams |
| from fastapi.testclient import TestClient |
| from pydantic import BaseModel, Field, model_validator |
|
|
| from dataline.models.llm_flow.enums import QueryResultType |
| from dataline.models.llm_flow.schema import ResultType, SQLQueryRunResult |
| from dataline.models.message.schema import QueryOut |
| from dataline.models.result.schema import ResultOut |
|
|
|
|
| def snake(s: str) -> str: |
| return s.replace(" ", "_").lower() |
|
|
|
|
| def evaluate_message_content( |
| generated_message: str, |
| query: str, |
| eval_steps: list[str], |
| raise_if_fail: bool = False, |
| eval_params: list[LLMTestCaseParams] | None = None, |
| ) -> GEval: |
| """ |
| Evaluates the content of the generated message against the query. |
| |
| Args: |
| - generated_message: generated AI message |
| - query: human message |
| - eval_steps: an explanation of how to evaluate the output |
| - raise_if_fail: if True, runs an assertion that the evaluation is successful |
| - eval_params: the parameters that are relevant for evaluation |
| """ |
| if eval_params is None: |
| eval_params = [LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT] |
| test_case = LLMTestCase(input=query, actual_output=generated_message) |
| |
| correctness_metric = GEval( |
| name="Metric", model="gemini-2.0-flash-lite", evaluation_steps=eval_steps, evaluation_params=eval_params |
| ) |
|
|
| correctness_metric.measure(test_case) |
| if raise_if_fail: |
| assert correctness_metric.is_successful(), "\n".join( |
| [ |
| f"score={correctness_metric.score}", |
| f"reason={correctness_metric.reason}", |
| f"{query=}", |
| f"{generated_message=}", |
| ] |
| ) |
| return correctness_metric |
|
|
|
|
| def get_results_of_type(result_type: QueryResultType, response: QueryOut) -> list[ResultOut]: |
| return [result for result in response.ai_message.results if result.type == result_type.value] |
|
|
|
|
| def get_query_out_from_stream(lines: list[str]) -> QueryOut: |
| assert len(lines) >= 3 |
| assert lines[-3] == "event: stored_messages_event" |
| assert lines[-2].startswith("data: ") |
|
|
| return QueryOut.model_validate_json(lines[-2].lstrip("data: ")) |
|
|
|
|
| def call_query_endpoint(client: TestClient, conversation_id: UUID, query: str) -> QueryOut: |
| body = {"message_options": {"secure_data": True}} |
| with client.stream( |
| method="post", url=f"/conversation/{conversation_id}/query", params={"query": query}, json=body |
| ) as response: |
| assert response.status_code == 200 |
| q_out = get_query_out_from_stream(list(response.iter_lines())) |
|
|
| assert q_out.human_message.content == query |
| return q_out |
|
|
|
|
| EvalName = str |
|
|
|
|
| class Tag(StrEnum): |
| RESPONSE_QUALITY = "Response Quality" |
| DATA_RESULTS_ACCURACY = "Data Results Accuracy" |
| SQL_QUALITY = "SQL Quality" |
| EXPLORATION_ABILITY = "Exploration Ability" |
| MULTI_QUESTION_ACCURACY = "Multi-Question Accuracy" |
| FOLLOWUP_ABILITY = "Followup Ability" |
|
|
|
|
| def render_tags(tags: Sequence[Tag]) -> list[str]: |
| return [snake(t.value) for t in tags] |
|
|
|
|
| class TestMetadata(BaseModel): |
| weight: float = Field(default=1.0, ge=0.0, le=1.0) |
| tags: Sequence[Tag] = [] |
|
|
|
|
| class EvalBlockBase(ABC, BaseModel): |
| metadata: TestMetadata = TestMetadata() |
|
|
| |
| @property |
| @abstractmethod |
| def name(self) -> EvalName: |
| """The name of the test""" |
|
|
| @abstractmethod |
| def evaluate(self, response: QueryOut) -> float: |
| raise NotImplementedError |
|
|
| def evaluate_and_serialize( |
| self, response: QueryOut |
| ) -> tuple[str, float, float, float, Sequence[str]] | list[tuple[str, float, float, float, Sequence[str]]]: |
| start = time.time() |
| score = self.evaluate(response) |
| time_delta = round(time.time() - start, 4) |
| return self.name, score, self.metadata.weight, time_delta, render_tags(self.metadata.tags) |
|
|
|
|
| class EvalAIText(EvalBlockBase): |
| name: ClassVar[str] = "ai_text_evaluation" |
| eval_steps: list[str] |
|
|
| def evaluate(self, response: QueryOut) -> float: |
| ai_text = response.ai_message.message.content |
| evaluation = evaluate_message_content(ai_text, response.human_message.content, self.eval_steps) |
| return evaluation.score or 0.0 |
|
|
|
|
| def comparator(operator: Callable[[int, int], bool], number: int) -> Callable[[int], bool]: |
| return lambda x: operator(x, number) |
|
|
|
|
| class EvalCountResultTuple(NamedTuple): |
| comparator: Callable[[int], bool] |
| weight: float = 1.0 |
|
|
|
|
| class EvalCountResult(EvalBlockBase): |
| name: ClassVar[str] = "result_count" |
| eval_count: tuple[QueryResultType, EvalCountResultTuple] |
|
|
| def evaluate(self, response: QueryOut) -> float: |
| result_type, result_tuple = self.eval_count |
| filtered_results = get_results_of_type(result_type, response) |
| if result_tuple.comparator(len(filtered_results)): |
| return 1 |
| return 0 |
|
|
| def evaluate_and_serialize(self, response: QueryOut) -> tuple[str, float, float, float, Sequence[str]]: |
| start = time.time() |
| score = self.evaluate(response) |
| time_delta = round(time.time() - start, 4) |
| name = f"{self.name}_{self.eval_count[0].value}" |
| return name, score, self.metadata.weight, time_delta, render_tags(self.metadata.tags) |
|
|
|
|
| class GroupedEvalCountResult(EvalBlockBase): |
| name: ClassVar[str] = "grouped_result_count_evaluation" |
| eval_counts: Sequence[EvalCountResult] |
|
|
| def evaluate(self, response: QueryOut) -> float: |
| return 0.0 |
|
|
| def evaluate_and_serialize(self, response: QueryOut) -> list[tuple[str, float, float, float, Sequence[str]]]: |
| results = [] |
| total_weight = sum(eval_count.metadata.weight for eval_count in self.eval_counts) |
|
|
| for eval_count in self.eval_counts: |
| result = eval_count.evaluate_and_serialize(response) |
| if isinstance(result, list): |
| raise ValueError("EvalCountResult should not return multiple results") |
|
|
| |
| adjusted_weight = result[2] / total_weight * eval_count.metadata.weight |
| results.append((result[0], result[1], adjusted_weight, result[3], result[4])) |
|
|
| return results |
|
|
|
|
| class EvalSQLString(EvalBlockBase): |
| name: ClassVar[str] = "sql_string_evaluation" |
| eval_steps: list[str] |
| eval_params: list[LLMTestCaseParams] = [LLMTestCaseParams.ACTUAL_OUTPUT] |
|
|
| def evaluate(self, response: QueryOut) -> float: |
| sql_string_results = get_results_of_type(QueryResultType.SQL_QUERY_STRING_RESULT, response) |
| if len(sql_string_results) == 0: |
| return 0.0 |
| sum_scores = 0.0 |
| for sql_str_res in sql_string_results: |
| generated_sql = cast(str, sql_str_res.content["sql"]) |
| evaluation = evaluate_message_content( |
| generated_sql, |
| response.human_message.content, |
| self.eval_steps, |
| eval_params=self.eval_params, |
| ) |
| sum_scores += evaluation.score or 0.0 |
| return sum_scores / len(sql_string_results) |
|
|
|
|
| class EvalSQLRun(EvalBlockBase): |
| |
| name: ClassVar[str] = "sql_run_evaluation" |
| expected_row_count: int | None = None |
| expected_col_names: list[str] | None = None |
| expected_row_values: list[list[Any]] | None = None |
|
|
| @staticmethod |
| def result_out_to_sql_run_result(result_out: ResultOut) -> SQLQueryRunResult: |
| assert result_out.linked_id is not None |
| return SQLQueryRunResult( |
| columns=result_out.content.get("columns"), |
| rows=result_out.content.get("rows"), |
| result_id=result_out.result_id, |
| linked_id=result_out.linked_id, |
| created_at=result_out.created_at, |
| ) |
|
|
| def evaluate(self, response: QueryOut) -> float: |
| scores: dict[str, float] = {} |
| sql_run_results = get_results_of_type(QueryResultType.SQL_QUERY_RUN_RESULT, response) |
| if len(sql_run_results) == 0: |
| return 0.0 |
| sql_run_result_out = sql_run_results[0] |
| sql_run_result = self.result_out_to_sql_run_result(sql_run_result_out) |
|
|
| if self.expected_row_count is not None: |
| scores["expected_row_count"] = float(len(sql_run_result.rows) == self.expected_row_count) |
|
|
| if self.expected_col_names: |
| cols_present = len([col in sql_run_result.columns for col in self.expected_col_names]) |
| scores["expected_col_names"] = cols_present * 1.0 / len(self.expected_col_names) |
|
|
| if self.expected_row_values: |
| rows_present = 0 |
| for expected_row in self.expected_row_values: |
| expected_row_set = set(expected_row) |
| for row in sql_run_result.rows: |
| row_set = set(row) |
| if expected_row_set.issubset(row_set): |
| rows_present += 1 |
| break |
| scores["expected_row_values"] = rows_present * 1.0 / len(self.expected_row_values) |
|
|
| return sum(val for val in scores.values()) / len(scores) |
|
|
| @model_validator(mode="after") |
| def check_any_evaluation_present(self) -> Self: |
| has_evaluation_present = ( |
| bool(self.expected_col_names) or self.expected_row_count is not None or bool(self.expected_row_values) |
| ) |
| if not has_evaluation_present: |
| raise ValueError("Must have at least one of expected_col_names, expected_row_count, expected_row_values") |
| return self |
|
|
|
|
| class Message(BaseModel): |
| content: str |
|
|
|
|
| |
|
|
|
|
| class MessageWithResults(BaseModel): |
| message: Message |
| results: list[ResultType] |
|
|
|
|
| class MessagePair(BaseModel): |
| human_message: Message |
| ai_message: MessageWithResults |
|
|
|
|
| class TestCase(BaseModel): |
| __test__ = False |
| test_name: str |
| message_history: list[MessagePair] = [] |
| query: str |
| evaluations: list[EvalBlockBase] |
| scores: list[float] = [] |
| total_score: float | None = None |
|
|
| |
| def run( |
| self, client: TestClient, conversation_id: UUID |
| ) -> list[tuple[str, str, float, float, float, Sequence[str]]]: |
| evaluation_results: list[tuple[str, float, float, float, Sequence[str]]] = [] |
| response = call_query_endpoint(client, conversation_id, self.query) |
| for evaluation in self.evaluations: |
| result = evaluation.evaluate_and_serialize(response) |
| if isinstance(result, list): |
| evaluation_results.extend(result) |
| else: |
| evaluation_results.append(result) |
|
|
| return [(snake(self.test_name), *result) for result in evaluation_results] |
|
|