orachamp1981 commited on
Commit
79367ca
·
verified ·
1 Parent(s): 1164f42

Upload 6 files

Browse files
Files changed (2) hide show
  1. model.py +28 -19
  2. sql_templates.py +31 -0
model.py CHANGED
@@ -1,6 +1,6 @@
1
  # model.py
2
  from sentence_transformers import SentenceTransformer, util
3
- from sql_templates import sql_templates # new import
4
  import torch
5
 
6
  # Load training rules (string-to-SQL map)
@@ -16,35 +16,44 @@ train_embeddings = model.encode(train_prompts, convert_to_tensor=True)
16
 
17
  def oracle_sql_suggester(prompt):
18
  prompt_clean = prompt.strip().lower()
19
-
20
  # Try direct rule match
21
  if prompt_clean in rules:
22
  return rules[prompt_clean]
 
 
 
 
 
 
23
 
24
- # Semantic matching
25
- user_embedding = model.encode(prompt_clean, convert_to_tensor=True)
26
- cosine_scores = util.cos_sim(user_embedding, train_embeddings)
27
-
28
- top_match_index = torch.argmax(cosine_scores).item()
29
- top_score = cosine_scores[0][top_match_index].item()
 
 
 
 
 
 
 
 
30
 
31
- if top_score >= 0.7:
32
- matched_prompt = train_prompts[top_match_index]
33
- return rules[matched_prompt]
34
 
35
- # Check template keywords
36
- for key in sql_templates:
37
- if key.replace("_", " ") in prompt_clean or key in prompt_clean:
38
- return sql_templates[key]
39
-
40
- # Semantic match
41
  user_embedding = model.encode(prompt_clean, convert_to_tensor=True)
42
  cosine_scores = util.cos_sim(user_embedding, train_embeddings)
 
43
  top_match_index = torch.argmax(cosine_scores).item()
44
  top_score = cosine_scores[0][top_match_index].item()
45
 
46
  if top_score >= 0.7:
47
  matched_prompt = train_prompts[top_match_index]
48
  return rules[matched_prompt]
49
-
50
- return "🤖 Sorry, I couldn’t understand that. Please try rephrasing your request."
 
 
 
1
  # model.py
2
  from sentence_transformers import SentenceTransformer, util
3
+ from sql_templates import sql_templates, sql_keyword_aliases, fuzzy_aliases
4
  import torch
5
 
6
  # Load training rules (string-to-SQL map)
 
16
 
17
  def oracle_sql_suggester(prompt):
18
  prompt_clean = prompt.strip().lower()
19
+
20
  # Try direct rule match
21
  if prompt_clean in rules:
22
  return rules[prompt_clean]
23
+
24
+ # Check template keywords (first!)
25
+ for word in prompt_clean.split():
26
+ if word in sql_keyword_aliases:
27
+ mapped_key = sql_keyword_aliases[word]
28
+ return sql_templates[mapped_key]
29
 
30
+ for key, template in sql_templates.items():
31
+ if key.replace("_", " ") in prompt_clean or key in prompt_clean:
32
+ return template
33
+ # Try keyword aliases (word-level match)
34
+ for word in prompt_clean.split():
35
+ if word in sql_keyword_aliases:
36
+ return sql_templates[sql_keyword_aliases[word]]
37
+
38
+ # Try fuzzy alias matches
39
+ for fuzzy_phrase, mapped_key in fuzzy_aliases.items():
40
+ if fuzzy_phrase in prompt_clean:
41
+ return sql_templates[mapped_key]
42
+
43
+
44
 
 
 
 
45
 
46
+ # Semantic matching
 
 
 
 
 
47
  user_embedding = model.encode(prompt_clean, convert_to_tensor=True)
48
  cosine_scores = util.cos_sim(user_embedding, train_embeddings)
49
+
50
  top_match_index = torch.argmax(cosine_scores).item()
51
  top_score = cosine_scores[0][top_match_index].item()
52
 
53
  if top_score >= 0.7:
54
  matched_prompt = train_prompts[top_match_index]
55
  return rules[matched_prompt]
56
+
57
+ # Fallback
58
+ return "🤖 Sorry, I couldn’t understand that. Please try rephrasing your request."
59
+
sql_templates.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  sql_templates = {
2
  "basic_select": "SELECT column1, column2 FROM table_name;",
3
  "select_where": "SELECT column1 FROM table_name WHERE condition;",
@@ -8,3 +10,32 @@ sql_templates = {
8
  "update": "UPDATE table_name SET col1 = val1 WHERE condition;",
9
  "delete": "DELETE FROM table_name WHERE condition;"
10
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+
3
  sql_templates = {
4
  "basic_select": "SELECT column1, column2 FROM table_name;",
5
  "select_where": "SELECT column1 FROM table_name WHERE condition;",
 
10
  "update": "UPDATE table_name SET col1 = val1 WHERE condition;",
11
  "delete": "DELETE FROM table_name WHERE condition;"
12
  }
13
+
14
+ sql_keyword_aliases = {
15
+ "select": "basic_select",
16
+ "where": "select_where",
17
+ "join": "join_example",
18
+ "group": "group_by",
19
+ "group by": "group_by",
20
+ "having": "having",
21
+ "insert": "insert",
22
+ "update": "update",
23
+ "delete": "delete"
24
+ }
25
+
26
+ # 🧠 NEW fuzzy aliases
27
+ fuzzy_aliases = {
28
+ "grouped result": "group_by",
29
+ "combine tables": "join_example",
30
+ "add new row": "insert",
31
+ "modify records": "update",
32
+ "remove entry": "delete",
33
+ "get rows": "basic_select",
34
+ "filter records": "select_where",
35
+ "apply condition": "select_where",
36
+ "summarize": "group_by",
37
+ "count groups": "group_by",
38
+ "condition on groups": "having",
39
+ "change row": "update",
40
+ "erase record": "delete"
41
+ }