import html import json import logging import logging as stdlib_logging from copy import deepcopy from typing import ( Any, ) import argilla as rg from datasets import Dataset from datasets import load_dataset as hf_load_dataset from distilabel.distiset import Distiset from distilabel.models import OpenAILLM from distilabel.models.base_clients.openai import SecretStr from distilabel.models.llms.utils import prepare_output from distilabel.steps.tasks.apigen.execution_checker import load_module_from_path from distilabel.typing import FormattedInput, GenerateOutput from openai.types.chat import ChatCompletion as OpenAIChatCompletion from pydantic import BaseModel, NonNegativeInt, PositiveInt from typing_extensions import override from linalg_zero.config.data import ( DistillationConfig, LlamaCppServerConfig, VllmServerConfig, ) from linalg_zero.distillation.components.models import ( ModelParameters, ModelType, ) from linalg_zero.shared.lib import get_tools from linalg_zero.shared.system_prompts import THINK_CLOSE, THINK_OPEN from linalg_zero.shared.utils import get_libpath, get_logger, setup_logging logger = get_logger(__name__) # TODO: is this the right file to store this class in? class CustomOpenAILLM(OpenAILLM): """ Patched OpenAI LLM that supports tool calls by bypassing the restrictive validation. This allows using the full OpenAI API format with tool_calls and tool roles. """ @override async def agenerate( self, input: FormattedInput, num_generations: int = 1, max_new_tokens: NonNegativeInt = 128, logprobs: bool = False, top_logprobs: PositiveInt | None = None, echo: bool = False, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, temperature: float = 1.0, top_p: float = 1.0, stop: str | list[str] | None = None, response_format: dict[str, str] | None = None, extra_body: dict[str, Any] | None = None, ) -> GenerateOutput: """Override agenerate to bypass validation and support tool calls.""" if isinstance(input, str): return await self._generate_completion( input=input, num_generations=num_generations, max_new_tokens=max_new_tokens, echo=echo, top_logprobs=top_logprobs, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, temperature=temperature, top_p=top_p, extra_body=extra_body, ) return await self._generate_chat_completion( input=input, num_generations=num_generations, max_new_tokens=max_new_tokens, logprobs=logprobs, top_logprobs=top_logprobs, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, temperature=temperature, top_p=top_p, stop=stop, response_format=response_format, extra_body=extra_body, ) def _generations_from_openai_completion_a3b(self, completion: "OpenAIChatCompletion") -> "GenerateOutput": """Get the generations from the OpenAI Chat Completion object. Args: completion: the completion object to get the generations from. Returns: A list of strings containing the generated responses for the input. """ generations = [] logprobs = [] for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( f"Received no response using OpenAI client (model: '{self.model}')." f" Finish reason was: {choice.finish_reason}" ) generations.append(content) if choice_logprobs := self._get_logprobs_from_chat_completion_choice(choice): logprobs.append(choice_logprobs) statistics = self._get_llm_statistics(completion) return prepare_output( generations=generations, input_tokens=statistics["input_tokens"], output_tokens=statistics["output_tokens"], logprobs=logprobs, ) def _generations_from_openai_completion(self, completion: "OpenAIChatCompletion") -> "GenerateOutput": """Get the generations from the OpenAI Chat Completion object. Args: completion: the completion object to get the generations from. Returns: A list of strings containing the generated responses for the input. """ generations = [] logprobs = [] for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( f"Received no response using OpenAI client (model: '{self.model}')." f" Finish reason was: {choice.finish_reason}" ) if (reasoning_content := choice.message.reasoning_content) is not None: content = THINK_OPEN + reasoning_content.strip() + THINK_CLOSE + (content or "") else: content = THINK_OPEN + "\n\n" + THINK_CLOSE + (content or "") generations.append(content) if choice_logprobs := self._get_logprobs_from_chat_completion_choice(choice): logprobs.append(choice_logprobs) statistics = self._get_llm_statistics(completion) return prepare_output( generations=generations, input_tokens=statistics["input_tokens"], output_tokens=statistics["output_tokens"], logprobs=logprobs, ) def get_openai_client( model: str, base_url: str, model_type: str, timeout: int = 900, retries: int = 3, max_new_tokens: int = 8192, deterministic: bool = True, structured_output: dict[str, Any] | None = None, ) -> OpenAILLM: generation_kwargs: dict[str, Any] = {"max_new_tokens": max_new_tokens} params: ModelParameters = ModelType(model_type).get_model_parameters() generation_kwargs = params.set_recommended_defaults(generation_kwargs, deterministic=deterministic) return CustomOpenAILLM( model=model, base_url=base_url, api_key=SecretStr("not-used"), timeout=timeout, max_retries=retries, generation_kwargs=generation_kwargs, structured_output=structured_output, ) def create_llm_clients( server: LlamaCppServerConfig | VllmServerConfig, args: DistillationConfig, schema: type[BaseModel] ) -> OpenAILLM: """Create structured and non-structured LLM clients.""" base_params: dict[str, Any] = { "model": server.model, "base_url": f"http://{server.host}:{server.port}/v1", "timeout": args.timeout, "retries": args.retries, "max_new_tokens": args.max_new_tokens, "model_type": args.model_type, "deterministic": args.deterministic, } if args.structured_output: base_params["structured_output"] = {"schema": schema} else: base_params["structured_output"] = None llm = get_openai_client(**base_params) return llm def get_function_schema() -> str: """Returns the tools for function calling.""" libpath_module = load_module_from_path(get_libpath()) tools = libpath_module.get_tools() function_definitions = [tool_info["function"] for tool_info in tools] function_schema = json.dumps(function_definitions, indent=2) return function_schema def is_openai_format(messages: Any) -> bool: """Checks if the input is in OpenAI chat-like format: ```python [ {"role": "user", "content": "Turn on the living room lights."}, {"role": "assistant", "tool_calls": [ {"type": "function", "function": { "name": "control_light", "arguments": {"room": "living room", "state": "on"} }}] }, {"role": "tool", "name": "control_light", "content": "The lights in the living room are now on."}, {"role": "assistant", "content": "Done!"} ] ``` Args: input: The input to check. Returns: A boolean indicating if the input is in OpenAI chat-like format. """ if not isinstance(messages, list): return False return all(isinstance(x, dict) and "role" in x and ("content" in x or "tool_calls" in x) for x in messages) def save_distiset_to_disk(distiset: Distiset, path: str) -> None: """Save the distiset to a directory.""" distiset.save_to_disk(path) def print_statistics(distilabel_train: list[dict[str, Any]]) -> None: total_train = len(distilabel_train) train_correct = sum(1 for row in distilabel_train if row["is_correct"]) logger.info(f" Math verify successes: {train_correct}/{total_train}") def cleanup() -> None: """Cleans up logging to prevent multiprocessing queue errors.""" root_logger = stdlib_logging.getLogger() queue_handlers = [h for h in root_logger.handlers if hasattr(h, "queue")] for handler in queue_handlers: root_logger.removeHandler(handler) # Reinitialize logging setup_logging(level=logging.INFO, include_timestamp=True) def create_argilla_dataset_settings() -> rg.Settings: """Create Argilla dataset settings for linear algebra distillation results.""" return rg.Settings( guidelines="""Review and validate the model's reasoning for linear algebra problems.""", fields=[ rg.TextField( name="problem_type", title="Problem Type", use_markdown=False, ), rg.TextField( name="tool_calls", title="Number of Tool Calls Made", use_markdown=False, ), rg.TextField( name="query", title="User's Linear Algebra Problem Query", use_markdown=False, ), rg.TextField( name="is_correct", title="Is Answer Correct?", use_markdown=False, ), rg.TextField( name="ground_truth", title="Ground Truth Result", use_markdown=False, ), rg.TextField( name="stepwise_ground_truths", title="Stepwise Ground Truth Solutions", use_markdown=False, ), rg.TextField( name="final_answer", title="Model's Final Answer", use_markdown=False, ), rg.TextField( name="messages", title="Full Conversation", use_markdown=False, ), rg.TextField( name="diagnostics", title="Diagnostics (per turn)", use_markdown=False, ), rg.TextField( name="diagnostic_messages", title="Diagnostic raw messages (failed turns)", use_markdown=False, ), rg.TextField( name="composition_dependencies", title="Composition Dependencies", use_markdown=False, ), rg.TextField( name="composition_type", title="Composition Type", use_markdown=False, ), rg.TextField( name="dependency_edges", title="Dependency Edges", use_markdown=False, ), rg.TextField( name="model_name", title="Model Name Used", use_markdown=False, ), ], questions=[ rg.LabelQuestion( name="reasoning_quality", title="How would you rate the overall reasoning quality?", labels=["excellent", "good", "fair", "poor"], ), rg.LabelQuestion( name="mathematical_accuracy", title="Is the mathematical reasoning correct?", labels=["correct", "minor_errors", "major_errors", "incorrect"], ), rg.LabelQuestion( name="tool_usage", title="Are the tool calls appropriate and effective?", labels=["optimal", "good", "suboptimal", "incorrect"], ), rg.LabelQuestion( name="final_correctness", title="Is the final answer correct?", labels=["correct", "close", "wrong", "no_answer"], ), rg.TextQuestion( name="feedback", title="Additional feedback or observations", ), ], ) def _delete_existing_argilla_dataset(client: rg.Argilla, dataset_name: str) -> None: """Delete existing Argilla dataset if it exists.""" logger = get_logger(__name__) try: existing_dataset = client.datasets(name=dataset_name) if existing_dataset: existing_dataset.delete() logger.info(f"Deleted existing Argilla dataset: {dataset_name}") except Exception: logger.exception("Failed to delete existing Argilla dataset") # Dataset doesn't exist pass def _format_value(value: Any) -> Any: """Recursively format values, applying safe_str_with_xml to strings and recursing through dicts.""" if isinstance(value, dict): return {k: _format_value(v) for k, v in value.items()} else: return safe_str_with_xml(value) def _format_indexed_list(items: list[Any]) -> str: """Format a list with indexed headers and separators for better readability.""" if not items: return "" indexed_dict = [] for i, item in enumerate(items): # Skip empty or whitespace-only strings if isinstance(item, str) and not item.strip(): continue if isinstance(item, dict): indexed_dict.append({"index": i, "content": _format_value(item)}) else: indexed_dict.append({"index": i, "content": safe_str_with_xml(item)}) return json.dumps(indexed_dict, indent=2) def safe_str_with_xml(value: Any) -> str: if value is None: return "N/A" str_value = str(value) return html.escape(str_value) def _convert_item_to_argilla_record(item: dict[str, Any]) -> dict[str, str] | None: """Convert a single distillation item to an Argilla record.""" logger = get_logger(__name__) try: metadata = item.get("distilabel_metadata", {}) diagnostics_key = next((k for k in metadata if k.startswith("diagnostics_")), None) diagnostic_msgs_key = next((k for k in metadata if k.startswith("diagnostic_messages_")), None) diagnostics_list = metadata.get(diagnostics_key, []) if diagnostics_key else [] diagnostic_msgs_list = metadata.get(diagnostic_msgs_key, []) if diagnostic_msgs_key else [] query = item.get("query", "N/A") ground_truth = item.get("ground_truth", "N/A") stepwise_ground_truths = item.get("stepwise_ground_truths", "N/A") problem_type = item.get("problem_type", "N/A") composition_type = item.get("composition_type", "N/A") composition_dependencies = item.get("composition_dependencies", "N/A") messages = item.get("messages", []) dependency_edges = item.get("dependency_edges", "N/A") final_answer = item.get("final_answer", "N/A") is_correct = item.get("is_correct", "N/A") model_name = item.get("model_name", "N/A") num_tool_calls = len(json.loads(stepwise_ground_truths)) return { "query": str(query), "ground_truth": str(ground_truth), "stepwise_ground_truths": str(stepwise_ground_truths), "tool_calls": str(num_tool_calls), "problem_type": str(problem_type), "composition_type": str(composition_type), "composition_dependencies": str(composition_dependencies), "messages": _format_indexed_list(messages), "dependency_edges": str(dependency_edges), "final_answer": str(final_answer), "is_correct": str(is_correct), "model_name": str(model_name), "diagnostics": _format_indexed_list(diagnostics_list), "diagnostic_messages": _format_indexed_list(diagnostic_msgs_list), } except Exception as e: logger.warning(f"Failed to process record: {e}") return None def create_argilla_dataset( dataset_name: str, distiset_data: list[dict[str, Any]], client: rg.Argilla, private: bool ) -> None: """Create and populate an Argilla dataset from distillation results.""" logger = get_logger(__name__) try: # Delete existing dataset if it exists to ensure clean reupload # _delete_existing_argilla_dataset(client, dataset_name) # Create dataset with settings settings = create_argilla_dataset_settings() dataset = rg.Dataset( name=dataset_name, settings=settings, client=client, ) _ = dataset.create() logger.info(f"Created Argilla dataset: {dataset_name}") # Convert distilabel data to Argilla records records = [] for item in distiset_data: record = _convert_item_to_argilla_record(item) if record is not None: records.append(record) # Log records to dataset if records: dataset.records.log(records=records) logger.info(f"Logged {len(records)} records to Argilla dataset") else: logger.warning("No valid records found to log") domain = dataset_name.replace("/", "-").replace("-debug", "").replace("-train", "").replace("-validation", "") logger.info("✅ Argilla dataset created successfully") logger.info(f" Privacy: {'Private' if private else 'Public'}") logger.info(f" Access URL: https://{domain}.hf.space") except Exception: logger.exception("Failed to create Argilla dataset") def filter_dataset_by_correctness(distiset: Distiset, is_correct: bool = True) -> Distiset: """Filter dataset by is_correct flag.""" filtered_distiset = deepcopy(distiset) for split_name in filtered_distiset["default"]: split_data = filtered_distiset["default"][split_name] # Keep only correct entries for SFT training if only_correct=True, otherwise keep all filtered_data = split_data.filter(lambda x: x["is_correct"] is is_correct) filtered_distiset["default"][split_name] = filtered_data return filtered_distiset def push_to_huggingface(distiset: Distiset, dataset_name: str, private: bool) -> None: prepare_dataset_for_sft(distiset) strip_diagnostic_messages_from_metadata(distiset) normalize_schema(distiset) try: distiset.push_to_hub( dataset_name, private=private, ) logger.info(f"✅ Dataset successfully pushed to: {dataset_name}") logger.info(f" Privacy: {'Private' if private else 'Public'}") logger.info(f" Access URL: https://huggingface.co/datasets/{dataset_name}") except Exception: logger.exception("Failed to push dataset to Hugging Face Hub") def push_argilla_dataset(argilla_client: rg.Argilla, distiset: Distiset, args: DistillationConfig) -> None: success = filter_dataset_by_correctness(distiset, is_correct=True) if len(success["default"]["train"]) > 0: create_argilla_dataset( dataset_name=f"{args.argilla_output_dataset}", distiset_data=success["default"]["train"], client=argilla_client, private=args.private, ) failures = filter_dataset_by_correctness(distiset, is_correct=False) if len(failures["default"]["train"]) > 0: create_argilla_dataset( dataset_name=f"{args.argilla_output_dataset}-failures", distiset_data=failures["default"]["train"], client=argilla_client, private=args.private, ) def push_datasets_to_huggingface(distiset: Distiset, args: DistillationConfig) -> None: """Push two datasets to Hugging Face: one with all entries and one with only correct entries.""" assert args.hf_output_dataset is not None private = args.private # Push all entries dataset all_entries_name = f"{args.hf_output_dataset}-failures" logger.info(f"Pushing dataset with all entries to: {all_entries_name}") all_entries_distiset = filter_dataset_by_correctness(distiset, is_correct=False) if len(all_entries_distiset["default"]["train"]) > 0: push_to_huggingface(all_entries_distiset, all_entries_name, private) # Push correct entries only dataset correct_only_name = args.hf_output_dataset logger.info(f"Pushing dataset with correct entries only to: {correct_only_name}") correct_only_distiset = filter_dataset_by_correctness(distiset, is_correct=True) if len(correct_only_distiset["default"]["train"]) > 0: push_to_huggingface(correct_only_distiset, correct_only_name, private) def prepare_dataset_for_sft(distiset: Distiset) -> None: """Adds the tools column to the dataset.""" TOOLS = get_tools() def add_tools_column(example: dict[str, Any]) -> dict[str, Any]: example["tools"] = TOOLS return example distiset["default"]["train"] = distiset["default"]["train"].map(add_tools_column) if "validation" in distiset["default"]: distiset["default"]["validation"] = distiset["default"]["validation"].map(add_tools_column) def normalize_schema(distiset: Distiset) -> None: ns = distiset["default"] # 1) Stringify nested columns if present for split in list(ns.keys()): if "messages" in ns[split].column_names: ns[split] = ns[split].map(lambda r: {"messages": json.dumps(r.get("messages", []))}) if "distilabel_metadata" in ns[split].column_names: ns[split] = ns[split].map(lambda r: {"distilabel_metadata": json.dumps(r.get("distilabel_metadata", {}))}) # 2) Align columns by UNION: add missing columns with empty placeholders all_cols = set() for split in ns: all_cols |= set(ns[split].column_names) for split in list(ns.keys()): missing = sorted(all_cols - set(ns[split].column_names)) if missing: for col in missing: ns[split] = ns[split].add_column(col, [None] * len(ns[split])) def convert_dataset_to_list_of_dicts(dataset: Dataset) -> list[dict[str, Any]]: """Convert dataset from dict format to list of dicts.""" dataset_dict = dataset.to_dict() return [dict(zip(dataset_dict.keys(), vals, strict=True)) for vals in zip(*dataset_dict.values(), strict=True)] def load_dataset_split( dataset_name: str, dataset_config: str | None, split: str, take_n: int | None = None ) -> Dataset: """Loads a single dataset split either from the hub or from a local file.""" logger = get_logger(__name__) try: logger.info(f"Loading '{dataset_name}' (config: {dataset_config}, split: {split}) dataset.") dataset = hf_load_dataset(dataset_name, dataset_config, split=split) assert isinstance(dataset, Dataset) logger.info("Dataset loaded!") except Exception as err: raise FileNotFoundError(f"The dataset {dataset_name} is not available on the Hugging Face Hub.") from err else: if take_n is not None: dataset = dataset.select(range(take_n)) return dataset def load_datasets_for_distillation(args: DistillationConfig) -> dict[str, list[dict[str, Any]]]: """Loads train and optionally validation splits as lists of dicts.""" take_n = args.take_n datasets: dict[str, list[dict[str, Any]]] = {} if args.dataset_name is None: raise ValueError("dataset_name must be provided") # TODO: Remove once done debugging runpod # if not args.debug_mode: dataset = load_dataset_split(args.dataset_name, args.dataset_config, "train", take_n=take_n) datasets["train"] = convert_dataset_to_list_of_dicts(dataset) # else: # failures_dataset = load_dataset_split(f"{args.hf_output_dataset}-failures", args.dataset_config, "train") # datasets["train"] = convert_dataset_to_list_of_dicts(failures_dataset) return datasets def strip_diagnostic_messages_from_metadata(distiset: Distiset) -> None: """Remove diagnostic_messages_* keys from distilabel_metadata for all splits (before HF push).""" ns = distiset["default"] def strip_md(record: dict[str, Any]) -> dict[str, Any]: md = record.get("distilabel_metadata", {}) if isinstance(md, dict): keys_to_remove = [k for k in md if k.startswith("diagnostic_messages_")] if keys_to_remove: for k in keys_to_remove: md.pop(k, None) return {"distilabel_metadata": md} return {"distilabel_metadata": md} for split in list(ns.keys()): if "distilabel_metadata" in ns[split].column_names: ns[split] = ns[split].map(strip_md)