Soha85 commited on
Commit
91cc1cc
·
verified ·
1 Parent(s): c00b3ab

fixing loading files

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +34 -31
src/streamlit_app.py CHANGED
@@ -11,12 +11,20 @@ from sentence_transformers import CrossEncoder
11
  import pickle
12
  import chromadb
13
  from chromadb.utils import embedding_functions
14
-
 
15
  # Global variables
16
- collected_file = "collected_data.txt"
17
- vector_db_file = "vector_db.faiss"
18
- embedding_file = "embeddings.npy"
19
- def bert_encode(texts, batch_size=300, device="cpu"):
 
 
 
 
 
 
 
20
  model.to(device)
21
  all_embeddings = []
22
  with torch.no_grad():
@@ -35,28 +43,20 @@ tab1, tab2, tab3 = st.tabs(["Collect Data", "DB Formation", "Inquiry Vector DB"]
35
 
36
  with tab1:
37
  st.header("Collect Data")
38
-
39
- uploaded_files = st.file_uploader(
40
- "Upload your .txt files",
41
- type=["txt"],
42
- accept_multiple_files=True
43
- )
44
-
45
  if st.button("Collect") and uploaded_files:
46
  all_texts = []
47
-
48
- # Hugging Face-safe path
49
- collected_file_path = os.path.join("data", collected_file)
50
- os.makedirs("data", exist_ok=True)
51
-
52
  for uploaded_file in uploaded_files:
53
  content = uploaded_file.read().decode("utf-8", errors="ignore")
54
  all_texts.append(content)
55
 
56
  with open(collected_file_path, "w", encoding="utf-8") as f:
57
  f.write("\n".join(all_texts))
58
-
59
  st.success(f"Collected {len(uploaded_files)} files successfully!")
 
60
  # Tab 2: DB Formation
61
  with tab2:
62
  st.header("Vector DB Formation")
@@ -66,8 +66,7 @@ with tab2:
66
  index_choice = st.selectbox("Vector DB", ["FAISS","ChromaDB"])
67
  embeddings = None
68
  if st.button("Create DB"):
69
- with open("data/collected_data.txt", "r", encoding="utf-8") as f:
70
- #with open(collected_file, "r", encoding="utf-8") as f:
71
  text_data = f.read()
72
  chunks = [text_data[i:i+chunk_size] for i in range(0, len(text_data), chunk_size-overlap)]
73
 
@@ -81,7 +80,7 @@ with tab2:
81
  model_name = "bert-base-uncased"
82
  tokenizer = AutoTokenizer.from_pretrained(model_name)
83
  model = AutoModel.from_pretrained(model_name)
84
- embeddings = bert_encode(chunks)
85
 
86
  if index_choice == "FAISS":
87
  dim = len(embeddings[0])
@@ -90,8 +89,12 @@ with tab2:
90
  faiss.write_index(index, vector_db_file)
91
  np.save(embedding_file, embeddings)
92
  else: # ChromaDB
93
- client = chromadb.PersistentClient(path="chroma_db")
94
- client.delete_collection("rag_collection")
 
 
 
 
95
  collection = client.get_or_create_collection("rag_collection")
96
  collection.add(
97
  documents=chunks,
@@ -100,11 +103,11 @@ with tab2:
100
  )
101
 
102
 
103
- with open("chunks.pkl", "wb") as f:
104
  pickle.dump(chunks, f)
105
- with open("embedding_choice.txt", "w") as f:
106
  f.write(embedding_choice)
107
- with open("index_choice.txt", "w") as f:
108
  f.write(index_choice)
109
 
110
  st.write(f"Saved embeddings with shape: {embeddings.shape}")
@@ -120,11 +123,11 @@ with tab3:
120
 
121
  if st.button("Search"):
122
  # Load chunks and embedding choice and index choice
123
- with open("chunks.pkl", "rb") as f:
124
  chunks = pickle.load(f)
125
- with open("embedding_choice.txt", "r") as f:
126
  embedding_choice = f.read().strip()
127
- with open("index_choice.txt", "r") as f:
128
  index_choice = f.read().strip()
129
  #display embedding choice and index choice
130
  st.header(f"Using Embedding: {embedding_choice}, Index: {index_choice}")
@@ -140,7 +143,7 @@ with tab3:
140
  model_name = "bert-base-uncased"
141
  tokenizer = AutoTokenizer.from_pretrained(model_name)
142
  model = AutoModel.from_pretrained(model_name)
143
- query_emb = bert_encode([user_query])
144
 
145
  if index_choice == "ChromaDB":
146
  #display similarity score measure used by chromadb and illustrate what number of score means more similar and its range
@@ -149,7 +152,7 @@ with tab3:
149
  "Cosine similarity scores range from -1 to 1, where 1 indicates perfect similarity, 0 indicates no similarity, and -1 indicates " \
150
  "perfect dissimilarity.")
151
 
152
- client = chromadb.PersistentClient(path="chroma_db")
153
  collection = client.get_or_create_collection("rag_collection")
154
  results = collection.query(
155
  query_embeddings=query_emb.tolist(),
 
11
  import pickle
12
  import chromadb
13
  from chromadb.utils import embedding_functions
14
+ BASE_DIR = "/tmp/rag_app"
15
+ os.makedirs(BASE_DIR, exist_ok=True)
16
  # Global variables
17
+ collected_file = f"{BASE_DIR}/collected_data.txt"
18
+ vector_db_file = f"{BASE_DIR}/vector_db.faiss"
19
+ embedding_file = f"{BASE_DIR}/embeddings.npy"
20
+ chunks_file = f"{BASE_DIR}/chunks.pkl"
21
+ emb_choice_file = f"{BASE_DIR}/embedding_choice.txt"
22
+ index_choice_file = f"{BASE_DIR}/index_choice.txt"
23
+ chroma_dir = f"{BASE_DIR}/chroma_db"
24
+
25
+ os.makedirs(chroma_dir, exist_ok=True)
26
+
27
+ def bert_encode(model,tokenizer,texts, batch_size=300, device="cpu"):
28
  model.to(device)
29
  all_embeddings = []
30
  with torch.no_grad():
 
43
 
44
  with tab1:
45
  st.header("Collect Data")
46
+ uploaded_files = st.file_uploader("Upload your .txt files",type=["txt"], accept_multiple_files=True)
47
+ collected_file_path = collected_file
48
+
 
 
 
 
49
  if st.button("Collect") and uploaded_files:
50
  all_texts = []
 
 
 
 
 
51
  for uploaded_file in uploaded_files:
52
  content = uploaded_file.read().decode("utf-8", errors="ignore")
53
  all_texts.append(content)
54
 
55
  with open(collected_file_path, "w", encoding="utf-8") as f:
56
  f.write("\n".join(all_texts))
57
+
58
  st.success(f"Collected {len(uploaded_files)} files successfully!")
59
+
60
  # Tab 2: DB Formation
61
  with tab2:
62
  st.header("Vector DB Formation")
 
66
  index_choice = st.selectbox("Vector DB", ["FAISS","ChromaDB"])
67
  embeddings = None
68
  if st.button("Create DB"):
69
+ with open(collected_file, "r", encoding="utf-8") as f:
 
70
  text_data = f.read()
71
  chunks = [text_data[i:i+chunk_size] for i in range(0, len(text_data), chunk_size-overlap)]
72
 
 
80
  model_name = "bert-base-uncased"
81
  tokenizer = AutoTokenizer.from_pretrained(model_name)
82
  model = AutoModel.from_pretrained(model_name)
83
+ embeddings = bert_encode(model,tokenizer,chunks)
84
 
85
  if index_choice == "FAISS":
86
  dim = len(embeddings[0])
 
89
  faiss.write_index(index, vector_db_file)
90
  np.save(embedding_file, embeddings)
91
  else: # ChromaDB
92
+ # client = chromadb.PersistentClient(path="chroma_db")
93
+ client = chromadb.PersistentClient(path=chroma_dir)
94
+ try:
95
+ client.delete_collection("rag_collection")
96
+ except:
97
+ pass
98
  collection = client.get_or_create_collection("rag_collection")
99
  collection.add(
100
  documents=chunks,
 
103
  )
104
 
105
 
106
+ with open(chunks_file, "wb") as f:
107
  pickle.dump(chunks, f)
108
+ with open(emb_choice_file, "w") as f:
109
  f.write(embedding_choice)
110
+ with open(index_choice_file, "w") as f:
111
  f.write(index_choice)
112
 
113
  st.write(f"Saved embeddings with shape: {embeddings.shape}")
 
123
 
124
  if st.button("Search"):
125
  # Load chunks and embedding choice and index choice
126
+ with open(chunks_file, "rb") as f:
127
  chunks = pickle.load(f)
128
+ with open(emb_choice_file, "r") as f:
129
  embedding_choice = f.read().strip()
130
+ with open(index_choice_file, "r") as f:
131
  index_choice = f.read().strip()
132
  #display embedding choice and index choice
133
  st.header(f"Using Embedding: {embedding_choice}, Index: {index_choice}")
 
143
  model_name = "bert-base-uncased"
144
  tokenizer = AutoTokenizer.from_pretrained(model_name)
145
  model = AutoModel.from_pretrained(model_name)
146
+ query_emb = bert_encode(model,tokenizer,[user_query])
147
 
148
  if index_choice == "ChromaDB":
149
  #display similarity score measure used by chromadb and illustrate what number of score means more similar and its range
 
152
  "Cosine similarity scores range from -1 to 1, where 1 indicates perfect similarity, 0 indicates no similarity, and -1 indicates " \
153
  "perfect dissimilarity.")
154
 
155
+ client = chromadb.PersistentClient(path=chroma_dir)
156
  collection = client.get_or_create_collection("rag_collection")
157
  results = collection.query(
158
  query_embeddings=query_emb.tolist(),