Spaces:
Running
Running
| """Custom actions used within a dashboard.""" | |
| import base64 | |
| import io | |
| import logging | |
| import black | |
| import dash | |
| import dash_bootstrap_components as dbc | |
| import pandas as pd | |
| from _utils import check_file_extension | |
| from dash.exceptions import PreventUpdate | |
| from langchain_openai import ChatOpenAI | |
| from plotly import graph_objects as go | |
| from vizro.models.types import capture | |
| from vizro_ai import VizroAI | |
| try: | |
| from langchain_anthropic import ChatAnthropic | |
| except ImportError: | |
| ChatAnthropic = None | |
| try: | |
| from langchain_mistralai import ChatMistralAI | |
| except ImportError: | |
| ChatMistralAI = None | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) # TODO: remove manual setting and make centrally controlled | |
| SUPPORTED_VENDORS = { | |
| "OpenAI": ChatOpenAI, | |
| "Anthropic": ChatAnthropic, | |
| "Mistral": ChatMistralAI, | |
| "xAI": ChatOpenAI, | |
| } | |
| SUPPORTED_MODELS = { | |
| "OpenAI": [ | |
| "gpt-4o-mini", | |
| "gpt-4o", | |
| "gpt-4-turbo", | |
| ], | |
| "Anthropic": [ | |
| "claude-3-opus-latest", | |
| "claude-3-5-sonnet-latest", | |
| "claude-3-sonnet-20240229", | |
| "claude-3-haiku-20240307", | |
| ], | |
| "Mistral": ["mistral-large-latest", "open-mistral-nemo", "codestral-latest"], | |
| "xAI": ["grok-beta"], | |
| } | |
| DEFAULT_TEMPERATURE = 0.1 | |
| DEFAULT_RETRY = 3 | |
| def get_vizro_ai_plot(user_prompt, df, model, api_key, api_base, vendor_input): | |
| """VizroAi plot configuration.""" | |
| vendor = SUPPORTED_VENDORS[vendor_input] | |
| if vendor_input == "OpenAI": | |
| llm = vendor( | |
| model_name=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE | |
| ) | |
| if vendor_input == "Anthropic": | |
| llm = vendor( | |
| model=model, anthropic_api_key=api_key, anthropic_api_url=api_base, temperature=DEFAULT_TEMPERATURE | |
| ) | |
| if vendor_input == "Mistral": | |
| llm = vendor(model=model, mistral_api_key=api_key, mistral_api_url=api_base, temperature=DEFAULT_TEMPERATURE) | |
| if vendor_input == "xAI": | |
| llm = vendor(model=model, openai_api_key=api_key, openai_api_base=api_base, temperature=DEFAULT_TEMPERATURE) | |
| vizro_ai = VizroAI(model=llm) | |
| ai_outputs = vizro_ai.plot(df, user_prompt, max_debug_retry=DEFAULT_RETRY, return_elements=True) | |
| return ai_outputs | |
| def run_vizro_ai(user_prompt, n_clicks, data, model, api_key, api_base, vendor_input): # noqa: PLR0913 | |
| """Gets the AI response and adds it to the text window.""" | |
| def create_response(ai_response, figure, ai_outputs): | |
| return (ai_response, figure, {"ai_outputs": ai_outputs}) | |
| if not n_clicks: | |
| raise PreventUpdate | |
| if not data: | |
| ai_response = "Please upload data to proceed!" | |
| figure = go.Figure() | |
| return create_response(ai_response, figure, ai_outputs=None) | |
| if not api_key: | |
| ai_response = "API key not found. Make sure you enter your API key!" | |
| figure = go.Figure() | |
| return create_response(ai_response, figure, ai_outputs=None) | |
| if api_key.startswith('"'): | |
| ai_response = "Make sure you enter your API key without quotes!" | |
| figure = go.Figure() | |
| return create_response(ai_response, figure, ai_outputs=None) | |
| if api_base is not None and api_base.startswith('"'): | |
| ai_response = "Make sure you enter your API base without quotes!" | |
| figure = go.Figure() | |
| return create_response(ai_response, figure, ai_outputs=None) | |
| try: | |
| logger.info("Attempting chart code.") | |
| df = pd.DataFrame(data["data"]) | |
| ai_outputs = get_vizro_ai_plot( | |
| user_prompt=user_prompt, | |
| df=df, | |
| model=model, | |
| api_key=api_key, | |
| api_base=api_base, | |
| vendor_input=vendor_input, | |
| ) | |
| ai_code = ai_outputs.code_vizro | |
| figure_vizro = ai_outputs.get_fig_object(data_frame=df, vizro=True) | |
| figure_plotly = ai_outputs.get_fig_object(data_frame=df, vizro=False) | |
| formatted_code = black.format_str(ai_code, mode=black.Mode(line_length=100)) | |
| ai_code_outputs = { | |
| "vizro": {"code": ai_outputs.code_vizro, "fig": figure_vizro.to_json()}, | |
| "plotly": {"code": ai_outputs.code, "fig": figure_plotly.to_json()}, | |
| } | |
| ai_response = "\n".join(["```python", formatted_code, "```"]) | |
| logger.info("Successful query produced.") | |
| return create_response(ai_response, figure_vizro, ai_outputs=ai_code_outputs) | |
| except Exception as exc: | |
| logger.debug(exc) | |
| logger.info("Chart creation failed.") | |
| ai_response = f"Sorry, I can't do that. Following Error occurred: {exc}" | |
| figure = go.Figure() | |
| return create_response(ai_response, figure, ai_outputs=None) | |
| def data_upload_action(contents, filename): | |
| """Custom data upload action.""" | |
| if not contents: | |
| raise PreventUpdate | |
| if not check_file_extension(filename=filename): | |
| return ( | |
| {"error_message": "Unsupported file extension.. Make sure to upload either csv or an excel file."}, | |
| {"color": "gray"}, | |
| {"display": "none"}, | |
| ) | |
| content_type, content_string = contents.split(",") | |
| try: | |
| decoded = base64.b64decode(content_string) | |
| if filename.endswith(".csv"): | |
| # Handle CSV file | |
| df = pd.read_csv(io.StringIO(decoded.decode("utf-8"))) | |
| else: | |
| # Handle Excel file | |
| df = pd.read_excel(io.BytesIO(decoded)) | |
| data = df.to_dict("records") | |
| return {"data": data, "filename": filename}, {"cursor": "pointer"}, {} | |
| except Exception as e: | |
| logger.debug(e) | |
| return ( | |
| {"error_message": "There was an error processing this file."}, | |
| {"color": "gray", "cursor": "default"}, | |
| {"display": "none"}, | |
| ) | |
| def display_filename(data): | |
| """Custom action to display uploaded filename.""" | |
| if data is None: | |
| raise PreventUpdate | |
| display_message = data.get("filename") or data.get("error_message") | |
| return f"Uploaded file name: '{display_message}'" if "filename" in data else display_message | |
| def update_table(data): | |
| """Custom action for updating data.""" | |
| if not data: | |
| return dash.no_update | |
| df = pd.DataFrame(data["data"]) | |
| filename = data.get("filename") or data.get("error_message") | |
| modal_title = f"Data sample preview for {filename} file" | |
| df_sample = df.sample(5) | |
| table = dbc.Table.from_dataframe(df_sample, striped=False, bordered=True, hover=True) | |
| return table, modal_title | |