dnzblgn commited on
Commit
9f7bf05
Β·
verified Β·
1 Parent(s): ab3973b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -64
app.py CHANGED
@@ -10,107 +10,217 @@ from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain_community.llms import HuggingFaceEndpoint
12
  from langchain_huggingface import HuggingFaceEmbeddings
 
 
 
 
 
13
 
14
- # Initialize semantic model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
16
 
 
17
  def extract_text_from_docx(file_path):
 
18
  doc = docx.Document(file_path)
19
  extracted_text = []
20
-
21
  for para in doc.paragraphs:
22
  if para.text.strip():
23
  extracted_text.append(para.text.strip())
24
-
25
  for table in doc.tables:
26
  extracted_text.append("πŸ“Œ Table Detected:")
27
  for row in table.rows:
28
  row_text = [cell.text.strip() for cell in row.cells]
29
  if any(row_text):
30
  extracted_text.append(" | ".join(row_text))
31
-
32
  return "\n".join(extracted_text)
33
 
 
34
  def load_documents():
 
35
  file_paths = {
36
  "Fastener_Types_Manual": "Fastener_Types_Manual.docx",
37
  "Manufacturing_Expert_Manual": "Manufacturing Expert Manual.docx"
38
  }
39
-
40
  all_splits = []
41
-
42
  for doc_name, file_path in file_paths.items():
43
  if not os.path.exists(file_path):
44
  raise FileNotFoundError(f"Document not found: {file_path}")
45
-
46
  print(f"Extracting text from {file_path}...")
47
  full_text = extract_text_from_docx(file_path)
48
-
49
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
50
  doc_splits = text_splitter.create_documents([full_text])
51
-
52
  for chunk in doc_splits:
53
  chunk.metadata = {"source": doc_name}
54
-
55
  all_splits.extend(doc_splits)
56
-
57
  return all_splits
58
 
 
59
  def create_db(splits):
 
60
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
61
  vectordb = FAISS.from_documents(splits, embeddings)
62
- return vectordb, embeddings
 
63
 
64
  def retrieve_documents(query, retriever, embeddings):
 
65
  query_embedding = np.array(embeddings.embed_query(query)).reshape(1, -1)
66
  results = retriever.invoke(query)
67
-
68
  if not results:
69
  return []
70
-
71
  doc_embeddings = np.array([embeddings.embed_query(doc.page_content) for doc in results])
72
- similarity_scores = cosine_similarity(query_embedding, doc_embeddings)[0]
73
-
74
- MIN_SIMILARITY = 0.5
75
  filtered_results = [(doc, sim) for doc, sim in zip(results, similarity_scores) if sim >= MIN_SIMILARITY]
76
-
 
77
  print(f"πŸ” Query: {query}")
78
- print(f"πŸ“„ Retrieved Docs: {[(doc.metadata.get('source', 'Unknown'), sim) for doc, sim in filtered_results]}")
79
-
 
80
  return [doc for doc, _ in filtered_results] if filtered_results else []
81
 
 
82
  def validate_query_semantically(query, retrieved_docs):
 
83
  if not retrieved_docs:
84
  return False
85
-
86
  combined_text = " ".join([doc.page_content for doc in retrieved_docs])
87
  query_embedding = semantic_model.encode(query, normalize_embeddings=True)
88
  doc_embedding = semantic_model.encode(combined_text, normalize_embeddings=True)
89
-
90
- similarity_score = np.dot(query_embedding, doc_embedding)
 
91
  print(f"πŸ” Semantic Similarity Score: {similarity_score}")
92
-
93
- return similarity_score >= 0.3
94
 
95
- def initialize_chatbot(vector_db, embeddings):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
97
-
98
- retriever = vector_db.as_retriever(search_kwargs={"k": 5})
99
-
100
- system_prompt = """You are an AI assistant that answers questions ONLY based on the provided documents.
101
- - If no relevant documents are retrieved, respond with: "I couldn't find any relevant information."
102
- - If the meaning of the query does not match the retrieved documents, say "I couldn't find any relevant information."
103
- - Do NOT attempt to answer from general knowledge."""
104
-
 
 
 
105
  llm = HuggingFaceEndpoint(
106
- repo_id="mistralai/Mistral-7B-Instruct-v0.3",
107
  huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
108
  temperature=0.1,
109
- max_new_tokens=400,
110
  task="text-generation",
111
  system_prompt=system_prompt
112
  )
