|
|
"""HuggingFace Chat Client adapter for Microsoft Agent Framework. |
|
|
|
|
|
This client enables the use of HuggingFace Inference API (including the free tier) |
|
|
as a backend for the agent framework, allowing "Advanced Mode" to work without |
|
|
an OpenAI API key. |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
from collections.abc import AsyncIterable, MutableSequence |
|
|
from functools import partial |
|
|
from typing import Any, cast |
|
|
|
|
|
import structlog |
|
|
from agent_framework import ( |
|
|
BaseChatClient, |
|
|
ChatMessage, |
|
|
ChatOptions, |
|
|
ChatResponse, |
|
|
ChatResponseUpdate, |
|
|
FinishReason, |
|
|
Role, |
|
|
) |
|
|
from agent_framework._middleware import use_chat_middleware |
|
|
from agent_framework._tools import use_function_invocation |
|
|
from agent_framework._types import FunctionCallContent, FunctionResultContent |
|
|
from agent_framework.observability import use_observability |
|
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
from src.middleware import RetryMiddleware, TokenTrackingMiddleware |
|
|
from src.utils.config import settings |
|
|
|
|
|
logger = structlog.get_logger() |
|
|
|
|
|
|
|
|
@use_function_invocation |
|
|
@use_observability |
|
|
@use_chat_middleware |
|
|
class HuggingFaceChatClient(BaseChatClient): |
|
|
"""Adapter for HuggingFace Inference API with full function calling support.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id: str | None = None, |
|
|
api_key: str | None = None, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
"""Initialize the HuggingFace chat client. |
|
|
|
|
|
Args: |
|
|
model_id: The HuggingFace model ID (default: configured value or Qwen2.5-7B). |
|
|
api_key: HF_TOKEN (optional, defaults to env var). |
|
|
**kwargs: Additional arguments passed to BaseChatClient. |
|
|
""" |
|
|
|
|
|
middleware = [ |
|
|
RetryMiddleware(max_attempts=3, min_wait=1.0, max_wait=10.0), |
|
|
TokenTrackingMiddleware(), |
|
|
] |
|
|
|
|
|
super().__init__(middleware=middleware, **kwargs) |
|
|
|
|
|
self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct" |
|
|
self.api_key = api_key or settings.hf_token |
|
|
|
|
|
|
|
|
|
|
|
self._client = InferenceClient( |
|
|
model=self.model_id, |
|
|
token=self.api_key, |
|
|
timeout=60, |
|
|
) |
|
|
logger.info("Initialized HuggingFaceChatClient", model=self.model_id) |
|
|
|
|
|
def _convert_messages(self, messages: MutableSequence[ChatMessage]) -> list[dict[str, Any]]: |
|
|
"""Convert framework messages to HuggingFace format.""" |
|
|
hf_messages: list[dict[str, Any]] = [] |
|
|
|
|
|
|
|
|
|
|
|
call_id_to_name: dict[str, str] = {} |
|
|
|
|
|
for msg in messages: |
|
|
|
|
|
if hasattr(msg.role, "value"): |
|
|
role_str = str(msg.role.value) |
|
|
else: |
|
|
role_str = str(msg.role) |
|
|
|
|
|
content_str = msg.text or "" |
|
|
tool_calls = [] |
|
|
tool_call_id = None |
|
|
tool_name = None |
|
|
|
|
|
|
|
|
if msg.contents: |
|
|
for item in msg.contents: |
|
|
if isinstance(item, FunctionCallContent): |
|
|
|
|
|
|
|
|
call_id_to_name[item.call_id] = item.name |
|
|
tool_calls.append( |
|
|
{ |
|
|
"id": item.call_id, |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": item.name, |
|
|
"arguments": ( |
|
|
item.arguments |
|
|
if isinstance(item.arguments, str) |
|
|
else json.dumps(item.arguments) |
|
|
), |
|
|
}, |
|
|
} |
|
|
) |
|
|
elif isinstance(item, FunctionResultContent): |
|
|
|
|
|
role_str = "tool" |
|
|
tool_call_id = item.call_id |
|
|
|
|
|
tool_name = call_id_to_name.get(item.call_id) |
|
|
|
|
|
|
|
|
if item.result is None: |
|
|
content_str = "" |
|
|
elif isinstance(item.result, str): |
|
|
content_str = item.result |
|
|
else: |
|
|
content_str = json.dumps(item.result) |
|
|
|
|
|
message_dict: dict[str, Any] = {"role": role_str, "content": content_str} |
|
|
|
|
|
if tool_calls: |
|
|
message_dict["tool_calls"] = tool_calls |
|
|
|
|
|
if tool_call_id: |
|
|
message_dict["tool_call_id"] = tool_call_id |
|
|
|
|
|
if tool_name: |
|
|
message_dict["name"] = tool_name |
|
|
|
|
|
hf_messages.append(message_dict) |
|
|
|
|
|
return hf_messages |
|
|
|
|
|
def _convert_tools(self, tools: list[Any] | None) -> list[dict[str, Any]] | None: |
|
|
"""Convert AIFunction objects to OpenAI-compatible tool definitions. |
|
|
|
|
|
AIFunction.to_dict() returns: |
|
|
{'type': 'ai_function', 'name': '...', 'input_model': {...}} |
|
|
|
|
|
OpenAI/HuggingFace expects: |
|
|
{'type': 'function', 'function': {'name': '...', 'parameters': {...}}} |
|
|
""" |
|
|
if not tools: |
|
|
return None |
|
|
|
|
|
json_tools = [] |
|
|
for tool in tools: |
|
|
if hasattr(tool, "to_dict"): |
|
|
try: |
|
|
t_dict = tool.to_dict() |
|
|
json_tools.append( |
|
|
{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": t_dict["name"], |
|
|
"description": t_dict.get("description", ""), |
|
|
"parameters": t_dict["input_model"], |
|
|
}, |
|
|
} |
|
|
) |
|
|
except (KeyError, TypeError) as e: |
|
|
logger.warning("Failed to convert tool", tool=str(tool), error=str(e)) |
|
|
elif isinstance(tool, dict): |
|
|
|
|
|
json_tools.append(tool) |
|
|
else: |
|
|
logger.warning("Skipping non-serializable tool", tool_type=str(type(tool))) |
|
|
|
|
|
return json_tools if json_tools else None |
|
|
|
|
|
def _parse_tool_calls(self, message: Any) -> list[FunctionCallContent]: |
|
|
"""Parse HuggingFace tool_calls into framework FunctionCallContent.""" |
|
|
contents: list[FunctionCallContent] = [] |
|
|
|
|
|
if not hasattr(message, "tool_calls") or not message.tool_calls: |
|
|
return contents |
|
|
|
|
|
for tc in message.tool_calls: |
|
|
try: |
|
|
contents.append( |
|
|
FunctionCallContent( |
|
|
call_id=tc.id, |
|
|
name=tc.function.name, |
|
|
arguments=tc.function.arguments, |
|
|
) |
|
|
) |
|
|
except (AttributeError, TypeError) as e: |
|
|
logger.warning("Failed to parse tool call", error=str(e)) |
|
|
|
|
|
return contents |
|
|
|
|
|
async def _inner_get_response( |
|
|
self, |
|
|
*, |
|
|
messages: MutableSequence[ChatMessage], |
|
|
chat_options: ChatOptions, |
|
|
**kwargs: Any, |
|
|
) -> ChatResponse: |
|
|
"""Synchronous response generation using chat_completion.""" |
|
|
hf_messages = self._convert_messages(messages) |
|
|
|
|
|
|
|
|
tools = self._convert_tools(chat_options.tools if chat_options.tools else None) |
|
|
|
|
|
|
|
|
|
|
|
hf_tool_choice: str | None = None |
|
|
if tools and chat_options.tool_choice is not None: |
|
|
tool_choice_str = str(chat_options.tool_choice) |
|
|
if "AUTO" in tool_choice_str: |
|
|
hf_tool_choice = "auto" |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
max_tokens = chat_options.max_tokens if chat_options.max_tokens is not None else 2048 |
|
|
temperature = chat_options.temperature if chat_options.temperature is not None else 0.7 |
|
|
|
|
|
|
|
|
call_fn = partial( |
|
|
self._client.chat_completion, |
|
|
messages=hf_messages, |
|
|
tools=tools, |
|
|
tool_choice=hf_tool_choice, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
stream=False, |
|
|
) |
|
|
|
|
|
response = await asyncio.to_thread(call_fn) |
|
|
|
|
|
|
|
|
|
|
|
choices = response.choices |
|
|
if not choices: |
|
|
return ChatResponse(messages=[], response_id="error-no-choices") |
|
|
|
|
|
choice = choices[0] |
|
|
message = choice.message |
|
|
message_content = message.content or "" |
|
|
|
|
|
|
|
|
tool_call_contents = self._parse_tool_calls(message) |
|
|
|
|
|
|
|
|
response_msg = ChatMessage( |
|
|
role=cast(Any, message.role), |
|
|
text=message_content, |
|
|
contents=tool_call_contents if tool_call_contents else None, |
|
|
) |
|
|
|
|
|
return ChatResponse( |
|
|
messages=[response_msg], |
|
|
response_id=response.id or "hf-response", |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error("HuggingFace API error", error=str(e)) |
|
|
raise |
|
|
|
|
|
async def _inner_get_streaming_response( |
|
|
self, |
|
|
*, |
|
|
messages: MutableSequence[ChatMessage], |
|
|
chat_options: ChatOptions, |
|
|
**kwargs: Any, |
|
|
) -> AsyncIterable[ChatResponseUpdate]: |
|
|
"""Streaming response generation.""" |
|
|
hf_messages = self._convert_messages(messages) |
|
|
|
|
|
|
|
|
tools = self._convert_tools(chat_options.tools if chat_options.tools else None) |
|
|
|
|
|
hf_tool_choice: str | None = None |
|
|
if tools and chat_options.tool_choice is not None: |
|
|
if "AUTO" in str(chat_options.tool_choice): |
|
|
hf_tool_choice = "auto" |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
max_tokens = chat_options.max_tokens if chat_options.max_tokens is not None else 2048 |
|
|
temperature = chat_options.temperature if chat_options.temperature is not None else 0.7 |
|
|
|
|
|
|
|
|
call_fn = partial( |
|
|
self._client.chat_completion, |
|
|
messages=hf_messages, |
|
|
tools=tools, |
|
|
tool_choice=hf_tool_choice, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
stream=True, |
|
|
) |
|
|
|
|
|
stream = await asyncio.to_thread(call_fn) |
|
|
|
|
|
|
|
|
|
|
|
tool_call_accumulator: dict[int, dict[str, Any]] = {} |
|
|
|
|
|
for chunk in stream: |
|
|
|
|
|
if not chunk.choices: |
|
|
continue |
|
|
choice = chunk.choices[0] |
|
|
delta = choice.delta |
|
|
|
|
|
|
|
|
if delta.content: |
|
|
yield ChatResponseUpdate( |
|
|
role=cast(Any, delta.role) if delta.role else None, |
|
|
text=delta.content, |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(delta, "tool_calls") and delta.tool_calls: |
|
|
for tc in delta.tool_calls: |
|
|
idx = tc.index |
|
|
if idx not in tool_call_accumulator: |
|
|
tool_call_accumulator[idx] = { |
|
|
"id": "", |
|
|
"name": "", |
|
|
"arguments": "", |
|
|
} |
|
|
|
|
|
|
|
|
if tc.id: |
|
|
tool_call_accumulator[idx]["id"] += tc.id |
|
|
if tc.function: |
|
|
if tc.function.name: |
|
|
tool_call_accumulator[idx]["name"] += tc.function.name |
|
|
if tc.function.arguments: |
|
|
tool_call_accumulator[idx]["arguments"] += tc.function.arguments |
|
|
|
|
|
|
|
|
await asyncio.sleep(0) |
|
|
|
|
|
|
|
|
if tool_call_accumulator: |
|
|
contents: list[FunctionCallContent] = [] |
|
|
for idx in sorted(tool_call_accumulator.keys()): |
|
|
tc_data = tool_call_accumulator[idx] |
|
|
|
|
|
if tc_data["id"] and tc_data["name"]: |
|
|
contents.append( |
|
|
FunctionCallContent( |
|
|
call_id=tc_data["id"], |
|
|
name=tc_data["name"], |
|
|
arguments=tc_data["arguments"], |
|
|
) |
|
|
) |
|
|
|
|
|
if contents: |
|
|
yield ChatResponseUpdate( |
|
|
contents=contents, |
|
|
role=Role.ASSISTANT, |
|
|
finish_reason=FinishReason.TOOL_CALLS, |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error("HuggingFace Streaming error", error=str(e)) |
|
|
raise |
|
|
|