orachamp1981 commited on
Commit
641c418
Β·
verified Β·
1 Parent(s): bc655a0

Upload 12 files

Browse files
app.py CHANGED
@@ -1,90 +1,101 @@
1
- # app.py
2
-
3
- import gradio as gr
4
- from model import oracle_sql_suggester
5
-
6
- def chat_fn(message, history):
7
- response = oracle_sql_suggester(message)
8
- history = history or []
9
- history.append((message, response))
10
- return history, ""
11
-
12
- def retry_last(history):
13
- if history:
14
- last_user_msg = history[-1][0]
15
- response = oracle_sql_suggester(last_user_msg)
16
- history[-1] = (last_user_msg, response)
17
- return history
18
-
19
- def undo_last(history):
20
- if history:
21
- history.pop()
22
- return history
23
-
24
- with gr.Blocks(
25
- css="""
26
- body, html, .gradio-container {
27
- background-color: #121212 !important;
28
- color: #ffffff !important;
29
- font-family: 'Times New Roman', serif !important;
30
- }
31
-
32
- .gr-chatbot {
33
- background-color: #1e1e2f !important;
34
- color: #ffffff !important;
35
- }
36
-
37
- .gr-button, .gr-textbox textarea {
38
- font-family: 'Times New Roman', serif !important;
39
- }
40
-
41
- .message.user {
42
- background-color: #2a2a3b !important;
43
- color: #e0e0e0 !important;
44
- }
45
-
46
- .message.bot {
47
- background-color: #333344 !important;
48
- color: #ffffff !important;
49
- }
50
-
51
- .gr-input, .gr-textbox, textarea {
52
- background-color: #2a2a2a !important;
53
- color: #ffffff !important;
54
- font-family: 'Times New Roman', serif !important;
55
- }
56
-
57
- #chatbox-style {
58
- background-color: #ffffff !important; /* white background */
59
- color: #000000 !important; /* black text */
60
- font-family: "Times New Roman", serif;
61
- }
62
- #chatbox-style .message.bot {
63
- background-color: #f5f5f5 !important; /* light gray for bot bubbles */
64
- color: #000000 !important;
65
- }
66
- #chatbox-style .message.user {
67
- background-color: #e0e0e0 !important; /* light gray for user bubbles */
68
- color: #000000 !important;
69
- }
70
- """
71
- ) as demo:
72
- gr.Markdown("<h2 style='color: #ffffff; font-family: Times New Roman; text-align: center;'>🧠 Oracle SQL and PL/SQL Assistant</h2>")
73
- chatbot = gr.Chatbot(show_copy_button=True, height=450, elem_id="chatbox-style")
74
-
75
-
76
- with gr.Row():
77
- txt = gr.Textbox(placeholder="Type your SQL or PL/SQL question here...", lines=2, scale=8)
78
- submit_btn = gr.Button("➑️ Submit", scale=1)
79
- retry_btn = gr.Button("πŸ” Retry", scale=1)
80
- undo_btn = gr.Button("↩️ Undo", scale=1)
81
- clear_btn = gr.Button("🧹 Clear", scale=1)
82
-
83
- submit_btn.click(chat_fn, [txt, chatbot], [chatbot, txt])
84
- txt.submit(chat_fn, [txt, chatbot], [chatbot, txt])
85
-
86
- retry_btn.click(retry_last, [chatbot], [chatbot])
87
- undo_btn.click(undo_last, [chatbot], [chatbot])
88
- clear_btn.click(lambda: [], None, chatbot)
89
-
90
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- PATCHED app.py ---
2
+
3
+ import gradio as gr
4
+ from model import oracle_sql_suggester
5
+
6
+ def chat_fn(message, history):
7
+ response = oracle_sql_suggester(message)
8
+ history = history or []
9
+ history.append({"role": "user", "content": message})
10
+
11
+ # βœ… Debug print for safety
12
+ print("πŸ› οΈ Assistant Response:", repr(response))
13
+
14
+ if isinstance(response, list):
15
+ for item in response:
16
+ history.append(item)
17
+ else:
18
+ # βœ… Ensure the response is always a string
19
+ if not isinstance(response, str):
20
+ response = str(response)
21
+ history.append({"role": "assistant", "content": response})
22
+
23
+ return history, ""
24
+
25
+ def retry_last(history):
26
+ if history:
27
+ last_user_msg = history[-1][0]
28
+ response = oracle_sql_suggester(last_user_msg)
29
+ history[-1] = (last_user_msg, response)
30
+ return history
31
+
32
+ def undo_last(history):
33
+ if history:
34
+ history.pop()
35
+ return history
36
+
37
+ def process_upload(file):
38
+ from data_loader import load_prompts_from_file
39
+ path = file.name
40
+ data = load_prompts_from_file(path)
41
+ return f"βœ… Uploaded {len(data)} prompt pairs!"
42
+
43
+ with gr.Blocks(
44
+ css="""
45
+ body, html, .gradio-container {
46
+ background-color: #121212 !important;
47
+ color: #ffffff !important;
48
+ font-family: 'Times New Roman', serif !important;
49
+ }
50
+ .gr-chatbot {
51
+ background-color: #1e1e2f !important;
52
+ color: #ffffff !important;
53
+ }
54
+ .gr-button, .gr-textbox textarea {
55
+ font-family: 'Times New Roman', serif !important;
56
+ }
57
+ .message.user {
58
+ background-color: #2a2a3b !important;
59
+ color: #e0e0e0 !important;
60
+ }
61
+ .message.bot {
62
+ background-color: #333344 !important;
63
+ color: #ffffff !important;
64
+ }
65
+ .gr-input, .gr-textbox, textarea {
66
+ background-color: #2a2a2a !important;
67
+ color: #ffffff !important;
68
+ font-family: 'Times New Roman', serif !important;
69
+ }
70
+ #chatbox-style {
71
+ background-color: #ffffff !important;
72
+ color: #000000 !important;
73
+ font-family: "Times New Roman", serif;
74
+ }
75
+ #chatbox-style .message.bot {
76
+ background-color: #f5f5f5 !important;
77
+ color: #000000 !important;
78
+ }
79
+ #chatbox-style .message.user {
80
+ background-color: #e0e0e0 !important;
81
+ color: #000000 !important;
82
+ }
83
+ """
84
+ ) as demo:
85
+ gr.Markdown("<h2 style='color: #ffffff; font-family: Times New Roman; text-align: center;'>🧠 Oracle SQL and PL/SQL Assistant</h2>")
86
+ chatbot = gr.Chatbot(show_copy_button=True, height=450, elem_id="chatbox-style", type="messages")
87
+
88
+ with gr.Row():
89
+ txt = gr.Textbox(placeholder="Type your SQL or PL/SQL question here...", lines=2, scale=8)
90
+ submit_btn = gr.Button("➑️ Submit", scale=1)
91
+ retry_btn = gr.Button("πŸ” Retry", scale=1)
92
+ undo_btn = gr.Button("↩️ Undo", scale=1)
93
+ clear_btn = gr.Button("🧹 Clear", scale=1)
94
+
95
+ submit_btn.click(chat_fn, [txt, chatbot], [chatbot, txt])
96
+ txt.submit(chat_fn, [txt, chatbot], [chatbot, txt])
97
+ retry_btn.click(retry_last, [chatbot], [chatbot])
98
+ undo_btn.click(undo_last, [chatbot], [chatbot])
99
+ clear_btn.click(lambda: [], None, chatbot)
100
+
101
+ demo.launch()
concept_library.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "project lifecycle": ["Initiation", "Planning", "Execution", "Closure"],
3
+ "approval flow": ["Request", "Review", "Decision", "Notification"],
4
+ "inventory tracking": ["Stock In", "Stock Out", "Audit", "Reorder"],
5
+ "recruitment process": ["Job Posting", "Screening", "Interview", "Offer Letter"],
6
+ "onboarding process": ["Document Collection", "Induction", "System Setup", "Training", "Access Provision"]
7
+
8
+ }
data_loader.py CHANGED
@@ -1,22 +1,35 @@
1
- # data_loader.py
2
 