113
-
114
  qa_chain = ConversationalRetrievalChain.from_llm(
115
  llm=llm,
116
  retriever=retriever,
@@ -118,40 +228,54 @@ def initialize_chatbot(vector_db, embeddings):
118
  return_source_documents=True,
119
  verbose=False
120
  )
121
-
122
- return retriever, qa_chain
123
 
124
- def handle_query(query, history, retriever, qa_chain, embeddings):
125
- retrieved_docs = retrieve_documents(query, retriever, embeddings)
126
-
127
- if not retrieved_docs or not validate_query_semantically(query, retrieved_docs):
128
- return history + [(query, "I couldn't find any relevant information.")], ""
129
-
130
- response = qa_chain.invoke({"question": query, "chat_history": history})
131
- assistant_response = response['answer'].strip()
132
-
133
- if not validate_query_semantically(query, retrieved_docs):
134
- assistant_response = "I couldn't find any relevant information."
135
-
136
- assistant_response += f"\n\nπŸ“„ Source: {', '.join(set(doc.metadata.get('source', 'Unknown') for doc in retrieved_docs))}"
137
 
138
- history.append((query, assistant_response))
139
- return history, ""
140
 
141
  def demo():
142
- documents = load_documents()
143
- vector_db, embeddings = create_db(documents)
144
- retriever, qa_chain = initialize_chatbot(vector_db, embeddings)
 
 
 
145
 
146
  with gr.Blocks() as app:
147
- gr.Markdown("### πŸ€– Document Question Answering System")
148
-
149
- chatbot = gr.Chatbot()
150
- query_input = gr.Textbox(label="Ask a question about the documents")
151
- query_btn = gr.Button("Submit")
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def user_query_handler(query, history):
154
  return handle_query(query, history, retriever, qa_chain, embeddings)
 
 
 
 
 
 
155
 
156
  query_btn.click(
157
  user_query_handler,
@@ -164,8 +288,8 @@ def demo():
164
  inputs=[query_input, chatbot],
165
  outputs=[chatbot, query_input]
166
  )
167
-
168
  app.launch()
169
 
170
  if __name__ == "__main__":
171
- demo()
 
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain_community.llms import HuggingFaceEndpoint
12
  from langchain_huggingface import HuggingFaceEmbeddings
13
+ import torch
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+ from torchvision.models import resnet50, ResNet50_Weights
17
+ from torchvision import transforms, models
18
 
