atomwalk12's picture
initial commit
0dd6c2f
import json
import uuid
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any
try:
from distilabel.errors import DistilabelUserError
except ImportError:
class DistilabelUserError(ValueError):
"""Fallback error type for runtime paths that do not install distilabel."""
if TYPE_CHECKING:
from distilabel.typing import ChatType
else:
ChatType = dict[str, Any]
from linalg_zero.distillation.data import ThoughtSchema
from linalg_zero.shared.system_prompts import (
ANSWER_CLOSE,
ANSWER_OPEN,
THINK_CLOSE,
THINK_OPEN,
)
if TYPE_CHECKING:
pass
DIAG_PREFIX = "[diag]"
class ModelType(str, Enum):
DEFAULT = "default"
def get_model_parameters(self) -> "ModelParameters":
return DefaultConfig()
class ModelParameters(ABC):
@abstractmethod
def set_recommended_defaults(self, generation_kwargs: dict[str, Any], *, deterministic: bool) -> dict[str, Any]:
"""Inject recommended generation defaults for the model.
deterministic=True should configure sampling deterministically (e.g., temperature=0, top_p=1).
"""
raise NotImplementedError
@abstractmethod
def format_assistant_message(self, message: ThoughtSchema) -> dict[str, Any] | None:
"""Return an OpenAI-compatible assistant message for the given parsed output."""
raise NotImplementedError
@abstractmethod
def append_policy(self) -> bool:
"""Return whether to append the assistant message to the conversation."""
pass
class DefaultConfig(ModelParameters):
def set_recommended_defaults(self, generation_kwargs: dict[str, Any], *, deterministic: bool) -> dict[str, Any]:
if deterministic:
generation_kwargs["temperature"] = 0.0
generation_kwargs["top_p"] = 0.95
else:
# Recommended non-deterministic defaults for high-quality generations
# Aligns with Qwen best practices while remaining backend-agnostic
generation_kwargs.setdefault("temperature", 0.6)
generation_kwargs.setdefault("top_p", 0.95)
# Some backends (e.g., vLLM) accept additional sampling params
extra_body = generation_kwargs.setdefault("extra_body", {})
extra_body.setdefault("top_k", 20)
extra_body.setdefault("min_p", 0)
return generation_kwargs
def append_policy(self) -> bool:
"""Return whether to append the assistant message to the conversation."""
return False
def format_assistant_message(self, message: ThoughtSchema) -> dict[str, Any] | None:
if message.completed:
if message.final_answer is None:
raise DistilabelUserError("final_answer cannot be None when completed=True")
return {
"role": "assistant",
"content": f"{THINK_OPEN}{message.thought}{THINK_CLOSE}\n\n{ANSWER_OPEN}{message.final_answer}{ANSWER_CLOSE}",
}
if message.tool_call is not None:
return {
"role": "assistant",
"content": THINK_OPEN + message.thought + THINK_CLOSE,
"tool_calls": [
{
"id": str(uuid.uuid4()),
"type": "function",
"function": {
"name": message.tool_call.name,
"arguments": json.dumps(message.tool_call.arguments),
},
}
],
}
return None
def create_tool_message(self, conversation: list[ChatType], message: dict[str, Any]) -> dict[str, Any]:
# NOTE: Find the last assistant message with tool calls. This only works for single-turn tool calls,
# if we transition to multiple calls per turn, must match by name or position.
tool_call_id = None
for msg in reversed(conversation):
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tool_call_id = msg.get("tool_calls", [{}])[0].get("id")
break
if tool_call_id is None:
raise DistilabelUserError("No assistant message with tool_calls found for tool response")
return {
"role": "tool",
"tool_call_id": tool_call_id,
"name": message["function_name"],
"content": message["execution_result"],
}