DOMMETI commited on
Commit
f31b2d0
·
verified ·
1 Parent(s): c4deb04

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +63 -21
src/streamlit_app.py CHANGED
@@ -6,31 +6,72 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
6
 
7
  st.set_page_config(page_title="RAG Search", page_icon="🔍")
8
 
9
- # --- 1️⃣ Unzip chroma_db.zip if not already extracted ---
10
- ZIP_PATH = os.path.join(os.path.dirname(__file__), "..", "chroma_db.zip")
11
- DB_PATH = os.path.join(os.path.dirname(__file__), "..", "chroma_db")
12
-
13
- if not os.path.exists(DB_PATH):
14
- st.info("📦 Extracting Chroma DB for first-time setup...")
15
- with zipfile.ZipFile(ZIP_PATH, "r") as zip_ref:
16
- zip_ref.extractall(DB_PATH)
17
- st.success("✅ Database extracted successfully!")
18
-
19
- # --- 2️⃣ Initialize embedding model ---
20
- embeddings = HuggingFaceEmbeddings(
21
- model_name="mixedbread-ai/mxbai-embed-large-v1",
22
- model_kwargs={"device": "cpu"} # Force CPU for Hugging Face Spaces
23
- )
24
-
25
- # --- 3️⃣ Load Chroma database ---
26
- vectordb = Chroma(persist_directory=DB_PATH, embedding_function=embeddings)
27
-
28
- # --- 4️⃣ Query input & results ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  query = st.text_input("Enter your query:", "What is SystemVerilog interface?")
30
 
31
  if st.button("Search"):
32
  st.write("🔎 Searching your local vector database...")
33
  results = vectordb.similarity_search(query, k=3)
 
34
  if results:
35
  for i, doc in enumerate(results):
36
  st.subheader(f"Result {i+1}")
@@ -38,4 +79,5 @@ if st.button("Search"):
38
  st.caption(doc.metadata)
39
  st.markdown("---")
40
  else:
41
- st.warning("⚠️ No results found.")
 
 
6
 
7
  st.set_page_config(page_title="RAG Search", page_icon="🔍")
8
 
9
+ # ---------------------------------------------
10
+ # 1️⃣ Locate your Chroma DB zip
11
+ # ---------------------------------------------
12
+ possible_paths = [
13
+ "chroma_db.zip",
14
+ os.path.join("src", "chroma_db.zip"),
15
+ os.path.join(os.path.dirname(__file__), "chroma_db.zip"),
16
+ os.path.join(os.path.dirname(__file__), "..", "chroma_db.zip"),
17
+ "/app/chroma_db.zip",
18
+ ]
19
+
20
+ ZIP_PATH = None
21
+ for p in possible_paths:
22
+ if os.path.exists(p):
23
+ ZIP_PATH = p
24
+ break
25
+
26
+ if ZIP_PATH is None:
27
+ st.error("❌ Could not find 'chroma_db.zip'. Please ensure it's in your repo root.")
28
+ st.stop()
29
+
30
+ DB_PATH = "chroma_db"
31
+
32
+ # ---------------------------------------------
33
+ # 2️⃣ Extract only once per app session
34
+ # ---------------------------------------------
35
+ if "db_ready" not in st.session_state:
36
+ if not os.path.exists(DB_PATH):
37
+ st.info("📦 Extracting Chroma DB for the first time...")
38
+ with zipfile.ZipFile(ZIP_PATH, "r") as zip_ref:
39
+ zip_ref.extractall(DB_PATH)
40
+ st.success("✅ Database extracted successfully!")
41
+ else:
42
+ st.info("✅ Chroma DB folder already exists.")
43
+ st.session_state.db_ready = True # Mark as ready
44
+
45
+ # ---------------------------------------------
46
+ # 3️⃣ Load embeddings (CPU safe)
47
+ # ---------------------------------------------
48
+ @st.cache_resource(show_spinner=False)
49
+ def load_embeddings():
50
+ return HuggingFaceEmbeddings(
51
+ model_name="mixedbread-ai/mxbai-embed-large-v1",
52
+ model_kwargs={"device": "cpu"}
53
+ )
54
+
55
+ embeddings = load_embeddings()
56
+
57
+ # ---------------------------------------------
58
+ # 4️⃣ Load Chroma DB (cached)
59
+ # ---------------------------------------------
60
+ @st.cache_resource(show_spinner=False)
61
+ def load_vectordb():
62
+ return Chroma(persist_directory=DB_PATH, embedding_function=embeddings)
63
+
64
+ vectordb = load_vectordb()
65
+
66
+ # ---------------------------------------------
67
+ # 5️⃣ Query + Display Results
68
+ # ---------------------------------------------
69
  query = st.text_input("Enter your query:", "What is SystemVerilog interface?")
70
 
71
  if st.button("Search"):
72
  st.write("🔎 Searching your local vector database...")
73
  results = vectordb.similarity_search(query, k=3)
74
+
75
  if results:
76
  for i, doc in enumerate(results):
77
  st.subheader(f"Result {i+1}")
 
79
  st.caption(doc.metadata)
80
  st.markdown("---")
81
  else:
82
+ st.warning("⚠️ No matching results found.")
83
+