Seth0330 commited on
Commit
3f6d044
·
verified ·
1 Parent(s): 2d11ab5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -34
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import streamlit as st
3
  import pandas as pd
4
  import openai
5
- import pyodbc
6
  import json
7
  import numpy as np
8
  import datetime
@@ -11,9 +11,9 @@ from langchain.llms import OpenAI
11
  from langchain.schema import Document
12
 
13
  # --- CONFIG ---
14
- AZURE_SQL_CONN_STR = "DRIVER={ODBC Driver 17 for SQL Server};SERVER=sethsrv.database.windows.net;DATABASE=sethdb;UID=seth;PWD=Senth@mil123"
15
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") # Or paste your key here
16
- EMBEDDING_MODEL = "text-embedding-ada-002" # Or your Azure embedding model
17
 
18
  # --- Streamlit State Initialization ---
19
  if "ingested_batches" not in st.session_state:
@@ -27,8 +27,8 @@ if "modal_content" not in st.session_state:
27
  if "modal_title" not in st.session_state:
28
  st.session_state.modal_title = ""
29
 
30
- st.set_page_config(page_title="Cumulative JSON Vector Search", layout="wide")
31
- st.title("LLM-Powered Analytics: Cumulative JSON Vector DB (Azure SQL)")
32
 
33
  uploaded_files = st.file_uploader(
34
  "Upload JSON files in batches (any structure)", type="json", accept_multiple_files=True
@@ -57,18 +57,17 @@ def get_embedding(text):
57
 
58
  # --- Ensure DB Table (accumulates all uploads, never deletes old data) ---
59
  def ensure_table():
60
- conn = pyodbc.connect(AZURE_SQL_CONN_STR)
61
  cursor = conn.cursor()
62
  cursor.execute("""
63
- IF OBJECT_ID('dbo.json_records', 'U') IS NULL
64
- CREATE TABLE json_records (
65
- id INT PRIMARY KEY IDENTITY,
66
- batch_time DATETIME,
67
- source_file NVARCHAR(255),
68
- raw_json NVARCHAR(MAX),
69
- flat_text NVARCHAR(MAX),
70
- embedding VARBINARY(MAX)
71
- );
72
  """)
73
  conn.commit()
74
  conn.close()
@@ -77,7 +76,7 @@ def ensure_table():
77
  def ingest_json_files(files):
78
  ensure_table()
79
  rows = []
80
- batch_time = datetime.datetime.utcnow()
81
  for file in files:
82
  raw = json.load(file)
83
  source_name = file.name
@@ -85,7 +84,6 @@ def ingest_json_files(files):
85
  if isinstance(raw, list):
86
  records = raw
87
  elif isinstance(raw, dict):
88
- # If nested records (like {"people": [...]})
89
  main_lists = [v for v in raw.values() if isinstance(v, list)]
90
  if main_lists:
91
  records = main_lists[0]
@@ -104,10 +102,10 @@ def ingest_json_files(files):
104
  st.write(f"Flattened {len(df)} records. Generating embeddings (this may take time, please wait)...")
105
  df["embedding"] = df["flat_text"].apply(get_embedding)
106
  # Insert into DB
107
- conn = pyodbc.connect(AZURE_SQL_CONN_STR)
108
  cursor = conn.cursor()
109
  for _, row in df.iterrows():
110
- emb_bytes = bytearray(np.array(row.embedding, dtype=np.float32).tobytes())
111
  cursor.execute("""
112
  INSERT INTO json_records (batch_time, source_file, raw_json, flat_text, embedding)
113
  VALUES (?, ?, ?, ?, ?)
@@ -123,12 +121,12 @@ if uploaded_files and st.button("Ingest batch to database"):
123
  # --- Query entire cumulative DB (ALL past and present records) ---
124
  def query_vector_db(user_query, top_k=5):
125
  query_emb = get_embedding(user_query)
126
- conn = pyodbc.connect(AZURE_SQL_CONN_STR)
127
  cursor = conn.cursor()
128
  cursor.execute("SELECT id, batch_time, source_file, raw_json, flat_text, embedding FROM json_records")
129
  results = []
130
  for row in cursor.fetchall():
131
- db_emb = np.frombuffer(row.embedding, dtype=np.float32)
132
  if len(db_emb) != len(query_emb): continue # Skip malformed
133
  sim = np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb))
134
  results.append((sim, row))
@@ -137,35 +135,34 @@ def query_vector_db(user_query, top_k=5):
137
  docs = []
138
  for sim, row in results:
139
  meta = {
140
- "id": row.id,
141
- "batch_time": str(row.batch_time),
142
- "source_file": row.source_file,
143
  "similarity": f"{sim:.4f}",
144
- "raw_json": row.raw_json,
145
  }
146
- docs.append(Document(page_content=row.flat_text, metadata=meta))
147
  return docs
148
 
149
  # --- LangChain Retriever ---
150
- class AzureSQLVectorRetriever:
151
  def __init__(self, top_k=5):
152
  self.top_k = top_k
153
  def get_relevant_documents(self, query):
154
  return query_vector_db(query, self.top_k)
155
 
156
- llm = OpenAI(model="gpt-4o", openai_api_key=OPENAI_API_KEY, temperature=0)
157
- retriever = AzureSQLVectorRetriever(top_k=5)
158
  qa_chain = RetrievalQA.from_chain_type(
159
  llm=llm,
160
  retriever=retriever,
161
  return_source_documents=True,
162
  )
163
 
164
- # --- Chat UI & Conversation Loop (preserves your history/modal system) ---
165
  st.header("Chat with all accumulated records")
166
 
167
  def show_json_links_and_modal():
168
- # Scan last result for JSON modal links
169
  for speaker, msg in reversed(st.session_state.chat_history):
170
  if speaker == "AI_DOCS":
171
  docs = msg
@@ -181,7 +178,6 @@ def show_json_links_and_modal():
181
  if st.button("Close", key="close_modal"):
182
  st.session_state.modal_open = False
183
 
184
- # --- Chat input ---
185
  user_input = st.text_input("Ask a question about ALL data (old and new):", key="user_input")
186
  if st.button("Send") and user_input:
187
  with st.spinner("Thinking..."):
@@ -190,7 +186,6 @@ if st.button("Send") and user_input:
190
  st.session_state.chat_history.append(("AI", result['result']))
191
  st.session_state.chat_history.append(("AI_DOCS", result['source_documents']))
192
 
193
- # --- Show conversation ---
194
  for speaker, msg in st.session_state.chat_history:
195
  if speaker == "User":
196
  st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg}</div>", unsafe_allow_html=True)
 
2
  import streamlit as st
3
  import pandas as pd
4
  import openai
5
+ import sqlite3
6
  import json
7
  import numpy as np
8
  import datetime
 
11
  from langchain.schema import Document
12
 
13
  # --- CONFIG ---
14
+ DB_PATH = "json_vector.db"
15
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
16
+ EMBEDDING_MODEL = "text-embedding-ada-002"
17
 
18
  # --- Streamlit State Initialization ---
19
  if "ingested_batches" not in st.session_state:
 
27
  if "modal_title" not in st.session_state:
28
  st.session_state.modal_title = ""
29
 
30
+ st.set_page_config(page_title="Cumulative JSON Vector Search (SQLite)", layout="wide")
31
+ st.title("LLM-Powered Analytics: Cumulative JSON Vector DB (SQLite, Local)")
32
 
33
  uploaded_files = st.file_uploader(
34
  "Upload JSON files in batches (any structure)", type="json", accept_multiple_files=True
 
57
 
58
  # --- Ensure DB Table (accumulates all uploads, never deletes old data) ---
59
  def ensure_table():
60
+ conn = sqlite3.connect(DB_PATH)
61
  cursor = conn.cursor()
62
  cursor.execute("""
63
+ CREATE TABLE IF NOT EXISTS json_records (
64
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
65
+ batch_time TEXT,
66
+ source_file TEXT,
67
+ raw_json TEXT,
68
+ flat_text TEXT,
69
+ embedding BLOB
70
+ )
 
71
  """)
72
  conn.commit()
73
  conn.close()
 
76
  def ingest_json_files(files):
77
  ensure_table()
78
  rows = []
79
+ batch_time = datetime.datetime.utcnow().isoformat()
80
  for file in files:
81
  raw = json.load(file)
82
  source_name = file.name
 
84
  if isinstance(raw, list):
85
  records = raw
86
  elif isinstance(raw, dict):
 
87
  main_lists = [v for v in raw.values() if isinstance(v, list)]
88
  if main_lists:
89
  records = main_lists[0]
 
102
  st.write(f"Flattened {len(df)} records. Generating embeddings (this may take time, please wait)...")
103
  df["embedding"] = df["flat_text"].apply(get_embedding)
104
  # Insert into DB
105
+ conn = sqlite3.connect(DB_PATH)
106
  cursor = conn.cursor()
107
  for _, row in df.iterrows():
108
+ emb_bytes = np.array(row.embedding, dtype=np.float32).tobytes()
109
  cursor.execute("""
110
  INSERT INTO json_records (batch_time, source_file, raw_json, flat_text, embedding)
111
  VALUES (?, ?, ?, ?, ?)
 
121
  # --- Query entire cumulative DB (ALL past and present records) ---
122
  def query_vector_db(user_query, top_k=5):
123
  query_emb = get_embedding(user_query)
124
+ conn = sqlite3.connect(DB_PATH)
125
  cursor = conn.cursor()
126
  cursor.execute("SELECT id, batch_time, source_file, raw_json, flat_text, embedding FROM json_records")
127
  results = []
128
  for row in cursor.fetchall():
129
+ db_emb = np.frombuffer(row[5], dtype=np.float32)
130
  if len(db_emb) != len(query_emb): continue # Skip malformed
131
  sim = np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb))
