Soumik555 commited on
Commit
88f85df
·
1 Parent(s): a7c62a2

changed model to .env gemini-flash-2.0

Browse files
controller.py CHANGED
@@ -28,7 +28,8 @@ import matplotlib
28
  import seaborn as sns
29
  from gemini_report_generator import generate_csv_report_gemini
30
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
31
- from orchestrator_agent import csv_orchestrator_chat
 
32
  from python_code_executor_service import CsvChatResult, PythonExecutor
33
  from supabase_service import upload_file_to_supabase
34
  from cerebras_csv_agent import query_csv_agent
@@ -401,7 +402,7 @@ async def csv_chat(request: Dict, authorization: str = Header(None)):
401
  logger.info("Processing detailed answer with orchestrator...")
402
  try:
403
  orchestrator_answer = await asyncio.to_thread(
404
- csv_orchestrator_chat, decoded_url, query, conversation_history, chat_id
405
  )
406
  if orchestrator_answer is not None:
407
  logger.info(f"Orchestrator answer successful: {str(orchestrator_answer)[:200]}...")
@@ -676,7 +677,7 @@ async def csv_chart(request: dict, authorization: str = Header(None)):
676
  # Use orchestrator to handle the user's chart query first
677
  if detailed_answer is True:
678
  orchestrator_answer = await asyncio.to_thread(
679
- csv_orchestrator_chat, csv_url, query, conversation_history, chat_id
680
  )
681
 
682
  if orchestrator_answer is not None:
 
28
  import seaborn as sns
29
  from gemini_report_generator import generate_csv_report_gemini
30
  from intitial_q_handler import if_initial_chart_question, if_initial_chat_question
31
+ from orc_agent_main_cerebras import csv_orchestrator_chat_cerebras
32
+ from orchestrator_agent import csv_orchestrator_chat_gemini
33
  from python_code_executor_service import CsvChatResult, PythonExecutor
34
  from supabase_service import upload_file_to_supabase
35
  from cerebras_csv_agent import query_csv_agent
 
402
  logger.info("Processing detailed answer with orchestrator...")
403
  try:
404
  orchestrator_answer = await asyncio.to_thread(
405
+ csv_orchestrator_chat_cerebras, decoded_url, query, conversation_history, chat_id
406
  )
407
  if orchestrator_answer is not None:
408
  logger.info(f"Orchestrator answer successful: {str(orchestrator_answer)[:200]}...")
 
677
  # Use orchestrator to handle the user's chart query first
678
  if detailed_answer is True:
679
  orchestrator_answer = await asyncio.to_thread(
680
+ csv_orchestrator_chat_gemini, csv_url, query, conversation_history, chat_id
681
  )
682
 
683
  if orchestrator_answer is not None:
