Soumik Bose commited on
Commit
e4f58ae
·
1 Parent(s): bbde124
Files changed (1) hide show
  1. orc_agent_main_cerebras.py +40 -74
orc_agent_main_cerebras.py CHANGED
@@ -1,48 +1,41 @@
1
- import os
2
  from typing import List, Any
 
 
 
3
  from pydantic_ai import Agent
 
 
 
4
  from openai import RateLimitError, APIError
 
 
5
  from csv_service import get_csv_basic_info
6
  from orchestrator_functions import csv_chart, csv_chat
7
  from cerebras_instance_provider import InstanceProvider
8
- from dotenv import load_dotenv
9
- import logging
 
 
10
 
11
  load_dotenv()
12
 
13
  # Initialize the instance provider
14
  instance_provider = InstanceProvider()
15
 
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
 
18
 
19
- # Define the tools
20
  async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
21
  """
22
- This function generates answers for the given user questions using the CSV URL.
23
- It uses the csv_chat function to process each question and return the answers.
24
-
25
- Args:
26
- csv_url (str): The URL of the CSV file.
27
- user_questions (List[str]): A list of user questions.
28
-
29
- Returns:
30
- List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question.
31
-
32
- Example:
33
- [
34
- {"question": "What is the average age of the customers?", "answer": "The average age is 35."},
35
- {"question": "What is the most common gender?", "answer": "The most common gender is Male."}
36
- ]
37
  """
38
-
39
  logger.info("LLM using the csv chat function....")
40
  logger.info(f"CSV URL: {csv_url}")
41
  logger.info(f"User question: {user_questions}")
42
 
43
- # Create an array to accumulate the answers
44
  answers = []
45
- # Loop through the user questions and generate answers for each
46
  for question in user_questions:
47
  answer = await csv_chat(csv_url, question)
48
  answers.append(dict(question=question, answer=answer))
@@ -50,38 +43,22 @@ async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
50
 
51
  async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
52
  """
53
- This function generates charts for the given user questions using the CSV URL.
54
- It uses the csv_chart function to process each question and return the chart URLs.
55
- It returns a list of dictionaries containing the question and chart URL for each question.
56
-
57
- Args:
58
- csv_url (str): The URL of the CSV file.
59
- user_questions (List[str]): A list of user questions.
60
- chat_id (str): The chat ID for the session.
61
-
62
- Returns:
63
- List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question.
64
-
65
- Example:
66
- [
67
- {"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"},
68
- {"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"}
69
- ]
70
  """
71
-
72
  logger.info("LLM using the csv chart function....")
73
  logger.info(f"CSV URL: {csv_url}")
74
  logger.info(f"User question: {user_questions}")
75
 
76
- # Create an array to accumulate the charts
77
  charts = []
78
- # Loop through the user questions and generate charts for each
79
  for question in user_questions:
80
  chart = await csv_chart(csv_url, question, chat_id)
81
  charts.append(dict(question=question, image_url=chart))
82
 
83
  return charts
84
 
 
 
 
85
 
86
  def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id: str) -> Agent:
87
  """Create a PydanticAI agent configured for CSV analysis using Cerebras"""
@@ -127,29 +104,22 @@ def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id:
127
  - **Metadata:** {csv_metadata}
128
  - **History:** {conversation_history}
129
  - **Chat ID:** {chat_id}
130
-
131
- ## Example Behavior:
132
-
133
- **Question: "How many rows are in the dataset?"**
134
- ✅ Correct Response: "The dataset contains 1,000 rows."
135
- ❌ Wrong Response: "The dataset contains 1,000 rows. ![Number of Rows](url)"
136
-
137
- **Question: "Show me a chart of the distribution of ages"**
138
- ✅ Correct Response: "The age distribution shows... [call generate_chart tool] ![Age Distribution](actual_url_from_tool)"
139
- ❌ Wrong Response: "The age distribution shows... ![Age Distribution](https://example.com/chart.png)"
140
-
141
- **Remember:**
142
- - Always generate fresh, tool-assisted responses
143
- - Never reuse previous answers
144
- - Never create fake image URLs
145
- - Only use visualization when explicitly requested
146
  """
147
 
148
- # Get next available model instance
149
- model = instance_provider.get_next_instance()
150
- if model is None:
 
 
 
151
  raise RuntimeError("No available API instances")
152
 
 
 
 
 
 
 
153
  return Agent(
154
  model=model,
155
  deps_type=str,
@@ -158,11 +128,13 @@ def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id:
158
  retries=0
159
  )
160
 
 
 
 
161
 
162
  def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
163
  """
