Seth0330 commited on
Commit
a84926c
·
verified ·
1 Parent(s): e145b0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -46
app.py CHANGED
@@ -11,18 +11,19 @@ from langchain.chains import RetrievalQA
11
  from langchain.schema import Document
12
  from langchain_core.retrievers import BaseRetriever
13
  from pydantic import Field
14
- from langchain_openai import ChatOpenAI # FIXED: Use ChatOpenAI for chat models
 
15
 
16
  # --- CONFIG ---
17
  DB_PATH = "json_vector.db"
18
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
19
  EMBEDDING_MODEL = "text-embedding-ada-002"
20
 
21
- # --- Streamlit State Initialization ---
22
  if "ingested_batches" not in st.session_state:
23
  st.session_state.ingested_batches = 0
24
- if "chat_history" not in st.session_state:
25
- st.session_state.chat_history = []
26
  if "modal_open" not in st.session_state:
27
  st.session_state.modal_open = False
28
  if "modal_content" not in st.session_state:
@@ -30,14 +31,13 @@ if "modal_content" not in st.session_state:
30
  if "modal_title" not in st.session_state:
31
  st.session_state.modal_title = ""
32
 
33
- st.set_page_config(page_title="Cumulative JSON Vector Search (SQLite)", layout="wide")
34
- st.title("LLM-Powered Analytics: Cumulative JSON Vector DB (SQLite, Local)")
35
 
36
  uploaded_files = st.file_uploader(
37
  "Upload JSON files in batches (any structure)", type="json", accept_multiple_files=True
38
  )
39
 
40
- # --- Helper: Flatten any unstructured JSON (handles dict, list, nested, various keys) ---
41
  def flatten_json_obj(obj, parent_key="", sep="."):
42
  items = {}
43
  if isinstance(obj, dict):
@@ -52,13 +52,11 @@ def flatten_json_obj(obj, parent_key="", sep="."):
52
  items[parent_key] = obj
53
  return items
54
 
55
- # --- Embedding function (openai>=1.0.0 style) ---
56
  def get_embedding(text):
57
  client = openai.OpenAI(api_key=OPENAI_API_KEY)
58
  response = client.embeddings.create(input=[text], model=EMBEDDING_MODEL)
59
  return response.data[0].embedding
60
 
61
- # --- Ensure DB Table (accumulates all uploads, never deletes old data) ---
62
  def ensure_table():
63
  conn = sqlite3.connect(DB_PATH)
64
  cursor = conn.cursor()
@@ -75,7 +73,6 @@ def ensure_table():
75
  conn.commit()
76
  conn.close()
77
 
78
- # --- Ingest and accumulate uploaded files ---
79
  def ingest_json_files(files):
80
  ensure_table()
81
  rows = []
@@ -83,15 +80,11 @@ def ingest_json_files(files):
83
  for file in files:
84
  raw = json.load(file)
85
  source_name = file.name
86
- # Handle top-level list or dict
87
  if isinstance(raw, list):
88
  records = raw
89
  elif isinstance(raw, dict):
90
  main_lists = [v for v in raw.values() if isinstance(v, list)]
91
- if main_lists:
92
- records = main_lists[0]
93
- else:
94
- records = [raw]
95
  else:
96
  records = [raw]
97
  for rec in records:
@@ -104,7 +97,6 @@ def ingest_json_files(files):
104
  df = pd.DataFrame(rows, columns=["batch_time", "source_file", "raw_json", "flat_text"])
105
  st.write(f"Flattened {len(df)} records. Generating embeddings (this may take time, please wait)...")
106
  df["embedding"] = df["flat_text"].apply(get_embedding)
107
- # Insert into DB
108
  conn = sqlite3.connect(DB_PATH)
109
  cursor = conn.cursor()
110
  for _, row in df.iterrows():
@@ -121,7 +113,6 @@ def ingest_json_files(files):
121
  if uploaded_files and st.button("Ingest batch to database"):
122
  ingest_json_files(uploaded_files)
123
 
124
- # --- Query entire cumulative DB (ALL past and present records) ---
125
  def query_vector_db(user_query, top_k=5):
126
  query_emb = get_embedding(user_query)
127
  conn = sqlite3.connect(DB_PATH)
@@ -130,7 +121,7 @@ def query_vector_db(user_query, top_k=5):
130
  results = []
131
  for row in cursor.fetchall():
132
  db_emb = np.frombuffer(row[5], dtype=np.float32)
133
- if len(db_emb) != len(query_emb): continue # Skip malformed
134
  sim = np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb))
135
  results.append((sim, row))
136
  conn.close()
@@ -147,33 +138,58 @@ def query_vector_db(user_query, top_k=5):
147
  docs.append(Document(page_content=row[4], metadata=meta))
148
  return docs
149
 
150
- # --- LangChain Retriever (BaseRetriever subclass, Pydantic v2 compliant) ---
151
  class SQLiteVectorRetriever(BaseRetriever):
152
  top_k: int = Field(default=5)
153
-
154
  def _get_relevant_documents(self, query, run_manager=None, **kwargs):