19
+
20
+ class GeometryImageClassifier:
21
+ def __init__(self):
22
+ # Use ResNet18 instead of ResNet50 - lighter and pre-downloaded
23
+ self.model = models.resnet18(pretrained=True)
24
+ self.model.fc = torch.nn.Identity()
25
+ self.model = self.model.to('cpu')
26
+ self.model.eval()
27
+
28
+ self.transform = transforms.Compose([
29
+ transforms.Resize((224, 224)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
32
+ ])
33
+
34
+ # Simplified reference embeddings
35
+ self.reference_embeddings = {
36
+ "flat.png": {
37
+ "embedding": None,
38
+ "label": "Flat or Sheet-Based"
39
+ },
40
+ "cylindrical.png": {
41
+ "embedding": None,
42
+ "label": "Cylindrical"
43
+ },
44
+ "complex.png": {
45
+ "embedding": None,
46
+ "label": "Complex Multi Axis Geometry"
47
+ }
48
+ }
49
+
50
+ def compute_embedding(self, image_path):
51
+ img = Image.open(image_path).convert('RGB')
52
+ img_tensor = self.transform(img).unsqueeze(0)
53
+ with torch.no_grad():
54
+ embedding = self.model(img_tensor)
55
+ return embedding.squeeze().cpu().numpy()
56
+
57
+ def initialize_reference_embeddings(self, reference_folder):
58
+ for image_name in self.reference_embeddings.keys():
59
+ image_path = os.path.join(reference_folder, image_name)
60
+ if os.path.exists(image_path):
61
+ self.reference_embeddings[image_name]["embedding"] = self.compute_embedding(image_path)
62
+ else:
63
+ print(f"Warning: Reference image {image_path} not found")
64
+
65
+ def find_closest_geometry(self, query_embedding):
66
+ best_similarity = -1
67
+ best_label = None
68
+
69
+ for ref_data in self.reference_embeddings.values():
70
+ if ref_data["embedding"] is not None:
71
+ similarity = np.dot(query_embedding, ref_data["embedding"])
72
+ if similarity > best_similarity:
73
+ best_similarity = similarity
74
+ best_label = ref_data["label"]
75
+
76
+ return best_label or "Unknown Geometry"
77
+
78
+ def process_image(self, image_path):
79
+ query_embedding = self.compute_embedding(image_path)
80
+ return self.find_closest_geometry(query_embedding)
81
+
82
+
83
+ # βœ… Use a strong sentence embedding model
84
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
85
 
86
+
87
  def extract_text_from_docx(file_path):
88
+ """ βœ… Extracts normal text & tables from a .docx file for better retrieval. """
89
  doc = docx.Document(file_path)
90
  extracted_text = []
91
+
92
  for para in doc.paragraphs:
93
  if para.text.strip():
94
  extracted_text.append(para.text.strip())
95
+
96
  for table in doc.tables:
97
  extracted_text.append("πŸ“Œ Table Detected:")
98
  for row in table.rows:
99
  row_text = [cell.text.strip() for cell in row.cells]
100
  if any(row_text):
101
  extracted_text.append(" | ".join(row_text))
102
+
103
  return "\n".join(extracted_text)
104
 
105
+
106
  def load_documents():
107
+ """ βœ… Loads & processes documents, ensuring table data is properly extracted. """
108
  file_paths = {
109
  "Fastener_Types_Manual": "Fastener_Types_Manual.docx",
110
  "Manufacturing_Expert_Manual": "Manufacturing Expert Manual.docx"
111
  }
112
+
113
  all_splits = []
114
+
115
  for doc_name, file_path in file_paths.items():
116
  if not os.path.exists(file_path):
117
  raise FileNotFoundError(f"Document not found: {file_path}")
118
+
119
  print(f"Extracting text from {file_path}...")
120
  full_text = extract_text_from_docx(file_path)
121
+
122
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
123
  doc_splits = text_splitter.create_documents([full_text])
124
+
125
  for chunk in doc_splits:
126
  chunk.metadata = {"source": doc_name}
127
+
128
  all_splits.extend(doc_splits)
129
+
130
  return all_splits
131
 
132
+
133
  def create_db(splits):
134
+ """ βœ… Creates a FAISS vector database from document splits. """
135
  embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
136
  vectordb = FAISS.from_documents(splits, embeddings)
137
+ return vectordb
138
+
139
 
140
  def retrieve_documents(query, retriever, embeddings):
141
+ """ βœ… Retrieves the most relevant documents & filters out low-relevance ones. """
142
  query_embedding = np.array(embeddings.embed_query(query)).reshape(1, -1)
143
  results = retriever.invoke(query)
144
+
145
  if not results:
146
  return []
147
+
148
  doc_embeddings = np.array([embeddings.embed_query(doc.page_content) for doc in results])
149
+ similarity_scores = cosine_similarity(query_embedding, doc_embeddings)[0] # βœ… Proper cosine similarity
150
+
151
+ MIN_SIMILARITY = 0.5 # πŸ”₯ Increased threshold to improve relevance
152
  filtered_results = [(doc, sim) for doc, sim in zip(results, similarity_scores) if sim >= MIN_SIMILARITY]
153
+
154
+ # βœ… Debugging log
155
  print(f"πŸ” Query: {query}")
156
+ print(f"πŸ“„ Retrieved Docs (before filtering): {[(doc.metadata.get('source', 'Unknown'), sim) for doc, sim in zip(results, similarity_scores)]}")
157
+ print(f"βœ… Filtered Docs (after threshold {MIN_SIMILARITY}): {[(doc.metadata.get('source', 'Unknown'), sim) for doc, sim in filtered_results]}")
158
+
159
  return [doc for doc, _ in filtered_results] if filtered_results else []
160
 
161
+
162
  def validate_query_semantically(query, retrieved_docs):
163
+ """ βœ… Ensures the query meaning is covered in the retrieved documents. """
164
  if not retrieved_docs:
165
  return False
166
+
167
  combined_text = " ".join([doc.page_content for doc in retrieved_docs])
168
  query_embedding = semantic_model.encode(query, normalize_embeddings=True)
169
  doc_embedding = semantic_model.encode(combined_text, normalize_embeddings=True)
170
+
171
+ similarity_score = np.dot(query_embedding, doc_embedding) # βœ… Cosine similarity already normalized
172
+
173
  print(f"πŸ” Semantic Similarity Score: {similarity_score}")
 
 
174
 
175
+ return similarity_score >= 0.3 # πŸ”₯ Stricter threshold to ensure correctness
176
+
177
+
178
+ def handle_query(query, history, retriever, qa_chain, embeddings):
179
+ """ βœ… Handles user queries & prevents hallucination. """
180
+ retrieved_docs = retrieve_documents(query, retriever, embeddings)
181
+
182
+ if not retrieved_docs or not validate_query_semantically(query, retrieved_docs):
183
+ return history + [(query, "I couldn't find any relevant information.")], ""
184
+
185
+ response = qa_chain.invoke({"question": query, "chat_history": history})
186
+ assistant_response = response['answer'].strip()
187
+
188
+ # βœ… Final hallucination check
189
+ if not validate_query_semantically(query, retrieved_docs):
190
+ assistant_response = "I couldn't find any relevant information."
191
+
192
+ assistant_response += f"\n\nπŸ“„ **Source:** {', '.join(set(doc.metadata.get('source', 'Unknown') for doc in retrieved_docs))}"
193
+
194
+ # βœ… Debugging logs
195
+ print(f"πŸ€– LLM Response: {assistant_response[:300]}") # βœ… Limit output for debugging
196
+
197
+ history.append((query, assistant_response))
198
+ return history, ""
199
+
200
+
201
+ def initialize_chatbot(vector_db):
202
+ """ βœ… Initializes chatbot with improved retrieval & processing. """
203
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
204
+
205
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
206
+
207
+ retriever = vector_db.as_retriever(search_kwargs={"k": 5, "search_type": "similarity"})
208
+
209
+ system_prompt = """You are an AI assistant that answers questions **ONLY based on the provided documents**.
210
+ - **If no relevant documents are retrieved, respond with: "I couldn't find any relevant information."**
211
+ - **If the meaning of the query does not match the retrieved documents, say "I couldn't find any relevant information."**
212
+ - **Do NOT attempt to answer from general knowledge.**
213
+ """
214
+
215
  llm = HuggingFaceEndpoint(
216
+ repo_id="mistralai/Mistral-7B-Instruct-v0.2",
217
  huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
218
  temperature=0.1,
219
+ max_new_tokens=400,
220
  task="text-generation",
221
  system_prompt=system_prompt
222
  )
223
+
224
  qa_chain = ConversationalRetrievalChain.from_llm(
225
  llm=llm,
226
  retriever=retriever,
 
228
  return_source_documents=True,
229
  verbose=False
230
  )
 
 
231
 
232
+ return retriever, qa_chain, embeddings
233
+
234
+
235
+ def process_image_and_generate_query(image):
236
+ classifier = GeometryImageClassifier()
237
+ geometry_type = classifier.process_image(image)
 
 
 
 
 
 
 
238
 
239
+ query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
240
+ return geometry_type, query
241
 
242
  def demo():
243
+ # Initialize classifier once at startup
244
+ classifier = GeometryImageClassifier()
245
+ classifier.initialize_reference_embeddings("images")
246
+
247
+ # Initialize chatbot components
248
+ retriever, qa_chain, embeddings = initialize_chatbot(create_db(load_documents()))
249
 
250
  with gr.Blocks() as app:
251
+ gr.Markdown("### πŸ€– **Fastener Agent with Image Recognition** πŸ“š")
 
 
 
 
252
 
253
+ with gr.Row():
254
+ with gr.Column(scale=1):
255
+ image_input = gr.Image(type="filepath", label="Upload Geometry Image")
256
+ geometry_label = gr.Textbox(label="Detected Geometry Type", interactive=False)
257
+
258
+ with gr.Column(scale=2):
259
+ chatbot = gr.Chatbot()
260
+ query_input = gr.Textbox(label="Ask me a question")
261
+ query_btn = gr.Button("Ask")
262
+
263
+ def image_upload_handler(image):
264
+ if image is None:
265
+ return "", ""
266
+ # Use the initialized classifier
267
+ geometry_type = classifier.process_image(image)
268
+ suggested_query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
269
+ return geometry_type, suggested_query
270
+
271
  def user_query_handler(query, history):
272
  return handle_query(query, history, retriever, qa_chain, embeddings)
273
+
274
+ image_input.change(
275
+ image_upload_handler,
276
+ inputs=[image_input],
277
+ outputs=[geometry_label, query_input]
278
+ )
279
 
280
  query_btn.click(
281
  user_query_handler,
 
288
  inputs=[query_input, chatbot],
289
  outputs=[chatbot, query_input]
290
  )
291
+
292
  app.launch()
293
 
294
  if __name__ == "__main__":
295
+ demo()