dnzblgn commited on
Commit
e987c24
Β·
verified Β·
1 Parent(s): 6a9127c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -182
app.py CHANGED
@@ -8,213 +8,109 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
  from langchain_community.vectorstores import FAISS
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
- import torch
12
- from PIL import Image
13
- from torchvision import transforms, models
14
  from langchain_community.llms import HuggingFaceEndpoint
15
  from langchain_huggingface import HuggingFaceEmbeddings
16
- from torchvision import transforms, models
17
 
18
- class GeometryImageClassifier:
19
- def __init__(self):
20
- self.model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) # Updated syntax
21
- self.model.fc = torch.nn.Identity()
22
- self.model = self.model.to('cpu')
23
- self.model.eval()
24
- self.transform = transforms.Compose([
25
- transforms.Resize((224, 224)),
26
- transforms.ToTensor(),
27
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
28
- ])
29
-
30
- self.reference_embeddings = {
31
- "flat.png": {
32
- "embedding": None,
33
- "label": "Flat or Sheet-Based"
34
- },
35
- "cylindrical.png": {
36
- "embedding": None,
37
- "label": "Cylindrical"
38
- },
39
- "complex.png": {
40
- "embedding": None,
41
- "label": "Complex Multi Axis Geometry"
42
- }
43
- }
44
-
45
- def compute_embedding(self, image_path):
46
- img = Image.open(image_path).convert('RGB')
47
- img_tensor = self.transform(img).unsqueeze(0)
48
- with torch.no_grad():
49
- embedding = self.model(img_tensor)
50
- return embedding.squeeze().cpu().numpy()
51
-
52
- def initialize_reference_embeddings(self, reference_folder):
53
- for image_name in self.reference_embeddings.keys():
54
- image_path = os.path.join(reference_folder, image_name)
55
- if os.path.exists(image_path):
56
- self.reference_embeddings[image_name]["embedding"] = self.compute_embedding(image_path)
57
- else:
58
- print(f"Warning: Reference image {image_path} not found")
59
-
60
- def find_closest_geometry(self, query_embedding):
61
- best_similarity = -1
62
- best_label = None
63
-
64
- for ref_data in self.reference_embeddings.values():
65
- if ref_data["embedding"] is not None:
66
- similarity = np.dot(query_embedding, ref_data["embedding"])
67
- if similarity > best_similarity:
68
- best_similarity = similarity
69
- best_label = ref_data["label"]
70
-
71
- return best_label or "Unknown Geometry"
72
-
73
- def process_image(self, image_path):
74
- query_embedding = self.compute_embedding(image_path)
75
- return self.find_closest_geometry(query_embedding)
76
-
77
-
78
- # βœ… Use a strong sentence embedding model
79
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
80
 
81
  def extract_text_from_docx(file_path):
82
- """ βœ… Extracts normal text & tables from a .docx file for better retrieval. """
83
  doc = docx.Document(file_path)
84
  extracted_text = []
85
-
86
  for para in doc.paragraphs:
87
  if para.text.strip():
88
  extracted_text.append(para.text.strip())
89
-
90
  for table in doc.tables:
91
  extracted_text.append("πŸ“Œ Table Detected:")
92
  for row in table.rows:
93
  row_text = [cell.text.strip() for cell in row.cells]
94
  if any(row_text):
95
  extracted_text.append(" | ".join(row_text))
96
-
97
  return "\n".join(extracted_text)
98
 
99
-
100
  def load_documents():
101
- """ βœ… Loads & processes documents, ensuring table data is properly extracted. """
102
  file_paths = {
103
  "Fastener_Types_Manual": "Fastener_Types_Manual.docx",
104
  "Manufacturing_Expert_Manual": "Manufacturing Expert Manual.docx"
105
  }
106
-
107
  all_splits = []
108
-
109
  for doc_name, file_path in file_paths.items():
110
  if not os.path.exists(file_path):
111
  raise FileNotFoundError(f"Document not found: {file_path}")
112
-
113
  print(f"Extracting text from {file_path}...")
114
  full_text = extract_text_from_docx(file_path)
115
-
116
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
117
  doc_splits = text_splitter.create_documents([full_text])
118
-
119
  for chunk in doc_splits:
120
  chunk.metadata = {"source": doc_name}
121
-
122
  all_splits.extend(doc_splits)
123
-
124
  return all_splits
125
 
126
-
127
  def create_db(splits):