155
  return query_vector_db(query, self.top_k)
156
 
157
- llm = ChatOpenAI(model="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0) # FIXED: use ChatOpenAI!
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  retriever = SQLiteVectorRetriever(top_k=5)
159
  qa_chain = RetrievalQA.from_chain_type(
160
  llm=llm,
161
  retriever=retriever,
 
162
  return_source_documents=True,
163
  )
164
 
165
- # --- Chat UI & Conversation Loop (with modal) ---
166
- st.header("Chat with all accumulated records")
 
 
 
 
 
 
 
167
 
168
  def show_json_links_and_modal():
169
- for speaker, msg in reversed(st.session_state.chat_history):
170
- if speaker == "AI_DOCS":
171
- docs = msg
172
- for idx, doc in enumerate(docs):
173
- if st.button(f"View JSON: {doc.metadata['source_file']} (#{doc.metadata['id']})", key=f"modal_{idx}"):
174
- st.session_state.modal_open = True
175
- st.session_state.modal_content = json.dumps(json.loads(doc.metadata["raw_json"]), indent=2)
176
- st.session_state.modal_title = f"{doc.metadata['source_file']} (#{doc.metadata['id']})"
 
 
 
 
 
 
177
  break
178
  if st.session_state.modal_open:
179
  with st.expander(f"JSON Record: {st.session_state.modal_title}", expanded=True):
@@ -181,23 +197,32 @@ def show_json_links_and_modal():
181
  if st.button("Close", key="close_modal"):
182
  st.session_state.modal_open = False
183
 
184
- user_input = st.text_input("Ask a question about ALL data (old and new):", key="user_input")
185
- if st.button("Send") and user_input:
186
- with st.spinner("Thinking..."):
187
- result = qa_chain(user_input)
188
- st.session_state.chat_history.append(("User", user_input))
189
- st.session_state.chat_history.append(("AI", result['result']))
190
- st.session_state.chat_history.append(("AI_DOCS", result['source_documents']))
191
-
192
- for speaker, msg in st.session_state.chat_history:
193
- if speaker == "User":
194
- st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg}</div>", unsafe_allow_html=True)
195
- elif speaker == "AI":
196
- st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {msg}</div>", unsafe_allow_html=True)
197
-
198
  show_json_links_and_modal()
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if st.button("Clear chat"):
201
- st.session_state.chat_history = []
202
 
203
  st.info(f"Batches ingested so far (this session): {st.session_state.ingested_batches}")
 
11
  from langchain.schema import Document
12
  from langchain_core.retrievers import BaseRetriever
13
  from pydantic import Field
14
+ from langchain_openai import ChatOpenAI
15
+ from langchain.prompts import ChatPromptTemplate
16
 
17
  # --- CONFIG ---
18
  DB_PATH = "json_vector.db"
19
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
20
  EMBEDDING_MODEL = "text-embedding-ada-002"
21
 
22
+ # --- State Initialization ---
23
  if "ingested_batches" not in st.session_state:
24
  st.session_state.ingested_batches = 0
25
+ if "messages" not in st.session_state:
26
+ st.session_state.messages = []
27
  if "modal_open" not in st.session_state:
28
  st.session_state.modal_open = False
29
  if "modal_content" not in st.session_state:
 
31
  if "modal_title" not in st.session_state:
32
  st.session_state.modal_title = ""
33
 
34
+ st.set_page_config(page_title="Chat with Your JSON Vectors", layout="wide")
35
+ st.title("Chat with Your Vectorized JSON Files (LangChain, SQLite, LLM)")
36
 
37
  uploaded_files = st.file_uploader(
38
  "Upload JSON files in batches (any structure)", type="json", accept_multiple_files=True
39
  )
40
 
 
41
  def flatten_json_obj(obj, parent_key="", sep="."):
42
  items = {}
43
  if isinstance(obj, dict):
 
52
  items[parent_key] = obj
53
  return items
54
 
 
55
  def get_embedding(text):
56
  client = openai.OpenAI(api_key=OPENAI_API_KEY)
57
  response = client.embeddings.create(input=[text], model=EMBEDDING_MODEL)
58
  return response.data[0].embedding
59
 
 
60
  def ensure_table():
61
  conn = sqlite3.connect(DB_PATH)
62
  cursor = conn.cursor()
 
73
  conn.commit()
74
  conn.close()
75
 
 
76
  def ingest_json_files(files):
77
  ensure_table()
78
  rows = []
 
80
  for file in files:
81
  raw = json.load(file)
82
  source_name = file.name
 
83
  if isinstance(raw, list):
84
  records = raw
85
  elif isinstance(raw, dict):
86
  main_lists = [v for v in raw.values() if isinstance(v, list)]
87
+ records = main_lists[0] if main_lists else [raw]
 
 
 
88
  else:
89
  records = [raw]
90
  for rec in records:
 
97
  df = pd.DataFrame(rows, columns=["batch_time", "source_file", "raw_json", "flat_text"])
98
  st.write(f"Flattened {len(df)} records. Generating embeddings (this may take time, please wait)...")
