purajith commited on
Commit
2115b69
·
verified ·
1 Parent(s): 0cf1576

Update hybrid_search.py

Browse files
Files changed (1) hide show
  1. hybrid_search.py +185 -184
hybrid_search.py CHANGED
@@ -1,184 +1,185 @@
1
- import os
2
- from langchain.vectorstores import FAISS
3
- from langchain.embeddings.openai import OpenAIEmbeddings
4
- from langchain.embeddings.huggingface import HuggingFaceEmbeddings
5
- from langchain.document_loaders import TextLoader
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
8
- from langchain.schema import Document
9
- from langchain.chains import ConversationChain
10
- from langchain.chains.conversation.memory import ConversationBufferWindowMemory
11
- from langchain.callbacks import get_openai_callback
12
- from sentence_transformers import CrossEncoder
13
- from langchain.chat_models import ChatOpenAI
14
- from sentence_transformers import SentenceTransformer
15
- from data_extraction import process_files
16
- from dotenv import load_dotenv
17
- import warnings
18
- warnings.filterwarnings("ignore")
19
- load_dotenv()
20
- # 🔹 Set OpenAI API Key
21
- all_hybrid_retriever = {}
22
- file_names = []
23
- llm_conversations = {} # {filename: ConversationChain}
24
- all_result = {}
25
- al_conversation_sum = {}
26
- openai_key = os.getenv("openai_key")
27
- os.environ["OPENAI_API_KEY"] = openai_key # Ensure 'openai_key' is defined
28
-
29
- reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
30
- def large_model(llm_model):
31
- llm = ChatOpenAI(openai_api_key=openai_key, model="llm_model")
32
- return llm
33
-
34
- # 🔹 Choose Embedding Model
35
- embedding_option = "open_source"
36
-
37
- if embedding_option == "open_source":
38
- print("Using BGE-M3 Embeddings")
39
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
40
- else:
41
- print("Using OpenAI Embeddings")
42
- embeddings = OpenAIEmbeddings(openai_api_key=openai_key, model="text-embedding-ada-002")
43
-
44
- class ManualMemory:
45
- def __init__(self, history_length=3):
46
- self.history = [] # Stores chat history
47
- self.history_length = history_length # How many interactions to keep
48
-
49
- def add_interaction(self, user_query, llm_response):
50
- """Add the user's query and the LLM's response to history."""
51
- # Add the interaction as a tuple (user_query, llm_response)
52
- self.history.append((user_query, llm_response))
53
- # Keep only the last 'history_length' interactions
54
- if len(self.history) > self.history_length:
55
- self.history.pop(0)
56
-
57
- def get_history(self):
58
- """Return the current chat history."""
59
- return "\n".join([f"User: {q}\nLLM: {r}" for q, r in self.history])
60
-
61
-
62
- # 🔹 Function to Create Separate LLM + Memory for Each File
63
- def create_conversation_chain():
64
- llm = ChatOpenAI(openai_api_key=openai_key, model="gpt-4o-mini")
65
- memory = ConversationBufferWindowMemory(k=0) # Stores last 3 interactions per file
66
- return ConversationChain(llm=llm, memory=memory)
67
-
68
- def hybrid_retrievers(split_docs):
69
- # Create Vector Store and Retrievers
70
- vector_store = FAISS.from_documents(split_docs, embeddings)
71
- dense_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
72
-
73
- bm25_retriever = BM25Retriever.from_documents(split_docs)
74
- bm25_retriever.k = 4
75
-
76
- hybrid_retriever = EnsembleRetriever(
77
- retrievers=[dense_retriever, bm25_retriever],
78
- weights=[0.5, 0.5])
79
- return hybrid_retriever
80
-
81
- def rerank_with_cross_encoder(query, documents):
82
- """Re-rank retrieved documents using a cross-encoder model."""
83
- input_pairs = [(query, doc.page_content) for doc in documents]
84
- scores = reranker.predict(input_pairs)
85
- ranked_results = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
86
- print("ranked_results",ranked_results)
87
- return ranked_results
88
-
89
-
90
- def count_tokens(chain, query, retriever, memory):
91
- """Retrieve documents, run LLM, and count tokens."""
92
- # Retrieve documents but don't store them in memory
93
- retrieved_docs = retriever.get_relevant_documents(query)
94
- reranked_docs = rerank_with_cross_encoder(query, retrieved_docs)
95
- retrieved_text = "\n\n".join([doc.page_content for doc, _ in reranked_docs]) # Extract text
96
-
97
- # Construct the prompt using the chat history and retrieved text
98
- prompt = f"""You are a cybersecurity expert RAG bot, answering queries using retrieved documents and Chat history.
99
- Retrieved documents: \n{retrieved_text}\n\nQuestion: {query}
100
-
101
- Chat history:
102
- {memory.get_history()}
103
-
104
- If the documents are relevant, use them to answer.
105
- If they don’t have enough useful information, say:
106
- "No info."
107
- Keep your responses clear and accurate."""
108
-
109
- # Generate response using the LLM and the prompt
110
- with get_openai_callback() as cb:
111
- result = chain.run(prompt) # Pass query + retrieved context + chat history as prompt
112
- print(f"Spent a total of {cb.total_tokens} tokens")
113
-
114
- # Store the interaction in memory
115
- memory.add_interaction(query, result)
116
-
117
- return result, reranked_docs
118
-
119
-
120
-
121
- manual_memory = ManualMemory(history_length=3)
122
- all_manual_memory = {}
123
- all_retrieved_docs = {}
124
- all_combined_chunks = {}
125
- all_hybrid_retriever = {}
126
- al_conversation_sum = {}
127
-
128
- # Global variables to track previous file paths and embeddings
129
- old_file_paths = []
130
- old_embeding = None # Initialize properly
131
- def multimodelrag(query, file_paths, embeding, llm_model,conversation=3):
132
- global old_file_paths, old_embeding
133
- global all_manual_memory, all_retrieved_docs, all_combined_chunks, all_hybrid_retriever, al_conversation_sum
134
-
135
- print("query, file_paths, embeding, conversation, llm_model", query, file_paths, embeding, conversation, llm_model)
136
-
137
- if embedding_option == embeding:
138
- print("Using BGE-M3 Embeddings")
139
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
140
- else:
141
- print("Using OpenAI Embeddings")
142
- embeddings = OpenAIEmbeddings(openai_api_key=openai_key, model="text-embedding-ada-002")
143
-
144
- llm = ChatOpenAI(openai_api_key=openai_key, model=llm_model)
145
-
146
- if (old_file_paths != file_paths) or (old_embeding != embeding):
147
- # Reset memory only when new files are loaded
148
- all_manual_memory = {}
149
- all_retrieved_docs = {}
150
- all_combined_chunks = {}
151
- all_hybrid_retriever = {}
152
- al_conversation_sum = {}
153
-
154
- for file__name in file_paths:
155
- file = file__name.split("/")[-1]
156
-
157
- print("Processing file:", file)
158
- old_embeding = embeding
159
- old_file_paths = file_paths
160
-
161
- combined_chunks = process_files(file__name)
162
-
163
- all_combined_chunks[file] = combined_chunks
164
- all_hybrid_retriever[file] = hybrid_retrievers(all_combined_chunks[file])
165
- al_conversation_sum[file] = create_conversation_chain()
166
-
167
- # ✅ Create a separate memory instance for each file
168
- all_manual_memory[file] = ManualMemory(history_length=conversation)
169
-
170
- # Using query
171
- all_result[file], all_retrieved_docs[file] = count_tokens(
172
- al_conversation_sum[file], query, all_hybrid_retriever[file], all_manual_memory[file]
173
- )
174
- else:
175
- # Reuse existing memory for the same file
176
- for file__name in file_paths:
177
- file = file__name.split("/")[-1]
178
- print("Reusing memory for:", file)
179
-
180
- all_result[file], all_retrieved_docs[file] = count_tokens(
181
- al_conversation_sum[file], query, all_hybrid_retriever[file], all_manual_memory[file]
182
- )
183
-
184
- return all_result
 
 
1
+ import os
2
+ from langchain.vectorstores import FAISS
3
+ from langchain.embeddings.openai import OpenAIEmbeddings
4
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
5
+ from langchain.document_loaders import TextLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
8
+ from langchain.schema import Document
9
+ from langchain.chains import ConversationChain
10
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
11
+ from langchain.callbacks import get_openai_callback
12
+ from sentence_transformers import CrossEncoder
13
+ from langchain.chat_models import ChatOpenAI
14
+ from sentence_transformers import SentenceTransformer
15
+ from data_extraction import process_files
16
+ from dotenv import load_dotenv
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+ load_dotenv()
20
+ # 🔹 Set OpenAI API Key
21
+ all_hybrid_retriever = {}
22
+ file_names = []
23
+ llm_conversations = {} # {filename: ConversationChain}
24
+ all_result = {}
25
+ al_conversation_sum = {}
26
+ openai_key = os.getenv("openai_key")
27
+ os.environ["OPENAI_API_KEY"] = openai_key # Ensure 'openai_key' is defined
28
+
29
+ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
30
+ def large_model(llm_model):
31
+ llm = ChatOpenAI(openai_api_key=openai_key, model="llm_model")
32
+ return llm
33
+
34
+ # 🔹 Choose Embedding Model
35
+ embedding_option = "open_source"
36
+
37
+ if embedding_option == "open_source":
38
+ print("Using BGE-M3 Embeddings")
39
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
40
+ else:
41
+ print("Using OpenAI Embeddings")
42
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_key, model="text-embedding-ada-002")
43
+
44
+ class ManualMemory:
45
+ def __init__(self, history_length=3):
46
+ self.history = [] # Stores chat history
47
+ self.history_length = history_length # How many interactions to keep
48
+
49
+ def add_interaction(self, user_query, llm_response):
50
+ """Add the user's query and the LLM's response to history."""
51
+ # Add the interaction as a tuple (user_query, llm_response)
52
+ self.history.append((user_query, llm_response))
53
+ # Keep only the last 'history_length' interactions
54
+ if len(self.history) > self.history_length:
55
+ self.history.pop(0)
56
+
57
+ def get_history(self):
58
+ """Return the current chat history."""
59
+ return "\n".join([f"User: {q}\nLLM: {r}" for q, r in self.history])
60
+
61
+
62
+ # 🔹 Function to Create Separate LLM + Memory for Each File
63
+ def create_conversation_chain():
64
+ llm = ChatOpenAI(openai_api_key=openai_key, model="gpt-4o-mini")
65
+ memory = ConversationBufferWindowMemory(k=0) # Stores last 3 interactions per file
66
+ return ConversationChain(llm=llm, memory=memory)
67
+
68
+ def hybrid_retrievers(split_docs):
69
+ # Create Vector Store and Retrievers
70
+ vector_store = FAISS.from_documents(split_docs, embeddings)
71
+ dense_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
72
+
73
+ bm25_retriever = BM25Retriever.from_documents(split_docs)
74
+ bm25_retriever.k = 4
75
+
76
+ hybrid_retriever = EnsembleRetriever(
77
+ retrievers=[dense_retriever, bm25_retriever],
78
+ weights=[0.5, 0.5])
79
+ return hybrid_retriever
80
+
81
+ def rerank_with_cross_encoder(query, documents):
82
+ """Re-rank retrieved documents using a cross-encoder model."""
83
+ input_pairs = [(query, doc.page_content) for doc in documents]
84
+ scores = reranker.predict(input_pairs)
85
+ ranked_results = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
86
+ print("ranked_results",ranked_results)
87
+ return ranked_results
88
+
89
+
90
+ def count_tokens(chain, query, retriever, memory):
91
+ """Retrieve documents, run LLM, and count tokens."""
92
+ # Retrieve documents but don't store them in memory
93
+ retrieved_docs = retriever.get_relevant_documents(query)
94
+ reranked_docs = rerank_with_cross_encoder(query, retrieved_docs)
95
+ retrieved_text = "\n\n".join([doc.page_content for doc, _ in reranked_docs]) # Extract text
96
+
97
+ # Construct the prompt using the chat history and retrieved text
98
+ prompt = f"""You are a cybersecurity expert RAG bot, answering queries using retrieved documents and Chat history.
99
+ Retrieved documents: \n{retrieved_text}\n\nQuestion: {query}
100
+
101
+ Chat history:
102
+ {memory.get_history()}
103
+
104
+ If the documents are relevant, use them to answer.
105
+ If they don’t have enough useful information, say:
106
+ "No info."
107
+ Keep your responses clear and accurate."""
108
+
109
+ # Generate response using the LLM and the prompt
110
+ with get_openai_callback() as cb:
111
+ result = chain.run(prompt) # Pass query + retrieved context + chat history as prompt
112
+ print(f"Spent a total of {cb.total_tokens} tokens")
113
+
114
+ # Store the interaction in memory
115
+ memory.add_interaction(query, result)
116
+
117
+ return result, reranked_docs
118
+
119
+
120
+
121
+ # manual_memory = ManualMemory(history_length=3)
122
+ all_manual_memory = {}
123
+ all_retrieved_docs = {}
124
+ all_combined_chunks = {}
125
+ all_hybrid_retriever = {}
126
+ al_conversation_sum = {}
127
+
128
+ # Global variables to track previous file paths and embeddings
129
+ old_file_paths = []
130
+ old_embeding = None # Initialize properly
131
+ def multimodelrag(query, file_paths, embeding, llm_model,conversation=3):
132
+ global old_file_paths, old_embeding
133
+ global all_manual_memory, all_retrieved_docs, all_combined_chunks, all_hybrid_retriever, al_conversation_sum ,all_result
134
+
135
+ print("query, file_paths, embeding, conversation, llm_model", query, file_paths, embeding, conversation, llm_model)
136
+
137
+ if embedding_option == embeding:
138
+ print("Using BGE-M3 Embeddings")
139
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
140
+ else:
141
+ print("Using OpenAI Embeddings")
142
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_key, model="text-embedding-ada-002")
143
+
144
+ llm = ChatOpenAI(openai_api_key=openai_key, model=llm_model)
145
+
146
+ if (old_file_paths != file_paths) or (old_embeding != embeding):
147
+ # Reset memory only when new files are loaded
148
+ all_manual_memory = {}
149
+ all_retrieved_docs = {}
150
+ all_combined_chunks = {}
151
+ all_hybrid_retriever = {}
152
+ al_conversation_sum = {}
153
+ all_result ={}
154
+
155
+ for file__name in file_paths:
156
+ file = file__name.split("/")[-1]
157
+
158
+ print("Processing file:", file)
159
+ old_embeding = embeding
160
+ old_file_paths = file_paths
161
+
162
+ combined_chunks = process_files(file__name)
163
+
164
+ all_combined_chunks[file] = combined_chunks
165
+ all_hybrid_retriever[file] = hybrid_retrievers(all_combined_chunks[file])
166
+ al_conversation_sum[file] = create_conversation_chain()
167
+
168
+ # Create a separate memory instance for each file
169
+ all_manual_memory[file] = ManualMemory(history_length=conversation)
170
+
171
+ # Using query
172
+ all_result[file], all_retrieved_docs[file] = count_tokens(
173
+ al_conversation_sum[file], query, all_hybrid_retriever[file], all_manual_memory[file]
174
+ )
175
+ else:
176
+ # Reuse existing memory for the same file
177
+ for file__name in file_paths:
178
+ file = file__name.split("/")[-1]
179
+ print("Reusing memory for:", file)
180
+
181
+ all_result[file], all_retrieved_docs[file] = count_tokens(
182
+ al_conversation_sum[file], query, all_hybrid_retriever[file], all_manual_memory[file]
183
+ )
184
+
185
+ return all_result