Humanlearning's picture
+ robust error handling
33426c9
raw
history blame
6.04 kB
import os
import random
import asyncio
import ssl
from dotenv import load_dotenv
from llama_index.core.agent.workflow import AgentWorkflow
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
# from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from langfuse import get_client
from rich.pretty import pprint
import aiohttp
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
import base64
# Import tool functions and initializers from tools.py
from tools import (
get_tavily_tool,
get_arxiv_reader,
get_wikipedia_reader,
get_wikipedia_tool,
get_arxiv_tool,
get_search_tool,
get_calculator_tool,
get_hub_stats_tool,
get_hub_stats,
)
load_dotenv("env.local")
class LlamaIndexAgent:
def __init__(self):
# Tool initializations using imported functions
self.tavily_tool = get_tavily_tool()
self.arxiv_reader = get_arxiv_reader()
self.wikipedia_reader = get_wikipedia_reader()
self.wikipedia_tool = get_wikipedia_tool(self.wikipedia_reader)
self.arxiv_tool = get_arxiv_tool(self.arxiv_reader)
self.search_tool = get_search_tool()
self.calculator_tool = get_calculator_tool()
self.hub_stats_tool = get_hub_stats_tool()
with open("system_prompt.txt", "r") as f:
self.system_prompt = f.read()
print("system_prompt loaded:", self.system_prompt[:80], "...")
print("DEBUG: search_tool:", self.search_tool, type(self.search_tool))
print("DEBUG: calculator_tool:", self.calculator_tool, type(self.calculator_tool))
print("DEBUG: wikipedia_tool:", self.wikipedia_tool, type(self.wikipedia_tool))
print("DEBUG: arxiv_tool:", self.arxiv_tool, type(self.arxiv_tool))
print("DEBUG: hub_stats_tool:", self.hub_stats_tool, type(self.hub_stats_tool))
all_tools = [*self.search_tool, *self.calculator_tool, self.wikipedia_tool, self.arxiv_tool, self.hub_stats_tool]
print("DEBUG: All tools list:", all_tools)
print("DEBUG: Types in all_tools:", [type(t) for t in all_tools])
# LLM and agent workflow
# self.llm = HuggingFaceInferenceAPI(model_name="Qwen/Qwen2.5-Coder-32B-Instruct")
self.llm = HuggingFaceInferenceAPI(model_name="Qwen/Qwen2.5-Coder-32B-Instruct", streaming=False, client_kwargs={"timeout": 60})
self.alfred = AgentWorkflow.from_tools_or_functions(
all_tools,
llm=self.llm,
system_prompt=self.system_prompt
# verbose=True
)
LANGFUSE_AUTH=base64.b64encode(f"{os.getenv('LANGFUSE_PUBLIC_KEY')}:{os.getenv('LANGFUSE_SECRET_KEY')}".encode()).decode()
os.environ['OTEL_EXPORTER_OTLP_ENDPOINT'] = os.environ.get("LANGFUSE_HOST") + "/api/public/otel"
os.environ['OTEL_EXPORTER_OTLP_HEADERS'] = f"Authorization=Basic {LANGFUSE_AUTH}"
# Set up OpenTelemetry tracing
self.tracer_provider = TracerProvider()
self.tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter()))
trace.set_tracer_provider(self.tracer_provider)
self.instrumentor = LlamaIndexInstrumentor(
public_key=os.getenv("LANGFUSE_PUBLIC_KEY"),
secret_key=os.getenv("LANGFUSE_SECRET_KEY"),
host=os.environ.get("LANGFUSE_HOST")
)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type((
aiohttp.client_exceptions.ClientConnectionError,
aiohttp.client_exceptions.ClientOSError,
ssl.SSLError,
KeyError,
ConnectionError
))
)
async def run_query(self, query: str):
# Instrument LlamaIndex with OpenTelemetry
self.instrumentor.instrument()
langfuse = get_client() # This picks up your LANGFUSE_PUBLIC_KEY, etc.
# Now, wrap your LlamaIndex calls in a Langfuse span for trace context
with langfuse.start_as_current_span(name="llamaindex-query") as span:
# Optionally set trace attributes
span.update_trace(user_id="user_123", input={"query": query})
try:
response = await self.alfred.run(query)
except aiohttp.client_exceptions.ClientConnectionError as e:
span.update_trace(output={"response": f"Connection error: {e}"})
raise # Re-raise for retry logic
except aiohttp.client_exceptions.ClientOSError as e:
span.update_trace(output={"response": f"Client OS error: {e}"})
raise # Re-raise for retry logic
except ssl.SSLError as e:
span.update_trace(output={"response": f"SSL error: {e}"})
raise # Re-raise for retry logic
except (KeyError, ConnectionError) as e:
span.update_trace(output={"response": f"Session/Connection error: {e}"})
raise # Re-raise for retry logic
except Exception as e:
span.update_trace(output={"response": f"General error: {e}"})
return f"AGENT ERROR: {e}"
# Optionally set trace output
span.update_trace(output={"response": str(response)})
# For short-lived scripts, flush before exit
langfuse.flush()
self.tracer_provider.shutdown()
return response
def main():
agent = LlamaIndexAgent()
query = "what is the capital of maharashtra?"
print(f"Running query: {query}")
response = asyncio.run(agent.run_query(query))
print("\n🎩 Agents's Response:")
print(response)
if __name__ == "__main__":
main()