linalg-zero / linalg_zero /distillation /components /multi_turn_generation_base.py
atomwalk12's picture
initial commit
0dd6c2f
import json
from typing import TYPE_CHECKING, Any
from distilabel.errors import DistilabelUserError
from distilabel.mixins.runtime_parameters import (
RuntimeParameter,
RuntimeParametersMixin,
)
from distilabel.models.llms.base import LLM
from pydantic import Field, PositiveInt, ValidationError
from linalg_zero.distillation.components.diagnostics import Diagnostics
from linalg_zero.distillation.components.models import ModelType
from linalg_zero.distillation.data import FunctionInvocationInfo, ThoughtSchema
from linalg_zero.grpo.verifiers.xml_parser import (
XMLParser,
)
from linalg_zero.grpo.verify import parse_string, verify_answers
from linalg_zero.shared.lib import get_lib
from linalg_zero.shared.system_prompts import (
ANSWER_OPEN,
THINK_OPEN,
TOOL_CALL_OPEN,
)
from linalg_zero.shared.utils import get_logger
if TYPE_CHECKING:
from distilabel.typing import ChatType
from linalg_zero.distillation.components.models import (
ModelParameters,
)
class MultiTurnWithToolUseBase(RuntimeParametersMixin):
llm: LLM
_logger = get_logger(__name__)
n_turns: PositiveInt = Field(
description="The number of turns to generate for the conversation.",
)
include_system_prompt: RuntimeParameter[bool] = Field(
default=True,
description="Whether to include the system prompt used in the generated conversation.",
)
library: list[str] = Field(
description="The list of function names available for tool calls.",
)
system_prompt: RuntimeParameter[str] = Field(
default=None,
description="The system prompt to use for the generation.",
)
structured_output: RuntimeParameter[bool] = Field(
default=False,
description="Whether to use structured output for the generation.",
)
enable_hint_injection: RuntimeParameter[bool] = Field(
default=False,
description="If true, inject a user hint about malformed outputs to guide the next turn. If false, track diagnostics only without modifying conversations.",
)
max_diagnostic_messages: RuntimeParameter[int | None] = Field(
default=1,
description="Maximum number of diagnostic user messages to retain in the conversation (None for unbounded).",
)
strict_format: RuntimeParameter[bool] = Field(
default=True,
description=f"If true, enforce strict '{THINK_OPEN} then {TOOL_CALL_OPEN}|{ANSWER_OPEN}' structure gate in parsing.",
)
min_successful_completions: int = Field(
default=-1,
exclude=True,
description="If set, continue generating until this many successful completions are achieved (ignores dataset size).",
)
strip_think_prefix: RuntimeParameter[bool] = Field(
default=True,
description="If true, strip the think prefix from the conversation. This is needed for Qwen3 models.",
)
model_type: str = Field(
description="The name of the model.",
)
def _prepare_inputs_for_instruction_generation(self, inputs: list[dict[str, Any]]) -> list["ChatType"]:
prepared_inputs = []
for data in inputs:
conversation = []
if self.system_prompt:
conversation.append({"role": "system", "content": self.system_prompt})
conversation.append({"role": "user", "content": data["query"]})
prepared_inputs.append(conversation)
return prepared_inputs
def create_assistant_message(self, message: ThoughtSchema) -> dict[str, Any] | None:
config: ModelParameters = ModelType(self.model_type).get_model_parameters()
return config.format_assistant_message(message)
def create_tool_message(self, conversation: list["ChatType"], message: dict[str, Any]) -> dict[str, Any]:
# NOTE: Find the last assistant message with tool calls. This only works for single-turn tool calls,
# if we transition to multiple calls per turn, must match by name or position.
tool_call_id = None
for msg in reversed(conversation):
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tool_call_id = msg.get("tool_calls", [{}])[0].get("id")
break
if tool_call_id is None:
raise DistilabelUserError("No assistant message with tool_calls found for tool response")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"name": message["function_name"],
"content": message["execution_result"],
}
def _append_messages_to_conversations(
self, role: str, messages: list[ThoughtSchema | None] | list[dict[str, str]], conversations: list["ChatType"]
) -> list["ChatType"]:
"""Appends the outputs generated by the LLM with the specified role to the conversations."""
for message, conversation in zip(messages, conversations, strict=True):
# The message may be None because of incomplete generations.
if message is None:
continue
if role == "assistant":
if isinstance(message, ThoughtSchema):
# The possible previous roles are diagnostic user messages or tool messages
last_role = conversation[-1]["role"]
assert last_role in ("user", "tool"), (
f"Invalid last role before assistant generation: '{last_role}'"
)
formatted_message = self.create_assistant_message(message)
if formatted_message is not None:
conversation.append(formatted_message)
else:
continue
elif role == "tool":
if isinstance(message, dict):
formatted_message = self.create_tool_message(conversation, message)
conversation.append(formatted_message)
else:
continue
else:
raise DistilabelUserError(f"Invalid role: {role}. Must be 'assistant' or 'tool'.")
return conversations
def _prepare_conversation_outputs(
self, conversations: list["ChatType"], final_answers: list[str], success_indices: list[bool]
) -> list[dict[str, Any]]:
"""Prepare the output conversation removing the system prompt if necessary.
It will return a dictionary with a "messages" key."""
diag = Diagnostics(model_type=ModelType(self.model_type))
outputs: list[dict[str, Any]] = []
for conversation, final_answer, is_correct in zip(conversations, final_answers, success_indices, strict=True):
if conversation is None:
raise DistilabelUserError("Conversation is None. This should not happen.")
if len(conversation) == 0:
# Something went wrong with the `LLM` and it didn't generate any message
outputs.append({"messages": []})
continue
conv_out = list(conversation)
# Remove any mid-conversation hint messages
conv_out = diag.remove_hint_messages(conv_out)
# Drop initial system if requested
if not self.include_system_prompt and conv_out[0]["role"] == "system":
conv_out = conv_out[1:]
output = {"messages": conv_out, "final_answer": final_answer, "is_correct": is_correct}
outputs.append(output)
return outputs
def _sanitize_history_for_llm(self, conversations: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
"""Return a sanitized deep copy of conversations without THINK content in history.
- Assistant messages: drop THINK segments; keep only final answer text.
If tool_calls present, keep tool_calls and clear content.
- Other roles are kept as-is.
"""
assert self.model_type == "default", "This function is only supported for Default model (Qwen3)"
sanitized: list[list[dict[str, Any]]] = []
for conv in conversations:
new_conv: list[dict[str, Any]] = []
for msg in conv:
m = dict(msg)
if m.get("role") == "assistant" and m.get("tool_calls"):
m["content"] = ""
new_conv.append(m)
sanitized.append(new_conv)
return sanitized
def _inject_hints(
self,
conversations: list["ChatType"],
active_indices: list[int],
parsed_active_msgs: list[ThoughtSchema | None],
raw_messages: list[str | None],
) -> dict[int, tuple[str, str]]:
"""Track diagnostic reasons and optionally inject hint messages for malformed outputs.
Returns a mapping from global sample index to (diagnostic reason, raw message).
"""
reasons: dict[int, tuple[str, str]] = {}
diag = Diagnostics(model_type=ModelType(self.model_type))
for local_idx, parsed in enumerate(parsed_active_msgs):
if parsed is None or (parsed.tool_call is None and not parsed.completed):
global_idx = active_indices[local_idx]
conv = conversations[global_idx]
raw = raw_messages[local_idx] or ""
hint = diag.analyze_and_build_hint(context=conv, message=raw, tool_names=self.library)
if hint:
reasons[global_idx] = (hint, raw)
if self.enable_hint_injection:
# Inject hint message into conversation to guide next turn
diag.apply_hint(conv, hint, max_hints=self.max_diagnostic_messages)
self._logger.info(f"Injected recovery hint for sample {global_idx}: {hint}")
else:
# Track diagnostic only without modifying conversation
self._logger.info(f"Tracked malformed output for sample {global_idx}: {hint}")
return reasons
def _generate_conversation_turn(
self, conversations: list["ChatType"], active_indices: list[int]
) -> tuple[list["ChatType"], list[int], tuple[Any, ...], list[ThoughtSchema | None], dict[int, tuple[str, str]]]:
# Generate an output for the conversations that are still active
inputs = []
for idx in active_indices:
conversation = conversations[idx]
inputs.append(conversation)
outputs = self.llm.generate(
inputs=inputs, # TODO: Sanitizing history likely not necessary: self._sanitize_history_for_llm(inputs)
num_generations=1,
**self.llm.get_generation_kwargs(),
)
# Extract the single message from the conversation and the statistics in separate lists
messages, statistics = zip(
*[(output["generations"][0], output["statistics"]) for output in outputs],
strict=True,
)
# Log potential truncation
self._check_generation_truncation(statistics, active_indices)
parser = XMLParser()
seeded_messages: list[str | None] = []
for msg in messages:
if msg is None:
seeded_messages.append(None)
else:
seeded_messages.append(parser.ensure_think_prefix(msg))
# The parsed messages may contain `None`s if the LLM didn't generate a message.
contexts = [conversations[idx] for idx in active_indices]
parsed_active_msgs = self.extract_output(list(seeded_messages), contexts=contexts)
active_conversations = [conversations[idx] for idx in active_indices]
updated_conversations = self._append_messages_to_conversations(
role="assistant",
messages=parsed_active_msgs,
conversations=active_conversations,
)
for idx, conv in zip(active_indices, updated_conversations, strict=True):
conversations[idx] = conv
# Inject hint for failed messages to guide next turn
diagnostics: dict[int, tuple[str, str]] = {}
diagnostics = self._inject_hints(conversations, active_indices, parsed_active_msgs, list(messages))
# Keep samples active even if message is None (retry next turn)
new_active_indices = [
idx
for idx, message in zip(active_indices, parsed_active_msgs, strict=True)
if (message is None or not message.completed)
]
# Create full-sized array aligned with conversations list
full_parsed_messages: list[ThoughtSchema | None] = [None] * len(conversations)
for active_idx, parsed_msg in zip(active_indices, parsed_active_msgs, strict=True):
full_parsed_messages[active_idx] = parsed_msg
return conversations, new_active_indices, statistics, full_parsed_messages, diagnostics
def _check_generation_truncation(self, statistics: tuple[Any, ...], active_indices: list[int]) -> None:
"""Check and log potential generation truncation based on token statistics."""
generation_kwargs = self.llm.get_generation_kwargs()
max_new_tokens = generation_kwargs["max_new_tokens"]
if max_new_tokens is not None:
for i, stats in enumerate(statistics):
output_tokens = stats["output_tokens"]
if output_tokens == max_new_tokens:
self._logger.warning(
f"Generation may have been truncated at max_new_tokens={max_new_tokens} "
f"for conversation {active_indices[i]} (output_tokens={output_tokens})"
)
def _execute_tool_calls(
self, conversations: list["ChatType"], active_indices: list[int], parsed_messages: list[ThoughtSchema | None]
) -> tuple[list["ChatType"], list[int], dict[int, str], dict[int, int]]:
# Filter only messages that have a valid tool call to execute
exec_indices: list[int] = []
exec_messages: list[ThoughtSchema] = []
exec_conversations: list[ChatType] = []
for idx in active_indices:
message = parsed_messages[idx]
if message is None or message.tool_call is None:
# Skip malformed or tool-less messages; conversation remains active for next turn
continue
exec_indices.append(idx)
exec_messages.append(message)
exec_conversations.append(conversations[idx])
if not exec_indices:
return conversations, active_indices, {}, {}
active_results, tool_statistics = self._execute(inputs=exec_messages, active_indices=exec_indices)
updated_conversations = self._append_messages_to_conversations(
role="tool",
messages=active_results,
conversations=exec_conversations,
)
for idx, conv in zip(exec_indices, updated_conversations, strict=True):
conversations[idx] = conv
# Determine tool errors by inspecting execution_result content
error_counts: dict[int, int] = {}
for local_pos, res in enumerate(active_results):
if str(res.get("execution_result", "")).startswith("ERROR:"):
global_idx = exec_indices[local_pos]
error_counts[global_idx] = error_counts.get(global_idx, 0) + 1
return conversations, active_indices, tool_statistics, error_counts
def _execute(
self, inputs: list[ThoughtSchema], active_indices: list[int]
) -> tuple[list[dict[str, str]], dict[int, str]]:
results: list[dict[str, str]] = []
statistics: dict[int, str] = {}
for data, idx in zip(inputs, active_indices, strict=True):
if data.tool_call is None:
# Should not happen due to upstream filtering; skip safely
continue
name = data.tool_call.name
arguments = data.tool_call.arguments
# Track tool call frequency for this specific input
statistics[idx] = name
try:
lib_functions = get_lib()
if name not in lib_functions:
results.append({
"function_name": name,
"execution_result": f"ERROR: Function '{name}' not found in library",
})
continue
result = lib_functions[name](**arguments)
results.append({"function_name": name, "execution_result": str(result)})
except Exception as exc:
# Avoid stopping the pipeline on tool errors; propagate error text
results.append({
"function_name": name,
"execution_result": f"ERROR: {type(exc).__name__}: {exc}",
})
return results, statistics
def extract_output(self, messages: list[str | None], contexts: list[list[dict]]) -> list[ThoughtSchema | None]:
"""Extract the structured output from the messages."""
result: list[ThoughtSchema | None] = []
for ctx, message in zip(contexts, messages, strict=True):
# Default placeholder to avoid None entries
placeholder = ThoughtSchema(thought="", tool_call=None, final_answer=None, completed=False)
if message is None:
result.append(placeholder)
continue
parsed: ThoughtSchema | None = None
if self.structured_output:
parsed, _ = self.extract_structured_output(message, context=ctx)
else:
parsed = self.extract_non_structured_output(message, context=ctx)
if parsed is None:
result.append(placeholder)
else:
result.append(parsed)
return result
def extract_structured_output(self, message: str, context: list[dict]) -> tuple[ThoughtSchema | None, str]:
"""Extract output from messages that enforce structured output."""
parser = XMLParser()
try:
result = ThoughtSchema.model_validate_json(message)
except ValidationError as e:
return None, f"malformed JSON: {e!s}"
# Enforce is_valid_think_then_tool_or_answer
if result.tool_call is None and result.final_answer is None:
return None, "missing tool_call or final_answer"
# Enforce answer_policy_valid
if result.final_answer is not None and not parser._has_tool_calls(context):
return None, "answer without tool response"
tool_call = result.tool_call
answer = result.final_answer
# If both a tool and an answer appear, treat this as a tool-call step
# and ignore the answer for this turn. In the unstructured setting,
# is_valid_think_then_tool_or_answer enforces that tool_call and answer
# can never be both present at the same time.
if tool_call is not None and answer is not None:
return None, "both tool_call and final_answer present"
return result, "ok"
def extract_non_structured_output(self, message: str, context: list[dict]) -> ThoughtSchema | None:
"""Extract output from messages that do not enforce structured output."""
parser = XMLParser()
analysis = parser.analyze_message_in_context(context, message=message, tool_names=self.library)
if self.strict_format and not bool(analysis["is_valid_think_then_tool_or_answer"]):
return None
if analysis["has_answer"] and not bool(analysis["answer_policy_valid"]):
return None
thought = analysis["thought"] or ""
# Enforce a single tool call per turn: take only the last tool block
tool_call: FunctionInvocationInfo | None = None
tool_info = analysis["tool"]
if tool_info and tool_info["json_valid"]:
tool_call = FunctionInvocationInfo(
name=str(tool_info["name"]),
arguments=dict(tool_info["arguments"]),
)
# Mark completion based on presence of answer
answer = analysis["answer"]
return ThoughtSchema(
thought=thought,
tool_call=tool_call,
final_answer=answer,
completed=answer is not None,
)
def _diagnose_structured_output(self, msg: str | None, context: list[dict]) -> str:
"""Heuristic diagnosis of why a generation is unusable to guide recovery."""
if msg is None or not str(msg).strip():
return "empty generation"
_, reason = self.extract_structured_output(msg, context)
return reason
def _diagnose_unstructured_output(self, msg: str | None, context: list[dict]) -> str:
"""Heuristic diagnosis of why a generation is unusable to guide recovery."""
if msg is None or not str(msg).strip():
return "empty generation"
parser = XMLParser()
msg = parser.ensure_think_prefix(msg) or ""
analysis = parser.analyze_message_in_context(context, message=msg, tool_names=self.library)
return parser.get_analysis_failure_reason(analysis, tool_names=self.library)
def _generate_multi_turn_conversation( # noqa: C901
self, inputs: list[dict[str, Any]]
) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[dict[str, int]]]:
conversations = self._prepare_inputs_for_instruction_generation(inputs)
# Keep track of the active conversations, as it could happen that for some conversation
# we can't generate the next turn because the `LLM` returned `None`.
active_indices = list(range(len(conversations)))
stats_gen: list[dict[int, Any]] = []
stats_tools: list[dict[int, str]] = []
final_answers: list[str | None] = [None] * len(conversations)
malformed_counts: list[int] = [0] * len(conversations)
tool_errors_total: list[int] = [0] * len(conversations)
diagnostics_reasons: list[list[str]] = [[] for _ in range(len(conversations))]
diagnostics_messages: list[list[str]] = [[] for _ in range(len(conversations))]
for i in range(self.n_turns):
if not active_indices:
break
# Store current active indices before they get updated
current_active_indices = active_indices.copy()
# Generate assistant-tool interaction
(
conversations,
active_indices,
statistics_generation,
parsed_messages,
diagnostics_map,
) = self._generate_conversation_turn(
conversations=conversations,
active_indices=active_indices,
)
# Use the original active indices to access the full parsed messages array
for idx in current_active_indices:
message = parsed_messages[idx]
if message and message.completed:
final_answers[idx] = message.final_answer
if message is None:
malformed_counts[idx] += 1
# Record diagnostics for this turn
for idx, (reason, raw) in diagnostics_map.items():
diagnostics_reasons[idx].append(f"turn {i}: {reason}")
diagnostics_messages[idx].append(raw)
if i == (self.n_turns - 1):
# Map tuple of stats back to global indices
per_turn = dict(zip(current_active_indices, statistics_generation, strict=True))
stats_gen.append(per_turn)
break
if not active_indices:
break
# Generate assistant message; execute only valid tool calls
conversations, active_indices, statistics_tools, error_counts = self._execute_tool_calls(
conversations=conversations, parsed_messages=parsed_messages, active_indices=active_indices
)
stats_tools.append(statistics_tools)
per_turn = dict(zip(current_active_indices, statistics_generation, strict=True))
stats_gen.append(per_turn)
# Log tool errors if any
if error_counts:
self._logger.warning(f"Tool errors in turn {i}: {error_counts}")
for idx, cnt in error_counts.items():
tool_errors_total[idx] += cnt
# Convert None values to empty strings instead of None
final_answers_clean = [answer if answer is not None else "" for answer in final_answers]
success_indices = self.check_final_answers(final_answers_clean, inputs)
self._logger.info(
f"Multi-turn conversation completed. Success rate: {sum(success_indices)}/{len(success_indices)}"
)
# Build per-sample generation stats list aligned to inputs (list of per-turn stats)
gen_stats_list: list[list[Any]] = []
for i in range(len(inputs)):
per_sample_stats = [per_turn[i] for per_turn in stats_gen if i in per_turn]
gen_stats_list.append(per_sample_stats)
merged_stats_tools = self.merge_tool_stats(stats_tools, inputs=inputs)
return (
self._prepare_conversation_outputs(conversations, final_answers_clean, success_indices),
[
{
"gen_stats": gen_stats_list[i],
"malformed_turns": malformed_counts[i],
"tool_errors": tool_errors_total[i],
"diagnostics": diagnostics_reasons[i],
"diagnostic_messages": diagnostics_messages[i],
}
for i in range(len(inputs))
],
merged_stats_tools,
)
def check_final_answers(self, final_answers: list[str], inputs: list[dict[str, Any]]) -> list[bool]:
"""Check if the final answers are correct."""
is_correct = []
for _input, lib_result in zip(inputs, final_answers, strict=True):
ground_truth = json.loads(_input["ground_truth"])
is_correct.append(verify_answers(ground_truth, parse_string(lib_result)))
return is_correct
def merge_tool_stats(
self, batch_stats: list[dict[int, str]], inputs: list[dict[str, Any]]
) -> list[dict[str, int]]:
"""Merge the tool stats into a single dictionary."""
merged_stats: list[dict[str, int]] = [{} for _ in range(len(inputs))]
for turn_stats in batch_stats:
for active_index, fn_name in turn_stats.items():
merged_stats[active_index][fn_name] = merged_stats[active_index].get(fn_name, 0) + 1
return merged_stats
def _prepare_diagnostics_lists(self, stats_gen: dict[str, Any]) -> tuple[list[str], list[str]]:
"""Prepare diagnostics and diagnostic_messages lists with safe filtering."""
diagnostics_list = [str(v) for v in (stats_gen.get("diagnostics", []) or []) if v is not None and str(v) != ""]
diagnostic_msgs_list = [
str(v) for v in (stats_gen.get("diagnostic_messages", []) or []) if v is not None and str(v) != ""
]
return diagnostics_list, diagnostic_msgs_list
def _generate_with_pre_query_template(self, inputs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Generate a list of instructions or conversations of the specified number of turns."""
outputs, statistics_gens, statistics_tools = self._generate_multi_turn_conversation(inputs)
generations = []
for input_data, output, stats_gen, stats_tools in zip(
inputs, outputs, statistics_gens, statistics_tools, strict=True
):
generation = {
**input_data,
**output,
"model_name": self.llm.model_name,
}
# Ensure stats_tools has at least one field to avoid Parquet serialization errors
tool_stats = stats_tools if isinstance(stats_tools, dict) and stats_tools else {"_empty": 0}
# Prepare diagnostics lists with safe filtering
diagnostics_list, diagnostic_msgs_list = self._prepare_diagnostics_lists(stats_gen)
generation["distilabel_metadata"] = {
f"statistics_gen_{self.name}": stats_gen.get("gen_stats", []),
f"statistics_tools_{self.name}": tool_stats,
f"malformed_turns_{self.name}": int(stats_gen.get("malformed_turns", 0) or 0),
f"tool_errors_{self.name}": int(stats_gen.get("tool_errors", 0) or 0),
f"tool_calls_total_{self.name}": int(sum(tool_stats.values())) if isinstance(tool_stats, dict) else 0,
f"diagnostics_{self.name}": (diagnostics_list if diagnostics_list else [""]),
f"diagnostic_messages_{self.name}": (diagnostic_msgs_list if diagnostic_msgs_list else [""]),
}
generations.append(generation)
return generations