| | """LangGraph Agent with Direct Groq API and Custom Rate Limiting""" |
| | import os |
| | import time |
| | import threading |
| | from collections import deque |
| | from typing import Dict, Any, List |
| | from dotenv import load_dotenv |
| | from langgraph.graph import START, StateGraph, MessagesState |
| | from langgraph.prebuilt import tools_condition |
| | from langgraph.prebuilt import ToolNode |
| | from langchain_community.tools.tavily_search import TavilySearchResults |
| | from langchain_community.document_loaders import WikipediaLoader |
| | from langchain_community.document_loaders import ArxivLoader |
| | from langchain_core.messages import SystemMessage, HumanMessage, AIMessage |
| | from langchain_core.tools import tool |
| | from groq import Groq, RateLimitError |
| | import logging |
| |
|
| | load_dotenv() |
| |
|
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class GroqRateLimiter: |
| | """Thread-safe rate limiter for direct Groq API calls""" |
| | |
| | def __init__(self, rpm: int = 20, tpm: int = 6000): |
| | self.rpm = rpm |
| | self.tpm = tpm |
| | self.request_times = deque() |
| | self.token_usage = deque() |
| | self.lock = threading.Lock() |
| | |
| | def _clean_old_records(self, current_time: float): |
| | """Remove records older than 1 minute""" |
| | minute_ago = current_time - 60 |
| | |
| | while self.request_times and self.request_times[0] <= minute_ago: |
| | self.request_times.popleft() |
| | |
| | while self.token_usage and self.token_usage[0][0] <= minute_ago: |
| | self.token_usage.popleft() |
| | |
| | def can_make_request(self, estimated_tokens: int = 1000) -> tuple[bool, float]: |
| | """Check if request can be made, return (can_proceed, wait_time)""" |
| | with self.lock: |
| | current_time = time.time() |
| | self._clean_old_records(current_time) |
| | |
| | wait_time = 0 |
| | |
| | |
| | if len(self.request_times) >= self.rpm: |
| | oldest_request = self.request_times[0] |
| | wait_time = max(wait_time, 60 - (current_time - oldest_request)) |
| | |
| | |
| | current_tokens = sum(tokens for _, tokens in self.token_usage) |
| | if current_tokens + estimated_tokens > self.tpm: |
| | if self.token_usage: |
| | oldest_token_time = self.token_usage[0][0] |
| | wait_time = max(wait_time, 60 - (current_time - oldest_token_time)) |
| | |
| | return wait_time <= 0, wait_time |
| | |
| | def record_request(self, token_count: int): |
| | """Record a successful request""" |
| | with self.lock: |
| | current_time = time.time() |
| | self.request_times.append(current_time) |
| | self.token_usage.append((current_time, token_count)) |
| |
|
| | class GroqWrapper: |
| | """Wrapper for direct Groq API with rate limiting and error handling""" |
| | |
| | def __init__(self, model: str = "qwen/qwen3-32b", |
| | rpm: int = 30, tpm: int = 6000): |
| | self.client = Groq(api_key=os.getenv("GROQ_API_KEY")) |
| | self.model = model |
| | self.rate_limiter = GroqRateLimiter(rpm=rpm, tpm=tpm) |
| | |
| | def estimate_tokens(self, messages: List[Dict]) -> int: |
| | """Rough token estimation (4 chars ≈ 1 token)""" |
| | total_chars = sum(len(str(msg.get('content', ''))) for msg in messages) |
| | return max(total_chars // 4, 100) |
| | |
| | def invoke(self, messages: List[Dict], **kwargs) -> Dict: |
| | """Invoke Groq API with rate limiting and retry logic""" |
| | |
| | groq_messages = [] |
| | for msg in messages: |
| | if hasattr(msg, 'content') and hasattr(msg, 'type'): |
| | |
| | role = "user" if msg.type == "human" else "assistant" if msg.type == "ai" else "system" |
| | groq_messages.append({"role": role, "content": str(msg.content)}) |
| | else: |
| | |
| | groq_messages.append(msg) |
| | |
| | estimated_tokens = self.estimate_tokens(groq_messages) |
| | |
| | max_retries = 3 |
| | for attempt in range(max_retries): |
| | try: |
| | |
| | can_proceed, wait_time = self.rate_limiter.can_make_request(estimated_tokens) |
| | if not can_proceed: |
| | logger.info(f"Rate limit: waiting {wait_time:.2f} seconds") |
| | time.sleep(wait_time) |
| | |
| | |
| | response = self.client.chat.completions.create( |
| | model=self.model, |
| | messages=groq_messages, |
| | **kwargs |
| | ) |
| | |
| | |
| | actual_tokens = response.usage.total_tokens if hasattr(response, 'usage') else estimated_tokens |
| | self.rate_limiter.record_request(actual_tokens) |
| | |
| | |
| | content = response.choices[0].message.content |
| | return AIMessage(content=content) |
| | |
| | except RateLimitError as e: |
| | if attempt == max_retries - 1: |
| | raise e |
| | |
| | |
| | retry_after = getattr(e.response, 'headers', {}).get('retry-after') |
| | if retry_after: |
| | delay = float(retry_after) |
| | else: |
| | delay = 2 ** attempt |
| | |
| | logger.warning(f"Rate limited. Retrying in {delay} seconds (attempt {attempt + 1})") |
| | time.sleep(delay) |
| | |
| | except Exception as e: |
| | logger.error(f"Groq API error: {e}") |
| | if attempt == max_retries - 1: |
| | raise e |
| | time.sleep(2 ** attempt) |
| | |
| | raise Exception("Max retries exceeded") |
| | |
| | def bind_tools(self, tools): |
| | """Mock bind_tools method for compatibility""" |
| | self.tools = tools |
| | return self |
| |
|
| | |
| | @tool |
| | def multiply(a: int, b: int) -> int: |
| | """Multiply two numbers.""" |
| | return a * b |
| |
|
| | @tool |
| | def add(a: int, b: int) -> int: |
| | """Add two numbers.""" |
| | return a + b |
| |
|
| | @tool |
| | def subtract(a: int, b: int) -> int: |
| | """Subtract two numbers.""" |
| | return a - b |
| |
|
| | @tool |
| | def divide(a: float, b: float) -> float: |
| | """Divide two numbers.""" |
| | if b == 0: |
| | raise ValueError("Cannot divide by zero.") |
| | return a / b |
| |
|
| | @tool |
| | def modulus(a: int, b: int) -> int: |
| | """Get the modulus of two numbers.""" |
| | return a % b |
| |
|
| | @tool |
| | def wiki_search(query: str) -> str: |
| | """Search Wikipedia for a query and return maximum 2 results.""" |
| | try: |
| | search_docs = WikipediaLoader(query=query, load_max_docs=2).load() |
| | formatted_search_docs = "\n\n---\n\n".join( |
| | [ |
| | f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' |
| | for doc in search_docs |
| | ]) |
| | return {"wiki_results": formatted_search_docs} |
| | except Exception as e: |
| | return {"wiki_results": f"Error: {str(e)}"} |
| |
|
| | @tool |
| | def web_search(query: str) -> str: |
| | """Search Tavily for a query and return maximum 3 results.""" |
| | try: |
| | search_docs = TavilySearchResults(max_results=3).invoke(query=query) |
| | formatted_search_docs = "\n\n---\n\n".join( |
| | [ |
| | f'<Document source="{doc.get("url", "")}">\n{doc.get("content", "")}\n</Document>' |
| | for doc in search_docs |
| | ]) |
| | return {"web_results": formatted_search_docs} |
| | except Exception as e: |
| | return {"web_results": f"Error: {str(e)}"} |
| |
|
| | @tool |
| | def arxiv_search(query: str) -> str: |
| | """Search Arxiv for a query and return maximum 3 results.""" |
| | try: |
| | search_docs = ArxivLoader(query=query, load_max_docs=3).load() |
| | formatted_search_docs = "\n\n---\n\n".join( |
| | [ |
| | f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' |
| | for doc in search_docs |
| | ]) |
| | return {"arxiv_results": formatted_search_docs} |
| | except Exception as e: |
| | return {"arxiv_results": f"Error: {str(e)}"} |
| |
|
| | def load_system_prompt(): |
| | """Load system prompt with error handling""" |
| |
|
| | with open("system_prompt.txt", "r", encoding="utf-8") as f: |
| | return f.read() |
| |
|
| |
|
| | system_prompt = load_system_prompt() |
| | sys_msg = SystemMessage(content=system_prompt) |
| |
|
| | tools = [ |
| | multiply, |
| | add, |
| | subtract, |
| | divide, |
| | modulus, |
| | wiki_search, |
| | web_search, |
| | arxiv_search, |
| | ] |
| |
|
| | def build_graph(provider: str = "direct_groq", model: str = "qwen/qwen3-32b"): |
| | """Build the graph with direct Groq API and custom rate limiting""" |
| | |
| | if provider == "direct_groq": |
| | |
| | llm = GroqWrapper(model=model, rpm=30, tpm=6000) |
| | |
| | elif provider == "langchain_groq": |
| | |
| | from langchain_core.rate_limiters import InMemoryRateLimiter |
| | |
| | rate_limiter = InMemoryRateLimiter( |
| | requests_per_second=0.5, |
| | check_every_n_seconds=0.1, |
| | max_bucket_size=5, |
| | ) |
| | |
| | from langchain_groq import ChatGroq |
| | llm = ChatGroq( |
| | model=model, |
| | temperature=0, |
| | groq_api_key=os.getenv("GROQ_API_KEY"), |
| | rate_limiter=rate_limiter |
| | ) |
| | else: |
| | raise ValueError("Choose 'direct_groq' or 'langchain_groq'") |
| | |
| | |
| | llm_with_tools = llm.bind_tools(tools) |
| |
|
| | def assistant(state: MessagesState): |
| | """Assistant node""" |
| | try: |
| | response = llm_with_tools.invoke(state["messages"]) |
| | return {"messages": [response]} |
| | except Exception as e: |
| | logger.error(f"Assistant failed: {e}") |
| | error_msg = AIMessage(content=f"I encountered an error: {str(e)}") |
| | return {"messages": [error_msg]} |
| |
|
| | |
| | builder = StateGraph(MessagesState) |
| | builder.add_node("assistant", assistant) |
| | builder.add_node("tools", ToolNode(tools)) |
| | builder.add_edge(START, "assistant") |
| | builder.add_conditional_edges("assistant", tools_condition) |
| | builder.add_edge("tools", "assistant") |
| |
|
| | return builder.compile() |
| |
|
| | if __name__ == "__main__": |
| | question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?" |
| | |
| | try: |
| | |
| | graph = build_graph(provider="direct_groq") |
| | messages = [HumanMessage(content=question)] |
| | result = graph.invoke({"messages": messages}) |
| | |
| | for m in result["messages"]: |
| | m.pretty_print() |
| | |
| | except Exception as e: |
| | logger.error(f"Test failed: {e}") |