Seth0330 commited on
Commit
36c52bd
·
verified ·
1 Parent(s): ed0fe96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -146
app.py CHANGED
@@ -1,27 +1,23 @@
1
- import streamlit as st
2
  import os
3
- import json
4
- import re
5
- import sqlite3
6
  import pandas as pd
 
 
 
7
  import numpy as np
8
  import datetime
9
- from typing import List
10
- import openai
11
- from langchain.schema import Document
12
  from langchain.chains import RetrievalQA
13
- from langchain_community.llms import OpenAI as LangOpenAI
14
  from langchain_core.retrievers import BaseRetriever
15
  from pydantic import Field
 
 
16
 
17
- # ---- CONFIG ----
18
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
19
  EMBEDDING_MODEL = "text-embedding-ada-002"
20
- DB_FILE = "json_vector_store.db"
21
 
22
- st.set_page_config(page_title="Chat with Your Vectorized JSON Files", layout="wide")
23
-
24
- # --- Session State ---
25
  if "ingested_batches" not in st.session_state:
26
  st.session_state.ingested_batches = 0
27
  if "messages" not in st.session_state:
@@ -30,17 +26,29 @@ if "json_links" not in st.session_state:
30
  st.session_state.json_links = []
31
  if "json_link_details" not in st.session_state:
32
  st.session_state.json_link_details = {}
33
- if "modal_link" not in st.session_state:
34
- st.session_state.modal_link = None
35
- if "last_entity" not in st.session_state:
36
- st.session_state.last_entity = None
37
 
38
- # ---- Helper: Flatten JSON ----
 
 
 
 
 
 
39
  def flatten_json_obj(obj, parent_key="", sep="."):
40
  items = {}
41
  if isinstance(obj, dict):
42
  for k, v in obj.items():
43
  new_key = f"{parent_key}{sep}{k}" if parent_key else k
 
 
 
 
 
 
 
 
 
 
44
  items.update(flatten_json_obj(v, new_key, sep=sep))
45
  elif isinstance(obj, list):
46
  for i, v in enumerate(obj):
@@ -50,175 +58,218 @@ def flatten_json_obj(obj, parent_key="", sep="."):
50
  items[parent_key] = obj
51
  return items
52
 
53
- # ---- Helper: Get OpenAI Embedding ----
54
  def get_embedding(text):
55
- openai.api_key = OPENAI_API_KEY
56
- resp = openai.embeddings.create(input=[text], model=EMBEDDING_MODEL)
57
- return resp.data[0].embedding
58
 
59
- # ---- SQLite DB Setup ----
60
  def ensure_table():
61
- with sqlite3.connect(DB_FILE) as conn:
62
- c = conn.cursor()
63
- c.execute("""
64
- CREATE TABLE IF NOT EXISTS json_records (
65
- id INTEGER PRIMARY KEY AUTOINCREMENT,
66
- batch_time TEXT,
67
- source_file TEXT,
68
- raw_json TEXT,
69
- flat_text TEXT,
70
- embedding BLOB
71
- )
72
- """)
73
- conn.commit()
74
-
75
- def insert_records(records):
76
- with sqlite3.connect(DB_FILE) as conn:
77
- c = conn.cursor()
78
- c.executemany(
79
- "INSERT INTO json_records (batch_time, source_file, raw_json, flat_text, embedding) VALUES (?, ?, ?, ?, ?)",
80
- records
81
- )
82
- conn.commit()
83
-
84
- def all_records():
85
- with sqlite3.connect(DB_FILE) as conn:
86
- c = conn.cursor()
87
- c.execute("SELECT id, batch_time, source_file, raw_json, flat_text, embedding FROM json_records")
88
- return c.fetchall()
89
-
90
- # ---- Ingest JSON Batch ----
91
  def ingest_json_files(files):
92
  ensure_table()
93
  rows = []
94
  batch_time = datetime.datetime.utcnow().isoformat()
95
  for file in files:
 
96
  raw = json.load(file)
97
  source_name = file.name
98
- if isinstance(raw, list):
99
- records = raw
100
- elif isinstance(raw, dict):
101
- main_lists = [v for v in raw.values() if isinstance(v, list)]
102
- if main_lists:
103
- records = main_lists[0]
104
- else:
105
- records = [raw]
106
- else:
107
- records = [raw]
108
  for rec in records:
109
  flat = flatten_json_obj(rec)
110
- if "customer" in rec and isinstance(rec["customer"], str):
111
- first_name = rec["customer"].split("@")[0].replace(".", " ")
112
- flat["customer_name"] = first_name
113
- flat["customer_all_names"] = first_name.replace(".", " ")
114
  flat_text = "; ".join([f"{k}: {v}" for k, v in flat.items()])
115
- rows.append((batch_time, source_name, json.dumps(rec), flat_text, None))
116
- df = pd.DataFrame(rows, columns=["batch_time", "source_file", "raw_json", "flat_text", "embedding"])
117
- st.write(f"Flattened {len(df)} records. Generating embeddings...")
 
 
 
118
  df["embedding"] = df["flat_text"].apply(get_embedding)
119
- sql_rows = [
120
- (
121
- row.batch_time, row.source_file, row.raw_json, row.flat_text,
122
- sqlite3.Binary(np.array(row.embedding, dtype=np.float32).tobytes())
123
- )
124
- for _, row in df.iterrows()
125
- ]
126
- insert_records(sql_rows)
 
 
127
  st.success(f"Ingested and indexed {len(df)} new records!")
128
  st.session_state.ingested_batches += 1
129
 
130
- # ---- Hybrid Retrieval ----
 
 
131
  def query_vector_db(user_query, top_k=5):
132
- query_emb = np.array(get_embedding(user_query), dtype=np.float32)
 
 
 
133
  results = []
134
- for row in all_records():
135
  db_emb = np.frombuffer(row[5], dtype=np.float32)
136
- if len(db_emb) != len(query_emb):
137
- continue
138
- sim = float(np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb)))
139
  results.append((sim, row))
140
- results = sorted(results, reverse=True, key=lambda x: x[0])[:top_k]
 
141
  docs = []
142
  for sim, row in results:
143
  meta = {
144
  "id": row[0],
145
- "batch_time": row[1],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  "source_file": row[2],
147
- "similarity": f"{sim:.4f}",
148
  "raw_json": row[3],
149
  }
150
  docs.append(Document(page_content=row[4], metadata=meta))
151
  return docs
152
 
153
- # ---- LangChain Retriever Adapter ----
154
- class SQLiteVectorRetriever(BaseRetriever):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  top_k: int = Field(default=5)
 
 
156
 
157
- def get_relevant_documents(self, query: str) -> List[Document]:
158
- return query_vector_db(query, self.top_k)
 
 
 
 
 
 
 
 
 
 
159
 
160
- # ---- LangChain LLM & QA Chain ----
161
- llm = LangOpenAI(model_name="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
162
- retriever = SQLiteVectorRetriever(top_k=5)
163
  qa_chain = RetrievalQA.from_chain_type(
164
  llm=llm,
165
  retriever=retriever,
 
166
  return_source_documents=True,
167
- # chain_type_kwargs={"input_key": "query"} # <--- REMOVED
168
  )
169
 
170
- # ---- Ingestion UI ----
171
- st.title("Chat with Your Vectorized JSON Files (Hybrid Retrieval, SQLite, LLM)")
172
- uploaded_files = st.file_uploader(
173
- "Upload JSON files in batches (any structure)", type="json", accept_multiple_files=True
174
- )
175
- if uploaded_files and st.button("Ingest batch to database"):
176
- ingest_json_files(uploaded_files)
177
-
178
- # ---- Conversation UI ----
179
  st.markdown("### Ask any question about your data, just like ChatGPT.")
180
 
181
- def update_last_entity(doc):
182
- try:
183
- rec = json.loads(doc.metadata["raw_json"])
184
- if "customer" in rec and "@" in rec["customer"]:
185
- st.session_state.last_entity = rec["customer"]
186
- elif "customer_name" in rec:
187
- st.session_state.last_entity = rec["customer_name"]
188
- except Exception:
189
- pass
190
-
191
- def render_json_links():
192
- for key in st.session_state.json_links:
193
- info = st.session_state.json_link_details[key]
194
- label = info["label"]
195
- rec = info["record"]
196
- if st.button(f"[view JSON] {label}", key=key, help="Show JSON record", use_container_width=False):
197
- st.session_state.modal_link = key
198
- if st.session_state.modal_link:
199
- info = st.session_state.json_link_details[st.session_state.modal_link]
200
- with st.container():
201
- st.code(json.dumps(info["record"], indent=2), language="json")
202
 
203
  def send_message():
204
  user_input = st.session_state.temp_input.strip()
205
  if not user_input:
206
  return
207
- pronoun = re.search(r"\b(he|his|him|her|she|their)\b", user_input, re.I)
208
- if st.session_state.last_entity and pronoun:
209
- q = f"For {st.session_state.last_entity}: {user_input}"
210
- else:
211
- q = user_input
212
  st.session_state.messages.append({"role": "user", "content": user_input})
213
  with st.spinner("Thinking..."):
214
- result = qa_chain({"query": q}) # <---- FIXED: use 'question' as input key
215
  answer = result['result']
216
  st.session_state.messages.append({"role": "assistant", "content": answer})
217
  docs = result['source_documents']
218
  link_keys = []
219
  link_details = {}
220
- if docs:
221
- update_last_entity(docs[0])
222
  for idx, doc in enumerate(docs):
223
  link_key = f"json_{doc.metadata['id']}_{idx}"
224
  rec = json.loads(doc.metadata["raw_json"])
@@ -227,25 +278,13 @@ def send_message():
227
  link_keys.append(link_key)
228
  st.session_state.json_links = link_keys
229
  st.session_state.json_link_details = link_details
230
- st.session_state.modal_link = None
231
  st.session_state.temp_input = ""
232
 
233
- for msg in st.session_state.messages:
234
- if msg["role"] == "user":
235
- st.markdown(f"<b style='color:#3575dd'>User:</b> <span style='color:#111'>{msg['content']}</span>", unsafe_allow_html=True)
236
- elif msg["role"] == "assistant":
237
- st.markdown(f"<b style='color:#1c6e4c'>Agent:</b> <span style='color:#111'>{msg['content']}</span>", unsafe_allow_html=True)
238
-
239
- if st.session_state.json_links:
240
- st.markdown("<b>Function Output:</b>", unsafe_allow_html=True)
241
- render_json_links()
242
-
243
  st.text_input("Your message:", key="temp_input", on_change=send_message)
 
244
  if st.button("Clear chat"):
245
  st.session_state.messages = []
246
  st.session_state.json_links = []
247
  st.session_state.json_link_details = {}
248
- st.session_state.modal_link = None
249
- st.session_state.last_entity = None
250
 
251
  st.info(f"Batches ingested so far (this session): {st.session_state.ingested_batches}")
 
 
1
  import os
2
+ import streamlit as st
 
 
3
  import pandas as pd
4
+ import openai
5
+ import sqlite3
6
+ import json
7
  import numpy as np
8
  import datetime
9
+ import re
 
 
10
  from langchain.chains import RetrievalQA
11
+ from langchain.schema import Document
12
  from langchain_core.retrievers import BaseRetriever
13
  from pydantic import Field
14
+ from langchain_openai import ChatOpenAI
15
+ from langchain.prompts import ChatPromptTemplate
16
 
17
+ DB_PATH = "json_vector.db"
18
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
19
  EMBEDDING_MODEL = "text-embedding-ada-002"
 
20
 
 
 
 
21
  if "ingested_batches" not in st.session_state:
22
  st.session_state.ingested_batches = 0
23
  if "messages" not in st.session_state:
 
26
  st.session_state.json_links = []
27
  if "json_link_details" not in st.session_state:
28
  st.session_state.json_link_details = {}
 
 
 
 
29
 
30
+ st.set_page_config(page_title="Chat with Your JSON Vectors (Hybrid, Clean)", layout="wide")
31
+ st.title("Chat with Your Vectorized JSON Files")
32
+
33
+ uploaded_files = st.file_uploader(
34
+ "Upload JSON files in batches (any structure)", type="json", accept_multiple_files=True
35
+ )
36
+
37
  def flatten_json_obj(obj, parent_key="", sep="."):
38
  items = {}
39
  if isinstance(obj, dict):
40
  for k, v in obj.items():
41
  new_key = f"{parent_key}{sep}{k}" if parent_key else k
42
+ if (
43
+ k.lower() in {"customer", "user", "email", "username"} and
44
+ isinstance(v, str) and "@" in v
45
+ ):
46
+ local = v.split("@")[0]
47
+ local_clean = re.sub(r'[^a-zA-Z0-9]', ' ', local)
48
+ parts = [part for part in local_clean.split() if part]
49
+ if parts:
50
+ items[new_key + "_name"] = parts[0].lower()
51
+ items[new_key + "_all_names"] = " ".join(parts).lower()
52
  items.update(flatten_json_obj(v, new_key, sep=sep))
53
  elif isinstance(obj, list):
54
  for i, v in enumerate(obj):
 
58
  items[parent_key] = obj
59
  return items
60
 
 
61
  def get_embedding(text):
62
+ client = openai.OpenAI(api_key=OPENAI_API_KEY)
63
+ response = client.embeddings.create(input=[text], model=EMBEDDING_MODEL)
64
+ return response.data[0].embedding
65
 
 
66
  def ensure_table():
67
+ conn = sqlite3.connect(DB_PATH)
68
+ cursor = conn.cursor()
69
+ cursor.execute("""
70
+ CREATE TABLE IF NOT EXISTS json_records (
71
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
72
+ batch_time TEXT,
73
+ source_file TEXT,
74
+ raw_json TEXT,
75
+ flat_text TEXT,
76
+ embedding BLOB
77
+ )
78
+ """)
79
+ conn.commit()
80
+ conn.close()
81
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def ingest_json_files(files):
83
  ensure_table()
84
  rows = []
85
  batch_time = datetime.datetime.utcnow().isoformat()
86
  for file in files:
87
+ file.seek(0)
88
  raw = json.load(file)
89
  source_name = file.name
90
+ records = raw if isinstance(raw, list) else [raw]
 
 
 
 
 
 
 
 
 
91
  for rec in records:
92
  flat = flatten_json_obj(rec)
 
 
 
 
93
  flat_text = "; ".join([f"{k}: {v}" for k, v in flat.items()])
94
+ rows.append((batch_time, source_name, json.dumps(rec), flat_text))
95
+ if not rows:
96
+ st.warning("No records found in uploaded files!")
97
+ return
98
+ df = pd.DataFrame(rows, columns=["batch_time", "source_file", "raw_json", "flat_text"])
99
+ st.write(f"Flattened {len(df)} records. Generating embeddings (this may take time, please wait)...")
100
  df["embedding"] = df["flat_text"].apply(get_embedding)
101
+ conn = sqlite3.connect(DB_PATH)
102
+ cursor = conn.cursor()
103
+ for _, row in df.iterrows():
104
+ emb_bytes = np.array(row.embedding, dtype=np.float32).tobytes()
105
+ cursor.execute("""
106
+ INSERT INTO json_records (batch_time, source_file, raw_json, flat_text, embedding)
107
+ VALUES (?, ?, ?, ?, ?)
108
+ """, (row.batch_time, row.source_file, row.raw_json, row.flat_text, emb_bytes))
109
+ conn.commit()
110
+ conn.close()
111
  st.success(f"Ingested and indexed {len(df)} new records!")
112
  st.session_state.ingested_batches += 1
113
 
114
+ if uploaded_files and st.button("Ingest batch to database"):
115
+ ingest_json_files(uploaded_files)
116
+
117
  def query_vector_db(user_query, top_k=5):
118
+ query_emb = get_embedding(user_query)
119
+ conn = sqlite3.connect(DB_PATH)
120
+ cursor = conn.cursor()
121
+ cursor.execute("SELECT id, batch_time, source_file, raw_json, flat_text, embedding FROM json_records")
122
  results = []
123
+ for row in cursor.fetchall():
124
  db_emb = np.frombuffer(row[5], dtype=np.float32)
125
+ if len(db_emb) != len(query_emb): continue
126
+ sim = np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb))
 
