import json from typing import TYPE_CHECKING, Any from distilabel.errors import DistilabelUserError from distilabel.mixins.runtime_parameters import ( RuntimeParameter, RuntimeParametersMixin, ) from distilabel.models.llms.base import LLM from pydantic import Field, PositiveInt, ValidationError from linalg_zero.distillation.components.diagnostics import Diagnostics from linalg_zero.distillation.components.models import ModelType from linalg_zero.distillation.data import FunctionInvocationInfo, ThoughtSchema from linalg_zero.grpo.verifiers.xml_parser import ( XMLParser, ) from linalg_zero.grpo.verify import parse_string, verify_answers from linalg_zero.shared.lib import get_lib from linalg_zero.shared.system_prompts import ( ANSWER_OPEN, THINK_OPEN, TOOL_CALL_OPEN, ) from linalg_zero.shared.utils import get_logger if TYPE_CHECKING: from distilabel.typing import ChatType from linalg_zero.distillation.components.models import ( ModelParameters, ) class MultiTurnWithToolUseBase(RuntimeParametersMixin): llm: LLM _logger = get_logger(__name__) n_turns: PositiveInt = Field( description="The number of turns to generate for the conversation.", ) include_system_prompt: RuntimeParameter[bool] = Field( default=True, description="Whether to include the system prompt used in the generated conversation.", ) library: list[str] = Field( description="The list of function names available for tool calls.", ) system_prompt: RuntimeParameter[str] = Field( default=None, description="The system prompt to use for the generation.", ) structured_output: RuntimeParameter[bool] = Field( default=False, description="Whether to use structured output for the generation.", ) enable_hint_injection: RuntimeParameter[bool] = Field( default=False, description="If true, inject a user hint about malformed outputs to guide the next turn. If false, track diagnostics only without modifying conversations.", ) max_diagnostic_messages: RuntimeParameter[int | None] = Field( default=1, description="Maximum number of diagnostic user messages to retain in the conversation (None for unbounded).", ) strict_format: RuntimeParameter[bool] = Field( default=True, description=f"If true, enforce strict '{THINK_OPEN} then {TOOL_CALL_OPEN}|{ANSWER_OPEN}' structure gate in parsing.", ) min_successful_completions: int = Field( default=-1, exclude=True, description="If set, continue generating until this many successful completions are achieved (ignores dataset size).", ) strip_think_prefix: RuntimeParameter[bool] = Field( default=True, description="If true, strip the think prefix from the conversation. This is needed for Qwen3 models.", ) model_type: str = Field( description="The name of the model.", ) def _prepare_inputs_for_instruction_generation(self, inputs: list[dict[str, Any]]) -> list["ChatType"]: prepared_inputs = [] for data in inputs: conversation = [] if self.system_prompt: conversation.append({"role": "system", "content": self.system_prompt}) conversation.append({"role": "user", "content": data["query"]}) prepared_inputs.append(conversation) return prepared_inputs def create_assistant_message(self, message: ThoughtSchema) -> dict[str, Any] | None: config: ModelParameters = ModelType(self.model_type).get_model_parameters() return config.format_assistant_message(message) def create_tool_message(self, conversation: list["ChatType"], message: dict[str, Any]) -> dict[str, Any]: # NOTE: Find the last assistant message with tool calls. This only works for single-turn tool calls, # if we transition to multiple calls per turn, must match by name or position. tool_call_id = None for msg in reversed(conversation): if msg.get("role") == "assistant" and msg.get("tool_calls"): tool_call_id = msg.get("tool_calls", [{}])[0].get("id") break if tool_call_id is None: raise DistilabelUserError("No assistant message with tool_calls found for tool response") return { "role": "tool", "tool_call_id": tool_call_id, "name": message["function_name"], "content": message["execution_result"], } def _append_messages_to_conversations( self, role: str, messages: list[ThoughtSchema | None] | list[dict[str, str]], conversations: list["ChatType"] ) -> list["ChatType"]: """Appends the outputs generated by the LLM with the specified role to the conversations.""" for message, conversation in zip(messages, conversations, strict=True): # The message may be None because of incomplete generations. if message is None: continue if role == "assistant": if isinstance(message, ThoughtSchema): # The possible previous roles are diagnostic user messages or tool messages last_role = conversation[-1]["role"] assert last_role in ("user", "tool"), ( f"Invalid last role before assistant generation: '{last_role}'" ) formatted_message = self.create_assistant_message(message) if formatted_message is not None: conversation.append(formatted_message) else: continue elif role == "tool": if isinstance(message, dict): formatted_message = self.create_tool_message(conversation, message) conversation.append(formatted_message) else: continue else: raise DistilabelUserError(f"Invalid role: {role}. Must be 'assistant' or 'tool'.") return conversations def _prepare_conversation_outputs( self, conversations: list["ChatType"], final_answers: list[str], success_indices: list[bool] ) -> list[dict[str, Any]]: """Prepare the output conversation removing the system prompt if necessary. It will return a dictionary with a "messages" key.""" diag = Diagnostics(model_type=ModelType(self.model_type)) outputs: list[dict[str, Any]] = [] for conversation, final_answer, is_correct in zip(conversations, final_answers, success_indices, strict=True): if conversation is None: raise DistilabelUserError("Conversation is None. This should not happen.") if len(conversation) == 0: # Something went wrong with the `LLM` and it didn't generate any message outputs.append({"messages": []}) continue conv_out = list(conversation) # Remove any mid-conversation hint messages conv_out = diag.remove_hint_messages(conv_out) # Drop initial system if requested if not self.include_system_prompt and conv_out[0]["role"] == "system": conv_out = conv_out[1:] output = {"messages": conv_out, "final_answer": final_answer, "is_correct": is_correct} outputs.append(output) return outputs def _sanitize_history_for_llm(self, conversations: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]: """Return a sanitized deep copy of conversations without THINK content in history. - Assistant messages: drop THINK segments; keep only final answer text. If tool_calls present, keep tool_calls and clear content. - Other roles are kept as-is. """ assert self.model_type == "default", "This function is only supported for Default model (Qwen3)" sanitized: list[list[dict[str, Any]]] = [] for conv in conversations: new_conv: list[dict[str, Any]] = [] for msg in conv: m = dict(msg) if m.get("role") == "assistant" and m.get("tool_calls"): m["content"] = "" new_conv.append(m) sanitized.append(new_conv) return sanitized def _inject_hints( self, conversations: list["ChatType"], active_indices: list[int], parsed_active_msgs: list[ThoughtSchema | None], raw_messages: list[str | None], ) -> dict[int, tuple[str, str]]: """Track diagnostic reasons and optionally inject hint messages for malformed outputs. Returns a mapping from global sample index to (diagnostic reason, raw message). """ reasons: dict[int, tuple[str, str]] = {} diag = Diagnostics(model_type=ModelType(self.model_type)) for local_idx, parsed in enumerate(parsed_active_msgs): if parsed is None or (parsed.tool_call is None and not parsed.completed): global_idx = active_indices[local_idx] conv = conversations[global_idx] raw = raw_messages[local_idx] or "" hint = diag.analyze_and_build_hint(context=conv, message=raw, tool_names=self.library) if hint: reasons[global_idx] = (hint, raw) if self.enable_hint_injection: # Inject hint message into conversation to guide next turn diag.apply_hint(conv, hint, max_hints=self.max_diagnostic_messages) self._logger.info(f"Injected recovery hint for sample {global_idx}: {hint}") else: # Track diagnostic only without modifying conversation self._logger.info(f"Tracked malformed output for sample {global_idx}: {hint}") return reasons def _generate_conversation_turn( self, conversations: list["ChatType"], active_indices: list[int] ) -> tuple[list["ChatType"], list[int], tuple[Any, ...], list[ThoughtSchema | None], dict[int, tuple[str, str]]]: # Generate an output for the conversations that are still active inputs = [] for idx in active_indices: conversation = conversations[idx] inputs.append(conversation) outputs = self.llm.generate( inputs=inputs, # TODO: Sanitizing history likely not necessary: self._sanitize_history_for_llm(inputs) num_generations=1, **self.llm.get_generation_kwargs(), ) # Extract the single message from the conversation and the statistics in separate lists messages, statistics = zip( *[(output["generations"][0], output["statistics"]) for output in outputs], strict=True, ) # Log potential truncation self._check_generation_truncation(statistics, active_indices) parser = XMLParser() seeded_messages: list[str | None] = [] for msg in messages: if msg is None: seeded_messages.append(None) else: seeded_messages.append(parser.ensure_think_prefix(msg)) # The parsed messages may contain `None`s if the LLM didn't generate a message. contexts = [conversations[idx] for idx in active_indices] parsed_active_msgs = self.extract_output(list(seeded_messages), contexts=contexts) active_conversations = [conversations[idx] for idx in active_indices] updated_conversations = self._append_messages_to_conversations( role="assistant", messages=parsed_active_msgs, conversations=active_conversations, ) for idx, conv in zip(active_indices, updated_conversations, strict=True): conversations[idx] = conv # Inject hint for failed messages to guide next turn diagnostics: dict[int, tuple[str, str]] = {} diagnostics = self._inject_hints(conversations, active_indices, parsed_active_msgs, list(messages)) # Keep samples active even if message is None (retry next turn) new_active_indices = [ idx for idx, message in zip(active_indices, parsed_active_msgs, strict=True) if (message is None or not message.completed) ] # Create full-sized array aligned with conversations list full_parsed_messages: list[ThoughtSchema | None] = [None] * len(conversations) for active_idx, parsed_msg in zip(active_indices, parsed_active_msgs, strict=True): full_parsed_messages[active_idx] = parsed_msg return conversations, new_active_indices, statistics, full_parsed_messages, diagnostics def _check_generation_truncation(self, statistics: tuple[Any, ...], active_indices: list[int]) -> None: """Check and log potential generation truncation based on token statistics.""" generation_kwargs = self.llm.get_generation_kwargs() max_new_tokens = generation_kwargs["max_new_tokens"] if max_new_tokens is not None: for i, stats in enumerate(statistics): output_tokens = stats["output_tokens"] if output_tokens == max_new_tokens: self._logger.warning( f"Generation may have been truncated at max_new_tokens={max_new_tokens} " f"for conversation {active_indices[i]} (output_tokens={output_tokens})" ) def _execute_tool_calls( self, conversations: list["ChatType"], active_indices: list[int], parsed_messages: list[ThoughtSchema | None] ) -> tuple[list["ChatType"], list[int], dict[int, str], dict[int, int]]: # Filter only messages that have a valid tool call to execute exec_indices: list[int] = [] exec_messages: list[ThoughtSchema] = [] exec_conversations: list[ChatType] = [] for idx in active_indices: message = parsed_messages[idx] if message is None or message.tool_call is None: # Skip malformed or tool-less messages; conversation remains active for next turn continue exec_indices.append(idx) exec_messages.append(message) exec_conversations.append(conversations[idx]) if not exec_indices: return conversations, active_indices, {}, {} active_results, tool_statistics = self._execute(inputs=exec_messages, active_indices=exec_indices) updated_conversations = self._append_messages_to_conversations( role="tool", messages=active_results, conversations=exec_conversations, ) for idx, conv in zip(exec_indices, updated_conversations, strict=True): conversations[idx] = conv # Determine tool errors by inspecting execution_result content error_counts: dict[int, int] = {} for local_pos, res in enumerate(active_results): if str(res.get("execution_result", "")).startswith("ERROR:"): global_idx = exec_indices[local_pos] error_counts[global_idx] = error_counts.get(global_idx, 0) + 1 return conversations, active_indices, tool_statistics, error_counts def _execute( self, inputs: list[ThoughtSchema], active_indices: list[int] ) -> tuple[list[dict[str, str]], dict[int, str]]: results: list[dict[str, str]] = [] statistics: dict[int, str] = {} for data, idx in zip(inputs, active_indices, strict=True): if data.tool_call is None: # Should not happen due to upstream filtering; skip safely continue name = data.tool_call.name arguments = data.tool_call.arguments # Track tool call frequency for this specific input statistics[idx] = name try: lib_functions = get_lib() if name not in lib_functions: results.append({ "function_name": name, "execution_result": f"ERROR: Function '{name}' not found in library", }) continue result = lib_functions[name](**arguments) results.append({"function_name": name, "execution_result": str(result)}) except Exception as exc: # Avoid stopping the pipeline on tool errors; propagate error text results.append({ "function_name": name, "execution_result": f"ERROR: {type(exc).__name__}: {exc}", }) return results, statistics def extract_output(self, messages: list[str | None], contexts: list[list[dict]]) -> list[ThoughtSchema | None]: """Extract the structured output from the messages.""" result: list[ThoughtSchema | None] = [] for ctx, message in zip(contexts, messages, strict=True): # Default placeholder to avoid None entries placeholder = ThoughtSchema(thought="", tool_call=None, final_answer=None, completed=False) if message is None: result.append(placeholder) continue parsed: ThoughtSchema | None = None if self.structured_output: parsed, _ = self.extract_structured_output(message, context=ctx) else: parsed = self.extract_non_structured_output(message, context=ctx) if parsed is None: result.append(placeholder) else: result.append(parsed) return result def extract_structured_output(self, message: str, context: list[dict]) -> tuple[ThoughtSchema | None, str]: """Extract output from messages that enforce structured output.""" parser = XMLParser() try: result = ThoughtSchema.model_validate_json(message) except ValidationError as e: return None, f"malformed JSON: {e!s}" # Enforce is_valid_think_then_tool_or_answer if result.tool_call is None and result.final_answer is None: return None, "missing tool_call or final_answer" # Enforce answer_policy_valid if result.final_answer is not None and not parser._has_tool_calls(context): return None, "answer without tool response" tool_call = result.tool_call answer = result.final_answer # If both a tool and an answer appear, treat this as a tool-call step # and ignore the answer for this turn. In the unstructured setting, # is_valid_think_then_tool_or_answer enforces that tool_call and answer # can never be both present at the same time. if tool_call is not None and answer is not None: return None, "both tool_call and final_answer present" return result, "ok" def extract_non_structured_output(self, message: str, context: list[dict]) -> ThoughtSchema | None: """Extract output from messages that do not enforce structured output.""" parser = XMLParser() analysis = parser.analyze_message_in_context(context, message=message, tool_names=self.library) if self.strict_format and not bool(analysis["is_valid_think_then_tool_or_answer"]): return None if analysis["has_answer"] and not bool(analysis["answer_policy_valid"]): return None thought = analysis["thought"] or "" # Enforce a single tool call per turn: take only the last tool block tool_call: FunctionInvocationInfo | None = None tool_info = analysis["tool"] if tool_info and tool_info["json_valid"]: tool_call = FunctionInvocationInfo( name=str(tool_info["name"]), arguments=dict(tool_info["arguments"]), ) # Mark completion based on presence of answer answer = analysis["answer"] return ThoughtSchema( thought=thought, tool_call=tool_call, final_answer=answer, completed=answer is not None, ) def _diagnose_structured_output(self, msg: str | None, context: list[dict]) -> str: """Heuristic diagnosis of why a generation is unusable to guide recovery.""" if msg is None or not str(msg).strip(): return "empty generation" _, reason = self.extract_structured_output(msg, context) return reason def _diagnose_unstructured_output(self, msg: str | None, context: list[dict]) -> str: """Heuristic diagnosis of why a generation is unusable to guide recovery.""" if msg is None or not str(msg).strip(): return "empty generation" parser = XMLParser() msg = parser.ensure_think_prefix(msg) or "" analysis = parser.analyze_message_in_context(context, message=msg, tool_names=self.library) return parser.get_analysis_failure_reason(analysis, tool_names=self.library) def _generate_multi_turn_conversation( # noqa: C901 self, inputs: list[dict[str, Any]] ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[dict[str, int]]]: conversations = self._prepare_inputs_for_instruction_generation(inputs) # Keep track of the active conversations, as it could happen that for some conversation # we can't generate the next turn because the `LLM` returned `None`. active_indices = list(range(len(conversations))) stats_gen: list[dict[int, Any]] = [] stats_tools: list[dict[int, str]] = [] final_answers: list[str | None] = [None] * len(conversations) malformed_counts: list[int] = [0] * len(conversations) tool_errors_total: list[int] = [0] * len(conversations) diagnostics_reasons: list[list[str]] = [[] for _ in range(len(conversations))] diagnostics_messages: list[list[str]] = [[] for _ in range(len(conversations))] for i in range(self.n_turns): if not active_indices: break # Store current active indices before they get updated current_active_indices = active_indices.copy() # Generate assistant-tool interaction ( conversations, active_indices, statistics_generation, parsed_messages, diagnostics_map, ) = self._generate_conversation_turn( conversations=conversations, active_indices=active_indices, ) # Use the original active indices to access the full parsed messages array for idx in current_active_indices: message = parsed_messages[idx] if message and message.completed: final_answers[idx] = message.final_answer if message is None: malformed_counts[idx] += 1 # Record diagnostics for this turn for idx, (reason, raw) in diagnostics_map.items(): diagnostics_reasons[idx].append(f"turn {i}: {reason}") diagnostics_messages[idx].append(raw) if i == (self.n_turns - 1): # Map tuple of stats back to global indices per_turn = dict(zip(current_active_indices, statistics_generation, strict=True)) stats_gen.append(per_turn) break if not active_indices: break # Generate assistant message; execute only valid tool calls conversations, active_indices, statistics_tools, error_counts = self._execute_tool_calls( conversations=conversations, parsed_messages=parsed_messages, active_indices=active_indices ) stats_tools.append(statistics_tools) per_turn = dict(zip(current_active_indices, statistics_generation, strict=True)) stats_gen.append(per_turn) # Log tool errors if any if error_counts: self._logger.warning(f"Tool errors in turn {i}: {error_counts}") for idx, cnt in error_counts.items(): tool_errors_total[idx] += cnt # Convert None values to empty strings instead of None final_answers_clean = [answer if answer is not None else "" for answer in final_answers] success_indices = self.check_final_answers(final_answers_clean, inputs) self._logger.info( f"Multi-turn conversation completed. Success rate: {sum(success_indices)}/{len(success_indices)}" ) # Build per-sample generation stats list aligned to inputs (list of per-turn stats) gen_stats_list: list[list[Any]] = [] for i in range(len(inputs)): per_sample_stats = [per_turn[i] for per_turn in stats_gen if i in per_turn] gen_stats_list.append(per_sample_stats) merged_stats_tools = self.merge_tool_stats(stats_tools, inputs=inputs) return ( self._prepare_conversation_outputs(conversations, final_answers_clean, success_indices), [ { "gen_stats": gen_stats_list[i], "malformed_turns": malformed_counts[i], "tool_errors": tool_errors_total[i], "diagnostics": diagnostics_reasons[i], "diagnostic_messages": diagnostics_messages[i], } for i in range(len(inputs)) ], merged_stats_tools, ) def check_final_answers(self, final_answers: list[str], inputs: list[dict[str, Any]]) -> list[bool]: """Check if the final answers are correct.""" is_correct = [] for _input, lib_result in zip(inputs, final_answers, strict=True): ground_truth = json.loads(_input["ground_truth"]) is_correct.append(verify_answers(ground_truth, parse_string(lib_result))) return is_correct def merge_tool_stats( self, batch_stats: list[dict[int, str]], inputs: list[dict[str, Any]] ) -> list[dict[str, int]]: """Merge the tool stats into a single dictionary.""" merged_stats: list[dict[str, int]] = [{} for _ in range(len(inputs))] for turn_stats in batch_stats: for active_index, fn_name in turn_stats.items(): merged_stats[active_index][fn_name] = merged_stats[active_index].get(fn_name, 0) + 1 return merged_stats def _prepare_diagnostics_lists(self, stats_gen: dict[str, Any]) -> tuple[list[str], list[str]]: """Prepare diagnostics and diagnostic_messages lists with safe filtering.""" diagnostics_list = [str(v) for v in (stats_gen.get("diagnostics", []) or []) if v is not None and str(v) != ""] diagnostic_msgs_list = [ str(v) for v in (stats_gen.get("diagnostic_messages", []) or []) if v is not None and str(v) != "" ] return diagnostics_list, diagnostic_msgs_list def _generate_with_pre_query_template(self, inputs: list[dict[str, Any]]) -> list[dict[str, Any]]: """Generate a list of instructions or conversations of the specified number of turns.""" outputs, statistics_gens, statistics_tools = self._generate_multi_turn_conversation(inputs) generations = [] for input_data, output, stats_gen, stats_tools in zip( inputs, outputs, statistics_gens, statistics_tools, strict=True ): generation = { **input_data, **output, "model_name": self.llm.model_name, } # Ensure stats_tools has at least one field to avoid Parquet serialization errors tool_stats = stats_tools if isinstance(stats_tools, dict) and stats_tools else {"_empty": 0} # Prepare diagnostics lists with safe filtering diagnostics_list, diagnostic_msgs_list = self._prepare_diagnostics_lists(stats_gen) generation["distilabel_metadata"] = { f"statistics_gen_{self.name}": stats_gen.get("gen_stats", []), f"statistics_tools_{self.name}": tool_stats, f"malformed_turns_{self.name}": int(stats_gen.get("malformed_turns", 0) or 0), f"tool_errors_{self.name}": int(stats_gen.get("tool_errors", 0) or 0), f"tool_calls_total_{self.name}": int(sum(tool_stats.values())) if isinstance(tool_stats, dict) else 0, f"diagnostics_{self.name}": (diagnostics_list if diagnostics_list else [""]), f"diagnostic_messages_{self.name}": (diagnostic_msgs_list if diagnostic_msgs_list else [""]), } generations.append(generation) return generations