ChaRtBot / utils /prompt.py
Deepa Shalini
gallery view, updated prompt for subplots choropleth
b89c575
# libraries to help with the environment variables
import os
from dotenv import load_dotenv
import logging
# libraries to help with the AI model
from langchain_groq import ChatGroq
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from utils import helpers
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# get the credentials from .env
load_dotenv()
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
# Validate that the API key is present and not a placeholder
if not GROQ_API_KEY or GROQ_API_KEY == 'your_groq_api_key_here':
raise ValueError(
"GROQ_API_KEY environment variable is not set or is still the placeholder. "
"Please update the .env file with your actual Groq API key."
)
# define connectivity to the llm
try:
llm = ChatGroq(
model="llama-3.3-70b-versatile",
api_key=GROQ_API_KEY,
temperature=0
)
except Exception as e:
raise ValueError(f"Failed to initialize ChatGroq: {str(e)}")
def get_prompt_text() -> str:
"""
Get the system prompt for data visualization generation.
Returns:
str: The system prompt template
"""
return """You are a data visualization expert and you only use the graphing library Plotly.
CRITICAL VALIDATION RULES - EXECUTE BEFORE GENERATING ANY CODE:
1. RELEVANCE CHECK: Before generating any code, you MUST verify that the user's request is relevant to the provided dataset.
2. COLUMN VERIFICATION: Analyze the first 5 rows of data provided. If the user explicitly mentions column names that do NOT exist in the dataset, you MUST return an error message instead of code.
3. DATA CONTEXT VERIFICATION: If the user's request asks about metrics, categories, or data points that are clearly incompatible with the dataset columns shown, you MUST return an error message instead of code.
4. NON-VISUALIZATION REQUESTS: If the user's request is not about data visualization (e.g., asking for text generation, general questions, unrelated tasks), you MUST return an error message instead of code.
ERROR MESSAGE FORMAT - Use this EXACT format when validation fails:
ERROR: The request appears to be unrelated to the provided dataset. Please rephrase your request to refer to the actual columns and data available in your file. Available columns are: [list the column names from the data provided].
IMPORTANT: Only generate Python code if ALL of the following are true:
- The request is about creating a data visualization
- The request refers to columns, metrics, or patterns that could reasonably exist in the provided dataset
- The user has not explicitly mentioned column names that don't exist in the dataset
If any validation rule fails, return ONLY the error message in the format specified above. Do NOT generate any Python code.
IF VALIDATION PASSES, PROCEED WITH CODE GENERATION:
PANDAS DATA HANDLING BEST PRACTICES:
- Always use .copy() when creating a new dataframe from a subset or filtered view to avoid SettingWithCopyWarning.
- Example: df_filtered = df[df['column'] > 0].copy()
- When modifying data, always work on explicit copies, not chained indexing.
- Use .loc[] for setting values: df.loc[condition, 'column'] = value
- Avoid chained assignment like df[condition]['column'] = value
Ensure that before performing any data manipulation or plotting, the code checks for column data types and converts them if necessary.
For example, numeric columns should be converted to floats or integers using pd.to_numeric(), and non-numeric columns should be excluded from numeric operations.
Before creating any visualizations, ensure that any rows with NaN or missing values in the relevant columns are removed. Additionally,
handle missing values appropriately based on the context, ensuring cleaner visualizations.
For example, use df.dropna(subset=[column_name]) for data cleaning. Never use this statement: df.dropna(inplace=True).
The graphs you plot shall always have a white background and shall follow data visualization best practices.
Do not ignore any of the following visualization best practices:
{data_visualization_best_practices}
If the user requests a single visualization, figure height to 800.
Ensure that the graph is clearly labeled with a title, x-axis label, y-axis label, and legend.
SPECIFIC CHART TYPE INSTRUCTIONS:
CHOROPLETH MAPS:
CRITICAL: When creating a choropleth map of the United States, you MUST include ALL of the following parameters:
- locations: Set to the column containing two-letter state abbreviations (e.g., 'AL', 'NY', 'CA', 'TX')
- locationmode: MUST be set to 'USA-states' (this is CRITICAL - without it, the map will be blank)
- scope: Set to 'usa'
Example:
fig = px.choropleth(df,
locations='state_code_column',
locationmode='USA-states',
scope='usa',
color='value_column',
title='Map Title')
The locations parameter should reference the column with state codes, not the column with full state names.
Always verify that locationmode='USA-states' is present in the code.
CHOROPLETH MAPS IN SUBPLOTS:
When creating a choropleth map of the USA as part of subplots using graph_objects (go.Choropleth), you MUST include the following:
1. After creating all traces with fig.add_trace(), add:
fig.update_layout(title_text='Your Title', height=800, plot_bgcolor='white', paper_bgcolor='white')
Note: Generic background settings don't apply to choropleths in subplots, so this is necessary.
2. Force the USA scope on the specific choropleth subplot using:
fig.update_geos(scope="usa", projection_type="albers usa", bgcolor="white", row=X, col=Y)
Replace X and Y with the actual row and column numbers where the choropleth is located.
Example:
# After adding all traces
fig.update_layout(title_text='E-commerce Data Visualization', height=800, plot_bgcolor='white', paper_bgcolor='white')
# Force USA scope on the choropleth subplot (adjust row, col as needed)
fig.update_geos(scope="usa", projection_type="albers usa", bgcolor="white", row=2, col=2)
{dumbbell_charts_section}
{polar_charts_section}
If the user requests multiple visualizations, create a subplot for each visualization.
The libraries required for multiple visualizations are: import plotly.graph_objects as go and from plotly.subplots import make_subplots.
Utilize the plotly.graph_objects library's make_subplots() method to create subplots, specifying the number of rows and columns,
and the specs parameter to define what type of graph will be present in each subplot to accommodate all requested visualizations.
Then, use the add_trace() method to add each graph to the appropriate subplot.
When generating subplots that include pie charts and xy plots (like bar or scatter), ensure that pie charts are assigned a separate 'domain' subplot type.
Use the make_subplots() function with the specs argument correctly set for pie charts and other plots.
For example, use make_subplots(rows=1, cols=2, specs=[[dict(type='domain'), dict(type='xy)]]) for a pie chart and a bar plot.
Before returning the final code, verify that all trace types are compatible with the assigned subplot types,
particularly ensuring that pie charts are in domain-type subplots. If an error is detected, correct the subplot type automatically.
Validate the layout before adding traces.
Ensure each subplot is clearly labeled and formatted according to best practices.
Here are examples of how to create multiple visualizations in a single figure:
Example 1: \n
{example_subplots1}
Example 2: \n
{example_subplots2}
Example 3: \n
{example_subplots3}
The height of the figure (fig) should be set to 800.
Suppose that the data is provided as a {name_of_file} file.
Here are the first 5 rows of the data set: {data}. Follow the user's indications when creating the graph.
There should be no natural language text in the python code block.
REMINDER: Your code MUST end with fig.show() to display the visualization."""
def _should_include_dumbbell_examples(user_input: str) -> bool:
"""
Check if user's request is about dumbbell charts or comparison visualizations.
Args:
user_input: User's visualization request
Returns:
bool: True if dumbbell chart examples should be included
"""
dumbbell_keywords = [
'dumbbell', 'dumb bell', 'dumbell', 'dumbel', 'comparison', 'before and after', 'before after',
'start and end', 'start end', 'range', 'difference', 'gap', 'change over'
]
user_input_lower = user_input.lower()
return any(keyword in user_input_lower for keyword in dumbbell_keywords)
def _should_include_polar_examples(user_input: str) -> bool:
"""
Check if user's request is about polar charts, calendar views, or circular visualizations.
Args:
user_input: User's visualization request
Returns:
bool: True if polar chart examples should be included
"""
polar_keywords = [
'polar', 'circular', 'radial', 'circular fashion', 'radar', 'rose'
]
user_input_lower = user_input.lower()
return any(keyword in user_input_lower for keyword in polar_keywords)
def get_response(user_input: str, data_top5_csv_string: str, file_name: str) -> str:
"""
Get a response from the LLM for creating data visualizations.
Args:
user_input: User's request for visualization
data_top5_csv_string: CSV string of first 5 rows of data
file_name: Name of the data file
Returns:
LLM response content containing Python code or error message
Raises:
Exception: If API call fails or validation fails
"""
try:
# Determine if dumbbell chart examples should be included
include_dumbbell = _should_include_dumbbell_examples(user_input)
# Determine if polar chart examples should be included
include_polar = _should_include_polar_examples(user_input)
# Build dumbbell charts section conditionally
dumbbell_charts_section = ""
if include_dumbbell:
dumbbell_example = helpers.read_doc(
helpers.get_app_file_path("assets", "example_dumbbell_chart.txt")
)
dumbbell_charts_section = f"""
DUMBBELL PLOTS:
When creating a dumbbell plot, use plotly.graph_objects (go) instead of plotly.express (px).
Use go.Figure() and add two go.Scatter traces for the two data points, and a go.Scatter trace for the lines connecting them.
Ensure proper labeling of axes and title for clarity.
Example: \n
{dumbbell_example}
"""
# Build polar charts section conditionally
polar_charts_section = ""
if include_polar:
polar_bar_example = helpers.read_doc(
helpers.get_app_file_path("assets", "example_polar_bar.txt")
)
polar_scatter_example = helpers.read_doc(
helpers.get_app_file_path("assets", "example_polar_scatter.txt")
)
polar_charts_section = f"""
POLAR CHARTS (RADIAL/CIRCULAR VISUALIZATIONS):
Polar charts are effective for displaying calendar views, weekly patterns, or circular data distributions.
Use them for innovative visualizations of time-based or cyclical data.
Example 1 - Polar Calendar with Cells (Barpolar):
{polar_bar_example}
Example 2 - Polar Calendar with Scatter:
{polar_scatter_example}
Use polar charts when the user requests:
- Calendar-like views
- Weekly or cyclical patterns
- Circular representations of data
- Radial visualizations
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", get_prompt_text()),
MessagesPlaceholder(variable_name="messages")
]
)
chain = prompt | llm
invoke_params = {
"messages": [HumanMessage(content=user_input)],
"data_visualization_best_practices": helpers.read_doc(
helpers.get_app_file_path("assets", "data_viz_best_practices.txt")
),
"example_subplots1": helpers.read_doc(
helpers.get_app_file_path("assets", "example_subplots1.txt")
),
"example_subplots2": helpers.read_doc(
helpers.get_app_file_path("assets", "example_subplots2.txt")
),
"example_subplots3": helpers.read_doc(
helpers.get_app_file_path("assets", "example_subplots3.txt")
),
"dumbbell_charts_section": dumbbell_charts_section,
"polar_charts_section": polar_charts_section,
"data": data_top5_csv_string,
"name_of_file": file_name
}
response = chain.invoke(invoke_params)
# Check if the response is an error message instead of code
response_text = response.content.strip()
if response_text.startswith("ERROR:"):
# Extract the error message and raise validation error
error_message = response_text.replace("ERROR:", "").strip()
raise ValueError(error_message)
return response_text
except ValueError as ve:
# This is our custom validation error from the LLM
# Re-raise with user-friendly message
raise Exception(f"Unable to process your request: {str(ve)}")
except Exception as e:
error_msg = str(e)
# DEBUG: Log the actual error to understand what's happening
logger.info(f"DEBUG - Caught exception type: {type(e).__name__}")
logger.info(f"DEBUG - Error message: {error_msg}")
# Check for specific API errors (these are real API issues, not validation errors)
if "rate_limit" in error_msg.lower() or "429" in error_msg:
raise Exception("Rate limit exceeded. Please wait a moment and try again.")
elif "authentication" in error_msg.lower() or "401" in error_msg or "api_key" in error_msg.lower():
raise Exception("We're having trouble generating your visualization.")
elif "timeout" in error_msg.lower():
raise Exception("Request timed out. Please try again.")
else:
raise Exception(f"Unable to process your request: {error_msg}")
def get_python_exception_prompt_text() -> str:
"""
Get the system prompt for fixing Python code errors.
Returns:
str: The system prompt for error fixing
"""
return """The Python code you provided {code} has an error {exception}"""
def get_python_exception_response(code: str, exception: str) -> str:
"""
Get a response from the LLM to fix Python code errors.
Args:
code: The Python code that has errors
exception: The exception message
Returns:
LLM response with fixed code
Raises:
Exception: If API call fails
"""
try:
prompt = ChatPromptTemplate.from_messages(
[
("system", get_python_exception_prompt_text()),
MessagesPlaceholder(variable_name="messages")
]
)
chain = prompt | llm
response = chain.invoke(
{
"messages": [HumanMessage(
content="Rewrite the entire Python code so that it does not contain any errors. "
"The code should be able to run without any errors."
)],
"code": code,
"exception": exception
}
)
return response.content
except Exception as e:
error_msg = str(e)
# Log the complete error message
logger.info(f"Exception fixing failed - Exception type: {type(e).__name__}")
logger.info(f"Exception fixing failed - Error message: {error_msg}")
if "rate_limit" in error_msg.lower() or "429" in error_msg:
raise Exception("Rate limit exceeded. Please wait a moment and try again.")
elif "authentication" in error_msg.lower() or "401" in error_msg or "api_key" in error_msg.lower():
raise Exception("We're having trouble generating your visualization.")
elif "timeout" in error_msg.lower():
raise Exception("Request timed out. Please try again.")
else:
raise Exception(f"Unable to process your request: {error_msg}")