Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import duckdb | |
| import gradio as gr | |
| import pandas as pd | |
| import pandera as pa | |
| from pandera import Column | |
| import ydata_profiling as pp | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace | |
| from prompt import PROMPT_PANDERA, PANDERA_USER_INPUT_PROMPT | |
| from langsmith import traceable | |
| import warnings | |
| warnings.filterwarnings("ignore", category=DeprecationWarning) | |
| # Height of the Tabs Text Area | |
| TAB_LINES = 8 | |
| # Load Token | |
| md_token = os.getenv('MD_TOKEN') | |
| INPUT_PROMPT = ''' | |
| Here are the first few samples of data: | |
| <Sample Data> | |
| {data} | |
| </Sample Data<> | |
| ''' | |
| USER_INPUT = ''' | |
| Here are the first few samples of data: | |
| <Sample Data> | |
| {data} | |
| </Sample Data<> | |
| Here is the User Description: | |
| <User Description> | |
| {user_description} | |
| </User Description> | |
| ''' | |
| print('Connecting to DB...') | |
| # Connect to DB | |
| conn = duckdb.connect(f"md:my_db?motherduck_token={md_token}", read_only=True) | |
| models = ["Qwen/Qwen2.5-72B-Instruct","meta-llama/Meta-Llama-3-70B-Instruct", | |
| "meta-llama/Llama-3.1-70B-Instruct"] | |
| model_loaded = False | |
| for model in models: | |
| try: | |
| endpoint = HuggingFaceEndpoint(repo_id=model, max_new_tokens=8192) | |
| info = endpoint.client.get_endpoint_info() | |
| model_loaded = True | |
| break | |
| except Exception as e: | |
| print(f"Error for model {model}: {e}") | |
| continue | |
| llm = ChatHuggingFace(llm=endpoint).bind_tools(tools=[], max_tokens=8192) | |
| # Get Databases | |
| def get_schemas(): | |
| schemas = conn.execute(""" | |
| SELECT DISTINCT schema_name | |
| FROM information_schema.schemata | |
| WHERE schema_name NOT IN ('information_schema', 'pg_catalog') | |
| """).fetchall() | |
| return [item[0] for item in schemas] | |
| # Get Tables | |
| def get_tables_names(schema_name): | |
| tables = conn.execute(f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{schema_name}'").fetchall() | |
| return [table[0] for table in tables] | |
| # Update Tables | |
| def update_table_names(schema_name): | |
| tables = get_tables_names(schema_name) | |
| return gr.update(choices=tables) | |
| def get_data_df(schema): | |
| print('Getting Dataframe from the Database') | |
| return conn.sql(f"SELECT * FROM {schema} LIMIT 1000").df() | |
| def chat_template(system_prompt, user_prompt, df): | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=user_prompt.format(data=df.head().to_json(orient='records'))), | |
| ] | |
| return messages | |
| def chat_template_user(system_prompt, user_prompt, user_description, df): | |
| messages = [ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=user_prompt.format(data=df.head(1).to_json(orient='records'), user_description=user_description)), | |
| ] | |
| return messages | |
| def run_llm(messages): | |
| try: | |
| response = llm.invoke(messages) | |
| print(response.content) | |
| tests = json.loads(response.content) | |
| except Exception as e: | |
| return e | |
| return tests | |
| # Get Schema | |
| def get_table_schema(table): | |
| result = conn.sql(f"SELECT sql, database_name, schema_name FROM duckdb_tables() where table_name ='{table}';").df() | |
| ddl_create = result.iloc[0,0] | |
| parent_database = result.iloc[0,1] | |
| schema_name = result.iloc[0,2] | |
| full_path = f"{parent_database}.{schema_name}.{table}" | |
| if schema_name != "main": | |
| old_path = f"{schema_name}.{table}" | |
| else: | |
| old_path = table | |
| ddl_create = ddl_create.replace(old_path, full_path) | |
| return full_path | |
| def describe(df): | |
| numerical_info = df.select_dtypes(include=['number']).describe().T.reset_index() | |
| numerical_info.rename(columns={'index': 'column'}, inplace=True) | |
| categorical_info = df.select_dtypes(include=['object']).describe().T.reset_index() | |
| categorical_info.rename(columns={'index': 'column'}, inplace=True) | |
| return numerical_info, categorical_info | |
| def validate_pandera(tests, df): | |
| validation_results = [] | |
| for test in tests: | |
| column_name = test['column_name'] | |
| try: | |
| rule = eval(test['pandera_rule']) | |
| validated_column = rule(df[[column_name]]) | |
| validation_results.append({ | |
| "Columns": column_name, | |
| "Result": "✅ Pass" | |
| }) | |
| except Exception as e: | |
| validation_results.append({ | |
| "Columns": column_name, | |
| "Result": f"❌ Fail - {str(e)}" | |
| }) | |
| return pd.DataFrame(validation_results) | |
| def statistics(df): | |
| profile = pp.ProfileReport(df) | |
| report_dict = profile.get_description() | |
| description, alerts = report_dict.table, report_dict.alerts | |
| # Statistics | |
| mapping = { | |
| 'n': 'Number of observations', | |
| 'n_var': 'Number of variables', | |
| 'n_cells_missing': 'Number of cells missing', | |
| 'n_vars_with_missing': 'Number of columns with missing data', | |
| 'n_vars_all_missing': 'Columns with all missing data', | |
| 'p_cells_missing': 'Missing cells (%)', | |
| 'n_duplicates': 'Duplicated rows', | |
| 'p_duplicates': 'Duplicated rows (%)', | |
| } | |
| updated_data = {mapping.get(k, k): v for k, v in description.items() if k != 'types'} | |
| # Add flattened types information | |
| if 'Text' in description.get('types', {}): | |
| updated_data['Number of text columns'] = description['types']['Text'] | |
| if 'Categorical' in description.get('types', {}): | |
| updated_data['Number of categorical columns'] = description['types']['Categorical'] | |
| if 'Numeric' in description.get('types', {}): | |
| updated_data['Number of numeric columns'] = description['types']['Numeric'] | |
| if 'DateTime' in description.get('types', {}): | |
| updated_data['Number of datetime columns'] = description['types']['DateTime'] | |
| df_statistics = pd.DataFrame(list(updated_data.items()), columns=['Statistic Description', 'Value']) | |
| df_statistics['Value'] = df_statistics['Value'].astype(int) | |
| # Alerts | |
| alerts_list = [(str(alert).replace('[', '').replace(']', ''), alert.alert_type_name) for alert in alerts] | |
| df_alerts = pd.DataFrame(alerts_list, columns=['Data Quality Issue', 'Category']) | |
| return df_statistics, df_alerts | |
| # Main Function | |
| def main(table): | |
| schema = get_table_schema(table) | |
| df = get_data_df(schema) | |
| df_statistics, df_alerts = statistics(df) | |
| describe_num, describe_cat = describe(df) | |
| messages = chat_template(system_prompt=PROMPT_PANDERA, user_prompt=INPUT_PROMPT, df=df) | |
| tests = run_llm(messages) | |
| print(tests) | |
| if isinstance(tests, Exception): | |
| tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) | |
| return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests, pd.DataFrame([]) | |
| tests_df = pd.DataFrame(tests) | |
| tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) | |
| pandera_results = validate_pandera(tests, df) | |
| return df.head(10), df_statistics, df_alerts, describe_cat, describe_num, tests_df, pandera_results | |
| def user_results(table, text_query): | |
| schema = get_table_schema(table) | |
| df = get_data_df(schema) | |
| messages = chat_template_user(system_prompt=PANDERA_USER_INPUT_PROMPT, | |
| user_prompt=USER_INPUT, user_description=text_query, | |
| df=df) | |
| print(messages) | |
| tests = run_llm(messages) | |
| print(f'Generated Tests from user input: {tests}') | |
| if isinstance(tests, Exception): | |
| tests = pd.DataFrame([{"error": f"❌ Unable to generate tests. {tests}"}]) | |
| return tests, pd.DataFrame([]) | |
| tests_df = pd.DataFrame(tests) | |
| tests_df.rename(columns={tests_df.columns[0]: 'Column', tests_df.columns[1]: 'Rule Name', tests_df.columns[2]: 'Rules' }, inplace=True) | |
| pandera_results = validate_pandera(tests, df) | |
| return tests_df, pandera_results | |
| # Custom CSS styling | |
| custom_css = """ | |
| print('Validated Tests with Pandera') | |
| .gradio-container { | |
| background-color: #f0f4f8; | |
| } | |
| .logo { | |
| max-width: 200px; | |
| margin: 20px auto; | |
| display: block; | |
| } | |
| .gr-button { | |
| background-color: #4a90e2 !important; | |
| } | |
| .gr-button:hover { | |
| background-color: #3a7bc8 !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="purple", secondary_hue="indigo"), css=custom_css) as demo: | |
| gr.Image("logo.png", label=None, show_label=False, container=False, height=100) | |
| gr.Markdown(""" | |
| <div style='text-align: center;'> | |
| <strong style='font-size: 36px;'>Dataset Test Workflow</strong> | |
| <br> | |
| <span style='font-size: 20px;'>Implement and Automate Data Validation Processes.</span> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| schema_dropdown = gr.Dropdown(choices=get_schemas(), label="Select Schema", interactive=True) | |
| tables_dropdown = gr.Dropdown(choices=[], label="Available Tables", value=None) | |
| with gr.Row(): | |
| generate_result = gr.Button("Validate Data", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.Tab("Description"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| data_description = gr.DataFrame(label="Data Description", value=[], interactive=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| describe_cat = gr.DataFrame(label="Categorical Information", value=[], interactive=False) | |
| with gr.Column(): | |
| describe_num = gr.DataFrame(label="Numerical Information", value=[], interactive=False) | |
| with gr.Tab("Alerts"): | |
| data_alerts = gr.DataFrame(label="Alerts", value=[], interactive=False) | |
| with gr.Tab("Rules & Validations"): | |
| tests_output = gr.DataFrame(label="Validation Rules", value=[], interactive=False) | |
| test_result_output = gr.DataFrame(label="Validation Result", value=[], interactive=False) | |
| with gr.Tab("Data"): | |
| result_output = gr.DataFrame(label="Dataframe (10 Rows)", value=[], interactive=False) | |
| with gr.Tab('Text to Validation'): | |
| with gr.Row(): | |
| query_input = gr.Textbox(lines=5, label="Text Query", placeholder="Enter Text Query to Generate Validation e.g. Validate that the incident_zip column contains valid 5-digit ZIP codes.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| pass | |
| with gr.Column(scale=1, min_width=50): | |
| user_generate_result = gr.Button("Validate Data", variant="primary" ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_tests = gr.DataFrame(label="Validation Rules", value=[], interactive=False) | |
| with gr.Column(): | |
| query_result = gr.DataFrame(label="Validation Result", value=[], interactive=False) | |
| schema_dropdown.change(update_table_names, inputs=schema_dropdown, outputs=tables_dropdown) | |
| generate_result.click(main, inputs=[tables_dropdown], outputs=[result_output, data_description, data_alerts, describe_cat, describe_num, tests_output, test_result_output]) | |
| user_generate_result.click(user_results, inputs=[tables_dropdown, query_input], outputs=[query_tests, query_result]) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |