orachamp1981 commited on
Commit
53d89cd
Β·
verified Β·
1 Parent(s): 9762962

Upload 4 files

Browse files
Files changed (3) hide show
  1. data_loader.py +27 -5
  2. model.py +34 -28
  3. sql_templates.py +2 -2
data_loader.py CHANGED
@@ -1,10 +1,32 @@
1
  # data_loader.py
2
 
 
 
3
  def load_rules(file_path="data/train_data.txt"):
4
  data = {}
5
- with open(file_path, "r", encoding="utf-8") as file:
6
- for line in file:
7
- if "=" in line:
8
- key, value = line.strip().split("=", 1)
9
- data[key.strip().lower()] = value.strip()
 
10
  return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # data_loader.py
2
 
3
+ import os
4
+
5
  def load_rules(file_path="data/train_data.txt"):
6
  data = {}
7
+ if os.path.exists(file_path):
8
+ with open(file_path, "r", encoding="utf-8") as file:
9
+ for line in file:
10
+ if "=" in line:
11
+ key, value = line.strip().split("=", 1)
12
+ data[key.strip().lower()] = value.strip()
13
  return data
14
+
15
+ def detect_domain(prompt):
16
+ prompt = prompt.lower()
17
+ if any(word in prompt for word in ["salary", "financial", "transaction", "ledger"]):
18
+ return "data/finance.txt"
19
+ elif any(word in prompt for word in ["employee", "hr", "hiring"]):
20
+ return "data/hr.txt"
21
+ elif any(word in prompt for word in ["sale", "customer", "order"]):
22
+ return "data/sales.txt"
23
+ else:
24
+ return None
25
+
26
+ def load_rules_by_domain(prompt):
27
+ domain_file = detect_domain(prompt)
28
+ if domain_file and os.path.exists(domain_file):
29
+ domain_rules = load_rules(domain_file)
30
+ if prompt in domain_rules:
31
+ return domain_rules[prompt]
32
+ return None # fallback will be handled in main logic
model.py CHANGED
@@ -1,68 +1,75 @@
1
  from sentence_transformers import SentenceTransformer, util
2
  from sql_templates import sql_templates, sql_keyword_aliases, fuzzy_aliases, conflicting_phrases, greeting_templates
3
- from data_loader import load_rules
4
  import torch
5
-
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
 
8
- # πŸ“˜ Load rules
9
- rules = load_rules()
10
-
11
  # πŸ” Load semantic model
12
  model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
13
- train_prompts = list(rules.keys())
14
- train_embeddings = model.encode(train_prompts, convert_to_tensor=True)
15
 
16
  # πŸ€– Load local LLM model
17
  llm_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
18
  tokenizer = AutoTokenizer.from_pretrained(llm_name)
19
- llm_model = AutoModelForCausalLM.from_pretrained(llm_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
20
- llm_pipeline = pipeline("text-generation", model=llm_model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def oracle_sql_suggester(prompt):
23
  prompt_clean = prompt.strip().lower()
24
 
25
- # βœ… Exact rule
26
- if prompt_clean in rules:
27
- return rules[prompt_clean]
 
 
28
 
29
- # βœ… Greeting handling
30
  for greet_key, greet_reply in greeting_templates.items():
31
  if greet_key in prompt_clean:
32
  return greet_reply
33
-
34
- # βœ… Conflicting phrase
35
  for terms, response in conflicting_phrases.items():
36
  if all(term in prompt_clean for term in terms):
37
  return response
38
 
39
- # βœ… Keyword alias
40
  for word in prompt_clean.split():
41
  if word in sql_keyword_aliases:
42
  mapped_key = sql_keyword_aliases[word]
43
  return sql_templates.get(mapped_key)
44
 
45
- # βœ… Template match
46
  for key, template in sql_templates.items():
47
  if key in prompt_clean or key.replace("_", " ") in prompt_clean:
48
  return template
49
 
50
- # βœ… Fuzzy match
51
  for fuzzy_phrase, mapped_key in fuzzy_aliases.items():
52
  if fuzzy_phrase in prompt_clean:
53
  return sql_templates.get(mapped_key)
54
 
55
- # βœ… Semantic match
56
- user_embedding = model.encode(prompt_clean, convert_to_tensor=True)
57
- cosine_scores = util.cos_sim(user_embedding, train_embeddings)
58
- top_match_index = torch.argmax(cosine_scores).item()
59
- top_score = cosine_scores[0][top_match_index].item()
 
60
 
61
- if top_score >= 0.7:
62
- matched_prompt = train_prompts[top_match_index]
63
- return rules[matched_prompt]
64
 
65
- # βœ… Local LLM fallback
66
  try:
67
  prompt_text = f"Generate an Oracle SQL query or guidance for the following request:\n{prompt}\n\nSQL:"
68
  output = llm_pipeline(prompt_text, max_new_tokens=256, do_sample=True, temperature=0.5)[0]["generated_text"]
@@ -70,4 +77,3 @@ def oracle_sql_suggester(prompt):
70
  except Exception as e:
71
  print("⚠️ Local LLM error:", e)
72
  return "πŸ€– Sorry, I couldn’t process that locally. Please try a simpler prompt."
73
-
 
1
  from sentence_transformers import SentenceTransformer, util
2
  from sql_templates import sql_templates, sql_keyword_aliases, fuzzy_aliases, conflicting_phrases, greeting_templates
3
+ from data_loader import load_rules, load_rules_by_domain, detect_domain
4
  import torch
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
 
 
 
 
7
  # πŸ” Load semantic model
8
  model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
 
 
9
 
10
  # πŸ€– Load local LLM model
11
  llm_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
12
  tokenizer = AutoTokenizer.from_pretrained(llm_name)
13
+ llm_model = AutoModelForCausalLM.from_pretrained(
14
+ llm_name,
15
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
16
+ )
17
+ llm_pipeline = pipeline(
18
+ "text-generation",
19
+ model=llm_model,
20
+ tokenizer=tokenizer,
21
+ device=0 if torch.cuda.is_available() else -1
22
+ )
23
+
24
+ # βœ… Load global training rules once for semantic match
25
+ global_rules = load_rules("data/train_data.txt")
26
+ train_prompts = list(global_rules.keys())
27
+ train_embeddings = model.encode(train_prompts, convert_to_tensor=True) if train_prompts else None
28
 
29
  def oracle_sql_suggester(prompt):
30
  prompt_clean = prompt.strip().lower()
31
 
32
+ # βœ… Step 1: Exact match in domain-specific rules
33
+ domain_match = load_rules_by_domain(prompt_clean)
34
+ if domain_match:
35
+ #return domain_match
36
+ return domain_match.replace("\\n", "\n")
37
 
38
+ # βœ… Step 2: Check hardcoded greeting or conflict response
39
  for greet_key, greet_reply in greeting_templates.items():
40
  if greet_key in prompt_clean:
41
  return greet_reply
42
+
 
43
  for terms, response in conflicting_phrases.items():
44
  if all(term in prompt_clean for term in terms):
45
  return response
46
 
47
+ # βœ… Step 3: Aliases and fuzzy matching
48
  for word in prompt_clean.split():
49
  if word in sql_keyword_aliases:
50
  mapped_key = sql_keyword_aliases[word]
51
  return sql_templates.get(mapped_key)
52
 
 
53
  for key, template in sql_templates.items():
54
  if key in prompt_clean or key.replace("_", " ") in prompt_clean:
55
  return template
56
 
 
57
  for fuzzy_phrase, mapped_key in fuzzy_aliases.items():
58
  if fuzzy_phrase in prompt_clean:
59
  return sql_templates.get(mapped_key)
60
 
61
+ # βœ… Step 4: Semantic match against full train_data.txt
62
+ if train_embeddings is not None and len(train_embeddings) > 0:
63
+ user_embedding = model.encode(prompt_clean, convert_to_tensor=True)
64
+ cosine_scores = util.cos_sim(user_embedding, train_embeddings)
65
+ top_match_index = torch.argmax(cosine_scores).item()
66
+ top_score = cosine_scores[0][top_match_index].item()
67
 
68
+ if top_score >= 0.7:
69
+ matched_prompt = train_prompts[top_match_index]
70
+ return global_rules[matched_prompt].replace("\\n", "\n") # ⬅️ Support multiline
71
 
72
+ # βœ… Step 5: LLM Fallback
73
  try:
74
  prompt_text = f"Generate an Oracle SQL query or guidance for the following request:\n{prompt}\n\nSQL:"
75
  output = llm_pipeline(prompt_text, max_new_tokens=256, do_sample=True, temperature=0.5)[0]["generated_text"]
 
77
  except Exception as e:
78
  print("⚠️ Local LLM error:", e)
79
  return "πŸ€– Sorry, I couldn’t process that locally. Please try a simpler prompt."
 
sql_templates.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from collections import defaultdict
2
 
3
  sql_templates = {
@@ -23,7 +25,6 @@ sql_keyword_aliases = {
23
  "delete": "delete"
24
  }
25
 
26
- # 🧠 NEW fuzzy aliases
27
  fuzzy_aliases = {
28
  "grouped result": "group_by",
29
  "combine tables": "join_example",
@@ -46,7 +47,6 @@ conflicting_phrases = {
46
  ("delete", "new"): "⚠️ You cannot delete something that doesn't exist yet.",
47
  }
48
 
49
- # πŸ€– Greeting phrases and responses
50
  greeting_templates = {
51
  "hello": "πŸ‘‹ Hello! How can I assist you with SQL or PL/SQL today?",
52
  "hi": "πŸ‘‹ Hi there! Need help with Oracle SQL or PL/SQL?",
 
1
+ # sql_templates.py
2
+
3
  from collections import defaultdict
4
 
5
  sql_templates = {
 
25
  "delete": "delete"
26
  }
27
 
 
28
  fuzzy_aliases = {
29
  "grouped result": "group_by",
30
  "combine tables": "join_example",
 
47
  ("delete", "new"): "⚠️ You cannot delete something that doesn't exist yet.",
48
  }
49
 
 
50
  greeting_templates = {
51
  "hello": "πŸ‘‹ Hello! How can I assist you with SQL or PL/SQL today?",
52
  "hi": "πŸ‘‹ Hi there! Need help with Oracle SQL or PL/SQL?",