linalg-zero / linalg_zero /sft /tool_calling_accuracy.py
atomwalk12's picture
initial commit
0dd6c2f
"""
Tool calling accuracy callback for SFT training.
Evaluates structural and correctnessfo metrics for tool-use generations on all eval data.
"""
from __future__ import annotations
import json as _json
import os
from typing import Any
import torch
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from weave import EvaluationLogger
from weave.trace.context import weave_client_context
from linalg_zero.distillation.components.models import DefaultConfig
from linalg_zero.distillation.data import FunctionInvocationInfo, ThoughtSchema
from linalg_zero.grpo.verifiers.xml_parser import XMLParser
from linalg_zero.sft.diagnostics import DiagnosticTracker
from linalg_zero.sft.tool_evaluation import EvaluationState
from linalg_zero.shared.lib import get_lib, get_lib_fn_names
from linalg_zero.shared.system_prompts import (
ANSWER_CLOSE,
ANSWER_OPEN,
THINK_CLOSE,
THINK_OPEN,
TOOL_CALL_CLOSE,
TOOL_CALL_OPEN,
TOOL_RESPONSE_CLOSE,
TOOL_RESPONSE_OPEN,
)
from linalg_zero.shared.utils import get_logger
logger = get_logger(__name__)
class ToolCallingAccuracyCallback(TrainerCallback):
"""
Callback to evaluate tool calling accuracy during SFT training.
Evaluates all samples in the eval_dataset for robust metric computation.
"""
def __init__(self, model_name: str, dataset_name: str, eval_dataset: Any) -> None:
self.eval_dataset = eval_dataset
self.model_name = model_name
self.dataset_name = dataset_name
self.library = get_lib()
self._parser = XMLParser()
self.model_config = DefaultConfig()
self.generation_config = None
def on_evaluate(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
metrics: dict[str, float] | None = None,
**kwargs: Any,
) -> None:
if not state.is_world_process_zero:
return
model = kwargs.get("model")
tokenizer = kwargs.get("processing_class")
if model is None or tokenizer is None:
return
# Get max_new_tokens from training args (eval_max_new_tokens config)
max_new_tokens = getattr(args, "eval_max_new_tokens", 1024)
# Determine eval subset size
max_eval_samples = getattr(args, "max_eval_samples", None)
if max_eval_samples is None or max_eval_samples < 0:
num_samples = len(self.eval_dataset)
else:
num_samples = min(max_eval_samples, len(self.eval_dataset))
logger.info(f"Computing tool-calling metrics on {num_samples}/{len(self.eval_dataset)} eval samples...")
all_messages, weave_metadata, eval_metrics = self._compute_metrics(
model=model,
tokenizer=tokenizer,
dataset=self.eval_dataset,
max_new_tokens=max_new_tokens,
max_samples=num_samples,
)
if eval_metrics:
self._log_evaluation_metrics(eval_metrics, state, prefix="eval")
if metrics is not None:
metrics.update({f"eval_{k}": v for k, v in eval_metrics.items()})
self._log_to_weave(state, all_messages, weave_metadata)
def _log_evaluation_metrics(self, metrics: dict[str, float], state: TrainerState, prefix: str = "eval") -> None:
"""Log evaluation metrics to trainer state and logger (Trainer will forward to W&B)."""
for metric_name, value in metrics.items():
logger.info(f"eval/{metric_name}: {value:.3f}")
def _log_to_weave(
self, state: TrainerState, all_messages: list[list[dict[str, Any]]], metadata: dict[str, int]
) -> None:
"""Log predictions and summary to Weave."""
client = weave_client_context.get_weave_client()
if client is not None:
eval_attributes = {
"training_step": state.global_step,
"model_name": self.model_name,
"generation_config": (self.generation_config if self.generation_config else None),
}
# Sanitize model name for Weave scorer validation
sanitized_model_name = self.model_name.replace("/", "_").replace("-", "_").replace(".", "_")
weave_logger = EvaluationLogger(
name=f"linalg_zero_sft_eval_{state.global_step}",
scorers=[],
model=sanitized_model_name,
dataset=self.dataset_name,
eval_attributes=eval_attributes,
)
else:
return
logger.info(f"Logging {len(all_messages)} samples to Weave...")
# Log each sample as a prediction
for messages in all_messages:
inputs = {"messages": messages[:2]}
output = {"messages": messages[2:]} if len(messages) > 2 else None
pred = weave_logger.log_prediction(inputs=inputs, output=output)
pred.finish()
# Log summary with metadata
weave_logger.log_summary(metadata)
logger.info(f"Weave logging completed with metadata: {metadata}")
def _compute_metrics(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
dataset: Any,
max_new_tokens: int,
max_samples: int | None = None,
) -> tuple[list[list[dict[str, Any]]], dict[str, int | float], dict[str, float]]:
"""Compute metrics on eval samples with fair turn allocation per sample."""
model.eval()
if not dataset or len(dataset) == 0:
raise ValueError("Dataset is empty")
# Subset dataset if max_samples specified
num_samples = len(dataset) if max_samples is None else min(max_samples, len(dataset))
samples: list[dict[str, Any]] = [dataset[i] for i in range(num_samples)]
# Initialize metrics tracker
tracker = DiagnosticTracker()
pbar = tqdm(samples, desc="Evaluating tool calling", unit="sample", disable=False)
for sample in pbar:
# Determine fair n_turns based on sample complexity
steps = sample["stepwise_ground_truths"]
arr = _json.loads(steps)
num_tool_turns = len(arr)
n_turns = num_tool_turns + 1
# Run evaluation and update tracker
state = self._run_evaluation_turns(model, tokenizer, sample, n_turns, max_new_tokens)
tracker.update(state)
# Update progress bar
pbar.set_postfix({**tracker.get_progress_info()})
all_messages, metadata, eval_metrics = tracker.get_history()
return all_messages, metadata, eval_metrics
def _generate(
self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt_text: str, max_new_tokens: int
) -> str:
"""Generate output for a single prompt."""
inputs = tokenizer(
prompt_text,
return_tensors="pt",
truncation=True,
padding=bool(getattr(tokenizer, "pad_token_id", None)),
)
if inputs["input_ids"].shape[1] == tokenizer.model_max_length:
logger.warning(f"Input truncated to {tokenizer.model_max_length} tokens during tool calling evaluation")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
if self.generation_config is None:
self.generation_config = self.get_generation_config(max_new_tokens, tokenizer)
outputs = model.generate( # type: ignore[operator]
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
tokenizer=tokenizer,
**self.generation_config,
)
# Extract only the generated tokens (after the input)
prompt_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, prompt_length:]
# Special tokens to preserve
KEEP_TOKENS = {ANSWER_OPEN, ANSWER_CLOSE, THINK_OPEN, THINK_CLOSE, TOOL_CALL_OPEN, TOOL_CALL_CLOSE}
# Decode without skipping any special tokens
output_raw = tokenizer.decode(generated_tokens[0], skip_special_tokens=False)
# Remove unwanted special tokens
output = output_raw
for special_token in tokenizer.all_special_tokens:
if special_token not in KEEP_TOKENS:
output = output.replace(special_token, "")
# Check if generation was truncated due to max_new_tokens
if (
generated_tokens.shape[1] == max_new_tokens
and getattr(tokenizer, "eos_token_id", None) is not None
and generated_tokens[0, -1].item() != tokenizer.eos_token_id
):
logger.warning(f"Generation may have been truncated at max_new_tokens={max_new_tokens}")
if os.getenv("SFT_LOG_GENERATIONS") == "1":
logger.info("=" * 100)
logger.info(f"Generated output (len={len(output)}): {output}")
logger.info("=" * 100)
return output
def get_generation_config(self, max_new_tokens: int, tokenizer: PreTrainedTokenizer) -> dict[str, Any]:
pad_id = getattr(tokenizer, "pad_token_id", None)
eos_id = getattr(tokenizer, "eos_token_id", None)
assert eos_id is not None, "EOS token ID is not set"
assert pad_id is not None, "PAD token ID is not set"
stop_strings = [TOOL_RESPONSE_OPEN, TOOL_RESPONSE_CLOSE]
stop_token_ids = [eos_id]
for token_str in [TOOL_CALL_CLOSE, ANSWER_CLOSE]:
token_ids = tokenizer.encode(token_str, add_special_tokens=False)
if len(token_ids) != 1:
logger.warning(
f"Special token '{token_str}' was split into {len(token_ids)} tokens. "
f"Ensure it's registered as a special token."
)
else:
stop_token_ids.append(token_ids[0])
return {
"max_new_tokens": max_new_tokens,
"do_sample": False,
"pad_token_id": pad_id,
"eos_token_id": stop_token_ids,
"stop_strings": stop_strings,
"repetition_penalty": 1.0,
"no_repeat_ngram_size": 0,
"top_k": None,
"top_p": None,
"temperature": None,
"use_cache": True,
}
def _run_evaluation_turns(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
sample: dict[str, Any],
n_turns: int,
max_new_tokens: int,
) -> EvaluationState:
"""Run evaluation using simplified GRPO-based conversation processing."""
state = EvaluationState()
state.sample = sample # Store the original dataset sample
context = list(sample["messages"])
all_context = list(sample["messages"])
tools = sample["tools"]
# Multi-turn conversation loop
for _ in range(n_turns):
# Generate assistant response (include tools so template exposes function signatures)
prompt = tokenizer.apply_chat_template(context, tools=tools, tokenize=False, add_generation_prompt=True)
if not isinstance(prompt, str):
state.early_stop_reason = "prompt_format_error"
break
output = self._generate(model, tokenizer, prompt, max_new_tokens)
message = self.extract_exact_match(output, context)
self.add_message("assistant", all_context, output, unstructured=True)
# Check if message extraction worked
if message is None:
state.early_stop_reason = "extraction_failed"
break
if message.final_answer is None and message.tool_call is None:
state.early_stop_reason = "no_action"
break
# Track diagnostic info
if message.tool_call is not None:
state.tool_parse_success = True
self.add_message("assistant", context, message)
# Execute tool call if it exists
if message.tool_call is not None:
tool_call = self._execute(message)
self.add_message("tool", context, tool_call)
self.add_message("tool", all_context, tool_call, unstructured=True)
if message.final_answer is not None:
state.generated_answer = message.final_answer
state.early_stop_reason = "final_answer_provided"
break
# Compute partial format score on complete conversation
state.strict_format_match = self.calculate_exact_match(context)
state.partial_format_score = self.calculate_partial_match(all_context)
state.messages = all_context
return state
def _execute(self, msg: ThoughtSchema) -> dict[str, str]:
if msg.tool_call is None:
raise ValueError("Tool call is required for execution")
name = msg.tool_call.name
arguments = msg.tool_call.arguments
try:
if name not in self.library:
return {
"function_name": name,
"execution_result": f"ERROR: Function '{name}' not found in library",
}
result = self.library[name](**arguments)
return {"function_name": name, "execution_result": str(result)}
except Exception as exc:
return {
"function_name": name,
"execution_result": f"ERROR: {type(exc).__name__}: {exc}",
}
def calculate_exact_match(self, context: list[dict]) -> float:
"""Calculate exact match for the entire conversation."""
reward: list[float] = []
# Discard the first two turns (system and user)
for completion in context[2:]:
msg = completion["content"]
score = 1 if self._parser._is_valid_think_then_tool_or_answer(msg) else 0
reward.append(score)
return sum(reward) / len(reward) if len(reward) > 0 else 0
def calculate_partial_match(self, context: list[dict]) -> float:
"""
Compute format adherence score for lenient evaluation.
"""
scores: list[float] = []
# Discard the first two turns (system and user)
for completion in context[2:]:
score: float = 0
tool_or_answer_reward = 0
response = completion["content"]
# No need to reward <start_working_out> since we always prepend it
# score += 0.5 if response.count(reasoning_start) == 1 else -1.0
score += 1 if response.count(THINK_CLOSE) == 1 else 0
tool_or_answer_reward += 1 if response.count(TOOL_CALL_OPEN) == 1 else 0
tool_or_answer_reward += 1 if response.count(TOOL_CALL_CLOSE) == 1 else 0
if tool_or_answer_reward == 0:
tool_or_answer_reward += 1 if response.count(ANSWER_OPEN) == 1 else 0
tool_or_answer_reward += 1 if response.count(ANSWER_CLOSE) == 1 else 0
score += tool_or_answer_reward
scores.append(score)
return sum(scores) / len(scores) if len(scores) > 0 else 0
def extract_exact_match(self, message: str, context: list[dict]) -> ThoughtSchema | None:
"""Extract output from messages that do not enforce structured output."""
analysis = self._parser.analyze_message_in_context(context, message=message, tool_names=get_lib_fn_names())
if not bool(analysis["is_valid_think_then_tool_or_answer"]):
return None
if analysis["has_answer"] and not bool(analysis["answer_policy_valid"]):
return None
thought = analysis["thought"] or ""
# Enforce a single tool call per turn: take only the last tool block
tool_call: FunctionInvocationInfo | None = None
tool_info = analysis["tool"]
if tool_info and tool_info["json_valid"]:
tool_call = FunctionInvocationInfo(
name=str(tool_info["name"]),
arguments=dict(tool_info["arguments"]),
)
# Mark completion based on presence of answer
answer = analysis["answer"]
return ThoughtSchema(
thought=thought,
tool_call=tool_call,
final_answer=answer,
completed=answer is not None,
)
def add_message(
self,
role: str,
context: list[dict[str, Any]],
message: ThoughtSchema | dict[str, str] | str,
*,
unstructured: bool = False,
) -> None:
if role == "assistant":
if isinstance(message, str):
msg = {"role": "assistant", "content": message}
else:
assert isinstance(message, ThoughtSchema)
msg = self.model_config.format_assistant_message(message)
elif role == "tool":
if unstructured:
assert isinstance(message, dict)
msg = {
"role": "tool",
"content": message["execution_result"],
"name": message["function_name"],
}
else:
assert isinstance(message, dict)
msg = self.model_config.create_tool_message(context, message)
else:
raise ValueError(f"Invalid role: {role}")
assert msg is not None, f"Message is None for role: {role}"
context.append(msg)