purajith commited on
Commit
f5f1a85
·
verified ·
1 Parent(s): c40d08d

Upload 5 file

Browse files
Files changed (5) hide show
  1. .env +7 -0
  2. data_extraction.py +171 -0
  3. hybrid_search.py +184 -0
  4. requirements.txt +17 -0
  5. stm.py +70 -0
.env ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ FILESYSTEM_CLOUD=s3
2
+ FILESYSTEM_DERIVER=s3
3
+ AWS_BUCKET=esg-portal-dev
4
+ AWS_DEFAULT_REGION=us-east-1
5
+ AWS_ACCESS_KEY_ID=AKIASDLTYHTYTP6L6I7F
6
+ AWS_SECRET_ACCESS_KEY=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
7
+ openai_key = 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
data_extraction.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from docx import Document as DocxDocument # Avoids conflict with langchain's Document
3
+ import csv
4
+ import fitz # PyMuPDF for text extraction
5
+ import camelot # Table extraction
6
+ from langchain.schema import Document # Structured document format
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ import os
9
+ from dotenv import load_dotenv
10
+ load_dotenv()
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+ # Ensure the API key is properly set
14
+ openai_key = os.getenv("openai_key")
15
+ os.environ["OPENAI_API_KEY"] = openai_key # Ensure 'openai_key' is defined
16
+ # Function to read and process .docx files
17
+ def extract_text_and_tables(docx_path):
18
+ doc = DocxDocument(docx_path) # Use renamed import to avoid conflict
19
+
20
+ # Extract text
21
+ text = "\n".join([para.text for para in doc.paragraphs if para.text.strip()])
22
+
23
+ # Extract tables
24
+ tables = []
25
+ for table in doc.tables:
26
+ table_data = []
27
+ for row in table.rows:
28
+ row_data = [cell.text.strip() for cell in row.cells]
29
+ table_data.append(row_data)
30
+ tables.append(Document(page_content=str(table_data), metadata={"source": docx_path})) # Store as Document object
31
+
32
+ return text, tables
33
+
34
+ # Function to read and process .xlsx (Excel) files
35
+ def read_excel(file_path):
36
+ print(f"Reading Excel file: {file_path}")
37
+ excel_data = pd.read_excel(file_path, sheet_name=None)
38
+
39
+ text = []
40
+ for sheet_name, df in excel_data.items():
41
+ text.append(f"Sheet: {sheet_name}")
42
+ for row in df.values:
43
+ row_text = " | ".join(str(cell) for cell in row)
44
+ text.append(row_text)
45
+
46
+ return text
47
+
48
+ # Function to read and process .csv files
49
+ def read_csv(file_path):
50
+ print(f"Reading CSV file: {file_path}")
51
+
52
+ text = []
53
+ with open(file_path, mode='r') as file:
54
+ reader = csv.reader(file)
55
+ for row in reader:
56
+ row_text = " | ".join(row)
57
+ text.append(row_text)
58
+
59
+ return text
60
+
61
+ # Function to extract text from PDFs
62
+ def extract_text(pdf_path):
63
+ """Extracts text from a PDF file and returns it as a list of Document objects."""
64
+ documents = []
65
+ try:
66
+ doc = fitz.open(pdf_path)
67
+ for page_num, page in enumerate(doc, start=1):
68
+ text = page.get_text("text").strip()
69
+ if text:
70
+ documents.append(Document(
71
+ page_content=text,
72
+ metadata={"source": pdf_path, "page": page_num}
73
+ ))
74
+ except Exception as e:
75
+ print(f"❌ Error extracting text: {e}")
76
+ return documents
77
+
78
+ # Function to extract tables from PDFs
79
+ def extract_tables(pdf_path):
80
+ """Extracts tables from a PDF using Camelot and returns them as Document objects."""
81
+ table_documents = []
82
+ try:
83
+ tables = camelot.read_pdf(pdf_path, pages="all", flavor="stream")
84
+
85
+ if tables.n == 0:
86
+ print(f"⚠️ No tables found in {pdf_path}. Adding dummy data for testing.")
87
+ return [Document(page_content="Dummy Table: No real data found", metadata={"source": pdf_path, "table_index": 0})]
88
+
89
+ for i in range(tables.n):
90
+ table_text = tables[i].df.to_string()
91
+ table_documents.append(Document(
92
+ page_content=table_text,
93
+ metadata={"source": pdf_path, "table_index": i+1}
94
+ ))
95
+
96
+ except Exception as e:
97
+ print(f"❌ Error extracting tables from {pdf_path}: {e}")
98
+ return [Document(page_content="Dummy Table: Extraction error", metadata={"source": pdf_path, "table_index": -1})]
99
+
100
+ return table_documents
101
+
102
+ # Function to chunk tables (for docx and pdf)
103
+ def chunk_table(documents, chunk_size=2):
104
+ """Chunks table data row-wise from Document objects."""
105
+ chunks = []
106
+ for doc in documents:
107
+ if isinstance(doc, Document): # Ensure it's a Document object
108
+ table_text = doc.page_content # Extract the actual text
109
+
110
+ rows = table_text.split("\n") # Split into rows
111
+ for i in range(0, len(rows), chunk_size):
112
+ chunk = "\n".join(rows[i:i+chunk_size]) # Group rows
113
+ chunks.append(Document(page_content=chunk, metadata=doc.metadata)) # Preserve metadata
114
+
115
+ return chunks
116
+
117
+ # Function to process .docx, .xlsx, .csv, and PDF files
118
+ def process_files(file, text_chunk_size=1000, chunk_overlap=40, table_chunk_size=2):
119
+ text = []
120
+ tables = []
121
+
122
+ # Process .docx file
123
+ if file.endswith(".docx"):
124
+ docx_text, docx_tables = extract_text_and_tables(file)
125
+ text.append(docx_text)
126
+ tables.extend(docx_tables)
127
+
128
+ # Process .xlsx file
129
+ if file.endswith((".xlsx", ".xls")):
130
+ excel_text = read_excel(file)
131
+ text.extend(excel_text)
132
+
133
+ # Process .csv file
134
+ if file.endswith(".csv"):
135
+ csv_text = read_csv(file)
136
+ text.extend(csv_text)
137
+
138
+ # Process PDF file
139
+ if file.endswith(".pdf"):
140
+ pdf_text_documents = extract_text(file)
141
+ pdf_table_documents = extract_tables(file)
142
+ text.extend([doc.page_content for doc in pdf_text_documents])
143
+
144
+ if pdf_table_documents: # Only add tables if they exist
145
+ tables.extend(pdf_table_documents)
146
+ else:
147
+ print(f"⚠️ No tables found in {file}, skipping table embeddings.")
148
+
149
+ # Chunk the tables **only if tables exist**
150
+ table_chunks = chunk_table(tables, chunk_size=table_chunk_size) if tables else []
151
+
152
+ # Chunk the text
153
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=text_chunk_size, chunk_overlap=chunk_overlap)
154
+ text_chunks = text_splitter.split_documents([Document(page_content=t) for t in text]) if text else []
155
+
156
+ combined_chunks = text_chunks + table_chunks
157
+
158
+ return combined_chunks if combined_chunks else [] # Ensure no empty embeddings
159
+
160
+ # Function to process multiple files
161
+ # def data_processing(file_paths):
162
+ # all_combined_chunks = {}
163
+ # for file in file_paths:
164
+ # print(f"Processing file: {file.split('/')[-1]}")
165
+ # combined_chunks = process_files(file)
166
+ # all_combined_chunks[file] = combined_chunks
167
+ # return all_combined_chunks
168
+
169
+ # # Example usage
170
+ # file_paths = ["/content/Acceptable Use Policy.docx","/content/RiskAnalysisGuide.pdf"]
171
+ # all_combined_chunks = data_processing(file_paths)
hybrid_search.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain.vectorstores import FAISS
3
+ from langchain.embeddings.openai import OpenAIEmbeddings
4
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
5
+ from langchain.document_loaders import TextLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
8
+ from langchain.schema import Document
9
+ from langchain.chains import ConversationChain
10
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
11
+ from langchain.callbacks import get_openai_callback
12
+ from sentence_transformers import CrossEncoder
13
+ from langchain.chat_models import ChatOpenAI
14
+ from sentence_transformers import SentenceTransformer
15
+ from data_extraction import process_files
16
+ from dotenv import load_dotenv
17
+ import warnings
18
+ warnings.filterwarnings("ignore")
19
+ load_dotenv()
20
+ # 🔹 Set OpenAI API Key
21
+ all_hybrid_retriever = {}
22
+ file_names = []
23
+ llm_conversations = {} # {filename: ConversationChain}
24
+ all_result = {}
25
+ al_conversation_sum = {}
26
+ openai_key = os.getenv("openai_key")
27
+ os.environ["OPENAI_API_KEY"] = openai_key # Ensure 'openai_key' is defined
28
+
29
+ reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
30
+ def large_model(llm_model):
31
+ llm = ChatOpenAI(openai_api_key=openai_key, model="llm_model")
32
+ return llm
33
+
34
+ # 🔹 Choose Embedding Model
35
+ embedding_option = "open_source"
36
+
37
+ if embedding_option == "open_source":
38
+ print("Using BGE-M3 Embeddings")
39
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
40
+ else:
41
+ print("Using OpenAI Embeddings")
42
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_key, model="text-embedding-ada-002")
43
+
44
+ class ManualMemory:
45
+ def __init__(self, history_length=3):
46
+ self.history = [] # Stores chat history
47
+ self.history_length = history_length # How many interactions to keep
48
+
49
+ def add_interaction(self, user_query, llm_response):
50
+ """Add the user's query and the LLM's response to history."""
51
+ # Add the interaction as a tuple (user_query, llm_response)
52
+ self.history.append((user_query, llm_response))
53
+ # Keep only the last 'history_length' interactions
54
+ if len(self.history) > self.history_length:
55
+ self.history.pop(0)
56
+
57
+ def get_history(self):
58
+ """Return the current chat history."""
59
+ return "\n".join([f"User: {q}\nLLM: {r}" for q, r in self.history])
60
+
61
+
62
+ # 🔹 Function to Create Separate LLM + Memory for Each File
63
+ def create_conversation_chain():
64
+ llm = ChatOpenAI(openai_api_key=openai_key, model="gpt-4o-mini")
65
+ memory = ConversationBufferWindowMemory(k=0) # Stores last 3 interactions per file
66
+ return ConversationChain(llm=llm, memory=memory)
67
+
68
+ def hybrid_retrievers(split_docs):
69
+ # Create Vector Store and Retrievers
70
+ vector_store = FAISS.from_documents(split_docs, embeddings)
71
+ dense_retriever = vector_store.as_retriever(search_kwargs={"k": 5})
72
+
73
+ bm25_retriever = BM25Retriever.from_documents(split_docs)
74
+ bm25_retriever.k = 4
75
+
76
+ hybrid_retriever = EnsembleRetriever(
77
+ retrievers=[dense_retriever, bm25_retriever],
78
+ weights=[0.5, 0.5])
79
+ return hybrid_retriever
80
+
81
+ def rerank_with_cross_encoder(query, documents):
82
+ """Re-rank retrieved documents using a cross-encoder model."""
83
+ input_pairs = [(query, doc.page_content) for doc in documents]
84
+ scores = reranker.predict(input_pairs)
85
+ ranked_results = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
86
+ print("ranked_results",ranked_results)
87
+ return ranked_results
88
+
89
+
90
+ def count_tokens(chain, query, retriever, memory):
91
+ """Retrieve documents, run LLM, and count tokens."""
92
+ # Retrieve documents but don't store them in memory
93
+ retrieved_docs = retriever.get_relevant_documents(query)
94
+ reranked_docs = rerank_with_cross_encoder(query, retrieved_docs)
95
+ retrieved_text = "\n\n".join([doc.page_content for doc, _ in reranked_docs]) # Extract text
96
+
97
+ # Construct the prompt using the chat history and retrieved text
98
+ prompt = f"""You are a cybersecurity expert RAG bot, answering queries using retrieved documents and Chat history.
99
+ Retrieved documents: \n{retrieved_text}\n\nQuestion: {query}
100
+
101
+ Chat history:
102
+ {memory.get_history()}
103
+
104
+ If the documents are relevant, use them to answer.
105
+ If they don’t have enough useful information, say:
106
+ "No info."
107
+ Keep your responses clear and accurate."""
108
+
109
+ # Generate response using the LLM and the prompt
110
+ with get_openai_callback() as cb:
111
+ result = chain.run(prompt) # Pass query + retrieved context + chat history as prompt
112
+ print(f"Spent a total of {cb.total_tokens} tokens")
113
+
114
+ # Store the interaction in memory
115
+ memory.add_interaction(query, result)
116
+
117
+ return result, reranked_docs
118
+
119
+
120
+
121
+ manual_memory = ManualMemory(history_length=3)
122
+ all_manual_memory = {}
123
+ all_retrieved_docs = {}
124
+ all_combined_chunks = {}
125
+ all_hybrid_retriever = {}
126
+ al_conversation_sum = {}
127
+
128
+ # Global variables to track previous file paths and embeddings
129
+ old_file_paths = []
130
+ old_embeding = None # Initialize properly
131
+ def multimodelrag(query, file_paths, embeding, llm_model,conversation=3):
132
+ global old_file_paths, old_embeding
133
+ global all_manual_memory, all_retrieved_docs, all_combined_chunks, all_hybrid_retriever, al_conversation_sum
134
+
135
+ print("query, file_paths, embeding, conversation, llm_model", query, file_paths, embeding, conversation, llm_model)
136
+
137
+ if embedding_option == embeding:
138
+ print("Using BGE-M3 Embeddings")
139
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
140
+ else:
141
+ print("Using OpenAI Embeddings")
142
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_key, model="text-embedding-ada-002")
143
+
144
+ llm = ChatOpenAI(openai_api_key=openai_key, model=llm_model)
145
+
146
+ if (old_file_paths != file_paths) or (old_embeding != embeding):
147
+ # Reset memory only when new files are loaded
148
+ all_manual_memory = {}
149
+ all_retrieved_docs = {}
150
+ all_combined_chunks = {}
151
+ all_hybrid_retriever = {}
152
+ al_conversation_sum = {}
153
+
154
+ for file__name in file_paths:
155
+ file = file__name.split("/")[-1]
156
+
157
+ print("Processing file:", file)
158
+ old_embeding = embeding
159
+ old_file_paths = file_paths
160
+
161
+ combined_chunks = process_files(file__name)
162
+
163
+ all_combined_chunks[file] = combined_chunks
164
+ all_hybrid_retriever[file] = hybrid_retrievers(all_combined_chunks[file])
165
+ al_conversation_sum[file] = create_conversation_chain()
166
+
167
+ # ✅ Create a separate memory instance for each file
168
+ all_manual_memory[file] = ManualMemory(history_length=conversation)
169
+
170
+ # Using query
171
+ all_result[file], all_retrieved_docs[file] = count_tokens(
172
+ al_conversation_sum[file], query, all_hybrid_retriever[file], all_manual_memory[file]
173
+ )
174
+ else:
175
+ # Reuse existing memory for the same file
176
+ for file__name in file_paths:
177
+ file = file__name.split("/")[-1]
178
+ print("Reusing memory for:", file)
179
+
180
+ all_result[file], all_retrieved_docs[file] = count_tokens(
181
+ al_conversation_sum[file], query, all_hybrid_retriever[file], all_manual_memory[file]
182
+ )
183
+
184
+ return all_result
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python-docx
2
+ PyMuPDF
3
+ frontend
4
+ langchain
5
+ openai==0.28
6
+ faiss-cpu
7
+ tiktoken
8
+ langchain_openai
9
+ tools
10
+ langchain-community
11
+ rank_bm25
12
+ openai # No specific version added
13
+ sentence-transformers
14
+ camelot-py
15
+ fitz
16
+ streamlit
17
+ frontend
stm.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import shutil
4
+ from hybrid_search import multimodelrag
5
+ import warnings
6
+
7
+ warnings.filterwarnings("ignore")
8
+
9
+ # Streamlit UI
10
+ st.set_page_config(layout="wide")
11
+ st.title("AI Document Processor with Conversational RAG")
12
+
13
+ # Initialize conversation history in session state
14
+ if "conversation_history" not in st.session_state:
15
+ st.session_state.conversation_history = []
16
+
17
+ # Sidebar for file upload and settings
18
+ with st.sidebar:
19
+ uploaded_files = st.file_uploader(
20
+ "Upload multiple files (PDF, DOCX, Excel, CSV)",
21
+ type=["pdf", "docx", "xlsx", "csv"],
22
+ accept_multiple_files=True
23
+ )
24
+
25
+ embeding = st.selectbox("Select Memory Mode", ["open_source", "openai"], index=0)
26
+ conversation = st.selectbox("Number of conversation", [2, 4, 6], index=0)
27
+ llm_option = st.selectbox("Select LLM Model", ["GPT-4o", "GPT-4o-mini"], index=1)
28
+
29
+ temp_dir = "temp_uploaded_files"
30
+
31
+ # Clear the previous uploads when new files are uploaded
32
+ if uploaded_files:
33
+ if os.path.exists(temp_dir):
34
+ shutil.rmtree(temp_dir) # Delete the old directory and its contents
35
+ os.makedirs(temp_dir) # Create a fresh directory
36
+
37
+ file_paths = [] # List to store saved file paths
38
+ for file in uploaded_files:
39
+ file_path = os.path.join(temp_dir, file.name)
40
+ with open(file_path, "wb") as f:
41
+ f.write(file.read()) # Save file locally
42
+ file_paths.append(file_path)
43
+ st.write(f"✅ Saved: {file.name}")
44
+
45
+ # Chat interface
46
+ st.write("### Chat Interface")
47
+
48
+ chat_display = "\n".join(st.session_state.conversation_history)
49
+ # st.text_area("Conversation History", chat_display, height=300, disabled=True)
50
+
51
+ # Input for user question
52
+ user_input = st.text_input("Ask a question:")
53
+ llm_model = llm_option
54
+
55
+ if st.button("Retrieve and Answer"):
56
+ if user_input or uploaded_files:
57
+ answer = multimodelrag(user_input, file_paths, embeding, llm_model,conversation)
58
+
59
+ # Update conversation history
60
+ st.session_state.conversation_history.append(f"User: {user_input}")
61
+ st.session_state.conversation_history.append(f"AI: {answer}")
62
+
63
+ # Refresh chat display
64
+ chat_display = "\n".join(st.session_state.conversation_history)
65
+ st.text_area("Conversation History", chat_display, height=400, disabled=True)
66
+
67
+ st.write("### Answer:")
68
+ st.write(answer)
69
+
70
+