| import os |
| from typing import List, Any |
| from pydantic_ai import Agent |
| from pydantic_ai.models.gemini import GeminiModel |
| from pydantic_ai.providers.google_gla import GoogleGLAProvider |
| from google.api_core.exceptions import ResourceExhausted |
| from csv_service import get_csv_basic_info |
| from orchestrator_functions import csv_chart, csv_chat |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
|
|
| |
| GEMINI_API_KEYS = os.getenv("GEMINI_API_KEYS", "").split(",") |
|
|
| |
| def initialize_model(api_key: str) -> GeminiModel: |
| return GeminiModel( |
| os.getenv("GEMINI_LLM_MODEL"), |
| provider=GoogleGLAProvider(api_key=api_key) |
| ) |
|
|
| |
| async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any: |
| """ |
| This function generates answers for the given user questions using the CSV URL. |
| It uses the csv_chat function to process each question and return the answers. |
| |
| Args: |
| csv_url (str): The URL of the CSV file. |
| user_questions (List[str]): A list of user questions. |
| |
| Returns: |
| List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question. |
| |
| Example: |
| [ |
| {"question": "What is the average age of the customers?", "answer": "The average age is 35."}, |
| {"question": "What is the most common gender?", "answer": "The most common gender is Male."} |
| ] |
| """ |
| |
| print("LLM using the csv chat function....") |
| print("CSV URL:", csv_url) |
| print("User question:", user_questions) |
|
|
| |
| answers = [] |
| |
| for question in user_questions: |
| answer = await csv_chat(csv_url, question) |
| answers.append(dict(question=question, answer=answer)) |
| return answers |
|
|
| async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any: |
| |
| """ |
| This function generates charts for the given user questions using the CSV URL. |
| It uses the csv_chart function to process each question and return the chart URLs. |
| It returns a list of dictionaries containing the question and chart URL for each question. |
| Args: |
| csv_url (str): The URL of the CSV file. |
| user_questions (List[str]): A list of user questions. |
| |
| Returns: |
| List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question. |
| |
| Example: |
| [ |
| {"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"}, |
| {"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"} |
| ] |
| """ |
| |
| print("LLM using the csv chart function....") |
| print("CSV URL:", csv_url) |
| print("User question:", user_questions) |
|
|
| |
| charts = [] |
| |
| for question in user_questions: |
| chart = await csv_chart(csv_url, question, chat_id) |
| charts.append(dict(question=question, image_url=chart)) |
| |
| return charts |
|
|
| |
| def create_agent(csv_url: str, api_key: str, conversation_history: List, chat_id: str) -> Agent: |
| csv_metadata = get_csv_basic_info(csv_url) |
| |
| system_prompt = f""" |
| # Role: Data Analyst Assistant |
| **Specialization:** CSV Analysis & Visualization |
| |
| ## Critical Rules: |
| |
| ### 1. Tool Usage - MANDATORY |
| - You MUST use `generate_csv_answer` tool for ALL data analysis questions |
| - You MUST use `generate_chart` tool ONLY when explicitly asked for visualization, graph, chart, or plot |
| - NEVER generate image markdown syntax () unless you have called `generate_chart` tool and received a real URL |
| - NEVER fabricate or create placeholder image URLs |
| |
| ### 2. When to Generate Visualizations |
| **ONLY create visualizations when the user explicitly requests:** |
| - "show me a chart/graph/plot" |
| - "visualize this data" |
| - "create a visualization" |
| - "plot the data" |
| - Any similar explicit visualization request |
| |
| **DO NOT create visualizations for:** |
| - Simple data retrieval questions (e.g., "how many rows?", "what is the average?") |
| - Questions that can be answered with text alone |
| - Questions that don't explicitly ask for visual representation |
| |
| ### 3. Response Format |
| - For questions WITHOUT visualization request: Provide only the textual answer from `generate_csv_answer` |
| - For questions WITH visualization request: Provide both textual answer AND call `generate_chart`, then include the image using the URL returned by the tool |
| |
| ### 4. Output Guidelines |
| - Use markdown formatting for text responses |
| - Only include image syntax `` if you actually called `generate_chart` and got a real URL back |
| - Provide clear, concise answers with explanations |
| - Never mention tool names to the user |
| |
| ## Current Context: |
| - **Dataset:** {csv_url} |
| - **Metadata:** {csv_metadata} |
| - **History:** {conversation_history} |
| - **Chat ID:** {chat_id} |
| |
| ## Example Behavior: |
| |
| **Question: "How many rows are in the dataset?"** |
| β
Correct Response: "The dataset contains 1,000 rows." |
| β Wrong Response: "The dataset contains 1,000 rows. " |
| |
| **Question: "What is the average age?"** |
| β
Correct Response: "The average age is 35 years." |
| β Wrong Response: "The average age is 35 years. " |
| |
| **Question: "Show me a chart of the distribution of ages"** |
| β
Correct Response: "The age distribution shows... [call generate_chart tool] " |
| β Wrong Response: "The age distribution shows... " |
| |
| **Remember:** |
| - Always generate fresh, tool-assisted responses |
| - Never reuse previous answers |
| - Never create fake image URLs |
| - Only use visualization when explicitly requested by the user |
| - Text-only answers are perfectly acceptable for most questions |
| - Never re-use previous image URLs without calling `generate_chart` |
| - Each request is unique and should be handled separately |
| """ |
|
|
| |
| return Agent( |
| model=initialize_model(api_key), |
| deps_type=str, |
| tools=[generate_csv_answer, generate_chart], |
| system_prompt=system_prompt, |
| retries=0 |
| ) |
|
|
| def csv_orchestrator_chat_gemini(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str: |
| print("CSV URL:", csv_url) |
| print("User questions:", user_question) |
|
|
| |
| for api_key in GEMINI_API_KEYS: |
| try: |
| print(f"Attempting with API key: {api_key}") |
| agent = create_agent(csv_url, api_key, conversation_history, chat_id) |
| result = agent.run_sync(user_question) |
| print("Orchestrator Result:", result.data) |
| return result.data |
| except ResourceExhausted or Exception as e: |
| print(f"Quota exhausted for API key: {api_key}. Switching to the next key.") |
| continue |
| except Exception as e: |
| print(f"Error with API key {api_key}: {e}") |
| continue |
|
|
| |
| print("All API keys have been exhausted or failed.") |
| return None |