132
  results.append((sim, row))
 
135
  docs = []
136
  for sim, row in results:
137
  meta = {
138
+ "id": row[0],
139
+ "batch_time": str(row[1]),
140
+ "source_file": row[2],
141
  "similarity": f"{sim:.4f}",
142
+ "raw_json": row[3],
143
  }
144
+ docs.append(Document(page_content=row[4], metadata=meta))
145
  return docs
146
 
147
  # --- LangChain Retriever ---
148
+ class SQLiteVectorRetriever:
149
  def __init__(self, top_k=5):
150
  self.top_k = top_k
151
  def get_relevant_documents(self, query):
152
  return query_vector_db(query, self.top_k)
153
 
154
+ llm = OpenAI(model="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
155
+ retriever = SQLiteVectorRetriever(top_k=5)
156
  qa_chain = RetrievalQA.from_chain_type(
157
  llm=llm,
158
  retriever=retriever,
159
  return_source_documents=True,
160
  )
161
 
162
+ # --- Chat UI & Conversation Loop (with modal) ---
163
  st.header("Chat with all accumulated records")
164
 
165
  def show_json_links_and_modal():
 
166
  for speaker, msg in reversed(st.session_state.chat_history):
167
  if speaker == "AI_DOCS":
168
  docs = msg
 
178
  if st.button("Close", key="close_modal"):
179
  st.session_state.modal_open = False
180
 
 
181
  user_input = st.text_input("Ask a question about ALL data (old and new):", key="user_input")
182
  if st.button("Send") and user_input:
183
  with st.spinner("Thinking..."):
 
186
  st.session_state.chat_history.append(("AI", result['result']))
187
  st.session_state.chat_history.append(("AI_DOCS", result['source_documents']))
188
 
 
189
  for speaker, msg in st.session_state.chat_history:
190
  if speaker == "User":
191
  st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg}</div>", unsafe_allow_html=True)