Soumik Bose commited on
Commit
3088397
·
1 Parent(s): b0d5ff7

go mdrfkr

Browse files
cerebras_openrouter_chart_generator.py CHANGED
@@ -121,97 +121,75 @@ def extract_code(content: str) -> str:
121
  return match.group(1).strip() if match else content.replace("```", "").strip()
122
 
123
  # ==================================================================================
124
- # CEREBRAS AGENT (MULTI-KEY)
125
  # ==================================================================================
126
 
127
- def generate_cerebras_chart(csv_url: str, query: str, max_retries: int = 3) -> Optional[str]:
128
- """
129
- Cerebras-specific orchestrator.
130
- Iterates through ALL available Cerebras API keys if failures occur.
131
- """
132
- logger.info("Starting CEREBRAS chart generation...")
133
 
134
- # Get all keys and clean them
135
- raw_keys = os.getenv("CEREBRAS_API_KEYS", "").split(",")
136
- api_keys = [k.strip() for k in raw_keys if k.strip()]
137
 
138
- if not api_keys:
139
- logger.error("No CEREBRAS_API_KEYS found.")
140
- return None
141
 
142
- base_url = os.getenv("CEREBRAS_BASE_URL", "https://api.cerebras.ai/v1")
143
- model_name = os.getenv("CEREBRAS_CODING_MODEL", "llama3.1-70b")
 
 
144
  data_context = get_data_context(csv_url)
 
 
 
145
 
146
- # LOOP THROUGH EACH KEY
147
- for key_index, api_key in enumerate(api_keys):
148
- logger.info(f"--- Attempting Cerebras with Key Index {key_index + 1}/{len(api_keys)} ---")
149
-
150
  try:
151
- client = OpenAI(base_url=base_url, api_key=api_key)
152
- error_history = []
153
 
154
- # Inner Loop: Attempt to fix code `max_retries` times with CURRENT key
155
- for attempt in range(1, max_retries + 1):
156
- try:
157
- system_prompt, user_prompt = construct_prompt(query, csv_url, data_context, error_history)
158
-
159
- response = client.chat.completions.create(
160
- model=model_name,
161
- messages=[
162
- {"role": "system", "content": system_prompt},
163
- {"role": "user", "content": user_prompt}
164
- ],
165
- temperature=0.1
166
- )
167
-
168
- code = extract_code(response.choices[0].message.content)
169
- file_path, error = execute_and_capture(code, csv_url)
170
-
171
- if file_path:
172
- logger.info(f"Cerebras Success (Key Index {key_index}): {file_path}")
173
- return file_path
174
-
175
- # Logic error in code, record and retry loop
176
- error_history.append(f"Attempt {attempt} Error:\n{error}")
177
- logger.warning(f"Cerebras Key {key_index} - Attempt {attempt} Logic Fail: {error}")
178
-
179
- except Exception as api_e:
180
- # If it's an API error (Auth, Rate Limit, Network), break inner loop to switch keys
181
- logger.error(f"Cerebras API Error with Key {key_index}: {api_e}")
182
- break
183
 
184
- logger.warning(f"Cerebras Key {key_index} exhausted retries or failed. Moving to next key...")
185
-
186
  except Exception as e:
187
- logger.error(f"Fatal error initializing Cerebras Key {key_index}: {e}")
188
- continue
189
-
190
- logger.error("All Cerebras keys failed.")
191
  return None
192
 
193
  # ==================================================================================
194
  # OPENROUTER AGENT (FALLBACK)
195
  # ==================================================================================
196
 
 
 
 
 
 
 
 
197
  def generate_openrouter_chart(csv_url: str, query: str, max_retries: int = 3) -> Optional[str]:
198
  """OpenRouter-specific orchestrator (Fallback)."""
199
  logger.info("Starting OPENROUTER chart generation (Fallback)...")
200
-
201
- # We stick to the first key for OpenRouter as per original logic, but can be expanded
202
- api_keys = os.getenv("OPENROUTER_API_KEYS", "").split(",")
203
- if not api_keys or not api_keys[0]:
204
- logger.error("Missing OPENROUTER_API_KEYS")
205
- return None
206
-
207
- client = OpenAI(
208
- base_url=os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"),
209
- api_key=api_keys[0]
210
- )
211
-
212
  model_name = os.getenv("OPENROUTER_MODEL", "openai/gpt-4o")
213
  data_context = get_data_context(csv_url)
214
  error_history = []
 
 
215
 
216
  for attempt in range(1, max_retries + 1):
217
  try:
@@ -240,39 +218,3 @@ def generate_openrouter_chart(csv_url: str, query: str, max_retries: int = 3) ->
240
  error_history.append(f"System Error: {str(e)}")
241
 
242
  return None
243
-
244
- # ==================================================================================
245
- # MAIN ORCHESTRATOR
246
- # ==================================================================================
247
-
248
- def generate_cerebras_openrouter_chart(csv_url: str, query: str, max_retries: int = 3) -> Dict[str, str]:
249
- """
250
- Orchestrator that:
251
- 1. Tries all Cerebras keys.
252
- 2. If all fail, tries OpenRouter.
253
- 3. Returns result dictionary.
254
- """
255
-
256
- # 1. Try Cerebras (Handles key rotation internally)
257
- chart_path = generate_cerebras_chart(csv_url, query, max_retries)
258
-
259
- # 2. Fallback to OpenRouter if Cerebras failed
260
- if not chart_path:
261
- logger.warning("Cerebras failed (all keys). Switching to OpenRouter...")
262
- chart_path = generate_openrouter_chart(csv_url, query, max_retries)
263
-
264
- # 3. Construct Response
265
- if chart_path:
266
- return {
267
- "image_url": chart_path,
268
- "message": "Chart generated successfully."
269
- }
270
- else:
271
- return {
272
- "image_url": "",
273
- "message": "Failed to generate chart after trying all providers."
274
- }
275
-
276
- # ==================================================================================
277
- # ENTRY POINT
278
- # ==================================================================================
 
121
  return match.group(1).strip() if match else content.replace("```", "").strip()
122
 
123
  # ==================================================================================
124
+ # CEREBRAS AGENT
125
  # ==================================================================================
126
 
127
+ def get_cerebras_client() -> OpenAI:
128
+ api_keys = os.getenv("CEREBRAS_API_KEYS", "").split(",")
129
+ base_url = os.getenv("CEREBRAS_BASE_URL", "https://api.cerebras.ai/v1")
 
 
 
130
 
131
+ if not api_keys or not api_keys[0]:
132
+ raise ValueError("Missing CEREBRAS_API_KEYS")
 
133
 
134
+ # Use the first key (can implement rotation if needed)
135
+ return OpenAI(base_url=base_url, api_key=api_keys[0])
 
136
 
137
+ def generate_cerebras_chart(csv_url: str, query: str, max_retries: int = 3) -> Optional[str]:
138
+ """Cerebras-specific orchestrator."""
139
+ logger.info("Starting CEREBRAS chart generation...")
140
+ model_name = os.getenv("CEREBRAS_MODEL", "llama3.1-70b")
141
  data_context = get_data_context(csv_url)
142
+ error_history = []
143
+
144
+ client = get_cerebras_client()
145
 
146
+ for attempt in range(1, max_retries + 1):
 
 
 
147
  try:
148
+ system_prompt, user_prompt = construct_prompt(query, csv_url, data_context, error_history)
 
149
 
150
+ response = client.chat.completions.create(
151
+ model=model_name,
152
+ messages=[
153
+ {"role": "system", "content": system_prompt},
154
+ {"role": "user", "content": user_prompt}
155
+ ],
156
+ temperature=0.1
157
+ )
158
+
159
+ code = extract_code(response.choices[0].message.content)
160
+ file_path, error = execute_and_capture(code, csv_url)
161
+
162
+ if file_path:
163
+ logger.info(f"Cerebras Success: {file_path}")
164
+ return file_path
165
+
166
+ error_history.append(f"Attempt {attempt} Error:\n{error}")
 
 
 
 
 
 
 
 
 
 
 
 
167
 
 
 
168
  except Exception as e:
169
+ logger.error(f"Cerebras Attempt {attempt} failed: {e}")
170
+ error_history.append(f"System Error: {str(e)}")
171
+
 
172
  return None
173
 
174
  # ==================================================================================
175
  # OPENROUTER AGENT (FALLBACK)
176
  # ==================================================================================
177
 
178
+ def get_openrouter_client() -> OpenAI:
179
+ api_keys = os.getenv("OPENROUTER_API_KEYS", "").split(",")
180
+ base_url = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
181
+ if not api_keys or not api_keys[0]:
182
+ raise ValueError("Missing OPENROUTER_API_KEYS")
183
+ return OpenAI(base_url=base_url, api_key=api_keys[0])
184
+
185
  def generate_openrouter_chart(csv_url: str, query: str, max_retries: int = 3) -> Optional[str]:
186
  """OpenRouter-specific orchestrator (Fallback)."""
187
  logger.info("Starting OPENROUTER chart generation (Fallback)...")
 
 
 
 
 
 
 
 
 
 
 
 
188
  model_name = os.getenv("OPENROUTER_MODEL", "openai/gpt-4o")
189
  data_context = get_data_context(csv_url)
