FastApi / orchestrator_agent.py
Soumik Bose
go
274e9c9
import os
from typing import List, Any
from pydantic_ai import Agent
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.google_gla import GoogleGLAProvider
from google.api_core.exceptions import ResourceExhausted # Import the exception for quota exhaustion
from csv_service import get_csv_basic_info
from orchestrator_functions import csv_chart, csv_chat
from dotenv import load_dotenv
load_dotenv()
# Load all API keys from the environment variable
GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",") # Expecting a comma-separated list of keys
# Function to initialize the model with a specific API key
def initialize_model(api_key: str) -> GeminiModel:
return GeminiModel(
os.getenv("GEMINI_LLM_MODEL"),
provider=GoogleGLAProvider(api_key=api_key)
)
# Define the tools
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
"""
This function generates answers for the given user questions using the CSV URL.
It uses the csv_chat function to process each question and return the answers.
Args:
csv_url (str): The URL of the CSV file.
user_questions (List[str]): A list of user questions.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question.
Example:
[
{"question": "What is the average age of the customers?", "answer": "The average age is 35."},
{"question": "What is the most common gender?", "answer": "The most common gender is Male."}
]
"""
print("LLM using the csv chat function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the answers
answers = []
# Loop through the user questions and generate answers for each
for question in user_questions:
answer = await csv_chat(csv_url, question)
answers.append(dict(question=question, answer=answer))
return answers
async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
"""
This function generates charts for the given user questions using the CSV URL.
It uses the csv_chart function to process each question and return the chart URLs.
It returns a list of dictionaries containing the question and chart URL for each question.
Args:
csv_url (str): The URL of the CSV file.
user_questions (List[str]): A list of user questions.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question.
Example:
[
{"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"},
{"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"}
]
"""
print("LLM using the csv chart function....")
print("CSV URL:", csv_url)
print("User question:", user_questions)
# Create an array to accumulate the charts
charts = []
# Loop through the user questions and generate charts for each
for question in user_questions:
chart = await csv_chart(csv_url, question, chat_id)
charts.append(dict(question=question, image_url=chart))
return charts
# Function to create an agent with a specific CSV URL
def create_agent(csv_url: str, api_key: str, conversation_history: List, chat_id: str) -> Agent:
csv_metadata = get_csv_basic_info(csv_url)
system_prompt = f"""
# Role: Data Analyst Assistant
**Specialization:** CSV Analysis & Visualization
## Critical Rules:
### 1. Tool Usage - MANDATORY
- You MUST use `generate_csv_answer` tool for ALL data analysis questions
- You MUST use `generate_chart` tool ONLY when explicitly asked for visualization, graph, chart, or plot
- NEVER generate image markdown syntax (![...](url)) unless you have called `generate_chart` tool and received a real URL
- NEVER fabricate or create placeholder image URLs
### 2. When to Generate Visualizations
**ONLY create visualizations when the user explicitly requests:**
- "show me a chart/graph/plot"
- "visualize this data"
- "create a visualization"
- "plot the data"
- Any similar explicit visualization request
**DO NOT create visualizations for:**
- Simple data retrieval questions (e.g., "how many rows?", "what is the average?")
- Questions that can be answered with text alone
- Questions that don't explicitly ask for visual representation
### 3. Response Format
- For questions WITHOUT visualization request: Provide only the textual answer from `generate_csv_answer`
- For questions WITH visualization request: Provide both textual answer AND call `generate_chart`, then include the image using the URL returned by the tool
### 4. Output Guidelines
- Use markdown formatting for text responses
- Only include image syntax `![Description](url)` if you actually called `generate_chart` and got a real URL back
- Provide clear, concise answers with explanations
- Never mention tool names to the user
## Current Context:
- **Dataset:** {csv_url}
- **Metadata:** {csv_metadata}
- **History:** {conversation_history}
- **Chat ID:** {chat_id}
## Example Behavior:
**Question: "How many rows are in the dataset?"**
βœ… Correct Response: "The dataset contains 1,000 rows."
❌ Wrong Response: "The dataset contains 1,000 rows. ![Number of Rows](url)"
**Question: "What is the average age?"**
βœ… Correct Response: "The average age is 35 years."
❌ Wrong Response: "The average age is 35 years. ![Average Age](https://example.com/chart.png)"
**Question: "Show me a chart of the distribution of ages"**
βœ… Correct Response: "The age distribution shows... [call generate_chart tool] ![Age Distribution](actual_url_from_tool)"
❌ Wrong Response: "The age distribution shows... ![Age Distribution](https://example.com/chart.png)"
**Remember:**
- Always generate fresh, tool-assisted responses
- Never reuse previous answers
- Never create fake image URLs
- Only use visualization when explicitly requested by the user
- Text-only answers are perfectly acceptable for most questions
- Never re-use previous image URLs without calling `generate_chart`
- Each request is unique and should be handled separately
"""
return Agent(
model=initialize_model(api_key),
deps_type=str,
tools=[generate_csv_answer, generate_chart],
system_prompt=system_prompt,
retries=0
)
def csv_orchestrator_chat_gemini(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
print("CSV URL:", csv_url)
print("User questions:", user_question)
# Iterate through all API keys
for api_key in GEMINI_API_KEYS:
try:
print(f"Attempting with API key: {api_key}")
agent = create_agent(csv_url, api_key, conversation_history, chat_id)
result = agent.run_sync(user_question)
print("Orchestrator Result:", result.data)
return result.data
except ResourceExhausted or Exception as e:
print(f"Quota exhausted for API key: {api_key}. Switching to the next key.")
continue # Move to the next key
except Exception as e:
print(f"Error with API key {api_key}: {e}")
continue # Move to the next key
# If all keys are exhausted or fail
print("All API keys have been exhausted or failed.")
return None