File size: 39,369 Bytes
c657ef6 99f8e2d c657ef6 b219e74 c657ef6 b219e74 99f8e2d c657ef6 b219e74 c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 99f8e2d c657ef6 b219e74 c657ef6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 | """
NoNoQL - Natural Language to SQL/MongoDB Query Generator
Streamlit Frontend Application (HuggingFace Spaces Version)
"""
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import os
import json
from datetime import datetime
# HuggingFace Spaces configuration
HF_MODEL_REPO = "mohhhhhit/nonoql" # Your HuggingFace model repo
# Detect environment
def is_hf_space():
"""Check if running on HuggingFace Spaces"""
return os.getenv("SPACE_ID") is not None
# Set default model path based on environment
DEFAULT_MODEL_PATH = HF_MODEL_REPO if is_hf_space() else "models"
HISTORY_FILE_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"data",
"query_history.json"
)
SCHEMA_FILE_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"data",
"database_schema.txt"
)
DEFAULT_SCHEMA = """**employees**
- employee_id, name, email
- department, salary, hire_date, age
**departments**
- department_id, department_name
- manager_id, budget, location
**projects**
- project_id, project_name
- start_date, end_date, budget, status
**orders**
- order_id, customer_name
- product_name, quantity
- order_date, total_amount
**products**
- product_id, product_name
- category, price
- stock_quantity, supplier"""
# Page configuration
st.set_page_config(
page_title="NoNoQL - Natural Language to SQL/MongoDB Query Generator",
page_icon="๐",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS
st.markdown("""
<style>
/* Inject title into Streamlit header bar */
header[data-testid="stHeader"] {
background-color: rgba(14, 17, 23, 0.95) !important;
}
header[data-testid="stHeader"]::before {
content: "NoNoQL";
color: white;
font-size: 1.3rem;
font-weight: 600;
position: absolute;
left: 1rem;
top: 50%;
transform: translateY(-50%);
z-index: 999;
}
.query-box {
background-color: #f0f2f6;
border-radius: 10px;
padding: 20px;
margin: 10px 0;
border-left: 5px solid #1E88E5;
}
.success-box {
background-color: #d4edda;
border-radius: 10px;
padding: 20px;
margin: 10px 0;
border-left: 5px solid #28a745;
}
.example-query {
background-color: #fff3cd;
border-radius: 5px;
padding: 10px;
margin: 5px 0;
cursor: pointer;
}
.example-query:hover {
background-color: #ffe69c;
}
.stButton>button {
width: 100%;
background-color: #1E88E5;
color: white;
font-size: 1.1rem;
padding: 0.5rem 1rem;
border-radius: 10px;
border: none;
margin-top: 1rem;
}
.stButton>button:hover {
background-color: #1565C0;
}
</style>
""", unsafe_allow_html=True)
def extract_columns_from_nl(natural_language_query):
"""Extract table name and column names from natural language query"""
import re
nl = natural_language_query.lower().strip()
# Extract table name
table_match = re.search(r'(?:table|collection)\s+(?:named|called)?\s*(\w+)', nl)
table_name = table_match.group(1) if table_match else None
# Extract column names - look for patterns like "columns as X, Y, Z" or "with X, Y, Z"
columns = []
# Pattern 1: "columns as/named X, Y, Z"
col_match = re.search(r'columns?\s+(?:as|named|like|called)?\s*([^,]+(?:,\s*[^,]+)*)', nl)
if col_match:
col_text = col_match.group(1)
# Split by comma or 'and'
columns = re.split(r',|\s+and\s+', col_text)
columns = [c.strip() for c in columns if c.strip()]
# Pattern 2: "add columns X, Y, Z"
if not columns:
col_match = re.search(r'(?:add|with)\s+(?:columns?)?\s*([^,]+(?:,\s*[^,]+)*)', nl)
if col_match:
col_text = col_match.group(1)
columns = re.split(r',|\s+and\s+', col_text)
columns = [c.strip() for c in columns if c.strip()]
return table_name, columns
def fix_create_table_sql(generated_sql, table_name, requested_columns):
"""Replace hallucinated columns with actual requested columns in CREATE TABLE"""
import re
if not table_name or not requested_columns:
return generated_sql
# Check if it's a CREATE TABLE query
if not re.search(r'CREATE\s+TABLE', generated_sql, re.IGNORECASE):
return generated_sql
# Default data types for common column patterns
def infer_type(col_name):
col_lower = col_name.lower()
if 'id' in col_lower:
return 'INT PRIMARY KEY'
elif any(word in col_lower for word in ['name', 'title', 'description', 'address', 'city']):
return 'VARCHAR(100)'
elif any(word in col_lower for word in ['email']):
return 'VARCHAR(100)'
elif any(word in col_lower for word in ['phone', 'contact', 'mobile']):
return 'VARCHAR(20)'
elif any(word in col_lower for word in ['date', 'created', 'updated']):
return 'DATE'
elif any(word in col_lower for word in ['price', 'salary', 'amount', 'cost']):
return 'DECIMAL(10,2)'
elif any(word in col_lower for word in ['age', 'quantity', 'count', 'stock']):
return 'INT'
elif any(word in col_lower for word in ['status', 'type', 'category']):
return 'VARCHAR(50)'
else:
return 'VARCHAR(100)'
# Build column definitions
col_defs = []
for col in requested_columns:
col_clean = col.strip()
if col_clean:
col_type = infer_type(col_clean)
col_defs.append(f"{col_clean} {col_type}")
# Rebuild a clean CREATE TABLE statement from requested columns.
# This avoids malformed model output leaking extra columns outside parentheses.
if_not_exists_match = re.search(
r'CREATE\s+TABLE\s+IF\s+NOT\s+EXISTS\s+' + re.escape(table_name),
generated_sql,
re.IGNORECASE
)
if if_not_exists_match:
create_clause = if_not_exists_match.group(0)
else:
create_match = re.search(
r'CREATE\s+TABLE\s+' + re.escape(table_name),
generated_sql,
re.IGNORECASE
)
if not create_match:
return generated_sql
create_clause = create_match.group(0)
new_columns = ', '.join(col_defs)
return f"{create_clause} ({new_columns});"
def fix_create_collection_mongo(generated_mongo, table_name, requested_columns):
"""Fix MongoDB createCollection to use correct collection name and sample document"""
if not table_name:
return generated_mongo
# Build sample document with requested columns
doc_fields = []
for col in requested_columns:
col_clean = col.strip()
if col_clean:
# Provide example values based on column name
if 'id' in col_clean.lower():
doc_fields.append(f'"{col_clean}": 1')
elif any(word in col_clean.lower() for word in ['name', 'title']):
doc_fields.append(f'"{col_clean}": "sample_name"')
elif 'email' in col_clean.lower():
doc_fields.append(f'"{col_clean}": "user@example.com"')
elif any(word in col_clean.lower() for word in ['phone', 'contact']):
doc_fields.append(f'"{col_clean}": "1234567890"')
else:
doc_fields.append(f'"{col_clean}": "sample_value"')
# Create proper MongoDB command
if doc_fields:
fixed_mongo = f"db.{table_name}.insertOne({{{', '.join(doc_fields)}}});"
else:
fixed_mongo = f"db.createCollection('{table_name}');"
return fixed_mongo
def detect_comparison_operator(natural_language_query):
"""Detect comparison operator from natural language
Returns: operator string ('>', '<', '>=', '<=', '=') or None
"""
import re
nl = natural_language_query.lower()
# Check for comparison keywords
if re.search(r'\b(greater than|more than|above|exceeds?)\b', nl):
return '>'
elif re.search(r'\b(less than|fewer than|below|under)\b', nl):
return '<'
elif re.search(r'\b(greater than or equal to|at least|minimum)\b', nl):
return '>='
elif re.search(r'\b(less than or equal to|at most|maximum)\b', nl):
return '<='
elif re.search(r'\b(equals?|is|=)\b', nl):
return '='
return None
def fix_sql_operation_type(generated_sql, natural_language_query):
"""Fix SQL queries with wrong operation type (SELECT vs DELETE vs UPDATE vs INSERT)"""
import re
nl = natural_language_query.lower()
# Detect intended operation from natural language
if re.search(r'\b(delete|remove)\b', nl):
# Should be DELETE, not SELECT
if re.match(r'SELECT\s+\*\s+FROM', generated_sql, re.IGNORECASE):
# Extract table and WHERE clause
match = re.search(r'SELECT\s+\*\s+FROM\s+(\w+)(\s+WHERE\s+.+)?', generated_sql, re.IGNORECASE)
if match:
table = match.group(1)
where_clause = match.group(2) if match.group(2) else ''
generated_sql = f"DELETE FROM {table}{where_clause}"
return generated_sql
def fix_mongodb_operation_type(generated_mongo, natural_language_query):
"""Fix MongoDB queries with wrong operation type"""
import re
nl = natural_language_query.lower()
# Detect intended operation from natural language
if re.search(r'\b(delete|remove)\b', nl):
# Should be deleteMany, not find, insertOne, or deleteOne
if re.search(r'\.(find|findOne|insertOne|deleteOne)\s*\(', generated_mongo):
# Replace with deleteMany
generated_mongo = re.sub(
r'\.(find|findOne|insertOne|deleteOne)\s*\(',
'.deleteMany(',
generated_mongo
)
return generated_mongo
def fix_mongodb_missing_braces(generated_mongo):
"""Fix MongoDB queries that are missing curly braces around query objects
Example: db.collection.find("field": value) -> db.collection.find({"field": value})
"""
import re
# Pattern: .method("field": value) or .method(field: value)
# Missing the outer { } around the query object
# Pattern 1: .find("field": value) -> .find({"field": value})
pattern1 = r'(\.\w+)\(\"(\w+)\":\s*([^)]+)\)'
match = re.search(pattern1, generated_mongo)
if match:
method = match.group(1) # e.g., .find
field = match.group(2) # e.g., salary
value = match.group(3).strip() # e.g., 50000
# Remove trailing semicolon if present
value = value.rstrip(';')
# Reconstruct with proper braces
generated_mongo = re.sub(
pattern1,
method + '({"' + field + '": ' + value + '})',
generated_mongo
)
else:
# Pattern 2: .find(field: value) -> .find({field: value})
pattern2 = r'(\.\w+)\((\w+):\s*([^)]+)\)'
match = re.search(pattern2, generated_mongo)
if match:
method = match.group(1)
field = match.group(2)
value = match.group(3).strip()
value = value.rstrip(';')
generated_mongo = re.sub(
pattern2,
method + '({' + field + ': ' + value + '})',
generated_mongo
)
return generated_mongo
def fix_comparison_operator_sql(generated_sql, natural_language_query):
"""Fix SQL queries with wrong comparison operators"""
import re
correct_op = detect_comparison_operator(natural_language_query)
if correct_op and correct_op != '=':
# Replace = with correct operator in WHERE clause
# Pattern: WHERE column = value
generated_sql = re.sub(
r'(WHERE\s+\w+)\s*=\s*',
r'\1 ' + correct_op + ' ',
generated_sql,
flags=re.IGNORECASE
)
return generated_sql
def fix_comparison_operator_mongodb(generated_mongo, natural_language_query):
"""Fix MongoDB queries with wrong comparison operators"""
import re
correct_op = detect_comparison_operator(natural_language_query)
if correct_op and correct_op != '=':
# Map SQL operators to MongoDB operators
mongo_op_map = {
'>': '$gt',
'<': '$lt',
'>=': '$gte',
'<=': '$lte'
}
mongo_op = mongo_op_map.get(correct_op)
if mongo_op:
# More robust pattern matching for MongoDB queries
# Handles: db.collection.operation({"field": value}) or db.collection.operation({field: value})
# Pattern 1: {"field": value} - quoted field name
pattern1 = r'\{"(\w+)":\s*([^,}{]+)\}'
match = re.search(pattern1, generated_mongo)
if match:
field = match.group(1)
value = match.group(2).strip()
# Replace with comparison operator
replacement = '{"' + field + '": {' + mongo_op + ': ' + value + '}}'
generated_mongo = re.sub(pattern1, replacement, generated_mongo, count=1)
else:
# Pattern 2: {field: value} - unquoted field name
pattern2 = r'\{(\w+):\s*([^,}{]+)\}'
match = re.search(pattern2, generated_mongo)
if match:
field = match.group(1)
value = match.group(2).strip()
# Replace with comparison operator
replacement = '{' + field + ': {' + mongo_op + ': ' + value + '}}'
generated_mongo = re.sub(pattern2, replacement, generated_mongo, count=1)
return generated_mongo
def parse_update_query(natural_language_query):
"""Parse UPDATE query from natural language
Example: "Update employees set department to Sales where employee_id is 101"
Returns: (table, set_column, set_value, where_column, where_value)
"""
import re
# Use case-insensitive matching but preserve original values
# Pattern 1: "update X set Y to Z where A is B"
match = re.search(
r'update\s+(\w+)\s+set\s+(\w+)\s+to\s+([^\s]+(?:\s+[^\s]+)*?)\s+where\s+(\w+)\s+(?:is|equals?|=)\s+(.+)',
natural_language_query,
re.IGNORECASE
)
if match:
table_name = match.group(1)
set_column = match.group(2)
set_value = match.group(3).strip()
where_column = match.group(4)
where_value = match.group(5).strip()
return (table_name, set_column, set_value, where_column, where_value)
# Pattern 2: "update X set Y = Z where A = B"
match = re.search(
r'update\s+(\w+)\s+set\s+(\w+)\s*=\s*([^\s]+(?:\s+[^\s]+)*?)\s+where\s+(\w+)\s*=\s*(.+)',
natural_language_query,
re.IGNORECASE
)
if match:
table_name = match.group(1)
set_column = match.group(2)
set_value = match.group(3).strip()
where_column = match.group(4)
where_value = match.group(5).strip()
return (table_name, set_column, set_value, where_column, where_value)
return None
def fix_update_query_sql(generated_sql, natural_language_query):
"""Fix malformed UPDATE SQL queries"""
import re
# Check if model generated garbage for UPDATE
if 'update' in natural_language_query.lower():
# If output doesn't look like proper SQL UPDATE
if not re.search(r'UPDATE\s+\w+\s+SET', generated_sql, re.IGNORECASE):
parsed = parse_update_query(natural_language_query)
if parsed:
table, set_col, set_val, where_col, where_val = parsed
# Determine if value should be quoted (string vs number)
try:
# Try to parse as number
float(set_val)
set_val_quoted = set_val
except:
set_val_quoted = f"'{set_val}'"
try:
float(where_val)
where_val_quoted = where_val
except:
where_val_quoted = f"'{where_val}'"
# Reconstruct proper SQL
return f"UPDATE {table} SET {set_col} = {set_val_quoted} WHERE {where_col} = {where_val_quoted};"
return generated_sql
def fix_update_query_mongodb(generated_mongo, natural_language_query):
"""Fix malformed UPDATE MongoDB queries"""
import re
# Check if model generated garbage for UPDATE
if 'update' in natural_language_query.lower():
# If output doesn't look like proper MongoDB update
if not re.search(r'\.update', generated_mongo, re.IGNORECASE):
parsed = parse_update_query(natural_language_query)
if parsed:
table, set_col, set_val, where_col, where_val = parsed
# Determine if value should be quoted
try:
float(set_val)
set_val_formatted = set_val
except:
set_val_formatted = f'"{set_val}"'
try:
float(where_val)
where_val_formatted = where_val
except:
where_val_formatted = f'"{where_val}"'
# Reconstruct proper MongoDB
return f"db.{table}.updateMany({{{where_col}: {where_val_formatted}}}, {{$set: {{{set_col}: {set_val_formatted}}}}});"
return generated_mongo
class TexQLModel:
"""Unified model wrapper for SQL/MongoDB generation"""
def __init__(self, model_path):
"""Initialize the model for inference"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.loaded = False
try:
# Show loading status
with st.spinner(f"Loading model from {'HuggingFace Hub' if '/' in model_path else 'local path'}..."):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
self.loaded = True
st.success(f"โ
Model loaded successfully on {self.device.upper()}")
except Exception as e:
st.error(f"โ Error loading model: {str(e)}")
if is_hf_space():
st.info("๐ก Model is loading from HuggingFace Hub - this may take a moment on first run")
def generate_query(self, natural_language_query, target_type='sql', temperature=0.3,
num_beams=10, repetition_penalty=1.2, length_penalty=0.8):
"""Generate SQL or MongoDB query from natural language
Args:
natural_language_query: The user's natural language query
target_type: 'sql' or 'mongodb' to specify output format
temperature: Sampling temperature (lower = more focused)
num_beams: Number of beams for beam search
repetition_penalty: Penalty for repeating tokens (>1.0 discourages repetition)
length_penalty: Penalty for length (>1.0 encourages longer, <1.0 encourages shorter)
"""
if not self.loaded:
return "Model not loaded"
input_text = f"translate to {target_type}: {natural_language_query}"
inputs = self.tokenizer(
input_text,
return_tensors="pt",
max_length=256,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=512,
num_beams=num_beams,
temperature=temperature,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
no_repeat_ngram_size=3, # Prevent repeating 3-grams
early_stopping=True,
do_sample=False # Use greedy/beam search (more deterministic)
)
generated_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# โ
POST-PROCESSING: Fix hallucinated columns in CREATE queries
if any(word in natural_language_query.lower() for word in ['create', 'add columns']):
table_name, requested_columns = extract_columns_from_nl(natural_language_query)
if table_name and requested_columns:
if target_type == 'sql':
generated_query = fix_create_table_sql(generated_query, table_name, requested_columns)
elif target_type == 'mongodb':
generated_query = fix_create_collection_mongo(generated_query, table_name, requested_columns)
# โ
POST-PROCESSING: Fix malformed UPDATE queries
if 'update' in natural_language_query.lower() and 'set' in natural_language_query.lower():
if target_type == 'sql':
generated_query = fix_update_query_sql(generated_query, natural_language_query)
elif target_type == 'mongodb':
generated_query = fix_update_query_mongodb(generated_query, natural_language_query)
# โ
POST-PROCESSING: Fix wrong operation type (SELECT vs DELETE, etc.)
if target_type == 'sql':
generated_query = fix_sql_operation_type(generated_query, natural_language_query)
elif target_type == 'mongodb':
generated_query = fix_mongodb_operation_type(generated_query, natural_language_query)
# โ
POST-PROCESSING: Fix missing curly braces in MongoDB queries
if target_type == 'mongodb':
generated_query = fix_mongodb_missing_braces(generated_query)
# โ
POST-PROCESSING: Fix comparison operators (>, <, >=, <=)
if target_type == 'sql':
generated_query = fix_comparison_operator_sql(generated_query, natural_language_query)
elif target_type == 'mongodb':
generated_query = fix_comparison_operator_mongodb(generated_query, natural_language_query)
return generated_query
@st.cache_resource
def load_model(model_path):
"""Load the unified NoNoQL model (cached)"""
model = None
# For HuggingFace Hub paths (contain '/'), always try to load
if '/' in model_path or not os.path.exists(model_path):
model = TexQLModel(model_path)
elif os.path.exists(model_path):
model = TexQLModel(model_path)
else:
st.error(f"โ Model path not found: {model_path}")
return model
def save_query_history(nl_query, sql_query, mongodb_query, max_history=500):
"""Save query to history with size limit"""
if 'history' not in st.session_state:
st.session_state.history = []
st.session_state.history.append({
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
'natural_language': nl_query,
'sql': sql_query,
'mongodb': mongodb_query
})
# Keep only the most recent entries
if len(st.session_state.history) > max_history:
st.session_state.history = st.session_state.history[-max_history:]
persist_query_history(st.session_state.history)
def delete_history_entry(index):
"""Delete a specific history entry"""
if 'history' in st.session_state and 0 <= index < len(st.session_state.history):
st.session_state.history.pop(index)
persist_query_history(st.session_state.history)
def load_query_history():
"""Load query history from disk"""
try:
if not os.path.exists(HISTORY_FILE_PATH):
return []
with open(HISTORY_FILE_PATH, "r", encoding="utf-8") as history_file:
history = json.load(history_file)
if isinstance(history, list):
return history
return []
except Exception:
return []
def persist_query_history(history):
"""Persist query history to disk"""
try:
os.makedirs(os.path.dirname(HISTORY_FILE_PATH), exist_ok=True)
with open(HISTORY_FILE_PATH, "w", encoding="utf-8") as history_file:
json.dump(history, history_file, indent=2)
except Exception:
pass # Silently fail on HF Spaces (read-only filesystem)
def load_schema():
"""Load database schema from disk"""
try:
if not os.path.exists(SCHEMA_FILE_PATH):
return DEFAULT_SCHEMA
with open(SCHEMA_FILE_PATH, "r", encoding="utf-8") as schema_file:
schema = schema_file.read()
return schema if schema.strip() else DEFAULT_SCHEMA
except Exception:
return DEFAULT_SCHEMA
def persist_schema(schema):
"""Persist database schema to disk"""
try:
os.makedirs(os.path.dirname(SCHEMA_FILE_PATH), exist_ok=True)
with open(SCHEMA_FILE_PATH, "w", encoding="utf-8") as schema_file:
schema_file.write(schema)
except Exception:
pass # Silently fail on HF Spaces (read-only filesystem)
def main():
# Show environment info
if is_hf_space():
st.info("๐ค Running on HuggingFace Spaces - Model loaded from Hub")
if 'history' not in st.session_state:
st.session_state.history = load_query_history()
if 'schema' not in st.session_state:
st.session_state.schema = load_schema()
if 'schema_edit_mode' not in st.session_state:
st.session_state.schema_edit_mode = False
# Sidebar
with st.sidebar:
st.header("โ๏ธ Configuration")
# Model path
st.subheader("Model Path")
model_path = st.text_input(
"NoNoQL Model Path",
value=DEFAULT_MODEL_PATH,
help="HuggingFace repo (user/repo) or local path"
)
# Show model source
if '/' in model_path:
st.caption(f"๐ฅ Loading from HuggingFace: [{model_path}](https://huggingface.co/{model_path})")
else:
st.caption(f"๐ Loading from local path: {model_path}")
# Generation parameters
st.subheader("Generation Parameters")
temperature = st.slider(
"Temperature",
min_value=0.1,
max_value=1.0,
value=0.3, # โ
Lower default = less hallucination
step=0.1,
help="Lower = more focused, Higher = more creative"
)
num_beams = st.slider(
"Beam Search Width",
min_value=1,
max_value=10,
value=10, # โ
Higher value = more accurate results
help="Higher values improve accuracy (recommended: keep at 10)"
)
repetition_penalty = st.slider(
"Repetition Penalty",
min_value=1.0,
max_value=2.0,
value=1.2, # โ
Discourages adding extra unwanted columns
step=0.1,
help="Higher = less repetition (prevents hallucinating extra columns)"
)
length_penalty = st.slider(
"Length Penalty",
min_value=0.5,
max_value=1.5,
value=0.8, # โ
Prefer shorter outputs
step=0.1,
help="Lower = prefer shorter outputs, Higher = prefer longer outputs"
)
# Load models button
if st.button("๐ Load/Reload Models"):
st.cache_resource.clear()
st.rerun()
# History management
st.subheader("๐ History Settings")
max_history_size = st.number_input(
"Max History Entries",
min_value=10,
max_value=1000,
value=500,
step=10,
help="Maximum number of queries to keep in history"
)
# Database schema info
st.subheader("๐ Database Schema")
# Toggle edit mode
col1, col2 = st.columns([1, 3])
with col1:
if st.button("โ๏ธ Edit" if not st.session_state.schema_edit_mode else "๐๏ธ View"):
st.session_state.schema_edit_mode = not st.session_state.schema_edit_mode
st.rerun()
with col2:
if st.session_state.schema_edit_mode:
st.info("โ๏ธ Editing Mode")
else:
st.caption("View your database tables and columns")
if st.session_state.schema_edit_mode:
# Edit mode - text area
edited_schema = st.text_area(
"Edit Database Schema",
value=st.session_state.schema,
height=300,
help="Define your database tables and columns. Use Markdown format."
)
col1, col2 = st.columns(2)
with col1:
if st.button("๐พ Save Schema", use_container_width=True):
st.session_state.schema = edited_schema
persist_schema(edited_schema)
if is_hf_space():
st.warning("โ ๏ธ Schema saved to session only (HF Spaces has read-only filesystem)")
else:
st.success("Schema saved!")
st.session_state.schema_edit_mode = False
st.rerun()
with col2:
if st.button("๐ Reset to Default", use_container_width=True):
st.session_state.schema = DEFAULT_SCHEMA
persist_schema(DEFAULT_SCHEMA)
st.success("Schema reset to default!")
st.rerun()
else:
# View mode - expandable display
with st.expander("View Available Tables", expanded=False):
st.markdown(st.session_state.schema)
# Load model
with st.spinner("Loading model..."):
model = load_model(model_path)
# Model status
if model and model.loaded:
device_info = "๐ฎ GPU" if model.device == "cuda" else "๐ป CPU"
st.success(f"โ
Model Loaded ({device_info})")
st.info("๐ก This model generates both SQL and MongoDB queries")
else:
st.error("โ ๏ธ Model Not Available - Please check the model path")
# Query input
st.subheader("๐ค Enter Your Query")
# Example queries dropdown
with st.expander("๐ก Example Queries - Click to expand"):
examples = [
"Show all employees",
"Find employees where salary is greater than 50000",
"Get all departments with budget more than 100000",
"Insert a new employee with name John Doe, email john@example.com, department Engineering",
"Update employees set department to Sales where employee_id is 101",
"Delete orders with total_amount less than 1000",
"Count all products in Electronics category",
"Show top 10 employees ordered by salary",
]
selected_example = st.selectbox(
"Choose an example query:",
[""] + examples,
index=0,
format_func=lambda x: "Select an example..." if x == "" else x
)
if selected_example and st.button("๐ Use This Example", use_container_width=True):
st.session_state.user_query = selected_example
st.rerun()
user_query = st.text_area(
"or",
value=st.session_state.get('user_query', ''),
height=100,
placeholder="write your query here..."
)
# Generate button
if st.button("๐ Generate Queries"):
if not user_query.strip():
st.warning("Please enter a query")
elif not model or not model.loaded:
st.error("Model is not loaded. Please check the model path and reload.")
else:
with st.spinner("Generating queries..."):
# Generate both SQL and MongoDB from the same model
sql_query = model.generate_query(
user_query,
target_type='sql',
temperature=temperature,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty
)
mongodb_query = model.generate_query(
user_query,
target_type='mongodb',
temperature=temperature,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty
)
# Save to history
save_query_history(user_query, sql_query, mongodb_query, max_history_size)
# Display results
st.markdown("---")
st.success("โ
Queries Generated Successfully!")
# Input query
st.markdown('<div class="query-box">', unsafe_allow_html=True)
st.markdown("**๐ Your Query:**")
st.code(user_query, language="text")
st.markdown('</div>', unsafe_allow_html=True)
# Results in columns
col1, col2 = st.columns(2)
with col1:
st.markdown("### ๐๏ธ SQL Query")
st.code(sql_query, language="sql")
# Copy button
if st.button("๐ Copy SQL", key="copy_sql"):
st.session_state.clipboard = sql_query
st.success("Copied to clipboard!")
with col2:
st.markdown("### ๐ MongoDB Query")
st.code(mongodb_query, language="javascript")
# Copy button
if st.button("๐ Copy MongoDB", key="copy_mongo"):
st.session_state.clipboard = mongodb_query
st.success("Copied to clipboard!")
# Query history
if 'history' in st.session_state and st.session_state.history:
st.markdown("---")
st.subheader("๐ Query History")
# History management controls
col1, col2, col3 = st.columns([2, 1, 1])
with col1:
search_term = st.text_input(
"๐ Search History",
placeholder="Search in queries...",
label_visibility="collapsed"
)
with col2:
sort_order = st.selectbox(
"Sort",
["Newest First", "Oldest First"],
label_visibility="collapsed"
)
with col3:
show_limit = st.number_input(
"Show",
min_value=5,
max_value=100,
value=10,
step=5,
label_visibility="collapsed"
)
# Action buttons
col1, col2 = st.columns(2)
with col1:
if st.button("๐๏ธ Clear All History"):
st.session_state.history = []
persist_query_history(st.session_state.history)
st.rerun()
with col2:
if st.button("๐พ Export History"):
history_json = json.dumps(st.session_state.history, indent=2)
st.download_button(
label="Download History (JSON)",
data=history_json,
file_name=f"nonoql_history_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
mime="application/json"
)
# Filter history
filtered_history = st.session_state.history
if search_term:
search_lower = search_term.lower()
filtered_history = [
entry for entry in st.session_state.history
if search_lower in entry['natural_language'].lower() or
search_lower in entry.get('sql', '').lower() or
search_lower in entry.get('mongodb', '').lower()
]
# Sort history
if sort_order == "Oldest First":
display_history = filtered_history[:show_limit]
else:
display_history = list(reversed(filtered_history[-show_limit:]))
# Display count
st.markdown(f"**Showing {len(display_history)} of {len(filtered_history)} queries** (Total: {len(st.session_state.history)})")
if not display_history:
st.info("No queries found matching your search.")
# Display history entries
for display_idx, entry in enumerate(display_history):
# Find actual index in original history for deletion
actual_idx = st.session_state.history.index(entry)
with st.expander(
f"๐ {entry['timestamp']} - {entry['natural_language'][:60]}...",
expanded=False
):
# Action buttons for this entry
col1, col2, col3 = st.columns([3, 1, 1])
with col1:
st.markdown(f"**Natural Language Query:**")
st.info(entry['natural_language'])
with col2:
if st.button("๐ Rerun", key=f"rerun_{actual_idx}"):
st.session_state.user_query = entry['natural_language']
st.rerun()
with col3:
if st.button("๐๏ธ Delete", key=f"del_{actual_idx}"):
delete_history_entry(actual_idx)
st.rerun()
# Display queries
col1, col2 = st.columns(2)
with col1:
st.markdown("**SQL Query:**")
if entry.get('sql'):
st.code(entry['sql'], language="sql")
else:
st.text("N/A")
with col2:
st.markdown("**MongoDB Query:**")
if entry.get('mongodb'):
st.code(entry['mongodb'], language="javascript")
else:
st.text("N/A")
# Footer
st.markdown("---")
st.markdown("""
<div style='text-align: center; color: #666; padding: 2rem;'>
<p>NoNoQL - Natural Language to Query Generator</p>
<p>Powered by T5 Transformer Models | Built with Streamlit</p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()
|