99
  df["embedding"] = df["flat_text"].apply(get_embedding)
 
100
  conn = sqlite3.connect(DB_PATH)
101
  cursor = conn.cursor()
102
  for _, row in df.iterrows():
 
113
  if uploaded_files and st.button("Ingest batch to database"):
114
  ingest_json_files(uploaded_files)
115
 
 
116
  def query_vector_db(user_query, top_k=5):
117
  query_emb = get_embedding(user_query)
118
  conn = sqlite3.connect(DB_PATH)
 
121
  results = []
122
  for row in cursor.fetchall():
123
  db_emb = np.frombuffer(row[5], dtype=np.float32)
124
+ if len(db_emb) != len(query_emb): continue
125
  sim = np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb))
126
  results.append((sim, row))
127
  conn.close()
 
138
  docs.append(Document(page_content=row[4], metadata=meta))
139
  return docs
140
 
 
141
  class SQLiteVectorRetriever(BaseRetriever):
142
  top_k: int = Field(default=5)
 
143
  def _get_relevant_documents(self, query, run_manager=None, **kwargs):
144
  return query_vector_db(query, self.top_k)
145
 
146
+ # --- FINETUNED SYSTEM PROMPT FOR DIRECT ANSWERS ---
147
+ system_prompt = (
148
+ "You are a JSON data assistant. Always give a direct, concise answer based only on the context provided. "
149
+ "If you do not see the answer in the context, reply: 'I don’t have that information.' "
150
+ "Never make up information. Never ask for clarification."
151
+ )
152
+
153
+ prompt = ChatPromptTemplate.from_messages([
154
+ ("system", system_prompt),
155
+ ("human", "{question}")
156
+ ])
157
+
158
+ llm = ChatOpenAI(model="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
159
+
160
  retriever = SQLiteVectorRetriever(top_k=5)
161
  qa_chain = RetrievalQA.from_chain_type(
162
  llm=llm,
163
  retriever=retriever,
164
+ chain_type_kwargs={"prompt": prompt},
165
  return_source_documents=True,
166
  )
167
 
168
+ # --- Conversation Area (fine-tuned style) ---
169
+ st.markdown("### Ask any question about your data, just like ChatGPT.")
170
+ for msg in st.session_state.messages:
171
+ if msg["role"] == "user":
172
+ st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg['content']}</div>", unsafe_allow_html=True)
173
+ elif msg["role"] == "assistant":
174
+ st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {msg['content']}</div>", unsafe_allow_html=True)
175
+ elif msg["role"] == "function":
176
+ st.markdown(f"<details><summary><b>Function Output:</b></summary><pre>{msg['content']}</pre></details>", unsafe_allow_html=True)
177
 
178
  def show_json_links_and_modal():
179
+ # Look for last function message (top results) and display view buttons
180
+ for msg in reversed(st.session_state.messages):
181
+ if msg.get("role") == "function" and msg.get("content"):
182
+ try:
183
+ docs = json.loads(msg["content"])
184
+ if isinstance(docs, list):
185
+ for idx, doc in enumerate(docs):
186
+ if isinstance(doc, dict) and "record" in doc:
187
+ if st.button(f"View JSON: {doc.get('file', 'unknown')} record #{idx+1}", key=f"modal_function_{idx}"):
188
+ st.session_state.modal_open = True
189
+ st.session_state.modal_content = json.dumps(doc["record"], indent=2)
190
+ st.session_state.modal_title = f"{doc.get('file', 'unknown')} record #{idx+1}"
191
+ except Exception:
192
+ continue
193
  break
194
  if st.session_state.modal_open:
195
  with st.expander(f"JSON Record: {st.session_state.modal_title}", expanded=True):
 
197
  if st.button("Close", key="close_modal"):
198
  st.session_state.modal_open = False
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  show_json_links_and_modal()
201
 
202
+ def send_message():
203
+ user_input = st.session_state.temp_input.strip()
204
+ if not user_input:
205
+ return
206
+ st.session_state.messages.append({"role": "user", "content": user_input})
207
+ with st.spinner("Thinking..."):
208
+ # Use the chain with { "question": ... } to match prompt format
209
+ result = qa_chain({"question": user_input})
210
+ answer = result['result']
211
+ st.session_state.messages.append({"role": "assistant", "content": answer})
212
+ docs = result['source_documents']
213
+ doc_list = []
214
+ for doc in docs:
215
+ doc_list.append({
216
+ "file": doc.metadata["source_file"],
217
+ "id": doc.metadata["id"],
218
+ "record": json.loads(doc.metadata["raw_json"])
219
+ })
220
+ st.session_state.messages.append({"role": "function", "content": json.dumps(doc_list, indent=2)})
221
+ st.session_state.temp_input = ""
222
+
223
+ st.text_input("Your message:", key="temp_input", on_change=send_message)
224
+
225
  if st.button("Clear chat"):
226
+ st.session_state.messages = []
227
 
228
  st.info(f"Batches ingested so far (this session): {st.session_state.ingested_batches}")