190
  error_history = []
191
+
192
+ client = get_openrouter_client()
193
 
194
  for attempt in range(1, max_retries + 1):
195
  try:
 
218
  error_history.append(f"System Error: {str(e)}")
219
 
220
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
orchestrator_functions.py CHANGED
@@ -12,7 +12,7 @@ from pandasai import SmartDataframe
12
  from langchain_groq.chat_models import ChatGroq
13
  from dotenv import load_dotenv
14
  from pydantic import BaseModel
15
- from cerebras_openrouter_chart_generator import generate_cerebras_openrouter_chart
16
  from csv_service import clean_data, extract_chart_filenames
17
  from langchain_groq import ChatGroq
18
  import pandas as pd
@@ -395,21 +395,61 @@ async def csv_chart(csv_url: str, query: str, chat_id: str) -> Dict[str, str]:
395
  """
396
  Generate a chart based on the provided CSV URL and query.
397
  Strategy:
398
- 1. Try Cerebras Agent (All Keys).
399
  2. If failed/empty, try OpenRouter Agent.
400
- 3. Return result.
401
  """
402
  logger.info(f"Received csv_chart request. Chat ID: {chat_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  try:
404
- logger.info("Attempting Chart Generation...")
405
- # Run synchronous generation logic in a thread
406
- csv_chart_response = await asyncio.to_thread(
407
- generate_cerebras_openrouter_chart, csv_url, query, 3
408
  )
409
 
410
- logger.info(f"Chart Response: {csv_chart_response}")
411
- return csv_chart_response
 
 
 
 
412
  except Exception as e:
413
- error_message = f"Chart Generation Error: {str(e)}"
414
- logger.error(error_message)
415
- return {"image_url": "", "message": error_message}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from langchain_groq.chat_models import ChatGroq
13
  from dotenv import load_dotenv
14
  from pydantic import BaseModel
15
+ from cerebras_openrouter_chart_generator import generate_cerebras_chart, generate_openrouter_chart
16
  from csv_service import clean_data, extract_chart_filenames
17
  from langchain_groq import ChatGroq
18
  import pandas as pd
 
395
  """
396
  Generate a chart based on the provided CSV URL and query.
397
  Strategy:
398
+ 1. Try Cerebras Agent.
399
  2. If failed/empty, try OpenRouter Agent.
400
+ 3. Upload and return.
401
  """
402
  logger.info(f"Received csv_chart request. Chat ID: {chat_id}")
403
+ error_messages = []
404
+
405
+ async def upload_and_return(image_path: str) -> Dict[str, str]:
406
+ """Handle image upload and return public URL"""
407
+ unique_name = f'{uuid.uuid4()}.png'
408
+ # Ensure upload_file_to_supabase is imported in your file
409
+ public_url = await upload_file_to_supabase(image_path, unique_name, chat_id)
410
+
411
+ logger.info(f"Uploaded chart to: {public_url}")
412
+ try:
413
+ os.remove(image_path)
414
+ except OSError as e:
415
+ logger.warning(f"Could not delete temp file {image_path}: {str(e)}")
416
+ return {"image_url": public_url}
417
+
418
+ # --- 1. Primary Attempt: Cerebras ---
419
  try:
420
+ logger.info("Attempting with Cerebras...")
421
+ cerebras_path = await asyncio.to_thread(
422
+ generate_cerebras_chart, csv_url, query, 3
 
423
  )
424
 
425
+ if cerebras_path:
426
+ return await upload_and_return(cerebras_path)
427
+
428
+ error_messages.append("Cerebras returned no valid image.")
429
+ logger.warning("Cerebras failed. Switching to fallback.")
430
+
431
  except Exception as e:
432
+ msg = f"Cerebras critical error: {str(e)}"
433
+ error_messages.append(msg)
434
+ logger.error(msg)
435
+
436
+ # --- 2. Fallback Attempt: OpenRouter ---
437
+ try:
438
+ logger.info("Attempting with OpenRouter...")
439
+ openrouter_path = await asyncio.to_thread(
440
+ generate_openrouter_chart, csv_url, query, 3
441
+ )
442
+
443
+ if openrouter_path:
444
+ return await upload_and_return(openrouter_path)
445
+
446
+ error_messages.append("OpenRouter fallback returned no valid image.")
447
+
448
+ except Exception as e:
449
+ msg = f"OpenRouter critical error: {str(e)}"
450
+ error_messages.append(msg)
451
+ logger.error(msg)
452
+
453
+ # --- Final Error Handling ---
454
+ logger.error(f"All chart generation providers failed. Errors: {'; '.join(error_messages)}")
455
+ return {"error": "Could not generate chart. Both primary and fallback agents failed."}