Seth0330 commited on
Commit
e604e69
·
verified ·
1 Parent(s): 7bbdd37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -7
app.py CHANGED
@@ -11,9 +11,8 @@ import openai
11
  from langchain.schema import Document
12
  from langchain.chains import RetrievalQA
13
  from langchain_community.llms import OpenAI as LangOpenAI
14
-
15
- # --- FIX: Import correct BaseRetriever
16
  from langchain_core.retrievers import BaseRetriever
 
17
 
18
  # ---- CONFIG ----
19
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
@@ -96,7 +95,6 @@ def ingest_json_files(files):
96
  for file in files:
97
  raw = json.load(file)
98
  source_name = file.name
99
- # Handle top-level list/dict
100
  if isinstance(raw, list):
101
  records = raw
102
  elif isinstance(raw, dict):
@@ -109,7 +107,6 @@ def ingest_json_files(files):
109
  records = [raw]
110
  for rec in records:
111
  flat = flatten_json_obj(rec)
112
- # Add entity keys if found
113
  if "customer" in rec and isinstance(rec["customer"], str):
114
  first_name = rec["customer"].split("@")[0].replace(".", " ")
115
  flat["customer_name"] = first_name
@@ -155,9 +152,7 @@ def query_vector_db(user_query, top_k=5):
155
 
156
  # ---- LangChain Retriever Adapter ----
157
  class SQLiteVectorRetriever(BaseRetriever):
158
- def __init__(self, top_k=5):
159
- self.top_k = top_k
160
- super().__init__()
161
 
162
  def get_relevant_documents(self, query: str) -> List[Document]:
163
  return query_vector_db(query, self.top_k)
 
11
  from langchain.schema import Document
12
  from langchain.chains import RetrievalQA
13
  from langchain_community.llms import OpenAI as LangOpenAI
 
 
14
  from langchain_core.retrievers import BaseRetriever
15
+ from pydantic import Field
16
 
17
  # ---- CONFIG ----
18
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
95
  for file in files:
96
  raw = json.load(file)
97
  source_name = file.name
 
98
  if isinstance(raw, list):
99
  records = raw
100
  elif isinstance(raw, dict):
 
107
  records = [raw]
108
  for rec in records:
109
  flat = flatten_json_obj(rec)
 
110
  if "customer" in rec and isinstance(rec["customer"], str):
111
  first_name = rec["customer"].split("@")[0].replace(".", " ")
112
  flat["customer_name"] = first_name
 
152
 
153
  # ---- LangChain Retriever Adapter ----
154
  class SQLiteVectorRetriever(BaseRetriever):
155
+ top_k: int = Field(default=5)
 
 
156
 
157
  def get_relevant_documents(self, query: str) -> List[Document]:
158
  return query_vector_db(query, self.top_k)