Spaces:
Runtime error
Runtime error
| """ | |
| Olist Text-to-SQL Gradio Application | |
| Gradio interface for the fine-tuned Mistral-7B model. | |
| """ | |
| import gradio as gr | |
| import pandas as pd | |
| from model_loader import FineTunedModelLoader | |
| from database import DatabaseHandler | |
| import os | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| # Global variables for lazy loading | |
| db_handler = None | |
| model_loader = None | |
| db_schema = None | |
| def initialize_components(): | |
| """Initialize model and database on first use (lazy loading).""" | |
| global db_handler, model_loader, db_schema | |
| if model_loader is None: | |
| print(" Initializing model and database...") | |
| db_path = os.getenv("DATABASE_PATH", "olist.sqlite") | |
| adapter_path = os.getenv("ADAPTER_PATH", "mhdakmal80/Olist-SQL-Agent-Final") | |
| db_handler = DatabaseHandler(db_path) | |
| model_loader = FineTunedModelLoader(adapter_path=adapter_path) | |
| db_schema = db_handler.get_schema() | |
| print(" Model and database loaded!") | |
| return db_handler, model_loader, db_schema | |
| # Example questions | |
| EXAMPLES = [ | |
| ["How many orders are there?"], | |
| ["What are the top 5 best-selling products?"], | |
| ["Show total revenue by customer state"], | |
| ["Which sellers have the highest ratings?"], | |
| ["List all orders from São Paulo"], | |
| ["What is the average delivery time?"], | |
| ["Count customers by state"], | |
| ["Show payment types and their usage"], | |
| ] | |
| def generate_and_execute(question): | |
| """ | |
| Generate SQL from question and execute it. | |
| Args: | |
| question: Natural language question | |
| Returns: | |
| Tuple of (sql_query, results_df, status_message) | |
| """ | |
| if not question or not question.strip(): | |
| return "", None, " Please enter a question" | |
| # Initialize components on first use (lazy loading) | |
| db_handler, model_loader, db_schema = initialize_components() | |
| # Generate SQL | |
| result = model_loader.generate_sql(question, db_schema) | |
| if not result['success']: | |
| return "", None, f" SQL Generation Failed: {result['error']}" | |
| sql_query = result['sql'] | |
| # Execute query | |
| exec_result = db_handler.execute_query(sql_query) | |
| if not exec_result['success']: | |
| return sql_query, None, f" Query Execution Failed: {exec_result['error']}" | |
| # Format results | |
| df = exec_result['data'] | |
| row_count = exec_result['row_count'] | |
| status = f" Success! Retrieved {row_count} rows" | |
| if exec_result.get('warning'): | |
| status += f"\n {exec_result['warning']}" | |
| return sql_query, df, status | |
| # Create Gradio interface | |
| with gr.Blocks(title="Olist Text-to-SQL Agent", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🤖 Olist Text-to-SQL Agent | |
| Convert natural language questions into SQL queries using a **fine-tuned Mistral-7B model**. | |
| **Model**: Mistral-7B-Instruct-v0.2 fine-tuned with QLoRA on Olist e-commerce dataset | |
| **Note**: Running on CPU - queries may take 30-60 seconds. For faster performance, the model supports GPU deployment. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| question_input = gr.Textbox( | |
| label="Ask your question", | |
| placeholder="e.g., What are the top 10 customers by total spending?", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button(" Generate SQL & Execute", variant="primary") | |
| clear_btn = gr.ClearButton([question_input]) | |
| with gr.Column(scale=1): | |
| gr.Markdown(""" | |
| ### 💡 Example Questions | |
| Click any example to try it! | |
| """) | |
| with gr.Row(): | |
| sql_output = gr.Code( | |
| label="Generated SQL Query", | |
| language="sql", | |
| lines=5 | |
| ) | |
| with gr.Row(): | |
| status_output = gr.Textbox( | |
| label="Status", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| results_output = gr.Dataframe( | |
| label="Query Results", | |
| wrap=True | |
| ) | |
| # Examples section | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=question_input, | |
| label="Try these examples:" | |
| ) | |
| # Info section | |
| with gr.Accordion("ℹ About this app", open=False): | |
| gr.Markdown(""" | |
| ### Model Details | |
| - **Base Model**: mistralai/Mistral-7B-Instruct-v0.2 | |
| - **Fine-Tuned Model**: [mhdakmal80/Olist-SQL-Agent-Final](https://huggingface.co/mhdakmal80/Olist-SQL-Agent-Final) | |
| - **Training Method**: QLoRA (4-bit quantization) | |
| - **Training Data**: 1000+ synthetic question-SQL pairs | |
| - **Accuracy**: 90% on test set | |
| ### Database | |
| - **Dataset**: Olist E-commerce (Brazilian marketplace) | |
| - **Tables**: 9 tables with 100K+ orders | |
| - **Columns**: Customer info, orders, products, payments, reviews, sellers | |
| ### Tech Stack | |
| - PyTorch, Transformers, PEFT, BitsAndBytes | |
| - Gradio for UI | |
| - SQLite for database | |
| """) | |
| with gr.Accordion("Database Schema", open=False): | |
| gr.Markdown(""" | |
| The database schema will be loaded when you submit your first query. | |
| **Tables**: orders, customers, products, sellers, payments, reviews, etc. | |
| """) | |
| # Event handlers | |
| submit_btn.click( | |
| fn=generate_and_execute, | |
| inputs=question_input, | |
| outputs=[sql_output, results_output, status_output] | |
| ) | |
| question_input.submit( | |
| fn=generate_and_execute, | |
| inputs=question_input, | |
| outputs=[sql_output, results_output, status_output] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |