Pranjal Gupta commited on
Commit
c7d967d
·
1 Parent(s): d529730

Contextual ChatBot

Browse files
Files changed (7) hide show
  1. app.py +0 -70
  2. imagequerying.py +50 -0
  3. requirement.txt +30 -0
  4. retrievingQueryResponse.py +152 -0
  5. run.py +166 -0
  6. storeConversation.py +26 -0
  7. storingEmbedding.py +128 -0
app.py DELETED
@@ -1,70 +0,0 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
-
68
-
69
- if __name__ == "__main__":
70
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
imagequerying.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import cv2
2
+ # import torch
3
+ # import ollama
4
+ # import base64
5
+ # import os
6
+ # import time
7
+ # from sentence_transformers import SentenceTransformer, util
8
+ # import chromadb
9
+ # import os
10
+ # from langchain.schema import Document # Import the Document class from LangChain
11
+ # import re
12
+ # import fitz
13
+ # from langchain_chroma import Chroma
14
+ # from chromadb.config import Settings, DEFAULT_DATABASE, DEFAULT_TENANT
15
+ # from chromadb.utils import embedding_functions
16
+ # from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ # from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
18
+ # from langchain_huggingface import HuggingFaceEmbeddings
19
+ # from langchain_core.prompts import PromptTemplate
20
+ # from langchain_core.output_parsers import StrOutputParser
21
+ # from langchain_ollama import ChatOllama
22
+
23
+
24
+ # def vision_model(file_path, query):
25
+ # """Processes an image and queries the LLaMA vision model."""
26
+ # print("<<<<< VISION MODEL STARTED >>>>>")
27
+
28
+ # image = cv2.imread(file_path)
29
+ # if image is None:
30
+ # return "Error: Failed to load image."
31
+
32
+ # _, buffer = cv2.imencode(".jpg", image)
33
+ # image_base64 = base64.b64encode(buffer).decode("utf-8")
34
+
35
+ # prompt = f"""
36
+ # Please describe the following image based on the given query.
37
+ # If the query is not relevant, respond with:
38
+ # "Sorry, I don't have enough information from this specific image."
39
+
40
+ # Query: {query}
41
+ # """
42
+
43
+ # try:
44
+ # response = ollama.chat(
45
+ # model="llama3.2-vision",
46
+ # messages=[{"role": "user", "content": prompt, "images": [image_base64]}],
47
+ # )
48
+ # return response.get("message", {}).get("content", "").strip()
49
+ # except Exception as e:
50
+ # return f"Error: {str(e)}"
requirement.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core LLM / RAG dependencies
2
+ ollama
3
+ chromadb
4
+ langchain
5
+ langchain-community
6
+ sentence-transformers
7
+
8
+ # For PDF, text & image handling
9
+ pypdf
10
+ pdfplumber
11
+ Pillow
12
+ pytesseract
13
+
14
+ # Web API / Backend
15
+ flask
16
+ flask-cors
17
+ requests
18
+
19
+ # Data handling
20
+ pandas
21
+ numpy
22
+
23
+ # Optional: Streamlit UI
24
+ streamlit
25
+
26
+ # For environment/config management
27
+ python-dotenv
28
+
29
+ # If you use MongoDB for storing docs/chat
30
+ pymongo
retrievingQueryResponse.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ import os
3
+ from langchain_chroma import Chroma
4
+ from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT
5
+ import time
6
+ import transformers
7
+ from langchain_community.llms import CTransformers
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+ from langchain_core.prompts import PromptTemplate
10
+ from transformers import pipeline
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_ollama import ChatOllama
13
+
14
+
15
+
16
+
17
+
18
+
19
+
20
+
21
+
22
+ client = chromadb.HttpClient("http://localhost:8000")
23
+
24
+
25
+ def using_ollama_model(retriever, query, results,conversation_history):
26
+
27
+ history_text = ""
28
+ for item in conversation_history:
29
+ if "question" in item and item["question"]:
30
+ history_text += f"User: {item['question']}\n"
31
+ if "answer" in item and item["answer"]:
32
+ history_text += f"Assistant: {item['answer']}\n"
33
+
34
+ print("<<<<<< LLM MODEL STARTED >>>>>>")
35
+ print(" ========>", history_text)
36
+ # Ensure the prompt template is well-structured
37
+ prompt_template = """
38
+ You are a helpful assistant. Answer the following question using the provided context and previous conversation history.
39
+ If the context does not contain the answer, only then reply with: "Sorry, I don't have enough information."
40
+ Conversation History :{history}
41
+ Context:{results}
42
+ Question:{query}
43
+ """
44
+
45
+ # Initialize the PromptTemplate
46
+
47
+ template = PromptTemplate(
48
+ input_variables=["history","results", "query"], template=prompt_template,
49
+ )
50
+
51
+ doc_texts = "\\n".join([doc.page_content for doc in results])
52
+
53
+ formatted_output = template.format(history=history_text,results=doc_texts, query=query)
54
+
55
+ print("<<<<<<<<<<< Formatted Output >>>>>>>>>>>")
56
+ print(formatted_output)
57
+ print("type of formatted output is ", type(formatted_output))
58
+
59
+
60
+ llm = ChatOllama(model="llama3.2", temperature=0.4, num_predict=512)
61
+
62
+ rag_chain = template | llm | StrOutputParser()
63
+
64
+ # results = retriever.invoke(query)
65
+ # doc_texts = "\\n".join([doc.page_content for doc in results])
66
+
67
+ answer = rag_chain.invoke({"history" : history_text,"results": doc_texts, "query": query})
68
+
69
+ return answer
70
+
71
+ # # Set up the RAG pipeline
72
+ # rag_pipeline = RetrievalQAWithSourcesChain.from_chain_type(
73
+ # llm=llm, chain_type="stuff", retriever=retriever
74
+ # )
75
+ #
76
+ # try:
77
+ # # # answer = rag_pipeline.run(formatted_output)
78
+ # answer = rag_pipeline.invoke(formatted_output)
79
+ # return answer
80
+ # except Exception as e:
81
+ # print(f"Error occurred during invocation: {e}")
82
+ # return None
83
+
84
+
85
+
86
+
87
+
88
+
89
+ def retrievingReponse(docId, query, conversation_history) :
90
+
91
+ model_kwargs = {"device": "mps"}
92
+ encode_kwargs = {"normalize_embeddings": True}
93
+ embeddings = HuggingFaceEmbeddings(
94
+ model_name="sentence-transformers/paraphrase-distilroberta-base-v1",
95
+ model_kwargs=model_kwargs,
96
+ encode_kwargs=encode_kwargs,
97
+ )
98
+
99
+ vectorDB = Chroma(
100
+ collection_name="embeddings",
101
+ embedding_function=embeddings, # Using the encode method to get embeddings
102
+ persist_directory="MM_CHROMA_DB",
103
+ )
104
+
105
+ # retriever = vectorDB.as_retriever(
106
+ # search_type="mmr",
107
+ # search_kwargs={
108
+ # "k": 6, # was 5 originally
109
+ # "lambda_mult": 1, # was 0.30 originally
110
+ # "filter": {"docId": docId}
111
+ # }
112
+ # )
113
+ retriever = vectorDB.as_retriever(
114
+ search_type="similarity",
115
+ search_kwargs={
116
+ "k": 4, # was 5 originally
117
+ # "lambda_mult": 1, # was 0.30 originally
118
+ "filter": {"docId": docId}
119
+ }
120
+ )
121
+
122
+ # retriever = vectorDB.as_retriever()
123
+ print("<<<<<<<<<<<<<<<< Retriever >>>>>>>>>>>>>>>>")
124
+ # print("d",retriever)
125
+ print("\n")
126
+
127
+ results = retriever.invoke(
128
+ query
129
+ )
130
+
131
+ unique_results = []
132
+ seen_texts = set()
133
+
134
+ for result in results:
135
+ print(result)
136
+ # If the result's content has not been seen before, process it
137
+ if result.page_content not in seen_texts:
138
+ ans = result.page_content
139
+ ans = ans.replace("\n", "") # Clean the content by removing newlines
140
+ unique_results.append(ans) # Add the cleaned answer to the results list
141
+ seen_texts.add(result.page_content) # Mark this text as seen
142
+
143
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
144
+
145
+ start = time.time()
146
+
147
+ # llm_result = using_llm_model(retriever, query, results)
148
+ llm_result = using_ollama_model(retriever, query, results, conversation_history)
149
+ end = time.time()
150
+ print("Inference Time:>>>>>>> ", end - start)
151
+ return llm_result
152
+
run.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ from pymongo import MongoClient
4
+ import uuid
5
+ import os
6
+ from storingEmbedding import process_pdf
7
+ # from imagequerying import vision_model
8
+ from retrievingQueryResponse import retrievingReponse
9
+ from storeConversation import storingConversation
10
+
11
+
12
+ app = Flask(__name__)
13
+ CORS(app)
14
+
15
+ # MongoDB Connection
16
+ client = MongoClient("mongodb://localhost:27017/")
17
+ db = client["document_system"]
18
+ docs_collection = db["documents"]
19
+ query_collection = db["queryStorage"]
20
+
21
+ UPLOAD_FOLDER = "uploads"
22
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
23
+ IMAGE_EXTENSIONS = {".png", ".svg", ".jpeg", ".jpg"}
24
+
25
+ @app.route("/getDoc", methods=["GET"])
26
+ def retireveAllDoc ():
27
+ documents = list(docs_collection.find({}, {"_id": 0})) # Exclude `_id`
28
+ return jsonify(documents)
29
+
30
+ @app.route("/upload", methods=["POST"])
31
+ def upload_document():
32
+ """Upload a document (PDF or Image), generate a unique ID, and store metadata."""
33
+ if 'file' not in request.files:
34
+ return jsonify({"error": "No file part in the request."}), 400
35
+
36
+ file = request.files['file']
37
+ if file.filename == '':
38
+ return jsonify({"error": "No file selected."}), 400
39
+
40
+ file_ext = os.path.splitext(file.filename)[1].lower()
41
+
42
+ if file_ext not in IMAGE_EXTENSIONS and file_ext != ".pdf":
43
+ return jsonify({"error": "Unsupported file type."}), 400
44
+
45
+ doc_id = str(uuid.uuid4())
46
+ file_path = os.path.join(UPLOAD_FOLDER, file.filename)
47
+ file.save(file_path)
48
+
49
+ doc_type = "pdf" if file_ext == ".pdf" else "image"
50
+
51
+ # Store metadata in MongoDB
52
+ docs_collection.insert_one({
53
+ "doc_id": doc_id,
54
+ "doc_name": file.filename,
55
+ "doc_type": file_ext,
56
+ "file_path": file_path,
57
+ "doc_Category" :doc_type
58
+ })
59
+
60
+ if file_ext == ".pdf":
61
+ process_pdf(doc_id, file_path)
62
+
63
+ return jsonify({
64
+ "message": "Document uploaded successfully.",
65
+ "doc_id": doc_id,
66
+ "doc_name": file.filename,
67
+ "doc_type": file_ext
68
+ }), 201
69
+
70
+ @app.route("/askBot", methods=["POST"])
71
+ def retrieve_answer():
72
+ print("dfghjkl")
73
+ """Retrieve an answer for the given query (text-based or image-based)."""
74
+ data = request.json
75
+
76
+ userId = data.get('userId')
77
+ userName = data.get('userName')
78
+ query = data.get('query')
79
+ docId = data.get('doc_id')
80
+
81
+ # Get document details from MongoDB
82
+ doc_info = docs_collection.find_one({"doc_id": docId})
83
+ chat_info = query_collection.find_one({"doc_id":docId})
84
+
85
+ if not doc_info:
86
+ return jsonify({"error": "Document ID not found"}), 404
87
+
88
+ file_type = doc_info["doc_type"]
89
+ file_path = doc_info["file_path"]
90
+ doc_name = doc_info['doc_name']
91
+ conversation_history = chat_info['conversation']
92
+
93
+ if file_type == ".pdf":
94
+ response = retrievingReponse(docId, query, conversation_history)
95
+ elif file_type in IMAGE_EXTENSIONS:
96
+ response = vision_model(file_path, query)
97
+ else:
98
+ return jsonify({"error": "Unsupported file type"}), 400
99
+
100
+
101
+ storingConversation(docId,query,response,doc_name)
102
+
103
+ return jsonify({
104
+ "question":query,
105
+ "answer": response,
106
+ "doc_id": docId
107
+ }), 201
108
+
109
+
110
+ @app.route("/getChat", methods=["GET"])
111
+ def get_chats():
112
+
113
+ doc_id = request.args.get("doc_id")
114
+
115
+ if doc_id:
116
+ # Fetch complete chat history for the given doc_id
117
+ chat_session = query_collection.find_one({"doc_id": doc_id}, {"_id": 0})
118
+ if not chat_session:
119
+ return jsonify({"error": "No chat found for this document"}), 404
120
+ return jsonify(chat_session)
121
+
122
+ else:
123
+ # Fetch only doc_id and chatHeading for all documents
124
+ all_chats = list(query_collection.find({}, {"_id": 0, "doc_id": 1, "chatHeading": 1,"doc_name":1}))
125
+ return jsonify({"chats": all_chats})
126
+
127
+ @app.route("/deleteDoc", methods=["DELETE"])
128
+ def delete_document():
129
+ """Delete a document and its associated data."""
130
+ doc_id = request.args.get("doc_id")
131
+
132
+ if not doc_id:
133
+ return jsonify({"error": "Missing doc_id"}), 400
134
+
135
+ doc_info = docs_collection.find_one({"doc_id": doc_id})
136
+ if not doc_info:
137
+ return jsonify({"error": "Document not found"}), 404
138
+
139
+ # Delete physical file
140
+ file_path = doc_info.get("file_path")
141
+ if file_path and os.path.exists(file_path):
142
+ os.remove(file_path)
143
+
144
+ # Delete from MongoDB
145
+ docs_collection.delete_one({"doc_id": doc_id})
146
+ query_collection.delete_many({"doc_id": doc_id}) # for all chats of that doc
147
+
148
+ return jsonify({"message": "Document and related data deleted successfully."}), 200
149
+
150
+ @app.route("/viewDoc", methods=["GET"])
151
+ def view_doc():
152
+ doc_name = request.args.get("docName")
153
+ if not doc_name:
154
+ return jsonify({"error": "Missing doc_name"}), 400
155
+
156
+ # Optional: check if file actually exists
157
+ file_path = os.path.join(UPLOAD_FOLDER, doc_name)
158
+ if not os.path.isfile(file_path):
159
+ return jsonify({"error": "File not found"}), 404
160
+
161
+ return jsonify({
162
+ "url": f"/uploads/{doc_name}"
163
+ })
164
+
165
+ if __name__ == "__main__":
166
+ app.run(debug=True, host='0.0.0.0', port=5001)
storeConversation.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymongo import MongoClient
2
+
3
+
4
+ client = MongoClient("mongodb://localhost:27017/") # Update the URI if needed
5
+ db = client["document_system"]
6
+ query_collection = db["queryStorage"]
7
+
8
+ def storingConversation (doc_id,user_query,model_reply,doc_name ):
9
+ existing_chat = query_collection.find_one({"doc_id": doc_id})
10
+
11
+ if not existing_chat:
12
+ # Create new chat session with the first message as chatHeading
13
+ chat_session = {
14
+ "doc_id": doc_id,
15
+ "doc_name":doc_name,
16
+ "chatHeading": user_query, # First question becomes the heading
17
+ "conversation": []
18
+ }
19
+ query_collection.insert_one(chat_session)
20
+
21
+
22
+ # Update the conversation array in MongoDB
23
+ query_collection.update_one(
24
+ {"doc_id": doc_id},
25
+ {"$push": {"conversation": {"question": user_query, "answer": model_reply}}}
26
+ )
storingEmbedding.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer, util
2
+ import chromadb
3
+ import os
4
+ from langchain.schema import Document
5
+ import re
6
+ import fitz
7
+ from langchain_chroma import Chroma
8
+ # from langchain.utils import embedding_functions
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_huggingface import HuggingFaceEmbeddings
11
+ import shutil
12
+
13
+
14
+ def initialize_chroma_db(collection_name, embeddings, persist_directory):
15
+ try:
16
+ print("Trying to load existing Chroma DB...")
17
+ vectorDB = Chroma(
18
+ collection_name=collection_name,
19
+ embedding_function=embeddings,
20
+ persist_directory=persist_directory,
21
+ )
22
+ print("Chroma DB loaded successfully.")
23
+ return vectorDB
24
+ except Exception as e:
25
+ print(f"Error loading Chroma DB: {e}")
26
+ print("Deleting corrupted persist directory and rebuilding...")
27
+ if os.path.exists(persist_directory):
28
+ shutil.rmtree(persist_directory)
29
+ # Recreate
30
+ vectorDB = Chroma(
31
+ collection_name=collection_name,
32
+ embedding_function=embeddings,
33
+ persist_directory=persist_directory,
34
+ )
35
+ print("New Chroma DB created.")
36
+ return vectorDB
37
+
38
+ # Function to extract text from PDF
39
+ def extract_text_from_pdf(pdf_file):
40
+ try:
41
+ if os.path.exists(pdf_file):
42
+ doc = fitz.open(pdf_file)
43
+ text = ""
44
+ for page in doc:
45
+ text += page.get_text("text")
46
+ return text
47
+ else:
48
+ print("No pdf file exists by this name.")
49
+ except Exception as e:
50
+ print(e)
51
+
52
+ # Function to clean symbols using regex
53
+ def applying_symbol_regex(text):
54
+ remove_symbols_text = re.sub(r"""[,._/?''"";{}\-*&^%$#@!,\\|()+=`~<>]""", "", text)
55
+ return remove_symbols_text
56
+
57
+ # Function to clean whitespaces
58
+ def clean_text(input_text):
59
+ cleaned_text = re.sub(r"\s+ ", " ", input_text)
60
+ cleaned_text = cleaned_text.strip()
61
+ clean_text = cleaned_text.replace("\n", "")
62
+ return clean_text
63
+
64
+ # Main processing function
65
+ def process_pdf(docId,pdf_file_path, collection_name="embeddings", persist_directory="./MM_CHROMA_DB"):
66
+ print(docId)
67
+ # Extract text from the PDF
68
+ pdf_result = extract_text_from_pdf(pdf_file_path)
69
+
70
+ # Apply regex to remove symbols
71
+ regex_result = applying_symbol_regex(pdf_result)
72
+
73
+ # Clean text result
74
+ clean_text_result = clean_text(regex_result)
75
+ print("Total tokens without symbols in a PDF => ", len(clean_text_result))
76
+
77
+ document = Document(page_content=clean_text_result)
78
+ print("came here")
79
+ # Splitting the document into chunks
80
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=30)
81
+ chunks = text_splitter.split_documents([document])
82
+
83
+ # Set up the embedding function
84
+ model_kwargs = {"device": "mps"}
85
+ encode_kwargs = {"normalize_embeddings": True}
86
+ embeddings = HuggingFaceEmbeddings(
87
+ model_name="sentence-transformers/paraphrase-distilroberta-base-v1",
88
+ model_kwargs=model_kwargs,
89
+ encode_kwargs=encode_kwargs,
90
+ )
91
+ print("beore vectorDB")
92
+ print("persist_directory exists:", os.path.exists(persist_directory))
93
+
94
+ # Set up the Chroma database
95
+ vectorDB = initialize_chroma_db(collection_name, embeddings, persist_directory)
96
+ print("after vectorDB")
97
+
98
+ metadata_chunks = []
99
+ # Concatenate all chunks into a single string
100
+ for i, chunk in enumerate(chunks):
101
+ # Add metadata to each chunk
102
+ metadata = {"source": f"example_source_{i}", "docId":str(docId)}
103
+ id = str(i)
104
+ doc_with_metadata = Document(
105
+ page_content=chunk.page_content, metadata=metadata, id=id,docId=docId
106
+ )
107
+ metadata_chunks.append(doc_with_metadata)
108
+
109
+ print("Done")
110
+
111
+ # Add the documents to the vector database
112
+ try:
113
+ vectorDB.add_documents(metadata_chunks)
114
+ except:
115
+ raise Exception()
116
+
117
+ # for i, chunk in enumerate(chunks):
118
+ # metadata = {"source": f"example_source_{i}"}
119
+
120
+ # # Use the same document ID for all chunks
121
+ # doc_with_metadata = Document(
122
+ # page_content=chunk.page_content, metadata=metadata, id=docId
123
+ # )
124
+ # print(f"Chunk {i} => {chunk.page_content}")
125
+ # print("\n")
126
+
127
+
128
+ print("Documents have been added to the vector database.")