linalg-zero / linalg_zero /sft /diagnostics.py
atomwalk12
initial commit
0dd6c2f
"""Diagnostic metrics tracker for tool calling evaluation."""
from __future__ import annotations
import json as _json
from typing import Any
from linalg_zero.grpo.verifiers.xml_parser import XMLParser
from linalg_zero.grpo.verify import parse_string, verify_answers
from linalg_zero.sft.tool_evaluation import EvaluationState
from linalg_zero.shared.lib import get_lib_fn_names
from linalg_zero.shared.types import LibTypes
class DiagnosticTracker:
"""Tracks messages and samples for evaluation logging."""
def __init__(self) -> None:
# Store all messages and samples for Weave logging
self.all_messages: list[list[dict[str, Any]]] = []
self.all_samples: list[dict[str, Any]] = []
self.all_strict_formats: list[float] = []
self.all_partial_formats: list[float] = []
self.all_generated_answers: list[str | LibTypes | None] = []
def update(self, state: EvaluationState) -> None:
"""Update tracker from an evaluation state."""
self.all_strict_formats.append(state.strict_format_match)
self.all_partial_formats.append(state.partial_format_score)
self.all_generated_answers.append(state.generated_answer)
# Store messages and sample from this evaluation
self.all_messages.append(state.messages)
if state.sample is not None:
self.all_samples.append(state.sample)
def get_history(self) -> tuple[list[list[dict[str, Any]]], dict[str, int | float], dict[str, float]]:
"""Return all messages, metadata and loss metrics for Weave logging best model selection."""
metadata = self._compute_metadata()
loss_metrics = self.calculate_loss_metrics()
return self.all_messages, {**metadata, **loss_metrics}, loss_metrics
def _compute_metadata(self) -> dict[str, int]:
"""Compute metadata statistics from messages and samples."""
parser = XMLParser()
tool_names = get_lib_fn_names()
total_samples = len(self.all_samples)
total_expected_tool_calls = 0
total_actual_tool_calls = 0
total_expected_answers = total_samples # One answer per sample
total_actual_answers = 0
total_correct_answers = 0
for messages, sample in zip(self.all_messages, self.all_samples, strict=True):
# Calculate expected tool calls from ground truth
if "stepwise_ground_truths" in sample:
ground_truths = _json.loads(sample["stepwise_ground_truths"])
total_expected_tool_calls += len(ground_truths)
# Track the last answer found in this conversation
last_answer = None
# Count actual tool calls and answers from assistant messages
# Skip first two messages (system and user)
for message in messages[2:]:
if message.get("role") != "assistant":
continue
content = message.get("content", "")
analysis = parser.analyze_message_in_context(
messages[: messages.index(message) + 1], message=content, tool_names=tool_names
)
# Count tool calls
if analysis.get("tool") and analysis["tool"].get("json_valid"):
total_actual_tool_calls += 1
# Count answers and track the last one
if analysis.get("has_answer"):
total_actual_answers += 1
last_answer = analysis.get("answer")
# Check if the last answer is correct
if last_answer is not None and "ground_truth" in sample:
try:
ground_truth = sample["ground_truth"]
parsed_gt = parse_string(ground_truth)
parsed_answer = parse_string(last_answer)
if verify_answers(parsed_gt, parsed_answer):
total_correct_answers += 1
except Exception: # noqa: S110
pass
return {
"total_samples": total_samples,
"total_expected_tool_calls": total_expected_tool_calls,
"total_actual_tool_calls": total_actual_tool_calls,
"total_expected_answers": total_expected_answers,
"total_actual_answers": total_actual_answers,
"total_correct_answers": total_correct_answers,
}
def calculate_loss_metrics(self) -> dict[str, float]:
metrics = self._compute_metadata()
expected_tool_calls = metrics["total_expected_tool_calls"]
expected_answers = metrics["total_expected_answers"]
total_samples = metrics["total_samples"]
total_actions = expected_tool_calls + total_samples
format_overall_accuracy = (
(metrics["total_actual_tool_calls"] + metrics["total_actual_answers"]) / total_actions
if total_actions > 0
else 0.0
)
format_tool_call_accuracy = (
metrics["total_actual_tool_calls"] / expected_tool_calls if expected_tool_calls > 0 else 0.0
)
answer_attempt_accuracy = metrics["total_actual_answers"] / expected_answers if expected_answers > 0 else 0.0
answer_accuracy = metrics["total_correct_answers"] / total_samples if total_samples > 0 else 0.0
return {
"format_accuracy": format_overall_accuracy,
"format_tool_call_accuracy": format_tool_call_accuracy,
"format_answer_accuracy": answer_attempt_accuracy,
"answer_accuracy": answer_accuracy,
}
def get_progress_info(self) -> dict[str, str]:
"""Return current progress info for progress bar (3 key metrics)."""
total_samples = len(self.all_samples)
if total_samples == 0:
return {
"strict": "0.000",
"partial": "0.000",
"correct": "0.000",
}
partial_format = sum(self.all_partial_formats) / total_samples
strict_format = sum(self.all_strict_formats) / total_samples
answers_generated = sum(1 for ans in self.all_generated_answers if ans is not None)
answer_rate = answers_generated / total_samples
return {
"strict": f"{strict_format:.3f}",
"partial": f"{partial_format:.3f}",
"correct": f"{answer_rate:.3f}",
}