Update app.py
Browse files
app.py
CHANGED
|
@@ -6,12 +6,15 @@ import sqlite3
|
|
| 6 |
import pandas as pd
|
| 7 |
import numpy as np
|
| 8 |
import datetime
|
| 9 |
-
from typing import List
|
| 10 |
import openai
|
| 11 |
from langchain.schema import Document
|
| 12 |
from langchain.chains import RetrievalQA
|
| 13 |
from langchain_community.llms import OpenAI as LangOpenAI
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
# ---- CONFIG ----
|
| 16 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 17 |
EMBEDDING_MODEL = "text-embedding-ada-002"
|
|
@@ -35,7 +38,6 @@ if "last_entity" not in st.session_state:
|
|
| 35 |
|
| 36 |
# ---- Helper: Flatten JSON ----
|
| 37 |
def flatten_json_obj(obj, parent_key="", sep="."):
|
| 38 |
-
"""Flatten nested JSON objects/lists with dot notation."""
|
| 39 |
items = {}
|
| 40 |
if isinstance(obj, dict):
|
| 41 |
for k, v in obj.items():
|
|
@@ -98,7 +100,6 @@ def ingest_json_files(files):
|
|
| 98 |
if isinstance(raw, list):
|
| 99 |
records = raw
|
| 100 |
elif isinstance(raw, dict):
|
| 101 |
-
# If dict with a single main list, use it
|
| 102 |
main_lists = [v for v in raw.values() if isinstance(v, list)]
|
| 103 |
if main_lists:
|
| 104 |
records = main_lists[0]
|
|
@@ -108,7 +109,7 @@ def ingest_json_files(files):
|
|
| 108 |
records = [raw]
|
| 109 |
for rec in records:
|
| 110 |
flat = flatten_json_obj(rec)
|
| 111 |
-
#
|
| 112 |
if "customer" in rec and isinstance(rec["customer"], str):
|
| 113 |
first_name = rec["customer"].split("@")[0].replace(".", " ")
|
| 114 |
flat["customer_name"] = first_name
|
|
@@ -139,7 +140,6 @@ def query_vector_db(user_query, top_k=5):
|
|
| 139 |
continue
|
| 140 |
sim = float(np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb)))
|
| 141 |
results.append((sim, row))
|
| 142 |
-
# Top K by similarity
|
| 143 |
results = sorted(results, reverse=True, key=lambda x: x[0])[:top_k]
|
| 144 |
docs = []
|
| 145 |
for sim, row in results:
|
|
@@ -154,13 +154,17 @@ def query_vector_db(user_query, top_k=5):
|
|
| 154 |
return docs
|
| 155 |
|
| 156 |
# ---- LangChain Retriever Adapter ----
|
| 157 |
-
class SQLiteVectorRetriever:
|
| 158 |
-
def
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
# ---- LangChain LLM & QA Chain ----
|
| 162 |
llm = LangOpenAI(model_name="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
|
| 163 |
-
retriever = SQLiteVectorRetriever()
|
| 164 |
qa_chain = RetrievalQA.from_chain_type(
|
| 165 |
llm=llm,
|
| 166 |
retriever=retriever,
|
|
@@ -190,7 +194,6 @@ def update_last_entity(doc):
|
|
| 190 |
pass
|
| 191 |
|
| 192 |
def render_json_links():
|
| 193 |
-
# Tiny inline [view JSON] links, expands in-place on click
|
| 194 |
for key in st.session_state.json_links:
|
| 195 |
info = st.session_state.json_link_details[key]
|
| 196 |
label = info["label"]
|
|
@@ -206,7 +209,6 @@ def send_message():
|
|
| 206 |
user_input = st.session_state.temp_input.strip()
|
| 207 |
if not user_input:
|
| 208 |
return
|
| 209 |
-
# Entity resolution for pronouns (he, his, etc.)
|
| 210 |
pronoun = re.search(r"\b(he|his|him|her|she|their)\b", user_input, re.I)
|
| 211 |
if st.session_state.last_entity and pronoun:
|
| 212 |
q = f"For {st.session_state.last_entity}: {user_input}"
|
|
@@ -230,10 +232,9 @@ def send_message():
|
|
| 230 |
link_keys.append(link_key)
|
| 231 |
st.session_state.json_links = link_keys
|
| 232 |
st.session_state.json_link_details = link_details
|
| 233 |
-
st.session_state.modal_link = None
|
| 234 |
st.session_state.temp_input = ""
|
| 235 |
|
| 236 |
-
# ---- Chat Conversation Rendering ----
|
| 237 |
for msg in st.session_state.messages:
|
| 238 |
if msg["role"] == "user":
|
| 239 |
st.markdown(f"<b style='color:#3575dd'>User:</b> <span style='color:#111'>{msg['content']}</span>", unsafe_allow_html=True)
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
import numpy as np
|
| 8 |
import datetime
|
| 9 |
+
from typing import List
|
| 10 |
import openai
|
| 11 |
from langchain.schema import Document
|
| 12 |
from langchain.chains import RetrievalQA
|
| 13 |
from langchain_community.llms import OpenAI as LangOpenAI
|
| 14 |
|
| 15 |
+
# --- FIX: Import correct BaseRetriever
|
| 16 |
+
from langchain_core.retrievers import BaseRetriever
|
| 17 |
+
|
| 18 |
# ---- CONFIG ----
|
| 19 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 20 |
EMBEDDING_MODEL = "text-embedding-ada-002"
|
|
|
|
| 38 |
|
| 39 |
# ---- Helper: Flatten JSON ----
|
| 40 |
def flatten_json_obj(obj, parent_key="", sep="."):
|
|
|
|
| 41 |
items = {}
|
| 42 |
if isinstance(obj, dict):
|
| 43 |
for k, v in obj.items():
|
|
|
|
| 100 |
if isinstance(raw, list):
|
| 101 |
records = raw
|
| 102 |
elif isinstance(raw, dict):
|
|
|
|
| 103 |
main_lists = [v for v in raw.values() if isinstance(v, list)]
|
| 104 |
if main_lists:
|
| 105 |
records = main_lists[0]
|
|
|
|
| 109 |
records = [raw]
|
| 110 |
for rec in records:
|
| 111 |
flat = flatten_json_obj(rec)
|
| 112 |
+
# Add entity keys if found
|
| 113 |
if "customer" in rec and isinstance(rec["customer"], str):
|
| 114 |
first_name = rec["customer"].split("@")[0].replace(".", " ")
|
| 115 |
flat["customer_name"] = first_name
|
|
|
|
| 140 |
continue
|
| 141 |
sim = float(np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb)))
|
| 142 |
results.append((sim, row))
|
|
|
|
| 143 |
results = sorted(results, reverse=True, key=lambda x: x[0])[:top_k]
|
| 144 |
docs = []
|
| 145 |
for sim, row in results:
|
|
|
|
| 154 |
return docs
|
| 155 |
|
| 156 |
# ---- LangChain Retriever Adapter ----
|
| 157 |
+
class SQLiteVectorRetriever(BaseRetriever):
|
| 158 |
+
def __init__(self, top_k=5):
|
| 159 |
+
self.top_k = top_k
|
| 160 |
+
super().__init__()
|
| 161 |
+
|
| 162 |
+
def get_relevant_documents(self, query: str) -> List[Document]:
|
| 163 |
+
return query_vector_db(query, self.top_k)
|
| 164 |
|
| 165 |
# ---- LangChain LLM & QA Chain ----
|
| 166 |
llm = LangOpenAI(model_name="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
|
| 167 |
+
retriever = SQLiteVectorRetriever(top_k=5)
|
| 168 |
qa_chain = RetrievalQA.from_chain_type(
|
| 169 |
llm=llm,
|
| 170 |
retriever=retriever,
|
|
|
|
| 194 |
pass
|
| 195 |
|
| 196 |
def render_json_links():
|
|
|
|
| 197 |
for key in st.session_state.json_links:
|
| 198 |
info = st.session_state.json_link_details[key]
|
| 199 |
label = info["label"]
|
|
|
|
| 209 |
user_input = st.session_state.temp_input.strip()
|
| 210 |
if not user_input:
|
| 211 |
return
|
|
|
|
| 212 |
pronoun = re.search(r"\b(he|his|him|her|she|their)\b", user_input, re.I)
|
| 213 |
if st.session_state.last_entity and pronoun:
|
| 214 |
q = f"For {st.session_state.last_entity}: {user_input}"
|
|
|
|
| 232 |
link_keys.append(link_key)
|
| 233 |
st.session_state.json_links = link_keys
|
| 234 |
st.session_state.json_link_details = link_details
|
| 235 |
+
st.session_state.modal_link = None
|
| 236 |
st.session_state.temp_input = ""
|
| 237 |
|
|
|
|
| 238 |
for msg in st.session_state.messages:
|
| 239 |
if msg["role"] == "user":
|
| 240 |
st.markdown(f"<b style='color:#3575dd'>User:</b> <span style='color:#111'>{msg['content']}</span>", unsafe_allow_html=True)
|