File size: 4,484 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import json
import uuid
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any

try:
    from distilabel.errors import DistilabelUserError
except ImportError:
    class DistilabelUserError(ValueError):
        """Fallback error type for runtime paths that do not install distilabel."""


if TYPE_CHECKING:
    from distilabel.typing import ChatType
else:
    ChatType = dict[str, Any]

from linalg_zero.distillation.data import ThoughtSchema
from linalg_zero.shared.system_prompts import (
    ANSWER_CLOSE,
    ANSWER_OPEN,
    THINK_CLOSE,
    THINK_OPEN,
)

if TYPE_CHECKING:
    pass


DIAG_PREFIX = "[diag]"


class ModelType(str, Enum):
    DEFAULT = "default"

    def get_model_parameters(self) -> "ModelParameters":
        return DefaultConfig()


class ModelParameters(ABC):
    @abstractmethod
    def set_recommended_defaults(self, generation_kwargs: dict[str, Any], *, deterministic: bool) -> dict[str, Any]:
        """Inject recommended generation defaults for the model.

        deterministic=True should configure sampling deterministically (e.g., temperature=0, top_p=1).
        """
        raise NotImplementedError

    @abstractmethod
    def format_assistant_message(self, message: ThoughtSchema) -> dict[str, Any] | None:
        """Return an OpenAI-compatible assistant message for the given parsed output."""
        raise NotImplementedError

    @abstractmethod
    def append_policy(self) -> bool:
        """Return whether to append the assistant message to the conversation."""
        pass


class DefaultConfig(ModelParameters):
    def set_recommended_defaults(self, generation_kwargs: dict[str, Any], *, deterministic: bool) -> dict[str, Any]:
        if deterministic:
            generation_kwargs["temperature"] = 0.0
            generation_kwargs["top_p"] = 0.95
        else:
            # Recommended non-deterministic defaults for high-quality generations
            # Aligns with Qwen best practices while remaining backend-agnostic
            generation_kwargs.setdefault("temperature", 0.6)
            generation_kwargs.setdefault("top_p", 0.95)

            # Some backends (e.g., vLLM) accept additional sampling params
            extra_body = generation_kwargs.setdefault("extra_body", {})
            extra_body.setdefault("top_k", 20)
            extra_body.setdefault("min_p", 0)
        return generation_kwargs

    def append_policy(self) -> bool:
        """Return whether to append the assistant message to the conversation."""
        return False

    def format_assistant_message(self, message: ThoughtSchema) -> dict[str, Any] | None:
        if message.completed:
            if message.final_answer is None:
                raise DistilabelUserError("final_answer cannot be None when completed=True")
            return {
                "role": "assistant",
                "content": f"{THINK_OPEN}{message.thought}{THINK_CLOSE}\n\n{ANSWER_OPEN}{message.final_answer}{ANSWER_CLOSE}",
            }
        if message.tool_call is not None:
            return {
                "role": "assistant",
                "content": THINK_OPEN + message.thought + THINK_CLOSE,
                "tool_calls": [
                    {
                        "id": str(uuid.uuid4()),
                        "type": "function",
                        "function": {
                            "name": message.tool_call.name,
                            "arguments": json.dumps(message.tool_call.arguments),
                        },
                    }
                ],
            }
        return None

    def create_tool_message(self, conversation: list[ChatType], message: dict[str, Any]) -> dict[str, Any]:
        # NOTE: Find the last assistant message with tool calls. This only works for single-turn tool calls,
        # if we transition to multiple calls per turn, must match by name or position.
        tool_call_id = None
        for msg in reversed(conversation):
            if msg.get("role") == "assistant" and msg.get("tool_calls"):
                tool_call_id = msg.get("tool_calls", [{}])[0].get("id")
                break

        if tool_call_id is None:
            raise DistilabelUserError("No assistant message with tool_calls found for tool response")

        return {
            "role": "tool",
            "tool_call_id": tool_call_id,
            "name": message["function_name"],
            "content": message["execution_result"],
        }