Soumik Bose commited on
Commit ·
e4f58ae
1
Parent(s): bbde124
go
Browse files- orc_agent_main_cerebras.py +40 -74
orc_agent_main_cerebras.py
CHANGED
|
@@ -1,48 +1,41 @@
|
|
| 1 |
-
import
|
| 2 |
from typing import List, Any
|
|
|
|
|
|
|
|
|
|
| 3 |
from pydantic_ai import Agent
|
|
|
|
|
|
|
|
|
|
| 4 |
from openai import RateLimitError, APIError
|
|
|
|
|
|
|
| 5 |
from csv_service import get_csv_basic_info
|
| 6 |
from orchestrator_functions import csv_chart, csv_chat
|
| 7 |
from cerebras_instance_provider import InstanceProvider
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
load_dotenv()
|
| 12 |
|
| 13 |
# Initialize the instance provider
|
| 14 |
instance_provider = InstanceProvider()
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
-
# Define the tools
|
| 20 |
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
|
| 21 |
"""
|
| 22 |
-
|
| 23 |
-
It uses the csv_chat function to process each question and return the answers.
|
| 24 |
-
|
| 25 |
-
Args:
|
| 26 |
-
csv_url (str): The URL of the CSV file.
|
| 27 |
-
user_questions (List[str]): A list of user questions.
|
| 28 |
-
|
| 29 |
-
Returns:
|
| 30 |
-
List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question.
|
| 31 |
-
|
| 32 |
-
Example:
|
| 33 |
-
[
|
| 34 |
-
{"question": "What is the average age of the customers?", "answer": "The average age is 35."},
|
| 35 |
-
{"question": "What is the most common gender?", "answer": "The most common gender is Male."}
|
| 36 |
-
]
|
| 37 |
"""
|
| 38 |
-
|
| 39 |
logger.info("LLM using the csv chat function....")
|
| 40 |
logger.info(f"CSV URL: {csv_url}")
|
| 41 |
logger.info(f"User question: {user_questions}")
|
| 42 |
|
| 43 |
-
# Create an array to accumulate the answers
|
| 44 |
answers = []
|
| 45 |
-
# Loop through the user questions and generate answers for each
|
| 46 |
for question in user_questions:
|
| 47 |
answer = await csv_chat(csv_url, question)
|
| 48 |
answers.append(dict(question=question, answer=answer))
|
|
@@ -50,38 +43,22 @@ async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
|
|
| 50 |
|
| 51 |
async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
|
| 52 |
"""
|
| 53 |
-
|
| 54 |
-
It uses the csv_chart function to process each question and return the chart URLs.
|
| 55 |
-
It returns a list of dictionaries containing the question and chart URL for each question.
|
| 56 |
-
|
| 57 |
-
Args:
|
| 58 |
-
csv_url (str): The URL of the CSV file.
|
| 59 |
-
user_questions (List[str]): A list of user questions.
|
| 60 |
-
chat_id (str): The chat ID for the session.
|
| 61 |
-
|
| 62 |
-
Returns:
|
| 63 |
-
List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question.
|
| 64 |
-
|
| 65 |
-
Example:
|
| 66 |
-
[
|
| 67 |
-
{"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"},
|
| 68 |
-
{"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"}
|
| 69 |
-
]
|
| 70 |
"""
|
| 71 |
-
|
| 72 |
logger.info("LLM using the csv chart function....")
|
| 73 |
logger.info(f"CSV URL: {csv_url}")
|
| 74 |
logger.info(f"User question: {user_questions}")
|
| 75 |
|
| 76 |
-
# Create an array to accumulate the charts
|
| 77 |
charts = []
|
| 78 |
-
# Loop through the user questions and generate charts for each
|
| 79 |
for question in user_questions:
|
| 80 |
chart = await csv_chart(csv_url, question, chat_id)
|
| 81 |
charts.append(dict(question=question, image_url=chart))
|
| 82 |
|
| 83 |
return charts
|
| 84 |
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id: str) -> Agent:
|
| 87 |
"""Create a PydanticAI agent configured for CSV analysis using Cerebras"""
|
|
@@ -127,29 +104,22 @@ def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id:
|
|
| 127 |
- **Metadata:** {csv_metadata}
|
| 128 |
- **History:** {conversation_history}
|
| 129 |
- **Chat ID:** {chat_id}
|
| 130 |
-
|
| 131 |
-
## Example Behavior:
|
| 132 |
-
|
| 133 |
-
**Question: "How many rows are in the dataset?"**
|
| 134 |
-
✅ Correct Response: "The dataset contains 1,000 rows."
|
| 135 |
-
❌ Wrong Response: "The dataset contains 1,000 rows. "
|
| 136 |
-
|
| 137 |
-
**Question: "Show me a chart of the distribution of ages"**
|
| 138 |
-
✅ Correct Response: "The age distribution shows... [call generate_chart tool] "
|
| 139 |
-
❌ Wrong Response: "The age distribution shows... "
|
| 140 |
-
|
| 141 |
-
**Remember:**
|
| 142 |
-
- Always generate fresh, tool-assisted responses
|
| 143 |
-
- Never reuse previous answers
|
| 144 |
-
- Never create fake image URLs
|
| 145 |
-
- Only use visualization when explicitly requested
|
| 146 |
"""
|
| 147 |
|
| 148 |
-
#
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
| 151 |
raise RuntimeError("No available API instances")
|
| 152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
return Agent(
|
| 154 |
model=model,
|
| 155 |
deps_type=str,
|
|
@@ -158,11 +128,13 @@ def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id:
|
|
| 158 |
retries=0
|
| 159 |
)
|
| 160 |
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
|
| 163 |
"""
|
| 164 |
-
CSV orchestrator with automatic failover on 429 errors using InstanceProvider.
|
| 165 |
-
Follows the same pattern as query_csv_agent_cerebras.
|
| 166 |
"""
|
| 167 |
logger.info(f"CSV URL: {csv_url}")
|
| 168 |
logger.info(f"User questions: {user_question}")
|
|
@@ -174,10 +146,10 @@ def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversatio
|
|
| 174 |
try:
|
| 175 |
logger.info(f"Attempt {attempt + 1}/{max_attempts}")
|
| 176 |
|
| 177 |
-
# Create agent
|
| 178 |
agent = create_orchestrator_agent(csv_url, conversation_history, chat_id)
|
| 179 |
|
| 180 |
-
# Run the agent
|
| 181 |
result = agent.run_sync(user_question)
|
| 182 |
|
| 183 |
logger.info(f"✓ Success with instance {attempt + 1}")
|
|
@@ -185,30 +157,24 @@ def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversatio
|
|
| 185 |
|
| 186 |
return result.data
|
| 187 |
|
| 188 |
-
except RateLimitError
|
| 189 |
logger.error(f"✗ Rate limit (429) hit for instance {attempt + 1}/{max_attempts}")
|
| 190 |
-
|
| 191 |
if attempt == max_attempts - 1:
|
| 192 |
raise RuntimeError(f"All {max_attempts} instances failed with rate limits")
|
| 193 |
-
|
| 194 |
logger.info("Trying next instance...")
|
| 195 |
continue
|
| 196 |
|
| 197 |
except APIError as e:
|
| 198 |
logger.error(f"✗ API error with instance {attempt + 1}: {str(e)}")
|
| 199 |
-
|
| 200 |
if attempt == max_attempts - 1:
|
| 201 |
raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
|
| 202 |
-
|
| 203 |
logger.info("Trying next instance...")
|
| 204 |
continue
|
| 205 |
|
| 206 |
except Exception as e:
|
| 207 |
logger.error(f"✗ Unexpected error with instance {attempt + 1}: {str(e)}")
|
| 208 |
-
|
| 209 |
if attempt == max_attempts - 1:
|
| 210 |
raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
|
| 211 |
-
|
| 212 |
logger.info("Trying next instance...")
|
| 213 |
continue
|
| 214 |
|
|
|
|
| 1 |
+
import logging
|
| 2 |
from typing import List, Any
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
# Pydantic AI imports
|
| 6 |
from pydantic_ai import Agent
|
| 7 |
+
from pydantic_ai.models.openai import OpenAIModel # <--- Essential fix
|
| 8 |
+
|
| 9 |
+
# OpenAI imports
|
| 10 |
from openai import RateLimitError, APIError
|
| 11 |
+
|
| 12 |
+
# Local application imports
|
| 13 |
from csv_service import get_csv_basic_info
|
| 14 |
from orchestrator_functions import csv_chart, csv_chat
|
| 15 |
from cerebras_instance_provider import InstanceProvider
|
| 16 |
+
|
| 17 |
+
# Setup logging
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
load_dotenv()
|
| 22 |
|
| 23 |
# Initialize the instance provider
|
| 24 |
instance_provider = InstanceProvider()
|
| 25 |
|
| 26 |
+
# ------------------------------------------------------------------
|
| 27 |
+
# 1. DEFINE TOOLS
|
| 28 |
+
# ------------------------------------------------------------------
|
| 29 |
|
|
|
|
| 30 |
async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
|
| 31 |
"""
|
| 32 |
+
Generates answers for user questions using the CSV URL.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
"""
|
|
|
|
| 34 |
logger.info("LLM using the csv chat function....")
|
| 35 |
logger.info(f"CSV URL: {csv_url}")
|
| 36 |
logger.info(f"User question: {user_questions}")
|
| 37 |
|
|
|
|
| 38 |
answers = []
|
|
|
|
| 39 |
for question in user_questions:
|
| 40 |
answer = await csv_chat(csv_url, question)
|
| 41 |
answers.append(dict(question=question, answer=answer))
|
|
|
|
| 43 |
|
| 44 |
async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
|
| 45 |
"""
|
| 46 |
+
Generates charts for user questions using the CSV URL.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
"""
|
|
|
|
| 48 |
logger.info("LLM using the csv chart function....")
|
| 49 |
logger.info(f"CSV URL: {csv_url}")
|
| 50 |
logger.info(f"User question: {user_questions}")
|
| 51 |
|
|
|
|
| 52 |
charts = []
|
|
|
|
| 53 |
for question in user_questions:
|
| 54 |
chart = await csv_chart(csv_url, question, chat_id)
|
| 55 |
charts.append(dict(question=question, image_url=chart))
|
| 56 |
|
| 57 |
return charts
|
| 58 |
|
| 59 |
+
# ------------------------------------------------------------------
|
| 60 |
+
# 2. AGENT CREATION (FIXED)
|
| 61 |
+
# ------------------------------------------------------------------
|
| 62 |
|
| 63 |
def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id: str) -> Agent:
|
| 64 |
"""Create a PydanticAI agent configured for CSV analysis using Cerebras"""
|
|
|
|
| 104 |
- **Metadata:** {csv_metadata}
|
| 105 |
- **History:** {conversation_history}
|
| 106 |
- **Chat ID:** {chat_id}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
"""
|
| 108 |
|
| 109 |
+
# ---------------------------------------------------------
|
| 110 |
+
# FIX: Unpack tuple and use OpenAIModel wrapper
|
| 111 |
+
# ---------------------------------------------------------
|
| 112 |
+
instance_data = instance_provider.get_next_instance()
|
| 113 |
+
|
| 114 |
+
if instance_data is None:
|
| 115 |
raise RuntimeError("No available API instances")
|
| 116 |
|
| 117 |
+
# Unpack the tuple (client, model_name)
|
| 118 |
+
client, model_name = instance_data
|
| 119 |
+
|
| 120 |
+
# Create the Pydantic AI Model using the specific client for this key
|
| 121 |
+
model = OpenAIModel(model_name, openai_client=client)
|
| 122 |
+
|
| 123 |
return Agent(
|
| 124 |
model=model,
|
| 125 |
deps_type=str,
|
|
|
|
| 128 |
retries=0
|
| 129 |
)
|
| 130 |
|
| 131 |
+
# ------------------------------------------------------------------
|
| 132 |
+
# 3. ORCHESTRATOR LOGIC (RETRY/FAILOVER)
|
| 133 |
+
# ------------------------------------------------------------------
|
| 134 |
|
| 135 |
def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
|
| 136 |
"""
|
| 137 |
+
CSV orchestrator with automatic failover on 429/API errors using InstanceProvider.
|
|
|
|
| 138 |
"""
|
| 139 |
logger.info(f"CSV URL: {csv_url}")
|
| 140 |
logger.info(f"User questions: {user_question}")
|
|
|
|
| 146 |
try:
|
| 147 |
logger.info(f"Attempt {attempt + 1}/{max_attempts}")
|
| 148 |
|
| 149 |
+
# Create agent (this internally rotates to the next key)
|
| 150 |
agent = create_orchestrator_agent(csv_url, conversation_history, chat_id)
|
| 151 |
|
| 152 |
+
# Run the agent
|
| 153 |
result = agent.run_sync(user_question)
|
| 154 |
|
| 155 |
logger.info(f"✓ Success with instance {attempt + 1}")
|
|
|
|
| 157 |
|
| 158 |
return result.data
|
| 159 |
|
| 160 |
+
except RateLimitError:
|
| 161 |
logger.error(f"✗ Rate limit (429) hit for instance {attempt + 1}/{max_attempts}")
|
|
|
|
| 162 |
if attempt == max_attempts - 1:
|
| 163 |
raise RuntimeError(f"All {max_attempts} instances failed with rate limits")
|
|
|
|
| 164 |
logger.info("Trying next instance...")
|
| 165 |
continue
|
| 166 |
|
| 167 |
except APIError as e:
|
| 168 |
logger.error(f"✗ API error with instance {attempt + 1}: {str(e)}")
|
|
|
|
| 169 |
if attempt == max_attempts - 1:
|
| 170 |
raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
|
|
|
|
| 171 |
logger.info("Trying next instance...")
|
| 172 |
continue
|
| 173 |
|
| 174 |
except Exception as e:
|
| 175 |
logger.error(f"✗ Unexpected error with instance {attempt + 1}: {str(e)}")
|
|
|
|
| 176 |
if attempt == max_attempts - 1:
|
| 177 |
raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
|
|
|
|
| 178 |
logger.info("Trying next instance...")
|
| 179 |
continue
|
| 180 |
|