fabi / backend /tests /evaluation /utils.py
Osman2010
Deploy DataLine with Gemini integration to HF Spaces
9470652
import time
from abc import ABC, abstractmethod
from enum import StrEnum
from typing import Any, Callable, ClassVar, NamedTuple, Self, Sequence, cast
from uuid import UUID
from deepeval.metrics import GEval
from deepeval.test_case import LLMTestCase, LLMTestCaseParams
from fastapi.testclient import TestClient
from pydantic import BaseModel, Field, model_validator
from dataline.models.llm_flow.enums import QueryResultType
from dataline.models.llm_flow.schema import ResultType, SQLQueryRunResult
from dataline.models.message.schema import QueryOut
from dataline.models.result.schema import ResultOut
def snake(s: str) -> str:
return s.replace(" ", "_").lower()
def evaluate_message_content(
generated_message: str,
query: str,
eval_steps: list[str],
raise_if_fail: bool = False,
eval_params: list[LLMTestCaseParams] | None = None,
) -> GEval:
"""
Evaluates the content of the generated message against the query.
Args:
- generated_message: generated AI message
- query: human message
- eval_steps: an explanation of how to evaluate the output
- raise_if_fail: if True, runs an assertion that the evaluation is successful
- eval_params: the parameters that are relevant for evaluation
"""
if eval_params is None:
eval_params = [LLMTestCaseParams.INPUT, LLMTestCaseParams.ACTUAL_OUTPUT]
test_case = LLMTestCase(input=query, actual_output=generated_message)
# don't really care about eval name
correctness_metric = GEval(
name="Metric", model="gemini-2.0-flash-lite", evaluation_steps=eval_steps, evaluation_params=eval_params
)
correctness_metric.measure(test_case)
if raise_if_fail:
assert correctness_metric.is_successful(), "\n".join(
[
f"score={correctness_metric.score}",
f"reason={correctness_metric.reason}",
f"{query=}",
f"{generated_message=}",
]
)
return correctness_metric
def get_results_of_type(result_type: QueryResultType, response: QueryOut) -> list[ResultOut]:
return [result for result in response.ai_message.results if result.type == result_type.value]
def get_query_out_from_stream(lines: list[str]) -> QueryOut:
assert len(lines) >= 3
assert lines[-3] == "event: stored_messages_event"
assert lines[-2].startswith("data: ")
return QueryOut.model_validate_json(lines[-2].lstrip("data: "))
def call_query_endpoint(client: TestClient, conversation_id: UUID, query: str) -> QueryOut:
body = {"message_options": {"secure_data": True}}
with client.stream(
method="post", url=f"/conversation/{conversation_id}/query", params={"query": query}, json=body
) as response:
assert response.status_code == 200
q_out = get_query_out_from_stream(list(response.iter_lines()))
assert q_out.human_message.content == query
return q_out
EvalName = str
class Tag(StrEnum):
RESPONSE_QUALITY = "Response Quality"
DATA_RESULTS_ACCURACY = "Data Results Accuracy"
SQL_QUALITY = "SQL Quality"
EXPLORATION_ABILITY = "Exploration Ability"
MULTI_QUESTION_ACCURACY = "Multi-Question Accuracy"
FOLLOWUP_ABILITY = "Followup Ability"
def render_tags(tags: Sequence[Tag]) -> list[str]:
return [snake(t.value) for t in tags]
class TestMetadata(BaseModel):
weight: float = Field(default=1.0, ge=0.0, le=1.0)
tags: Sequence[Tag] = []
class EvalBlockBase(ABC, BaseModel):
metadata: TestMetadata = TestMetadata()
# https://github.com/pydantic/pydantic/discussions/2410#discussioncomment-408613
@property
@abstractmethod
def name(self) -> EvalName:
"""The name of the test"""
@abstractmethod
def evaluate(self, response: QueryOut) -> float:
raise NotImplementedError
def evaluate_and_serialize(
self, response: QueryOut
) -> tuple[str, float, float, float, Sequence[str]] | list[tuple[str, float, float, float, Sequence[str]]]:
start = time.time()
score = self.evaluate(response)
time_delta = round(time.time() - start, 4)
return self.name, score, self.metadata.weight, time_delta, render_tags(self.metadata.tags)
class EvalAIText(EvalBlockBase):
name: ClassVar[str] = "ai_text_evaluation"
eval_steps: list[str]
def evaluate(self, response: QueryOut) -> float:
ai_text = response.ai_message.message.content
evaluation = evaluate_message_content(ai_text, response.human_message.content, self.eval_steps)
return evaluation.score or 0.0 # score: float | None is annoying
def comparator(operator: Callable[[int, int], bool], number: int) -> Callable[[int], bool]:
return lambda x: operator(x, number)
class EvalCountResultTuple(NamedTuple):
comparator: Callable[[int], bool]
weight: float = 1.0
class EvalCountResult(EvalBlockBase):
name: ClassVar[str] = "result_count"
eval_count: tuple[QueryResultType, EvalCountResultTuple]
def evaluate(self, response: QueryOut) -> float:
result_type, result_tuple = self.eval_count
filtered_results = get_results_of_type(result_type, response)
if result_tuple.comparator(len(filtered_results)):
return 1
return 0
def evaluate_and_serialize(self, response: QueryOut) -> tuple[str, float, float, float, Sequence[str]]:
start = time.time()
score = self.evaluate(response)
time_delta = round(time.time() - start, 4)
name = f"{self.name}_{self.eval_count[0].value}"
return name, score, self.metadata.weight, time_delta, render_tags(self.metadata.tags)
class GroupedEvalCountResult(EvalBlockBase):
name: ClassVar[str] = "grouped_result_count_evaluation"
eval_counts: Sequence[EvalCountResult]
def evaluate(self, response: QueryOut) -> float:
return 0.0 # not used for this class
def evaluate_and_serialize(self, response: QueryOut) -> list[tuple[str, float, float, float, Sequence[str]]]:
results = []
total_weight = sum(eval_count.metadata.weight for eval_count in self.eval_counts)
for eval_count in self.eval_counts:
result = eval_count.evaluate_and_serialize(response)
if isinstance(result, list):
raise ValueError("EvalCountResult should not return multiple results")
# adjust weight to sub-eval-count-weight / total_weight (normalize) * grouped_eval_count.metadata.weight
adjusted_weight = result[2] / total_weight * eval_count.metadata.weight
results.append((result[0], result[1], adjusted_weight, result[3], result[4]))
return results
class EvalSQLString(EvalBlockBase):
name: ClassVar[str] = "sql_string_evaluation"
eval_steps: list[str]
eval_params: list[LLMTestCaseParams] = [LLMTestCaseParams.ACTUAL_OUTPUT]
def evaluate(self, response: QueryOut) -> float:
sql_string_results = get_results_of_type(QueryResultType.SQL_QUERY_STRING_RESULT, response)
if len(sql_string_results) == 0:
return 0.0 # we'd expect there to be an sql string if we use this eval block
sum_scores = 0.0
for sql_str_res in sql_string_results:
generated_sql = cast(str, sql_str_res.content["sql"])
evaluation = evaluate_message_content(
generated_sql,
response.human_message.content,
self.eval_steps,
eval_params=self.eval_params,
)
sum_scores += evaluation.score or 0.0
return sum_scores / len(sql_string_results)
class EvalSQLRun(EvalBlockBase):
# TODO: currently assumes we expect one SQLRun
name: ClassVar[str] = "sql_run_evaluation"
expected_row_count: int | None = None
expected_col_names: list[str] | None = None
expected_row_values: list[list[Any]] | None = None # type: ignore[misc]
@staticmethod
def result_out_to_sql_run_result(result_out: ResultOut) -> SQLQueryRunResult:
assert result_out.linked_id is not None
return SQLQueryRunResult(
columns=result_out.content.get("columns"),
rows=result_out.content.get("rows"),
result_id=result_out.result_id,
linked_id=result_out.linked_id,
created_at=result_out.created_at,
)
def evaluate(self, response: QueryOut) -> float:
scores: dict[str, float] = {}
sql_run_results = get_results_of_type(QueryResultType.SQL_QUERY_RUN_RESULT, response)
if len(sql_run_results) == 0:
return 0.0
sql_run_result_out = sql_run_results[0]
sql_run_result = self.result_out_to_sql_run_result(sql_run_result_out)
if self.expected_row_count is not None:
scores["expected_row_count"] = float(len(sql_run_result.rows) == self.expected_row_count)
if self.expected_col_names:
cols_present = len([col in sql_run_result.columns for col in self.expected_col_names])
scores["expected_col_names"] = cols_present * 1.0 / len(self.expected_col_names)
if self.expected_row_values:
rows_present = 0
for expected_row in self.expected_row_values:
expected_row_set = set(expected_row)
for row in sql_run_result.rows:
row_set = set(row)
if expected_row_set.issubset(row_set):
rows_present += 1
break
scores["expected_row_values"] = rows_present * 1.0 / len(self.expected_row_values)
return sum(val for val in scores.values()) / len(scores)
@model_validator(mode="after")
def check_any_evaluation_present(self) -> Self:
has_evaluation_present = (
bool(self.expected_col_names) or self.expected_row_count is not None or bool(self.expected_row_values)
)
if not has_evaluation_present:
raise ValueError("Must have at least one of expected_col_names, expected_row_count, expected_row_values")
return self
class Message(BaseModel):
content: str
# options: Optional[MessageOptions] # Kept for when we test msg options
class MessageWithResults(BaseModel):
message: Message
results: list[ResultType]
class MessagePair(BaseModel):
human_message: Message
ai_message: MessageWithResults
class TestCase(BaseModel):
__test__ = False # so pytest doesn't collect this as a test
test_name: str
message_history: list[MessagePair] = []
query: str
evaluations: list[EvalBlockBase]
scores: list[float] = []
total_score: float | None = None
# TODO: Maybe created named tuples for the results - getting a bit messy
def run(
self, client: TestClient, conversation_id: UUID
) -> list[tuple[str, str, float, float, float, Sequence[str]]]:
evaluation_results: list[tuple[str, float, float, float, Sequence[str]]] = []
response = call_query_endpoint(client, conversation_id, self.query)
for evaluation in self.evaluations:
result = evaluation.evaluate_and_serialize(response)
if isinstance(result, list):
evaluation_results.extend(result)
else:
evaluation_results.append(result)
return [(snake(self.test_name), *result) for result in evaluation_results]