GAIA_Agent / llm_rotator.py
nikhmr1235's picture
add descriptive docstring for llm_rotator.py
e6739e2 verified
"""
Provides a thread-safe, round-robin API key rotator for LangChain.
This module defines the ApiKeyRotator class, a custom LangChain Runnable
designed to cycle through a list of Google API keys for each invocation.
Its primary purpose is to distribute API calls across multiple keys, helping
to manage usage quotas and avoid rate-limiting errors in high-volume
applications.
The rotator is thread-safe, making it suitable for use in concurrent
environments.
"""
import threading
from typing import List, Any
from langchain_core.runnables import RunnableSerializable
from langchain_google_genai import ChatGoogleGenerativeAI
from pydantic import PrivateAttr
class ApiKeyRotator(RunnableSerializable):
"""
A custom LangChain Runnable that rotates Google API keys for each call.
This allows distributing requests across multiple API keys to stay within
free tier limits. It is designed to be a drop-in replacement for the
standard ChatGoogleGenerativeAI instance in a LangChain chain.
"""
api_keys: List[str]
model: str
temperature: float = 0.0
streaming: bool = False
# Private, non-pydantic attributes for managing state thread-safely.
# These are excluded from serialization and deepcopy operations.
_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_next_key_index: int = PrivateAttr(default=0)
def _get_next_key(self) -> str:
"""Safely gets the next API key from the list in a round-robin fashion."""
with self._lock:
key = self.api_keys[self._next_key_index]
self._next_key_index = (self._next_key_index + 1) % len(self.api_keys)
return key
def invoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any:
"""Synchronous invocation with a temporary LLM instance."""
api_key = self._get_next_key()
print(f"--- Using Google API Key ending in: ...{api_key[-4:]} ---")
# Create a temporary LLM with the selected key
temp_llm = ChatGoogleGenerativeAI(
model=self.model,
google_api_key=api_key,
temperature=self.temperature,
streaming=self.streaming
)
return temp_llm.invoke(input, config, **kwargs)
async def ainvoke(self, input: Any, config: Any = None, **kwargs: Any) -> Any:
"""Asynchronous invocation with a temporary LLM instance."""
api_key = self._get_next_key()
# Create a temporary LLM with the selected key
temp_llm = ChatGoogleGenerativeAI(
model=self.model,
google_api_key=api_key,
temperature=self.temperature,
streaming=self.streaming
)
return await temp_llm.ainvoke(input, config, **kwargs)