Spaces:
Sleeping
Sleeping
File size: 6,155 Bytes
4ef118d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | """
Base provider adapter using Agno framework.
Defines the interface for all provider adapters.
"""
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any, Literal
@dataclass
class ProviderConfig:
"""Provider configuration."""
name: str
base_url: str | None = None
default_model: str = ""
supports_streaming: bool = True
supports_tools: bool = True
supports_streaming_tool_calls: bool = False
supports_json_schema: bool = True
supports_thinking: bool = False
supports_vision: bool = False
@dataclass
class ExecutionContext:
"""Execution context for chat completion."""
messages: list[dict[str, Any]]
tools: list[dict[str, Any]] | None = None
tool_choice: Any = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
frequency_penalty: float | None = None
presence_penalty: float | None = None
response_format: dict[str, Any] | None = None
thinking: dict[str, Any] | None = None
stream: bool = True
tavily_api_key: str | None = None
@dataclass
class StreamChunk:
"""A streaming chunk from the provider."""
type: Literal["text", "thought", "tool_calls", "done", "error"]
content: str = ""
thought: str = ""
tool_calls: list[dict[str, Any]] | None = None
finish_reason: str | None = None
error: str | None = None
class BaseProviderAdapter(ABC):
"""
Base provider adapter using Agno framework.
All provider adapters should inherit from this class and implement
the abstract methods to support their specific provider API.
"""
def __init__(self, config: ProviderConfig):
self.config = config
@abstractmethod
def build_model(
self,
api_key: str,
model: str | None = None,
base_url: str | None = None,
**kwargs
) -> Any:
"""
Build an Agno model instance for this provider.
Args:
api_key: API key for the provider
model: Model name (uses default if not specified)
base_url: Custom base URL (uses provider default if not specified)
**kwargs: Additional model parameters
Returns:
Configured Agno model instance
"""
pass
@abstractmethod
async def execute(
self,
context: ExecutionContext,
api_key: str,
model: str | None = None,
base_url: str | None = None,
) -> AsyncGenerator[StreamChunk, None]:
"""
Execute chat completion with streaming support.
Args:
context: Execution context with messages and parameters
api_key: API key for the provider
model: Model name to use
base_url: Custom base URL
Yields:
StreamChunk objects with content, thoughts, and tool calls
"""
pass
def extract_thinking_content(self, chunk: dict[str, Any]) -> str | None:
"""
Extract thinking/reasoning content from a streaming chunk.
Args:
chunk: Raw chunk from the provider
Returns:
Thinking content if present, None otherwise
"""
# Check common locations for reasoning content
raw_response = chunk.get("additional_kwargs", {}).get("__raw_response", {})
choices = raw_response.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
reasoning = delta.get("reasoning_content") or delta.get("reasoning")
if reasoning:
return str(reasoning)
# Check additional_kwargs directly
additional = chunk.get("additional_kwargs", {})
reasoning = additional.get("reasoning_content") or additional.get("reasoning")
if reasoning:
return str(reasoning)
return None
def parse_tool_calls(self, response: dict[str, Any]) -> list[dict[str, Any]] | None:
"""
Parse tool calls from a response.
Args:
response: Response from the provider
Returns:
List of tool calls or None
"""
raw = response.get("additional_kwargs", {}).get("__raw_response", {})
choice = raw.get("choices", [{}])[0]
message = choice.get("message", {})
tool_calls = (
message.get("tool_calls") or
response.get("additional_kwargs", {}).get("tool_calls") or
response.get("tool_calls")
)
if tool_calls:
return list(tool_calls)
return None
def get_finish_reason(self, response: dict[str, Any]) -> str | None:
"""Get finish reason from response."""
raw = response.get("additional_kwargs", {}).get("__raw_response", {})
return raw.get("choices", [{}])[0].get("finish_reason")
def normalize_tool_call(self, tool_call: dict[str, Any]) -> dict[str, Any]:
"""Normalize tool call to standard format."""
function = tool_call.get("function", {})
return {
"id": tool_call.get("id"),
"type": tool_call.get("type", "function"),
"function": {
"name": function.get("name") or tool_call.get("name"),
"arguments": function.get("arguments") or tool_call.get("arguments", ""),
},
}
def apply_context_limit(
self,
messages: list[dict[str, Any]],
limit: int | None
) -> list[dict[str, Any]]:
"""
Apply context limit to messages, preserving system messages.
Args:
messages: Message list
limit: Maximum number of non-system messages to keep
Returns:
Filtered message list
"""
if not limit or limit <= 0 or len(messages) <= limit:
return messages
system_messages = [m for m in messages if m.get("role") == "system"]
non_system_messages = [m for m in messages if m.get("role") != "system"]
recent = non_system_messages[-limit:]
return system_messages + recent
|