Spaces:
Sleeping
Sleeping
| # libraries to help with the environment variables | |
| import os | |
| from dotenv import load_dotenv | |
| import logging | |
| # libraries to help with the AI model | |
| from langchain_groq import ChatGroq | |
| from langchain_core.messages import HumanMessage | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from utils import helpers | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # get the credentials from .env | |
| load_dotenv() | |
| GROQ_API_KEY = os.getenv('GROQ_API_KEY') | |
| # Validate that the API key is present and not a placeholder | |
| if not GROQ_API_KEY or GROQ_API_KEY == 'your_groq_api_key_here': | |
| raise ValueError( | |
| "GROQ_API_KEY environment variable is not set or is still the placeholder. " | |
| "Please update the .env file with your actual Groq API key." | |
| ) | |
| # define connectivity to the llm | |
| try: | |
| llm = ChatGroq( | |
| model="llama-3.3-70b-versatile", | |
| api_key=GROQ_API_KEY, | |
| temperature=0 | |
| ) | |
| except Exception as e: | |
| raise ValueError(f"Failed to initialize ChatGroq: {str(e)}") | |
| def get_prompt_text() -> str: | |
| """ | |
| Get the system prompt for data visualization generation. | |
| Returns: | |
| str: The system prompt template | |
| """ | |
| return """You are a data visualization expert and you only use the graphing library Plotly. | |
| CRITICAL VALIDATION RULES - EXECUTE BEFORE GENERATING ANY CODE: | |
| 1. RELEVANCE CHECK: Before generating any code, you MUST verify that the user's request is relevant to the provided dataset. | |
| 2. COLUMN VERIFICATION: Analyze the first 5 rows of data provided. If the user explicitly mentions column names that do NOT exist in the dataset, you MUST return an error message instead of code. | |
| 3. DATA CONTEXT VERIFICATION: If the user's request asks about metrics, categories, or data points that are clearly incompatible with the dataset columns shown, you MUST return an error message instead of code. | |
| 4. NON-VISUALIZATION REQUESTS: If the user's request is not about data visualization (e.g., asking for text generation, general questions, unrelated tasks), you MUST return an error message instead of code. | |
| ERROR MESSAGE FORMAT - Use this EXACT format when validation fails: | |
| ERROR: The request appears to be unrelated to the provided dataset. Please rephrase your request to refer to the actual columns and data available in your file. Available columns are: [list the column names from the data provided]. | |
| IMPORTANT: Only generate Python code if ALL of the following are true: | |
| - The request is about creating a data visualization | |
| - The request refers to columns, metrics, or patterns that could reasonably exist in the provided dataset | |
| - The user has not explicitly mentioned column names that don't exist in the dataset | |
| If any validation rule fails, return ONLY the error message in the format specified above. Do NOT generate any Python code. | |
| IF VALIDATION PASSES, PROCEED WITH CODE GENERATION: | |
| PANDAS DATA HANDLING BEST PRACTICES: | |
| - Always use .copy() when creating a new dataframe from a subset or filtered view to avoid SettingWithCopyWarning. | |
| - Example: df_filtered = df[df['column'] > 0].copy() | |
| - When modifying data, always work on explicit copies, not chained indexing. | |
| - Use .loc[] for setting values: df.loc[condition, 'column'] = value | |
| - Avoid chained assignment like df[condition]['column'] = value | |
| Ensure that before performing any data manipulation or plotting, the code checks for column data types and converts them if necessary. | |
| For example, numeric columns should be converted to floats or integers using pd.to_numeric(), and non-numeric columns should be excluded from numeric operations. | |
| Before creating any visualizations, ensure that any rows with NaN or missing values in the relevant columns are removed. Additionally, | |
| handle missing values appropriately based on the context, ensuring cleaner visualizations. | |
| For example, use df.dropna(subset=[column_name]) for data cleaning. Never use this statement: df.dropna(inplace=True). | |
| The graphs you plot shall always have a white background and shall follow data visualization best practices. | |
| Do not ignore any of the following visualization best practices: | |
| {data_visualization_best_practices} | |
| If the user requests a single visualization, figure height to 800. | |
| Ensure that the graph is clearly labeled with a title, x-axis label, y-axis label, and legend. | |
| SPECIFIC CHART TYPE INSTRUCTIONS: | |
| CHOROPLETH MAPS: | |
| CRITICAL: When creating a choropleth map of the United States, you MUST include ALL of the following parameters: | |
| - locations: Set to the column containing two-letter state abbreviations (e.g., 'AL', 'NY', 'CA', 'TX') | |
| - locationmode: MUST be set to 'USA-states' (this is CRITICAL - without it, the map will be blank) | |
| - scope: Set to 'usa' | |
| Example: | |
| fig = px.choropleth(df, | |
| locations='state_code_column', | |
| locationmode='USA-states', | |
| scope='usa', | |
| color='value_column', | |
| title='Map Title') | |
| The locations parameter should reference the column with state codes, not the column with full state names. | |
| Always verify that locationmode='USA-states' is present in the code. | |
| CHOROPLETH MAPS IN SUBPLOTS: | |
| When creating a choropleth map of the USA as part of subplots using graph_objects (go.Choropleth), you MUST include the following: | |
| 1. After creating all traces with fig.add_trace(), add: | |
| fig.update_layout(title_text='Your Title', height=800, plot_bgcolor='white', paper_bgcolor='white') | |
| Note: Generic background settings don't apply to choropleths in subplots, so this is necessary. | |
| 2. Force the USA scope on the specific choropleth subplot using: | |
| fig.update_geos(scope="usa", projection_type="albers usa", bgcolor="white", row=X, col=Y) | |
| Replace X and Y with the actual row and column numbers where the choropleth is located. | |
| Example: | |
| # After adding all traces | |
| fig.update_layout(title_text='E-commerce Data Visualization', height=800, plot_bgcolor='white', paper_bgcolor='white') | |
| # Force USA scope on the choropleth subplot (adjust row, col as needed) | |
| fig.update_geos(scope="usa", projection_type="albers usa", bgcolor="white", row=2, col=2) | |
| {dumbbell_charts_section} | |
| {polar_charts_section} | |
| If the user requests multiple visualizations, create a subplot for each visualization. | |
| The libraries required for multiple visualizations are: import plotly.graph_objects as go and from plotly.subplots import make_subplots. | |
| Utilize the plotly.graph_objects library's make_subplots() method to create subplots, specifying the number of rows and columns, | |
| and the specs parameter to define what type of graph will be present in each subplot to accommodate all requested visualizations. | |
| Then, use the add_trace() method to add each graph to the appropriate subplot. | |
| When generating subplots that include pie charts and xy plots (like bar or scatter), ensure that pie charts are assigned a separate 'domain' subplot type. | |
| Use the make_subplots() function with the specs argument correctly set for pie charts and other plots. | |
| For example, use make_subplots(rows=1, cols=2, specs=[[dict(type='domain'), dict(type='xy)]]) for a pie chart and a bar plot. | |
| Before returning the final code, verify that all trace types are compatible with the assigned subplot types, | |
| particularly ensuring that pie charts are in domain-type subplots. If an error is detected, correct the subplot type automatically. | |
| Validate the layout before adding traces. | |
| Ensure each subplot is clearly labeled and formatted according to best practices. | |
| Here are examples of how to create multiple visualizations in a single figure: | |
| Example 1: \n | |
| {example_subplots1} | |
| Example 2: \n | |
| {example_subplots2} | |
| Example 3: \n | |
| {example_subplots3} | |
| The height of the figure (fig) should be set to 800. | |
| Suppose that the data is provided as a {name_of_file} file. | |
| Here are the first 5 rows of the data set: {data}. Follow the user's indications when creating the graph. | |
| There should be no natural language text in the python code block. | |
| REMINDER: Your code MUST end with fig.show() to display the visualization.""" | |
| def _should_include_dumbbell_examples(user_input: str) -> bool: | |
| """ | |
| Check if user's request is about dumbbell charts or comparison visualizations. | |
| Args: | |
| user_input: User's visualization request | |
| Returns: | |
| bool: True if dumbbell chart examples should be included | |
| """ | |
| dumbbell_keywords = [ | |
| 'dumbbell', 'dumb bell', 'dumbell', 'dumbel', 'comparison', 'before and after', 'before after', | |
| 'start and end', 'start end', 'range', 'difference', 'gap', 'change over' | |
| ] | |
| user_input_lower = user_input.lower() | |
| return any(keyword in user_input_lower for keyword in dumbbell_keywords) | |
| def _should_include_polar_examples(user_input: str) -> bool: | |
| """ | |
| Check if user's request is about polar charts, calendar views, or circular visualizations. | |
| Args: | |
| user_input: User's visualization request | |
| Returns: | |
| bool: True if polar chart examples should be included | |
| """ | |
| polar_keywords = [ | |
| 'polar', 'circular', 'radial', 'circular fashion', 'radar', 'rose' | |
| ] | |
| user_input_lower = user_input.lower() | |
| return any(keyword in user_input_lower for keyword in polar_keywords) | |
| def get_response(user_input: str, data_top5_csv_string: str, file_name: str) -> str: | |
| """ | |
| Get a response from the LLM for creating data visualizations. | |
| Args: | |
| user_input: User's request for visualization | |
| data_top5_csv_string: CSV string of first 5 rows of data | |
| file_name: Name of the data file | |
| Returns: | |
| LLM response content containing Python code or error message | |
| Raises: | |
| Exception: If API call fails or validation fails | |
| """ | |
| try: | |
| # Determine if dumbbell chart examples should be included | |
| include_dumbbell = _should_include_dumbbell_examples(user_input) | |
| # Determine if polar chart examples should be included | |
| include_polar = _should_include_polar_examples(user_input) | |
| # Build dumbbell charts section conditionally | |
| dumbbell_charts_section = "" | |
| if include_dumbbell: | |
| dumbbell_example = helpers.read_doc( | |
| helpers.get_app_file_path("assets", "example_dumbbell_chart.txt") | |
| ) | |
| dumbbell_charts_section = f""" | |
| DUMBBELL PLOTS: | |
| When creating a dumbbell plot, use plotly.graph_objects (go) instead of plotly.express (px). | |
| Use go.Figure() and add two go.Scatter traces for the two data points, and a go.Scatter trace for the lines connecting them. | |
| Ensure proper labeling of axes and title for clarity. | |
| Example: \n | |
| {dumbbell_example} | |
| """ | |
| # Build polar charts section conditionally | |
| polar_charts_section = "" | |
| if include_polar: | |
| polar_bar_example = helpers.read_doc( | |
| helpers.get_app_file_path("assets", "example_polar_bar.txt") | |
| ) | |
| polar_scatter_example = helpers.read_doc( | |
| helpers.get_app_file_path("assets", "example_polar_scatter.txt") | |
| ) | |
| polar_charts_section = f""" | |
| POLAR CHARTS (RADIAL/CIRCULAR VISUALIZATIONS): | |
| Polar charts are effective for displaying calendar views, weekly patterns, or circular data distributions. | |
| Use them for innovative visualizations of time-based or cyclical data. | |
| Example 1 - Polar Calendar with Cells (Barpolar): | |
| {polar_bar_example} | |
| Example 2 - Polar Calendar with Scatter: | |
| {polar_scatter_example} | |
| Use polar charts when the user requests: | |
| - Calendar-like views | |
| - Weekly or cyclical patterns | |
| - Circular representations of data | |
| - Radial visualizations | |
| """ | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", get_prompt_text()), | |
| MessagesPlaceholder(variable_name="messages") | |
| ] | |
| ) | |
| chain = prompt | llm | |
| invoke_params = { | |
| "messages": [HumanMessage(content=user_input)], | |
| "data_visualization_best_practices": helpers.read_doc( | |
| helpers.get_app_file_path("assets", "data_viz_best_practices.txt") | |
| ), | |
| "example_subplots1": helpers.read_doc( | |
| helpers.get_app_file_path("assets", "example_subplots1.txt") | |
| ), | |
| "example_subplots2": helpers.read_doc( | |
| helpers.get_app_file_path("assets", "example_subplots2.txt") | |
| ), | |
| "example_subplots3": helpers.read_doc( | |
| helpers.get_app_file_path("assets", "example_subplots3.txt") | |
| ), | |
| "dumbbell_charts_section": dumbbell_charts_section, | |
| "polar_charts_section": polar_charts_section, | |
| "data": data_top5_csv_string, | |
| "name_of_file": file_name | |
| } | |
| response = chain.invoke(invoke_params) | |
| # Check if the response is an error message instead of code | |
| response_text = response.content.strip() | |
| if response_text.startswith("ERROR:"): | |
| # Extract the error message and raise validation error | |
| error_message = response_text.replace("ERROR:", "").strip() | |
| raise ValueError(error_message) | |
| return response_text | |
| except ValueError as ve: | |
| # This is our custom validation error from the LLM | |
| # Re-raise with user-friendly message | |
| raise Exception(f"Unable to process your request: {str(ve)}") | |
| except Exception as e: | |
| error_msg = str(e) | |
| # DEBUG: Log the actual error to understand what's happening | |
| logger.info(f"DEBUG - Caught exception type: {type(e).__name__}") | |
| logger.info(f"DEBUG - Error message: {error_msg}") | |
| # Check for specific API errors (these are real API issues, not validation errors) | |
| if "rate_limit" in error_msg.lower() or "429" in error_msg: | |
| raise Exception("Rate limit exceeded. Please wait a moment and try again.") | |
| elif "authentication" in error_msg.lower() or "401" in error_msg or "api_key" in error_msg.lower(): | |
| raise Exception("We're having trouble generating your visualization.") | |
| elif "timeout" in error_msg.lower(): | |
| raise Exception("Request timed out. Please try again.") | |
| else: | |
| raise Exception(f"Unable to process your request: {error_msg}") | |
| def get_python_exception_prompt_text() -> str: | |
| """ | |
| Get the system prompt for fixing Python code errors. | |
| Returns: | |
| str: The system prompt for error fixing | |
| """ | |
| return """The Python code you provided {code} has an error {exception}""" | |
| def get_python_exception_response(code: str, exception: str) -> str: | |
| """ | |
| Get a response from the LLM to fix Python code errors. | |
| Args: | |
| code: The Python code that has errors | |
| exception: The exception message | |
| Returns: | |
| LLM response with fixed code | |
| Raises: | |
| Exception: If API call fails | |
| """ | |
| try: | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", get_python_exception_prompt_text()), | |
| MessagesPlaceholder(variable_name="messages") | |
| ] | |
| ) | |
| chain = prompt | llm | |
| response = chain.invoke( | |
| { | |
| "messages": [HumanMessage( | |
| content="Rewrite the entire Python code so that it does not contain any errors. " | |
| "The code should be able to run without any errors." | |
| )], | |
| "code": code, | |
| "exception": exception | |
| } | |
| ) | |
| return response.content | |
| except Exception as e: | |
| error_msg = str(e) | |
| # Log the complete error message | |
| logger.info(f"Exception fixing failed - Exception type: {type(e).__name__}") | |
| logger.info(f"Exception fixing failed - Error message: {error_msg}") | |
| if "rate_limit" in error_msg.lower() or "429" in error_msg: | |
| raise Exception("Rate limit exceeded. Please wait a moment and try again.") | |
| elif "authentication" in error_msg.lower() or "401" in error_msg or "api_key" in error_msg.lower(): | |
| raise Exception("We're having trouble generating your visualization.") | |
| elif "timeout" in error_msg.lower(): | |
| raise Exception("Request timed out. Please try again.") | |
| else: | |
| raise Exception(f"Unable to process your request: {error_msg}") |