File size: 3,083 Bytes
f793029
4f09ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0be6708
4f09ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# ui/agent/completion.py
from __future__ import annotations

from typing import Any

from huggingface_hub import InferenceClient

from .messages import parse_text_tool_calls, parse_tool_calls


def complete_turn(

    client: InferenceClient,

    api_messages: list[dict[str, Any]],

    *,

    max_tokens: int,

    temperature: float,

    top_p: float,

    tools: list[dict[str, Any]] | None,

    tool_choice: Any = "auto",

) -> tuple[str, str, list[dict[str, Any]] | None]:
    """Run one model turn, preferring streaming and falling back to a single request."""
    content = ""
    reasoning = ""
    tool_calls_map: dict[int, dict[str, Any]] = {}

    stream = client.chat_completion(
        api_messages,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        tools=tools,
        tool_choice=tool_choice if tools else None,
        stream=True,
    )
    for chunk in stream:
        if not chunk.choices:
            continue
        delta = chunk.choices[0].delta
        if delta.content:
            content += delta.content
        if delta.reasoning:
            reasoning += delta.reasoning
        if delta.tool_calls:
            for tool_call in delta.tool_calls:
                idx = tool_call.index
                if idx not in tool_calls_map:
                    tool_calls_map[idx] = {
                        "id": tool_call.id,
                        "type": tool_call.type,
                        "function": {
                            "name": tool_call.function.name or "",
                            "arguments": tool_call.function.arguments or "",
                        },
                    }
                    continue
                if tool_call.function.name:
                    tool_calls_map[idx]["function"]["name"] = tool_call.function.name
                if tool_call.function.arguments:
                    tool_calls_map[idx]["function"]["arguments"] += (
                        tool_call.function.arguments
                    )

    text_tool_calls = parse_text_tool_calls(content) or parse_text_tool_calls(reasoning)
    tool_calls = list(tool_calls_map.values()) or text_tool_calls
    if content or reasoning or tool_calls:
        if text_tool_calls:
            content = ""
            reasoning = ""
        return content, reasoning, tool_calls

    response = client.chat_completion(
        api_messages,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        tools=tools,
        tool_choice=tool_choice if tools else None,
    )
    assistant_msg = response.choices[0].message
    content = assistant_msg.content or ""
    reasoning = assistant_msg.reasoning or ""
    text_tool_calls = parse_text_tool_calls(content) or parse_text_tool_calls(reasoning)
    tool_calls = parse_tool_calls(assistant_msg) or text_tool_calls
    return (
        "" if text_tool_calls else content,
        "" if text_tool_calls else reasoning,
        tool_calls,
    )