Spaces:
Running on Zero
Running on Zero
| import json | |
| import uuid | |
| from abc import ABC, abstractmethod | |
| from enum import Enum | |
| from typing import TYPE_CHECKING, Any | |
| try: | |
| from distilabel.errors import DistilabelUserError | |
| except ImportError: | |
| class DistilabelUserError(ValueError): | |
| """Fallback error type for runtime paths that do not install distilabel.""" | |
| if TYPE_CHECKING: | |
| from distilabel.typing import ChatType | |
| else: | |
| ChatType = dict[str, Any] | |
| from linalg_zero.distillation.data import ThoughtSchema | |
| from linalg_zero.shared.system_prompts import ( | |
| ANSWER_CLOSE, | |
| ANSWER_OPEN, | |
| THINK_CLOSE, | |
| THINK_OPEN, | |
| ) | |
| if TYPE_CHECKING: | |
| pass | |
| DIAG_PREFIX = "[diag]" | |
| class ModelType(str, Enum): | |
| DEFAULT = "default" | |
| def get_model_parameters(self) -> "ModelParameters": | |
| return DefaultConfig() | |
| class ModelParameters(ABC): | |
| def set_recommended_defaults(self, generation_kwargs: dict[str, Any], *, deterministic: bool) -> dict[str, Any]: | |
| """Inject recommended generation defaults for the model. | |
| deterministic=True should configure sampling deterministically (e.g., temperature=0, top_p=1). | |
| """ | |
| raise NotImplementedError | |
| def format_assistant_message(self, message: ThoughtSchema) -> dict[str, Any] | None: | |
| """Return an OpenAI-compatible assistant message for the given parsed output.""" | |
| raise NotImplementedError | |
| def append_policy(self) -> bool: | |
| """Return whether to append the assistant message to the conversation.""" | |
| pass | |
| class DefaultConfig(ModelParameters): | |
| def set_recommended_defaults(self, generation_kwargs: dict[str, Any], *, deterministic: bool) -> dict[str, Any]: | |
| if deterministic: | |
| generation_kwargs["temperature"] = 0.0 | |
| generation_kwargs["top_p"] = 0.95 | |
| else: | |
| # Recommended non-deterministic defaults for high-quality generations | |
| # Aligns with Qwen best practices while remaining backend-agnostic | |
| generation_kwargs.setdefault("temperature", 0.6) | |
| generation_kwargs.setdefault("top_p", 0.95) | |
| # Some backends (e.g., vLLM) accept additional sampling params | |
| extra_body = generation_kwargs.setdefault("extra_body", {}) | |
| extra_body.setdefault("top_k", 20) | |
| extra_body.setdefault("min_p", 0) | |
| return generation_kwargs | |
| def append_policy(self) -> bool: | |
| """Return whether to append the assistant message to the conversation.""" | |
| return False | |
| def format_assistant_message(self, message: ThoughtSchema) -> dict[str, Any] | None: | |
| if message.completed: | |
| if message.final_answer is None: | |
| raise DistilabelUserError("final_answer cannot be None when completed=True") | |
| return { | |
| "role": "assistant", | |
| "content": f"{THINK_OPEN}{message.thought}{THINK_CLOSE}\n\n{ANSWER_OPEN}{message.final_answer}{ANSWER_CLOSE}", | |
| } | |
| if message.tool_call is not None: | |
| return { | |
| "role": "assistant", | |
| "content": THINK_OPEN + message.thought + THINK_CLOSE, | |
| "tool_calls": [ | |
| { | |
| "id": str(uuid.uuid4()), | |
| "type": "function", | |
| "function": { | |
| "name": message.tool_call.name, | |
| "arguments": json.dumps(message.tool_call.arguments), | |
| }, | |
| } | |
| ], | |
| } | |
| return None | |
| def create_tool_message(self, conversation: list[ChatType], message: dict[str, Any]) -> dict[str, Any]: | |
| # NOTE: Find the last assistant message with tool calls. This only works for single-turn tool calls, | |
| # if we transition to multiple calls per turn, must match by name or position. | |
| tool_call_id = None | |
| for msg in reversed(conversation): | |
| if msg.get("role") == "assistant" and msg.get("tool_calls"): | |
| tool_call_id = msg.get("tool_calls", [{}])[0].get("id") | |
| break | |
| if tool_call_id is None: | |
| raise DistilabelUserError("No assistant message with tool_calls found for tool response") | |
| return { | |
| "role": "tool", | |
| "tool_call_id": tool_call_id, | |
| "name": message["function_name"], | |
| "content": message["execution_result"], | |
| } | |