127
  results.append((sim, row))
128
+ conn.close()
129
+ results = sorted(results, reverse=True)[:top_k]
130
  docs = []
131
  for sim, row in results:
132
  meta = {
133
  "id": row[0],
134
+ "batch_time": str(row[1]),
135
+ "source_file": row[2],
136
+ "similarity": f"{sim:.4f} (embedding)",
137
+ "raw_json": row[3],
138
+ }
139
+ docs.append(Document(page_content=row[4], metadata=meta))
140
+ return docs
141
+
142
+ def python_fuzzy_match(user_query, top_k=5):
143
+ query_terms = set(user_query.lower().replace("@", " ").replace(".", " ").split())
144
+ conn = sqlite3.connect(DB_PATH)
145
+ cursor = conn.cursor()
146
+ cursor.execute("SELECT id, batch_time, source_file, raw_json, flat_text FROM json_records")
147
+ results = []
148
+ for row in cursor.fetchall():
149
+ flat_text = row[4].lower()
150
+ score = sum(any(term in flat_text for term in query_terms) for term in query_terms)
151
+ if score > 0:
152
+ results.append((score, row))
153
+ conn.close()
154
+ results = sorted(results, reverse=True)[:top_k]
155
+ docs = []
156
+ for score, row in results:
157
+ meta = {
158
+ "id": row[0],
159
+ "batch_time": str(row[1]),
160
  "source_file": row[2],
161
+ "similarity": f"{score} (fuzzy)",
162
  "raw_json": row[3],
163
  }