128
- """ βœ… Creates a FAISS vector database from document splits. """
129
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
130
  vectordb = FAISS.from_documents(splits, embeddings)
131
- return vectordb
132
-
133
 
134
  def retrieve_documents(query, retriever, embeddings):
135
- """ βœ… Retrieves the most relevant documents & filters out low-relevance ones. """
136
  query_embedding = np.array(embeddings.embed_query(query)).reshape(1, -1)
137
  results = retriever.invoke(query)
138
-
139
  if not results:
140
  return []
141
-
142
  doc_embeddings = np.array([embeddings.embed_query(doc.page_content) for doc in results])
143
- similarity_scores = cosine_similarity(query_embedding, doc_embeddings)[0] # βœ… Proper cosine similarity
144
-
145
- MIN_SIMILARITY = 0.5 # πŸ”₯ Increased threshold to improve relevance
146
  filtered_results = [(doc, sim) for doc, sim in zip(results, similarity_scores) if sim >= MIN_SIMILARITY]
147
-
148
- # βœ… Debugging log
149
  print(f"πŸ” Query: {query}")
150
- print(f"πŸ“„ Retrieved Docs (before filtering): {[(doc.metadata.get('source', 'Unknown'), sim) for doc, sim in zip(results, similarity_scores)]}")
151
- print(f"βœ… Filtered Docs (after threshold {MIN_SIMILARITY}): {[(doc.metadata.get('source', 'Unknown'), sim) for doc, sim in filtered_results]}")
152
-
153
  return [doc for doc, _ in filtered_results] if filtered_results else []
154
 
155
-
156
  def validate_query_semantically(query, retrieved_docs):
157
- """ βœ… Ensures the query meaning is covered in the retrieved documents. """
158
  if not retrieved_docs:
159
  return False
160
-
161
  combined_text = " ".join([doc.page_content for doc in retrieved_docs])
162
  query_embedding = semantic_model.encode(query, normalize_embeddings=True)
163
  doc_embedding = semantic_model.encode(combined_text, normalize_embeddings=True)
164
-
165
- similarity_score = np.dot(query_embedding, doc_embedding) # βœ… Cosine similarity already normalized
166
-
167
  print(f"πŸ” Semantic Similarity Score: {similarity_score}")
 
 
168
 
169
- return similarity_score >= 0.3 # πŸ”₯ Stricter threshold to ensure correctness
170
-
171
-
172
- def handle_query(query, history, retriever, qa_chain, embeddings):
173
- """ βœ… Handles user queries & prevents hallucination. """
174
- retrieved_docs = retrieve_documents(query, retriever, embeddings)
175
-
176
- if not retrieved_docs or not validate_query_semantically(query, retrieved_docs):
177
- return history + [(query, "I couldn't find any relevant information.")], ""
178
-
179
- response = qa_chain.invoke({"question": query, "chat_history": history})
180
- assistant_response = response['answer'].strip()
181
-
182
- # βœ… Final hallucination check
183
- if not validate_query_semantically(query, retrieved_docs):
184
- assistant_response = "I couldn't find any relevant information."
185
-
186
- assistant_response += f"\n\nπŸ“„ **Source:** {', '.join(set(doc.metadata.get('source', 'Unknown') for doc in retrieved_docs))}"
187
-
188
- # βœ… Debugging logs
189
- print(f"πŸ€– LLM Response: {assistant_response[:300]}") # βœ… Limit output for debugging
190
-
191
- history.append((query, assistant_response))
192
- return history, ""
193
-
194
-
195
- def initialize_chatbot(vector_db):
196
- """ βœ… Initializes chatbot with improved retrieval & processing. """
197
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
198
-
199
- embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
200
-
201
- retriever = vector_db.as_retriever(search_kwargs={"k": 5, "search_type": "similarity"})
202
-
203
- system_prompt = """You are an AI assistant that answers questions **ONLY based on the provided documents**.
204
- - **If no relevant documents are retrieved, respond with: "I couldn't find any relevant information."**
205
- - **If the meaning of the query does not match the retrieved documents, say "I couldn't find any relevant information."**
206
- - **Do NOT attempt to answer from general knowledge.**
207
- """
208
-
209
  llm = HuggingFaceEndpoint(
210
- repo_id="mistralai/Mistral-7B-Instruct-v0.2",
211
  huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
212
  temperature=0.1,
213
- max_new_tokens=400,
214
  task="text-generation",
215
  system_prompt=system_prompt
216
  )