164
- CSV orchestrator with automatic failover on 429 errors using InstanceProvider.
165
- Follows the same pattern as query_csv_agent_cerebras.
166
  """
167
  logger.info(f"CSV URL: {csv_url}")
168
  logger.info(f"User questions: {user_question}")
@@ -174,10 +146,10 @@ def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversatio
174
  try:
175
  logger.info(f"Attempt {attempt + 1}/{max_attempts}")
176
 
177
- # Create agent with next instance
178
  agent = create_orchestrator_agent(csv_url, conversation_history, chat_id)
179
 
180
- # Run the agent - this is where rate limits typically occur
181
  result = agent.run_sync(user_question)
182
 
183
  logger.info(f"✓ Success with instance {attempt + 1}")
@@ -185,30 +157,24 @@ def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversatio
185
 
186
  return result.data
187
 
188
- except RateLimitError as e:
189
  logger.error(f"✗ Rate limit (429) hit for instance {attempt + 1}/{max_attempts}")
190
-
191
  if attempt == max_attempts - 1:
192
  raise RuntimeError(f"All {max_attempts} instances failed with rate limits")
193
-
194
  logger.info("Trying next instance...")
195
  continue
196
 
197
  except APIError as e:
198
  logger.error(f"✗ API error with instance {attempt + 1}: {str(e)}")
199
-
200
  if attempt == max_attempts - 1:
201
  raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
202
-
203
  logger.info("Trying next instance...")
204
  continue
205
 
206
  except Exception as e:
207
  logger.error(f"✗ Unexpected error with instance {attempt + 1}: {str(e)}")
208
-
209
  if attempt == max_attempts - 1:
210
  raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
211
-
212
  logger.info("Trying next instance...")
213
  continue
214
 
 
1
+ import logging
2
  from typing import List, Any
3
+ from dotenv import load_dotenv
4
+
5
+ # Pydantic AI imports
6
  from pydantic_ai import Agent
7
+ from pydantic_ai.models.openai import OpenAIModel # <--- Essential fix
8
+
9
+ # OpenAI imports
10
  from openai import RateLimitError, APIError
11
+
12
+ # Local application imports
13
  from csv_service import get_csv_basic_info
14
  from orchestrator_functions import csv_chart, csv_chat
15
  from cerebras_instance_provider import InstanceProvider
16
+
17
+ # Setup logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
21
  load_dotenv()
22
 
23
  # Initialize the instance provider
24
  instance_provider = InstanceProvider()
25
 
26
+ # ------------------------------------------------------------------
27
+ # 1. DEFINE TOOLS
28
+ # ------------------------------------------------------------------
29
 
 
30
  async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
31
  """
32
+ Generates answers for user questions using the CSV URL.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  """
 
34
  logger.info("LLM using the csv chat function....")
35
  logger.info(f"CSV URL: {csv_url}")
36
  logger.info(f"User question: {user_questions}")
37
 
 
38
  answers = []
 
39
  for question in user_questions:
40
  answer = await csv_chat(csv_url, question)
41
  answers.append(dict(question=question, answer=answer))
 
43
 
44
  async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
45
  """
46
+ Generates charts for user questions using the CSV URL.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  """
 
48
  logger.info("LLM using the csv chart function....")
49
  logger.info(f"CSV URL: {csv_url}")
50
  logger.info(f"User question: {user_questions}")
51
 
 
52
  charts = []
 
53
  for question in user_questions:
54
  chart = await csv_chart(csv_url, question, chat_id)
55
  charts.append(dict(question=question, image_url=chart))
56
 
57
  return charts
58
 
59
+ # ------------------------------------------------------------------
60
+ # 2. AGENT CREATION (FIXED)
61
+ # ------------------------------------------------------------------
62
 
63
  def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id: str) -> Agent:
64
  """Create a PydanticAI agent configured for CSV analysis using Cerebras"""
 
104
  - **Metadata:** {csv_metadata}
105
  - **History:** {conversation_history}
106
  - **Chat ID:** {chat_id}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  """
108
 
109
+ # ---------------------------------------------------------
110
+ # FIX: Unpack tuple and use OpenAIModel wrapper
111
+ # ---------------------------------------------------------
112
+ instance_data = instance_provider.get_next_instance()
113
+
114
+ if instance_data is None:
115
  raise RuntimeError("No available API instances")
116
 
117
+ # Unpack the tuple (client, model_name)
118
+ client, model_name = instance_data
119
+
120
+ # Create the Pydantic AI Model using the specific client for this key
121
+ model = OpenAIModel(model_name, openai_client=client)
122
+
123
  return Agent(
124
  model=model,
125
  deps_type=str,
 
128
  retries=0
129
  )
130
 
131
+ # ------------------------------------------------------------------
132
+ # 3. ORCHESTRATOR LOGIC (RETRY/FAILOVER)
133
+ # ------------------------------------------------------------------
134
 
135
  def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
136
  """
137
+ CSV orchestrator with automatic failover on 429/API errors using InstanceProvider.
 
138
  """
139
  logger.info(f"CSV URL: {csv_url}")
140
  logger.info(f"User questions: {user_question}")
 
146
  try:
147
  logger.info(f"Attempt {attempt + 1}/{max_attempts}")
148
 
149
+ # Create agent (this internally rotates to the next key)
150
  agent = create_orchestrator_agent(csv_url, conversation_history, chat_id)
151
 
152
+ # Run the agent
153
  result = agent.run_sync(user_question)
154
 
155
  logger.info(f"✓ Success with instance {attempt + 1}")
 
157
 
158
  return result.data
159
 
160
+ except RateLimitError:
161
  logger.error(f"✗ Rate limit (429) hit for instance {attempt + 1}/{max_attempts}")
 
162
  if attempt == max_attempts - 1:
163
  raise RuntimeError(f"All {max_attempts} instances failed with rate limits")
 
164
  logger.info("Trying next instance...")
165
  continue
166
 
167
  except APIError as e:
168
  logger.error(f"✗ API error with instance {attempt + 1}: {str(e)}")
 
169
  if attempt == max_attempts - 1:
170
  raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
 
171
  logger.info("Trying next instance...")
172
  continue
173
 
174
  except Exception as e:
175
  logger.error(f"✗ Unexpected error with instance {attempt + 1}: {str(e)}")
 
176
  if attempt == max_attempts - 1:
177
  raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
 
178
  logger.info("Trying next instance...")
179
  continue
180