Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import duckdb | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from src.client import LLMChain | |
| from src.models import Charts, TableData | |
| from src.pipelines import SQLVizChain | |
| from src.utils import plot_chart | |
| MD_TOKEN = os.getenv("MD_TOKEN") | |
| conn = duckdb.connect(f"md:my_db?motherduck_token={MD_TOKEN}", read_only=True) | |
| LEVEL = "INFO" if not os.getenv("ENV") == "PROD" else "WARNING" | |
| TAB_LINES = 8 | |
| logging.basicConfig( | |
| level=getattr(logging, LEVEL, logging.INFO), | |
| format="%(asctime)s %(levelname)s %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def _load_pipeline(): | |
| return SQLVizChain(duckdb=conn, chain=LLMChain()) | |
| pipeline = _load_pipeline() | |
| def get_schemas(): | |
| schemas = conn.execute(""" | |
| SELECT DISTINCT schema_name | |
| FROM information_schema.schemata | |
| WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
| """).fetchall() | |
| return [item[0] for item in schemas] | |
| def get_tables(schema_name): | |
| tables = conn.execute( | |
| f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'" | |
| ).fetchall() | |
| return [table[0] for table in tables] | |
| def update_tables(schema_name): | |
| tables = get_tables(schema_name) | |
| return gr.update(choices=tables) | |
| def get_table_schema(table): | |
| result = conn.sql( | |
| f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';" | |
| ).df() | |
| ddl_create = result.iloc[0, 0] | |
| parent_database = result.iloc[0, 1] | |
| schema_name = result.iloc[0, 2] | |
| full_path = f"{parent_database}.{schema_name}.{table}" | |
| if schema_name != "main": | |
| old_path = f"{schema_name}.{table}" | |
| else: | |
| old_path = table | |
| ddl_create = ddl_create.replace(old_path, full_path) | |
| return ddl_create, full_path | |
| def main(table, text_query): | |
| fig, ax = plt.subplots() | |
| ax.set_axis_off() | |
| schema, _ = get_table_schema(table) | |
| try: | |
| results = pipeline.run(user_question=text_query, context=schema) | |
| chart_data = results["chart_data"] | |
| chart_config = results["chart_config"] | |
| chart_type = results["chart_type"] | |
| generated_sql = results["sql_config"]["sql_query"] | |
| if not chart_type and chart_data is not None: | |
| if isinstance(chart_data, TableData): | |
| data = pd.DataFrame(chart_data.model_dump(exclude_none=True)) | |
| return (fig, generated_sql, data) | |
| if chart_type is not None and chart_data is not None: | |
| if isinstance(chart_data, Charts): | |
| chart_dict = chart_data.model_dump(exclude_none=True).get(chart_type) | |
| data = pd.DataFrame(chart_dict["data"]) | |
| fig = plot_chart(chart_type=chart_type, data=data, **chart_config) | |
| return (fig, generated_sql, data) | |
| if chart_data is None: | |
| return fig, generated_sql, None | |
| except Exception as e: | |
| logger.error(e) | |
| gr.Warning(f"❌ Unable to generate the visualization. {e}") | |
| return fig, None, None | |
| custom_css = """ | |
| .gradio-container { | |
| background-color: #f0f4f8; | |
| } | |
| .logo { | |
| max-width: 200px; | |
| margin: 20px auto; | |
| display: block; | |
| } | |
| .gr-button { | |
| background-color: #4a90e2 !important; | |
| } | |
| .gr-button:hover { | |
| background-color: #3a7bc8 !important; | |
| } | |
| """ | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css | |
| ) as demo: | |
| gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
| gr.Markdown(""" | |
| <div style='text-align: center;'> | |
| <strong style='font-size: 36px;'>DataViz Agent</strong> | |
| <br> | |
| <span style='font-size: 20px;'>Visualize SQL queries based on a given text for the dataset.</span> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| schema_dropdown = gr.Dropdown( | |
| choices=get_schemas(), label="Select Schema", interactive=True | |
| ) | |
| tables_dropdown = gr.Dropdown( | |
| choices=[], label="Available Tables", value=None | |
| ) | |
| with gr.Column(scale=2): | |
| query_input = gr.Textbox( | |
| lines=3, label="Text Query", placeholder="Enter your text query here..." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=7): | |
| pass | |
| with gr.Column(scale=1): | |
| generate_query_button = gr.Button("Run Query", variant="primary") | |
| with gr.Tabs(): | |
| with gr.Tab("Plot"): | |
| result_plot = gr.Plot() | |
| with gr.Tab("SQL"): | |
| generated_sql = gr.Textbox( | |
| lines=TAB_LINES, | |
| label="Generated SQL", | |
| value="", | |
| interactive=False, | |
| autoscroll=False, | |
| ) | |
| with gr.Tab("Data"): | |
| data = gr.Dataframe(label="Data", interactive=False) | |
| schema_dropdown.change( | |
| update_tables, inputs=schema_dropdown, outputs=tables_dropdown | |
| ) | |
| generate_query_button.click( | |
| main, | |
| inputs=[tables_dropdown, query_input], | |
| outputs=[result_plot, generated_sql, data], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |