Spaces:
Sleeping
Sleeping
File size: 6,339 Bytes
9419f40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import json
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
from groq import Groq
from google import genai
from google.genai import types
class ProviderError(Exception):
"""Base class for all provider issues."""
pass
class QuotaExhaustedError(ProviderError):
"""Raised when we run out of credits/limit."""
pass
class ProviderDownError(ProviderError):
"""Raised when the provider is temporarily broken (500, 429)."""
pass
class LLMResponse(BaseModel):
content: Optional[str] = None
tool_call: Optional[Dict[str, Any]] = None
class LLMProvider(ABC):
@abstractmethod
def get_response(self, messages: List[Dict[str, str]], tools: List[Dict]) -> LLMResponse:
"""
Args:
messages: Full conversation history [{"role": "user", "content": "..."}, ...]
tools: JSON Schema definitions for tools.
"""
pass
class GroqProvider(LLMProvider):
def __init__(self, api_key: str, model_name: str = 'llama-3.1-8b-instant'):
self.client = Groq(api_key=api_key)
self.model_name = model_name
def get_response(self, messages: List[Dict[str, str]], tools: List[Dict]) -> LLMResponse:
try:
# Groq/OpenAI native format
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
tools=tools,
tool_choice="auto",
temperature=0.1
)
candidate = response.choices[0]
# Check for tool calls
if candidate.message.tool_calls:
# We take the first tool call
tool_call_data = candidate.message.tool_calls[0]
return LLMResponse(
tool_call={
"name": tool_call_data.function.name,
"args": json.loads(tool_call_data.function.arguments),
"id": tool_call_data.id # Store ID for history tracking
}
)
# Return text content
return LLMResponse(content=candidate.message.content)
except Exception as e:
error_msg = str(e).lower()
if "resource_exhausted" in error_msg or "quota" in error_msg:
raise QuotaExhaustedError("Groq Quota Exhausted")
else:
raise ProviderDownError(f"Groq Error: {e}")
class GeminiProvider(LLMProvider):
def __init__(self, api_key: str, model_name: str = 'gemini-2.0-flash'):
self.client = genai.Client(api_key=api_key)
self.model_name = model_name
def _map_tools(self, tools: List[Dict]) -> List[types.Tool]:
"""
Converts OpenAI/Groq-style tool definitions into Gemini types.Tool objects.
"""
gemini_tools = []
for t in tools:
# Check if it matches the OpenAI schema {"type": "function", "function": {...}}
if t.get("type") == "function":
func_def = t["function"]
# Create the Gemini-specific FunctionDeclaration
fn_decl = types.FunctionDeclaration(
name=func_def["name"],
description=func_def.get("description"),
parameters=func_def.get("parameters")
)
# Wrap it in a Tool object
gemini_tools.append(types.Tool(function_declarations=[fn_decl]))
return gemini_tools
def _default_history_format(self, messages: List[Dict]) -> str:
formatted_prompt = ""
for msg in messages:
role = msg["role"]
content = msg.get("content", "") or ""
if role == "system":
formatted_prompt += f"System Instruction: {content}\n\n"
elif role == "user":
formatted_prompt += f"User: {content}\n"
elif role == "assistant":
if "tool_calls" in msg:
tc = msg["tool_calls"][0]
formatted_prompt += f"Assistant (Thought): I will call tool '{tc['function']['name']}' with args {tc['function']['arguments']}.\n"
else:
formatted_prompt += f"Assistant: {content}\n"
elif role == "tool":
formatted_prompt += f"Tool Output ({msg.get('name')}): {content}\n"
formatted_prompt += "\nBased on the history above, provide the next response or tool call."
return formatted_prompt
def get_response(self, messages: List[Dict[str, str]], tools: List[Dict]) -> LLMResponse:
try:
# 1. Translate History
full_prompt = self._default_history_format(messages)
gemini_messages = [
types.Content(role="user", parts=[types.Part(text=full_prompt)])
]
# 2. Translate Tools
mapped_tools = self._map_tools(tools)
config = types.GenerateContentConfig(
tools=mapped_tools,
temperature=0.0
)
response = self.client.models.generate_content(
model=self.model_name,
contents=gemini_messages,
config=config,
)
candidate = response.candidates[0]
function_call_part = None
for part in candidate.content.parts:
if part.function_call:
function_call_part = part
break
if function_call_part:
return LLMResponse(
tool_call={
"name": function_call_part.function_call.name,
"args": function_call_part.function_call.args
}
)
return LLMResponse(content=candidate.content.parts[0].text)
except Exception as e:
error_msg = str(e).lower()
if "resource_exhausted" in error_msg or "quota" in error_msg:
raise QuotaExhaustedError("Gemini Quota Exhausted")
else:
raise ProviderDownError(f"Gemini Error: {e}") |