mohhhhhit commited on
Commit
c657ef6
·
verified ·
1 Parent(s): 745ca15

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +1036 -35
src/streamlit_app.py CHANGED
@@ -1,40 +1,1041 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
4
  import streamlit as st
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ """
2
+ NoNoQL - Natural Language to SQL/MongoDB Query Generator
3
+ Streamlit Frontend Application
4
+ """
5
+
6
  import streamlit as st
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
+ import os
10
+ import json
11
+ from datetime import datetime
12
 
13
+ HISTORY_FILE_PATH = os.path.join(
14
+ os.path.dirname(os.path.abspath(__file__)),
15
+ "data",
16
+ "query_history.json"
17
+ )
18
 
19
+ SCHEMA_FILE_PATH = os.path.join(
20
+ os.path.dirname(os.path.abspath(__file__)),
21
+ "data",
22
+ "database_schema.txt"
23
+ )
24
+
25
+ DEFAULT_SCHEMA = """**employees**
26
+ - employee_id, name, email
27
+ - department, salary, hire_date, age
28
+
29
+ **departments**
30
+ - department_id, department_name
31
+ - manager_id, budget, location
32
+
33
+ **projects**
34
+ - project_id, project_name
35
+ - start_date, end_date, budget, status
36
+
37
+ **orders**
38
+ - order_id, customer_name
39
+ - product_name, quantity
40
+ - order_date, total_amount
41
+
42
+ **products**
43
+ - product_id, product_name
44
+ - category, price
45
+ - stock_quantity, supplier"""
46
+
47
+ # Page configuration
48
+ st.set_page_config(
49
+ page_title="NoNoQL - Natural Language to SQL/MongoDB Query Generator",
50
+ page_icon="🔍",
51
+ layout="wide",
52
+ initial_sidebar_state="expanded"
53
+ )
54
+
55
+ # Custom CSS
56
+ st.markdown("""
57
+ <style>
58
+ /* Inject title into Streamlit header bar */
59
+ header[data-testid="stHeader"] {
60
+ background-color: rgba(14, 17, 23, 0.95) !important;
61
+ }
62
+
63
+ header[data-testid="stHeader"]::before {
64
+ content: "NoNoQL";
65
+ color: white;
66
+ font-size: 1.3rem;
67
+ font-weight: 600;
68
+ position: absolute;
69
+ left: 1rem;
70
+ top: 50%;
71
+ transform: translateY(-50%);
72
+ z-index: 999;
73
+ }
74
+
75
+ .query-box {
76
+ background-color: #f0f2f6;
77
+ border-radius: 10px;
78
+ padding: 20px;
79
+ margin: 10px 0;
80
+ border-left: 5px solid #1E88E5;
81
+ }
82
+ .success-box {
83
+ background-color: #d4edda;
84
+ border-radius: 10px;
85
+ padding: 20px;
86
+ margin: 10px 0;
87
+ border-left: 5px solid #28a745;
88
+ }
89
+ .example-query {
90
+ background-color: #fff3cd;
91
+ border-radius: 5px;
92
+ padding: 10px;
93
+ margin: 5px 0;
94
+ cursor: pointer;
95
+ }
96
+ .example-query:hover {
97
+ background-color: #ffe69c;
98
+ }
99
+ .stButton>button {
100
+ width: 100%;
101
+ background-color: #1E88E5;
102
+ color: white;
103
+ font-size: 1.1rem;
104
+ padding: 0.5rem 1rem;
105
+ border-radius: 10px;
106
+ border: none;
107
+ margin-top: 1rem;
108
+ }
109
+ .stButton>button:hover {
110
+ background-color: #1565C0;
111
+ }
112
+ </style>
113
+ """, unsafe_allow_html=True)
114
+
115
+
116
+ def extract_columns_from_nl(natural_language_query):
117
+ """Extract table name and column names from natural language query"""
118
+ import re
119
+
120
+ nl = natural_language_query.lower().strip()
121
+
122
+ # Extract table name
123
+ table_match = re.search(r'(?:table|collection)\s+(?:named|called)?\s*(\w+)', nl)
124
+ table_name = table_match.group(1) if table_match else None
125
+
126
+ # Extract column names - look for patterns like "columns as X, Y, Z" or "with X, Y, Z"
127
+ columns = []
128
+
129
+ # Pattern 1: "columns as/named X, Y, Z"
130
+ col_match = re.search(r'columns?\s+(?:as|named|like|called)?\s*([^,]+(?:,\s*[^,]+)*)', nl)
131
+ if col_match:
132
+ col_text = col_match.group(1)
133
+ # Split by comma or 'and'
134
+ columns = re.split(r',|\s+and\s+', col_text)
135
+ columns = [c.strip() for c in columns if c.strip()]
136
+
137
+ # Pattern 2: "add columns X, Y, Z"
138
+ if not columns:
139
+ col_match = re.search(r'(?:add|with)\s+(?:columns?)?\s*([^,]+(?:,\s*[^,]+)*)', nl)
140
+ if col_match:
141
+ col_text = col_match.group(1)
142
+ columns = re.split(r',|\s+and\s+', col_text)
143
+ columns = [c.strip() for c in columns if c.strip()]
144
+
145
+ return table_name, columns
146
+
147
+
148
+ def fix_create_table_sql(generated_sql, table_name, requested_columns):
149
+ """Replace hallucinated columns with actual requested columns in CREATE TABLE"""
150
+ import re
151
+
152
+ if not table_name or not requested_columns:
153
+ return generated_sql
154
+
155
+ # Check if it's a CREATE TABLE query
156
+ if not re.search(r'CREATE\s+TABLE', generated_sql, re.IGNORECASE):
157
+ return generated_sql
158
+
159
+ # Default data types for common column patterns
160
+ def infer_type(col_name):
161
+ col_lower = col_name.lower()
162
+ if 'id' in col_lower:
163
+ return 'INT PRIMARY KEY'
164
+ elif any(word in col_lower for word in ['name', 'title', 'description', 'address', 'city']):
165
+ return 'VARCHAR(100)'
166
+ elif any(word in col_lower for word in ['email']):
167
+ return 'VARCHAR(100)'
168
+ elif any(word in col_lower for word in ['phone', 'contact', 'mobile']):
169
+ return 'VARCHAR(20)'
170
+ elif any(word in col_lower for word in ['date', 'created', 'updated']):
171
+ return 'DATE'
172
+ elif any(word in col_lower for word in ['price', 'salary', 'amount', 'cost']):
173
+ return 'DECIMAL(10,2)'
174
+ elif any(word in col_lower for word in ['age', 'quantity', 'count', 'stock']):
175
+ return 'INT'
176
+ elif any(word in col_lower for word in ['status', 'type', 'category']):
177
+ return 'VARCHAR(50)'
178
+ else:
179
+ return 'VARCHAR(100)'
180
+
181
+ # Build column definitions
182
+ col_defs = []
183
+ for col in requested_columns:
184
+ col_clean = col.strip()
185
+ if col_clean:
186
+ col_type = infer_type(col_clean)
187
+ col_defs.append(f"{col_clean} {col_type}")
188
+
189
+ # Rebuild a clean CREATE TABLE statement from requested columns.
190
+ # This avoids malformed model output leaking extra columns outside parentheses.
191
+ if_not_exists_match = re.search(
192
+ r'CREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\s+' + re.escape(table_name),
193
+ generated_sql,
194
+ re.IGNORECASE
195
+ )
196
+ if if_not_exists_match:
197
+ create_clause = if_not_exists_match.group(0)
198
+ else:
199
+ create_match = re.search(
200
+ r'CREATE\s+TABLE\s+' + re.escape(table_name),
201
+ generated_sql,
202
+ re.IGNORECASE
203
+ )
204
+ if not create_match:
205
+ return generated_sql
206
+ create_clause = create_match.group(0)
207
+
208
+ new_columns = ', '.join(col_defs)
209
+ return f"{create_clause} ({new_columns});"
210
+
211
+
212
+ def fix_create_collection_mongo(generated_mongo, table_name, requested_columns):
213
+ """Fix MongoDB createCollection to use correct collection name and sample document"""
214
+ if not table_name:
215
+ return generated_mongo
216
+
217
+ # Build sample document with requested columns
218
+ doc_fields = []
219
+ for col in requested_columns:
220
+ col_clean = col.strip()
221
+ if col_clean:
222
+ # Provide example values based on column name
223
+ if 'id' in col_clean.lower():
224
+ doc_fields.append(f'"{col_clean}": 1')
225
+ elif any(word in col_clean.lower() for word in ['name', 'title']):
226
+ doc_fields.append(f'"{col_clean}": "sample_name"')
227
+ elif 'email' in col_clean.lower():
228
+ doc_fields.append(f'"{col_clean}": "user@example.com"')
229
+ elif any(word in col_clean.lower() for word in ['phone', 'contact']):
230
+ doc_fields.append(f'"{col_clean}": "1234567890"')
231
+ else:
232
+ doc_fields.append(f'"{col_clean}": "sample_value"')
233
+
234
+ # Create proper MongoDB command
235
+ if doc_fields:
236
+ fixed_mongo = f"db.{table_name}.insertOne({{{', '.join(doc_fields)}}});"
237
+ else:
238
+ fixed_mongo = f"db.createCollection('{table_name}');"
239
+
240
+ return fixed_mongo
241
+
242
+
243
+ def detect_comparison_operator(natural_language_query):
244
+ """Detect comparison operator from natural language
245
+
246
+ Returns: operator string ('>', '<', '>=', '<=', '=') or None
247
+ """
248
+ import re
249
+
250
+ nl = natural_language_query.lower()
251
+
252
+ # Check for comparison keywords
253
+ if re.search(r'\b(greater than|more than|above|exceeds?)\b', nl):
254
+ return '>'
255
+ elif re.search(r'\b(less than|fewer than|below|under)\b', nl):
256
+ return '<'
257
+ elif re.search(r'\b(greater than or equal to|at least|minimum)\b', nl):
258
+ return '>='
259
+ elif re.search(r'\b(less than or equal to|at most|maximum)\b', nl):
260
+ return '<='
261
+ elif re.search(r'\b(equals?|is|=)\b', nl):
262
+ return '='
263
+
264
+ return None
265
+
266
+
267
+ def fix_sql_operation_type(generated_sql, natural_language_query):
268
+ """Fix SQL queries with wrong operation type (SELECT vs DELETE vs UPDATE vs INSERT)"""
269
+ import re
270
+
271
+ nl = natural_language_query.lower()
272
+
273
+ # Detect intended operation from natural language
274
+ if re.search(r'\b(delete|remove)\b', nl):
275
+ # Should be DELETE, not SELECT
276
+ if re.match(r'SELECT\s+\*\s+FROM', generated_sql, re.IGNORECASE):
277
+ # Extract table and WHERE clause
278
+ match = re.search(r'SELECT\s+\*\s+FROM\s+(\w+)(\s+WHERE\s+.+)?', generated_sql, re.IGNORECASE)
279
+ if match:
280
+ table = match.group(1)
281
+ where_clause = match.group(2) if match.group(2) else ''
282
+ generated_sql = f"DELETE FROM {table}{where_clause}"
283
+
284
+ return generated_sql
285
+
286
+
287
+ def fix_mongodb_operation_type(generated_mongo, natural_language_query):
288
+ """Fix MongoDB queries with wrong operation type"""
289
+ import re
290
+
291
+ nl = natural_language_query.lower()
292
+
293
+ # Detect intended operation from natural language
294
+ if re.search(r'\b(delete|remove)\b', nl):
295
+ # Should be deleteMany, not find, insertOne, or deleteOne
296
+ if re.search(r'\.(find|findOne|insertOne|deleteOne)\s*\(', generated_mongo):
297
+ # Replace with deleteMany
298
+ generated_mongo = re.sub(
299
+ r'\.(find|findOne|insertOne|deleteOne)\s*\(',
300
+ '.deleteMany(',
301
+ generated_mongo
302
+ )
303
+
304
+ return generated_mongo
305
+
306
+
307
+ def fix_mongodb_missing_braces(generated_mongo):
308
+ """Fix MongoDB queries that are missing curly braces around query objects
309
+
310
+ Example: db.collection.find("field": value) -> db.collection.find({"field": value})
311
+ """
312
+ import re
313
+
314
+ # Pattern: .method("field": value) or .method(field: value)
315
+ # Missing the outer { } around the query object
316
+
317
+ # Pattern 1: .find("field": value) -> .find({"field": value})
318
+ pattern1 = r'(\.\w+)\(\"(\w+)\":\s*([^)]+)\)'
319
+ match = re.search(pattern1, generated_mongo)
320
+ if match:
321
+ method = match.group(1) # e.g., .find
322
+ field = match.group(2) # e.g., salary
323
+ value = match.group(3).strip() # e.g., 50000
324
+ # Remove trailing semicolon if present
325
+ value = value.rstrip(';')
326
+ # Reconstruct with proper braces
327
+ generated_mongo = re.sub(
328
+ pattern1,
329
+ method + '({"' + field + '": ' + value + '})',
330
+ generated_mongo
331
+ )
332
+ else:
333
+ # Pattern 2: .find(field: value) -> .find({field: value})
334
+ pattern2 = r'(\.\w+)\((\w+):\s*([^)]+)\)'
335
+ match = re.search(pattern2, generated_mongo)
336
+ if match:
337
+ method = match.group(1)
338
+ field = match.group(2)
339
+ value = match.group(3).strip()
340
+ value = value.rstrip(';')
341
+ generated_mongo = re.sub(
342
+ pattern2,
343
+ method + '({' + field + ': ' + value + '})',
344
+ generated_mongo
345
+ )
346
+
347
+ return generated_mongo
348
+
349
+
350
+ def fix_comparison_operator_sql(generated_sql, natural_language_query):
351
+ """Fix SQL queries with wrong comparison operators"""
352
+ import re
353
+
354
+ correct_op = detect_comparison_operator(natural_language_query)
355
+
356
+ if correct_op and correct_op != '=':
357
+ # Replace = with correct operator in WHERE clause
358
+ # Pattern: WHERE column = value
359
+ generated_sql = re.sub(
360
+ r'(WHERE\s+\w+)\s*=\s*',
361
+ r'\1 ' + correct_op + ' ',
362
+ generated_sql,
363
+ flags=re.IGNORECASE
364
+ )
365
+
366
+ return generated_sql
367
+
368
+
369
+ def fix_comparison_operator_mongodb(generated_mongo, natural_language_query):
370
+ """Fix MongoDB queries with wrong comparison operators"""
371
+ import re
372
+
373
+ correct_op = detect_comparison_operator(natural_language_query)
374
+
375
+ if correct_op and correct_op != '=':
376
+ # Map SQL operators to MongoDB operators
377
+ mongo_op_map = {
378
+ '>': '$gt',
379
+ '<': '$lt',
380
+ '>=': '$gte',
381
+ '<=': '$lte'
382
+ }
383
+
384
+ mongo_op = mongo_op_map.get(correct_op)
385
+
386
+ if mongo_op:
387
+ # More robust pattern matching for MongoDB queries
388
+ # Handles: db.collection.operation({"field": value}) or db.collection.operation({field: value})
389
+
390
+ # Pattern 1: {"field": value} - quoted field name
391
+ pattern1 = r'\{"(\w+)":\s*([^,}{]+)\}'
392
+ match = re.search(pattern1, generated_mongo)
393
+ if match:
394
+ field = match.group(1)
395
+ value = match.group(2).strip()
396
+ # Replace with comparison operator
397
+ replacement = '{"' + field + '": {' + mongo_op + ': ' + value + '}}'
398
+ generated_mongo = re.sub(pattern1, replacement, generated_mongo, count=1)
399
+ else:
400
+ # Pattern 2: {field: value} - unquoted field name
401
+ pattern2 = r'\{(\w+):\s*([^,}{]+)\}'
402
+ match = re.search(pattern2, generated_mongo)
403
+ if match:
404
+ field = match.group(1)
405
+ value = match.group(2).strip()
406
+ # Replace with comparison operator
407
+ replacement = '{' + field + ': {' + mongo_op + ': ' + value + '}}'
408
+ generated_mongo = re.sub(pattern2, replacement, generated_mongo, count=1)
409
+
410
+ return generated_mongo
411
+
412
+
413
+ def parse_update_query(natural_language_query):
414
+ """Parse UPDATE query from natural language
415
+
416
+ Example: "Update employees set department to Sales where employee_id is 101"
417
+ Returns: (table, set_column, set_value, where_column, where_value)
418
+ """
419
+ import re
420
+
421
+ # Use case-insensitive matching but preserve original values
422
+
423
+ # Pattern 1: "update X set Y to Z where A is B"
424
+ match = re.search(
425
+ r'update\s+(\w+)\s+set\s+(\w+)\s+to\s+([^\s]+(?:\s+[^\s]+)*?)\s+where\s+(\w+)\s+(?:is|equals?|=)\s+(.+)',
426
+ natural_language_query,
427
+ re.IGNORECASE
428
+ )
429
+
430
+ if match:
431
+ table_name = match.group(1)
432
+ set_column = match.group(2)
433
+ set_value = match.group(3).strip()
434
+ where_column = match.group(4)
435
+ where_value = match.group(5).strip()
436
+ return (table_name, set_column, set_value, where_column, where_value)
437
+
438
+ # Pattern 2: "update X set Y = Z where A = B"
439
+ match = re.search(
440
+ r'update\s+(\w+)\s+set\s+(\w+)\s*=\s*([^\s]+(?:\s+[^\s]+)*?)\s+where\s+(\w+)\s*=\s*(.+)',
441
+ natural_language_query,
442
+ re.IGNORECASE
443
+ )
444
+
445
+ if match:
446
+ table_name = match.group(1)
447
+ set_column = match.group(2)
448
+ set_value = match.group(3).strip()
449
+ where_column = match.group(4)
450
+ where_value = match.group(5).strip()
451
+ return (table_name, set_column, set_value, where_column, where_value)
452
+
453
+ return None
454
+
455
+
456
+ def fix_update_query_sql(generated_sql, natural_language_query):
457
+ """Fix malformed UPDATE SQL queries"""
458
+ import re
459
+
460
+ # Check if model generated garbage for UPDATE
461
+ if 'update' in natural_language_query.lower():
462
+ # If output doesn't look like proper SQL UPDATE
463
+ if not re.search(r'UPDATE\s+\w+\s+SET', generated_sql, re.IGNORECASE):
464
+ parsed = parse_update_query(natural_language_query)
465
+ if parsed:
466
+ table, set_col, set_val, where_col, where_val = parsed
467
+
468
+ # Determine if value should be quoted (string vs number)
469
+ try:
470
+ # Try to parse as number
471
+ float(set_val)
472
+ set_val_quoted = set_val
473
+ except:
474
+ set_val_quoted = f"'{set_val}'"
475
+
476
+ try:
477
+ float(where_val)
478
+ where_val_quoted = where_val
479
+ except:
480
+ where_val_quoted = f"'{where_val}'"
481
+
482
+ # Reconstruct proper SQL
483
+ return f"UPDATE {table} SET {set_col} = {set_val_quoted} WHERE {where_col} = {where_val_quoted};"
484
+
485
+ return generated_sql
486
+
487
+
488
+ def fix_update_query_mongodb(generated_mongo, natural_language_query):
489
+ """Fix malformed UPDATE MongoDB queries"""
490
+ import re
491
+
492
+ # Check if model generated garbage for UPDATE
493
+ if 'update' in natural_language_query.lower():
494
+ # If output doesn't look like proper MongoDB update
495
+ if not re.search(r'\.update', generated_mongo, re.IGNORECASE):
496
+ parsed = parse_update_query(natural_language_query)
497
+ if parsed:
498
+ table, set_col, set_val, where_col, where_val = parsed
499
+
500
+ # Determine if value should be quoted
501
+ try:
502
+ float(set_val)
503
+ set_val_formatted = set_val
504
+ except:
505
+ set_val_formatted = f'"{set_val}"'
506
+
507
+ try:
508
+ float(where_val)
509
+ where_val_formatted = where_val
510
+ except:
511
+ where_val_formatted = f'"{where_val}"'
512
+
513
+ # Reconstruct proper MongoDB
514
+ return f"db.{table}.updateMany({{{where_col}: {where_val_formatted}}}, {{$set: {{{set_col}: {set_val_formatted}}}}});"
515
+
516
+ return generated_mongo
517
+
518
+
519
+ class TexQLModel:
520
+ """Unified model wrapper for SQL/MongoDB generation"""
521
+
522
+ def __init__(self, model_path):
523
+ """Initialize the model for inference"""
524
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
525
+
526
+ try:
527
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
528
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
529
+ self.model.to(self.device)
530
+ self.model.eval()
531
+ self.loaded = True
532
+ except Exception as e:
533
+ st.error(f"Error loading model: {str(e)}")
534
+ self.loaded = False
535
+
536
+ def generate_query(self, natural_language_query, target_type='sql', temperature=0.3,
537
+ num_beams=10, repetition_penalty=1.2, length_penalty=0.8):
538
+ """Generate SQL or MongoDB query from natural language
539
+
540
+ Args:
541
+ natural_language_query: The user's natural language query
542
+ target_type: 'sql' or 'mongodb' to specify output format
543
+ temperature: Sampling temperature (lower = more focused)
544
+ num_beams: Number of beams for beam search
545
+ repetition_penalty: Penalty for repeating tokens (>1.0 discourages repetition)
546
+ length_penalty: Penalty for length (>1.0 encourages longer, <1.0 encourages shorter)
547
+ """
548
+ if not self.loaded:
549
+ return "Model not loaded"
550
+
551
+ input_text = f"translate to {target_type}: {natural_language_query}"
552
+
553
+ inputs = self.tokenizer(
554
+ input_text,
555
+ return_tensors="pt",
556
+ max_length=256,
557
+ truncation=True
558
+ ).to(self.device)
559
+
560
+ with torch.no_grad():
561
+ outputs = self.model.generate(
562
+ **inputs,
563
+ max_length=512,
564
+ num_beams=num_beams,
565
+ temperature=temperature,
566
+ repetition_penalty=repetition_penalty,
567
+ length_penalty=length_penalty,
568
+ no_repeat_ngram_size=3, # Prevent repeating 3-grams
569
+ early_stopping=True,
570
+ do_sample=False # Use greedy/beam search (more deterministic)
571
+ )
572
+
573
+ generated_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
574
+
575
+ # ✅ POST-PROCESSING: Fix hallucinated columns in CREATE queries
576
+ if any(word in natural_language_query.lower() for word in ['create', 'add columns']):
577
+ table_name, requested_columns = extract_columns_from_nl(natural_language_query)
578
+
579
+ if table_name and requested_columns:
580
+ if target_type == 'sql':
581
+ generated_query = fix_create_table_sql(generated_query, table_name, requested_columns)
582
+ elif target_type == 'mongodb':
583
+ generated_query = fix_create_collection_mongo(generated_query, table_name, requested_columns)
584
+
585
+ # ✅ POST-PROCESSING: Fix malformed UPDATE queries
586
+ if 'update' in natural_language_query.lower() and 'set' in natural_language_query.lower():
587
+ if target_type == 'sql':
588
+ generated_query = fix_update_query_sql(generated_query, natural_language_query)
589
+ elif target_type == 'mongodb':
590
+ generated_query = fix_update_query_mongodb(generated_query, natural_language_query)
591
+
592
+ # ✅ POST-PROCESSING: Fix wrong operation type (SELECT vs DELETE, etc.)
593
+ if target_type == 'sql':
594
+ generated_query = fix_sql_operation_type(generated_query, natural_language_query)
595
+ elif target_type == 'mongodb':
596
+ generated_query = fix_mongodb_operation_type(generated_query, natural_language_query)
597
+
598
+ # ✅ POST-PROCESSING: Fix missing curly braces in MongoDB queries
599
+ if target_type == 'mongodb':
600
+ generated_query = fix_mongodb_missing_braces(generated_query)
601
+
602
+ # ✅ POST-PROCESSING: Fix comparison operators (>, <, >=, <=)
603
+ if target_type == 'sql':
604
+ generated_query = fix_comparison_operator_sql(generated_query, natural_language_query)
605
+ elif target_type == 'mongodb':
606
+ generated_query = fix_comparison_operator_mongodb(generated_query, natural_language_query)
607
+
608
+ return generated_query
609
+
610
+
611
+ @st.cache_resource
612
+ def load_model(model_path):
613
+ """Load the unified NoNoQL model (cached)"""
614
+ model = None
615
+
616
+ if os.path.exists(model_path):
617
+ model = TexQLModel(model_path)
618
+
619
+ return model
620
+
621
+
622
+ def save_query_history(nl_query, sql_query, mongodb_query, max_history=500):
623
+ """Save query to history with size limit"""
624
+ if 'history' not in st.session_state:
625
+ st.session_state.history = []
626
+
627
+ st.session_state.history.append({
628
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
629
+ 'natural_language': nl_query,
630
+ 'sql': sql_query,
631
+ 'mongodb': mongodb_query
632
+ })
633
+
634
+ # Keep only the most recent entries
635
+ if len(st.session_state.history) > max_history:
636
+ st.session_state.history = st.session_state.history[-max_history:]
637
+
638
+ persist_query_history(st.session_state.history)
639
+
640
+
641
+ def delete_history_entry(index):
642
+ """Delete a specific history entry"""
643
+ if 'history' in st.session_state and 0 <= index < len(st.session_state.history):
644
+ st.session_state.history.pop(index)
645
+ persist_query_history(st.session_state.history)
646
+
647
+
648
+ def load_query_history():
649
+ """Load query history from disk"""
650
+ try:
651
+ if not os.path.exists(HISTORY_FILE_PATH):
652
+ return []
653
+
654
+ with open(HISTORY_FILE_PATH, "r", encoding="utf-8") as history_file:
655
+ history = json.load(history_file)
656
+
657
+ if isinstance(history, list):
658
+ return history
659
+ return []
660
+ except Exception:
661
+ return []
662
+
663
+
664
+ def persist_query_history(history):
665
+ """Persist query history to disk"""
666
+ os.makedirs(os.path.dirname(HISTORY_FILE_PATH), exist_ok=True)
667
+ with open(HISTORY_FILE_PATH, "w", encoding="utf-8") as history_file:
668
+ json.dump(history, history_file, indent=2)
669
+
670
+
671
+ def load_schema():
672
+ """Load database schema from disk"""
673
+ try:
674
+ if not os.path.exists(SCHEMA_FILE_PATH):
675
+ return DEFAULT_SCHEMA
676
+
677
+ with open(SCHEMA_FILE_PATH, "r", encoding="utf-8") as schema_file:
678
+ schema = schema_file.read()
679
+
680
+ return schema if schema.strip() else DEFAULT_SCHEMA
681
+ except Exception:
682
+ return DEFAULT_SCHEMA
683
+
684
+
685
+ def persist_schema(schema):
686
+ """Persist database schema to disk"""
687
+ os.makedirs(os.path.dirname(SCHEMA_FILE_PATH), exist_ok=True)
688
+ with open(SCHEMA_FILE_PATH, "w", encoding="utf-8") as schema_file:
689
+ schema_file.write(schema)
690
+
691
+
692
+ def main():
693
+ if 'history' not in st.session_state:
694
+ st.session_state.history = load_query_history()
695
+
696
+ if 'schema' not in st.session_state:
697
+ st.session_state.schema = load_schema()
698
+
699
+ if 'schema_edit_mode' not in st.session_state:
700
+ st.session_state.schema_edit_mode = False
701
+
702
+ # Sidebar
703
+ with st.sidebar:
704
+ st.header("⚙️ Configuration")
705
+
706
+ # Model path
707
+ st.subheader("Model Path")
708
+ model_path = st.text_input(
709
+ "NoNoQL Model Path",
710
+ value="models",
711
+ help="Path to the unified NoNoQL model (generates both SQL and MongoDB)"
712
+ )
713
+
714
+ # Generation parameters
715
+ st.subheader("Generation Parameters")
716
+ temperature = st.slider(
717
+ "Temperature",
718
+ min_value=0.1,
719
+ max_value=1.0,
720
+ value=0.3, # ✅ Lower default = less hallucination
721
+ step=0.1,
722
+ help="Lower = more focused, Higher = more creative"
723
+ )
724
+ num_beams = st.slider(
725
+ "Beam Search Width",
726
+ min_value=1,
727
+ max_value=10,
728
+ value=10, # ✅ Higher value = more accurate results
729
+ help="Higher values improve accuracy (recommended: keep at 10)"
730
+ )
731
+ repetition_penalty = st.slider(
732
+ "Repetition Penalty",
733
+ min_value=1.0,
734
+ max_value=2.0,
735
+ value=1.2, # ✅ Discourages adding extra unwanted columns
736
+ step=0.1,
737
+ help="Higher = less repetition (prevents hallucinating extra columns)"
738
+ )
739
+ length_penalty = st.slider(
740
+ "Length Penalty",
741
+ min_value=0.5,
742
+ max_value=1.5,
743
+ value=0.8, # ✅ Prefer shorter outputs
744
+ step=0.1,
745
+ help="Lower = prefer shorter outputs, Higher = prefer longer outputs"
746
+ )
747
+
748
+ # Load models button
749
+ if st.button("🔄 Load/Reload Models"):
750
+ st.cache_resource.clear()
751
+ st.rerun()
752
+
753
+ # History management
754
+ st.subheader("📚 History Settings")
755
+ max_history_size = st.number_input(
756
+ "Max History Entries",
757
+ min_value=10,
758
+ max_value=1000,
759
+ value=500,
760
+ step=10,
761
+ help="Maximum number of queries to keep in history"
762
+ )
763
+
764
+ # Database schema info
765
+ st.subheader("📊 Database Schema")
766
+
767
+ # Toggle edit mode
768
+ col1, col2 = st.columns([1, 3])
769
+ with col1:
770
+ if st.button("✏️ Edit" if not st.session_state.schema_edit_mode else "👁️ View"):
771
+ st.session_state.schema_edit_mode = not st.session_state.schema_edit_mode
772
+ st.rerun()
773
+ with col2:
774
+ if st.session_state.schema_edit_mode:
775
+ st.info("✏️ Editing Mode")
776
+ else:
777
+ st.caption("View your database tables and columns")
778
+
779
+ if st.session_state.schema_edit_mode:
780
+ # Edit mode - text area
781
+ edited_schema = st.text_area(
782
+ "Edit Database Schema",
783
+ value=st.session_state.schema,
784
+ height=300,
785
+ help="Define your database tables and columns. Use Markdown format."
786
+ )
787
+
788
+ col1, col2 = st.columns(2)
789
+ with col1:
790
+ if st.button("💾 Save Schema", use_container_width=True):
791
+ st.session_state.schema = edited_schema
792
+ persist_schema(edited_schema)
793
+ st.success("Schema saved!")
794
+ st.session_state.schema_edit_mode = False
795
+ st.rerun()
796
+
797
+ with col2:
798
+ if st.button("🔄 Reset to Default", use_container_width=True):
799
+ st.session_state.schema = DEFAULT_SCHEMA
800
+ persist_schema(DEFAULT_SCHEMA)
801
+ st.success("Schema reset to default!")
802
+ st.rerun()
803
+ else:
804
+ # View mode - expandable display
805
+ with st.expander("View Available Tables", expanded=False):
806
+ st.markdown(st.session_state.schema)
807
+
808
+ # Load model
809
+ with st.spinner("Loading model..."):
810
+ model = load_model("mohhhhhit/nonoql") # Load directly from HF!
811
+
812
+ # Model status
813
+ if model and model.loaded:
814
+ device_info = "🎮 GPU" if model.device == "cuda" else "💻 CPU"
815
+ st.success(f"✅ Model Loaded ({device_info})")
816
+ st.info("💡 This model generates both SQL and MongoDB queries")
817
+ else:
818
+ st.error("⚠️ Model Not Available - Please check the model path")
819
+
820
+ # Query input
821
+ st.subheader("🔤 Enter Your Query")
822
+
823
+ # Example queries dropdown
824
+ with st.expander("💡 Example Queries - Click to expand"):
825
+ examples = [
826
+ "Show all employees",
827
+ "Find employees where salary is greater than 50000",
828
+ "Get all departments with budget more than 100000",
829
+ "Insert a new employee with name John Doe, email john@example.com, department Engineering",
830
+ "Update employees set department to Sales where employee_id is 101",
831
+ "Delete orders with total_amount less than 1000",
832
+ "Count all products in Electronics category",
833
+ "Show top 10 employees ordered by salary",
834
+ ]
835
+
836
+ selected_example = st.selectbox(
837
+ "Choose an example query:",
838
+ [""] + examples,
839
+ index=0,
840
+ format_func=lambda x: "Select an example..." if x == "" else x
841
+ )
842
+
843
+ if selected_example and st.button("📝 Use This Example", use_container_width=True):
844
+ st.session_state.user_query = selected_example
845
+ st.rerun()
846
+
847
+ user_query = st.text_area(
848
+ "or",
849
+ value=st.session_state.get('user_query', ''),
850
+ height=100,
851
+ placeholder="write your query here..."
852
+ )
853
+
854
+ # Generate button
855
+ if st.button("🚀 Generate Queries"):
856
+ if not user_query.strip():
857
+ st.warning("Please enter a query")
858
+ elif not model or not model.loaded:
859
+ st.error("Model is not loaded. Please check the model path and reload.")
860
+ else:
861
+ with st.spinner("Generating queries..."):
862
+ # Generate both SQL and MongoDB from the same model
863
+ sql_query = model.generate_query(
864
+ user_query,
865
+ target_type='sql',
866
+ temperature=temperature,
867
+ num_beams=num_beams,
868
+ repetition_penalty=repetition_penalty,
869
+ length_penalty=length_penalty
870
+ )
871
+
872
+ mongodb_query = model.generate_query(
873
+ user_query,
874
+ target_type='mongodb',
875
+ temperature=temperature,
876
+ num_beams=num_beams,
877
+ repetition_penalty=repetition_penalty,
878
+ length_penalty=length_penalty
879
+ )
880
+
881
+ # Save to history
882
+ save_query_history(user_query, sql_query, mongodb_query, max_history_size)
883
+
884
+ # Display results
885
+ st.markdown("---")
886
+ st.success("✅ Queries Generated Successfully!")
887
+
888
+ # Input query
889
+ st.markdown('<div class="query-box">', unsafe_allow_html=True)
890
+ st.markdown("**📝 Your Query:**")
891
+ st.code(user_query, language="text")
892
+ st.markdown('</div>', unsafe_allow_html=True)
893
+
894
+ # Results in columns
895
+ col1, col2 = st.columns(2)
896
+
897
+ with col1:
898
+ st.markdown("### 🗄️ SQL Query")
899
+ st.code(sql_query, language="sql")
900
+
901
+ # Copy button
902
+ if st.button("📋 Copy SQL", key="copy_sql"):
903
+ st.session_state.clipboard = sql_query
904
+ st.success("Copied to clipboard!")
905
+
906
+ with col2:
907
+ st.markdown("### 🍃 MongoDB Query")
908
+ st.code(mongodb_query, language="javascript")
909
+
910
+ # Copy button
911
+ if st.button("📋 Copy MongoDB", key="copy_mongo"):
912
+ st.session_state.clipboard = mongodb_query
913
+ st.success("Copied to clipboard!")
914
+
915
+ # Query history
916
+ if 'history' in st.session_state and st.session_state.history:
917
+ st.markdown("---")
918
+ st.subheader("📚 Query History")
919
+
920
+ # History management controls
921
+ col1, col2, col3 = st.columns([2, 1, 1])
922
+
923
+ with col1:
924
+ search_term = st.text_input(
925
+ "🔍 Search History",
926
+ placeholder="Search in queries...",
927
+ label_visibility="collapsed"
928
+ )
929
+
930
+ with col2:
931
+ sort_order = st.selectbox(
932
+ "Sort",
933
+ ["Newest First", "Oldest First"],
934
+ label_visibility="collapsed"
935
+ )
936
+
937
+ with col3:
938
+ show_limit = st.number_input(
939
+ "Show",
940
+ min_value=5,
941
+ max_value=100,
942
+ value=10,
943
+ step=5,
944
+ label_visibility="collapsed"
945
+ )
946
+
947
+ # Action buttons
948
+ col1, col2 = st.columns(2)
949
+ with col1:
950
+ if st.button("🗑️ Clear All History"):
951
+ st.session_state.history = []
952
+ persist_query_history(st.session_state.history)
953
+ st.rerun()
954
+
955
+ with col2:
956
+ if st.button("💾 Export History"):
957
+ history_json = json.dumps(st.session_state.history, indent=2)
958
+ st.download_button(
959
+ label="Download History (JSON)",
960
+ data=history_json,
961
+ file_name=f"nonoql_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
962
+ mime="application/json"
963
+ )
964
+
965
+ # Filter history
966
+ filtered_history = st.session_state.history
967
+ if search_term:
968
+ search_lower = search_term.lower()
969
+ filtered_history = [
970
+ entry for entry in st.session_state.history
971
+ if search_lower in entry['natural_language'].lower() or
972
+ search_lower in entry.get('sql', '').lower() or
973
+ search_lower in entry.get('mongodb', '').lower()
974
+ ]
975
+
976
+ # Sort history
977
+ if sort_order == "Oldest First":
978
+ display_history = filtered_history[:show_limit]
979
+ else:
980
+ display_history = list(reversed(filtered_history[-show_limit:]))
981
+
982
+ # Display count
983
+ st.markdown(f"**Showing {len(display_history)} of {len(filtered_history)} queries** (Total: {len(st.session_state.history)})")
984
+
985
+ if not display_history:
986
+ st.info("No queries found matching your search.")
987
+
988
+ # Display history entries
989
+ for display_idx, entry in enumerate(display_history):
990
+ # Find actual index in original history for deletion
991
+ actual_idx = st.session_state.history.index(entry)
992
+
993
+ with st.expander(
994
+ f"🕐 {entry['timestamp']} - {entry['natural_language'][:60]}...",
995
+ expanded=False
996
+ ):
997
+ # Action buttons for this entry
998
+ col1, col2, col3 = st.columns([3, 1, 1])
999
+
1000
+ with col1:
1001
+ st.markdown(f"**Natural Language Query:**")
1002
+ st.info(entry['natural_language'])
1003
+
1004
+ with col2:
1005
+ if st.button("🔄 Rerun", key=f"rerun_{actual_idx}"):
1006
+ st.session_state.user_query = entry['natural_language']
1007
+ st.rerun()
1008
+
1009
+ with col3:
1010
+ if st.button("🗑️ Delete", key=f"del_{actual_idx}"):
1011
+ delete_history_entry(actual_idx)
1012
+ st.rerun()
1013
+
1014
+ # Display queries
1015
+ col1, col2 = st.columns(2)
1016
+ with col1:
1017
+ st.markdown("**SQL Query:**")
1018
+ if entry.get('sql'):
1019
+ st.code(entry['sql'], language="sql")
1020
+ else:
1021
+ st.text("N/A")
1022
+
1023
+ with col2:
1024
+ st.markdown("**MongoDB Query:**")
1025
+ if entry.get('mongodb'):
1026
+ st.code(entry['mongodb'], language="javascript")
1027
+ else:
1028
+ st.text("N/A")
1029
+
1030
+ # Footer
1031
+ st.markdown("---")
1032
+ st.markdown("""
1033
+ <div style='text-align: center; color: #666; padding: 2rem;'>
1034
+ <p>NoNoQL - Natural Language to Query Generator</p>
1035
+ <p>Powered by T5 Transformer Models | Built with Streamlit</p>
1036
+ </div>
1037
+ """, unsafe_allow_html=True)
1038
 
 
 
1039
 
1040
+ if __name__ == "__main__":
1041
+ main()