Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- model.py +28 -19
- sql_templates.py +31 -0
model.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# model.py
|
| 2 |
from sentence_transformers import SentenceTransformer, util
|
| 3 |
-
from sql_templates import sql_templates
|
| 4 |
import torch
|
| 5 |
|
| 6 |
# Load training rules (string-to-SQL map)
|
|
@@ -16,35 +16,44 @@ train_embeddings = model.encode(train_prompts, convert_to_tensor=True)
|
|
| 16 |
|
| 17 |
def oracle_sql_suggester(prompt):
|
| 18 |
prompt_clean = prompt.strip().lower()
|
| 19 |
-
|
| 20 |
# Try direct rule match
|
| 21 |
if prompt_clean in rules:
|
| 22 |
return rules[prompt_clean]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
if top_score >= 0.7:
|
| 32 |
-
matched_prompt = train_prompts[top_match_index]
|
| 33 |
-
return rules[matched_prompt]
|
| 34 |
|
| 35 |
-
|
| 36 |
-
for key in sql_templates:
|
| 37 |
-
if key.replace("_", " ") in prompt_clean or key in prompt_clean:
|
| 38 |
-
return sql_templates[key]
|
| 39 |
-
|
| 40 |
-
# Semantic match
|
| 41 |
user_embedding = model.encode(prompt_clean, convert_to_tensor=True)
|
| 42 |
cosine_scores = util.cos_sim(user_embedding, train_embeddings)
|
|
|
|
| 43 |
top_match_index = torch.argmax(cosine_scores).item()
|
| 44 |
top_score = cosine_scores[0][top_match_index].item()
|
| 45 |
|
| 46 |
if top_score >= 0.7:
|
| 47 |
matched_prompt = train_prompts[top_match_index]
|
| 48 |
return rules[matched_prompt]
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
| 1 |
# model.py
|
| 2 |
from sentence_transformers import SentenceTransformer, util
|
| 3 |
+
from sql_templates import sql_templates, sql_keyword_aliases, fuzzy_aliases
|
| 4 |
import torch
|
| 5 |
|
| 6 |
# Load training rules (string-to-SQL map)
|
|
|
|
| 16 |
|
| 17 |
def oracle_sql_suggester(prompt):
|
| 18 |
prompt_clean = prompt.strip().lower()
|
| 19 |
+
|
| 20 |
# Try direct rule match
|
| 21 |
if prompt_clean in rules:
|
| 22 |
return rules[prompt_clean]
|
| 23 |
+
|
| 24 |
+
# Check template keywords (first!)
|
| 25 |
+
for word in prompt_clean.split():
|
| 26 |
+
if word in sql_keyword_aliases:
|
| 27 |
+
mapped_key = sql_keyword_aliases[word]
|
| 28 |
+
return sql_templates[mapped_key]
|
| 29 |
|
| 30 |
+
for key, template in sql_templates.items():
|
| 31 |
+
if key.replace("_", " ") in prompt_clean or key in prompt_clean:
|
| 32 |
+
return template
|
| 33 |
+
# Try keyword aliases (word-level match)
|
| 34 |
+
for word in prompt_clean.split():
|
| 35 |
+
if word in sql_keyword_aliases:
|
| 36 |
+
return sql_templates[sql_keyword_aliases[word]]
|
| 37 |
+
|
| 38 |
+
# Try fuzzy alias matches
|
| 39 |
+
for fuzzy_phrase, mapped_key in fuzzy_aliases.items():
|
| 40 |
+
if fuzzy_phrase in prompt_clean:
|
| 41 |
+
return sql_templates[mapped_key]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
# Semantic matching
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
user_embedding = model.encode(prompt_clean, convert_to_tensor=True)
|
| 48 |
cosine_scores = util.cos_sim(user_embedding, train_embeddings)
|
| 49 |
+
|
| 50 |
top_match_index = torch.argmax(cosine_scores).item()
|
| 51 |
top_score = cosine_scores[0][top_match_index].item()
|
| 52 |
|
| 53 |
if top_score >= 0.7:
|
| 54 |
matched_prompt = train_prompts[top_match_index]
|
| 55 |
return rules[matched_prompt]
|
| 56 |
+
|
| 57 |
+
# Fallback
|
| 58 |
+
return "🤖 Sorry, I couldn’t understand that. Please try rephrasing your request."
|
| 59 |
+
|
sql_templates.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
sql_templates = {
|
| 2 |
"basic_select": "SELECT column1, column2 FROM table_name;",
|
| 3 |
"select_where": "SELECT column1 FROM table_name WHERE condition;",
|
|
@@ -8,3 +10,32 @@ sql_templates = {
|
|
| 8 |
"update": "UPDATE table_name SET col1 = val1 WHERE condition;",
|
| 9 |
"delete": "DELETE FROM table_name WHERE condition;"
|
| 10 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
|
| 3 |
sql_templates = {
|
| 4 |
"basic_select": "SELECT column1, column2 FROM table_name;",
|
| 5 |
"select_where": "SELECT column1 FROM table_name WHERE condition;",
|
|
|
|
| 10 |
"update": "UPDATE table_name SET col1 = val1 WHERE condition;",
|
| 11 |
"delete": "DELETE FROM table_name WHERE condition;"
|
| 12 |
}
|
| 13 |
+
|
| 14 |
+
sql_keyword_aliases = {
|
| 15 |
+
"select": "basic_select",
|
| 16 |
+
"where": "select_where",
|
| 17 |
+
"join": "join_example",
|
| 18 |
+
"group": "group_by",
|
| 19 |
+
"group by": "group_by",
|
| 20 |
+
"having": "having",
|
| 21 |
+
"insert": "insert",
|
| 22 |
+
"update": "update",
|
| 23 |
+
"delete": "delete"
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
# 🧠 NEW fuzzy aliases
|
| 27 |
+
fuzzy_aliases = {
|
| 28 |
+
"grouped result": "group_by",
|
| 29 |
+
"combine tables": "join_example",
|
| 30 |
+
"add new row": "insert",
|
| 31 |
+
"modify records": "update",
|
| 32 |
+
"remove entry": "delete",
|
| 33 |
+
"get rows": "basic_select",
|
| 34 |
+
"filter records": "select_where",
|
| 35 |
+
"apply condition": "select_where",
|
| 36 |
+
"summarize": "group_by",
|
| 37 |
+
"count groups": "group_by",
|
| 38 |
+
"condition on groups": "having",
|
| 39 |
+
"change row": "update",
|
| 40 |
+
"erase record": "delete"
|
| 41 |
+
}
|