Spaces:
Paused
Paused
| """Diagnostic metrics tracker for tool calling evaluation.""" | |
| from __future__ import annotations | |
| import json as _json | |
| from typing import Any | |
| from linalg_zero.grpo.verifiers.xml_parser import XMLParser | |
| from linalg_zero.grpo.verify import parse_string, verify_answers | |
| from linalg_zero.sft.tool_evaluation import EvaluationState | |
| from linalg_zero.shared.lib import get_lib_fn_names | |
| from linalg_zero.shared.types import LibTypes | |
| class DiagnosticTracker: | |
| """Tracks messages and samples for evaluation logging.""" | |
| def __init__(self) -> None: | |
| # Store all messages and samples for Weave logging | |
| self.all_messages: list[list[dict[str, Any]]] = [] | |
| self.all_samples: list[dict[str, Any]] = [] | |
| self.all_strict_formats: list[float] = [] | |
| self.all_partial_formats: list[float] = [] | |
| self.all_generated_answers: list[str | LibTypes | None] = [] | |
| def update(self, state: EvaluationState) -> None: | |
| """Update tracker from an evaluation state.""" | |
| self.all_strict_formats.append(state.strict_format_match) | |
| self.all_partial_formats.append(state.partial_format_score) | |
| self.all_generated_answers.append(state.generated_answer) | |
| # Store messages and sample from this evaluation | |
| self.all_messages.append(state.messages) | |
| if state.sample is not None: | |
| self.all_samples.append(state.sample) | |
| def get_history(self) -> tuple[list[list[dict[str, Any]]], dict[str, int | float], dict[str, float]]: | |
| """Return all messages, metadata and loss metrics for Weave logging best model selection.""" | |
| metadata = self._compute_metadata() | |
| loss_metrics = self.calculate_loss_metrics() | |
| return self.all_messages, {**metadata, **loss_metrics}, loss_metrics | |
| def _compute_metadata(self) -> dict[str, int]: | |
| """Compute metadata statistics from messages and samples.""" | |
| parser = XMLParser() | |
| tool_names = get_lib_fn_names() | |
| total_samples = len(self.all_samples) | |
| total_expected_tool_calls = 0 | |
| total_actual_tool_calls = 0 | |
| total_expected_answers = total_samples # One answer per sample | |
| total_actual_answers = 0 | |
| total_correct_answers = 0 | |
| for messages, sample in zip(self.all_messages, self.all_samples, strict=True): | |
| # Calculate expected tool calls from ground truth | |
| if "stepwise_ground_truths" in sample: | |
| ground_truths = _json.loads(sample["stepwise_ground_truths"]) | |
| total_expected_tool_calls += len(ground_truths) | |
| # Track the last answer found in this conversation | |
| last_answer = None | |
| # Count actual tool calls and answers from assistant messages | |
| # Skip first two messages (system and user) | |
| for message in messages[2:]: | |
| if message.get("role") != "assistant": | |
| continue | |
| content = message.get("content", "") | |
| analysis = parser.analyze_message_in_context( | |
| messages[: messages.index(message) + 1], message=content, tool_names=tool_names | |
| ) | |
| # Count tool calls | |
| if analysis.get("tool") and analysis["tool"].get("json_valid"): | |
| total_actual_tool_calls += 1 | |
| # Count answers and track the last one | |
| if analysis.get("has_answer"): | |
| total_actual_answers += 1 | |
| last_answer = analysis.get("answer") | |
| # Check if the last answer is correct | |
| if last_answer is not None and "ground_truth" in sample: | |
| try: | |
| ground_truth = sample["ground_truth"] | |
| parsed_gt = parse_string(ground_truth) | |
| parsed_answer = parse_string(last_answer) | |
| if verify_answers(parsed_gt, parsed_answer): | |
| total_correct_answers += 1 | |
| except Exception: # noqa: S110 | |
| pass | |
| return { | |
| "total_samples": total_samples, | |
| "total_expected_tool_calls": total_expected_tool_calls, | |
| "total_actual_tool_calls": total_actual_tool_calls, | |
| "total_expected_answers": total_expected_answers, | |
| "total_actual_answers": total_actual_answers, | |
| "total_correct_answers": total_correct_answers, | |
| } | |
| def calculate_loss_metrics(self) -> dict[str, float]: | |
| metrics = self._compute_metadata() | |
| expected_tool_calls = metrics["total_expected_tool_calls"] | |
| expected_answers = metrics["total_expected_answers"] | |
| total_samples = metrics["total_samples"] | |
| total_actions = expected_tool_calls + total_samples | |
| format_overall_accuracy = ( | |
| (metrics["total_actual_tool_calls"] + metrics["total_actual_answers"]) / total_actions | |
| if total_actions > 0 | |
| else 0.0 | |
| ) | |
| format_tool_call_accuracy = ( | |
| metrics["total_actual_tool_calls"] / expected_tool_calls if expected_tool_calls > 0 else 0.0 | |
| ) | |
| answer_attempt_accuracy = metrics["total_actual_answers"] / expected_answers if expected_answers > 0 else 0.0 | |
| answer_accuracy = metrics["total_correct_answers"] / total_samples if total_samples > 0 else 0.0 | |
| return { | |
| "format_accuracy": format_overall_accuracy, | |
| "format_tool_call_accuracy": format_tool_call_accuracy, | |
| "format_answer_accuracy": answer_attempt_accuracy, | |
| "answer_accuracy": answer_accuracy, | |
| } | |
| def get_progress_info(self) -> dict[str, str]: | |
| """Return current progress info for progress bar (3 key metrics).""" | |
| total_samples = len(self.all_samples) | |
| if total_samples == 0: | |
| return { | |
| "strict": "0.000", | |
| "partial": "0.000", | |
| "correct": "0.000", | |
| } | |
| partial_format = sum(self.all_partial_formats) / total_samples | |
| strict_format = sum(self.all_strict_formats) / total_samples | |
| answers_generated = sum(1 for ans in self.all_generated_answers if ans is not None) | |
| answer_rate = answers_generated / total_samples | |
| return { | |
| "strict": f"{strict_format:.3f}", | |
| "partial": f"{partial_format:.3f}", | |
| "correct": f"{answer_rate:.3f}", | |
| } | |