| """HuggingFace Chat Client adapter for Microsoft Agent Framework. |
| |
| This client enables the use of HuggingFace Inference API (including the free tier) |
| as a backend for the agent framework, allowing "Advanced Mode" to work without |
| an OpenAI API key. |
| """ |
|
|
| import asyncio |
| import json |
| from collections.abc import AsyncIterable, MutableSequence |
| from functools import partial |
| from typing import Any, cast |
|
|
| import structlog |
| from agent_framework import ( |
| BaseChatClient, |
| ChatMessage, |
| ChatOptions, |
| ChatResponse, |
| ChatResponseUpdate, |
| FinishReason, |
| Role, |
| ) |
| from agent_framework._middleware import use_chat_middleware |
| from agent_framework._tools import use_function_invocation |
| from agent_framework._types import FunctionCallContent, FunctionResultContent |
| from agent_framework.observability import use_observability |
| from huggingface_hub import InferenceClient |
|
|
| from src.middleware import RetryMiddleware, TokenTrackingMiddleware |
| from src.utils.config import settings |
|
|
| logger = structlog.get_logger() |
|
|
|
|
| @use_function_invocation |
| @use_observability |
| @use_chat_middleware |
| class HuggingFaceChatClient(BaseChatClient): |
| """Adapter for HuggingFace Inference API with full function calling support.""" |
|
|
| def __init__( |
| self, |
| model_id: str | None = None, |
| api_key: str | None = None, |
| **kwargs: Any, |
| ) -> None: |
| """Initialize the HuggingFace chat client. |
| |
| Args: |
| model_id: The HuggingFace model ID (default: configured value or Qwen2.5-7B). |
| api_key: HF_TOKEN (optional, defaults to env var). |
| **kwargs: Additional arguments passed to BaseChatClient. |
| """ |
| |
| middleware = [ |
| RetryMiddleware(max_attempts=3, min_wait=1.0, max_wait=10.0), |
| TokenTrackingMiddleware(), |
| ] |
|
|
| super().__init__(middleware=middleware, **kwargs) |
| |
| self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct" |
| self.api_key = api_key or settings.hf_token |
|
|
| |
| |
| self._client = InferenceClient( |
| model=self.model_id, |
| token=self.api_key, |
| timeout=60, |
| ) |
| logger.info("Initialized HuggingFaceChatClient", model=self.model_id) |
|
|
| def _convert_messages(self, messages: MutableSequence[ChatMessage]) -> list[dict[str, Any]]: |
| """Convert framework messages to HuggingFace format.""" |
| hf_messages: list[dict[str, Any]] = [] |
|
|
| |
| |
| call_id_to_name: dict[str, str] = {} |
|
|
| for msg in messages: |
| |
| if hasattr(msg.role, "value"): |
| role_str = str(msg.role.value) |
| else: |
| role_str = str(msg.role) |
|
|
| content_str = msg.text or "" |
| tool_calls = [] |
| tool_call_id = None |
| tool_name = None |
|
|
| |
| if msg.contents: |
| for item in msg.contents: |
| if isinstance(item, FunctionCallContent): |
| |
| |
| call_id_to_name[item.call_id] = item.name |
| tool_calls.append( |
| { |
| "id": item.call_id, |
| "type": "function", |
| "function": { |
| "name": item.name, |
| "arguments": ( |
| item.arguments |
| if isinstance(item.arguments, str) |
| else json.dumps(item.arguments) |
| ), |
| }, |
| } |
| ) |
| elif isinstance(item, FunctionResultContent): |
| |
| role_str = "tool" |
| tool_call_id = item.call_id |
| |
| tool_name = call_id_to_name.get(item.call_id) |
| |
| |
| if item.result is None: |
| content_str = "" |
| elif isinstance(item.result, str): |
| content_str = item.result |
| else: |
| content_str = json.dumps(item.result) |
|
|
| message_dict: dict[str, Any] = {"role": role_str, "content": content_str} |
|
|
| if tool_calls: |
| message_dict["tool_calls"] = tool_calls |
|
|
| if tool_call_id: |
| message_dict["tool_call_id"] = tool_call_id |
| |
| if tool_name: |
| message_dict["name"] = tool_name |
|
|
| hf_messages.append(message_dict) |
|
|
| return hf_messages |
|
|
| def _convert_tools(self, tools: list[Any] | None) -> list[dict[str, Any]] | None: |
| """Convert AIFunction objects to OpenAI-compatible tool definitions. |
| |
| AIFunction.to_dict() returns: |
| {'type': 'ai_function', 'name': '...', 'input_model': {...}} |
| |
| OpenAI/HuggingFace expects: |
| {'type': 'function', 'function': {'name': '...', 'parameters': {...}}} |
| """ |
| if not tools: |
| return None |
|
|
| json_tools = [] |
| for tool in tools: |
| if hasattr(tool, "to_dict"): |
| try: |
| t_dict = tool.to_dict() |
| json_tools.append( |
| { |
| "type": "function", |
| "function": { |
| "name": t_dict["name"], |
| "description": t_dict.get("description", ""), |
| "parameters": t_dict["input_model"], |
| }, |
| } |
| ) |
| except (KeyError, TypeError) as e: |
| logger.warning("Failed to convert tool", tool=str(tool), error=str(e)) |
| elif isinstance(tool, dict): |
| |
| json_tools.append(tool) |
| else: |
| logger.warning("Skipping non-serializable tool", tool_type=str(type(tool))) |
|
|
| return json_tools if json_tools else None |
|
|
| def _parse_tool_calls(self, message: Any) -> list[FunctionCallContent]: |
| """Parse HuggingFace tool_calls into framework FunctionCallContent.""" |
| contents: list[FunctionCallContent] = [] |
|
|
| if not hasattr(message, "tool_calls") or not message.tool_calls: |
| return contents |
|
|
| for tc in message.tool_calls: |
| try: |
| contents.append( |
| FunctionCallContent( |
| call_id=tc.id, |
| name=tc.function.name, |
| arguments=tc.function.arguments, |
| ) |
| ) |
| except (AttributeError, TypeError) as e: |
| logger.warning("Failed to parse tool call", error=str(e)) |
|
|
| return contents |
|
|
| async def _inner_get_response( |
| self, |
| *, |
| messages: MutableSequence[ChatMessage], |
| chat_options: ChatOptions, |
| **kwargs: Any, |
| ) -> ChatResponse: |
| """Synchronous response generation using chat_completion.""" |
| hf_messages = self._convert_messages(messages) |
|
|
| |
| tools = self._convert_tools(chat_options.tools if chat_options.tools else None) |
|
|
| |
| |
| hf_tool_choice: str | None = None |
| if tools and chat_options.tool_choice is not None: |
| tool_choice_str = str(chat_options.tool_choice) |
| if "AUTO" in tool_choice_str: |
| hf_tool_choice = "auto" |
| |
|
|
| try: |
| |
| |
| max_tokens = chat_options.max_tokens if chat_options.max_tokens is not None else 2048 |
| temperature = chat_options.temperature if chat_options.temperature is not None else 0.7 |
|
|
| |
| call_fn = partial( |
| self._client.chat_completion, |
| messages=hf_messages, |
| tools=tools, |
| tool_choice=hf_tool_choice, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| stream=False, |
| ) |
|
|
| response = await asyncio.to_thread(call_fn) |
|
|
| |
| |
| choices = response.choices |
| if not choices: |
| return ChatResponse(messages=[], response_id="error-no-choices") |
|
|
| choice = choices[0] |
| message = choice.message |
| message_content = message.content or "" |
|
|
| |
| tool_call_contents = self._parse_tool_calls(message) |
|
|
| |
| response_msg = ChatMessage( |
| role=cast(Any, message.role), |
| text=message_content, |
| contents=tool_call_contents if tool_call_contents else None, |
| ) |
|
|
| return ChatResponse( |
| messages=[response_msg], |
| response_id=response.id or "hf-response", |
| ) |
|
|
| except Exception as e: |
| logger.error("HuggingFace API error", error=str(e)) |
| raise |
|
|
| async def _inner_get_streaming_response( |
| self, |
| *, |
| messages: MutableSequence[ChatMessage], |
| chat_options: ChatOptions, |
| **kwargs: Any, |
| ) -> AsyncIterable[ChatResponseUpdate]: |
| """Streaming response generation.""" |
| hf_messages = self._convert_messages(messages) |
|
|
| |
| tools = self._convert_tools(chat_options.tools if chat_options.tools else None) |
|
|
| hf_tool_choice: str | None = None |
| if tools and chat_options.tool_choice is not None: |
| if "AUTO" in str(chat_options.tool_choice): |
| hf_tool_choice = "auto" |
|
|
| try: |
| |
| |
| max_tokens = chat_options.max_tokens if chat_options.max_tokens is not None else 2048 |
| temperature = chat_options.temperature if chat_options.temperature is not None else 0.7 |
|
|
| |
| call_fn = partial( |
| self._client.chat_completion, |
| messages=hf_messages, |
| tools=tools, |
| tool_choice=hf_tool_choice, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| stream=True, |
| ) |
|
|
| stream = await asyncio.to_thread(call_fn) |
|
|
| |
| |
| tool_call_accumulator: dict[int, dict[str, Any]] = {} |
|
|
| for chunk in stream: |
| |
| if not chunk.choices: |
| continue |
| choice = chunk.choices[0] |
| delta = choice.delta |
|
|
| |
| if delta.content: |
| yield ChatResponseUpdate( |
| role=cast(Any, delta.role) if delta.role else None, |
| text=delta.content, |
| ) |
|
|
| |
| if hasattr(delta, "tool_calls") and delta.tool_calls: |
| for tc in delta.tool_calls: |
| idx = tc.index |
| if idx not in tool_call_accumulator: |
| tool_call_accumulator[idx] = { |
| "id": "", |
| "name": "", |
| "arguments": "", |
| } |
|
|
| |
| if tc.id: |
| tool_call_accumulator[idx]["id"] += tc.id |
| if tc.function: |
| if tc.function.name: |
| tool_call_accumulator[idx]["name"] += tc.function.name |
| if tc.function.arguments: |
| tool_call_accumulator[idx]["arguments"] += tc.function.arguments |
|
|
| |
| await asyncio.sleep(0) |
|
|
| |
| if tool_call_accumulator: |
| contents: list[FunctionCallContent] = [] |
| for idx in sorted(tool_call_accumulator.keys()): |
| tc_data = tool_call_accumulator[idx] |
| |
| if tc_data["id"] and tc_data["name"]: |
| contents.append( |
| FunctionCallContent( |
| call_id=tc_data["id"], |
| name=tc_data["name"], |
| arguments=tc_data["arguments"], |
| ) |
| ) |
|
|
| if contents: |
| yield ChatResponseUpdate( |
| contents=contents, |
| role=Role.ASSISTANT, |
| finish_reason=FinishReason.TOOL_CALLS, |
| ) |
|
|
| except Exception as e: |
| logger.error("HuggingFace Streaming error", error=str(e)) |
| raise |
|
|