164
  docs.append(Document(page_content=row[4], metadata=meta))
165
  return docs
166
 
167
+ def extract_main_entity(question):
168
+ import re
169
+ quoted = re.findall(r"['\"]([^'\"]+)['\"]", question)
170
+ if quoted:
171
+ return quoted[0].lower()
172
+ email = re.findall(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", question)
173
+ if email:
174
+ return email[0].lower().split('@')[0]
175
+ tokens = re.findall(r"\b([A-Za-z0-9]+)\b", question)
176
+ stopwords = {"how", "much", "did", "spend", "was", "the", "is", "in", "on", "for", "a", "an", "of", "to", "with"}
177
+ keywords = [t.lower() for t in tokens if t.lower() not in stopwords]
178
+ if not keywords:
179
+ return ""
180
+ return max(keywords, key=len)
181
+
182
+ def filter_records_by_entity(records, entity):
183
+ if not entity:
184
+ return records
185
+ matches = []
186
+ for doc in records:
187
+ if entity in doc.page_content.lower():
188
+ matches.append(doc)
189
+ elif any(entity in v.lower() for v in doc.page_content.split(';')):
190
+ matches.append(doc)
191
+ return matches if matches else records
192
+
193
+ def hybrid_query(user_query, top_k=5):
194
+ vector_docs = query_vector_db(user_query, top_k=top_k)
195
+ fuzzy_docs = python_fuzzy_match(user_query, top_k=top_k)
196
+ all_docs = []
197
+ seen_ids = set()
198
+ for doc in (vector_docs + fuzzy_docs):
199
+ doc_id = doc.metadata.get("id")
200
+ if doc_id not in seen_ids:
201
+ all_docs.append(doc)
202
+ seen_ids.add(doc_id)
203
+ entity = extract_main_entity(user_query)
204
+ entity_docs = filter_records_by_entity(all_docs, entity) if entity else all_docs
205
+ if entity_docs:
206
+ doc = entity_docs[0]
207
+ return [doc]
208
+ else:
209
+ return all_docs[:1]
210
+
211
+ class HybridRetriever(BaseRetriever):
212
  top_k: int = Field(default=5)
213
+ def _get_relevant_documents(self, query, run_manager=None, **kwargs):
214
+ return hybrid_query(query, self.top_k)
215
 
216
+ system_prompt = (
217
+ "You are a JSON data assistant. "
218
+ "If the question mentions a name or email (e.g. Johnny), match it to any field value (even as part of an email) "
219
+ "and answer directly using the record's fields. "
220
+ "For example, if 'customer: johnny.appleseed@gmail.com' and the question is about Johnny, you should use that record."
221
+ "If you can't find the answer, reply: 'I don’t have that information.'"
222
+ "Never make up data. Never ask for clarification."
223
+ )
224
+ prompt = ChatPromptTemplate.from_messages([
225
+ ("system", system_prompt),
226
+ ("human", "Here are the most relevant records:\n{context}\n\nQuestion: {question}")
227
+ ])
228
 
229
+ llm = ChatOpenAI(model="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
230
+ retriever = HybridRetriever(top_k=5)
 
231
  qa_chain = RetrievalQA.from_chain_type(
232
  llm=llm,
233
  retriever=retriever,
234
+ chain_type_kwargs={"prompt": prompt},
235
  return_source_documents=True,
 
236
  )
237
 
 
 
 
 
 
 
 
 
 
238
  st.markdown("### Ask any question about your data, just like ChatGPT.")
239
 
240
+ def show_tiny_json_links():
241
+ # Only show for the last assistant answer if there are matching JSONs
242
+ if not st.session_state.json_links:
243
+ return
244
+ for idx, link_key in enumerate(st.session_state.json_links):
245
+ label = st.session_state.json_link_details[link_key]['label']
246
+ rec = st.session_state.json_link_details[link_key]['record']
247
+ expander_label = f"<span style='font-size:11px; color:#444; text-decoration:underline;'>[view JSON]</span> <span style='font-size:10px; color:#aaa'>{label}</span>"
248
+ with st.expander(label="", expanded=False):
249
+ st.markdown(expander_label, unsafe_allow_html=True)
250
+ st.code(json.dumps(rec, indent=2), language="json")
251
+ st.session_state.json_links = []
252
+ st.session_state.json_link_details = {}
253
+
254
+ for msg in st.session_state.messages:
255
+ if msg["role"] == "user":
256
+ st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True)
257
+ elif msg["role"] == "assistant":
258
+ st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {msg['content']}</div>", unsafe_allow_html=True)
259
+ show_tiny_json_links()
 
260
 
261
  def send_message():
262
  user_input = st.session_state.temp_input.strip()
263
  if not user_input:
264
  return
 
 
 
 
 
265
  st.session_state.messages.append({"role": "user", "content": user_input})
266
  with st.spinner("Thinking..."):
267
+ result = qa_chain({"query": user_input})
268
  answer = result['result']
269
  st.session_state.messages.append({"role": "assistant", "content": answer})
270
  docs = result['source_documents']
271
  link_keys = []
272
  link_details = {}
 
 
273
  for idx, doc in enumerate(docs):
274
  link_key = f"json_{doc.metadata['id']}_{idx}"
275
  rec = json.loads(doc.metadata["raw_json"])
 
278
  link_keys.append(link_key)
279
  st.session_state.json_links = link_keys
280
  st.session_state.json_link_details = link_details
 
281
  st.session_state.temp_input = ""
282
 
 
 
 
 
 
 
 
 
 
 
283
  st.text_input("Your message:", key="temp_input", on_change=send_message)
284
+
285
  if st.button("Clear chat"):
286
  st.session_state.messages = []
287
  st.session_state.json_links = []
288
  st.session_state.json_link_details = {}
 
 
289
 
290
  st.info(f"Batches ingested so far (this session): {st.session_state.ingested_batches}")