orc_agent_main_cerebras.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Any
3
+ from pydantic_ai import Agent
4
+ from pydantic_ai.models.openai import OpenAIChatModel
5
+ from openai import RateLimitError, APIError
6
+ from csv_service import get_csv_basic_info
7
+ from orchestrator_functions import csv_chart, csv_chat
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+ # Load all API keys from the environment variable
13
+ CEREBRAS_API_KEYS = os.getenv("CEREBRAS_API_KEYS", "").split(",") # Expecting a comma-separated list of keys
14
+ CEREBRAS_BASE_URL = os.getenv("CEREBRAS_BASE_URL") # Cerebras API base URL
15
+ CEREBRAS_MODEL = os.getenv("CEREBRAS_MODEL") # Default Cerebras model
16
+
17
+ # Function to initialize the model with a specific API key
18
+ def initialize_model(api_key: str) -> OpenAIChatModel:
19
+ """Initialize Cerebras model using OpenAI-compatible interface"""
20
+ return OpenAIChatModel(
21
+ CEREBRAS_MODEL,
22
+ base_url=CEREBRAS_BASE_URL,
23
+ api_key=api_key
24
+ )
25
+
26
+ # Define the tools
27
+ async def generate_csv_answer(csv_url: str, user_questions: List[str]) -> Any:
28
+ """
29
+ This function generates answers for the given user questions using the CSV URL.
30
+ It uses the csv_chat function to process each question and return the answers.
31
+
32
+ Args:
33
+ csv_url (str): The URL of the CSV file.
34
+ user_questions (List[str]): A list of user questions.
35
+
36
+ Returns:
37
+ List[Dict[str, Any]]: A list of dictionaries containing the question and answer for each question.
38
+
39
+ Example:
40
+ [
41
+ {"question": "What is the average age of the customers?", "answer": "The average age is 35."},
42
+ {"question": "What is the most common gender?", "answer": "The most common gender is Male."}
43
+ ]
44
+ """
45
+
46
+ print("LLM using the csv chat function....")
47
+ print("CSV URL:", csv_url)
48
+ print("User question:", user_questions)
49
+
50
+ # Create an array to accumulate the answers
51
+ answers = []
52
+ # Loop through the user questions and generate answers for each
53
+ for question in user_questions:
54
+ answer = await csv_chat(csv_url, question)
55
+ answers.append(dict(question=question, answer=answer))
56
+ return answers
57
+
58
+ async def generate_chart(csv_url: str, user_questions: List[str], chat_id: str) -> Any:
59
+ """
60
+ This function generates charts for the given user questions using the CSV URL.
61
+ It uses the csv_chart function to process each question and return the chart URLs.
62
+ It returns a list of dictionaries containing the question and chart URL for each question.
63
+
64
+ Args:
65
+ csv_url (str): The URL of the CSV file.
66
+ user_questions (List[str]): A list of user questions.
67
+ chat_id (str): The chat ID for the session.
68
+
69
+ Returns:
70
+ List[Dict[str, Any]]: A list of dictionaries containing the question and chart URL for each question.
71
+
72
+ Example:
73
+ [
74
+ {"question": "What is the average age of the customers?", "chart_url": "https://example.com/chart1.png"},
75
+ {"question": "What is the most common gender?", "chart_url": "https://example.com/chart2.png"}
76
+ ]
77
+ """
78
+
79
+ print("LLM using the csv chart function....")
80
+ print("CSV URL:", csv_url)
81
+ print("User question:", user_questions)
82
+
83
+ # Create an array to accumulate the charts
84
+ charts = []
85
+ # Loop through the user questions and generate charts for each
86
+ for question in user_questions:
87
+ chart = await csv_chart(csv_url, question, chat_id)
88
+ charts.append(dict(question=question, image_url=chart))
89
+
90
+ return charts
91
+
92
+ # Function to create an agent with a specific CSV URL
93
+ def create_agent(csv_url: str, api_key: str, conversation_history: List, chat_id: str) -> Agent:
94
+ """Create a PydanticAI agent configured for CSV analysis using Cerebras"""
95
+ csv_metadata = get_csv_basic_info(csv_url)
96
+
97
+ system_prompt = f"""
98
+ # Role: Data Analyst Assistant
99
+ **Specialization:** CSV Analysis & Visualization
100
+ **Powered by:** Cerebras AI
101
+
102
+ ## Key Rules:
103
+ 1. **Always provide both:**
104
+ - Complete textual answer with explanations
105
+ - Visualization when applicable
106
+ 2. **Output Format:** Markdown compatible (visualizations as `![Image Description](url generated by tool)`)
107
+ 3. **Tool Handling:**
108
+ - Use `generate_csv_answer` for analysis
109
+ - Use `generate_chart` for visuals
110
+ - Never disclose tool names
111
+ 4. **Visualization Fallback:**
112
+ - If requested library (plotly, bokeh etc.) isn't available:
113
+ - Provide closest alternative
114
+ - Explain the limitation
115
+
116
+ ## Current Context:
117
+ - **Dataset:** {csv_url}
118
+ - **Metadata:** {csv_metadata}
119
+ - **History:** {conversation_history}
120
+ - **Chat ID:** {chat_id}
121
+
122
+ ## Required Output:
123
+ For every question return:
124
+ 1. Clear analysis answer
125
+ 2. Visualization (when possible, in markdown format)
126
+ 3. Follow-up suggestions
127
+
128
+ **Critical:**
129
+ - Never return partial responses - always combine both textual answers and visualizations when applicable.
130
+ - Always generate a fresh, tool-assisted response for every query, regardless of its similarity to any prior questions. Never reuse or return a previous answer.
131
+ - Leverage Cerebras's fast inference capabilities for efficient data analysis.
132
+ """
133
+
134
+ return Agent(
135
+ model=initialize_model(api_key),
136
+ deps_type=str,
137
+ tools=[generate_csv_answer, generate_chart],
138
+ system_prompt=system_prompt
139
+ )
140
+
141
+ def csv_orchestrator_chat_cerebras(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
142
+ """
143
+ Main orchestrator function that processes CSV analysis requests using Cerebras AI.
144
+
145
+ Args:
146
+ csv_url (str): URL of the CSV file to analyze
147
+ user_question (str): User's question about the CSV data
148
+ conversation_history (List): Previous conversation context
149
+ chat_id (str): Unique chat session identifier
150
+
151
+ Returns:
152
+ str: Analysis response or None if all API keys are exhausted
153
+ """
154
+ print("CSV URL:", csv_url)
155
+ print("User questions:", user_question)
156
+
157
+ # Validate API keys
158
+ if not CEREBRAS_API_KEYS or CEREBRAS_API_KEYS == ['']:
159
+ print("Error: No Cerebras API keys found. Please set CEREBRAS_API_KEYS environment variable.")
160
+ return "Configuration error: Cerebras API keys not found."
161
+
162
+ # Iterate through all API keys with improved error handling
163
+ for i, api_key in enumerate(CEREBRAS_API_KEYS):
164
+ api_key = api_key.strip() # Remove any whitespace
165
+ if not api_key:
166
+ continue
167
+
168
+ try:
169
+ print(f"Attempting with Cerebras API key #{i+1}")
170
+ agent = create_agent(csv_url, api_key, conversation_history, chat_id)
171
+ result = agent.run_sync(user_question)
172
+ print("Orchestrator Result:", result.data)
173
+ return result.data
174
+
175
+ except RateLimitError as e:
176
+ print(f"Rate limit exceeded for API key #{i+1}. Switching to the next key.")
177
+ continue
178
+
179
+ except APIError as e:
180
+ print(f"API error with key #{i+1}: {e}")
181
+ continue
182
+
183
+ except Exception as e:
184
+ print(f"Unexpected error with API key #{i+1}: {e}")
185
+ continue
186
+
187
+ # If all keys are exhausted or fail
188
+ error_msg = "All Cerebras API keys have been exhausted or failed. Please check your API keys and quotas."
189
+ print(error_msg)
190
+ return error_msg
orchestrator_agent.py CHANGED
@@ -134,7 +134,7 @@ For every question return:
134
  system_prompt=system_prompt
135
  )
136
 
137
- def csv_orchestrator_chat(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
138
  print("CSV URL:", csv_url)
139
  print("User questions:", user_question)
140
 
 
134
  system_prompt=system_prompt
135
  )
136
 
137
+ def csv_orchestrator_chat_gemini(csv_url: str, user_question: str, conversation_history: List, chat_id: str) -> str:
138
  print("CSV URL:", csv_url)
139
  print("User questions:", user_question)
140