| import os |
| import uuid |
| import re |
| import io |
| import sys |
| import contextlib |
| import logging |
| import asyncio |
| import traceback |
| from typing import Dict, Optional, List, Tuple |
|
|
| |
| import pandas as pd |
| import numpy as np |
| import matplotlib |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| import datetime as dt |
|
|
| |
| from openai import OpenAI, APITimeoutError |
| from dotenv import load_dotenv |
|
|
| |
| matplotlib.use('Agg') |
| load_dotenv() |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| datefmt='%H:%M:%S' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| |
| CHART_DIR = "generated_charts" |
| os.makedirs(CHART_DIR, exist_ok=True) |
|
|
| |
| |
| |
|
|
| def get_data_context(csv_url: str) -> str: |
| """Inspects CSV structure to give the LLM ground truth.""" |
| logger.info(f"Inspecting CSV structure from: {csv_url}") |
| try: |
| df = pd.read_csv(csv_url, nrows=3) |
| buffer = io.StringIO() |
| df.info(buf=buffer) |
| info_str = buffer.getvalue() |
|
|
| context = ( |
| f"DATA CONTEXT:\n" |
| f"1. Columns: {list(df.columns)}\n" |
| f"2. Data Types: {info_str.splitlines()[0]}\n" |
| f"3. Sample Data (First 3 rows):\n{df.to_markdown(index=False)}\n" |
| ) |
| return context |
| except Exception as e: |
| logger.error(f"Error inspecting CSV: {e}") |
| return f"Error reading CSV: {e}. Assume standard CSV format." |
|
|
| def execute_and_capture(code: str, csv_url: str) -> Tuple[Optional[str], Optional[str]]: |
| """Executes Python code and captures the printed filename.""" |
| local_scope = { |
| "pd": pd, "np": np, "plt": plt, "sns": sns, "dt": dt, |
| "uuid": uuid, "os": os, "csv_url": csv_url |
| } |
|
|
| stdout_capture = io.StringIO() |
| try: |
| plt.clf() |
| plt.close('all') |
| |
| with contextlib.redirect_stdout(stdout_capture): |
| exec(code, {}, local_scope) |
| |
| output = stdout_capture.getvalue().strip() |
| lines = output.split('\n') |
| |
| for line in reversed(lines): |
| path = line.strip() |
| if path.endswith('.png') and os.path.exists(path): |
| return path, None |
| |
| return None, f"Code ran but no valid file path found in output. Output: {output}" |
| |
| except Exception: |
| error_msg = traceback.format_exc() |
| return None, error_msg |
|
|
| def construct_prompt(query: str, csv_url: str, data_context: str, error_history: List[str] = None) -> Tuple[str, str]: |
| """Builds the System and User prompts for code generation.""" |
| system_prompt = """You are a Senior Data Analyst and Python Expert. |
| |
| TASK: Write Python code to create a chart based on the user query and dataset. |
| |
| STRICT VISUALIZATION RULES: |
| 1. Use `pd.read_csv(csv_url)` to load data. |
| 2. Use `seaborn` (imported as sns) and `matplotlib.pyplot` (imported as plt). |
| 3. STYLE: Use `sns.set_palette("colorblind")` and `sns.set_style("whitegrid")`. |
| 4. SIZE: `plt.figure(figsize=(12, 7))` |
| 5. FONT: Title size 14, Label size 12. Rotate x-labels 45 degrees if needed. |
| 6. CLEANUP: Handle missing values appropriately. |
| |
| FILE HANDLING & OUTPUT RULES: |
| 1. Generate a unique filename: `filename = f"generated_charts/chart_{uuid.uuid4().hex}.png"` |
| 2. Save file: `plt.savefig(filename, bbox_inches='tight')` |
| 3. CRITICAL: The FINAL line of code MUST be exactly `print(filename)` |
| 4. Do NOT use plt.show(). |
| |
| RESPONSE FORMAT: |
| - Return ONLY the Python code inside ```python ... ``` blocks. |
| """ |
|
|
| user_prompt = f"CSV URL: {csv_url}\nQUERY: {query}\n\n{data_context}" |
|
|
| if error_history: |
| history_str = "\n".join(error_history) |
| user_prompt += f"\n\n!!! PREVIOUS ATTEMPTS FAILED !!!\nError Log:\n{history_str}\n\nPlease FIX the code." |
| |
| return system_prompt, user_prompt |
|
|
| def extract_code(content: str) -> str: |
| """Extracts code from markdown blocks.""" |
| match = re.search(r"python(.*?)", content, re.DOTALL) |
| return match.group(1).strip() if match else content.replace("```", "").strip() |
|
|
| |
| |
| |
|
|
| def generate_cerebras_chart(csv_url: str, query: str, max_retries: int = 3) -> Optional[str]: |
| """Cerebras-specific orchestrator with Multi-Key Rotation.""" |
| logger.info("Starting CEREBRAS chart generation...") |
| model_name = os.getenv("CEREBRAS_CODING_MODEL", "llama3.1-70b") |
| base_url = os.getenv("CEREBRAS_BASE_URL", "https://api.cerebras.ai/v1") |
| data_context = get_data_context(csv_url) |
|
|
| |
| api_keys = [k.strip() for k in os.getenv("CEREBRAS_API_KEYS", "").split(",") if k.strip()] |
| |
| if not api_keys: |
| logger.error("No CEREBRAS_API_KEYS found.") |
| return None |
|
|
| |
| for key_index, api_key in enumerate(api_keys): |
| logger.info(f"Attempting Cerebras Provider with Key Index [{key_index}]") |
| |
| |
| |
| |
| client = OpenAI( |
| base_url=base_url, |
| api_key=api_key, |
| ) |
| error_history = [] |
|
|
| for attempt in range(1, max_retries + 1): |
| try: |
| system_prompt, user_prompt = construct_prompt(query, csv_url, data_context, error_history) |
| |
| logger.info(f"Key [{key_index}] - Requesting LLM (Attempt {attempt})...") |
| |
| response = client.chat.completions.create( |
| model=model_name, |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt} |
| ], |
| temperature=0.1 |
| ) |
| |
| code = extract_code(response.choices[0].message.content) |
| file_path, error = execute_and_capture(code, csv_url) |
| |
| if file_path: |
| logger.info(f"Cerebras Success using Key [{key_index}]: {file_path}") |
| return file_path |
| |
| error_history.append(f"Attempt {attempt} Code Execution Error:\n{error}") |
| |
| except APITimeoutError: |
| logger.error(f"Cerebras Key [{key_index}] TIMED OUT after 20s on Attempt {attempt}.") |
| error_history.append("System Error: API Timeout (20s limit reached).") |
| except Exception as e: |
| logger.error(f"Cerebras Key [{key_index}] Attempt {attempt} failed: {e}") |
| error_history.append(f"System Error: {str(e)}") |
| |
| logger.warning(f"Cerebras Key [{key_index}] exhausted all retries. Moving to next key...") |
|
|
| logger.error("All Cerebras keys failed.") |
| return None |
|
|
| |
| |
| |
|
|
| def generate_openrouter_chart(csv_url: str, query: str, max_retries: int = 3) -> Optional[str]: |
| """OpenRouter-specific orchestrator (Fallback) with Multi-Key Rotation.""" |
| logger.info("Starting OPENROUTER chart generation (Fallback)...") |
| model_name = os.getenv("OPENROUTER_MODEL", "openai/gpt-4o") |
| base_url = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1") |
| data_context = get_data_context(csv_url) |
|
|
| |
| api_keys = [k.strip() for k in os.getenv("OPENROUTER_API_KEYS", "").split(",") if k.strip()] |
|
|
| if not api_keys: |
| logger.error("No OPENROUTER_API_KEYS found.") |
| return None |
|
|
| |
| for key_index, api_key in enumerate(api_keys): |
| logger.info(f"Attempting OpenRouter Provider with Key Index [{key_index}]") |
| |
| |
| |
| |
| client = OpenAI( |
| base_url=base_url, |
| api_key=api_key, |
| ) |
| error_history = [] |
|
|
| for attempt in range(1, max_retries + 1): |
| try: |
| system_prompt, user_prompt = construct_prompt(query, csv_url, data_context, error_history) |
| |
| logger.info(f"Key [{key_index}] - Requesting LLM (Attempt {attempt})...") |
| |
| response = client.chat.completions.create( |
| model=model_name, |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt} |
| ], |
| temperature=0.1 |
| ) |
| |
| code = extract_code(response.choices[0].message.content) |
| file_path, error = execute_and_capture(code, csv_url) |
| |
| if file_path: |
| logger.info(f"OpenRouter Success using Key [{key_index}]: {file_path}") |
| return file_path |
| |
| error_history.append(f"Attempt {attempt} Code Execution Error:\n{error}") |
| |
| except APITimeoutError: |
| logger.error(f"OpenRouter Key [{key_index}] TIMED OUT after 30s on Attempt {attempt}.") |
| error_history.append("System Error: API Timeout (30s limit reached).") |
| except Exception as e: |
| logger.error(f"OpenRouter Key [{key_index}] Attempt {attempt} failed: {e}") |
| error_history.append(f"System Error: {str(e)}") |
|
|
| logger.warning(f"OpenRouter Key [{key_index}] exhausted all retries. Moving to next key...") |
|
|
| logger.error("All OpenRouter keys failed.") |
| return None |
|
|