FastApi / cerebras_openrouter_chart_generator.py
Soumik Bose
go
03e9433
import os
import uuid
import re
import io
import sys
import contextlib
import logging
import asyncio
import traceback
from typing import Dict, Optional, List, Tuple
# Data Science Stack
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import datetime as dt
# LLM Client
from openai import OpenAI, APITimeoutError
from dotenv import load_dotenv
# --- Configuration & Setup ---
matplotlib.use('Agg')
load_dotenv()
# Configure logging to show timestamps
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%H:%M:%S'
)
logger = logging.getLogger(__name__)
# Directory for charts
CHART_DIR = "generated_charts"
os.makedirs(CHART_DIR, exist_ok=True)
# ==================================================================================
# SHARED HELPER FUNCTIONS
# ==================================================================================
def get_data_context(csv_url: str) -> str:
"""Inspects CSV structure to give the LLM ground truth."""
logger.info(f"Inspecting CSV structure from: {csv_url}")
try:
df = pd.read_csv(csv_url, nrows=3)
buffer = io.StringIO()
df.info(buf=buffer)
info_str = buffer.getvalue()
context = (
f"DATA CONTEXT:\n"
f"1. Columns: {list(df.columns)}\n"
f"2. Data Types: {info_str.splitlines()[0]}\n"
f"3. Sample Data (First 3 rows):\n{df.to_markdown(index=False)}\n"
)
return context
except Exception as e:
logger.error(f"Error inspecting CSV: {e}")
return f"Error reading CSV: {e}. Assume standard CSV format."
def execute_and_capture(code: str, csv_url: str) -> Tuple[Optional[str], Optional[str]]:
"""Executes Python code and captures the printed filename."""
local_scope = {
"pd": pd, "np": np, "plt": plt, "sns": sns, "dt": dt,
"uuid": uuid, "os": os, "csv_url": csv_url
}
stdout_capture = io.StringIO()
try:
plt.clf()
plt.close('all')
with contextlib.redirect_stdout(stdout_capture):
exec(code, {}, local_scope)
output = stdout_capture.getvalue().strip()
lines = output.split('\n')
for line in reversed(lines):
path = line.strip()
if path.endswith('.png') and os.path.exists(path):
return path, None
return None, f"Code ran but no valid file path found in output. Output: {output}"
except Exception:
error_msg = traceback.format_exc()
return None, error_msg
def construct_prompt(query: str, csv_url: str, data_context: str, error_history: List[str] = None) -> Tuple[str, str]:
"""Builds the System and User prompts for code generation."""
system_prompt = """You are a Senior Data Analyst and Python Expert.
TASK: Write Python code to create a chart based on the user query and dataset.
STRICT VISUALIZATION RULES:
1. Use `pd.read_csv(csv_url)` to load data.
2. Use `seaborn` (imported as sns) and `matplotlib.pyplot` (imported as plt).
3. STYLE: Use `sns.set_palette("colorblind")` and `sns.set_style("whitegrid")`.
4. SIZE: `plt.figure(figsize=(12, 7))`
5. FONT: Title size 14, Label size 12. Rotate x-labels 45 degrees if needed.
6. CLEANUP: Handle missing values appropriately.
FILE HANDLING & OUTPUT RULES:
1. Generate a unique filename: `filename = f"generated_charts/chart_{uuid.uuid4().hex}.png"`
2. Save file: `plt.savefig(filename, bbox_inches='tight')`
3. CRITICAL: The FINAL line of code MUST be exactly `print(filename)`
4. Do NOT use plt.show().
RESPONSE FORMAT:
- Return ONLY the Python code inside ```python ... ``` blocks.
"""
user_prompt = f"CSV URL: {csv_url}\nQUERY: {query}\n\n{data_context}"
if error_history:
history_str = "\n".join(error_history)
user_prompt += f"\n\n!!! PREVIOUS ATTEMPTS FAILED !!!\nError Log:\n{history_str}\n\nPlease FIX the code."
return system_prompt, user_prompt
def extract_code(content: str) -> str:
"""Extracts code from markdown blocks."""
match = re.search(r"python(.*?)", content, re.DOTALL)
return match.group(1).strip() if match else content.replace("```", "").strip()
# ==================================================================================
# CEREBRAS AGENT
# ==================================================================================
def generate_cerebras_chart(csv_url: str, query: str, max_retries: int = 3) -> Optional[str]:
"""Cerebras-specific orchestrator with Multi-Key Rotation."""
logger.info("Starting CEREBRAS chart generation...")
model_name = os.getenv("CEREBRAS_CODING_MODEL", "llama3.1-70b")
base_url = os.getenv("CEREBRAS_BASE_URL", "https://api.cerebras.ai/v1")
data_context = get_data_context(csv_url)
# Split keys by comma and strip whitespace
api_keys = [k.strip() for k in os.getenv("CEREBRAS_API_KEYS", "").split(",") if k.strip()]
if not api_keys:
logger.error("No CEREBRAS_API_KEYS found.")
return None
# Iterate through EVERY available key
for key_index, api_key in enumerate(api_keys):
logger.info(f"Attempting Cerebras Provider with Key Index [{key_index}]")
# ------------------------------------------------------------------
# FIX: Initialize client with strict timeout and NO internal retries
# ------------------------------------------------------------------
client = OpenAI(
base_url=base_url,
api_key=api_key,
)
error_history = []
for attempt in range(1, max_retries + 1):
try:
system_prompt, user_prompt = construct_prompt(query, csv_url, data_context, error_history)
logger.info(f"Key [{key_index}] - Requesting LLM (Attempt {attempt})...")
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1
)
code = extract_code(response.choices[0].message.content)
file_path, error = execute_and_capture(code, csv_url)
if file_path:
logger.info(f"Cerebras Success using Key [{key_index}]: {file_path}")
return file_path
error_history.append(f"Attempt {attempt} Code Execution Error:\n{error}")
except APITimeoutError:
logger.error(f"Cerebras Key [{key_index}] TIMED OUT after 20s on Attempt {attempt}.")
error_history.append("System Error: API Timeout (20s limit reached).")
except Exception as e:
logger.error(f"Cerebras Key [{key_index}] Attempt {attempt} failed: {e}")
error_history.append(f"System Error: {str(e)}")
logger.warning(f"Cerebras Key [{key_index}] exhausted all retries. Moving to next key...")
logger.error("All Cerebras keys failed.")
return None
# ==================================================================================
# OPENROUTER AGENT (FALLBACK)
# ==================================================================================
def generate_openrouter_chart(csv_url: str, query: str, max_retries: int = 3) -> Optional[str]:
"""OpenRouter-specific orchestrator (Fallback) with Multi-Key Rotation."""
logger.info("Starting OPENROUTER chart generation (Fallback)...")
model_name = os.getenv("OPENROUTER_MODEL", "openai/gpt-4o")
base_url = os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1")
data_context = get_data_context(csv_url)
# Split keys by comma and strip whitespace
api_keys = [k.strip() for k in os.getenv("OPENROUTER_API_KEYS", "").split(",") if k.strip()]
if not api_keys:
logger.error("No OPENROUTER_API_KEYS found.")
return None
# Iterate through EVERY available key
for key_index, api_key in enumerate(api_keys):
logger.info(f"Attempting OpenRouter Provider with Key Index [{key_index}]")
# ------------------------------------------------------------------
# FIX: Initialize client with strict timeout and NO internal retries
# ------------------------------------------------------------------
client = OpenAI(
base_url=base_url,
api_key=api_key,
)
error_history = []
for attempt in range(1, max_retries + 1):
try:
system_prompt, user_prompt = construct_prompt(query, csv_url, data_context, error_history)
logger.info(f"Key [{key_index}] - Requesting LLM (Attempt {attempt})...")
response = client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1
)
code = extract_code(response.choices[0].message.content)
file_path, error = execute_and_capture(code, csv_url)
if file_path:
logger.info(f"OpenRouter Success using Key [{key_index}]: {file_path}")
return file_path
error_history.append(f"Attempt {attempt} Code Execution Error:\n{error}")
except APITimeoutError:
logger.error(f"OpenRouter Key [{key_index}] TIMED OUT after 30s on Attempt {attempt}.")
error_history.append("System Error: API Timeout (30s limit reached).")
except Exception as e:
logger.error(f"OpenRouter Key [{key_index}] Attempt {attempt} failed: {e}")
error_history.append(f"System Error: {str(e)}")
logger.warning(f"OpenRouter Key [{key_index}] exhausted all retries. Moving to next key...")
logger.error("All OpenRouter keys failed.")
return None