Seth0330 commited on
Commit
7bbdd37
·
verified ·
1 Parent(s): 7dee79b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -6,12 +6,15 @@ import sqlite3
6
  import pandas as pd
7
  import numpy as np
8
  import datetime
9
- from typing import List, Dict
10
  import openai
11
  from langchain.schema import Document
12
  from langchain.chains import RetrievalQA
13
  from langchain_community.llms import OpenAI as LangOpenAI
14
 
 
 
 
15
  # ---- CONFIG ----
16
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
17
  EMBEDDING_MODEL = "text-embedding-ada-002"
@@ -35,7 +38,6 @@ if "last_entity" not in st.session_state:
35
 
36
  # ---- Helper: Flatten JSON ----
37
  def flatten_json_obj(obj, parent_key="", sep="."):
38
- """Flatten nested JSON objects/lists with dot notation."""
39
  items = {}
40
  if isinstance(obj, dict):
41
  for k, v in obj.items():
@@ -98,7 +100,6 @@ def ingest_json_files(files):
98
  if isinstance(raw, list):
99
  records = raw
100
  elif isinstance(raw, dict):
101
- # If dict with a single main list, use it
102
  main_lists = [v for v in raw.values() if isinstance(v, list)]
103
  if main_lists:
104
  records = main_lists[0]
@@ -108,7 +109,7 @@ def ingest_json_files(files):
108
  records = [raw]
109
  for rec in records:
110
  flat = flatten_json_obj(rec)
111
- # Heuristic: add top-level "name"/"customer" fields for entity tracking
112
  if "customer" in rec and isinstance(rec["customer"], str):
113
  first_name = rec["customer"].split("@")[0].replace(".", " ")
114
  flat["customer_name"] = first_name
@@ -139,7 +140,6 @@ def query_vector_db(user_query, top_k=5):
139
  continue
140
  sim = float(np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb)))
141
  results.append((sim, row))
142
- # Top K by similarity
143
  results = sorted(results, reverse=True, key=lambda x: x[0])[:top_k]
144
  docs = []
145
  for sim, row in results:
@@ -154,13 +154,17 @@ def query_vector_db(user_query, top_k=5):
154
  return docs
155
 
156
  # ---- LangChain Retriever Adapter ----
157
- class SQLiteVectorRetriever:
158
- def get_relevant_documents(self, query):
159
- return query_vector_db(query, top_k=5)
 
 
 
 
160
 
161
  # ---- LangChain LLM & QA Chain ----
162
  llm = LangOpenAI(model_name="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
163
- retriever = SQLiteVectorRetriever()
164
  qa_chain = RetrievalQA.from_chain_type(
165
  llm=llm,
166
  retriever=retriever,
@@ -190,7 +194,6 @@ def update_last_entity(doc):
190
  pass
191
 
192
  def render_json_links():
193
- # Tiny inline [view JSON] links, expands in-place on click
194
  for key in st.session_state.json_links:
195
  info = st.session_state.json_link_details[key]
196
  label = info["label"]
@@ -206,7 +209,6 @@ def send_message():
206
  user_input = st.session_state.temp_input.strip()
207
  if not user_input:
208
  return
209
- # Entity resolution for pronouns (he, his, etc.)
210
  pronoun = re.search(r"\b(he|his|him|her|she|their)\b", user_input, re.I)
211
  if st.session_state.last_entity and pronoun:
212
  q = f"For {st.session_state.last_entity}: {user_input}"
@@ -230,10 +232,9 @@ def send_message():
230
  link_keys.append(link_key)
231
  st.session_state.json_links = link_keys
232
  st.session_state.json_link_details = link_details
233
- st.session_state.modal_link = None # reset on every new message
234
  st.session_state.temp_input = ""
235
 
236
- # ---- Chat Conversation Rendering ----
237
  for msg in st.session_state.messages:
238
  if msg["role"] == "user":
239
  st.markdown(f"<b style='color:#3575dd'>User:</b> <span style='color:#111'>{msg['content']}</span>", unsafe_allow_html=True)
 
6
  import pandas as pd
7
  import numpy as np
8
  import datetime
9
+ from typing import List
10
  import openai
11
  from langchain.schema import Document
12
  from langchain.chains import RetrievalQA
13
  from langchain_community.llms import OpenAI as LangOpenAI
14
 
15
+ # --- FIX: Import correct BaseRetriever
16
+ from langchain_core.retrievers import BaseRetriever
17
+
18
  # ---- CONFIG ----
19
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
  EMBEDDING_MODEL = "text-embedding-ada-002"
 
38
 
39
  # ---- Helper: Flatten JSON ----
40
  def flatten_json_obj(obj, parent_key="", sep="."):
 
41
  items = {}
42
  if isinstance(obj, dict):
43
  for k, v in obj.items():
 
100
  if isinstance(raw, list):
101
  records = raw
102
  elif isinstance(raw, dict):
 
103
  main_lists = [v for v in raw.values() if isinstance(v, list)]
104
  if main_lists:
105
  records = main_lists[0]
 
109
  records = [raw]
110
  for rec in records:
111
  flat = flatten_json_obj(rec)
112
+ # Add entity keys if found
113
  if "customer" in rec and isinstance(rec["customer"], str):
114
  first_name = rec["customer"].split("@")[0].replace(".", " ")
115
  flat["customer_name"] = first_name
 
140
  continue
141
  sim = float(np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb)))
142
  results.append((sim, row))
 
143
  results = sorted(results, reverse=True, key=lambda x: x[0])[:top_k]
144
  docs = []
145
  for sim, row in results:
 
154
  return docs
155
 
156
  # ---- LangChain Retriever Adapter ----
157
+ class SQLiteVectorRetriever(BaseRetriever):
158
+ def __init__(self, top_k=5):
159
+ self.top_k = top_k
160
+ super().__init__()
161
+
162
+ def get_relevant_documents(self, query: str) -> List[Document]:
163
+ return query_vector_db(query, self.top_k)
164
 
165
  # ---- LangChain LLM & QA Chain ----
166
  llm = LangOpenAI(model_name="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
167
+ retriever = SQLiteVectorRetriever(top_k=5)
168
  qa_chain = RetrievalQA.from_chain_type(
169
  llm=llm,
170
  retriever=retriever,
 
194
  pass
195
 
196
  def render_json_links():
 
197
  for key in st.session_state.json_links:
198
  info = st.session_state.json_link_details[key]
199
  label = info["label"]
 
209
  user_input = st.session_state.temp_input.strip()
210
  if not user_input:
211
  return
 
212
  pronoun = re.search(r"\b(he|his|him|her|she|their)\b", user_input, re.I)
213
  if st.session_state.last_entity and pronoun:
214
  q = f"For {st.session_state.last_entity}: {user_input}"
 
232
  link_keys.append(link_key)
233
  st.session_state.json_links = link_keys
234
  st.session_state.json_link_details = link_details
235
+ st.session_state.modal_link = None
236
  st.session_state.temp_input = ""
237
 
 
238
  for msg in st.session_state.messages:
239
  if msg["role"] == "user":
240
  st.markdown(f"<b style='color:#3575dd'>User:</b> <span style='color:#111'>{msg['content']}</span>", unsafe_allow_html=True)