orachamp1981 commited on
Commit
0bfab9f
Β·
verified Β·
1 Parent(s): 4a09457

Upload 2 files

Browse files
Files changed (2) hide show
  1. data_loader.py +12 -1
  2. model.py +6 -3
data_loader.py CHANGED
@@ -2,6 +2,17 @@
2
 
3
  import os
4
 
 
 
 
 
 
 
 
 
 
 
 
5
  def load_rules(file_path="data/train_data.txt"):
6
  data = {}
7
  if os.path.exists(file_path):
@@ -9,7 +20,7 @@ def load_rules(file_path="data/train_data.txt"):
9
  for line in file:
10
  if "=" in line:
11
  key, value = line.strip().split("=", 1)
12
- data[key.strip().lower()] = value.strip()
13
  return data
14
 
15
  def detect_domain(prompt):
 
2
 
3
  import os
4
 
5
+ def clean_sql_output(raw_text):
6
+ return (
7
+ raw_text.strip()
8
+ .replace("\\n", "\n") # Handle escaped newlines
9
+ .replace(";\n", ";\n") # Normalize semicolon-linebreak
10
+ .replace(";", ";\n") # Add line breaks after semicolons
11
+ .replace("\n\n", "\n") # Remove double line breaks
12
+ .replace(";\\n", ".;\\n") # Remove double line breaks
13
+ .strip()
14
+ )
15
+
16
  def load_rules(file_path="data/train_data.txt"):
17
  data = {}
18
  if os.path.exists(file_path):
 
20
  for line in file:
21
  if "=" in line:
22
  key, value = line.strip().split("=", 1)
23
+ data[key.strip().lower()] = clean_sql_output(value)
24
  return data
25
 
26
  def detect_domain(prompt):
model.py CHANGED
@@ -4,6 +4,10 @@ from data_loader import load_rules, load_rules_by_domain, detect_domain
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
 
 
 
 
 
7
  # πŸ” Load semantic model
8
  model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
9
 
@@ -32,8 +36,7 @@ def oracle_sql_suggester(prompt):
32
  # βœ… Step 1: Exact match in domain-specific rules
33
  domain_match = load_rules_by_domain(prompt_clean)
34
  if domain_match:
35
- #return domain_match
36
- return domain_match.replace("\\n", "\n")
37
 
38
  # βœ… Step 2: Check hardcoded greeting or conflict response
39
  for greet_key, greet_reply in greeting_templates.items():
@@ -67,7 +70,7 @@ def oracle_sql_suggester(prompt):
67
 
68
  if top_score >= 0.7:
69
  matched_prompt = train_prompts[top_match_index]
70
- return global_rules[matched_prompt].replace("\\n", "\n") # ⬅️ Support multiline
71
 
72
  # βœ… Step 5: LLM Fallback
73
  try:
 
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
 
7
+ # πŸ”§ Clean up response formatting
8
+ def clean_response(text):
9
+ return text.replace("\\n", "\n").replace(";;", ";")
10
+
11
  # πŸ” Load semantic model
12
  model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
13
 
 
36
  # βœ… Step 1: Exact match in domain-specific rules
37
  domain_match = load_rules_by_domain(prompt_clean)
38
  if domain_match:
39
+ return clean_response(domain_match)
 
40
 
41
  # βœ… Step 2: Check hardcoded greeting or conflict response
42
  for greet_key, greet_reply in greeting_templates.items():
 
70
 
71
  if top_score >= 0.7:
72
  matched_prompt = train_prompts[top_match_index]
73
+ return clean_response(global_rules[matched_prompt])
74
 
75
  # βœ… Step 5: LLM Fallback
76
  try: