Update app.py
Browse files
app.py
CHANGED
|
@@ -11,9 +11,8 @@ import openai
|
|
| 11 |
from langchain.schema import Document
|
| 12 |
from langchain.chains import RetrievalQA
|
| 13 |
from langchain_community.llms import OpenAI as LangOpenAI
|
| 14 |
-
|
| 15 |
-
# --- FIX: Import correct BaseRetriever
|
| 16 |
from langchain_core.retrievers import BaseRetriever
|
|
|
|
| 17 |
|
| 18 |
# ---- CONFIG ----
|
| 19 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
@@ -96,7 +95,6 @@ def ingest_json_files(files):
|
|
| 96 |
for file in files:
|
| 97 |
raw = json.load(file)
|
| 98 |
source_name = file.name
|
| 99 |
-
# Handle top-level list/dict
|
| 100 |
if isinstance(raw, list):
|
| 101 |
records = raw
|
| 102 |
elif isinstance(raw, dict):
|
|
@@ -109,7 +107,6 @@ def ingest_json_files(files):
|
|
| 109 |
records = [raw]
|
| 110 |
for rec in records:
|
| 111 |
flat = flatten_json_obj(rec)
|
| 112 |
-
# Add entity keys if found
|
| 113 |
if "customer" in rec and isinstance(rec["customer"], str):
|
| 114 |
first_name = rec["customer"].split("@")[0].replace(".", " ")
|
| 115 |
flat["customer_name"] = first_name
|
|
@@ -155,9 +152,7 @@ def query_vector_db(user_query, top_k=5):
|
|
| 155 |
|
| 156 |
# ---- LangChain Retriever Adapter ----
|
| 157 |
class SQLiteVectorRetriever(BaseRetriever):
|
| 158 |
-
|
| 159 |
-
self.top_k = top_k
|
| 160 |
-
super().__init__()
|
| 161 |
|
| 162 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
| 163 |
return query_vector_db(query, self.top_k)
|
|
|
|
| 11 |
from langchain.schema import Document
|
| 12 |
from langchain.chains import RetrievalQA
|
| 13 |
from langchain_community.llms import OpenAI as LangOpenAI
|
|
|
|
|
|
|
| 14 |
from langchain_core.retrievers import BaseRetriever
|
| 15 |
+
from pydantic import Field
|
| 16 |
|
| 17 |
# ---- CONFIG ----
|
| 18 |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
|
|
| 95 |
for file in files:
|
| 96 |
raw = json.load(file)
|
| 97 |
source_name = file.name
|
|
|
|
| 98 |
if isinstance(raw, list):
|
| 99 |
records = raw
|
| 100 |
elif isinstance(raw, dict):
|
|
|
|
| 107 |
records = [raw]
|
| 108 |
for rec in records:
|
| 109 |
flat = flatten_json_obj(rec)
|
|
|
|
| 110 |
if "customer" in rec and isinstance(rec["customer"], str):
|
| 111 |
first_name = rec["customer"].split("@")[0].replace(".", " ")
|
| 112 |
flat["customer_name"] = first_name
|
|
|
|
| 152 |
|
| 153 |
# ---- LangChain Retriever Adapter ----
|
| 154 |
class SQLiteVectorRetriever(BaseRetriever):
|
| 155 |
+
top_k: int = Field(default=5)
|
|
|
|
|
|
|
| 156 |
|
| 157 |
def get_relevant_documents(self, query: str) -> List[Document]:
|
| 158 |
return query_vector_db(query, self.top_k)
|