Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import sys | |
| import pandas as pd | |
| import sqlite3 | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import re | |
| # For Hugging Face Spaces, set project root to current directory | |
| PROJECT_ROOT = Path(__file__).parent.resolve() | |
| sys.path.append(str(PROJECT_ROOT)) | |
| # Import model loading and utility functions | |
| from code.train_sqlgen_t5_local import load_model as load_sql_model, generate_sql, get_schema_from_csv | |
| from code.train_intent_classifier_local import load_model as load_intent_model, classify_intent | |
| # Load models | |
| sql_model, sql_tokenizer, device = load_sql_model() | |
| intent_model, intent_tokenizer, device, label_mapping = load_intent_model() | |
| # Path to the built-in data file in the data folder | |
| DATA_FILE = str(PROJECT_ROOT / "data" / "testing_sql_data.csv") | |
| # Verify data file exists | |
| if not os.path.exists(DATA_FILE): | |
| raise FileNotFoundError(f"Data file not found at {DATA_FILE}. Please ensure testing_sql_data.csv exists in the data folder.") | |
| def process_query(question, chart_type="auto"): | |
| try: | |
| # Generate schema from CSV | |
| schema = get_schema_from_csv(DATA_FILE) | |
| # Generate SQL | |
| sql_query = generate_sql(question, schema, sql_model, sql_tokenizer, device) | |
| # --- Fix: Table and column name replacements --- | |
| sql_query = re.sub(r'(FROM|JOIN)\s+\w+', r'\1 data', sql_query, flags=re.IGNORECASE) | |
| sql_query = re.sub(r'(FROM|JOIN)\s+"[^"]+"', r'\1 data', sql_query, flags=re.IGNORECASE) | |
| sql_query = re.sub(r'(FROM|JOIN)\s+\'[^"]+\'', r'\1 data', sql_query, flags=re.IGNORECASE) | |
| sql_query = sql_query.replace('product_price', 'total_price') | |
| sql_query = sql_query.replace('store_name', 'store_id') | |
| sql_query = sql_query.replace('sales_method', 'date') | |
| sql_query = re.sub(r'\bsales\b', 'total_price', sql_query) | |
| # --- End fix --- | |
| # Classify intent | |
| intent = classify_intent(question, intent_model, intent_tokenizer, device, label_mapping) | |
| # Execute SQL on the CSV data | |
| df = pd.read_csv(DATA_FILE) | |
| conn = sqlite3.connect(":memory:") | |
| df.to_sql("data", conn, index=False, if_exists="replace") | |
| result_df = pd.read_sql_query(sql_query, conn) | |
| conn.close() | |
| # Defensive check for result_df columns | |
| if result_df.empty or len(result_df.columns) < 2: | |
| chart_path = None | |
| insights = "No results or not enough columns to display chart/insights." | |
| return result_df, intent, sql_query, chart_path, insights | |
| # Generate chart | |
| chart_path = os.path.join(PROJECT_ROOT, "chart.png") | |
| if not result_df.empty: | |
| plt.figure(figsize=(10, 6)) | |
| if chart_type == "auto": | |
| if intent == "trend": | |
| chart_type = "line" | |
| elif intent == "comparison": | |
| chart_type = "bar" | |
| else: | |
| chart_type = "bar" | |
| if chart_type == "bar": | |
| result_df.plot(kind="bar", x=result_df.columns[0], y=result_df.columns[1]) | |
| elif chart_type == "line": | |
| result_df.plot(kind="line", x=result_df.columns[0], y=result_df.columns[1], marker='o') | |
| elif chart_type == "pie": | |
| result_df.plot(kind="pie", y=result_df.columns[1], labels=result_df[result_df.columns[0]]) | |
| plt.title(question) | |
| plt.tight_layout() | |
| plt.savefig(chart_path) | |
| plt.close() | |
| else: | |
| chart_path = None | |
| # Generate insights | |
| insights = generate_insights(result_df, intent, question) | |
| return result_df, intent, sql_query, chart_path, insights | |
| except Exception as e: | |
| return None, "Error", str(e), None, f"Error: {str(e)}" | |
| def generate_insights(result_df, intent, question): | |
| if result_df is None or result_df.empty or len(result_df.columns) < 2: | |
| return "No data available for insights." | |
| insights = [] | |
| if intent == "summary": | |
| try: | |
| total = result_df[result_df.columns[1]].sum() | |
| insights.append(f"Total {result_df.columns[1]}: {total:,.2f}") | |
| except Exception: | |
| pass | |
| elif intent == "comparison": | |
| if len(result_df) >= 2: | |
| try: | |
| highest = result_df.iloc[0] | |
| lowest = result_df.iloc[-1] | |
| diff = ((highest.iloc[1] / lowest.iloc[1] - 1) * 100) | |
| insights.append(f"{highest.iloc[0]} is {diff:.1f}% higher than {lowest.iloc[0]}") | |
| except Exception: | |
| pass | |
| elif intent == "trend": | |
| if len(result_df) >= 2: | |
| try: | |
| first = result_df.iloc[0][result_df.columns[1]] | |
| last = result_df.iloc[-1][result_df.columns[1]] | |
| change = ((last / first - 1) * 100) | |
| insights.append(f"Overall change: {change:+.1f}%") | |
| except Exception: | |
| pass | |
| insights.append(f"Analysis covers {len(result_df)} records") | |
| if "category" in result_df.columns: | |
| insights.append(f"Number of categories: {result_df['category'].nunique()}") | |
| return "\n".join(f"• {insight}" for insight in insights) | |
| # Clickable FAQs (6 only) | |
| faqs = [ | |
| "What are the top 5 products by quantity sold?", | |
| "What is the total sales amount for each category?", | |
| "Which store had the highest total sales?", | |
| "What are the most popular payment methods?", | |
| "What is the sales trend over time?", | |
| "What is the average transaction value?" | |
| ] | |
| def fill_question(faq): | |
| return gr.update(value=faq) | |
| with gr.Blocks(title="RetailGenie - Natural Language to SQL") as demo: | |
| gr.Markdown(""" | |
| # RetailGenie - Natural Language to SQL | |
| Ask questions in natural language to generate SQL queries and visualizations. Using retail dataset with product sales information. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| question = gr.Textbox( | |
| label="Enter your question", | |
| placeholder="What is the total sales amount for each product category?" | |
| ) | |
| faq_radio = gr.Radio(faqs, label="FAQs (click to autofill)", interactive=True) | |
| faq_radio.change(fn=fill_question, inputs=faq_radio, outputs=question) | |
| chart_type = gr.Radio( | |
| ["auto", "bar", "line", "pie"], | |
| label="Chart Type", | |
| value="auto" | |
| ) | |
| submit_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| with gr.Accordion("SQL and Intent Details", open=False): | |
| intent_output = gr.Textbox(label="Predicted Intent") | |
| sql_output = gr.Textbox(label="Generated SQL", lines=3) | |
| results_df = gr.DataFrame(label="Query Results") | |
| chart_output = gr.Image(label="Chart") | |
| insights_output = gr.Textbox(label="Insights", lines=5) | |
| submit_btn.click( | |
| fn=process_query, | |
| inputs=[question, chart_type], | |
| outputs=[results_df, intent_output, sql_output, chart_output, insights_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |