Spaces:
Running on Zero
Running on Zero
| import json | |
| from typing import TYPE_CHECKING, Any | |
| from distilabel.errors import DistilabelUserError | |
| from distilabel.mixins.runtime_parameters import ( | |
| RuntimeParameter, | |
| RuntimeParametersMixin, | |
| ) | |
| from distilabel.models.llms.base import LLM | |
| from pydantic import Field, PositiveInt, ValidationError | |
| from linalg_zero.distillation.components.diagnostics import Diagnostics | |
| from linalg_zero.distillation.components.models import ModelType | |
| from linalg_zero.distillation.data import FunctionInvocationInfo, ThoughtSchema | |
| from linalg_zero.grpo.verifiers.xml_parser import ( | |
| XMLParser, | |
| ) | |
| from linalg_zero.grpo.verify import parse_string, verify_answers | |
| from linalg_zero.shared.lib import get_lib | |
| from linalg_zero.shared.system_prompts import ( | |
| ANSWER_OPEN, | |
| THINK_OPEN, | |
| TOOL_CALL_OPEN, | |
| ) | |
| from linalg_zero.shared.utils import get_logger | |
| if TYPE_CHECKING: | |
| from distilabel.typing import ChatType | |
| from linalg_zero.distillation.components.models import ( | |
| ModelParameters, | |
| ) | |
| class MultiTurnWithToolUseBase(RuntimeParametersMixin): | |
| llm: LLM | |
| _logger = get_logger(__name__) | |
| n_turns: PositiveInt = Field( | |
| description="The number of turns to generate for the conversation.", | |
| ) | |
| include_system_prompt: RuntimeParameter[bool] = Field( | |
| default=True, | |
| description="Whether to include the system prompt used in the generated conversation.", | |
| ) | |
| library: list[str] = Field( | |
| description="The list of function names available for tool calls.", | |
| ) | |
| system_prompt: RuntimeParameter[str] = Field( | |
| default=None, | |
| description="The system prompt to use for the generation.", | |
| ) | |
| structured_output: RuntimeParameter[bool] = Field( | |
| default=False, | |
| description="Whether to use structured output for the generation.", | |
| ) | |
| enable_hint_injection: RuntimeParameter[bool] = Field( | |
| default=False, | |
| description="If true, inject a user hint about malformed outputs to guide the next turn. If false, track diagnostics only without modifying conversations.", | |
| ) | |
| max_diagnostic_messages: RuntimeParameter[int | None] = Field( | |
| default=1, | |
| description="Maximum number of diagnostic user messages to retain in the conversation (None for unbounded).", | |
| ) | |
| strict_format: RuntimeParameter[bool] = Field( | |
| default=True, | |
| description=f"If true, enforce strict '{THINK_OPEN} then {TOOL_CALL_OPEN}|{ANSWER_OPEN}' structure gate in parsing.", | |
| ) | |
| min_successful_completions: int = Field( | |
| default=-1, | |
| exclude=True, | |
| description="If set, continue generating until this many successful completions are achieved (ignores dataset size).", | |
| ) | |
| strip_think_prefix: RuntimeParameter[bool] = Field( | |
| default=True, | |
| description="If true, strip the think prefix from the conversation. This is needed for Qwen3 models.", | |
| ) | |
| model_type: str = Field( | |
| description="The name of the model.", | |
| ) | |
| def _prepare_inputs_for_instruction_generation(self, inputs: list[dict[str, Any]]) -> list["ChatType"]: | |
| prepared_inputs = [] | |
| for data in inputs: | |
| conversation = [] | |
| if self.system_prompt: | |
| conversation.append({"role": "system", "content": self.system_prompt}) | |
| conversation.append({"role": "user", "content": data["query"]}) | |
| prepared_inputs.append(conversation) | |
| return prepared_inputs | |
| def create_assistant_message(self, message: ThoughtSchema) -> dict[str, Any] | None: | |
| config: ModelParameters = ModelType(self.model_type).get_model_parameters() | |
| return config.format_assistant_message(message) | |
| def create_tool_message(self, conversation: list["ChatType"], message: dict[str, Any]) -> dict[str, Any]: | |
| # NOTE: Find the last assistant message with tool calls. This only works for single-turn tool calls, | |
| # if we transition to multiple calls per turn, must match by name or position. | |
| tool_call_id = None | |
| for msg in reversed(conversation): | |
| if msg.get("role") == "assistant" and msg.get("tool_calls"): | |
| tool_call_id = msg.get("tool_calls", [{}])[0].get("id") | |
| break | |
| if tool_call_id is None: | |
| raise DistilabelUserError("No assistant message with tool_calls found for tool response") | |
| return { | |
| "role": "tool", | |
| "tool_call_id": tool_call_id, | |
| "name": message["function_name"], | |
| "content": message["execution_result"], | |
| } | |
| def _append_messages_to_conversations( | |
| self, role: str, messages: list[ThoughtSchema | None] | list[dict[str, str]], conversations: list["ChatType"] | |
| ) -> list["ChatType"]: | |
| """Appends the outputs generated by the LLM with the specified role to the conversations.""" | |
| for message, conversation in zip(messages, conversations, strict=True): | |
| # The message may be None because of incomplete generations. | |
| if message is None: | |
| continue | |
| if role == "assistant": | |
| if isinstance(message, ThoughtSchema): | |
| # The possible previous roles are diagnostic user messages or tool messages | |
| last_role = conversation[-1]["role"] | |
| assert last_role in ("user", "tool"), ( | |
| f"Invalid last role before assistant generation: '{last_role}'" | |
| ) | |
| formatted_message = self.create_assistant_message(message) | |
| if formatted_message is not None: | |
| conversation.append(formatted_message) | |
| else: | |
| continue | |
| elif role == "tool": | |
| if isinstance(message, dict): | |
| formatted_message = self.create_tool_message(conversation, message) | |
| conversation.append(formatted_message) | |
| else: | |
| continue | |
| else: | |
| raise DistilabelUserError(f"Invalid role: {role}. Must be 'assistant' or 'tool'.") | |
| return conversations | |
| def _prepare_conversation_outputs( | |
| self, conversations: list["ChatType"], final_answers: list[str], success_indices: list[bool] | |
| ) -> list[dict[str, Any]]: | |
| """Prepare the output conversation removing the system prompt if necessary. | |
| It will return a dictionary with a "messages" key.""" | |
| diag = Diagnostics(model_type=ModelType(self.model_type)) | |
| outputs: list[dict[str, Any]] = [] | |
| for conversation, final_answer, is_correct in zip(conversations, final_answers, success_indices, strict=True): | |
| if conversation is None: | |
| raise DistilabelUserError("Conversation is None. This should not happen.") | |
| if len(conversation) == 0: | |
| # Something went wrong with the `LLM` and it didn't generate any message | |
| outputs.append({"messages": []}) | |
| continue | |
| conv_out = list(conversation) | |
| # Remove any mid-conversation hint messages | |
| conv_out = diag.remove_hint_messages(conv_out) | |
| # Drop initial system if requested | |
| if not self.include_system_prompt and conv_out[0]["role"] == "system": | |
| conv_out = conv_out[1:] | |
| output = {"messages": conv_out, "final_answer": final_answer, "is_correct": is_correct} | |
| outputs.append(output) | |
| return outputs | |
| def _sanitize_history_for_llm(self, conversations: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]: | |
| """Return a sanitized deep copy of conversations without THINK content in history. | |
| - Assistant messages: drop THINK segments; keep only final answer text. | |
| If tool_calls present, keep tool_calls and clear content. | |
| - Other roles are kept as-is. | |
| """ | |
| assert self.model_type == "default", "This function is only supported for Default model (Qwen3)" | |
| sanitized: list[list[dict[str, Any]]] = [] | |
| for conv in conversations: | |
| new_conv: list[dict[str, Any]] = [] | |
| for msg in conv: | |
| m = dict(msg) | |
| if m.get("role") == "assistant" and m.get("tool_calls"): | |
| m["content"] = "" | |
| new_conv.append(m) | |
| sanitized.append(new_conv) | |
| return sanitized | |
| def _inject_hints( | |
| self, | |
| conversations: list["ChatType"], | |
| active_indices: list[int], | |
| parsed_active_msgs: list[ThoughtSchema | None], | |
| raw_messages: list[str | None], | |
| ) -> dict[int, tuple[str, str]]: | |
| """Track diagnostic reasons and optionally inject hint messages for malformed outputs. | |
| Returns a mapping from global sample index to (diagnostic reason, raw message). | |
| """ | |
| reasons: dict[int, tuple[str, str]] = {} | |
| diag = Diagnostics(model_type=ModelType(self.model_type)) | |
| for local_idx, parsed in enumerate(parsed_active_msgs): | |
| if parsed is None or (parsed.tool_call is None and not parsed.completed): | |
| global_idx = active_indices[local_idx] | |
| conv = conversations[global_idx] | |
| raw = raw_messages[local_idx] or "" | |
| hint = diag.analyze_and_build_hint(context=conv, message=raw, tool_names=self.library) | |
| if hint: | |
| reasons[global_idx] = (hint, raw) | |
| if self.enable_hint_injection: | |
| # Inject hint message into conversation to guide next turn | |
| diag.apply_hint(conv, hint, max_hints=self.max_diagnostic_messages) | |
| self._logger.info(f"Injected recovery hint for sample {global_idx}: {hint}") | |
| else: | |
| # Track diagnostic only without modifying conversation | |
| self._logger.info(f"Tracked malformed output for sample {global_idx}: {hint}") | |
| return reasons | |
| def _generate_conversation_turn( | |
| self, conversations: list["ChatType"], active_indices: list[int] | |
| ) -> tuple[list["ChatType"], list[int], tuple[Any, ...], list[ThoughtSchema | None], dict[int, tuple[str, str]]]: | |
| # Generate an output for the conversations that are still active | |
| inputs = [] | |
| for idx in active_indices: | |
| conversation = conversations[idx] | |
| inputs.append(conversation) | |
| outputs = self.llm.generate( | |
| inputs=inputs, # TODO: Sanitizing history likely not necessary: self._sanitize_history_for_llm(inputs) | |
| num_generations=1, | |
| **self.llm.get_generation_kwargs(), | |
| ) | |
| # Extract the single message from the conversation and the statistics in separate lists | |
| messages, statistics = zip( | |
| *[(output["generations"][0], output["statistics"]) for output in outputs], | |
| strict=True, | |
| ) | |
| # Log potential truncation | |
| self._check_generation_truncation(statistics, active_indices) | |
| parser = XMLParser() | |
| seeded_messages: list[str | None] = [] | |
| for msg in messages: | |
| if msg is None: | |
| seeded_messages.append(None) | |
| else: | |
| seeded_messages.append(parser.ensure_think_prefix(msg)) | |
| # The parsed messages may contain `None`s if the LLM didn't generate a message. | |
| contexts = [conversations[idx] for idx in active_indices] | |
| parsed_active_msgs = self.extract_output(list(seeded_messages), contexts=contexts) | |
| active_conversations = [conversations[idx] for idx in active_indices] | |
| updated_conversations = self._append_messages_to_conversations( | |
| role="assistant", | |
| messages=parsed_active_msgs, | |
| conversations=active_conversations, | |
| ) | |
| for idx, conv in zip(active_indices, updated_conversations, strict=True): | |
| conversations[idx] = conv | |
| # Inject hint for failed messages to guide next turn | |
| diagnostics: dict[int, tuple[str, str]] = {} | |
| diagnostics = self._inject_hints(conversations, active_indices, parsed_active_msgs, list(messages)) | |
| # Keep samples active even if message is None (retry next turn) | |
| new_active_indices = [ | |
| idx | |
| for idx, message in zip(active_indices, parsed_active_msgs, strict=True) | |
| if (message is None or not message.completed) | |
| ] | |
| # Create full-sized array aligned with conversations list | |
| full_parsed_messages: list[ThoughtSchema | None] = [None] * len(conversations) | |
| for active_idx, parsed_msg in zip(active_indices, parsed_active_msgs, strict=True): | |
| full_parsed_messages[active_idx] = parsed_msg | |
| return conversations, new_active_indices, statistics, full_parsed_messages, diagnostics | |
| def _check_generation_truncation(self, statistics: tuple[Any, ...], active_indices: list[int]) -> None: | |
| """Check and log potential generation truncation based on token statistics.""" | |
| generation_kwargs = self.llm.get_generation_kwargs() | |
| max_new_tokens = generation_kwargs["max_new_tokens"] | |
| if max_new_tokens is not None: | |
| for i, stats in enumerate(statistics): | |
| output_tokens = stats["output_tokens"] | |
| if output_tokens == max_new_tokens: | |
| self._logger.warning( | |
| f"Generation may have been truncated at max_new_tokens={max_new_tokens} " | |
| f"for conversation {active_indices[i]} (output_tokens={output_tokens})" | |
| ) | |
| def _execute_tool_calls( | |
| self, conversations: list["ChatType"], active_indices: list[int], parsed_messages: list[ThoughtSchema | None] | |
| ) -> tuple[list["ChatType"], list[int], dict[int, str], dict[int, int]]: | |
| # Filter only messages that have a valid tool call to execute | |
| exec_indices: list[int] = [] | |
| exec_messages: list[ThoughtSchema] = [] | |
| exec_conversations: list[ChatType] = [] | |
| for idx in active_indices: | |
| message = parsed_messages[idx] | |
| if message is None or message.tool_call is None: | |
| # Skip malformed or tool-less messages; conversation remains active for next turn | |
| continue | |
| exec_indices.append(idx) | |
| exec_messages.append(message) | |
| exec_conversations.append(conversations[idx]) | |
| if not exec_indices: | |
| return conversations, active_indices, {}, {} | |
| active_results, tool_statistics = self._execute(inputs=exec_messages, active_indices=exec_indices) | |
| updated_conversations = self._append_messages_to_conversations( | |
| role="tool", | |
| messages=active_results, | |
| conversations=exec_conversations, | |
| ) | |
| for idx, conv in zip(exec_indices, updated_conversations, strict=True): | |
| conversations[idx] = conv | |
| # Determine tool errors by inspecting execution_result content | |
| error_counts: dict[int, int] = {} | |
| for local_pos, res in enumerate(active_results): | |
| if str(res.get("execution_result", "")).startswith("ERROR:"): | |
| global_idx = exec_indices[local_pos] | |
| error_counts[global_idx] = error_counts.get(global_idx, 0) + 1 | |
| return conversations, active_indices, tool_statistics, error_counts | |
| def _execute( | |
| self, inputs: list[ThoughtSchema], active_indices: list[int] | |
| ) -> tuple[list[dict[str, str]], dict[int, str]]: | |
| results: list[dict[str, str]] = [] | |
| statistics: dict[int, str] = {} | |
| for data, idx in zip(inputs, active_indices, strict=True): | |
| if data.tool_call is None: | |
| # Should not happen due to upstream filtering; skip safely | |
| continue | |
| name = data.tool_call.name | |
| arguments = data.tool_call.arguments | |
| # Track tool call frequency for this specific input | |
| statistics[idx] = name | |
| try: | |
| lib_functions = get_lib() | |
| if name not in lib_functions: | |
| results.append({ | |
| "function_name": name, | |
| "execution_result": f"ERROR: Function '{name}' not found in library", | |
| }) | |
| continue | |
| result = lib_functions[name](**arguments) | |
| results.append({"function_name": name, "execution_result": str(result)}) | |
| except Exception as exc: | |
| # Avoid stopping the pipeline on tool errors; propagate error text | |
| results.append({ | |
| "function_name": name, | |
| "execution_result": f"ERROR: {type(exc).__name__}: {exc}", | |
| }) | |
| return results, statistics | |
| def extract_output(self, messages: list[str | None], contexts: list[list[dict]]) -> list[ThoughtSchema | None]: | |
| """Extract the structured output from the messages.""" | |
| result: list[ThoughtSchema | None] = [] | |
| for ctx, message in zip(contexts, messages, strict=True): | |
| # Default placeholder to avoid None entries | |
| placeholder = ThoughtSchema(thought="", tool_call=None, final_answer=None, completed=False) | |
| if message is None: | |
| result.append(placeholder) | |
| continue | |
| parsed: ThoughtSchema | None = None | |
| if self.structured_output: | |
| parsed, _ = self.extract_structured_output(message, context=ctx) | |
| else: | |
| parsed = self.extract_non_structured_output(message, context=ctx) | |
| if parsed is None: | |
| result.append(placeholder) | |
| else: | |
| result.append(parsed) | |
| return result | |
| def extract_structured_output(self, message: str, context: list[dict]) -> tuple[ThoughtSchema | None, str]: | |
| """Extract output from messages that enforce structured output.""" | |
| parser = XMLParser() | |
| try: | |
| result = ThoughtSchema.model_validate_json(message) | |
| except ValidationError as e: | |
| return None, f"malformed JSON: {e!s}" | |
| # Enforce is_valid_think_then_tool_or_answer | |
| if result.tool_call is None and result.final_answer is None: | |
| return None, "missing tool_call or final_answer" | |
| # Enforce answer_policy_valid | |
| if result.final_answer is not None and not parser._has_tool_calls(context): | |
| return None, "answer without tool response" | |
| tool_call = result.tool_call | |
| answer = result.final_answer | |
| # If both a tool and an answer appear, treat this as a tool-call step | |
| # and ignore the answer for this turn. In the unstructured setting, | |
| # is_valid_think_then_tool_or_answer enforces that tool_call and answer | |
| # can never be both present at the same time. | |
| if tool_call is not None and answer is not None: | |
| return None, "both tool_call and final_answer present" | |
| return result, "ok" | |
| def extract_non_structured_output(self, message: str, context: list[dict]) -> ThoughtSchema | None: | |
| """Extract output from messages that do not enforce structured output.""" | |
| parser = XMLParser() | |
| analysis = parser.analyze_message_in_context(context, message=message, tool_names=self.library) | |
| if self.strict_format and not bool(analysis["is_valid_think_then_tool_or_answer"]): | |
| return None | |
| if analysis["has_answer"] and not bool(analysis["answer_policy_valid"]): | |
| return None | |
| thought = analysis["thought"] or "" | |
| # Enforce a single tool call per turn: take only the last tool block | |
| tool_call: FunctionInvocationInfo | None = None | |
| tool_info = analysis["tool"] | |
| if tool_info and tool_info["json_valid"]: | |
| tool_call = FunctionInvocationInfo( | |
| name=str(tool_info["name"]), | |
| arguments=dict(tool_info["arguments"]), | |
| ) | |
| # Mark completion based on presence of answer | |
| answer = analysis["answer"] | |
| return ThoughtSchema( | |
| thought=thought, | |
| tool_call=tool_call, | |
| final_answer=answer, | |
| completed=answer is not None, | |
| ) | |
| def _diagnose_structured_output(self, msg: str | None, context: list[dict]) -> str: | |
| """Heuristic diagnosis of why a generation is unusable to guide recovery.""" | |
| if msg is None or not str(msg).strip(): | |
| return "empty generation" | |
| _, reason = self.extract_structured_output(msg, context) | |
| return reason | |
| def _diagnose_unstructured_output(self, msg: str | None, context: list[dict]) -> str: | |
| """Heuristic diagnosis of why a generation is unusable to guide recovery.""" | |
| if msg is None or not str(msg).strip(): | |
| return "empty generation" | |
| parser = XMLParser() | |
| msg = parser.ensure_think_prefix(msg) or "" | |
| analysis = parser.analyze_message_in_context(context, message=msg, tool_names=self.library) | |
| return parser.get_analysis_failure_reason(analysis, tool_names=self.library) | |
| def _generate_multi_turn_conversation( # noqa: C901 | |
| self, inputs: list[dict[str, Any]] | |
| ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[dict[str, int]]]: | |
| conversations = self._prepare_inputs_for_instruction_generation(inputs) | |
| # Keep track of the active conversations, as it could happen that for some conversation | |
| # we can't generate the next turn because the `LLM` returned `None`. | |
| active_indices = list(range(len(conversations))) | |
| stats_gen: list[dict[int, Any]] = [] | |
| stats_tools: list[dict[int, str]] = [] | |
| final_answers: list[str | None] = [None] * len(conversations) | |
| malformed_counts: list[int] = [0] * len(conversations) | |
| tool_errors_total: list[int] = [0] * len(conversations) | |
| diagnostics_reasons: list[list[str]] = [[] for _ in range(len(conversations))] | |
| diagnostics_messages: list[list[str]] = [[] for _ in range(len(conversations))] | |
| for i in range(self.n_turns): | |
| if not active_indices: | |
| break | |
| # Store current active indices before they get updated | |
| current_active_indices = active_indices.copy() | |
| # Generate assistant-tool interaction | |
| ( | |
| conversations, | |
| active_indices, | |
| statistics_generation, | |
| parsed_messages, | |
| diagnostics_map, | |
| ) = self._generate_conversation_turn( | |
| conversations=conversations, | |
| active_indices=active_indices, | |
| ) | |
| # Use the original active indices to access the full parsed messages array | |
| for idx in current_active_indices: | |
| message = parsed_messages[idx] | |
| if message and message.completed: | |
| final_answers[idx] = message.final_answer | |
| if message is None: | |
| malformed_counts[idx] += 1 | |
| # Record diagnostics for this turn | |
| for idx, (reason, raw) in diagnostics_map.items(): | |
| diagnostics_reasons[idx].append(f"turn {i}: {reason}") | |
| diagnostics_messages[idx].append(raw) | |
| if i == (self.n_turns - 1): | |
| # Map tuple of stats back to global indices | |
| per_turn = dict(zip(current_active_indices, statistics_generation, strict=True)) | |
| stats_gen.append(per_turn) | |
| break | |
| if not active_indices: | |
| break | |
| # Generate assistant message; execute only valid tool calls | |
| conversations, active_indices, statistics_tools, error_counts = self._execute_tool_calls( | |
| conversations=conversations, parsed_messages=parsed_messages, active_indices=active_indices | |
| ) | |
| stats_tools.append(statistics_tools) | |
| per_turn = dict(zip(current_active_indices, statistics_generation, strict=True)) | |
| stats_gen.append(per_turn) | |
| # Log tool errors if any | |
| if error_counts: | |
| self._logger.warning(f"Tool errors in turn {i}: {error_counts}") | |
| for idx, cnt in error_counts.items(): | |
| tool_errors_total[idx] += cnt | |
| # Convert None values to empty strings instead of None | |
| final_answers_clean = [answer if answer is not None else "" for answer in final_answers] | |
| success_indices = self.check_final_answers(final_answers_clean, inputs) | |
| self._logger.info( | |
| f"Multi-turn conversation completed. Success rate: {sum(success_indices)}/{len(success_indices)}" | |
| ) | |
| # Build per-sample generation stats list aligned to inputs (list of per-turn stats) | |
| gen_stats_list: list[list[Any]] = [] | |
| for i in range(len(inputs)): | |
| per_sample_stats = [per_turn[i] for per_turn in stats_gen if i in per_turn] | |
| gen_stats_list.append(per_sample_stats) | |
| merged_stats_tools = self.merge_tool_stats(stats_tools, inputs=inputs) | |
| return ( | |
| self._prepare_conversation_outputs(conversations, final_answers_clean, success_indices), | |
| [ | |
| { | |
| "gen_stats": gen_stats_list[i], | |
| "malformed_turns": malformed_counts[i], | |
| "tool_errors": tool_errors_total[i], | |
| "diagnostics": diagnostics_reasons[i], | |
| "diagnostic_messages": diagnostics_messages[i], | |
| } | |
| for i in range(len(inputs)) | |
| ], | |
| merged_stats_tools, | |
| ) | |
| def check_final_answers(self, final_answers: list[str], inputs: list[dict[str, Any]]) -> list[bool]: | |
| """Check if the final answers are correct.""" | |
| is_correct = [] | |
| for _input, lib_result in zip(inputs, final_answers, strict=True): | |
| ground_truth = json.loads(_input["ground_truth"]) | |
| is_correct.append(verify_answers(ground_truth, parse_string(lib_result))) | |
| return is_correct | |
| def merge_tool_stats( | |
| self, batch_stats: list[dict[int, str]], inputs: list[dict[str, Any]] | |
| ) -> list[dict[str, int]]: | |
| """Merge the tool stats into a single dictionary.""" | |
| merged_stats: list[dict[str, int]] = [{} for _ in range(len(inputs))] | |
| for turn_stats in batch_stats: | |
| for active_index, fn_name in turn_stats.items(): | |
| merged_stats[active_index][fn_name] = merged_stats[active_index].get(fn_name, 0) + 1 | |
| return merged_stats | |
| def _prepare_diagnostics_lists(self, stats_gen: dict[str, Any]) -> tuple[list[str], list[str]]: | |
| """Prepare diagnostics and diagnostic_messages lists with safe filtering.""" | |
| diagnostics_list = [str(v) for v in (stats_gen.get("diagnostics", []) or []) if v is not None and str(v) != ""] | |
| diagnostic_msgs_list = [ | |
| str(v) for v in (stats_gen.get("diagnostic_messages", []) or []) if v is not None and str(v) != "" | |
| ] | |
| return diagnostics_list, diagnostic_msgs_list | |
| def _generate_with_pre_query_template(self, inputs: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| """Generate a list of instructions or conversations of the specified number of turns.""" | |
| outputs, statistics_gens, statistics_tools = self._generate_multi_turn_conversation(inputs) | |
| generations = [] | |
| for input_data, output, stats_gen, stats_tools in zip( | |
| inputs, outputs, statistics_gens, statistics_tools, strict=True | |
| ): | |
| generation = { | |
| **input_data, | |
| **output, | |
| "model_name": self.llm.model_name, | |
| } | |
| # Ensure stats_tools has at least one field to avoid Parquet serialization errors | |
| tool_stats = stats_tools if isinstance(stats_tools, dict) and stats_tools else {"_empty": 0} | |
| # Prepare diagnostics lists with safe filtering | |
| diagnostics_list, diagnostic_msgs_list = self._prepare_diagnostics_lists(stats_gen) | |
| generation["distilabel_metadata"] = { | |
| f"statistics_gen_{self.name}": stats_gen.get("gen_stats", []), | |
| f"statistics_tools_{self.name}": tool_stats, | |
| f"malformed_turns_{self.name}": int(stats_gen.get("malformed_turns", 0) or 0), | |
| f"tool_errors_{self.name}": int(stats_gen.get("tool_errors", 0) or 0), | |
| f"tool_calls_total_{self.name}": int(sum(tool_stats.values())) if isinstance(tool_stats, dict) else 0, | |
| f"diagnostics_{self.name}": (diagnostics_list if diagnostics_list else [""]), | |
| f"diagnostic_messages_{self.name}": (diagnostic_msgs_list if diagnostic_msgs_list else [""]), | |
| } | |
| generations.append(generation) | |
| return generations | |