3
  import os
 
 
 
4
  import codecs
 
 
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
6
  def clean_sql_output(raw_text):
7
  try:
8
- # Decode escaped characters like \\n to real \n
9
  decoded = codecs.decode(raw_text.strip(), 'unicode_escape')
10
- # Clean up formatting
11
- return (
12
- decoded.replace(";;", ";")
13
- .replace("\n\n\n", "\n\n") # In case of extra breaks
14
- .strip()
15
- )
16
  except Exception as e:
17
- print("⚠️ Cleaning error:", e)
18
  return raw_text.strip()
19
 
 
20
  def load_rules(file_path="data/train_data.txt"):
21
  data = {}
22
  if os.path.exists(file_path):
@@ -27,6 +40,7 @@ def load_rules(file_path="data/train_data.txt"):
27
  data[key.strip().lower()] = clean_sql_output(value)
28
  return data
29
 
 
30
  def detect_domain(prompt):
31
  prompt = prompt.lower()
32
  if any(word in prompt for word in ["salary", "financial", "transaction", "ledger"]):
@@ -44,4 +58,183 @@ def load_rules_by_domain(prompt):
44
  domain_rules = load_rules(domain_file)
45
  if prompt in domain_rules:
46
  return domain_rules[prompt]
47
- return None # fallback will be handled in main logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data_loader
2
 
3
  import os
4
+ import re
5
+ import json
6
+ import csv
7
  import codecs
8
+ import requests
9
+ import PyPDF2
10
+ from docx import Document
11
+ import openpyxl
12
+ from bs4 import BeautifulSoup
13
 
14
+ # ? Normalize utility
15
+ def normalize_prompt(text):
16
+ text = text.strip().lower()
17
+ text = re.sub(r"[;:,.!?]+$", "", text)
18
+ text = re.sub(r"[^\w\s]", "", text)
19
+ #text = re.sub(r"\b(what|actually|please|tell|me|about|can|you|explain|is|the|do)\b", "", text)
20
+ text = re.sub(r"\b(what|actually|please|tell|me|about|can|you|explain|is|the|do|does|give|show)\b(?!\w)", "", text)
21
+ return re.sub(r"\s+", " ", text).strip()
22
+
23
+ # ? Output cleanup for SQL responses
24
  def clean_sql_output(raw_text):
25
  try:
 
26
  decoded = codecs.decode(raw_text.strip(), 'unicode_escape')
27
+ return decoded.replace(";;", ";").replace("\n\n\n", "\n\n").strip()
 
 
 
 
 
28
  except Exception as e:
29
+ print("?? Cleaning error:", e)
30
  return raw_text.strip()
31
 
32
+ # ? Existing basic rule loader
33
  def load_rules(file_path="data/train_data.txt"):
34
  data = {}
35
  if os.path.exists(file_path):
 
40
  data[key.strip().lower()] = clean_sql_output(value)
41
  return data
42
 
43
+ # ? Domain routing logic
44
  def detect_domain(prompt):
45
  prompt = prompt.lower()
46
  if any(word in prompt for word in ["salary", "financial", "transaction", "ledger"]):
 
58
  domain_rules = load_rules(domain_file)
59
  if prompt in domain_rules:
60
  return domain_rules[prompt]
61
+ return None
62
+
63
+ # ? Extended loaders for structured files
64
+ def load_txt(path):
65
+ pairs = []
66
+ with open(path, 'r', encoding='utf-8') as f:
67
+ for line in f:
68
+ if '=' in line:
69
+ prompt, answer = line.split('=', 1)
70
+ pairs.append((normalize_prompt(prompt), answer.strip()))
71
+ return pairs
72
+
73
+ def load_json(path):
74
+ pairs = []
75
+ with open(path, 'r', encoding='utf-8') as f:
76
+ for entry in json.load(f):
77
+ pairs.append((normalize_prompt(entry['prompt']), entry['answer'].strip()))
78
+ return pairs
79
+
80
+ def load_csv(path):
81
+ pairs = []
82
+ with open(path, newline='', encoding='utf-8') as csvfile:
83
+ reader = csv.DictReader(csvfile)
84
+ for row in reader:
85
+ if 'prompt' in row and 'answer' in row:
86
+ pairs.append((normalize_prompt(row['prompt']), row['answer'].strip()))
87
+ return pairs
88
+
89
+ def load_pdf(path):
90
+ pairs = []
91
+ with open(path, 'rb') as f:
92
+ reader = PyPDF2.PdfReader(f)
93
+ text = "\n".join([p.extract_text() for p in reader.pages if p.extract_text()])
94
+ for line in text.split("\n"):
95
+ if '=' in line:
96
+ prompt, answer = line.split('=', 1)
97
+ pairs.append((normalize_prompt(prompt), answer.strip()))
98
+ return pairs
99
+
100
+
101
+
102
+ def load_docx(path):
103
+ pairs = []
104
+ doc = Document(path)
105
+ for para in doc.paragraphs:
106
+ if "=" in para.text:
107
+ prompt, answer = para.text.split("=", 1)
108
+ pairs.append((normalize_prompt(prompt), answer.strip()))
109
+ return pairs
110
+
111
+ def load_xlsx(path):
112
+ pairs = []
113
+ wb = openpyxl.load_workbook(path)
114
+ for sheet in wb.worksheets:
115
+ for row in sheet.iter_rows(values_only=True):
116
+ if not row or len(row) < 2:
117
+ continue
118
+ prompt, answer = row[0], row[1]
119
+ if isinstance(prompt, str) and isinstance(answer, str) and "=" not in prompt:
120
+ pairs.append((normalize_prompt(prompt), answer.strip()))
121
+ elif isinstance(prompt, str) and "=" in prompt:
122
+ p, a = prompt.split("=", 1)
123
+ pairs.append((normalize_prompt(p), a.strip()))
124
+ return pairs
125
+
126
+
127
+
128
+ # ? Load from GitHub/HuggingFace (TXT/JSON)
129
+ def fetch_text_from_url(url):
130
+ try:
131
+ resp = requests.get(url, timeout=10)
132
+ resp.raise_for_status()
133
+ return resp.text
134
+ except Exception as e:
135
+ print(f"?? Error reading remote file {url}: {e}")
136
+ return ""
137
+ # ? Dispatcher for local files
138
+ def load_prompts_from_file(path):
139
+ if path.endswith('.txt'):
140
+ return load_txt(path)
141
+ elif path.endswith('.json'):
142
+ return load_json(path)
143
+ elif path.endswith('.csv'):
144
+ return load_csv(path)
145
+ elif path.endswith('.pdf'):
146
+ return load_pdf(path)
147
+ elif path.endswith('.docx'):
148
+ return load_docx(path)
149
+ elif path.endswith('.xlsx'):
150
+ return load_xlsx(path)
151
+ else:
152
+ print(f"? Unsupported format: {path}")
153
+ return []
154
+
155
+ def load_prompts_from_url(url):
156
+ pairs = []
157
+ text = fetch_text_from_url(url)
158
+ if not text:
159
+ return []
160
+
161
+ if url.endswith(".txt"):
162
+ for line in text.splitlines():
163
+ if '=' in line:
164
+ prompt, answer = line.split('=', 1)
165
+ pairs.append((normalize_prompt(prompt), answer.strip()))
166
+ elif url.endswith(".json"):
167
+ try:
168
+ data = json.loads(text)
169
+ for entry in data:
170
+ pairs.append((normalize_prompt(entry['prompt']), entry['answer'].strip()))
171
+ except Exception as e:
172
+ print(f"?? JSON parsing failed: {e}")
173
+ return pairs
174
+
175
+ def load_prompt_pairs(path):
176
+ import json, csv
177
+ import requests
178
+ import io
179
+ import PyPDF2
180
+
181
+ def is_url(p): return p.startswith("http")
182
+ ext = path.split(".")[-1].lower()
183
+ data = []
184
+
185
+ if is_url(path):
186
+ response = requests.get(path)
187
+ response.raise_for_status()
188
+ content = response.content
189
+
190
+ if ext == "json":
191
+ parsed = json.loads(content.decode("utf-8"))
192
+ for entry in parsed:
193
+ data.append((normalize_prompt(entry['prompt']), entry['answer'].strip()))
194
+ elif ext == "csv":
195
+ reader = csv.DictReader(io.StringIO(content.decode("utf-8")))
196
+ for row in reader:
197
+ data.append((normalize_prompt(row['prompt']), row['answer'].strip()))
198
+ elif ext == "txt":
199
+ for line in content.decode("utf-8", errors="replace").splitlines():
200
+ if "=" in line:
201
+ p, a = line.split("=", 1)
202
+ data.append((normalize_prompt(p), a.strip()))
203
+ elif ext == "pdf":
204
+ reader = PyPDF2.PdfReader(io.BytesIO(content))
205
+ for page in reader.pages:
206
+ text = page.extract_text()
207
+ if text:
208
+ for line in text.splitlines():
209
+ if "=" in line:
210
+ p, a = line.split("=", 1)
211
+ data.append((normalize_prompt(p), a.strip()))
212
+ else:
213
+ with open(path, "r", encoding="utf-8", errors="replace") as f:
214
+ lines = f.readlines()
215
+ for line in lines:
216
+ line = line.strip()
217
+ if "=" in line:
218
+ p, a = line.split("=", 1)
219
+ data.append((normalize_prompt(p), a.strip()))
220
+
221
+ return data
222
+
223
+
224
+ def list_files_from_github_folder(github_folder_url):
225
+ try:
226
+ html = requests.get(github_folder_url).text
227
+ soup = BeautifulSoup(html, "lxml")
228
+ links = soup.select("a.js-navigation-open")
229
+ raw_base = github_folder_url.replace("github.com", "raw.githubusercontent.com").replace("/blob", "")
230
+ file_links = []
231
+ for link in links:
232
+ href = link.get("href", "")
233
+ if any(href.endswith(ext) for ext in [".txt", ".json", ".csv", ".pdf", ".docx", ".xlsx"]):
234
+ file_links.append(f"https://{raw_base.split('/', 2)[-1].split('/')[0]}/{href.split('/', 2)[-1]}")
235
+ return file_links
236
+ except Exception as e:
237
+ print("⚠️ GitHub scan error:", e)
238
+ return []
239
+
240
+
diagram_generator.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #diagram_generator.py
2
+
3
+ import graphviz
4
+ import os
5
+
6
+ def generate_plsql_structure_chart(output_path="output/plsql_structure"):
7
+ dot = graphviz.Digraph(format="png")
8
+ dot.attr(rankdir='TB', bgcolor="lightyellow", fontname="Arial")
9
+
10
+ # Nodes
11
+ dot.node("START", "BEGIN", shape="oval", style="filled", fillcolor="lightgreen")
12
+ dot.node("DECLARE", "DECLARE", shape="box", style="filled", fillcolor="lightblue")
13
+ dot.node("EXCEPTION", "EXCEPTION", shape="box", style="filled", fillcolor="orange")
14
+ dot.node("END", "END", shape="oval", style="filled", fillcolor="lightgreen")
15
+
16
+ # Edges
17
+ dot.edge("START", "DECLARE")
18
+ dot.edge("DECLARE", "EXCEPTION")
19
+ dot.edge("EXCEPTION", "END")
20
+
21
+ output_file = dot.render(output_path, cleanup=True)
22
+ return output_file + ".png" # e.g., output/plsql_structure.png
model.py CHANGED
@@ -1,193 +1,190 @@
1
- # model.py
2
-
3
- import os
4
- import re
5
- import torch
6
- import random
7
- import requests
8
- from sentence_transformers import SentenceTransformer, util
9
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
- from sql_templates import sql_templates, sql_keyword_aliases, fuzzy_aliases, conflicting_phrases, greeting_templates
11
-
12
- DATA_DIR = [
13
- "data", # local folder
14
- "https://raw.githubusercontent.com/orachamp1981" # GitHub base (MUST be raw)
15
- ]
16
- DOMAIN_FILES = {
17
- "sql": "sql.txt",
18
- "plsql": "plsql.txt",
19
- "oracle_forms": "oracle_forms.txt",
20
- "oracle_reports": "oracle_reports.txt",
21
- "sql_plsql_interview": "interview_sql_pl_sql_question.txt",
22
- "pl_sql": "PL-SQL-Development/orachamp1981-patch-1/email_send.txt"
23
- }
24
-
25
- FALLBACK_FILE = "train_data.txt"
26
-
27
- # βœ… Semantic model
28
- model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
29
-
30
- # βœ… Normalize input
31
- def normalize_prompt(text):
32
- text = text.strip().lower()
33
- text = re.sub(r"[;:,.!?]+$", "", text) # ⬅️ Remove trailing punctuation
34
- text = re.sub(r"[^\w\s]", "", text) # Remove all other punctuation
35
- text = re.sub(r"\b(what|actually|please|tell|me|about|can|you|explain|is|the|do)\b", "", text)
36
- return re.sub(r"\s+", " ", text).strip()
37
-
38
- # βœ… Output cleaner
39
- def clean_response(text):
40
- return text.replace("\\n", "\n").replace(";;", ";").strip()
41
-
42
- # βœ… Load multi-prompt rules
43
- def load_multi_prompt_file(file_path):
44
- items = []
45
- if not os.path.exists(file_path):
46
- return items
47
- with open(file_path, "r", encoding="utf-8") as f:
48
- for line in f:
49
- line = line.strip()
50
- if not line or line.startswith("#") or "=" not in line:
51
- continue
52
- prompts_raw, answer = line.split("=", 1)
53
- for p in prompts_raw.split(","):
54
- norm = normalize_prompt(p)
55
- if norm:
56
- items.append((norm, answer.strip()))
57
- return items
58
-
59
-
60
- def load_multi_prompt_file_from_url(url):
61
- items = []
62
- try:
63
- response = requests.get(url)
64
- response.raise_for_status()
65
- lines = response.text.splitlines()
66
- for line in lines:
67
- line = line.strip()
68
- if not line or line.startswith("#") or "=" not in line:
69
- continue
70
- prompts_raw, answer = line.split("=", 1)
71
- for p in prompts_raw.split(","):
72
- norm = normalize_prompt(p)
73
- if norm:
74
- items.append((norm, answer.strip()))
75
- except Exception as e:
76
- print(f"⚠️ Error reading {url}:", e)
77
- return items
78
-
79
- # Build URL Building
80
- def is_url(path):
81
- return path.startswith("http://") or path.startswith("https://")
82
- # βœ… Build embeddings from all files
83
-
84
- def load_all_embeddings():
85
- all_data = []
86
- for _, rel_path in DOMAIN_FILES.items():
87
- file_loaded = False
88
- for base in DATA_DIR:
89
- if base.startswith("http"):
90
- full_url = f"{base}/{rel_path}"
91
- data = load_multi_prompt_file_from_url(full_url)
92
- else:
93
- full_path = os.path.join(base, rel_path)
94
- data = load_multi_prompt_file(full_path)
95
-
96
- if data:
97
- all_data.extend(data)
98
- file_loaded = True
99
- break # stop after first successful load
100
-
101
- if not file_loaded:
102
- print(f"⚠️ Could not load file: {rel_path}")
103
-
104
- # Fallback train_data.txt (assumed local)
105
- all_data.extend(load_multi_prompt_file(os.path.join("data", FALLBACK_FILE)))
106
-
107
- if not all_data:
108
- return [], None
109
-
110
- prompts = [p[0] for p in all_data]
111
- embeddings = model.encode(prompts, convert_to_tensor=True)
112
- return all_data, embeddings
113
- ALL_PAIRS, ALL_EMBEDDINGS = load_all_embeddings()
114
-
115
-
116
- # πŸ€– Local LLM
117
- llm_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
118
- tokenizer = AutoTokenizer.from_pretrained(llm_name)
119
- llm_model = AutoModelForCausalLM.from_pretrained(
120
- llm_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
121
- )
122
- llm_pipeline = pipeline("text-generation", model=llm_model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
123
-
124
- # πŸ” Main suggestor function
125
- def oracle_sql_suggester(prompt):
126
- norm_prompt = normalize_prompt(prompt)
127
-
128
- # Step 1: Greeting or Conflicts
129
- for greet_key, greet_reply in greeting_templates.items():
130
- if greet_key in norm_prompt:
131
- return greet_reply
132
- for terms, response in conflicting_phrases.items():
133
- if all(term in norm_prompt for term in terms):
134
- return response
135
-
136
-
137
- # Step X: Simple acknowledgment replies
138
- acknowledgment_prompts = ["okay", "ok", "got it", "thanks", "thank you", "cool", "alright", "great"]
139
- acknowledgment_replies = [
140
- "πŸ‘ Great! Let me know if you want to continue or explore another topic.",
141
- "πŸ‘Œ Got it! I'm here if you need help with anything else.",
142
- "βœ… Understood. Feel free to ask the next question whenever you're ready.",
143
- "Glad to hear that! Would you like to dive deeper or move on?",
144
- "Perfect! Let me know what you'd like to explore next."
145
- ]
146
- if norm_prompt in acknowledgment_prompts:
147
- return random.choice(acknowledgment_replies)
148
-
149
-
150
- # βœ… Step 1.5: Dynamic vague prompt detection
151
- if len(norm_prompt.split()) <= 3:
152
- user_embedding = model.encode(norm_prompt, convert_to_tensor=True)
153
- if ALL_EMBEDDINGS is not None:
154
- cosine_scores = util.cos_sim(user_embedding, ALL_EMBEDDINGS)[0]
155
- top_score = torch.max(cosine_scores).item()
156
- if top_score < 0.55:
157
- return (
158
- "πŸ€– Your question seems a bit broad or unclear.\n\n"
159
- "Could you please clarify what you're asking?\n\n"
160
- "- Are you referring to a query structure, data model, or system design?\n"
161
- "- Is this related to SQL, PL/SQL, or business process?\n\n"
162
- "Once you confirm, I'll provide the best possible answer!"
163
- )
164
-
165
- # Step 2: Semantic match across all domains + fallback
166
- if ALL_EMBEDDINGS is not None:
167
- user_embedding = model.encode(norm_prompt, convert_to_tensor=True)
168
- cosine_scores = util.cos_sim(user_embedding, ALL_EMBEDDINGS)[0]
169
- best_idx = torch.argmax(cosine_scores).item()
170
- best_score = cosine_scores[best_idx].item()
171
- if best_score >= 0.6:
172
- return clean_response(ALL_PAIRS[best_idx][1])
173
-
174
- # Step 3: Template, fuzzy, keyword aliases
175
- for word in norm_prompt.split():
176
- if word in sql_keyword_aliases:
177
- return sql_templates.get(sql_keyword_aliases[word])
178
- for key, template in sql_templates.items():
179
- if key in norm_prompt or key.replace("_", " ") in norm_prompt:
180
- return template
181
- for fuzzy_phrase, mapped_key in fuzzy_aliases.items():
182
- if fuzzy_phrase in norm_prompt:
183
- return sql_templates.get(mapped_key)
184
-
185
- # Step 4: Fallback LLM
186
- try:
187
- prompt_text = f"Generate an Oracle SQL query or explanation for the following:\n{prompt}\n\nSQL:"
188
- output = llm_pipeline(prompt_text, max_new_tokens=256, do_sample=True, temperature=0.5)[0]["generated_text"]
189
- return "πŸ€– (LLM): " + output.split("SQL:")[-1].strip()
190
- except Exception as e:
191
- print("⚠️ LLM fallback error:", e)
192
- return "πŸ€– Sorry, I couldn’t process that locally. Please try a simpler prompt."
193
-
 
1
+ # --- PATCHED model.py ---
2
+
3
+ import os
4
+ import re
5
+ import torch
6
+ import gradio as gr
7
+ import random
8
+ from tree_builder import generate_tree_for_prompt
9
+ from query_generator import generate_dynamic_query
10
+ from sentence_transformers import SentenceTransformer, util
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
12
+ from sql_templates import sql_templates, sql_keyword_aliases, fuzzy_aliases, conflicting_phrases, greeting_templates
13
+ from data_loader import (
14
+ load_prompt_pairs,
15
+ load_prompts_from_file,
16
+ load_prompts_from_url,
17
+ normalize_prompt,
18
+ clean_sql_output,
19
+ detect_domain,
20
+ load_rules,
21
+ load_rules_by_domain
22
+ )
23
+
24
+ # ========== Load Resources ==========
25
+ DATA_DIR = [
26
+ "data",
27
+ "https://raw.githubusercontent.com/orachamp1981/PL-SQL-Development/orachamp1981-patch-1"
28
+ ]
29
+
30
+ DOMAIN_FILES = {
31
+ "sql": "sql.txt",
32
+ "plsql": "plsql.txt",
33
+ "sql_plsql_interview": "interview_sql_pl_sql_question.txt",
34
+ "pl_sql": "email_send.txt",
35
+ "faq": "oracle_faq.json",
36
+ "guides": "best_practices.pdf"
37
+ }
38
+
39
+ FALLBACK_FILE = "train_data.txt"
40
+ model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
41
+
42
+ # -- NEW: Sanitize broken unicode --
43
+ def sanitize_unicode(text):
44
+ if isinstance(text, str):
45
+ return text.encode('utf-16', 'surrogatepass').decode('utf-16', 'replace')
46
+ return text
47
+
48
+ def clean_response(text):
49
+ return text.replace("\\n", "\n").replace(";;", ";").strip()
50
+
51
+ def is_definition_prompt(prompt: str) -> bool:
52
+ text = normalize_prompt(prompt)
53
+ return (
54
+ text.startswith("what is")
55
+ or text.startswith("define")
56
+ or "difference between" in text
57
+ )
58
+
59
+ def is_dynamic_sql_prompt(prompt: str) -> bool:
60
+ text = normalize_prompt(prompt)
61
+ dynamic_keywords = ["tables", "schema", "design", "query", "join", "reports", "forms", "structure", "modules"]
62
+ return any(word in text for word in dynamic_keywords) and ("sql" in text or "module" in text or "table" in text)
63
+
64
+ def load_all_embeddings():
65
+ all_data = []
66
+ failed_urls = set()
67
+
68
+ for _, rel_path in DOMAIN_FILES.items():
69
+ file_loaded = False
70
+ local_path = os.path.join("data", rel_path)
71
+ if os.path.exists(local_path):
72
+ data = load_prompts_from_file(local_path)
73
+ all_data.extend(data)
74
+ file_loaded = True
75
+ else:
76
+ for base in DATA_DIR:
77
+ if base.startswith("http"):
78
+ full_url = f"{base}/{rel_path}"
79
+ if full_url in failed_urls or full_url.endswith(".pdf"):
80
+ continue
81
+ data = load_prompts_from_url(full_url)
82
+ if data:
83
+ all_data.extend(data)
84
+ file_loaded = True
85
+ break
86
+ else:
87
+ failed_urls.add(full_url)
88
+
89
+ if not file_loaded:
90
+ print(f"⚠️ Could not load file: {rel_path}")
91
+
92
+ all_data.extend(load_prompts_from_file(os.path.join("data", FALLBACK_FILE)))
93
+ if not all_data:
94
+ return [], None
95
+
96
+ prompts = [p[0] for p in all_data]
97
+ embeddings = model.encode(prompts, convert_to_tensor=True)
98
+ return all_data, embeddings
99
+
100
+ ALL_PAIRS, ALL_EMBEDDINGS = load_all_embeddings()
101
+
102
+ # ========== Load LLM ==========
103
+ llm_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
104
+ tokenizer = AutoTokenizer.from_pretrained(llm_name)
105
+ llm_model = AutoModelForCausalLM.from_pretrained(
106
+ llm_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
107
+ )
108
+ llm_pipeline = pipeline("text-generation", model=llm_model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
109
+
110
+ # ========== Oracle SQL Assistant ==========
111
+ def oracle_sql_suggester(prompt):
112
+ norm_prompt = normalize_prompt(prompt)
113
+
114
+ # βœ… Greeting fix
115
+ greeting_templates.update({
116
+ "how are you": "πŸ€– I'm just code, but thanks for asking! Ready to help with Oracle SQL or PL/SQL.",
117
+ "how r u": "πŸ€– I'm doing great in the cloud ☁️. Let's solve some SQL problems!"
118
+ })
119
+
120
+ for greet_key, greet_reply in greeting_templates.items():
121
+ if greet_key in norm_prompt:
122
+ return sanitize_unicode(greet_reply)
123
+
124
+ # βœ… Acknowledgement input
125
+ ack_inputs = ["okay", "ok", "got it", "thanks", "thank you", "cool", "alright", "great"]
126
+ ack_replies = [
127
+ "πŸ‘ Great! Let me know if you want to continue or explore another topic.",
128
+ "πŸ‘Œ Got it! I'm here if you need help with anything else.",
129
+ "βœ… Understood. Feel free to ask the next question whenever you're ready.",
130
+ "Glad to hear that! Would you like to dive deeper or move on?",
131
+ "Perfect! Let me know what you'd like to explore next."
132
+ ]
133
+ if norm_prompt in ack_inputs:
134
+ return sanitize_unicode(random.choice(ack_replies))
135
+
136
+ # βœ… Conflict logic
137
+ for terms, response in conflicting_phrases.items():
138
+ if all(term in norm_prompt for term in terms):
139
+ return sanitize_unicode(response)
140
+
141
+ # βœ… Definitions
142
+ if is_definition_prompt(prompt):
143
+ if ALL_EMBEDDINGS is not None:
144
+ user_embedding = model.encode(norm_prompt, convert_to_tensor=True)
145
+ scores = util.cos_sim(user_embedding, ALL_EMBEDDINGS)[0]
146
+ best_idx = torch.argmax(scores).item()
147
+ best_score = scores[best_idx].item()
148
+ if best_score >= 0.6:
149
+ return sanitize_unicode(clean_response(ALL_PAIRS[best_idx][1]))
150
+ return sanitize_unicode("πŸ€– I couldn't find a strong match. Try rephrasing or ask something more specific.")
151
+
152
+ # βœ… Dynamic SQL
153
+ if is_dynamic_sql_prompt(prompt):
154
+ result = generate_dynamic_query(prompt)
155
+ if result:
156
+ return sanitize_unicode(f"πŸ€– (Dynamic SQL):\n{result}")
157
+
158
+ # βœ… Workflow tree
159
+ tree = generate_tree_for_prompt(prompt)
160
+ if tree:
161
+ return sanitize_unicode(str(tree))
162
+
163
+ # βœ… Semantic fallback
164
+ if ALL_EMBEDDINGS is not None:
165
+ user_embedding = model.encode(norm_prompt, convert_to_tensor=True)
166
+ scores = util.cos_sim(user_embedding, ALL_EMBEDDINGS)[0]
167
+ best_idx = torch.argmax(scores).item()
168
+ best_score = scores[best_idx].item()
169
+ if best_score >= 0.6:
170
+ return sanitize_unicode(clean_response(ALL_PAIRS[best_idx][1]))
171
+
172
+ # βœ… Templates
173
+ for word in norm_prompt.split():
174
+ if word in sql_keyword_aliases:
175
+ return sanitize_unicode(sql_templates.get(sql_keyword_aliases[word]))
176
+ for key, val in sql_templates.items():
177
+ if key in norm_prompt or key.replace("_", " ") in norm_prompt:
178
+ return sanitize_unicode(val)
179
+ for fuzzy, target_key in fuzzy_aliases.items():
180
+ if fuzzy in norm_prompt:
181
+ return sanitize_unicode(sql_templates.get(target_key))
182
+
183
+ # βœ… LLM fallback
184
+ try:
185
+ prompt_text = f"Generate an Oracle SQL query or explanation for the following:\n{prompt}\n\nSQL:"
186
+ output = llm_pipeline(prompt_text, max_new_tokens=256, do_sample=True, temperature=0.5)[0]["generated_text"]
187
+ return sanitize_unicode("πŸ€– (LLM): " + output.split("SQL:")[-1].strip())
188
+ except Exception as e:
189
+ print("⚠️ LLM fallback error:", e)
190
+ return sanitize_unicode("πŸ€– Sorry, I couldn’t process that locally. Please try a simpler prompt.")
 
 
 
query_generator.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # query_generator.py
2
+
3
+ import re
4
+ from sentence_transformers import SentenceTransformer, util
5
+
6
+ model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
7
+
8
+ # Sample domain entities
9
+ MODULE_TABLES = {
10
+ "hr": ["employees", "departments", "payroll", "leave", "attendance"],
11
+ "payroll": ["payroll", "salary", "deductions", "benefits"],
12
+ "inventory": ["items", "stock", "suppliers", "reorder"],
13
+ "finance": ["accounts", "transactions", "expenses", "budget"],
14
+ "employee": ["employees", "leave", "payroll", "attendance"]
15
+ }
16
+
17
+ FORMS_REPORTS = {
18
+ "hr": ["employee_form", "leave_application", "recruitment_form"],
19
+ "payroll": ["salary_report", "deduction_report", "payslip_form"],
20
+ "inventory": ["stock_report", "item_entry_form"],
21
+ "finance": ["budget_report", "expense_form"]
22
+ }
23
+
24
+ def normalize(text):
25
+ return re.sub(r"[^a-zA-Z0-9\s]", "", text.lower()).strip()
26
+
27
+ def get_best_match(prompt, candidates):
28
+ prompt_emb = model.encode(prompt, convert_to_tensor=True)
29
+ cand_embs = model.encode(candidates, convert_to_tensor=True)
30
+ sims = util.cos_sim(prompt_emb, cand_embs)[0]
31
+ top_index = sims.argmax().item()
32
+ return candidates[top_index], sims[top_index].item()
33
+
34
+ def extract_module_from_prompt(prompt):
35
+ prompt_norm = normalize(prompt)
36
+ for module in MODULE_TABLES:
37
+ if module in prompt_norm:
38
+ return module
39
+ match, score = get_best_match(prompt_norm, list(MODULE_TABLES.keys()))
40
+ return match if score >= 0.4 else None
41
+
42
+ def infer_fields_from_prompt(prompt):
43
+ prompt = normalize(prompt)
44
+ fields = []
45
+ if "name" in prompt: fields.append("emp_name")
46
+ if "salary" in prompt: fields.append("salary")
47
+ if "leave" in prompt: fields.append("leave_days")
48
+ if "department" in prompt: fields.append("dept_name")
49
+ if "id" in prompt: fields.append("emp_id")
50
+ return fields or ["*"]
51
+
52
+ def generate_join_query(module, fields):
53
+ tables = MODULE_TABLES.get(module, [])
54
+ if not tables:
55
+ return "SELECT * FROM some_table"
56
+
57
+ select_parts = []
58
+ joins = []
59
+ base_table = tables[0]
60
+
61
+ for table in tables:
62
+ if fields == ["*"]:
63
+ select_parts.append(f"{table}.*")
64
+ else:
65
+ select_parts.extend([f"{table}.{f}" for f in fields])
66
+
67
+ for t in tables[1:]:
68
+ if "dept" in t:
69
+ joins.append(f"JOIN {t} ON {base_table}.dept_id = {t}.dept_id")
70
+ elif "payroll" in t:
71
+ joins.append(f"JOIN {t} ON {base_table}.emp_id = {t}.emp_id")
72
+ elif "leave" in t:
73
+ joins.append(f"JOIN {t} ON {base_table}.emp_id = {t}.emp_id")
74
+ else:
75
+ joins.append(f"JOIN {t} ON 1=1")
76
+
77
+ query = f"SELECT {', '.join(select_parts)} FROM {base_table} " + " ".join(joins)
78
+ return query
79
+
80
+
81
+ def generate_forms_reports_query(module):
82
+ reports = FORMS_REPORTS.get(module, [])
83
+ if not reports:
84
+ return f"-- No forms or reports found for module '{module}'"
85
+ select_lines = [f"SELECT * FROM {r}" for r in reports]
86
+ return "\nUNION ALL\n".join(select_lines)
87
+
88
+ def generate_dynamic_query(prompt):
89
+ prompt_norm = normalize(prompt)
90
+
91
+ if any(kw in prompt_norm for kw in ["form", "report", "forms", "reports"]):
92
+ module = extract_module_from_prompt(prompt)
93
+ return "πŸ€– (Dynamic SQL):\n" + generate_forms_reports_query(module)
94
+
95
+ elif any(kw in prompt_norm for kw in ["table", "tables", "schema", "structure", "design"]):
96
+ module = extract_module_from_prompt(prompt)
97
+ fields = infer_fields_from_prompt(prompt)
98
+ return "πŸ€– (Dynamic SQL):\n" + generate_join_query(module, fields)
99
+
100
+ elif "join" in prompt_norm:
101
+ module = extract_module_from_prompt(prompt)
102
+ fields = infer_fields_from_prompt(prompt)
103
+ return "πŸ€– (Dynamic SQL):\n" + generate_join_query(module, fields)
104
+
105
+ elif any(kw in prompt_norm for kw in ["query", "select", "show", "get", "fetch"]):
106
+ module = extract_module_from_prompt(prompt)
107
+ fields = infer_fields_from_prompt(prompt)
108
+ return "πŸ€– (Dynamic SQL):\n" + generate_join_query(module, fields)
109
+
110
+ return None # let model.py handle fallback if not match
requirements.txt CHANGED
@@ -1,5 +1,12 @@
1
- gradio==3.50.2
2
- sentence-transformers
3
- torch
4
- transformers
5
- accelerate
 
 
 
 
 
 
 
 
1
+ # REQUIREMENT
2
+
3
+ gradio>=4.14.0
4
+ gradio
5
+ torch
6
+ sentence-transformers
7
+ transformers
8
+ accelerate
9
+ graphviz
10
+ PyPDF2
11
+ python-docx
12
+ openpyxl
semantic_tree.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # semantic_tree.py
2
+
3
+ import os
4
+ import json
5
+ import re
6
+ from sentence_transformers import SentenceTransformer, util
7
+
8
+ # Load embedding model
9
+ model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
10
+
11
+ # Path to concept memory file
12
+ CONCEPT_DB = os.path.join(os.path.dirname(__file__), "concept_library.json")
13
+
14
+ # Load or initialize concept memory
15
+ if os.path.exists(CONCEPT_DB):
16
+ with open(CONCEPT_DB, "r", encoding="utf-8") as f:
17
+ concept_map = json.load(f)
18
+ else:
19
+ concept_map = {}
20
+
21
+ # Utility: Format dictionary as tree
22
+ def format_tree(title, tree_dict):
23
+ output = [title]
24
+ parents = list(tree_dict.keys())
25
+ for i, parent in enumerate(parents):
26
+ is_last = (i == len(parents) - 1)
27
+ output.append(f"{'└──' if is_last else 'β”œβ”€β”€'} {parent}")
28
+ children = tree_dict[parent]
29
+ for j, child in enumerate(children):
30
+ is_last_child = (j == len(children) - 1)
31
+ prefix = " " if is_last else "β”‚ "
32
+ output.append(f"{prefix}{'└──' if is_last_child else 'β”œβ”€β”€'} {child}")
33
+ return "\n".join(output)
34
+
35
+ # Extract base concept phrases
36
+
37
+ def normalize_prompt(prompt):
38
+ return re.sub(r"[^a-zA-Z0-9\s]", "", prompt.lower()).strip()
39
+
40
+ # Semantic search over known concepts
41
+ def match_concept(prompt, concept_map, threshold=0.65):
42
+ if not concept_map:
43
+ return None
44
+ prompt_embed = model.encode(prompt, convert_to_tensor=True)
45
+ keys = list(concept_map.keys())
46
+ key_embeds = model.encode(keys, convert_to_tensor=True)
47
+ sims = util.cos_sim(prompt_embed, key_embeds)[0]
48
+ best_idx = sims.argmax().item()
49
+ if sims[best_idx] >= threshold:
50
+ return keys[best_idx]
51
+ return None
52
+
53
+ # Generate tree from prompt meaning
54
+ def generate_semantic_tree(prompt):
55
+ norm_prompt = normalize_prompt(prompt)
56
+
57
+ match = match_concept(norm_prompt, concept_map)
58
+ if match:
59
+ return format_tree(match.title(), {match.title(): concept_map[match]})
60
+
61
+ # No known match: fallback clustering logic (primitive expansion)
62
+ keywords = [w for w in norm_prompt.split() if len(w) > 2]
63
+ root = "Inferred Structure"
64
+ grouped = {
65
+ "Concept A": keywords[:len(keywords)//2 or 1],
66
+ "Concept B": keywords[len(keywords)//2 or 1:]
67
+ }
68
+ return format_tree(root, grouped)
69
+
70
+ # Optional: expand concept map dynamically
71
+ # Could be added later to "learn" from corrections
tree_builder.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # tree_builder.py
2
+
3
+ from tree_synthesizer import generate_step_tree
4
+
5
+ def generate_tree_for_prompt(prompt: str) -> str:
6
+ return generate_step_tree(prompt)
tree_cache.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "workflow of inventory tracking": "You're an expert at designing hierarchies and tree structures.\nYour task is to create a clear, non-repetitive, 2\u20133 level tree-style diagram for the following topic:\n\"workflow of inventory tracking\"\nFormatting rules:\n- Use \"\u251c\u2500\u2500\" for branches and \"\u2514\u2500\u2500\" for final child.\n- Do NOT repeat the topic name in multiple branches.\n- Do NOT output extra text like introductions or notes.\nExamples:\nERP System\n\u251c\u2500\u2500 Finance\n\u251c\u2500\u2500 Inventory\n\u251c\u2500\u2500 Sales\n\u2514\u2500\u2500 HR\nHospital Admissions\n\u251c\u2500\u2500 Patient Details\n\u2502 \u251c\u2500\u2500 Demographics\n\u2502 \u251c\u2500\u2500 Insurance\n\u2502 \u2514\u2500\u2500 Emergency Contact\n\u251c\u2500\u2500 Admission Records\n\u2514\u2500\u2500 Discharge Process\nNow generate the tree for:\nExpected Output:\n```\n\u2502 \u251c\u2500\u2500 Inventory\n\u2502 \u2502 \u251c\u2500\u2500 Stock Levels\n\u2502 \u2502 \u251c\u2500\u2500 Purchase Orders\n\u2502 \u2502 \u251c\u2500\u2500 Sales Orders\n\u2502 \u2502 \u2514\u2500\u2500 Inventory Reports\n\u2502 \u2514\u2500\u2500 Ledger\n\u2502 \u251c\u2500\u2500 Cash\n\u2502 \u251c\u2500\u2500 Bank Accounts\n\u2502 \u2514\u2500\u2500 Payroll\n\u2502 \u251c\u2500\u2500 Stock Levels\n\u2502 \u251c\u2500\u2500 Sales Orders\n\u2502 \u2502 \u251c\u2500\u2500 Inventory Reports\n\u2502 \u2502 \u2514\u2500\u2500 Sales Reports\n\u251c\u2500\u2500 HR\n\u2502 \u251c\u2500\u2500 Employee Data\n\u2502 \u2502 \u251c\u2500\u2500 Employee Handbook\n\u2502 \u2502 \u2514\u2500\u2500 Evaluation Forms\n\u2502 \u2514\u2500\u2500 Leave Requests\n\u2514\u2500\u2500 Admissions\n \u251c\u2500\u2500 Ad"
3
+ }
tree_synthesizer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tree_synthesizer.py
2
+
3
+ import re
4
+ from sentence_transformers import SentenceTransformer, util
5
+
6
+ model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
7
+
8
+ COMMON_STEPS = [
9
+ "Requirements Gathering", "Design", "Development", "Testing", "Deployment",
10
+ "Monitoring", "Approval", "Rejection", "Review", "Validation", "Planning",
11
+ "Execution", "Documentation", "Evaluation", "Feedback", "Support", "Training",
12
+ "Payment", "Notification", "Registration", "Submission", "Completion"
13
+ ]
14
+
15
+ def normalize_prompt(text):
16
+ return re.sub(r"[^a-zA-Z0-9\s]", "", text.lower()).strip()
17
+
18
+ def extract_domain(prompt):
19
+ tokens = normalize_prompt(prompt).split()
20
+ stopwords = {"workflow", "steps", "process", "structure", "of", "in", "the", "how", "does"}
21
+ return " ".join([w for w in tokens if w not in stopwords])
22
+
23
+ def select_relevant_steps(prompt, top_k=5):
24
+ prompt_embedding = model.encode(prompt, convert_to_tensor=True)
25
+ step_embeddings = model.encode(COMMON_STEPS, convert_to_tensor=True)
26
+ sims = util.cos_sim(prompt_embedding, step_embeddings)[0]
27
+ top_indices = sims.argsort(descending=True)[:top_k]
28
+ return [COMMON_STEPS[i] for i in top_indices]
29
+
30
+ def format_step_tree(title, steps):
31
+ output = [title.title()]
32
+ for i, step in enumerate(steps):
33
+ is_last = (i == len(steps) - 1)
34
+ output.append(f"{'└──' if is_last else 'β”œβ”€β”€'} {step}")
35
+ return "\n".join(output)
36
+
37
+ def is_definition_prompt(prompt):
38
+ text = normalize_prompt(prompt)
39
+ if text.startswith("what is") or text.startswith("define") or "difference between" in text:
40
+ return True
41
+ # Expanded direct checks for SQL/PLSQL terms even in short form
42
+ direct_terms = [
43
+ "sql", "plsql", "pl sql", "oracle sql", "commit", "rollback", "group by", "function"
44
+ ]
45
+ return any(term in text for term in direct_terms)
46
+
47
+ def is_schema_design_prompt(prompt):
48
+ text = normalize_prompt(prompt)
49
+ design_keywords = [
50
+ "table", "tables", "schema", "structure", "data model", "sql tables",
51
+ "entity", "forms", "reports", "design", "layout"
52
+ ]
53
+ return any(word in text for word in design_keywords)
54
+
55
+
56
+ def generate_step_tree(prompt):
57
+ text = normalize_prompt(prompt)
58
+ if len(text.split()) <= 2 and any(greet in text for greet in {"hi", "hello", "hey"}):
59
+ return "πŸ‘‹ Hi there! Need help with Oracle SQL or PL/SQL?"
60
+ if text.startswith("what is") and not any(term in text for term in {"workflow", "flow", "structure", "steps"}):
61
+ return ""
62
+
63
+ norm = normalize_prompt(prompt)
64
+ if len(norm.split()) <= 2:
65
+ return ""
66
+
67
+ if is_definition_prompt(prompt):
68
+ return ""
69
+
70
+ if is_schema_design_prompt(prompt):
71
+ return ""
72
+
73
+ domain = extract_domain(prompt).title()
74
+ if not domain:
75
+ domain = "Process"
76
+ steps = select_relevant_steps(prompt)
77
+ return format_step_tree(f"{domain} Workflow", steps)