Seth0330 commited on
Commit
9a66ef3
·
verified ·
1 Parent(s): 901175d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -57
app.py CHANGED
@@ -14,12 +14,10 @@ from pydantic import Field
14
  from langchain_openai import ChatOpenAI
15
  from langchain.prompts import ChatPromptTemplate
16
 
17
- # --- CONFIG ---
18
  DB_PATH = "json_vector.db"
19
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
  EMBEDDING_MODEL = "text-embedding-ada-002"
21
 
22
- # --- State Initialization ---
23
  if "ingested_batches" not in st.session_state:
24
  st.session_state.ingested_batches = 0
25
  if "messages" not in st.session_state:
@@ -38,13 +36,11 @@ uploaded_files = st.file_uploader(
38
  "Upload JSON files in batches (any structure)", type="json", accept_multiple_files=True
39
  )
40
 
41
- # --- Enhanced Flattening: extract names from emails/user fields for LLM context
42
  def flatten_json_obj(obj, parent_key="", sep="."):
43
  items = {}
44
  if isinstance(obj, dict):
45
  for k, v in obj.items():
46
  new_key = f"{parent_key}{sep}{k}" if parent_key else k
47
- # Entity extraction: add name(s) from email/user
48
  if (
49
  k.lower() in {"customer", "user", "email", "username"} and
50
  isinstance(v, str) and "@" in v
@@ -125,7 +121,40 @@ def ingest_json_files(files):
125
  if uploaded_files and st.button("Ingest batch to database"):
126
  ingest_json_files(uploaded_files)
127
 
128
- # --- VECTOR RETRIEVAL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def query_vector_db(user_query, top_k=5):
130
  query_emb = get_embedding(user_query)
131
  conn = sqlite3.connect(DB_PATH)
@@ -151,7 +180,6 @@ def query_vector_db(user_query, top_k=5):
151
  docs.append(Document(page_content=row[4], metadata=meta))
152
  return docs
153
 
154
- # --- PYTHON FUZZY/KEYWORD SEARCH
155
  def python_fuzzy_match(user_query, top_k=5):
156
  query_terms = set(user_query.lower().replace("@", " ").replace(".", " ").split())
157
  conn = sqlite3.connect(DB_PATH)
@@ -177,45 +205,23 @@ def python_fuzzy_match(user_query, top_k=5):
177
  docs.append(Document(page_content=row[4], metadata=meta))
178
  return docs
179
 
180
- # --- HYBRID RETRIEVER
181
- def hybrid_query(user_query, top_k=5):
182
- vector_docs = query_vector_db(user_query, top_k=top_k)
183
- fuzzy_docs = python_fuzzy_match(user_query, top_k=top_k)
184
- seen_ids = set()
185
- all_docs = []
186
- for doc in (vector_docs + fuzzy_docs):
187
- doc_id = doc.metadata.get("id")
188
- if doc_id not in seen_ids:
189
- all_docs.append(doc)
190
- seen_ids.add(doc_id)
191
- return all_docs[:top_k]
192
-
193
  class HybridRetriever(BaseRetriever):
194
  top_k: int = Field(default=5)
195
  def _get_relevant_documents(self, query, run_manager=None, **kwargs):
196
  return hybrid_query(query, self.top_k)
197
 
198
- # --- SYSTEM PROMPT & PROMPT TEMPLATE
199
  system_prompt = (
200
- "You are a JSON data assistant. Always answer using only the provided context records. "
201
- "If a question mentions a person or entity (e.g. 'Johnny'), find all records where any field includes that name "
202
- "(including as part of an email or username, e.g. customer_name: johnny, customer: johnny.appleseed@gmail.com). "
203
- "If you find a relevant record, answer as directly as possible using its fields (e.g. 'Johnny spent $100'). "
204
- "If you cannot find the answer, reply: 'I don’t have that information.'"
205
- "If the user asks for the number or details of items in a list/array (e.g., completed tasks), use 'find_in_arrays'. "
206
- "If the user asks about the sum/total of a field for a name or identifier, use 'sum_field_by_name'. "
207
- "If the user asks about female names, use 'count_female_names'. "
208
- "If the user's query does not mention a key, use 'fuzzy_value_search' to match on any value. "
209
- "If a key is mentioned (like 'apps_installed'), use 'search_all_jsons' for that key and the value. "
210
- "You may use 'list_keys' to help discover the file structure if needed. "
211
- "Always give a direct answer from the data if possible."
212
  )
213
  prompt = ChatPromptTemplate.from_messages([
214
  ("system", system_prompt),
215
  ("human", "Here are the most relevant records:\n{context}\n\nQuestion: {question}")
216
  ])
217
 
218
-
219
  llm = ChatOpenAI(model="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
220
 
221
  retriever = HybridRetriever(top_k=5)
@@ -226,7 +232,6 @@ qa_chain = RetrievalQA.from_chain_type(
226
  return_source_documents=True,
227
  )
228
 
229
- # --- Chat UI and Conversation Area ---
230
  st.markdown("### Ask any question about your data, just like ChatGPT.")
231
  for msg in st.session_state.messages:
232
  if msg["role"] == "user":
@@ -236,29 +241,6 @@ for msg in st.session_state.messages:
236
  elif msg["role"] == "function":
237
  st.markdown(f"<details><summary><b>Function Output:</b></summary><pre>{msg['content']}</pre></details>", unsafe_allow_html=True)
238
 
239
- def show_json_links_and_modal():
240
- for msg in reversed(st.session_state.messages):
241
- if msg.get("role") == "function" and msg.get("content"):
242
- try:
243
- docs = json.loads(msg["content"])
244
- if isinstance(docs, list):
245
- for idx, doc in enumerate(docs):
246
- if isinstance(doc, dict) and "record" in doc:
247
- if st.button(f"View JSON: {doc.get('file', 'unknown')} record #{idx+1}", key=f"modal_function_{idx}"):
248
- st.session_state.modal_open = True
249
- st.session_state.modal_content = json.dumps(doc["record"], indent=2)
250
- st.session_state.modal_title = f"{doc.get('file', 'unknown')} record #{idx+1}"
251
- except Exception:
252
- continue
253
- break
254
- if st.session_state.modal_open:
255
- with st.expander(f"JSON Record: {st.session_state.modal_title}", expanded=True):
256
- st.code(st.session_state.modal_content, language="json")
257
- if st.button("Close", key="close_modal"):
258
- st.session_state.modal_open = False
259
-
260
- show_json_links_and_modal()
261
-
262
  def send_message():
263
  user_input = st.session_state.temp_input.strip()
264
  if not user_input:
 
14
  from langchain_openai import ChatOpenAI
15
  from langchain.prompts import ChatPromptTemplate
16
 
 
17
  DB_PATH = "json_vector.db"
18
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
19
  EMBEDDING_MODEL = "text-embedding-ada-002"
20
 
 
21
  if "ingested_batches" not in st.session_state:
22
  st.session_state.ingested_batches = 0
23
  if "messages" not in st.session_state:
 
36
  "Upload JSON files in batches (any structure)", type="json", accept_multiple_files=True
37
  )
38
 
 
39
  def flatten_json_obj(obj, parent_key="", sep="."):
40
  items = {}
41
  if isinstance(obj, dict):
42
  for k, v in obj.items():
43
  new_key = f"{parent_key}{sep}{k}" if parent_key else k
 
44
  if (
45
  k.lower() in {"customer", "user", "email", "username"} and
46
  isinstance(v, str) and "@" in v
 
121
  if uploaded_files and st.button("Ingest batch to database"):
122
  ingest_json_files(uploaded_files)
123
 
124
+ # --- Improved entity search/filter
125
+ def extract_main_entity(question):
126
+ # crude: get the first capitalized word, or all words
127
+ tokens = re.findall(r"\b([A-Za-z0-9]+)\b", question)
128
+ keywords = [t.lower() for t in tokens if t.lower() not in {"how", "much", "did", "spend", "was", "the", "is", "in", "on", "for", "a", "an", "of", "to", "with"}]
129
+ # e.g. ["johnny", "spend"] → "johnny"
130
+ return keywords[0] if keywords else None
131
+
132
+ def filter_records_by_entity(records, entity):
133
+ matches = []
134
+ for doc in records:
135
+ if entity and entity in doc.page_content.lower():
136
+ matches.append(doc)
137
+ return matches if matches else records
138
+
139
+ def hybrid_query(user_query, top_k=5):
140
+ vector_docs = query_vector_db(user_query, top_k=top_k)
141
+ fuzzy_docs = python_fuzzy_match(user_query, top_k=top_k)
142
+ all_docs = []
143
+ seen_ids = set()
144
+ for doc in (vector_docs + fuzzy_docs):
145
+ doc_id = doc.metadata.get("id")
146
+ if doc_id not in seen_ids:
147
+ all_docs.append(doc)
148
+ seen_ids.add(doc_id)
149
+ # Filter for entity match if possible
150
+ entity = extract_main_entity(user_query)
151
+ entity_docs = filter_records_by_entity(all_docs, entity) if entity else all_docs
152
+ # Optionally, highlight the entity in the flat_text for the LLM
153
+ for doc in entity_docs:
154
+ if entity:
155
+ doc.page_content = re.sub(rf"({re.escape(entity)})", r"**\1**", doc.page_content, flags=re.IGNORECASE)
156
+ return entity_docs[:top_k]
157
+
158
  def query_vector_db(user_query, top_k=5):
159
  query_emb = get_embedding(user_query)
160
  conn = sqlite3.connect(DB_PATH)
 
180
  docs.append(Document(page_content=row[4], metadata=meta))
181
  return docs
182
 
 
183
  def python_fuzzy_match(user_query, top_k=5):
184
  query_terms = set(user_query.lower().replace("@", " ").replace(".", " ").split())
185
  conn = sqlite3.connect(DB_PATH)
 
205
  docs.append(Document(page_content=row[4], metadata=meta))
206
  return docs
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  class HybridRetriever(BaseRetriever):
209
  top_k: int = Field(default=5)
210
  def _get_relevant_documents(self, query, run_manager=None, **kwargs):
211
  return hybrid_query(query, self.top_k)
212
 
213
+ # --- Prompt (explicitly tells LLM what to do)
214
  system_prompt = (
215
+ "You are a JSON data assistant. "
216
+ "If a question mentions a name (like Johnny), find any record where that name appears as part of any field value (including emails or usernames). "
217
+ "Use the provided records to answer directly. If you can't find the answer, reply: 'I don’t have that information.' "
218
+ "Never make up data. Never ask for clarification."
 
 
 
 
 
 
 
 
219
  )
220
  prompt = ChatPromptTemplate.from_messages([
221
  ("system", system_prompt),
222
  ("human", "Here are the most relevant records:\n{context}\n\nQuestion: {question}")
223
  ])
224
 
 
225
  llm = ChatOpenAI(model="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
226
 
227
  retriever = HybridRetriever(top_k=5)
 
232
  return_source_documents=True,
233
  )
234
 
 
235
  st.markdown("### Ask any question about your data, just like ChatGPT.")
236
  for msg in st.session_state.messages:
237
  if msg["role"] == "user":
 
241
  elif msg["role"] == "function":
242
  st.markdown(f"<details><summary><b>Function Output:</b></summary><pre>{msg['content']}</pre></details>", unsafe_allow_html=True)
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  def send_message():
245
  user_input = st.session_state.temp_input.strip()
246
  if not user_input: