mhdakmal80 commited on
Commit
6a096d0
·
verified ·
1 Parent(s): d60cb1f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app_gradio.py +37 -24
  2. model_loader.py +15 -4
app_gradio.py CHANGED
@@ -13,16 +13,27 @@ from dotenv import load_dotenv
13
  # Load environment variables
14
  load_dotenv()
15
 
16
- # Initialize components
17
- print("🔄 Initializing model and database...")
18
- db_path = os.getenv("DATABASE_PATH", "olist.sqlite")
19
- adapter_path = os.getenv("ADAPTER_PATH", "mhdakmal80/Olist-SQL-Agent-Final")
20
 
21
- db_handler = DatabaseHandler(db_path)
22
- model_loader = FineTunedModelLoader(adapter_path=adapter_path)
23
- db_schema = db_handler.get_schema()
24
-
25
- print("✅ Model and database loaded!")
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Example questions
28
  EXAMPLES = [
@@ -47,13 +58,16 @@ def generate_and_execute(question):
47
  Tuple of (sql_query, results_df, status_message)
48
  """
49
  if not question or not question.strip():
50
- return "", None, "⚠️ Please enter a question"
 
 
 
51
 
52
  # Generate SQL
53
  result = model_loader.generate_sql(question, db_schema)
54
 
55
  if not result['success']:
56
- return "", None, f" SQL Generation Failed: {result['error']}"
57
 
58
  sql_query = result['sql']
59
 
@@ -61,15 +75,15 @@ def generate_and_execute(question):
61
  exec_result = db_handler.execute_query(sql_query)
62
 
63
  if not exec_result['success']:
64
- return sql_query, None, f" Query Execution Failed: {exec_result['error']}"
65
 
66
  # Format results
67
  df = exec_result['data']
68
  row_count = exec_result['row_count']
69
 
70
- status = f" Success! Retrieved {row_count} rows"
71
  if exec_result.get('warning'):
72
- status += f"\n⚠️ {exec_result['warning']}"
73
 
74
  return sql_query, df, status
75
 
@@ -83,7 +97,7 @@ with gr.Blocks(title="Olist Text-to-SQL Agent", theme=gr.themes.Soft()) as demo:
83
 
84
  **Model**: Mistral-7B-Instruct-v0.2 fine-tuned with QLoRA on Olist e-commerce dataset
85
 
86
- ⚠️ **Note**: Running on CPU - queries may take 30-60 seconds. For faster performance, the model supports GPU deployment.
87
  """)
88
 
89
  with gr.Row():
@@ -95,7 +109,7 @@ with gr.Blocks(title="Olist Text-to-SQL Agent", theme=gr.themes.Soft()) as demo:
95
  )
96
 
97
  with gr.Row():
98
- submit_btn = gr.Button("🚀 Generate SQL & Execute", variant="primary")
99
  clear_btn = gr.ClearButton([question_input])
100
 
101
  with gr.Column(scale=1):
@@ -132,7 +146,7 @@ with gr.Blocks(title="Olist Text-to-SQL Agent", theme=gr.themes.Soft()) as demo:
132
  )
133
 
134
  # Info section
135
- with gr.Accordion("ℹ️ About this app", open=False):
136
  gr.Markdown("""
137
  ### Model Details
138
  - **Base Model**: mistralai/Mistral-7B-Instruct-v0.2
@@ -152,13 +166,12 @@ with gr.Blocks(title="Olist Text-to-SQL Agent", theme=gr.themes.Soft()) as demo:
152
  - SQLite for database
153
  """)
154
 
155
- with gr.Accordion("🗄️ Database Schema", open=False):
156
- gr.Code(
157
- value=db_schema,
158
- language="sql",
159
- label="Database Schema",
160
- lines=20
161
- )
162
 
163
  # Event handlers
