Shashwat-18 commited on
Commit
3a15a02
·
verified ·
1 Parent(s): 85bc6f6

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/olist_analytics.db filter=lfs diff=lfs merge=lfs -text
37
+ data/olist_geolocation_dataset.csv filter=lfs diff=lfs merge=lfs -text
38
+ data/olist_order_items_dataset.csv filter=lfs diff=lfs merge=lfs -text
39
+ data/olist_order_reviews_dataset.csv filter=lfs diff=lfs merge=lfs -text
40
+ data/olist_orders_dataset.csv filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,2652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import os
3
+ import uuid
4
+ import sqlite3
5
+ import datetime
6
+ import json
7
+ import re
8
+ import time
9
+ from typing import Dict, Any, List, Tuple, Optional
10
+ import hashlib
11
+
12
+ import pandas as pd
13
+ from groq import Groq
14
+ from langchain_community.vectorstores import Chroma
15
+ from langchain_community.embeddings import HuggingFaceEmbeddings
16
+
17
+ # %%
18
+ # 0. Global config
19
+ DB_PATH = "olist.db"
20
+ DATA_DIR = "data"
21
+
22
+ # Groq Model
23
+ GROQ_MODEL_NAME = "llama-3.3-70b-versatile"
24
+
25
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
26
+ if not GROQ_API_KEY:
27
+ raise ValueError("GROQ_API_KEY not found.")
28
+
29
+ groq_client = Groq(api_key=GROQ_API_KEY)
30
+
31
+ # Embedding model
32
+ EMBED_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
33
+
34
+ # Logging paths
35
+ PROMPT_LOG_PATH = "prompt_logs.txt"
36
+ RUN_LOG_PATH = "run_logs.txt"
37
+
38
+ # %%
39
+
40
+ import os
41
+ import chromadb
42
+ from chromadb.config import Settings
43
+
44
+ CHROMA_PERSIST_DIR = os.path.abspath("./chroma_data")
45
+ os.makedirs(CHROMA_PERSIST_DIR, exist_ok=True)
46
+
47
+ _CHROMA_CLIENT = None
48
+
49
+ def get_chroma_client():
50
+ global _CHROMA_CLIENT
51
+ if _CHROMA_CLIENT is None:
52
+ _CHROMA_CLIENT = chromadb.PersistentClient(
53
+ path=CHROMA_PERSIST_DIR,
54
+ settings=Settings(
55
+ anonymized_telemetry=False,
56
+ ),
57
+ )
58
+ return _CHROMA_CLIENT
59
+
60
+
61
+ # %%
62
+
63
+ def log_prompt(tag: str, prompt: str) -> None:
64
+ """
65
+ Append the full prompt to a log file, with a tag and timestamp.
66
+ """
67
+ timestamp = datetime.datetime.now().isoformat(timespec="seconds")
68
+ header = f"\n\n================ {tag} @ {timestamp} ================\n"
69
+ with open(PROMPT_LOG_PATH, "a", encoding="utf-8") as f:
70
+ f.write(header)
71
+ f.write(prompt)
72
+ f.write("\n================ END PROMPT ================\n")
73
+
74
+
75
+ def log_run_event(tag: str, content: str) -> None:
76
+ """
77
+ Append model response, final SQL, and error info into a run log.
78
+ """
79
+ timestamp = datetime.datetime.now().isoformat(timespec="seconds")
80
+ header = f"\n\n================ {tag} @ {timestamp} ================\n"
81
+ with open(RUN_LOG_PATH, "a", encoding="utf-8") as f:
82
+ f.write(header)
83
+ f.write(content)
84
+ f.write("\n================ END EVENT ================\n")
85
+
86
+
87
+ # %%
88
+ from pathlib import Path
89
+
90
+ BASE_DIR = Path.cwd()
91
+
92
+ ANALYTICS_DB_PATH = BASE_DIR / "data" / "olist_analytics.db"
93
+ ANALYTICS_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
94
+
95
+ CHAT_DB_PATH = BASE_DIR / "chat_data" / "chat_history.db"
96
+ CHAT_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
97
+
98
+
99
+ # %%
100
+
101
+
102
+ # %%
103
+
104
+ def get_analytics_conn() -> sqlite3.Connection:
105
+ return sqlite3.connect(
106
+ ANALYTICS_DB_PATH,
107
+ check_same_thread=False
108
+ )
109
+
110
+ def get_chat_conn() -> sqlite3.Connection:
111
+ conn = sqlite3.connect(CHAT_DB_PATH)
112
+ conn.row_factory = sqlite3.Row
113
+ conn.execute("PRAGMA foreign_keys = ON;")
114
+ return conn
115
+
116
+
117
+ # %%
118
+ # ==============================
119
+ # Chat message helpers
120
+ # ==============================
121
+
122
+ def assistant_msg(text: str):
123
+ return (None, text)
124
+
125
+ def user_msg(text: str):
126
+ return (text, None)
127
+
128
+ def add_assistant_state(history, text):
129
+ history.append(assistant_msg(text))
130
+ return history
131
+
132
+
133
+ # %%
134
+
135
+
136
+ def df_to_state(df: pd.DataFrame) -> dict:
137
+ if df is None or not isinstance(df, pd.DataFrame) or df.empty:
138
+ return {"columns": [], "rows": []}
139
+
140
+ cols = list(df.columns)
141
+ rows = df.astype(str).values.tolist()
142
+
143
+ if not rows:
144
+ return {"columns": [], "rows": []}
145
+
146
+ return {"columns": cols, "rows": rows}
147
+
148
+
149
+
150
+ def state_to_df(state: dict) -> pd.DataFrame:
151
+ if not isinstance(state, dict):
152
+ return pd.DataFrame()
153
+
154
+ cols = state.get("columns")
155
+ rows = state.get("rows")
156
+
157
+ if not cols or rows is None:
158
+ return pd.DataFrame()
159
+
160
+ return pd.DataFrame(rows, columns=cols)
161
+
162
+
163
+ # %%
164
+
165
+ def init_feedback_table(conn: sqlite3.Connection) -> None:
166
+ """
167
+ Create (or upgrade) a table to capture user feedback on model answers.
168
+ """
169
+ conn.execute("""
170
+ CREATE TABLE IF NOT EXISTS user_feedback (
171
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
172
+ created_at TEXT NOT NULL,
173
+ question TEXT NOT NULL,
174
+ generated_sql TEXT,
175
+ model_answer TEXT,
176
+ rating TEXT CHECK(rating IN ('good','bad')) NOT NULL,
177
+ comment TEXT,
178
+ corrected_sql TEXT
179
+ )
180
+ """)
181
+ conn.commit()
182
+
183
+
184
+ def record_feedback(
185
+ conn: sqlite3.Connection,
186
+ question: str,
187
+ generated_sql: str,
188
+ model_answer: str,
189
+ rating: str, # "good" or "bad"
190
+ comment: Optional[str] = None,
191
+ corrected_sql: Optional[str] = None,
192
+ ) -> None:
193
+ """
194
+ Store user feedback about a particular model answer / SQL query.
195
+ If corrected_sql is provided, it is treated as an external correction.
196
+ """
197
+ rating = rating.lower()
198
+ if rating not in ("good", "bad"):
199
+ raise ValueError("rating must be 'good' or 'bad'")
200
+
201
+ ts = datetime.datetime.now().isoformat(timespec="seconds")
202
+ conn.execute(
203
+ """
204
+ INSERT INTO user_feedback (
205
+ created_at, question, generated_sql, model_answer,
206
+ rating, comment, corrected_sql
207
+ )
208
+ VALUES (?, ?, ?, ?, ?, ?, ?)
209
+ """,
210
+ (ts, question, generated_sql, model_answer, rating, comment, corrected_sql),
211
+ )
212
+ conn.commit()
213
+
214
+
215
+ def get_last_feedback_for_question(
216
+ conn: sqlite3.Connection,
217
+ question: str,
218
+ ) -> Optional[Dict[str, Any]]:
219
+ """
220
+ Return the most recent feedback row for this question (if any).
221
+ """
222
+ cur = conn.cursor()
223
+ cur.execute(
224
+ """
225
+ SELECT created_at, generated_sql, model_answer,
226
+ rating, comment, corrected_sql
227
+ FROM user_feedback
228
+ WHERE question = ?
229
+ ORDER BY created_at DESC
230
+ LIMIT 1
231
+ """,
232
+ (question,),
233
+ )
234
+ row = cur.fetchone()
235
+ if not row:
236
+ return None
237
+
238
+ return {
239
+ "created_at": row[0],
240
+ "generated_sql": row[1],
241
+ "model_answer": row[2],
242
+ "rating": row[3],
243
+ "comment": row[4],
244
+ "corrected_sql": row[5],
245
+ }
246
+
247
+
248
+ # %%
249
+ def init_db() -> None:
250
+ with get_analytics_conn() as conn:
251
+ csv_files = [
252
+ "olist_customers_dataset.csv",
253
+ "olist_orders_dataset.csv",
254
+ "olist_order_items_dataset.csv",
255
+ "olist_products_dataset.csv",
256
+ "olist_order_reviews_dataset.csv",
257
+ "olist_order_payments_dataset.csv",
258
+ "product_category_name_translation.csv",
259
+ "olist_sellers_dataset.csv",
260
+ "olist_geolocation_dataset.csv",
261
+ ]
262
+
263
+ for fname in csv_files:
264
+ path = os.path.join(DATA_DIR, fname)
265
+ if not os.path.exists(path):
266
+ continue
267
+
268
+ table_name = os.path.splitext(fname)[0]
269
+ df = pd.read_csv(path)
270
+ df.to_sql(table_name, conn, if_exists="replace", index=False)
271
+
272
+ init_feedback_table(conn)
273
+
274
+
275
+
276
+ # %%
277
+ # 4. Manual docs for Olist tables
278
+
279
+ OLIST_DOCS: Dict[str, Dict[str, Any]] = {
280
+ "olist_customers_dataset": {
281
+ "description": "Customer master data, one row per customer_id (which can change over time for the same end-user).",
282
+ "columns": {
283
+ "customer_id": "Primary key for this table. Unique technical identifier for a customer at a point in time. Used to join with olist_orders_dataset.customer_id.",
284
+ "customer_unique_id": "Stable unique identifier for the end-user. A single customer_unique_id can map to multiple customer_id records over time.",
285
+ "customer_zip_code_prefix": "Customer ZIP/postal code prefix. Used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
286
+ "customer_city": "Customer's city as captured at the time of the order or registration.",
287
+ "customer_state": "Customer's state (two-letter Brazilian state code, e.g. SP, RJ)."
288
+ }
289
+ },
290
+ "olist_orders_dataset": {
291
+ "description": "Customer orders placed on the Olist marketplace, one row per order.",
292
+ "columns": {
293
+ "order_id": "Primary key. Unique identifier for each order. Used to join with items, payments, and reviews.",
294
+ "customer_id": "Foreign key to olist_customers_dataset.customer_id indicating who placed the order.",
295
+ "order_status": "Current lifecycle status of the order (e.g. created, shipped, delivered, canceled, unavailable).",
296
+ "order_purchase_timestamp": "Timestamp when the customer completed the purchase (event time for order placement).",
297
+ "order_approved_at": "Timestamp when the payment was approved by the system or financial gateway.",
298
+ "order_delivered_carrier_date": "Timestamp when the order was handed over by the seller to the carrier/logistics provider.",
299
+ "order_delivered_customer_date": "Timestamp when the carrier reported the order as delivered to the final customer.",
300
+ "order_estimated_delivery_date": "Estimated delivery date promised to the customer at checkout."
301
+ }
302
+ },
303
+ "olist_order_items_dataset": {
304
+ "description": "Order line items, one row per product per order.",
305
+ "columns": {
306
+ "order_id": "Foreign key to olist_orders_dataset.order_id. Multiple order_items can belong to the same order.",
307
+ "order_item_id": "Sequential item number within an order (1, 2, 3, ...). Uniquely identifies a line inside an order.",
308
+ "product_id": "Foreign key to olist_products_dataset.product_id representing the purchased product.",
309
+ "seller_id": "Foreign key to olist_sellers_dataset.seller_id representing the seller that fulfilled this item.",
310
+ "shipping_limit_date": "Deadline for the seller to hand the item over to the carrier for shipping.",
311
+ "price": "Item price paid by the customer for this line (in BRL, not including freight).",
312
+ "freight_value": "Freight (shipping) cost attributed to this line item (in BRL)."
313
+ }
314
+ },
315
+ "olist_products_dataset": {
316
+ "description": "Product catalog with physical and category attributes, one row per product.",
317
+ "columns": {
318
+ "product_id": "Primary key. Unique identifier for each product. Used to join with olist_order_items_dataset.product_id.",
319
+ "product_category_name": "Product category name in Portuguese. Join to product_category_name_translation.product_category_name for English.",
320
+ "product_name_lenght": "Number of characters in the product name (field name misspelled as 'lenght' in the original dataset).",
321
+ "product_description_lenght": "Number of characters in the product description (also misspelled as 'lenght').",
322
+ "product_photos_qty": "Number of product images associated with the listing.",
323
+ "product_weight_g": "Product weight in grams.",
324
+ "product_length_cm": "Product length in centimeters (package dimension).",
325
+ "product_height_cm": "Product height in centimeters (package dimension).",
326
+ "product_width_cm": "Product width in centimeters (package dimension)."
327
+ }
328
+ },
329
+ "olist_order_reviews_dataset": {
330
+ "description": "Post-purchase customer reviews and satisfaction scores, one row per review.",
331
+ "columns": {
332
+ "review_id": "Primary key. Unique identifier for each review record.",
333
+ "order_id": "Foreign key to olist_orders_dataset.order_id for the reviewed order.",
334
+ "review_score": "Star rating given by the customer on a 1–5 scale (5 = very satisfied, 1 = very dissatisfied).",
335
+ "review_comment_title": "Optional short text title or summary of the review.",
336
+ "review_comment_message": "Optional detailed free-text comment describing the customer experience.",
337
+ "review_creation_date": "Date when the customer created the review.",
338
+ "review_answer_timestamp": "Timestamp when Olist or the seller responded to the review (if applicable)."
339
+ }
340
+ },
341
+ "olist_order_payments_dataset": {
342
+ "description": "Payments associated with orders, one row per payment record (order can have multiple payments).",
343
+ "columns": {
344
+ "order_id": "Foreign key to olist_orders_dataset.order_id.",
345
+ "payment_sequential": "Sequence number for multiple payments of the same order (1 for first payment, 2 for second, etc.).",
346
+ "payment_type": "Payment method used (e.g. credit_card, boleto, voucher, debit_card).",
347
+ "payment_installments": "Number of installments chosen by the customer for this payment.",
348
+ "payment_value": "Monetary amount paid in this payment record (in BRL)."
349
+ }
350
+ },
351
+ "product_category_name_translation": {
352
+ "description": "Lookup table mapping Portuguese product category names to English equivalents.",
353
+ "columns": {
354
+ "product_category_name": "Product category name in Portuguese as used in olist_products_dataset.",
355
+ "product_category_name_english": "Translated product category name in English."
356
+ }
357
+ },
358
+ "olist_sellers_dataset": {
359
+ "description": "Seller master data, one row per seller operating on the Olist marketplace.",
360
+ "columns": {
361
+ "seller_id": "Primary key. Unique identifier for each seller. Used to join with olist_order_items_dataset.seller_id.",
362
+ "seller_zip_code_prefix": "Seller ZIP/postal code prefix, used to join with olist_geolocation_dataset.geolocation_zip_code_prefix.",
363
+ "seller_city": "City where the seller is located.",
364
+ "seller_state": "State where the seller is located (two-letter Brazilian state code)."
365
+ }
366
+ },
367
+ "olist_geolocation_dataset": {
368
+ "description": "Geolocation reference data for ZIP code prefixes in Brazil, not unique per prefix.",
369
+ "columns": {
370
+ "geolocation_zip_code_prefix": "ZIP/postal code prefix, used to link customers and sellers via zip code.",
371
+ "geolocation_lat": "Latitude coordinate of the location.",
372
+ "geolocation_lng": "Longitude coordinate of the location.",
373
+ "geolocation_city": "City name for the location.",
374
+ "geolocation_state": "State code for the location (two-letter Brazilian state code)."
375
+ }
376
+ }
377
+ }
378
+
379
+ init_db()
380
+
381
+ # %%
382
+
383
+ from pathlib import Path
384
+
385
+ # ----------------------------------
386
+ # Configuration
387
+ # ----------------------------------
388
+ from pathlib import Path
389
+ BASE_DIR = Path.cwd()
390
+ DB_PATH = BASE_DIR / "chat_data" / "chat_history.db"
391
+ DB_PATH.parent.mkdir(parents=True, exist_ok=True)
392
+
393
+
394
+
395
+ # ----------------------------------
396
+ # SQL schema
397
+ # ----------------------------------
398
+ SCHEMA_SQL = """
399
+ -- ===============================
400
+ -- Conversations (chat threads)
401
+ -- ===============================
402
+ CREATE TABLE IF NOT EXISTS conversations (
403
+ conversation_id TEXT PRIMARY KEY,
404
+ title TEXT,
405
+ created_at TEXT NOT NULL,
406
+ updated_at TEXT NOT NULL
407
+ );
408
+
409
+ -- ===============================
410
+ -- Messages inside conversations
411
+ -- ===============================
412
+ CREATE TABLE IF NOT EXISTS conversation_messages (
413
+ message_id INTEGER PRIMARY KEY AUTOINCREMENT,
414
+ conversation_id TEXT NOT NULL,
415
+ role TEXT CHECK(role IN ('user', 'assistant')) NOT NULL,
416
+ content TEXT NOT NULL,
417
+ created_at TEXT NOT NULL,
418
+ FOREIGN KEY (conversation_id) REFERENCES conversations(conversation_id)
419
+ );
420
+
421
+ -- ===============================
422
+ -- Snapshot of last known state
423
+ -- ===============================
424
+ CREATE TABLE IF NOT EXISTS conversation_state (
425
+ conversation_id TEXT PRIMARY KEY,
426
+ state_json TEXT NOT NULL,
427
+ updated_at TEXT NOT NULL,
428
+ FOREIGN KEY (conversation_id) REFERENCES conversations(conversation_id)
429
+ );
430
+ """
431
+
432
+ # ----------------------------------
433
+ # Execute
434
+ # ----------------------------------
435
+ conn = sqlite3.connect(DB_PATH)
436
+ conn.execute("PRAGMA foreign_keys = ON;")
437
+ conn.executescript(SCHEMA_SQL)
438
+ conn.commit()
439
+ conn.close()
440
+
441
+ print(" Chat history tables initialized at:", DB_PATH.resolve())
442
+
443
+
444
+ # %%
445
+
446
+ import json
447
+ import uuid
448
+ from pathlib import Path
449
+
450
+ DB_PATH = Path("chat_data/chat_history.db")
451
+
452
+ # ------------------------------
453
+ # Connection helper
454
+ # ------------------------------
455
+ def get_conn():
456
+ return get_chat_conn()
457
+
458
+ # ------------------------------
459
+ # Conversation lifecycle
460
+ # ------------------------------
461
+ def create_conversation(title: str | None = None) -> str:
462
+ conversation_id = uuid.uuid4().hex
463
+ now = datetime.datetime.utcnow().isoformat()
464
+
465
+ if not title:
466
+ title = "New conversation"
467
+
468
+ with get_conn() as conn:
469
+ conn.execute(
470
+ """
471
+ INSERT INTO conversations (conversation_id, title, created_at, updated_at)
472
+ VALUES (?, ?, ?, ?)
473
+ """,
474
+ (conversation_id, title, now, now),
475
+ )
476
+
477
+ return conversation_id
478
+
479
+
480
+ def list_conversations(limit: int = 50):
481
+ with get_conn() as conn:
482
+ rows = conn.execute(
483
+ """
484
+ SELECT conversation_id, title, updated_at
485
+ FROM conversations
486
+ ORDER BY updated_at DESC
487
+ LIMIT ?
488
+ """,
489
+ (limit,),
490
+ ).fetchall()
491
+
492
+ return [dict(r) for r in rows]
493
+
494
+
495
+ def rename_conversation(conversation_id: str, new_title: str):
496
+ now = datetime.datetime.utcnow().isoformat()
497
+ with get_conn() as conn:
498
+ conn.execute(
499
+ """
500
+ UPDATE conversations
501
+ SET title = ?, updated_at = ?
502
+ WHERE conversation_id = ?
503
+ """,
504
+ (new_title, now, conversation_id),
505
+ )
506
+
507
+ # ------------------------------
508
+ # Messages
509
+ # ------------------------------
510
+ def append_message(conversation_id: str, role: str, content: str):
511
+ now = datetime.datetime.utcnow().isoformat()
512
+
513
+ with get_conn() as conn:
514
+ conn.execute(
515
+ """
516
+ INSERT INTO conversation_messages
517
+ (conversation_id, role, content, created_at)
518
+ VALUES (?, ?, ?, ?)
519
+ """,
520
+ (conversation_id, role, content, now),
521
+ )
522
+
523
+ conn.execute(
524
+ """
525
+ UPDATE conversations
526
+ SET updated_at = ?
527
+ WHERE conversation_id = ?
528
+ """,
529
+ (now, conversation_id),
530
+ )
531
+
532
+
533
+ def load_messages(conversation_id: str):
534
+ with get_conn() as conn:
535
+ rows = conn.execute(
536
+ """
537
+ SELECT role, content
538
+ FROM conversation_messages
539
+ WHERE conversation_id = ?
540
+ ORDER BY message_id ASC
541
+ """,
542
+ (conversation_id,),
543
+ ).fetchall()
544
+
545
+ return [{"role": r["role"], "content": r["content"]} for r in rows]
546
+
547
+ # ------------------------------
548
+ # State snapshot
549
+ # ------------------------------
550
+ def save_state(conversation_id: str, state: dict):
551
+ now = datetime.datetime.utcnow().isoformat()
552
+ payload = json.dumps(state)
553
+
554
+ with get_conn() as conn:
555
+ conn.execute(
556
+ """
557
+ INSERT INTO conversation_state (conversation_id, state_json, updated_at)
558
+ VALUES (?, ?, ?)
559
+ ON CONFLICT(conversation_id)
560
+ DO UPDATE SET
561
+ state_json = excluded.state_json,
562
+ updated_at = excluded.updated_at
563
+ """,
564
+ (conversation_id, payload, now),
565
+ )
566
+
567
+
568
+ def load_state(conversation_id: str) -> dict:
569
+ with get_conn() as conn:
570
+ row = conn.execute(
571
+ """
572
+ SELECT state_json
573
+ FROM conversation_state
574
+ WHERE conversation_id = ?
575
+ """,
576
+ (conversation_id,),
577
+ ).fetchone()
578
+
579
+ if not row:
580
+ return {}
581
+
582
+ return json.loads(row["state_json"])
583
+
584
+
585
+ # %%
586
+
587
+ def extract_schema_metadata(connection: sqlite3.Connection) -> Dict[str, Any]:
588
+ """
589
+ Introspect SQLite tables, columns and foreign key relationships.
590
+ """
591
+ cursor = connection.cursor()
592
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
593
+ tables = [row[0] for row in cursor.fetchall()]
594
+
595
+ metadata: Dict[str, Any] = {"tables": {}}
596
+
597
+ for table in tables:
598
+ cursor.execute(f"PRAGMA table_info('{table}')")
599
+ cols = cursor.fetchall()
600
+
601
+ # Manual docs for the table
602
+ table_docs = OLIST_DOCS.get(table, {})
603
+ table_desc: str = table_docs.get(
604
+ "description",
605
+ f"Table '{table}' from Olist dataset."
606
+ )
607
+ column_docs: Dict[str, str] = table_docs.get("columns", {})
608
+
609
+ columns_meta: Dict[str, str] = {}
610
+ for c in cols:
611
+ col_name = c[1]
612
+ col_type = c[2] or "TEXT"
613
+
614
+ if col_name in column_docs:
615
+ columns_meta[col_name] = column_docs[col_name]
616
+ else:
617
+ columns_meta[col_name] = f"Column '{col_name}' of type {col_type}"
618
+
619
+ cursor.execute(f"PRAGMA foreign_key_list('{table}')")
620
+ fk_rows = cursor.fetchall()
621
+ relationships: List[str] = []
622
+ for fk in fk_rows:
623
+ ref_table = fk[2]
624
+ from_col = fk[3]
625
+ to_col = fk[4]
626
+ relationships.append(
627
+ f"{table}.{from_col} → {ref_table}.{to_col} (foreign key)"
628
+ )
629
+
630
+ metadata["tables"][table] = {
631
+ "description": table_desc,
632
+ "columns": columns_meta,
633
+ "relationships": relationships,
634
+ }
635
+
636
+ return metadata
637
+
638
+ with get_analytics_conn() as conn:
639
+ schema_metadata = extract_schema_metadata(conn)
640
+
641
+
642
+
643
+ # %%
644
+
645
+ def build_schema_yaml(metadata: Dict[str, Any]) -> str:
646
+ """
647
+ Render metadata dict into a YAML-style string.
648
+ """
649
+ lines: List[str] = ["tables:"]
650
+ for tname, tinfo in metadata["tables"].items():
651
+ lines.append(f" {tname}:")
652
+ desc = tinfo.get("description", "").replace('"', "'")
653
+ lines.append(f' description: "{desc}"')
654
+ lines.append(" columns:")
655
+ for col_name, col_desc in tinfo.get("columns", {}).items():
656
+ col_desc_clean = col_desc.replace('"', "'")
657
+ lines.append(f' {col_name}: "{col_desc_clean}"')
658
+ rels = tinfo.get("relationships", [])
659
+ if rels:
660
+ lines.append(" relationships:")
661
+ for rel in rels:
662
+ rel_clean = rel.replace('"', "'")
663
+ lines.append(f' - "{rel_clean}"')
664
+ return "\n".join(lines)
665
+
666
+
667
+ schema_yaml = build_schema_yaml(schema_metadata)
668
+
669
+
670
+
671
+ # %%
672
+ # 6. Build schema documents for RAG (taking samples from the table)
673
+
674
+ def build_schema_documents(
675
+ schema_metadata: Dict[str, Any],
676
+ sample_rows: int = 5,
677
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
678
+ """
679
+ Build one rich RAG document per table, using schema_metadata.
680
+ """
681
+
682
+ with get_analytics_conn() as connection:
683
+ cursor = connection.cursor()
684
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
685
+ tables = [row[0] for row in cursor.fetchall()]
686
+
687
+ docs: List[str] = []
688
+ metadatas: List[Dict[str, Any]] = []
689
+
690
+ for table in tables:
691
+ if table not in schema_metadata["tables"]:
692
+ continue
693
+
694
+ tmeta = schema_metadata["tables"][table]
695
+ table_desc = tmeta.get("description", "")
696
+ columns_meta = tmeta.get("columns", {})
697
+ relationships = tmeta.get("relationships", [])
698
+
699
+ cursor.execute(f"PRAGMA table_info('{table}')")
700
+ cols = cursor.fetchall()
701
+
702
+ col_lines = []
703
+ for c in cols:
704
+ col_name = c[1]
705
+ col_type = c[2] or "TEXT"
706
+ col_desc = columns_meta.get(
707
+ col_name,
708
+ f"Column '{col_name}' of type {col_type}"
709
+ )
710
+ col_lines.append(f"- {col_name} ({col_type}): {col_desc}")
711
+
712
+ try:
713
+ sample_df = pd.read_sql_query(
714
+ f"SELECT * FROM '{table}' LIMIT {sample_rows}",
715
+ connection,
716
+ )
717
+ sample_text = sample_df.to_markdown(index=False)
718
+ except Exception:
719
+ sample_text = "(could not fetch sample rows)"
720
+
721
+ rel_block = ""
722
+ if relationships:
723
+ rel_block = "Relationships:\n" + "\n".join(
724
+ f"- {rel}" for rel in relationships
725
+ ) + "\n"
726
+
727
+ doc_text = (
728
+ f"Table: {table}\n"
729
+ f"Description: {table_desc}\n\n"
730
+ f"Columns:\n" + "\n".join(col_lines) + "\n\n"
731
+ f"{rel_block}\n"
732
+ f"Example rows:\n{sample_text}\n"
733
+ )
734
+
735
+ docs.append(doc_text)
736
+ metadatas.append({
737
+ "doc_type": "table_schema",
738
+ "table_name": table,
739
+ })
740
+
741
+ return docs, metadatas
742
+
743
+
744
+ # Build RAG texts + metadata
745
+ schema_docs, schema_doc_metas = build_schema_documents(schema_metadata)
746
+
747
+ RAG_TEXTS: List[str] = []
748
+ RAG_METADATAS: List[Dict[str, Any]] = []
749
+
750
+ # 1) Per-table docs
751
+ RAG_TEXTS.extend(schema_docs)
752
+ RAG_METADATAS.extend(schema_doc_metas)
753
+
754
+ # 2) Global YAML as a separate doc
755
+ RAG_TEXTS.append("SCHEMA_METADATA_YAML:\n" + schema_yaml)
756
+ RAG_METADATAS.append({"doc_type": "global_schema"})
757
+
758
+
759
+ # %%
760
+ def build_store_final():
761
+ embedding_model = HuggingFaceEmbeddings(
762
+ model_name=EMBED_MODEL_NAME,
763
+ encode_kwargs={"normalize_embeddings": True},
764
+ )
765
+
766
+ rag_store = Chroma(
767
+ client=get_chroma_client(),
768
+ collection_name="rag_schema_store",
769
+ embedding_function=embedding_model,
770
+ )
771
+
772
+ if rag_store._collection.count() == 0:
773
+ rag_store.add_texts(
774
+ texts=RAG_TEXTS,
775
+ metadatas=RAG_METADATAS,
776
+ )
777
+
778
+ rag_retriever = rag_store.as_retriever(
779
+ search_kwargs={"k": 3, "filter": {"doc_type": "table_schema"}}
780
+ )
781
+
782
+ return rag_store, rag_retriever
783
+
784
+ # Initialize once
785
+ rag_store, rag_retriever = build_store_final()
786
+
787
+
788
+ # %%
789
+ from langchain_community.vectorstores import Chroma
790
+ from langchain_community.embeddings import HuggingFaceEmbeddings
791
+
792
+
793
+ _sql_cache_store = None
794
+ _sql_embedding_fn = None
795
+
796
+ SQL_CACHE_COLLECTION = "sql_cache_mpnet"
797
+
798
+ def get_sql_cache_store():
799
+ global _sql_cache_store, _sql_embedding_fn
800
+
801
+ if _sql_cache_store is not None:
802
+ return _sql_cache_store
803
+
804
+ if _sql_embedding_fn is None:
805
+ _sql_embedding_fn = HuggingFaceEmbeddings(
806
+ model_name=EMBED_MODEL_NAME,
807
+ encode_kwargs={"normalize_embeddings": True},
808
+ )
809
+
810
+ _sql_cache_store = Chroma(
811
+ client=get_chroma_client(),
812
+ collection_name=SQL_CACHE_COLLECTION,
813
+ embedding_function=_sql_embedding_fn,
814
+ )
815
+
816
+ return _sql_cache_store
817
+
818
+
819
+ # %%
820
+ def sanitize_metadata_for_chroma(metadata: dict) -> dict:
821
+ safe = {}
822
+ for k, v in (metadata or {}).items():
823
+ if v is None:
824
+ continue
825
+ if isinstance(v, (int, float, bool)):
826
+ safe[k] = v
827
+ elif isinstance(v, str):
828
+ safe[k] = v
829
+ else:
830
+ safe[k] = str(v)
831
+ return safe
832
+
833
+
834
+ # %%
835
+ def normalize_question_strict(q: str) -> str:
836
+ """
837
+ Deterministic normalization for exact cache hits.
838
+ """
839
+ if not q:
840
+ return ""
841
+ q = q.lower().strip()
842
+ q = re.sub(r"[^\w\s]", "", q)
843
+ q = re.sub(r"\s+", " ", q)
844
+ return q
845
+
846
+
847
+ # %%
848
+ # Helper: normalize question
849
+
850
+ def normalize_question_text(q: str) -> str:
851
+ if not q:
852
+ return ""
853
+ q = q.strip().lower()
854
+ q = re.sub(r"[^\w\s]", " ", q)
855
+ q = re.sub(r"\s+", " ", q).strip()
856
+ return q
857
+
858
+ # Helper: compute success rate
859
+ def compute_success_rate(md: dict) -> float:
860
+ sc = md.get("success_count", 0) or 0
861
+ tf = md.get("total_feedbacks", 0) or 0
862
+ if tf <= 0:
863
+ return 0.0
864
+ return float(sc) / float(tf)
865
+
866
+ # Insert initial cache entry (no feedback yet)
867
+ def cache_sql_answer_initial(question: str, sql: str, answer_md: str, store=None, extra_metadata: dict = None):
868
+ """
869
+ Insert a cached entry when you run a query and want to cache it regardless of feedback.
870
+ initial metrics: views=1, success_count=0, total_feedbacks=0, success_rate=1.0
871
+ """
872
+ if store is None:
873
+ store = get_sql_cache_store()
874
+
875
+ ident = uuid.uuid4().hex
876
+ norm = normalize_question_text(question)
877
+ md = {
878
+ "id": ident,
879
+ "normalized_question": norm,
880
+ "sql": sql,
881
+ "answer_md": answer_md,
882
+ "saved_at": time.time(),
883
+ "views": 1,
884
+ "success_count": 0,
885
+ "total_feedbacks": 0,
886
+ "success_rate": 0.5,
887
+ }
888
+ if extra_metadata:
889
+ md.update(extra_metadata)
890
+
891
+ # Use store's API;
892
+ store.add_texts([question], metadatas=[md])
893
+ return ident
894
+
895
+
896
+
897
+ # %%
898
+ import time
899
+ import logging
900
+ from typing import Optional, List, Dict, Any
901
+ from difflib import SequenceMatcher
902
+
903
+ _logger = logging.getLogger(__name__)
904
+
905
+ # -------------------------------------------------------------------------
906
+ # Utility helpers
907
+ # -------------------------------------------------------------------------
908
+
909
+ def _now_ts() -> float:
910
+ return time.time()
911
+
912
+ def similarity_score(a: str, b: str) -> float:
913
+ return SequenceMatcher(None, (a or ""), (b or "")).ratio()
914
+
915
+ # -------------------------------------------------------------------------
916
+ # Persist helper
917
+ # -------------------------------------------------------------------------
918
+ def langchain_upsert(
919
+ store,
920
+ text: str,
921
+ metadata: dict,
922
+ cache_id: str,
923
+ ):
924
+ safe_md = sanitize_metadata_for_chroma(metadata)
925
+
926
+ try:
927
+ store.add_texts(
928
+ texts=[text],
929
+ metadatas=[safe_md],
930
+ ids=[cache_id],
931
+ )
932
+ except TypeError:
933
+ store.add_texts(
934
+ texts=[text],
935
+ metadatas=[safe_md],
936
+ )
937
+
938
+
939
+ # -------------------------------------------------------------------------
940
+ # Cache insert / update
941
+ # -------------------------------------------------------------------------
942
+ def cache_sql_answer_dedup(
943
+ question: str,
944
+ sql: str,
945
+ answer_md: str,
946
+ metadata: dict,
947
+ store,
948
+ ):
949
+ norm_q_semantic = normalize_text(question)
950
+ norm_q_exact = normalize_question_strict(question)
951
+
952
+ cache_id = generate_cache_id(question, sql)
953
+ now = _now_ts()
954
+
955
+ md = {
956
+ # identity
957
+ "cache_id": cache_id,
958
+ "sql": sql,
959
+ "answer_md": answer_md,
960
+
961
+ # exact match key
962
+ "normalized_question": norm_q_exact,
963
+
964
+ # timestamps
965
+ "saved_at": metadata.get("saved_at", now),
966
+ "last_updated_at": now,
967
+ "last_viewed_at": metadata.get("last_viewed_at", 0),
968
+
969
+ # metrics
970
+ "good_count": metadata.get("good_count", 0),
971
+ "bad_count": metadata.get("bad_count", 0),
972
+ "total_feedbacks": metadata.get("total_feedbacks", 0),
973
+ "success_rate": metadata.get("success_rate", 0.5),
974
+ "views": metadata.get("views", 0),
975
+ }
976
+
977
+ langchain_upsert(
978
+ store=store,
979
+ text=norm_q_semantic, # semantic vector
980
+ metadata=md,
981
+ cache_id=cache_id,
982
+ )
983
+
984
+ return {
985
+ "question": question,
986
+ "sql": sql,
987
+ "answer_md": answer_md,
988
+ "metadata": md,
989
+ }
990
+
991
+ # -------------------------------------------------------------------------
992
+ # Find exact cached entry (question + SQL)
993
+ # -------------------------------------------------------------------------
994
+
995
+ def find_cached_doc_by_sql(question: str, sql: str, store):
996
+ cache_id = generate_cache_id(question, sql)
997
+ coll = getattr(store, "_collection", None)
998
+
999
+ if coll and hasattr(coll, "get"):
1000
+ try:
1001
+ res = coll.get(ids=[cache_id])
1002
+ if res and res.get("metadatas"):
1003
+ md = res["metadatas"][0]
1004
+ return {
1005
+ "id": cache_id,
1006
+ "question": question,
1007
+ "sql": md.get("sql"),
1008
+ "answer_md": md.get("answer_md"),
1009
+ "metadata": md,
1010
+ }
1011
+ except Exception:
1012
+ pass
1013
+
1014
+ return None
1015
+
1016
+ # -------------------------------------------------------------------------
1017
+ # Retrieve cached answers ranked primarily by success rate
1018
+ # -------------------------------------------------------------------------
1019
+
1020
+ import re
1021
+ import unicodedata
1022
+
1023
+
1024
+ def normalize_text(text: str) -> str:
1025
+ if not text:
1026
+ return ""
1027
+ text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("ascii")
1028
+ text = text.lower()
1029
+ text = re.sub(r"[^\w\s]", " ", text)
1030
+ text = re.sub(r"\s+", " ", text).strip()
1031
+ return text
1032
+
1033
+ import hashlib
1034
+
1035
+ def generate_cache_id(question: str, sql: str) -> str:
1036
+ q = normalize_text(question)
1037
+ s = (sql or "").strip()
1038
+ key = f"{q}||{s}".encode("utf-8")
1039
+ return hashlib.sha1(key).hexdigest()
1040
+
1041
+
1042
+ def rank_cached_candidates(candidates: list[dict]) -> dict:
1043
+ """
1044
+ Deterministic quality-first ranking.
1045
+ """
1046
+ candidates.sort(
1047
+ key=lambda c: (
1048
+ -float(c["metadata"].get("success_rate", 0.5)),
1049
+ -int(c["metadata"].get("good_count", 0)),
1050
+ int(c["metadata"].get("bad_count", 0)),
1051
+ -float(c["metadata"].get("last_updated_at", 0)),
1052
+ )
1053
+ )
1054
+ return candidates[0]
1055
+
1056
+
1057
+
1058
+ def retrieve_exact_cached_sql(question: str, store):
1059
+ """
1060
+ Exact question match, but still quality-ranked.
1061
+ """
1062
+ norm_q = normalize_question_strict(question)
1063
+ coll = store._collection
1064
+
1065
+ try:
1066
+ res = coll.get(where={"normalized_question": norm_q})
1067
+ if not res or not res.get("metadatas"):
1068
+ return None
1069
+
1070
+ candidates = []
1071
+ for md in res["metadatas"]:
1072
+ candidates.append({
1073
+ "matched_question": question,
1074
+ "sql": md["sql"],
1075
+ "answer_md": md.get("answer_md", ""),
1076
+ "distance": 0.0,
1077
+ "metadata": md,
1078
+ })
1079
+
1080
+ return rank_cached_candidates(candidates)
1081
+
1082
+ except Exception:
1083
+ return None
1084
+
1085
+ def retrieve_best_cached_sql(
1086
+ question: str,
1087
+ store,
1088
+ max_distance: float = 0.25,
1089
+ ):
1090
+ norm_q = normalize_text(question)
1091
+ results = store.similarity_search_with_score(norm_q, k=20)
1092
+
1093
+ candidates = []
1094
+
1095
+ for doc, score in results:
1096
+ distance = float(score)
1097
+ if distance > max_distance:
1098
+ continue
1099
+
1100
+ md = doc.metadata or {}
1101
+ if "sql" not in md:
1102
+ continue
1103
+
1104
+ candidates.append({
1105
+ "matched_question": doc.page_content,
1106
+ "sql": md["sql"],
1107
+ "answer_md": md.get("answer_md", ""),
1108
+ "distance": distance,
1109
+ "metadata": md,
1110
+
1111
+ # ranking signals (FORCED numeric)
1112
+ "success_rate": float(md.get("success_rate", 0.5)),
1113
+ "good": int(md.get("good_count", 0)),
1114
+ "bad": int(md.get("bad_count", 0)),
1115
+ "views": int(md.get("views", 0)),
1116
+ "last_updated": float(md.get("last_updated_at", 0)),
1117
+ })
1118
+
1119
+ if not candidates:
1120
+ return None
1121
+
1122
+ # QUALITY-FIRST RANKING
1123
+ candidates.sort(
1124
+ key=lambda c: (
1125
+ -c["success_rate"],
1126
+ -c["good"],
1127
+ c["bad"],
1128
+ c["distance"],
1129
+ -c["last_updated"],
1130
+ )
1131
+ )
1132
+
1133
+ return candidates[0]
1134
+
1135
+
1136
+
1137
+ # -------------------------------------------------------------------------
1138
+ # Increment views
1139
+ # -------------------------------------------------------------------------
1140
+
1141
+ def increment_cache_views(metadata: dict, store):
1142
+ if not metadata:
1143
+ return False
1144
+
1145
+ cache_id = metadata.get("cache_id")
1146
+ if not cache_id:
1147
+ return False
1148
+
1149
+ md = dict(metadata)
1150
+ md["views"] = int(md.get("views", 0)) + 1
1151
+ md["last_viewed_at"] = _now_ts()
1152
+ md["last_updated_at"] = _now_ts()
1153
+ md["saved_at"] = md.get("saved_at", _now_ts())
1154
+
1155
+ try:
1156
+ langchain_upsert(
1157
+ store=store,
1158
+ text=normalize_text(md.get("normalized_question", "")),
1159
+ metadata=md,
1160
+ cache_id=cache_id,
1161
+ )
1162
+ return True
1163
+ except Exception:
1164
+ _logger.exception("increment_cache_views failed")
1165
+ return False
1166
+
1167
+
1168
+ # Update metrics on feedback
1169
+ def update_cache_on_feedback(
1170
+ question: str,
1171
+ original_doc_md: dict,
1172
+ user_marked_good: bool,
1173
+ llm_corrected_sql: str | None,
1174
+ llm_corrected_answer_md: str | None,
1175
+ store,
1176
+ ):
1177
+ if not original_doc_md:
1178
+ return
1179
+
1180
+ md = dict(original_doc_md["metadata"])
1181
+ cache_id = md["cache_id"]
1182
+
1183
+ # ---- feedback counts ----
1184
+ if user_marked_good:
1185
+ md["good_count"] = md.get("good_count", 0) + 1
1186
+ else:
1187
+ md["bad_count"] = md.get("bad_count", 0) + 1
1188
+
1189
+ md["total_feedbacks"] = md.get("total_feedbacks", 0) + 1
1190
+ md["success_rate"] = (
1191
+ md["good_count"] / md["total_feedbacks"]
1192
+ if md["total_feedbacks"] > 0 else 0.5
1193
+ )
1194
+
1195
+ # ---- timestamps ----
1196
+ md["saved_at"] = md.get("saved_at", _now_ts()) # preserve
1197
+ md["last_updated_at"] = _now_ts()
1198
+
1199
+ langchain_upsert(
1200
+ store=store,
1201
+ text=normalize_text(question),
1202
+ metadata=md,
1203
+ cache_id=cache_id,
1204
+ )
1205
+
1206
+ # -------------------------
1207
+ # Corrected SQL --> NEW ENTRY
1208
+ # -------------------------
1209
+ if llm_corrected_sql and llm_corrected_answer_md:
1210
+ cache_sql_answer_dedup(
1211
+ question=question,
1212
+ sql=llm_corrected_sql,
1213
+ answer_md=llm_corrected_answer_md,
1214
+ metadata={
1215
+ "good_count": 1,
1216
+ "bad_count": 0,
1217
+ "total_feedbacks": 1,
1218
+ "success_rate": 0.5,
1219
+ "views": 0,
1220
+ "saved_at": _now_ts(),
1221
+ },
1222
+ store=store,
1223
+ )
1224
+
1225
+
1226
+ # %%
1227
+
1228
+ def log_pipeline(event: str, detail: str = ""):
1229
+ """
1230
+ Centralized pipeline logger.
1231
+ Shows in console + can be routed to file later.
1232
+ """
1233
+ ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
1234
+ msg = f"[{ts}] PIPELINE::{event}"
1235
+ if detail:
1236
+ msg += f" | {detail}"
1237
+ print(msg)
1238
+
1239
+ # ### 8. Groq LLM via LangChain
1240
+
1241
+ from langchain_groq import ChatGroq
1242
+ import re
1243
+ import gradio as gr
1244
+
1245
+ llm = ChatGroq(model=GROQ_MODEL_NAME, groq_api_key=GROQ_API_KEY)
1246
+
1247
+ # %%
1248
+ def get_rag_context(question: str) -> str:
1249
+ """
1250
+ Retrieve the most relevant schema documents for the question.
1251
+ """
1252
+ docs = rag_retriever.invoke(question)
1253
+ return "\n\n---\n\n".join(d.page_content for d in docs)
1254
+
1255
+ def clean_sql(sql: str) -> str:
1256
+ sql = sql.strip()
1257
+ if "```" in sql:
1258
+ sql = sql.replace("```sql", "").replace("```", "").strip()
1259
+ return sql
1260
+
1261
+ def extract_sql_from_markdown(text: str) -> str:
1262
+ """
1263
+ Extract the first ```sql ... ``` block from LLM output.
1264
+ If not found, return the whole text.
1265
+ """
1266
+ match = re.search(r"```sql(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
1267
+ if match:
1268
+ return match.group(1).strip()
1269
+ return text.strip()
1270
+
1271
+ def extract_explanation_after_marker(text: str, marker: str = "EXPLANATION:") -> str:
1272
+ """
1273
+ After the given marker, return the rest of the text as explanation.
1274
+ """
1275
+ idx = text.upper().find(marker.upper())
1276
+ if idx == -1:
1277
+ return text.strip()
1278
+ return text[idx + len(marker):].strip()
1279
+
1280
+ # 6b. General / descriptive questions
1281
+
1282
+ GENERAL_DESC_KEYWORDS = [
1283
+ "what is this dataset about",
1284
+ "what is this data about",
1285
+ "describe this dataset",
1286
+ "describe the dataset",
1287
+ "dataset overview",
1288
+ "data overview",
1289
+ "summary of the dataset",
1290
+ "explain this dataset",
1291
+ ]
1292
+
1293
+
1294
+ def is_general_question(question: str) -> bool:
1295
+ """
1296
+ Detect high-level descriptive questions where we should answer
1297
+ directly from schema context instead of generating SQL.
1298
+ """
1299
+ q = question.lower().strip()
1300
+ return any(key in q for key in GENERAL_DESC_KEYWORDS)
1301
+
1302
+
1303
+ def answer_general_question(question: str) -> str:
1304
+ """
1305
+ Use the RAG schema docs to generate a rich, high-level description
1306
+ of the Olist dataset for conceptual questions.
1307
+ """
1308
+ rag_context = get_rag_context(question)
1309
+
1310
+ system_instructions = """
1311
+ You are a data documentation expert.
1312
+
1313
+ You will be given:
1314
+ - Schema documentation for the Olist dataset (tables, descriptions, columns, relationships).
1315
+ - A high-level user question like "what is this dataset about?".
1316
+
1317
+ Your job:
1318
+ - Write a clear, structured overview of the dataset.
1319
+ - Explain the main entities (customers, orders, items, products, sellers, payments, reviews, geolocation).
1320
+ - Mention typical analysis use-cases (delivery performance, customer behavior, seller performance, product/category analysis, etc.).
1321
+ - Target a non-technical person.
1322
+
1323
+ Do NOT write SQL. Answer in Markdown.
1324
+ """
1325
+
1326
+ prompt = (
1327
+ system_instructions
1328
+ + "\n\n=== SCHEMA CONTEXT ===\n"
1329
+ + rag_context
1330
+ + "\n\n=== USER QUESTION ===\n"
1331
+ + question
1332
+ + "\n\nDetailed dataset overview:"
1333
+ )
1334
+
1335
+ log_prompt("GENERAL_DATASET_QUESTION", prompt)
1336
+
1337
+ response = llm.invoke(prompt)
1338
+ log_run_event("RAW_MODEL_RESPONSE_GENERAL_DATASET", response.content)
1339
+
1340
+ return response.content.strip()
1341
+
1342
+ # %%
1343
+ # ### 9. SQL execution / validation
1344
+
1345
+ def execute_sql(sql: str, connection: sqlite3.Connection) -> Tuple[Optional[pd.DataFrame], Optional[str]]:
1346
+ """
1347
+ Execute SQL on SQLite and return a DataFrame, else return an error.
1348
+ """
1349
+ try:
1350
+ df = pd.read_sql_query(sql, connection)
1351
+ return df, None
1352
+ except Exception as e:
1353
+ return None, str(e)
1354
+
1355
+ def validate_sql(sql: str, connection: sqlite3.Connection) -> Tuple[bool, Optional[str]]:
1356
+ """
1357
+ Basic SQL validator:
1358
+ - Uses EXPLAIN QUERY PLAN to detect syntax or schema issues.
1359
+ """
1360
+ try:
1361
+ cursor = connection.cursor()
1362
+ cursor.execute(f"EXPLAIN QUERY PLAN {sql}")
1363
+ return True, None
1364
+ except Exception as e:
1365
+ return False, str(e)
1366
+
1367
+ # %%
1368
+ # ### 10. SQL Generation + Repair with LLM
1369
+ def review_and_correct_sql_with_llm(
1370
+ question: str,
1371
+ generated_sql: str,
1372
+ user_feedback_comment: str,
1373
+ validation_payload: dict,
1374
+ decision_reason: str,
1375
+ rag_context: str,
1376
+ ) -> tuple[str, str]:
1377
+ """
1378
+ Fix SQL ONLY after LLM has decided it is necessary.
1379
+ """
1380
+
1381
+ log_pipeline("SQL_CORRECTION_START")
1382
+
1383
+ evidence_md = ""
1384
+ for chk in validation_payload.get("equivalence", []):
1385
+ evidence_md += chk["df"].to_markdown(index=False) + "\n\n"
1386
+
1387
+ prompt = f"""
1388
+ Original question:
1389
+ {question}
1390
+
1391
+ Original SQL:
1392
+ ```sql
1393
+ {generated_sql}
1394
+
1395
+ User feedback:
1396
+ {user_feedback_comment}
1397
+
1398
+ Reason SQL is considered incorrect:
1399
+ {decision_reason}
1400
+
1401
+ Validation evidence:
1402
+ {evidence_md}
1403
+
1404
+ Schema:
1405
+ {rag_context}
1406
+
1407
+ TASK:
1408
+
1409
+ Fix the SQL ONLY if necessary
1410
+
1411
+ If SQL is already correct, return it unchanged
1412
+
1413
+ Explain your reasoning
1414
+
1415
+ Return corrected SQL and explanation.
1416
+ """
1417
+
1418
+ resp = llm.invoke(prompt)
1419
+
1420
+ corrected_sql = extract_sql_from_markdown(resp.content) or generated_sql
1421
+ corrected_sql = clean_sql(corrected_sql)
1422
+
1423
+ explanation = extract_explanation_after_marker(resp.content, "EXPLANATION:")
1424
+
1425
+ log_pipeline(
1426
+ "SQL_CORRECTION_DONE",
1427
+ f"sql_changed={corrected_sql.strip() != generated_sql.strip()}"
1428
+ )
1429
+
1430
+ return corrected_sql, explanation
1431
+
1432
+ # Generate SQL
1433
+
1434
+ def generate_sql(question: str, rag_context: str) -> str:
1435
+ """
1436
+ Generate SQL using LLM, but also pull in any past user feedback
1437
+ (corrected SQL) for this question as external guidance.
1438
+ """
1439
+ # External correction from past feedback, if any
1440
+ with get_analytics_conn() as conn:
1441
+ try:
1442
+ last_fb = get_last_feedback_for_question(conn, question)
1443
+ except sqlite3.OperationalError:
1444
+ init_feedback_table(conn)
1445
+ last_fb = None
1446
+
1447
+
1448
+ if last_fb and last_fb.get("corrected_sql"):
1449
+ previous_feedback_block = f"""
1450
+ EXTERNAL USER FEEDBACK FROM PAST RUNS:
1451
+
1452
+ Previous generated SQL:
1453
+ {last_fb['generated_sql']}
1454
+
1455
+ Corrected SQL (preferred reference):
1456
+ {last_fb['corrected_sql']}
1457
+
1458
+ User comment & prior explanation:
1459
+ {last_fb.get('comment') or '(none)'}
1460
+
1461
+ You must avoid repeating the same mistake and should follow the logic of
1462
+ the corrected SQL when appropriate, while still reasoning from the schema.
1463
+ """
1464
+ else:
1465
+ previous_feedback_block = ""
1466
+
1467
+ system_instructions = f"""
1468
+ You are a senior data analyst writing SQL for a SQLite database.
1469
+
1470
+ You will be given:
1471
+
1472
+ - A description of available tables, columns, and relationships (schema + YAML metadata).
1473
+ - A natural language question from the user.
1474
+
1475
+ Your job:
1476
+
1477
+ - Write ONE valid SQLite SQL query that answers the question.
1478
+ - ONLY use tables and columns that exist in the schema_context.
1479
+ - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
1480
+ - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
1481
+ - Always use floating-point division for percentage calculations using 1.0 * numerator / denominator,
1482
+ and round to 2 decimals when appropriate.
1483
+
1484
+ {previous_feedback_block}
1485
+ """
1486
+
1487
+ prompt = (
1488
+ system_instructions
1489
+ + "\n\n=== RAG CONTEXT ===\n"
1490
+ + rag_context
1491
+ + "\n\n=== USER QUESTION ===\n"
1492
+ + question
1493
+ + "\n\nSQL query:"
1494
+ )
1495
+
1496
+ log_prompt("SQL_GENERATION", prompt)
1497
+ response = llm.invoke(prompt)
1498
+ log_run_event("RAW_MODEL_RESPONSE_SQL_GENERATION", response.content)
1499
+
1500
+ sql = clean_sql(response.content)
1501
+ return sql
1502
+
1503
+
1504
+ # %%
1505
+ # Repair SQL
1506
+
1507
+ def repair_sql(
1508
+ question: str,
1509
+ rag_context: str,
1510
+ bad_sql: str,
1511
+ error_message: str,
1512
+ ) -> str:
1513
+ """
1514
+ Ask the LLM to correct a failing SQL query.
1515
+ """
1516
+ system_instructions = """
1517
+ You are a senior data analyst fixing an existing SQL query for a SQLite database.
1518
+
1519
+ You will be given:
1520
+
1521
+ - Schema context (tables, columns, relationships).
1522
+ - The user's question.
1523
+ - A previously generated SQL query that failed.
1524
+ - The SQLite error message.
1525
+
1526
+ Your job:
1527
+
1528
+ - Diagnose why the query failed.
1529
+ - Rewrite ONE valid SQLite SQL query that answers the question.
1530
+ - ONLY use tables and columns that exist in the schema_context.
1531
+ - Use correct JOINS (Left, Right, Inner, Outer, Full Outer etc.) with ON conditions.
1532
+ - Do not use DROP, INSERT, UPDATE, DELETE or other destructive operations.
1533
+ - Return ONLY the corrected SQL query, no explanation or markdown.
1534
+ """
1535
+
1536
+ prompt = (
1537
+ system_instructions
1538
+ + "\n\n=== RAG CONTEXT ===\n"
1539
+ + rag_context
1540
+ + "\n\n=== USER QUESTION ===\n"
1541
+ + question
1542
+ + "\n\n=== PREVIOUS (FAILING) SQL ===\n"
1543
+ + bad_sql
1544
+ + "\n\n=== SQLITE ERROR ===\n"
1545
+ + error_message
1546
+ + "\n\nCorrected SQL query:"
1547
+ )
1548
+
1549
+ log_prompt("SQL_REPAIR", prompt)
1550
+ response = llm.invoke(prompt)
1551
+ log_run_event("RAW_MODEL_RESPONSE_SQL_REPAIR", response.content)
1552
+
1553
+ sql = clean_sql(response.content)
1554
+ return sql
1555
+
1556
+
1557
+ # %%
1558
+ ### 11. Result summarization
1559
+
1560
+ def summarize_results(
1561
+ question: str,
1562
+ sql: str,
1563
+ df: Optional[pd.DataFrame],
1564
+ rag_context: str,
1565
+ error: Optional[str] = None,
1566
+ ) -> str:
1567
+ """
1568
+ Ask the LLM to produce a concise, human-readable answer.
1569
+ """
1570
+ system_instructions = """
1571
+ You are a senior data analyst.
1572
+
1573
+ You will be given:
1574
+
1575
+ The user's question.
1576
+ The final SQL that was executed.
1577
+ A small preview of the query result (as a Markdown table, if available).
1578
+ Optional error information if the query failed.
1579
+
1580
+ Your job:
1581
+
1582
+ Provide a clear, concise answer in Markdown.
1583
+ If the result is numeric / aggregated, explain what it means in business terms.
1584
+ If there was an error, explain it simply and suggest how the user could rephrase.
1585
+ Do NOT show raw SQL unless it is helpful to the user.
1586
+ """
1587
+
1588
+ # Build a markdown table preview if we have data
1589
+ if df is not None and not df.empty:
1590
+ preview_rows = min(len(df), 50)
1591
+ df_preview_md = df.head(preview_rows).to_markdown(index=False)
1592
+ else:
1593
+ df_preview_md = "(no rows returned)"
1594
+
1595
+ prompt = (
1596
+ system_instructions
1597
+ + "\n\n=== USER QUESTION ===\n"
1598
+ + question
1599
+ + "\n\n=== EXECUTED SQL ===\n"
1600
+ + sql
1601
+ + "\n\n=== QUERY RESULT PREVIEW ===\n"
1602
+ + df_preview_md
1603
+ + "\n\n=== RAG CONTEXT (schema) ===\n"
1604
+ + rag_context
1605
+ )
1606
+
1607
+ if error:
1608
+ prompt += "\n\n=== ERROR ===\n" + error
1609
+
1610
+ # Logging helpers assumed to exist
1611
+ log_prompt("RESULT_SUMMARY", prompt)
1612
+
1613
+ response = llm.invoke(prompt)
1614
+ log_run_event("RAW_MODEL_RESPONSE_RESULT_SUMMARY", response.content)
1615
+
1616
+ return response.content.strip()
1617
+
1618
+
1619
+ # %%
1620
+ def backend_pipeline(question: str):
1621
+ """
1622
+ STRICT cache-first backend.
1623
+ ALWAYS returns EXACTLY 10 values.
1624
+ """
1625
+
1626
+ if not question or not question.strip():
1627
+ return (
1628
+ "Please type a question.",
1629
+ pd.DataFrame(),
1630
+ "",
1631
+ "",
1632
+ "",
1633
+ "",
1634
+ {},
1635
+ 4,
1636
+ "**Feedback attempts remaining: 4**",
1637
+ "",
1638
+ )
1639
+
1640
+ attempts_left = 4
1641
+ attempts_text = f"**Feedback attempts remaining: {attempts_left}**"
1642
+ store = get_sql_cache_store()
1643
+
1644
+ # ---------------- CACHE LOOKUP ----------------
1645
+ cached = retrieve_exact_cached_sql(question, store)
1646
+ if cached is None:
1647
+ cached = retrieve_best_cached_sql(question, store, max_distance=0.25)
1648
+
1649
+ # ---------------- CACHE HIT ----------------
1650
+ if cached:
1651
+ increment_cache_views(cached["metadata"], store)
1652
+
1653
+ rag_context = get_rag_context(question)
1654
+ sql = cached["sql"]
1655
+
1656
+ with get_analytics_conn() as conn:
1657
+ df, err = execute_sql(sql, conn)
1658
+
1659
+ if err:
1660
+ df = pd.DataFrame()
1661
+
1662
+ md = cached["metadata"]
1663
+
1664
+ answer_md = cached.get("answer_md", "")
1665
+
1666
+ debug_md = f"""
1667
+ ### Query Provenance (Cache)
1668
+
1669
+ **Matched question:** `{cached.get("matched_question", "")}`
1670
+ **Success rate:** {md.get("success_rate", 0.5):.2f}
1671
+ **Good / Bad:** {md.get("good_count", 0)} / {md.get("bad_count", 0)}
1672
+ **Views:** {md.get("views", 0)}
1673
+
1674
+ **SQL used:**
1675
+ {sql}
1676
+
1677
+
1678
+ """
1679
+
1680
+ return (
1681
+ answer_md,
1682
+ df,
1683
+ sql,
1684
+ rag_context,
1685
+ question,
1686
+ answer_md,
1687
+ df_to_state(df),
1688
+ attempts_left,
1689
+ attempts_text,
1690
+ debug_md,
1691
+ )
1692
+
1693
+ # ---------------- LLM FLOW ----------------
1694
+ rag_context = get_rag_context(question)
1695
+ sql = generate_sql(question, rag_context)
1696
+
1697
+ with get_analytics_conn() as conn:
1698
+ is_valid, err = validate_sql(sql, conn)
1699
+
1700
+ if not is_valid and err:
1701
+ sql = repair_sql(question, rag_context, sql, err)
1702
+
1703
+ with get_analytics_conn() as conn:
1704
+ df, exec_error = execute_sql(sql, conn)
1705
+ if exec_error:
1706
+ df = pd.DataFrame()
1707
+
1708
+ answer_md = summarize_results(
1709
+ question=question,
1710
+ sql=sql,
1711
+ df=df if not exec_error else None,
1712
+ rag_context=rag_context,
1713
+ error=exec_error,
1714
+ )
1715
+
1716
+ # Cache LLM result
1717
+ cache_sql_answer_dedup(
1718
+ question=question,
1719
+ sql=sql,
1720
+ answer_md=answer_md,
1721
+ metadata={
1722
+ "good_count": 0,
1723
+ "bad_count": 0,
1724
+ "total_feedbacks": 0,
1725
+ "success_rate": 0.5,
1726
+ "views": 1,
1727
+ "saved_at": time.time(),
1728
+ },
1729
+ store=store,
1730
+ )
1731
+
1732
+ debug_md = f"""
1733
+
1734
+ Query Provenance (LLM Generated)
1735
+
1736
+ Source: LLM
1737
+ Execution status: {"Error" if exec_error else "Success"}
1738
+
1739
+ SQL used:
1740
+
1741
+ {sql}
1742
+
1743
+
1744
+ """
1745
+
1746
+ return (
1747
+ answer_md,
1748
+ df,
1749
+ sql,
1750
+ rag_context,
1751
+ question,
1752
+ answer_md,
1753
+ df_to_state(df),
1754
+ attempts_left,
1755
+ attempts_text,
1756
+ debug_md,
1757
+ )
1758
+
1759
+ # %%
1760
+ import re
1761
+
1762
+ def looks_like_sql(text: str) -> bool:
1763
+ if not text:
1764
+ return False
1765
+ return bool(re.search(
1766
+ r"\bselect\b|\bfrom\b|\bwhere\b|\bjoin\b|\bgroup by\b|\border by\b",
1767
+ text,
1768
+ flags=re.I,
1769
+ ))
1770
+
1771
+
1772
+ def is_feedback_sufficient(feedback_text: str) -> bool:
1773
+ """
1774
+ Decide whether feedback is actionable WITHOUT LLM.
1775
+ """
1776
+ if not feedback_text:
1777
+ return False
1778
+
1779
+ text = feedback_text.strip().lower()
1780
+
1781
+ # Long feedback → sufficient
1782
+ if len(text) >= 60:
1783
+ return True
1784
+
1785
+ # SQL pasted
1786
+ if looks_like_sql(text):
1787
+ return True
1788
+
1789
+ # Numbers + comparison language → sufficient
1790
+ if re.search(r"\b\d+\b", text) and any(
1791
+ k in text for k in [
1792
+ "last year", "previous", "this year",
1793
+ "increase", "decrease", "higher", "lower"
1794
+ ]
1795
+ ):
1796
+ return True
1797
+
1798
+ # Analytical keywords
1799
+ signal_words = [
1800
+ "filter", "where", "year", "month", "should", "instead",
1801
+ "expected", "wrong", "missing", "count", "distinct",
1802
+ "join", "exclude", "include", "only"
1803
+ ]
1804
+
1805
+ signals = sum(1 for w in signal_words if w in text)
1806
+ if signals >= 1 and len(text) >= 20:
1807
+ return True
1808
+
1809
+ return False
1810
+
1811
+
1812
+ # %%
1813
+ def build_followup_prompt_for_user(sample_feedback: str = "") -> str:
1814
+ """
1815
+ Deterministic follow-up question to ask the user when feedback is vague.
1816
+ Returns a friendly prompt that the UI can display to the user.
1817
+ """
1818
+ base = (
1819
+ "Thanks — I need a bit more detail to act on this feedback.\n\n"
1820
+ "Please tell me one (or more) of the following so I can check and correct the result:\n\n"
1821
+ "1. Which part looks wrong — the **numbers**, the **aggregation** (sum/count/avg),\n"
1822
+ " the **time range** (year/month), or the **filters** applied?\n"
1823
+ "2. If you expected a different number, what was the expected number (and how was it computed)?\n"
1824
+ "3. If you have a corrected SQL snippet, paste it (I can run and compare it).\n\n"
1825
+ "Examples you can copy-paste:\n"
1826
+ )
1827
+ examples = (
1828
+ "- \"I think the query should count DISTINCT customer_unique_id, not customer_id.\"\n"
1829
+ "- \"This looks off for year 2018 — I expected the count for 2018 to be ~40k.\"\n"
1830
+ "- \"Please exclude canceled orders (order_status = 'canceled').\"\n"
1831
+ "- \"SELECT COUNT(DISTINCT customer_unique_id) FROM olist_customers_dataset;\"\n"
1832
+ )
1833
+ hint = "\nIf you prefer, just paste a corrected SQL snippet and I'll run it and compare."
1834
+ prompt = base + examples + hint
1835
+ if sample_feedback:
1836
+ prompt = f"I saw your feedback: \"{sample_feedback}\"\n\n" + prompt
1837
+ return prompt
1838
+
1839
+
1840
+
1841
+ # %%
1842
+ def interpret_feedback_intent(question: str, feedback_text: str) -> dict:
1843
+ """
1844
+ Convert business feedback into a testable hypothesis.
1845
+ NO SQL generation here.
1846
+ """
1847
+
1848
+ system_prompt = """
1849
+ You are a senior analytics reviewer.
1850
+
1851
+ Your task:
1852
+ - Identify what the user is CLAIMING.
1853
+ - Do NOT assume they are correct.
1854
+ - Do NOT write SQL.
1855
+
1856
+ Possible intents:
1857
+ - comparison_claim (mentions last year / previous / higher / lower)
1858
+ - overcount
1859
+ - undercount
1860
+ - missing_entities
1861
+ - wrong_filter
1862
+ - wrong_time_range
1863
+ - unclear
1864
+
1865
+ Return JSON ONLY.
1866
+ """
1867
+
1868
+ prompt = f"""
1869
+ USER QUESTION:
1870
+ {question}
1871
+
1872
+ USER FEEDBACK:
1873
+ {feedback_text}
1874
+
1875
+ Return JSON exactly like this:
1876
+ {{
1877
+ "intent": "comparison_claim | overcount | undercount | missing_entities | wrong_filter | wrong_time_range | unclear",
1878
+ "confidence": "low | medium | high"
1879
+ }}
1880
+ """
1881
+
1882
+ resp = llm.invoke(system_prompt + "\n\n" + prompt)
1883
+ print(resp)
1884
+
1885
+ try:
1886
+ return json.loads(resp.content)
1887
+ except Exception:
1888
+ return {
1889
+ "intent": "unclear",
1890
+ "confidence": "low",
1891
+ }
1892
+
1893
+
1894
+ # %%
1895
+ def run_validation_queries(
1896
+ intent_info: dict,
1897
+ question: str,
1898
+ original_sql: str,
1899
+ user_feedback: str,
1900
+ rag_context: str,
1901
+ ) -> dict:
1902
+ """
1903
+ Generate AND EXECUTE validation SQL.
1904
+ Separates equivalence checks vs context checks.
1905
+ Uses USER FEEDBACK explicitly for context validation.
1906
+ """
1907
+
1908
+ log_pipeline(
1909
+ "VALIDATION_START",
1910
+ f"intent={intent_info.get('intent')}"
1911
+ )
1912
+
1913
+ system_prompt = f"""
1914
+ You are validating a business claim.
1915
+
1916
+ User intent:
1917
+ {intent_info.get("intent")}
1918
+
1919
+ Original question:
1920
+ {question}
1921
+
1922
+ User feedback (IMPORTANT):
1923
+ {user_feedback}
1924
+
1925
+ Original SQL:
1926
+ {original_sql}
1927
+
1928
+ Schema context:
1929
+ {rag_context}
1930
+
1931
+ TASK:
1932
+ - Generate 2-4 SQL queries.
1933
+ - Mark EACH query explicitly as one of:
1934
+ - "equivalence":
1935
+ * Alternative formulations of the SAME question
1936
+ * Used to verify SQL correctness (joins, filters, logic)
1937
+ - "context":
1938
+ * Used to VERIFY or CHALLENGE the user's feedback
1939
+ * If the user mentions years, comparisons, growth, or missing entities,
1940
+ generate year-over-year or comparative queries.
1941
+ - DO NOT fix or modify the original SQL.
1942
+ - These queries are ONLY for validation and analysis.
1943
+
1944
+ IMPORTANT:
1945
+ - Context queries MUST be driven by the user's feedback.
1946
+ - If the user claims a number, for example, values of a previous year, then generate queries that can verify that claim from the data.
1947
+ - If the user provides some query to test in the feedback for the original question, use it as well. Do not assume it is correct.
1948
+ Return JSON ONLY in this format:
1949
+ {{
1950
+ "checks": [
1951
+ {{
1952
+ "type": "equivalence | context",
1953
+ "sql": "SELECT ...",
1954
+ "label": "short description of what this validates"
1955
+ }}
1956
+ ]
1957
+ }}
1958
+ """
1959
+
1960
+ resp = llm.invoke(system_prompt)
1961
+ raw = resp.content.strip()
1962
+ print(raw)
1963
+
1964
+ # Strip ```json fences if present
1965
+ if raw.startswith("```"):
1966
+ raw = raw.strip("`")
1967
+ raw = raw.replace("json", "", 1).strip()
1968
+
1969
+ try:
1970
+ parsed = json.loads(raw)
1971
+ except Exception as e:
1972
+ log_pipeline("VALIDATION_PARSE_FAILED", str(e))
1973
+ return {
1974
+ "equivalence": [],
1975
+ "context": [],
1976
+ }
1977
+
1978
+ equivalence = []
1979
+ context = []
1980
+
1981
+ for chk in parsed.get("checks", []):
1982
+ sql = chk.get("sql")
1983
+ chk_type = chk.get("type")
1984
+ label = chk.get("label", "unnamed check")
1985
+
1986
+ if not sql or chk_type not in ("equivalence", "context"):
1987
+ continue
1988
+
1989
+ with get_analytics_conn() as conn:
1990
+ df, err = execute_sql(sql, conn)
1991
+
1992
+ print(df)
1993
+
1994
+ log_pipeline(
1995
+ "VALIDATION_SQL_EXECUTED",
1996
+ f"type={chk_type} | label={label}"
1997
+ )
1998
+
1999
+ if err or df is None or df.empty:
2000
+ continue
2001
+
2002
+ payload = {
2003
+ "label": label,
2004
+ "sql": sql,
2005
+ "df": df,
2006
+ }
2007
+
2008
+ if chk_type == "equivalence":
2009
+ equivalence.append(payload)
2010
+ else:
2011
+ context.append(payload)
2012
+
2013
+ log_pipeline(
2014
+ "VALIDATION_DONE",
2015
+ f"equivalence={len(equivalence)}, context={len(context)}"
2016
+ )
2017
+
2018
+ return {
2019
+ "equivalence": equivalence,
2020
+ "context": context,
2021
+ }
2022
+
2023
+
2024
+ # %%
2025
+ def safe_json_load(text: str) -> dict | None:
2026
+ try:
2027
+ return json.loads(text)
2028
+ except Exception:
2029
+ pass
2030
+
2031
+ # Try extracting first JSON block
2032
+ match = re.search(r"\{[\s\S]*\}", text)
2033
+ if match:
2034
+ try:
2035
+ return json.loads(match.group(0))
2036
+ except Exception:
2037
+ return None
2038
+ return None
2039
+
2040
+
2041
+ # %%
2042
+ def decide_after_validation_with_llm(
2043
+ question: str,
2044
+ original_sql: str,
2045
+ user_feedback: str,
2046
+ intent_info: dict,
2047
+ original_df: pd.DataFrame,
2048
+ validation_payload: dict,
2049
+ rag_context: str,
2050
+ ) -> dict:
2051
+ """
2052
+ Use LLM to decide whether SQL must change or explanation is sufficient.
2053
+ Returns a structured decision.
2054
+ """
2055
+
2056
+ log_pipeline("DECISION_LLM_START")
2057
+
2058
+ # Build evidence
2059
+ evidence_md = "### Original Result\n"
2060
+ evidence_md += original_df.to_markdown(index=False)
2061
+
2062
+ for i, chk in enumerate(validation_payload.get("equivalence", []), 1):
2063
+ evidence_md += f"\n\n### Equivalence Check {i}: {chk['label']}\n"
2064
+ evidence_md += chk["df"].to_markdown(index=False)
2065
+
2066
+ for i, chk in enumerate(validation_payload.get("context", []), 1):
2067
+ evidence_md += f"\n\n### Context Check {i}: {chk['label']}\n"
2068
+ evidence_md += chk["df"].to_markdown(index=False)
2069
+
2070
+ system_prompt = """
2071
+ You are a senior analytics and data analyst.
2072
+
2073
+ Your task:
2074
+ - Decide whether the ORIGINAL SQL is logically correct or not, given the original question, the initial result, the feedback and the evidences of the equivalence check and context shown to you.
2075
+ - Do NOT judge based on expectations or growth alone.
2076
+ - Different definitions (e.g. customers vs customers-with-orders) are NOT errors.
2077
+ - Only mark SQL incorrect if evidences shows a true flaw in the results or the logic.
2078
+
2079
+ CRITICAL RULES:
2080
+ - You must return VALID JSON.
2081
+ - DO NOT add explnations outside JSON.
2082
+ - DO NO use markdown.
2083
+ - DO NOT include extra text.
2084
+
2085
+ Return JSON ONLY in this format:
2086
+ {
2087
+ "decision": "correct_sql | not_correct_sql",
2088
+ "confidence": "high | medium | low",
2089
+ "reason": "provide an explanation here"
2090
+ }
2091
+ """
2092
+
2093
+ prompt = f"""
2094
+ Original question:
2095
+ {question}
2096
+
2097
+ Original SQL:
2098
+ {original_sql}
2099
+
2100
+ User feedback:
2101
+ {user_feedback}
2102
+
2103
+ Interpreted intent:
2104
+ {intent_info}
2105
+
2106
+ Schema context:
2107
+ {rag_context}
2108
+
2109
+ Below is the validation evidence.
2110
+ """
2111
+
2112
+ resp = llm.invoke(system_prompt + "\n\n" + prompt + "\n\n" + evidence_md)
2113
+
2114
+ parsed = safe_json_load(resp.content)
2115
+
2116
+ if not parsed:
2117
+ decision = {
2118
+ "decision": "correct_sql",
2119
+ "confidence": "low",
2120
+ "reason": "LLM response was not valid JSON"
2121
+ }
2122
+ else:
2123
+ decision = parsed
2124
+
2125
+
2126
+ log_pipeline(
2127
+ "DECISION_LLM_DONE",
2128
+ f"decision={decision.get('decision')} | confidence={decision.get('confidence')} | reason ={decision.get('reason')}"
2129
+ )
2130
+
2131
+ return decision
2132
+
2133
+
2134
+ # %%
2135
+ def explain_validation_with_llm(
2136
+ question: str,
2137
+ feedback: str,
2138
+ original_df: pd.DataFrame,
2139
+ validation_payload: dict,
2140
+ schema_context: str,
2141
+ ) -> str:
2142
+ """
2143
+ Business explanation of validation results.
2144
+ """
2145
+
2146
+ log_pipeline("EXPLAIN_START")
2147
+
2148
+ evidence_md = (
2149
+ "## Pipeline: Feedback --> Validation -> Explanation\n\n"
2150
+ "### Baseline — Original Result\n"
2151
+ )
2152
+ evidence_md += original_df.to_markdown(index=False)
2153
+
2154
+ for i, chk in enumerate(validation_payload.get("context", []), 1):
2155
+ evidence_md += f"\n\n### Context Check {i}: {chk['label']}\n"
2156
+ evidence_md += chk["df"].to_markdown(index=False)
2157
+
2158
+ for i, chk in enumerate(validation_payload.get("equivalence", []), 1):
2159
+ evidence_md += f"\n\n### Verification Check {i}: {chk['label']}\n"
2160
+ evidence_md += chk["df"].to_markdown(index=False)
2161
+
2162
+ system_prompt = """
2163
+ You are a senior data analyst explaining results to a business user.
2164
+
2165
+ Rules:
2166
+ - DO NOT mention SQL mechanics
2167
+ - Compare numbers explicitly
2168
+ - Explain trends year-over-year.
2169
+ - Explain what the actual numbers are and what it means for the business.
2170
+ - If the user is concerned about the SQL or the numbers in their feedback, based on verification and context checks, explain the results to them.
2171
+ """
2172
+
2173
+ prompt = f"""
2174
+ Original question:
2175
+ {question}
2176
+
2177
+ User feedback:
2178
+ {feedback}
2179
+
2180
+ Schema of the data:
2181
+ {schema_context}
2182
+ """
2183
+
2184
+ resp = llm.invoke(system_prompt + "\n\n" + prompt + "\n\n" + evidence_md)
2185
+
2186
+ log_pipeline("EXPLAIN_DONE")
2187
+
2188
+ return resp.content.strip()
2189
+
2190
+
2191
+ # %%
2192
+
2193
+
2194
+ # %%
2195
+
2196
+ def feedback_pipeline_interactive(
2197
+ feedback_rating: str,
2198
+ feedback_comment: str,
2199
+ last_sql: str,
2200
+ last_rag_context: str,
2201
+ last_question: str,
2202
+ last_answer_md: str,
2203
+ last_df_state: dict,
2204
+ feedback_sql: str,
2205
+ attempts_left: int,
2206
+ ):
2207
+ """
2208
+ ALWAYS returns EXACTLY 10 values
2209
+ NO booleans
2210
+ followup_signal ∈ {"FOLLOWUP", "NONE"}
2211
+ """
2212
+
2213
+ last_df = state_to_df(last_df_state)
2214
+ rating = (feedback_rating or "").strip().lower()
2215
+ comment = (feedback_comment or "").strip()
2216
+ attempts_left = max(0, int(attempts_left or 0))
2217
+
2218
+ # ---------- INVALID RATING ----------
2219
+ if rating not in {"correct", "wrong"}:
2220
+ return (
2221
+ last_answer_md,
2222
+ last_df,
2223
+ last_sql,
2224
+ last_rag_context,
2225
+ last_question,
2226
+ last_answer_md,
2227
+ df_to_state(last_df),
2228
+ "NONE",
2229
+ "",
2230
+ attempts_left,
2231
+ )
2232
+
2233
+ # ---------- CORRECT ----------
2234
+ if rating == "correct":
2235
+ update_cache_on_feedback(
2236
+ question=last_question,
2237
+ original_doc_md=find_cached_doc_by_sql(
2238
+ last_question, last_sql, get_sql_cache_store()
2239
+ ),
2240
+ user_marked_good=True,
2241
+ llm_corrected_sql=None,
2242
+ llm_corrected_answer_md=None,
2243
+ store=get_sql_cache_store(),
2244
+ )
2245
+
2246
+ # ADD: persist feedback
2247
+ with get_analytics_conn() as conn:
2248
+ record_feedback(
2249
+ conn=conn,
2250
+ question=last_question,
2251
+ generated_sql=last_sql,
2252
+ model_answer=last_answer_md,
2253
+ rating="good",
2254
+ comment=comment or None,
2255
+ corrected_sql=None,
2256
+ )
2257
+
2258
+ md = "**Confirmed correct by user**\n\n" + last_answer_md
2259
+ return (
2260
+ md,
2261
+ last_df,
2262
+ last_sql,
2263
+ last_rag_context,
2264
+ last_question,
2265
+ md,
2266
+ df_to_state(last_df),
2267
+ "NONE",
2268
+ "",
2269
+ attempts_left,
2270
+ )
2271
+
2272
+ # ---------- WRONG ----------
2273
+ attempts_left = max(0, attempts_left - 1)
2274
+
2275
+ if not is_feedback_sufficient(comment):
2276
+ return (
2277
+ last_answer_md,
2278
+ last_df,
2279
+ last_sql,
2280
+ last_rag_context,
2281
+ last_question,
2282
+ last_answer_md,
2283
+ df_to_state(last_df),
2284
+ "FOLLOWUP",
2285
+ build_followup_prompt_for_user(comment),
2286
+ attempts_left,
2287
+ )
2288
+
2289
+ # ---------- VALIDATION ----------
2290
+ intent_info = interpret_feedback_intent(last_question, comment)
2291
+
2292
+ validation = run_validation_queries(
2293
+ intent_info=intent_info,
2294
+ question=last_question,
2295
+ original_sql=last_sql,
2296
+ user_feedback=comment,
2297
+ rag_context=last_rag_context,
2298
+ )
2299
+
2300
+ decision = decide_after_validation_with_llm(
2301
+ question=last_question,
2302
+ original_sql=last_sql,
2303
+ user_feedback=comment,
2304
+ intent_info=intent_info,
2305
+ original_df=last_df,
2306
+ validation_payload=validation,
2307
+ rag_context=last_rag_context,
2308
+ )
2309
+
2310
+ # ---------- EXPLAIN ----------
2311
+ if decision.get("decision") == "correct_sql":
2312
+ explanation = explain_validation_with_llm(
2313
+ question=last_question,
2314
+ feedback=comment,
2315
+ original_df=last_df,
2316
+ validation_payload=validation,
2317
+ schema_context=last_rag_context,
2318
+ )
2319
+
2320
+ return (
2321
+ explanation,
2322
+ last_df,
2323
+ last_sql,
2324
+ last_rag_context,
2325
+ last_question,
2326
+ explanation,
2327
+ df_to_state(last_df),
2328
+ "NONE",
2329
+ "",
2330
+ attempts_left,
2331
+ )
2332
+
2333
+ # ---------- SQL CORRECTION ----------
2334
+ corrected_sql, explanation = review_and_correct_sql_with_llm(
2335
+ question=last_question,
2336
+ generated_sql=last_sql,
2337
+ user_feedback_comment=comment,
2338
+ validation_payload=validation,
2339
+ decision_reason=decision.get("reason", ""),
2340
+ rag_context=last_rag_context,
2341
+ )
2342
+
2343
+ with get_analytics_conn() as conn:
2344
+ df_new, err = execute_sql(corrected_sql, conn)
2345
+
2346
+ # ADD: persist negative feedback + corrected SQL
2347
+ record_feedback(
2348
+ conn=conn,
2349
+ question=last_question,
2350
+ generated_sql=last_sql,
2351
+ model_answer=last_answer_md,
2352
+ rating="bad",
2353
+ comment=comment,
2354
+ corrected_sql=corrected_sql,
2355
+ )
2356
+
2357
+ final_md = "**SQL corrected**\n\n" + summarize_results(
2358
+ question=last_question,
2359
+ sql=corrected_sql,
2360
+ df=df_new if not err else None,
2361
+ rag_context=last_rag_context,
2362
+ error=err,
2363
+ )
2364
+
2365
+ return (
2366
+ final_md,
2367
+ df_new if not err else pd.DataFrame(),
2368
+ corrected_sql,
2369
+ last_rag_context,
2370
+ last_question,
2371
+ final_md,
2372
+ df_to_state(df_new if not err else pd.DataFrame()),
2373
+ "NONE",
2374
+ "",
2375
+ attempts_left,
2376
+ )
2377
+
2378
+
2379
+ # %%
2380
+
2381
+
2382
+ # %%
2383
+ import gradio as gr
2384
+ # =========================================================
2385
+ # Helpers
2386
+ # =========================================================
2387
+ def attempts_text(attempts: int) -> str:
2388
+ if attempts <= 0:
2389
+ return "**Feedback exhausted. Please ask a new question.**"
2390
+ return f"**Feedback attempts remaining: {attempts}**"
2391
+
2392
+
2393
+ def conversation_choices():
2394
+ convs = list_conversations()
2395
+ return [(c["title"], c["conversation_id"]) for c in convs]
2396
+
2397
+
2398
+ def load_conversation_dropdown():
2399
+ return gr.update(choices=conversation_choices())
2400
+
2401
+
2402
+ # =========================================================
2403
+ # CHAT TURN FIXED
2404
+ # =========================================================
2405
+ def chat_turn(user_input, history, state):
2406
+ history = history or []
2407
+
2408
+ # --- Create conversation if new ---
2409
+ if state.get("conversation_id") is None:
2410
+ ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M")
2411
+ title = f"{ts} · {user_input[:50]}"
2412
+ cid = create_conversation(title=title)
2413
+ state["conversation_id"] = cid
2414
+
2415
+ else:
2416
+ cid = state["conversation_id"]
2417
+
2418
+ append_message(cid, "user", user_input)
2419
+
2420
+ history.append({"role": "user", "content": user_input})
2421
+ history.append({"role": "assistant", "content": "🔎 Thinking…"})
2422
+
2423
+ (
2424
+ answer_md,
2425
+ df,
2426
+ sql,
2427
+ rag,
2428
+ q,
2429
+ ans_state,
2430
+ df_state,
2431
+ attempts,
2432
+ attempts_md,
2433
+ debug_md,
2434
+ ) = backend_pipeline(user_input)
2435
+
2436
+ history.pop()
2437
+ history.append({"role": "assistant", "content": answer_md})
2438
+
2439
+ append_message(cid, "assistant", answer_md)
2440
+
2441
+ new_state = {
2442
+ "conversation_id": cid,
2443
+ "last_sql": sql,
2444
+ "last_rag": rag,
2445
+ "last_question": q,
2446
+ "last_answer": ans_state,
2447
+ "last_df": df_state,
2448
+ "attempts": attempts,
2449
+ }
2450
+
2451
+ save_state(cid, new_state)
2452
+
2453
+ return (
2454
+ history,
2455
+ new_state,
2456
+ attempts_md,
2457
+ debug_md,
2458
+ df, # THIS FIXES THE TABLE
2459
+ gr.update(choices=conversation_choices()),
2460
+ )
2461
+
2462
+
2463
+ # =========================================================
2464
+ # FEEDBACK TURN
2465
+ # =========================================================
2466
+ def feedback_turn(feedback_rating, feedback_text, history, state):
2467
+ history = history or []
2468
+ cid = state.get("conversation_id")
2469
+
2470
+ # ---------- Attempts exhausted ----------
2471
+ if state["attempts"] <= 0:
2472
+ history.append({
2473
+ "role": "assistant",
2474
+ "content": "Feedback attempts are exhausted. Please ask a new question."
2475
+ })
2476
+ return history, state, attempts_text(0), pd.DataFrame()
2477
+
2478
+ # ---------- BAD without text ----------
2479
+ if feedback_rating == "wrong" and not (feedback_text or "").strip():
2480
+ new_attempts = max(0, state["attempts"] - 1)
2481
+ history.append({
2482
+ "role": "assistant",
2483
+ "content": "Please provide details before submitting negative feedback."
2484
+ })
2485
+ state["attempts"] = new_attempts
2486
+ save_state(cid, state)
2487
+ return history, state, attempts_text(new_attempts), state_to_df(state["last_df"])
2488
+
2489
+ history.append({"role": "user", "content": f"{feedback_rating.upper()} feedback"})
2490
+ history.append({"role": "assistant", "content": "🧪 Validating feedback…"})
2491
+
2492
+ (
2493
+ answer_md,
2494
+ df,
2495
+ sql_new,
2496
+ rag_new,
2497
+ q_new,
2498
+ ans_state,
2499
+ df_state,
2500
+ followup_flag,
2501
+ followup_prompt,
2502
+ attempts_new,
2503
+ ) = feedback_pipeline_interactive(
2504
+ feedback_rating=feedback_rating,
2505
+ feedback_comment=feedback_text,
2506
+ last_sql=state["last_sql"],
2507
+ last_rag_context=state["last_rag"],
2508
+ last_question=state["last_question"],
2509
+ last_answer_md=state["last_answer"],
2510
+ last_df_state=state["last_df"],
2511
+ feedback_sql="",
2512
+ attempts_left=state["attempts"],
2513
+ )
2514
+
2515
+ history.pop()
2516
+ history.append({
2517
+ "role": "assistant",
2518
+ "content": followup_prompt if followup_flag == "FOLLOWUP" else answer_md
2519
+ })
2520
+
2521
+ append_message(cid, "assistant", answer_md)
2522
+
2523
+ new_state = {
2524
+ "conversation_id": cid,
2525
+ "last_sql": sql_new,
2526
+ "last_rag": rag_new,
2527
+ "last_question": q_new,
2528
+ "last_answer": ans_state,
2529
+ "last_df": df_state,
2530
+ "attempts": attempts_new,
2531
+ }
2532
+
2533
+ save_state(cid, new_state)
2534
+
2535
+ return history, new_state, attempts_text(attempts_new), df
2536
+
2537
+
2538
+ # =========================================================
2539
+ # RESTORE / NEW CHAT
2540
+ # =========================================================
2541
+ def restore_conversation(conversation_id):
2542
+ msgs = load_messages(conversation_id)
2543
+ st = load_state(conversation_id)
2544
+ st["conversation_id"] = conversation_id
2545
+ df = state_to_df(st.get("last_df", {}))
2546
+ return msgs, st, attempts_text(st.get("attempts", 4)), df
2547
+
2548
+
2549
+ def new_chat():
2550
+ cid = create_conversation("New conversation")
2551
+ return [], {
2552
+ "conversation_id": cid,
2553
+ "last_sql": "",
2554
+ "last_rag": "",
2555
+ "last_question": "",
2556
+ "last_answer": "",
2557
+ "last_df": {},
2558
+ "attempts": 4,
2559
+ }, attempts_text(4), pd.DataFrame()
2560
+
2561
+
2562
+ # =========================================================
2563
+ # UI
2564
+ # =========================================================
2565
+ with gr.Blocks(analytics_enabled=False) as demo:
2566
+ gr.Markdown("# Analytics Assistant Chatbot")
2567
+
2568
+ with gr.Row():
2569
+ conversation_selector = gr.Dropdown(
2570
+ label="Conversation history",
2571
+ choices=[],
2572
+ interactive=True,
2573
+ )
2574
+ new_chat_btn = gr.Button("New chat")
2575
+
2576
+ chat = gr.Chatbot(height=520)
2577
+ attempts_md = gr.Markdown(value=attempts_text(4))
2578
+
2579
+ show_debug = gr.Checkbox(label="Show SQL / Cache info", value=False)
2580
+ debug_panel = gr.Markdown(visible=False)
2581
+
2582
+ result_df = gr.Dataframe(
2583
+ label="Query Results",
2584
+ interactive=False,
2585
+ wrap=True,
2586
+ )
2587
+
2588
+ state = gr.State({
2589
+ "conversation_id": None,
2590
+ "last_sql": "",
2591
+ "last_rag": "",
2592
+ "last_question": "",
2593
+ "last_answer": "",
2594
+ "last_df": {},
2595
+ "attempts": 4,
2596
+ })
2597
+
2598
+ user_input = gr.Textbox(placeholder="Ask a question…", lines=1, show_label=False)
2599
+ feedback_input = gr.Textbox(
2600
+ placeholder="Details (required if result is wrong)…",
2601
+ lines=2,
2602
+ show_label=False,
2603
+ )
2604
+
2605
+ with gr.Row():
2606
+ gr.Column(scale=6)
2607
+ good_btn = gr.Button("Good", size="sm")
2608
+ bad_btn = gr.Button("Bad", size="sm")
2609
+
2610
+ show_debug.change(
2611
+ lambda show, md: gr.update(value=md, visible=show),
2612
+ inputs=[show_debug, debug_panel],
2613
+ outputs=debug_panel,
2614
+ )
2615
+
2616
+ user_input.submit(
2617
+ chat_turn,
2618
+ inputs=[user_input, chat, state],
2619
+ outputs=[chat, state, attempts_md, debug_panel, result_df, conversation_selector],
2620
+ queue=False,
2621
+ )
2622
+ good_btn.click(
2623
+ lambda fb, h, s: feedback_turn("correct", fb, h, s),
2624
+ inputs=[feedback_input, chat, state],
2625
+ outputs=[chat, state, attempts_md, result_df],
2626
+ )
2627
+
2628
+ bad_btn.click(
2629
+ lambda fb, h, s: feedback_turn("wrong", fb, h, s),
2630
+ inputs=[feedback_input, chat, state],
2631
+ outputs=[chat, state, attempts_md, result_df],
2632
+ )
2633
+
2634
+ conversation_selector.change(
2635
+ restore_conversation,
2636
+ inputs=conversation_selector,
2637
+ outputs=[chat, state, attempts_md, result_df],
2638
+ )
2639
+
2640
+ new_chat_btn.click(
2641
+ new_chat,
2642
+ outputs=[chat, state, attempts_md, result_df],
2643
+ )
2644
+
2645
+ demo.load(load_conversation_dropdown, outputs=conversation_selector)
2646
+
2647
+
2648
+ # %%
2649
+ if __name__ == "__main__":
2650
+ demo.launch()
2651
+
2652
+
chat_data/chat_history.db ADDED
Binary file (41 kB). View file
 
data/olist_analytics.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8d45ae56701986b8ed859c234e89d1cab7054f4e093a90f76ff5ad23924be79
3
+ size 111755264
data/olist_customers_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/olist_geolocation_dataset.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b514f6fc991b9566aeba02aa5d67e2c3630f034b60a0e05aa0d082a3b66d88d6
3
+ size 61273883
data/olist_order_items_dataset.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bc4d068c4fe38cbb01bd90e8746e3c613fe7b4baef75fab7b0e329701c3e279
3
+ size 15438671
data/olist_order_payments_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/olist_order_reviews_dataset.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:012b61c7593e34f51fa614efdf802b9c7056ce6aae5307ddb93236e7cfc797d7
3
+ size 14451670
data/olist_orders_dataset.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8df58ef3d2d7e9944010f7beecd9b75367f5588ec6e3c91cec19ae3345ef9ecf
3
+ size 17654914
data/olist_products_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/olist_sellers_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/product_category_name_translation.csv ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ product_category_name,product_category_name_english
2
+ beleza_saude,health_beauty
3
+ informatica_acessorios,computers_accessories
4
+ automotivo,auto
5
+ cama_mesa_banho,bed_bath_table
6
+ moveis_decoracao,furniture_decor
7
+ esporte_lazer,sports_leisure
8
+ perfumaria,perfumery
9
+ utilidades_domesticas,housewares
10
+ telefonia,telephony
11
+ relogios_presentes,watches_gifts
12
+ alimentos_bebidas,food_drink
13
+ bebes,baby
14
+ papelaria,stationery
15
+ tablets_impressao_imagem,tablets_printing_image
16
+ brinquedos,toys
17
+ telefonia_fixa,fixed_telephony
18
+ ferramentas_jardim,garden_tools
19
+ fashion_bolsas_e_acessorios,fashion_bags_accessories
20
+ eletroportateis,small_appliances
21
+ consoles_games,consoles_games
22
+ audio,audio
23
+ fashion_calcados,fashion_shoes
24
+ cool_stuff,cool_stuff
25
+ malas_acessorios,luggage_accessories
26
+ climatizacao,air_conditioning
27
+ construcao_ferramentas_construcao,construction_tools_construction
28
+ moveis_cozinha_area_de_servico_jantar_e_jardim,kitchen_dining_laundry_garden_furniture
29
+ construcao_ferramentas_jardim,costruction_tools_garden
30
+ fashion_roupa_masculina,fashion_male_clothing
31
+ pet_shop,pet_shop
32
+ moveis_escritorio,office_furniture
33
+ market_place,market_place
34
+ eletronicos,electronics
35
+ eletrodomesticos,home_appliances
36
+ artigos_de_festas,party_supplies
37
+ casa_conforto,home_confort
38
+ construcao_ferramentas_ferramentas,costruction_tools_tools
39
+ agro_industria_e_comercio,agro_industry_and_commerce
40
+ moveis_colchao_e_estofado,furniture_mattress_and_upholstery
41
+ livros_tecnicos,books_technical
42
+ casa_construcao,home_construction
43
+ instrumentos_musicais,musical_instruments
44
+ moveis_sala,furniture_living_room
45
+ construcao_ferramentas_iluminacao,construction_tools_lights
46
+ industria_comercio_e_negocios,industry_commerce_and_business
47
+ alimentos,food
48
+ artes,art
49
+ moveis_quarto,furniture_bedroom
50
+ livros_interesse_geral,books_general_interest
51
+ construcao_ferramentas_seguranca,construction_tools_safety
52
+ fashion_underwear_e_moda_praia,fashion_underwear_beach
53
+ fashion_esporte,fashion_sport
54
+ sinalizacao_e_seguranca,signaling_and_security
55
+ pcs,computers
56
+ artigos_de_natal,christmas_supplies
57
+ fashion_roupa_feminina,fashio_female_clothing
58
+ eletrodomesticos_2,home_appliances_2
59
+ livros_importados,books_imported
60
+ bebidas,drinks
61
+ cine_foto,cine_photo
62
+ la_cuisine,la_cuisine
63
+ musica,music
64
+ casa_conforto_2,home_comfort_2
65
+ portateis_casa_forno_e_cafe,small_appliances_home_oven_and_coffee
66
+ cds_dvds_musicais,cds_dvds_musicals
67
+ dvds_blu_ray,dvds_blu_ray
68
+ flores,flowers
69
+ artes_e_artesanato,arts_and_craftmanship
70
+ fraldas_higiene,diapers_and_hygiene
71
+ fashion_roupa_infanto_juvenil,fashion_childrens_clothes
72
+ seguros_e_servicos,security_and_services
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ pandas==2.2.2
3
+ numpy==1.26.4
4
+ matplotlib==3.9.0
5
+
6
+ # LangChain and Groq
7
+ langchain==0.3.8
8
+ langchain-community==0.3.2
9
+ langchain-groq==0.2.1
10
+
11
+ # Vector Store (Chroma)
12
+ chromadb==0.5.18
13
+
14
+ # Embeddings
15
+ sentence-transformers==3.2.1
16
+
17
+ # UI
18
+ gradio==6.2.0
19
+
20
+ # Support libs
21
+ pydantic==2.9.2
22
+ tabulate==0.9.0
23
+ pyyaml==6.0.2
24
+
25
+ # Optional
26
+ tqdm==4.66.5