mhdakmal80 commited on
Commit
d60cb1f
·
verified ·
1 Parent(s): 64a5793

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. README.md +61 -13
  3. app_gradio.py +178 -0
  4. database.py +205 -0
  5. model_loader.py +199 -0
  6. olist.sqlite +3 -0
  7. requirements.txt +12 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ olist.sqlite filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,61 @@
1
- ---
2
- title: Olist Text2sql
3
- emoji: 🐢
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 6.0.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Olist Text-to-SQL Agent
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app_gradio.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # 🤖 Olist Text-to-SQL Agent
14
+
15
+ Convert natural language questions into SQL queries using a **fine-tuned Mistral-7B model**.
16
+
17
+ ## 🎯 Features
18
+
19
+ - **Fine-Tuned Model**: Mistral-7B-Instruct-v0.2 fine-tuned with QLoRA on Olist e-commerce dataset
20
+ - **Natural Language to SQL**: Ask questions in plain English, get executable SQL queries
21
+ - **Real Database**: Query against actual Olist e-commerce data (100K+ orders)
22
+ - **Interactive UI**: Built with Gradio for easy interaction
23
+
24
+ ## 🚀 How to Use
25
+
26
+ 1. Type your question in natural language
27
+ 2. Click "Generate SQL & Execute"
28
+ 3. View the generated SQL query and results
29
+
30
+ ## 💡 Example Questions
31
+
32
+ - "How many orders are there?"
33
+ - "What are the top 5 best-selling products?"
34
+ - "Show total revenue by customer state"
35
+ - "Which sellers have the highest ratings?"
36
+ - "List all orders from São Paulo"
37
+
38
+ ## 🛠️ Tech Stack
39
+
40
+ - **Model**: Mistral-7B-Instruct-v0.2 (fine-tuned with QLoRA)
41
+ - **Frontend**: Gradio
42
+ - **Database**: SQLite (Olist e-commerce dataset)
43
+ - **ML Libraries**: PyTorch, Transformers, PEFT, BitsAndBytes
44
+
45
+ ## 📊 Model Details
46
+
47
+ - **Base Model**: mistralai/Mistral-7B-Instruct-v0.2
48
+ - **Fine-Tuned Model**: [mhdakmal80/Olist-SQL-Agent-Final](https://huggingface.co/mhdakmal80/Olist-SQL-Agent-Final)
49
+ - **Training Method**: QLoRA (4-bit quantization)
50
+ - **Training Data**: 1000+ synthetic question-SQL pairs
51
+ - **Accuracy**: 90% on test set
52
+
53
+ ## 🎓 About
54
+
55
+ This project demonstrates:
56
+ - Fine-tuning large language models (7B parameters)
57
+ - Parameter-efficient fine-tuning with QLoRA
58
+ - Production deployment of ML models
59
+ - Full-stack application development
60
+
61
+ Built by [mhdakmal80](https://huggingface.co/mhdakmal80)
app_gradio.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Olist Text-to-SQL Gradio Application
3
+ Gradio interface for the fine-tuned Mistral-7B model.
4
+ """
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+ from model_loader import FineTunedModelLoader
9
+ from database import DatabaseHandler
10
+ import os
11
+ from dotenv import load_dotenv
12
+
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ # Initialize components
17
+ print("🔄 Initializing model and database...")
18
+ db_path = os.getenv("DATABASE_PATH", "olist.sqlite")
19
+ adapter_path = os.getenv("ADAPTER_PATH", "mhdakmal80/Olist-SQL-Agent-Final")
20
+
21
+ db_handler = DatabaseHandler(db_path)
22
+ model_loader = FineTunedModelLoader(adapter_path=adapter_path)
23
+ db_schema = db_handler.get_schema()
24
+
25
+ print("✅ Model and database loaded!")
26
+
27
+ # Example questions
28
+ EXAMPLES = [
29
+ ["How many orders are there?"],
30
+ ["What are the top 5 best-selling products?"],
31
+ ["Show total revenue by customer state"],
32
+ ["Which sellers have the highest ratings?"],
33
+ ["List all orders from São Paulo"],
34
+ ["What is the average delivery time?"],
35
+ ["Count customers by state"],
36
+ ["Show payment types and their usage"],
37
+ ]
38
+
39
+ def generate_and_execute(question):
40
+ """
41
+ Generate SQL from question and execute it.
42
+
43
+ Args:
44
+ question: Natural language question
45
+
46
+ Returns:
47
+ Tuple of (sql_query, results_df, status_message)
48
+ """
49
+ if not question or not question.strip():
50
+ return "", None, "⚠️ Please enter a question"
51
+
52
+ # Generate SQL
53
+ result = model_loader.generate_sql(question, db_schema)
54
+
55
+ if not result['success']:
56
+ return "", None, f"❌ SQL Generation Failed: {result['error']}"
57
+
58
+ sql_query = result['sql']
59
+
60
+ # Execute query
61
+ exec_result = db_handler.execute_query(sql_query)
62
+
63
+ if not exec_result['success']:
64
+ return sql_query, None, f"❌ Query Execution Failed: {exec_result['error']}"
65
+
66
+ # Format results
67
+ df = exec_result['data']
68
+ row_count = exec_result['row_count']
69
+
70
+ status = f"✅ Success! Retrieved {row_count} rows"
71
+ if exec_result.get('warning'):
72
+ status += f"\n⚠️ {exec_result['warning']}"
73
+
74
+ return sql_query, df, status
75
+
76
+ # Create Gradio interface
77
+ with gr.Blocks(title="Olist Text-to-SQL Agent", theme=gr.themes.Soft()) as demo:
78
+
79
+ gr.Markdown("""
80
+ # 🤖 Olist Text-to-SQL Agent
81
+
82
+ Convert natural language questions into SQL queries using a **fine-tuned Mistral-7B model**.
83
+
84
+ **Model**: Mistral-7B-Instruct-v0.2 fine-tuned with QLoRA on Olist e-commerce dataset
85
+
86
+ ⚠️ **Note**: Running on CPU - queries may take 30-60 seconds. For faster performance, the model supports GPU deployment.
87
+ """)
88
+
89
+ with gr.Row():
90
+ with gr.Column(scale=2):
91
+ question_input = gr.Textbox(
92
+ label="Ask your question",
93
+ placeholder="e.g., What are the top 10 customers by total spending?",
94
+ lines=3
95
+ )
96
+
97
+ with gr.Row():
98
+ submit_btn = gr.Button("🚀 Generate SQL & Execute", variant="primary")
99
+ clear_btn = gr.ClearButton([question_input])
100
+
101
+ with gr.Column(scale=1):
102
+ gr.Markdown("""
103
+ ### 💡 Example Questions
104
+ Click any example to try it!
105
+ """)
106
+
107
+ with gr.Row():
108
+ sql_output = gr.Code(
109
+ label="Generated SQL Query",
110
+ language="sql",
111
+ lines=5
112
+ )
113
+
114
+ with gr.Row():
115
+ status_output = gr.Textbox(
116
+ label="Status",
117
+ lines=2
118
+ )
119
+
120
+ with gr.Row():
121
+ results_output = gr.Dataframe(
122
+ label="Query Results",
123
+ wrap=True,
124
+ max_height=400
125
+ )
126
+
127
+ # Examples section
128
+ gr.Examples(
129
+ examples=EXAMPLES,
130
+ inputs=question_input,
131
+ label="Try these examples:"
132
+ )
133
+
134
+ # Info section
135
+ with gr.Accordion("ℹ️ About this app", open=False):
136
+ gr.Markdown("""
137
+ ### Model Details
138
+ - **Base Model**: mistralai/Mistral-7B-Instruct-v0.2
139
+ - **Fine-Tuned Model**: [mhdakmal80/Olist-SQL-Agent-Final](https://huggingface.co/mhdakmal80/Olist-SQL-Agent-Final)
140
+ - **Training Method**: QLoRA (4-bit quantization)
141
+ - **Training Data**: 1000+ synthetic question-SQL pairs
142
+ - **Accuracy**: 90% on test set
143
+
144
+ ### Database
145
+ - **Dataset**: Olist E-commerce (Brazilian marketplace)
146
+ - **Tables**: 9 tables with 100K+ orders
147
+ - **Columns**: Customer info, orders, products, payments, reviews, sellers
148
+
149
+ ### Tech Stack
150
+ - PyTorch, Transformers, PEFT, BitsAndBytes
151
+ - Gradio for UI
152
+ - SQLite for database
153
+ """)
154
+
155
+ with gr.Accordion("🗄️ Database Schema", open=False):
156
+ gr.Code(
157
+ value=db_schema,
158
+ language="sql",
159
+ label="Database Schema",
160
+ lines=20
161
+ )
162
+
163
+ # Event handlers
164
+ submit_btn.click(
165
+ fn=generate_and_execute,
166
+ inputs=question_input,
167
+ outputs=[sql_output, results_output, status_output]
168
+ )
169
+
170
+ question_input.submit(
171
+ fn=generate_and_execute,
172
+ inputs=question_input,
173
+ outputs=[sql_output, results_output, status_output]
174
+ )
175
+
176
+ # Launch the app
177
+ if __name__ == "__main__":
178
+ demo.launch()
database.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pandas as pd
3
+ from typing import Dict, Any, Optional, List
4
+
5
+
6
+ class DatabaseHandler:
7
+ """Handles all database operations for the Olist database."""
8
+
9
+ def __init__(self, db_path: str = "olist.sqlite"):
10
+ """
11
+ Initialize database handler.
12
+
13
+ Args:
14
+ db_path: Path to SQLite database file
15
+ """
16
+ self.db_path = db_path
17
+ self._verify_database()
18
+
19
+ def _verify_database(self):
20
+ """Verify database exists and is accessible."""
21
+ try:
22
+ conn = sqlite3.connect(self.db_path)
23
+ conn.close()
24
+ except Exception as e:
25
+ raise FileNotFoundError(f"Database not found at {self.db_path}: {str(e)}")
26
+
27
+ def get_schema(self) -> str:
28
+ """
29
+ Extract and format database schema.
30
+
31
+ Returns:
32
+ Formatted schema string with all tables and columns
33
+ """
34
+ try:
35
+ conn = sqlite3.connect(self.db_path)
36
+ cursor = conn.cursor()
37
+
38
+ # Get all table names
39
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
40
+ tables = cursor.fetchall()
41
+
42
+ schema_parts = []
43
+
44
+ for table in tables:
45
+ table_name = table[0]
46
+
47
+ # Get column information
48
+ cursor.execute(f"PRAGMA table_info({table_name});")
49
+ columns = cursor.fetchall()
50
+
51
+ # Format table schema
52
+ schema_parts.append(f"\nTable: {table_name}")
53
+ schema_parts.append("Columns:")
54
+
55
+ for col in columns:
56
+ col_name = col[1]
57
+ col_type = col[2]
58
+ is_pk = " (PRIMARY KEY)" if col[5] else ""
59
+ schema_parts.append(f" - {col_name} ({col_type}){is_pk}")
60
+
61
+ conn.close()
62
+
63
+ return "\n".join(schema_parts)
64
+
65
+ except Exception as e:
66
+ return f"Error extracting schema: {str(e)}"
67
+
68
+ def execute_query(self, sql: str, max_rows: int = 1000) -> Dict[str, Any]:
69
+ """
70
+ Execute SQL query and return results.
71
+
72
+ Args:
73
+ sql: SQL query to execute
74
+ max_rows: Maximum number of rows to return
75
+
76
+ Returns:
77
+ Dictionary with:
78
+ - success: Boolean indicating success
79
+ - data: Pandas DataFrame with results
80
+ - row_count: Number of rows returned
81
+ - error: Error message if failed
82
+ """
83
+ # Validate query first
84
+ if not self._validate_query(sql):
85
+ return {
86
+ "success": False,
87
+ "data": None,
88
+ "row_count": 0,
89
+ "error": "Query validation failed: Only SELECT queries are allowed"
90
+ }
91
+
92
+ try:
93
+ conn = sqlite3.connect(self.db_path)
94
+
95
+ # Execute query and fetch results
96
+ df = pd.read_sql_query(sql, conn)
97
+
98
+ # Limit rows if needed
99
+ if len(df) > max_rows:
100
+ df = df.head(max_rows)
101
+ warning = f"Results limited to {max_rows} rows"
102
+ else:
103
+ warning = None
104
+
105
+ conn.close()
106
+
107
+ return {
108
+ "success": True,
109
+ "data": df,
110
+ "row_count": len(df),
111
+ "error": None,
112
+ "warning": warning
113
+ }
114
+
115
+ except Exception as e:
116
+ return {
117
+ "success": False,
118
+ "data": None,
119
+ "row_count": 0,
120
+ "error": f"Query execution error: {str(e)}"
121
+ }
122
+
123
+ def _validate_query(self, sql: str) -> bool:
124
+ """
125
+ Validate SQL query for safety.
126
+
127
+ Args:
128
+ sql: SQL query to validate
129
+
130
+ Returns:
131
+ True if query is safe, False otherwise
132
+ """
133
+ sql_upper = sql.upper().strip()
134
+
135
+ # Only allow SELECT queries
136
+ if not sql_upper.startswith("SELECT"):
137
+ return False
138
+
139
+ # Block dangerous keywords
140
+ dangerous_keywords = [
141
+ "DROP", "DELETE", "INSERT", "UPDATE",
142
+ "ALTER", "CREATE", "TRUNCATE", "REPLACE"
143
+ ]
144
+
145
+ for keyword in dangerous_keywords:
146
+ if keyword in sql_upper:
147
+ return False
148
+
149
+ return True
150
+
151
+ def get_table_names(self) -> List[str]:
152
+ """
153
+ Get list of all table names in database.
154
+
155
+ Returns:
156
+ List of table names
157
+ """
158
+ try:
159
+ conn = sqlite3.connect(self.db_path)
160
+ cursor = conn.cursor()
161
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
162
+ tables = [row[0] for row in cursor.fetchall()]
163
+ conn.close()
164
+ return tables
165
+ except Exception as e:
166
+ print(f"Error getting table names: {e}")
167
+ return []
168
+
169
+ def get_table_preview(self, table_name: str, limit: int = 5) -> Optional[pd.DataFrame]:
170
+ """
171
+ Get preview of table data.
172
+
173
+ Args:
174
+ table_name: Name of table to preview
175
+ limit: Number of rows to return
176
+
177
+ Returns:
178
+ DataFrame with sample data or None if error
179
+ """
180
+ try:
181
+ conn = sqlite3.connect(self.db_path)
182
+ df = pd.read_sql_query(f"SELECT * FROM {table_name} LIMIT {limit};", conn)
183
+ conn.close()
184
+ return df
185
+ except Exception as e:
186
+ print(f"Error previewing table {table_name}: {e}")
187
+ return None
188
+
189
+
190
+ # Test function
191
+ if __name__ == "__main__":
192
+ # Quick test
193
+ db = DatabaseHandler("olist.sqlite")
194
+
195
+ print("=== Database Schema ===")
196
+ print(db.get_schema())
197
+
198
+ print("\n=== Table Names ===")
199
+ print(db.get_table_names())
200
+
201
+ print("\n=== Test Query ===")
202
+ result = db.execute_query("SELECT COUNT(*) as total_orders FROM orders;")
203
+ print(f"Success: {result['success']}")
204
+ if result['success']:
205
+ print(result['data'])
model_loader.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
+ from peft import PeftModel
4
+ from typing import Dict, Any, Optional
5
+ import re
6
+
7
+
8
+ class FineTunedModelLoader:
9
+ """Loads and manages the fine-tuned Mistral-7B model."""
10
+
11
+ def __init__(self,
12
+ base_model_name: str = "mistralai/Mistral-7B-Instruct-v0.2",
13
+ adapter_path: str = "mhdakmal80/Olist-SQL-Agent-Final",
14
+ use_4bit: bool = True):
15
+ """
16
+ Initialize the fine-tuned model.
17
+
18
+ Args:
19
+ base_model_name: HuggingFace model name
20
+ adapter_path: Path to LoRA adapter weights
21
+ use_4bit: Whether to use 4-bit quantization
22
+ """
23
+ self.base_model_name = base_model_name
24
+ self.adapter_path = adapter_path
25
+ self.use_4bit = use_4bit
26
+
27
+ print(" Loading fine-tuned model...")
28
+ self.model, self.tokenizer = self._load_model()
29
+ print(" Model loaded successfully!")
30
+
31
+ def _load_model(self):
32
+ """Load the base model and LoRA adapters."""
33
+
34
+ # Configure 4-bit quantization if enabled
35
+ if self.use_4bit:
36
+ bnb_config = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_quant_type="nf4",
39
+ bnb_4bit_compute_dtype=torch.bfloat16,
40
+ bnb_4bit_use_double_quant=False,
41
+ )
42
+ else:
43
+ bnb_config = None
44
+
45
+ # Load base model
46
+ print(f" Loading base model: {self.base_model_name}")
47
+ base_model = AutoModelForCausalLM.from_pretrained(
48
+ self.base_model_name,
49
+ quantization_config=bnb_config if self.use_4bit else None,
50
+ torch_dtype=torch.bfloat16 if not self.use_4bit else None,
51
+ device_map="auto",
52
+ trust_remote_code=True,
53
+ )
54
+
55
+ # Load tokenizer
56
+ print(f" Loading tokenizer")
57
+ tokenizer = AutoTokenizer.from_pretrained(
58
+ self.base_model_name,
59
+ trust_remote_code=True
60
+ )
61
+ tokenizer.pad_token = tokenizer.eos_token
62
+ tokenizer.padding_side = "right"
63
+
64
+ # Load LoRA adapter
65
+ print(f" Loading LoRA adapter from: {self.adapter_path}")
66
+ model = PeftModel.from_pretrained(base_model, self.adapter_path)
67
+
68
+ return model, tokenizer
69
+
70
+ def generate_sql(self, question: str, schema: str) -> Dict[str, Any]:
71
+ """
72
+ Generate SQL query from natural language question.
73
+
74
+ Args:
75
+ question: User's natural language question
76
+ schema: Database schema as string
77
+
78
+ Returns:
79
+ Dictionary with 'sql', 'success', and 'error' keys
80
+ """
81
+ # Format prompt
82
+ prompt = f"""[INST]You are a SQL expert. Generate a valid SQLite query using ONLY the columns and tables listed below.
83
+ Don't ever use columns that is not in the schema (this need to be followed strictly).Always try to come up the
84
+ solution based on provided schema only.
85
+
86
+ ### Available Tables and Columns:
87
+
88
+ {schema}
89
+
90
+ ### IMPORTANT:
91
+ - Use ONLY the column names listed above
92
+ - Do NOT invent column names
93
+ - Do NOT use columns that don't exist
94
+
95
+ ### Question:
96
+ {question}
97
+
98
+ ### Generate SQL using only the columns listed above:
99
+ [/INST]```sql
100
+ """
101
+
102
+ try:
103
+ # Tokenize
104
+ inputs = self.tokenizer(
105
+ prompt,
106
+ return_tensors="pt",
107
+ truncation=True,
108
+ max_length=512
109
+ )
110
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
111
+
112
+ # Generate
113
+ with torch.no_grad():
114
+ outputs = self.model.generate(
115
+ **inputs,
116
+ max_new_tokens=256,
117
+ temperature=0.1,
118
+ do_sample=False,
119
+ pad_token_id=self.tokenizer.eos_token_id,
120
+ eos_token_id=self.tokenizer.eos_token_id,
121
+ )
122
+
123
+ # Decode
124
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
125
+
126
+ # Extract SQL from response
127
+ sql_query = self._extract_sql(generated_text, prompt)
128
+
129
+ return {
130
+ "sql": sql_query,
131
+ "success": True,
132
+ "error": None
133
+ }
134
+
135
+ except Exception as e:
136
+ return {
137
+ "sql": "",
138
+ "success": False,
139
+ "error": f"Model Error: {str(e)}"
140
+ }
141
+
142
+ def _extract_sql(self, generated_text: str, prompt: str) -> str:
143
+ """
144
+ Extract SQL query from generated text.
145
+
146
+ Args:
147
+ generated_text: Full generated text from model
148
+ prompt: Original prompt (to remove from output)
149
+
150
+ Returns:
151
+ Cleaned SQL query
152
+ """
153
+ # Remove the prompt from the generated text
154
+ sql = generated_text.replace(prompt, "").strip()
155
+
156
+ # Try to extract SQL after "### SQL Query:" marker
157
+ patterns = [
158
+ r"### SQL Query:\s*(.+?)(?:###|$)",
159
+ r"```sql\s*(.+?)\s*```",
160
+ r"SELECT\s+.+",
161
+ ]
162
+
163
+ for pattern in patterns:
164
+ match = re.search(pattern, sql, re.IGNORECASE | re.DOTALL)
165
+ if match:
166
+ sql = match.group(1) if match.lastindex else match.group(0)
167
+ break
168
+
169
+ # Clean up
170
+ sql = sql.replace("```sql", "").replace("```", "")
171
+ sql = " ".join(sql.split()) # Remove extra whitespace
172
+ sql = sql.strip()
173
+
174
+ # Ensure it ends with semicolon
175
+ if not sql.endswith(";"):
176
+ sql += ";"
177
+
178
+ return sql
179
+
180
+
181
+ # Test function
182
+ if __name__ == "__main__":
183
+ # Quick test
184
+ model_loader = FineTunedModelLoader()
185
+
186
+ test_schema = """
187
+ Table: orders
188
+ Columns: order_id, customer_id, order_status, order_purchase_timestamp
189
+ """
190
+
191
+ result = model_loader.generate_sql(
192
+ "How many orders are there?",
193
+ test_schema
194
+ )
195
+
196
+ print(f"\nSuccess: {result['success']}")
197
+ print(f"SQL: {result['sql']}")
198
+ if result['error']:
199
+ print(f"Error: {result['error']}")
olist.sqlite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49446afd935721ee12fc95316fbee9666a3e1bd4872dfa194fe4625d6762a81a
3
+ size 112701440
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio>=4.0.0
3
+ python-dotenv==1.0.0
4
+ pandas==2.1.4
5
+
6
+ # ML/AI dependencies for fine-tuned model
7
+ torch>=2.0.0
8
+ transformers>=4.35.0
9
+ accelerate>=0.24.0
10
+ peft>=0.6.0
11
+ bitsandbytes>=0.41.0
12
+ sentencepiece>=0.1.99