164
  submit_btn.click(
 
13
  # Load environment variables
14
  load_dotenv()
15
 
16
+ # Global variables for lazy loading
17
+ db_handler = None
18
+ model_loader = None
19
+ db_schema = None
20
 
21
+ def initialize_components():
22
+ """Initialize model and database on first use (lazy loading)."""
23
+ global db_handler, model_loader, db_schema
24
+
25
+ if model_loader is None:
26
+ print(" Initializing model and database...")
27
+ db_path = os.getenv("DATABASE_PATH", "olist.sqlite")
28
+ adapter_path = os.getenv("ADAPTER_PATH", "mhdakmal80/Olist-SQL-Agent-Final")
29
+
30
+ db_handler = DatabaseHandler(db_path)
31
+ model_loader = FineTunedModelLoader(adapter_path=adapter_path)
32
+ db_schema = db_handler.get_schema()
33
+
34
+ print(" Model and database loaded!")
35
+
36
+ return db_handler, model_loader, db_schema
37
 
38
  # Example questions
39
  EXAMPLES = [
 
58
  Tuple of (sql_query, results_df, status_message)
59
  """
60
  if not question or not question.strip():
61
+ return "", None, " Please enter a question"
62
+
63
+ # Initialize components on first use (lazy loading)
64
+ db_handler, model_loader, db_schema = initialize_components()
65
 
66
  # Generate SQL
67
  result = model_loader.generate_sql(question, db_schema)
68
 
69
  if not result['success']:
70
+ return "", None, f" SQL Generation Failed: {result['error']}"
71
 
72
  sql_query = result['sql']
73
 
 
75
  exec_result = db_handler.execute_query(sql_query)
76
 
77
  if not exec_result['success']:
78
+ return sql_query, None, f" Query Execution Failed: {exec_result['error']}"
79
 
80
  # Format results
81
  df = exec_result['data']
82
  row_count = exec_result['row_count']
83
 
84
+ status = f" Success! Retrieved {row_count} rows"
85
  if exec_result.get('warning'):
86
+ status += f"\n {exec_result['warning']}"
87
 
88
  return sql_query, df, status
89
 
 
97
 
98
  **Model**: Mistral-7B-Instruct-v0.2 fine-tuned with QLoRA on Olist e-commerce dataset
99
 
100
+ **Note**: Running on CPU - queries may take 30-60 seconds. For faster performance, the model supports GPU deployment.
101
  """)
102
 
103
  with gr.Row():
 
109
  )
110
 
111
  with gr.Row():
112
+ submit_btn = gr.Button(" Generate SQL & Execute", variant="primary")
113
  clear_btn = gr.ClearButton([question_input])
114
 
115
  with gr.Column(scale=1):
 
146
  )
147
 
148
  # Info section
149
+ with gr.Accordion(" About this app", open=False):
150
  gr.Markdown("""
151
  ### Model Details
152
  - **Base Model**: mistralai/Mistral-7B-Instruct-v0.2
 
166
  - SQLite for database
167
  """)
168
 
169
+ with gr.Accordion("Database Schema", open=False):
170
+ gr.Markdown("""
171
+ The database schema will be loaded when you submit your first query.
172
+
173
+ **Tables**: orders, customers, products, sellers, payments, reviews, etc.
174
+ """)
 
175
 
176
  # Event handlers
177
  submit_btn.click(
model_loader.py CHANGED
@@ -31,25 +31,36 @@ class FineTunedModelLoader:
31
  def _load_model(self):
32
  """Load the base model and LoRA adapters."""
33
 
34
- # Configure 4-bit quantization if enabled
35
- if self.use_4bit:
 
 
 
 
 
 
 
 
36
  bnb_config = BitsAndBytesConfig(
37
  load_in_4bit=True,
38
  bnb_4bit_quant_type="nf4",
39
  bnb_4bit_compute_dtype=torch.bfloat16,
40
  bnb_4bit_use_double_quant=False,
41
  )
 
42
  else:
43
  bnb_config = None
 
44
 
45
  # Load base model
46
  print(f" Loading base model: {self.base_model_name}")
47
  base_model = AutoModelForCausalLM.from_pretrained(
48
  self.base_model_name,
49
- quantization_config=bnb_config if self.use_4bit else None,
50
- torch_dtype=torch.bfloat16 if not self.use_4bit else None,
51
  device_map="auto",
52
  trust_remote_code=True,
 
53
  )
54
 
55
  # Load tokenizer
 
31
  def _load_model(self):
32
  """Load the base model and LoRA adapters."""
33
 
34
+ # Check if GPU is available
35
+ has_gpu = torch.cuda.is_available()
36
+
37
+ if not has_gpu:
38
+ print(" ⚠️ No GPU detected - loading model on CPU (this will be slow)")
39
+ print(" ⚠️ Disabling 4-bit quantization (requires GPU)")
40
+ self.use_4bit = False # Force disable 4-bit on CPU
41
+
42
+ # Configure 4-bit quantization only if GPU available
43
+ if self.use_4bit and has_gpu:
44
  bnb_config = BitsAndBytesConfig(
45
  load_in_4bit=True,
46
  bnb_4bit_quant_type="nf4",
47
  bnb_4bit_compute_dtype=torch.bfloat16,
48
  bnb_4bit_use_double_quant=False,
49
  )
50
+ print(" ✅ Using 4-bit quantization (GPU)")
51
  else:
52
  bnb_config = None
53
+ print(" ℹ️ Using float32 (CPU mode)")
54
 
55
  # Load base model
56
  print(f" Loading base model: {self.base_model_name}")
57
  base_model = AutoModelForCausalLM.from_pretrained(
58
  self.base_model_name,
59
+ quantization_config=bnb_config if (self.use_4bit and has_gpu) else None,
60
+ torch_dtype=torch.float32 if not has_gpu else torch.bfloat16, # float32 for CPU
61
  device_map="auto",
62
  trust_remote_code=True,
63
+ low_cpu_mem_usage=True, # Optimize CPU memory
64
  )
65
 
66
  # Load tokenizer