Soumik Bose commited on
Commit
60aa09c
·
1 Parent(s): e4f58ae
cerebras_instance_provider.py CHANGED
@@ -1,85 +1,50 @@
1
- # instance_provider.py
2
  import os
3
  import logging
4
- from typing import List, Optional, Tuple
5
- from openai import OpenAI
6
  from dotenv import load_dotenv
7
 
8
  load_dotenv()
9
 
10
- # Setup basic logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
  class InstanceProvider:
15
- """Manages multiple Cerebras/OpenAI clients with simple rotation"""
16
 
17
  def __init__(self):
18
- self.clients: List[OpenAI] = []
19
  self.current_index = 0
20
- self.model_name = os.getenv("CEREBRAS_MODEL") or "llama3.1-70b"
 
21
  self._initialize_instances()
22
 
23
  def _initialize_instances(self):
24
- """Load all API keys and create OpenAI clients"""
25
- # Split keys by comma
26
- api_keys = os.getenv("CEREBRAS_API_KEYS", "").split(",")
27
- base_url = os.getenv("CEREBRAS_BASE_URL")
28
 
29
- for key in api_keys:
30
- key = key.strip()
31
- if key:
32
- try:
33
- # Create a standard OpenAI client for this key
34
- client = OpenAI(
35
- base_url=base_url,
36
- api_key=key
37
- )
38
- self.clients.append(client)
39
- except Exception as e:
40
- logger.error(f"Failed to initialize key {key[:4]}...: {e}")
41
 
42
- def get_next_instance(self) -> Optional[Tuple[OpenAI, str]]:
43
  """
44
- Get next client in rotation.
45
- Returns: Tuple (OpenAI_Client, Model_Name)
46
  """
47
- if not self.clients:
48
  return None
49
 
50
- # Get current client
51
- client = self.clients[self.current_index]
52
 
53
- # Rotate index for the next call (Round Robin)
54
- self.current_index = (self.current_index + 1) % len(self.clients)
55
 
56
- return client, self.model_name
 
 
 
 
57
 
58
  def get_total_instances(self) -> int:
59
- """Return total number of active clients available"""
60
- return len(self.clients)
61
-
62
- def chat_completion_with_retry(self, messages: list, **kwargs):
63
- """
64
- Helper function that automatically retries across all instances
65
- if one fails.
66
- """
67
- total_attempts = self.get_total_instances()
68
-
69
- for attempt in range(total_attempts):
70
- client, model = self.get_next_instance()
71
-
72
- try:
73
- # Execute the API call
74
- response = client.chat.completions.create(
75
- model=model,
76
- messages=messages,
77
- **kwargs
78
- )
79
- return response
80
- except Exception as e:
81
- logger.warning(f"Instance failed (Attempt {attempt+1}/{total_attempts}): {e}")
82
- # Loop continues to next instance automatically
83
- continue
84
-
85
- raise RuntimeError(f"All {total_attempts} instances failed.")
 
 
1
  import os
2
  import logging
3
+ from typing import List, Optional, Dict
 
4
  from dotenv import load_dotenv
5
 
6
  load_dotenv()
7
 
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
  class InstanceProvider:
12
+ """Manages multiple Cerebras API keys with simple rotation"""
13
 
14
  def __init__(self):
15
+ self.api_keys: List[str] = []
16
  self.current_index = 0
17
+ self.base_url = os.getenv("CEREBRAS_BASE_URL")
18
+ self.model_name = os.getenv("CEREBRAS_MODEL", "llama3.1-70b")
19
  self._initialize_instances()
20
 
21
  def _initialize_instances(self):
22
+ """Load all API keys into a list"""
23
+ keys_str = os.getenv("CEREBRAS_API_KEYS", "")
24
+ self.api_keys = [k.strip() for k in keys_str.split(",") if k.strip()]
 
25
 
26
+ if not self.api_keys:
27
+ logger.error("No API keys found in CEREBRAS_API_KEYS")
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def get_next_instance(self) -> Optional[Dict[str, str]]:
30
  """
31
+ Returns a dictionary with the credentials for the next instance.
32
+ Returns: {'api_key': str, 'base_url': str, 'model': str}
33
  """
34
+ if not self.api_keys:
35
  return None
36
 
37
+ # Get current key
38
+ key = self.api_keys[self.current_index]
39
 
40
+ # Rotate index for the next call
41
+ self.current_index = (self.current_index + 1) % len(self.api_keys)
42
 
43
+ return {
44
+ "api_key": key,
45
+ "base_url": self.base_url,
46
+ "model": self.model_name
47
+ }
48
 
49
  def get_total_instances(self) -> int:
50
+ return len(self.api_keys)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
orc_agent_main_cerebras.py CHANGED
@@ -1,20 +1,20 @@
 
1
  import logging
2
- from typing import List, Any
3
- from dotenv import load_dotenv
4
 
5
  # Pydantic AI imports
6
  from pydantic_ai import Agent
7
- from pydantic_ai.models.openai import OpenAIModel # <--- Essential fix
8
 
9
- # OpenAI imports
10
  from openai import RateLimitError, APIError
11
 
12
  # Local application imports
13
  from csv_service import get_csv_basic_info
14
  from orchestrator_functions import csv_chart, csv_chat
15
  from cerebras_instance_provider import InstanceProvider
 
16
 
17
- # Setup logging
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
@@ -28,13 +28,8 @@ instance_provider = InstanceProvider()
28
  # ------------------------------------------------------------------
29
 
30
  async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
31
- """
32
- Generates answers for user questions using the CSV URL.
33
- """
34
- logger.info("LLM using the csv chat function....")
35
- logger.info(f"CSV URL: {csv_url}")
36
- logger.info(f"User question: {user_questions}")
37
-
38
  answers = []
39
  for question in user_questions:
40
  answer = await csv_chat(csv_url, question)
@@ -42,140 +37,95 @@ async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
42
  return answers
43
 
44
  async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
45
- """
46
- Generates charts for user questions using the CSV URL.
47
- """
48
- logger.info("LLM using the csv chart function....")
49
- logger.info(f"CSV URL: {csv_url}")
50
- logger.info(f"User question: {user_questions}")
51
-
52
  charts = []
53
  for question in user_questions:
54
- chart = await csv_chart(csv_url, question, chat_id)
55
- charts.append(dict(question=question, image_url=chart))
56
-
57
  return charts
58
 
59
  # ------------------------------------------------------------------
60
- # 2. AGENT CREATION (FIXED)
61
  # ------------------------------------------------------------------
62
 
63
  def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id: str) -> Agent:
64
- """Create a PydanticAI agent configured for CSV analysis using Cerebras"""
65
- csv_metadata = get_csv_basic_info(csv_url)
66
 
67
- system_prompt = f"""
68
- # Role: Data Analyst Assistant
69
- **Specialization:** CSV Analysis & Visualization
70
-
71
- ## Critical Rules:
72
-
73
- ### 1. Tool Usage - MANDATORY
74
- - You MUST use `generate_csv_answer` tool for ALL data analysis questions
75
- - You MUST use `generate_chart` tool ONLY when explicitly asked for visualization, graph, chart, or plot
76
- - NEVER generate image markdown syntax (![...](url)) unless you have called `generate_chart` tool and received a real URL
77
- - NEVER fabricate or create placeholder image URLs
78
-
79
- ### 2. When to Generate Visualizations
80
- **ONLY create visualizations when the user explicitly requests:**
81
- - "show me a chart/graph/plot"
82
- - "visualize this data"
83
- - "create a visualization"
84
- - "plot the data"
85
- - Any similar explicit visualization request
86
-
87
- **DO NOT create visualizations for:**
88
- - Simple data retrieval questions (e.g., "how many rows?", "what is the average?")
89
- - Questions that can be answered with text alone
90
- - Questions that don't explicitly ask for visual representation
91
-
92
- ### 3. Response Format
93
- - For questions WITHOUT visualization request: Provide only the textual answer from `generate_csv_answer`
94
- - For questions WITH visualization request: Provide both textual answer AND call `generate_chart`, then include the image using the URL returned by the tool
95
-
96
- ### 4. Output Guidelines
97
- - Use markdown formatting for text responses
98
- - Only include image syntax `![Description](url)` if you actually called `generate_chart` and got a real URL back
99
- - Provide clear, concise answers with explanations
100
- - Never mention tool names to the user
101
-
102
- ## Current Context:
103
- - **Dataset:** {csv_url}
104
- - **Metadata:** {csv_metadata}
105
- - **History:** {conversation_history}
106
- - **Chat ID:** {chat_id}
107
- """
108
 
109
- # ---------------------------------------------------------
110
- # FIX: Unpack tuple and use OpenAIModel wrapper
111
- # ---------------------------------------------------------
112
- instance_data = instance_provider.get_next_instance()
113
 
114
- if instance_data is None:
115
- raise RuntimeError("No available API instances")
 
 
 
 
 
116
 
117
- # Unpack the tuple (client, model_name)
118
- client, model_name = instance_data
119
 
120
- # Create the Pydantic AI Model using the specific client for this key
121
- model = OpenAIModel(model_name, openai_client=client)
 
 
 
 
 
 
 
 
 
122
 
123
  return Agent(
124
  model=model,
125
  deps_type=str,
126
  tools=[generate_csv_answer, generate_chart],
127
  system_prompt=system_prompt,
128
- retries=0
129
  )
130
 
131
  # ------------------------------------------------------------------
132
- # 3. ORCHESTRATOR LOGIC (RETRY/FAILOVER)
133
  # ------------------------------------------------------------------
134
 
135
  def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
136
  """
137
- CSV orchestrator with automatic failover on 429/API errors using InstanceProvider.
138
  """
139
- logger.info(f"CSV URL: {csv_url}")
140
- logger.info(f"User questions: {user_question}")
141
 
142
  max_attempts = instance_provider.get_total_instances()
143
-
144
- # Try with different instances until one works
 
 
145
  for attempt in range(max_attempts):
146
  try:
147
- logger.info(f"Attempt {attempt + 1}/{max_attempts}")
148
-
149
- # Create agent (this internally rotates to the next key)
150
  agent = create_orchestrator_agent(csv_url, conversation_history, chat_id)
151
 
152
- # Run the agent
 
 
153
  result = agent.run_sync(user_question)
154
 
155
- logger.info(f"✓ Success with instance {attempt + 1}")
156
- logger.info(f"Orchestrator Result: {result.data}")
157
-
158
  return result.data
159
 
160
- except RateLimitError:
161
- logger.error(f" Rate limit (429) hit for instance {attempt + 1}/{max_attempts}")
162
- if attempt == max_attempts - 1:
163
- raise RuntimeError(f"All {max_attempts} instances failed with rate limits")
164
- logger.info("Trying next instance...")
165
- continue
166
-
167
- except APIError as e:
168
- logger.error(f"✗ API error with instance {attempt + 1}: {str(e)}")
169
- if attempt == max_attempts - 1:
170
- raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
171
- logger.info("Trying next instance...")
172
  continue
173
 
174
  except Exception as e:
175
- logger.error(f"✗ Unexpected error with instance {attempt + 1}: {str(e)}")
176
- if attempt == max_attempts - 1:
177
- raise RuntimeError(f"All {max_attempts} instances failed. Last error: {str(e)}")
178
- logger.info("Trying next instance...")
179
  continue
180
 
181
- raise RuntimeError(f"Failed after {max_attempts} attempts")
 
1
+ import os
2
  import logging
3
+ from typing import List, Any, Dict
 
4
 
5
  # Pydantic AI imports
6
  from pydantic_ai import Agent
7
+ from pydantic_ai.models.openai import OpenAIModel
8
 
9
+ # Error handling
10
  from openai import RateLimitError, APIError
11
 
12
  # Local application imports
13
  from csv_service import get_csv_basic_info
14
  from orchestrator_functions import csv_chart, csv_chat
15
  from cerebras_instance_provider import InstanceProvider
16
+ from dotenv import load_dotenv
17
 
 
18
  logging.basicConfig(level=logging.INFO)
19
  logger = logging.getLogger(__name__)
20
 
 
28
  # ------------------------------------------------------------------
29
 
30
  async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
31
+ """Generates answers for user questions using the CSV URL."""
32
+ logger.info(f"Tool: generate_csv_answer | Questions: {user_questions}")
 
 
 
 
 
33
  answers = []
34
  for question in user_questions:
35
  answer = await csv_chat(csv_url, question)
 
37
  return answers
38
 
39
  async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
40
+ """Generates charts for user questions using the CSV URL."""
41
+ logger.info(f"Tool: generate_chart | Questions: {user_questions}")
 
 
 
 
 
42
  charts = []
43
  for question in user_questions:
44
+ chart_url = await csv_chart(csv_url, question, chat_id)
45
+ charts.append(dict(question=question, image_url=chart_url))
 
46
  return charts
47
 
48
  # ------------------------------------------------------------------
49
+ # 2. AGENT CREATION
50
  # ------------------------------------------------------------------
51
 
52
  def create_orchestrator_agent(csv_url: str, conversation_history: List, chat_id: str) -> Agent:
53
+ """Create a PydanticAI agent with a specific API Key instance"""
 
54
 
55
+ # 1. Get credentials dictionary from provider
56
+ instance_config = instance_provider.get_next_instance()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ if instance_config is None:
59
+ raise RuntimeError("No available API instances (Check CEREBRAS_API_KEYS)")
 
 
60
 
61
+ # 2. Create the Model using standard arguments.
62
+ # We pass api_key and base_url directly. PydanticAI will handle the client creation.
63
+ model = OpenAIModel(
64
+ instance_config['model'],
65
+ base_url=instance_config['base_url'],
66
+ api_key=instance_config['api_key'],
67
+ )
68
 
69
+ csv_metadata = get_csv_basic_info(csv_url)
 
70
 
71
+ system_prompt = f"""
72
+ # Role: Data Analyst Assistant
73
+ **Context:** Analyzing CSV: {csv_url}
74
+ **Metadata:** {csv_metadata}
75
+ **Chat ID:** {chat_id}
76
+
77
+ ## Rules:
78
+ 1. Use `generate_csv_answer` for text questions.
79
+ 2. Use `generate_chart` ONLY if explicitly asked for visual/plot/graph.
80
+ 3. Output format: Markdown. If chart generated, use ![Desc](url).
81
+ """
82
 
83
  return Agent(
84
  model=model,
85
  deps_type=str,
86
  tools=[generate_csv_answer, generate_chart],
87
  system_prompt=system_prompt,
88
+ retries=0 # We handle retries manually in the loop below
89
  )
90
 
91
  # ------------------------------------------------------------------
92
+ # 3. ORCHESTRATOR LOGIC (RETRY LOOP)
93
  # ------------------------------------------------------------------
94
 
95
  def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
96
  """
97
+ Orchestrator that rebuilds the agent with a new key upon failure.
98
  """
99
+ logger.info(f"Starting Orchestrator | Query: {user_question}")
 
100
 
101
  max_attempts = instance_provider.get_total_instances()
102
+ if max_attempts == 0:
103
+ return "System Error: No API keys configured."
104
+
105
+ # Loop through available keys
106
  for attempt in range(max_attempts):
107
  try:
108
+ # 1. Create a NEW agent (this fetches the NEXT key automatically)
 
 
109
  agent = create_orchestrator_agent(csv_url, conversation_history, chat_id)
110
 
111
+ logger.info(f"Attempt {attempt + 1}/{max_attempts} using key ending in ...{agent.model.client.api_key[-4:] if hasattr(agent.model, 'client') else '****'}")
112
+
113
+ # 2. Run the agent
114
  result = agent.run_sync(user_question)
115
 
116
+ logger.info(f"✓ Success on attempt {attempt + 1}")
 
 
117
  return result.data
118
 
119
+ except (RateLimitError, APIError) as e:
120
+ logger.warning(f" API Error on attempt {attempt + 1}: {e}")
121
+ logger.info("Rotating to next instance...")
122
+ # The loop continues, calling create_orchestrator_agent() again, getting the next key.
 
 
 
 
 
 
 
 
123
  continue
124
 
125
  except Exception as e:
126
+ # Catch unexpected Pydantic/Python errors
127
+ logger.error(f"✗ Unexpected Error on attempt {attempt + 1}: {e}")
128
+ logger.info("Rotating to next instance...")
 
129
  continue
130
 
131
+ raise RuntimeError(f"Failed to generate response after {max_attempts} attempts.")