217
-
218
  qa_chain = ConversationalRetrievalChain.from_llm(
219
  llm=llm,
220
  retriever=retriever,
@@ -222,54 +118,40 @@ def initialize_chatbot(vector_db):
222
  return_source_documents=True,
223
  verbose=False
224
  )
 
 
225
 
226
- return retriever, qa_chain, embeddings
227
-
228
-
229
- def process_image_and_generate_query(image):
230
- classifier = GeometryImageClassifier()
231
- geometry_type = classifier.process_image(image)
 
 
 
 
 
232
 
233
- query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
234
- return geometry_type, query
 
 
235
 
236
  def demo():
237
- # Initialize classifier once at startup
238
- classifier = GeometryImageClassifier()
239
- classifier.initialize_reference_embeddings("images")
240
-
241
- # Initialize chatbot components
242
- retriever, qa_chain, embeddings = initialize_chatbot(create_db(load_documents()))
243
 
244
  with gr.Blocks() as app:
245
- gr.Markdown("### πŸ€– **Fastener Agent with Image Recognition** πŸ“š")
 
 
 
 
246
 
247
- with gr.Row():
248
- with gr.Column(scale=1):
249
- image_input = gr.Image(type="filepath", label="Upload Geometry Image")
250
- geometry_label = gr.Textbox(label="Detected Geometry Type", interactive=False)
251
-
252
- with gr.Column(scale=2):
253
- chatbot = gr.Chatbot()
254
- query_input = gr.Textbox(label="Ask me a question")
255
- query_btn = gr.Button("Ask")
256
-
257
- def image_upload_handler(image):
258
- if image is None:
259
- return "", ""
260
- # Use the initialized classifier
261
- geometry_type = classifier.process_image(image)
262
- suggested_query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
263
- return geometry_type, suggested_query
264
-
265
  def user_query_handler(query, history):
266
  return handle_query(query, history, retriever, qa_chain, embeddings)
267
-
268
- image_input.change(
269
- image_upload_handler,
270
- inputs=[image_input],
271
- outputs=[geometry_label, query_input]
272
- )
273
 
274
  query_btn.click(
275
  user_query_handler,
@@ -282,8 +164,8 @@ def demo():
282
  inputs=[query_input, chatbot],
283
  outputs=[chatbot, query_input]
284
  )
285
-
286
  app.launch()
287
 
288
  if __name__ == "__main__":
289
- demo()
 
8
  from langchain_community.vectorstores import FAISS
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
 
 
 
11
  from langchain_community.llms import HuggingFaceEndpoint
12
  from langchain_huggingface import HuggingFaceEmbeddings
 
13
 
14
+ # Initialize semantic model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
16
 
17
  def extract_text_from_docx(file_path):
 
18
  doc = docx.Document(file_path)
19
  extracted_text = []
20
+
21
  for para in doc.paragraphs:
22
  if para.text.strip():
23
  extracted_text.append(para.text.strip())
24
+
25
  for table in doc.tables:
26
  extracted_text.append("πŸ“Œ Table Detected:")
27
  for row in table.rows:
28
  row_text = [cell.text.strip() for cell in row.cells]
29
  if any(row_text):
30
  extracted_text.append(" | ".join(row_text))
31
+
32
  return "\n".join(extracted_text)
33
 
 
34
  def load_documents():
 
35
  file_paths = {
36
  "Fastener_Types_Manual": "Fastener_Types_Manual.docx",
37
  "Manufacturing_Expert_Manual": "Manufacturing Expert Manual.docx"
38
  }
39
+
40
  all_splits = []
41
+
42
  for doc_name, file_path in file_paths.items():
43
  if not os.path.exists(file_path):
44
  raise FileNotFoundError(f"Document not found: {file_path}")
45
+
46
  print(f"Extracting text from {file_path}...")
47
  full_text = extract_text_from_docx(file_path)
48
+
49
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
50
  doc_splits = text_splitter.create_documents([full_text])
51
+
52
  for chunk in doc_splits:
53
  chunk.metadata = {"source": doc_name}
54
+
55
  all_splits.extend(doc_splits)
56
+
57
  return all_splits
58
 
 
59
  def create_db(splits):
 
60
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
61
  vectordb = FAISS.from_documents(splits, embeddings)
62
+ return vectordb, embeddings
 
63
 
64
  def retrieve_documents(query, retriever, embeddings):
 
65
  query_embedding = np.array(embeddings.embed_query(query)).reshape(1, -1)
66
  results = retriever.invoke(query)
67
+
68
  if not results:
69
  return []
70
+
71
  doc_embeddings = np.array([embeddings.embed_query(doc.page_content) for doc in results])
72
+ similarity_scores = cosine_similarity(query_embedding, doc_embeddings)[0]
73
+
74
+ MIN_SIMILARITY = 0.5
75
  filtered_results = [(doc, sim) for doc, sim in zip(results, similarity_scores) if sim >= MIN_SIMILARITY]
76
+
 
77
  print(f"πŸ” Query: {query}")
78
+ print(f"πŸ“„ Retrieved Docs: {[(doc.metadata.get('source', 'Unknown'), sim) for doc, sim in filtered_results]}")
79
+
 
80
  return [doc for doc, _ in filtered_results] if filtered_results else []
81
 
 
82
  def validate_query_semantically(query, retrieved_docs):
 
83
  if not retrieved_docs:
84
  return False
85
+
86
  combined_text = " ".join([doc.page_content for doc in retrieved_docs])
87
  query_embedding = semantic_model.encode(query, normalize_embeddings=True)
88
  doc_embedding = semantic_model.encode(combined_text, normalize_embeddings=True)
89
+
90
+ similarity_score = np.dot(query_embedding, doc_embedding)
 
91
  print(f"πŸ” Semantic Similarity Score: {similarity_score}")
92
+
93
+ return similarity_score >= 0.3
94
 
95
+ def initialize_chatbot(vector_db, embeddings):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
97
+
98
+ retriever = vector_db.as_retriever(search_kwargs={"k": 5})
99
+
100
+ system_prompt = """You are an AI assistant that answers questions ONLY based on the provided documents.
101
+ - If no relevant documents are retrieved, respond with: "I couldn't find any relevant information."
102
+ - If the meaning of the query does not match the retrieved documents, say "I couldn't find any relevant information."
103
+ - Do NOT attempt to answer from general knowledge."""
104
+
 
 
 
105
  llm = HuggingFaceEndpoint(
106
+ repo_id="openai-community/gpt2",
107
  huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
108
  temperature=0.1,
109
+ max_new_tokens=400,
110
  task="text-generation",
111
  system_prompt=system_prompt
112
  )
113
+
114
  qa_chain = ConversationalRetrievalChain.from_llm(
115
  llm=llm,
116
  retriever=retriever,
 
118
  return_source_documents=True,
119
  verbose=False
120
  )
121
+
122
+ return retriever, qa_chain
123
 
124
+ def handle_query(query, history, retriever, qa_chain, embeddings):
125
+ retrieved_docs = retrieve_documents(query, retriever, embeddings)
126
+
127
+ if not retrieved_docs or not validate_query_semantically(query, retrieved_docs):
128
+ return history + [(query, "I couldn't find any relevant information.")], ""
129
+
130
+ response = qa_chain.invoke({"question": query, "chat_history": history})
131
+ assistant_response = response['answer'].strip()
132
+
133
+ if not validate_query_semantically(query, retrieved_docs):
134
+ assistant_response = "I couldn't find any relevant information."
135
 
136
+ assistant_response += f"\n\nπŸ“„ Source: {', '.join(set(doc.metadata.get('source', 'Unknown') for doc in retrieved_docs))}"
137
+
138
+ history.append((query, assistant_response))
139
+ return history, ""
140
 
141
  def demo():
142
+ documents = load_documents()
143
+ vector_db, embeddings = create_db(documents)
144
+ retriever, qa_chain = initialize_chatbot(vector_db, embeddings)
 
 
 
145
 
146
  with gr.Blocks() as app:
147
+ gr.Markdown("### πŸ€– Document Question Answering System")
148
+
149
+ chatbot = gr.Chatbot()
150
+ query_input = gr.Textbox(label="Ask a question about the documents")
151
+ query_btn = gr.Button("Submit")
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def user_query_handler(query, history):
154
  return handle_query(query, history, retriever, qa_chain, embeddings)
 
 
 
 
 
 
155
 
156
  query_btn.click(
157
  user_query_handler,
 
164
  inputs=[query_input, chatbot],
165
  outputs=[chatbot, query_input]
166
  )
167
+
168
  app.launch()
169
 
170
  if __name__ == "__main__":
171
+ demo()