Seth0330's picture
Update app.py
8fa2c42 verified
raw
history blame
7.86 kB
import os
import streamlit as st
import pandas as pd
import openai
import sqlite3
import json
import numpy as np
import datetime
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.schema import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import Field
# --- CONFIG ---
DB_PATH = "json_vector.db"
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
EMBEDDING_MODEL = "text-embedding-ada-002"
# --- Streamlit State Initialization ---
if "ingested_batches" not in st.session_state:
st.session_state.ingested_batches = 0
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "modal_open" not in st.session_state:
st.session_state.modal_open = False
if "modal_content" not in st.session_state:
st.session_state.modal_content = ""
if "modal_title" not in st.session_state:
st.session_state.modal_title = ""
st.set_page_config(page_title="Cumulative JSON Vector Search (SQLite)", layout="wide")
st.title("LLM-Powered Analytics: Cumulative JSON Vector DB (SQLite, Local)")
uploaded_files = st.file_uploader(
"Upload JSON files in batches (any structure)", type="json", accept_multiple_files=True
)
# --- Helper: Flatten any unstructured JSON (handles dict, list, nested, various keys) ---
def flatten_json_obj(obj, parent_key="", sep="."):
items = {}
if isinstance(obj, dict):
for k, v in obj.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
items.update(flatten_json_obj(v, new_key, sep=sep))
elif isinstance(obj, list):
for i, v in enumerate(obj):
new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
items.update(flatten_json_obj(v, new_key, sep=sep))
else:
items[parent_key] = obj
return items
# --- Embedding function (updated for openai>=1.0.0) ---
def get_embedding(text):
client = openai.OpenAI(api_key=OPENAI_API_KEY)
response = client.embeddings.create(input=[text], model=EMBEDDING_MODEL)
return response.data[0].embedding
# --- Ensure DB Table (accumulates all uploads, never deletes old data) ---
def ensure_table():
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS json_records (
id INTEGER PRIMARY KEY AUTOINCREMENT,
batch_time TEXT,
source_file TEXT,
raw_json TEXT,
flat_text TEXT,
embedding BLOB
)
""")
conn.commit()
conn.close()
# --- Ingest and accumulate uploaded files ---
def ingest_json_files(files):
ensure_table()
rows = []
batch_time = datetime.datetime.utcnow().isoformat()
for file in files:
raw = json.load(file)
source_name = file.name
# Handle top-level list or dict
if isinstance(raw, list):
records = raw
elif isinstance(raw, dict):
main_lists = [v for v in raw.values() if isinstance(v, list)]
if main_lists:
records = main_lists[0]
else:
records = [raw]
else:
records = [raw]
for rec in records:
flat = flatten_json_obj(rec)
flat_text = "; ".join([f"{k}: {v}" for k, v in flat.items()])
rows.append((batch_time, source_name, json.dumps(rec), flat_text))
if not rows:
st.warning("No records found in uploaded files!")
return
df = pd.DataFrame(rows, columns=["batch_time", "source_file", "raw_json", "flat_text"])
st.write(f"Flattened {len(df)} records. Generating embeddings (this may take time, please wait)...")
df["embedding"] = df["flat_text"].apply(get_embedding)
# Insert into DB
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
for _, row in df.iterrows():
emb_bytes = np.array(row.embedding, dtype=np.float32).tobytes()
cursor.execute("""
INSERT INTO json_records (batch_time, source_file, raw_json, flat_text, embedding)
VALUES (?, ?, ?, ?, ?)
""", (row.batch_time, row.source_file, row.raw_json, row.flat_text, emb_bytes))
conn.commit()
conn.close()
st.success(f"Ingested and indexed {len(df)} new records!")
st.session_state.ingested_batches += 1
if uploaded_files and st.button("Ingest batch to database"):
ingest_json_files(uploaded_files)
# --- Query entire cumulative DB (ALL past and present records) ---
def query_vector_db(user_query, top_k=5):
query_emb = get_embedding(user_query)
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("SELECT id, batch_time, source_file, raw_json, flat_text, embedding FROM json_records")
results = []
for row in cursor.fetchall():
db_emb = np.frombuffer(row[5], dtype=np.float32)
if len(db_emb) != len(query_emb): continue # Skip malformed
sim = np.dot(query_emb, db_emb) / (np.linalg.norm(query_emb) * np.linalg.norm(db_emb))
results.append((sim, row))
conn.close()
results = sorted(results, reverse=True)[:top_k]
docs = []
for sim, row in results:
meta = {
"id": row[0],
"batch_time": str(row[1]),
"source_file": row[2],
"similarity": f"{sim:.4f}",
"raw_json": row[3],
}
docs.append(Document(page_content=row[4], metadata=meta))
return docs
# --- LangChain Retriever (BaseRetriever subclass, Pydantic v2 compliant) ---
class SQLiteVectorRetriever(BaseRetriever):
top_k: int = Field(default=5)
def _get_relevant_documents(self, query, run_manager=None, **kwargs):
return query_vector_db(query, self.top_k)
llm = OpenAI(model="gpt-4.1", openai_api_key=OPENAI_API_KEY, temperature=0)
retriever = SQLiteVectorRetriever(top_k=5)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
return_source_documents=True,
)
# --- Chat UI & Conversation Loop (with modal) ---
st.header("Chat with all accumulated records")
def show_json_links_and_modal():
for speaker, msg in reversed(st.session_state.chat_history):
if speaker == "AI_DOCS":
docs = msg
for idx, doc in enumerate(docs):
if st.button(f"View JSON: {doc.metadata['source_file']} (#{doc.metadata['id']})", key=f"modal_{idx}"):
st.session_state.modal_open = True
st.session_state.modal_content = json.dumps(json.loads(doc.metadata["raw_json"]), indent=2)
st.session_state.modal_title = f"{doc.metadata['source_file']} (#{doc.metadata['id']})"
break
if st.session_state.modal_open:
with st.expander(f"JSON Record: {st.session_state.modal_title}", expanded=True):
st.code(st.session_state.modal_content, language="json")
if st.button("Close", key="close_modal"):
st.session_state.modal_open = False
user_input = st.text_input("Ask a question about ALL data (old and new):", key="user_input")
if st.button("Send") and user_input:
with st.spinner("Thinking..."):
result = qa_chain(user_input)
st.session_state.chat_history.append(("User", user_input))
st.session_state.chat_history.append(("AI", result['result']))
st.session_state.chat_history.append(("AI_DOCS", result['source_documents']))
for speaker, msg in st.session_state.chat_history:
if speaker == "User":
st.markdown(f"<div style='color: #4F8BF9;'><b>User:</b> {msg}</div>", unsafe_allow_html=True)
elif speaker == "AI":
st.markdown(f"<div style='color: #1C6E4C;'><b>Agent:</b> {msg}</div>", unsafe_allow_html=True)
show_json_links_and_modal()
if st.button("Clear chat"):
st.session_state.chat_history = []
st.info(f"Batches ingested so far (this session): {st.session_state.ingested_batches}")