dnzblgn commited on
Commit
904ea81
Β·
verified Β·
1 Parent(s): 5657823

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -98
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import os
3
- import docx
4
- import numpy as np
5
  import cv2
 
6
  from pathlib import Path
7
  from sentence_transformers import SentenceTransformer
8
  from sklearn.metrics.pairwise import cosine_similarity
@@ -13,50 +13,53 @@ from langchain.memory import ConversationBufferMemory
13
  from langchain_community.llms import HuggingFaceEndpoint
14
  from langchain_huggingface import HuggingFaceEmbeddings
15
 
 
 
 
 
16
  class SimpleGeometryClassifier:
17
  def __init__(self):
18
  self.reference_embeddings = {
19
- "flat.png": {
20
- "embedding": None,
21
- "label": "Flat or Sheet-Based"
22
- },
23
- "cylindrical.png": {
24
- "embedding": None,
25
- "label": "Cylindrical"
26
- },
27
- "complex.png": {
28
- "embedding": None,
29
- "label": "Complex Multi Axis Geometry"
30
- }
31
  }
32
 
33
  def compute_embedding(self, image_path):
34
- img = cv2.imread(image_path)
35
- img = cv2.resize(img, (224, 224))
36
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
37
-
38
- win_size = (224, 224)
39
  cell_size = (8, 8)
40
- block_size = (16, 16)
41
- block_stride = (8, 8)
42
  num_bins = 9
43
-
44
  hog = cv2.HOGDescriptor(win_size, block_size, block_stride, cell_size, num_bins)
45
  embedding = hog.compute(img)
46
-
 
 
 
47
  return embedding.flatten()
48
 
49
  def initialize_reference_embeddings(self, reference_folder):
 
50
  for image_name in self.reference_embeddings.keys():
51
  image_path = str(Path(reference_folder) / image_name)
52
  if Path(image_path).exists():
53
  self.reference_embeddings[image_name]["embedding"] = self.compute_embedding(image_path)
54
  else:
55
- print(f"Warning: Reference image {image_path} not found")
 
 
 
 
 
56
 
57
  def find_closest_geometry(self, query_embedding):
58
  best_similarity = -1
59
- best_label = None
60
 
61
  for ref_data in self.reference_embeddings.values():
62
  if ref_data["embedding"] is not None:
@@ -67,14 +70,16 @@ class SimpleGeometryClassifier:
67
  best_similarity = similarity
68
  best_label = ref_data["label"]
69
 
70
- return best_label or "Unknown Geometry"
71
 
72
- def process_image(self, image_path):
73
- query_embedding = self.compute_embedding(image_path)
74
- return self.find_closest_geometry(query_embedding)
75
 
76
- # Initialize semantic model
77
- semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
 
78
 
79
  def extract_text_from_docx(file_path):
80
  doc = docx.Document(file_path)
@@ -154,16 +159,23 @@ def validate_query_semantically(query, retrieved_docs):
154
 
155
  return similarity_score >= 0.3
156
 
157
- def initialize_chatbot(vector_db, embeddings):
 
158
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
159
-
 
 
 
 
160
  retriever = vector_db.as_retriever(search_kwargs={"k": 5})
161
-
162
  system_prompt = """You are an AI assistant that answers questions ONLY based on the provided documents.
163
  - If no relevant documents are retrieved, respond with: "I couldn't find any relevant information."
164
- - If the meaning of the query does not match the retrieved documents, say "I couldn't find any relevant information."
165
- - Do NOT attempt to answer from general knowledge."""
166
-
 
 
167
  llm = HuggingFaceEndpoint(
168
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
169
  huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
@@ -172,93 +184,57 @@ def initialize_chatbot(vector_db, embeddings):
172
  task="text-generation",
173
  system_prompt=system_prompt
174
  )
175
-
176
  qa_chain = ConversationalRetrievalChain.from_llm(
177
- llm=llm,
178
- retriever=retriever,
179
- memory=memory,
180
- return_source_documents=True,
181
- verbose=False
182
  )
183
-
184
- return retriever, qa_chain
185
 
186
- def handle_query(query, history, retriever, qa_chain, embeddings):
 
 
 
 
 
 
 
 
 
 
 
187
  retrieved_docs = retrieve_documents(query, retriever, embeddings)
188
 
189
- if not retrieved_docs or not validate_query_semantically(query, retrieved_docs):
190
  return history + [(query, "I couldn't find any relevant information.")], ""
191
 
192
  response = qa_chain.invoke({"question": query, "chat_history": history})
193
  assistant_response = response['answer'].strip()
194
 
195
- if not validate_query_semantically(query, retrieved_docs):
196
- assistant_response = "I couldn't find any relevant information."
197
-
198
  assistant_response += f"\n\nπŸ“„ Source: {', '.join(set(doc.metadata.get('source', 'Unknown') for doc in retrieved_docs))}"
199
 
200
  history.append((query, assistant_response))
201
  return history, ""
202
 
203
- def process_image_and_generate_query(image):
204
- classifier = SimpleGeometryClassifier()
205
- classifier.initialize_reference_embeddings("images")
206
- geometry_type = classifier.process_image(image)
207
- query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
208
- return geometry_type, query
209
 
210
  def demo():
211
- # Initialize classifier
212
- classifier = SimpleGeometryClassifier()
213
- classifier.initialize_reference_embeddings("images")
214
-
215
- # Initialize chatbot components
216
- documents = load_documents()
217
- vector_db, embeddings = create_db(documents)
218
- retriever, qa_chain = initialize_chatbot(vector_db, embeddings)
219
-
220
  with gr.Blocks() as app:
221
- gr.Markdown("### πŸ€– Fastener Agent with Image Recognition πŸ“š")
222
-
223
  with gr.Row():
224
  with gr.Column(scale=1):
225
- image_input = gr.Image(type="filepath", label="Upload Geometry Image")
226
- geometry_label = gr.Textbox(label="Detected Geometry Type", interactive=False)
227
-
228
  with gr.Column(scale=2):
229
  chatbot = gr.Chatbot()
230
- query_input = gr.Textbox(label="Ask a question about the documents")
231
  query_btn = gr.Button("Submit")
232
 
233
- def image_upload_handler(image):
234
- if image is None:
235
- return "", ""
236
- geometry_type = classifier.process_image(image)
237
- suggested_query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
238
- return geometry_type, suggested_query
239
-
240
- def user_query_handler(query, history):
241
- return handle_query(query, history, retriever, qa_chain, embeddings)
242
-
243
- image_input.change(
244
- image_upload_handler,
245
- inputs=[image_input],
246
- outputs=[geometry_label, query_input]
247
- )
248
-
249
- query_btn.click(
250
- user_query_handler,
251
- inputs=[query_input, chatbot],
252
- outputs=[chatbot, query_input]
253
- )
254
-
255
- query_input.submit(
256
- user_query_handler,
257
- inputs=[query_input, chatbot],
258
- outputs=[chatbot, query_input]
259
- )
260
 
261
  app.launch()
262
 
 
263
  if __name__ == "__main__":
264
  demo()
 
1
  import gradio as gr
2
  import os
3
+ import gc
 
4
  import cv2
5
+ import numpy as np
6
  from pathlib import Path
7
  from sentence_transformers import SentenceTransformer
8
  from sklearn.metrics.pairwise import cosine_similarity
 
13
  from langchain_community.llms import HuggingFaceEndpoint
14
  from langchain_huggingface import HuggingFaceEmbeddings
15
 
16
+ # βœ… Semantic model for query validation
17
+ semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
18
+
19
+ # βœ… Optimized Image Classifier
20
  class SimpleGeometryClassifier:
21
  def __init__(self):
22
  self.reference_embeddings = {
23
+ "flat.png": {"embedding": None, "label": "Flat or Sheet-Based"},
24
+ "cylindrical.png": {"embedding": None, "label": "Cylindrical"},
25
+ "complex.png": {"embedding": None, "label": "Complex Multi Axis Geometry"}
 
 
 
 
 
 
 
 
 
26
  }
27
 
28
  def compute_embedding(self, image_path):
29
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
30
+ img = cv2.resize(img, (128, 128))
31
+
32
+ win_size = (128, 128)
 
33
  cell_size = (8, 8)
34
+ block_size = (8, 8)
35
+ block_stride = (4, 4)
36
  num_bins = 9
37
+
38
  hog = cv2.HOGDescriptor(win_size, block_size, block_stride, cell_size, num_bins)
39
  embedding = hog.compute(img)
40
+
41
+ # βœ… Free OpenCV resources
42
+ cv2.destroyAllWindows()
43
+
44
  return embedding.flatten()
45
 
46
  def initialize_reference_embeddings(self, reference_folder):
47
+ """ Load reference embeddings for classification """
48
  for image_name in self.reference_embeddings.keys():
49
  image_path = str(Path(reference_folder) / image_name)
50
  if Path(image_path).exists():
51
  self.reference_embeddings[image_name]["embedding"] = self.compute_embedding(image_path)
52
  else:
53
+ print(f"Warning: Missing reference image: {image_path}")
54
+
55
+ def process_image(self, image_path):
56
+ """ Classify uploaded image """
57
+ query_embedding = self.compute_embedding(image_path)
58
+ return self.find_closest_geometry(query_embedding)
59
 
60
  def find_closest_geometry(self, query_embedding):
61
  best_similarity = -1
62
+ best_label = "Unknown Geometry"
63
 
64
  for ref_data in self.reference_embeddings.values():
65
  if ref_data["embedding"] is not None:
 
70
  best_similarity = similarity
71
  best_label = ref_data["label"]
72
 
73
+ return best_label
74
 
 
 
 
75
 
76
+ # βœ… Initialize Image Classifier
77
+ classifier = SimpleGeometryClassifier()
78
+ classifier.initialize_reference_embeddings("images")
79
+
80
+ # βœ… Initialize Chatbot Once
81
+ retriever, qa_chain, embeddings = None, None, None
82
+ retriever, qa_chain, embeddings = initialize_chatbot()
83
 
84
  def extract_text_from_docx(file_path):
85
  doc = docx.Document(file_path)
 
159
 
160
  return similarity_score >= 0.3
161
 
162
+ # βœ… Initialize Chatbot
163
+ def initialize_chatbot():
164
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, output_key='answer')
165
+
166
+ documents = load_documents()
167
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
168
+ vector_db = FAISS.from_documents(documents, embeddings)
169
+
170
  retriever = vector_db.as_retriever(search_kwargs={"k": 5})
171
+
172
  system_prompt = """You are an AI assistant that answers questions ONLY based on the provided documents.
173
  - If no relevant documents are retrieved, respond with: "I couldn't find any relevant information."
174
+ - Do NOT answer from general knowledge."""
175
+
176
+ # βœ… Free memory before LLM call
177
+ gc.collect()
178
+
179
  llm = HuggingFaceEndpoint(
180
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
181
  huggingfacehub_api_token=os.environ.get("HUGGINGFACE_API_TOKEN"),
 
184
  task="text-generation",
185
  system_prompt=system_prompt
186
  )
187
+
188
  qa_chain = ConversationalRetrievalChain.from_llm(
189
+ llm=llm, retriever=retriever, memory=memory, return_source_documents=True, verbose=False
 
 
 
 
190
  )
 
 
191
 
192
+ return retriever, qa_chain, embeddings
193
+
194
+ def process_image_and_generate_query(image_path):
195
+ """ Run Image Classification Separately and Generate Query """
196
+ geometry_type = classifier.process_image(image_path)
197
+ query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
198
+
199
+ # βœ… Free up memory **before** calling API
200
+ gc.collect()
201
+ return geometry_type, query
202
+
203
+ def handle_query(query, history):
204
  retrieved_docs = retrieve_documents(query, retriever, embeddings)
205
 
206
+ if not retrieved_docs:
207
  return history + [(query, "I couldn't find any relevant information.")], ""
208
 
209
  response = qa_chain.invoke({"question": query, "chat_history": history})
210
  assistant_response = response['answer'].strip()
211
 
 
 
 
212
  assistant_response += f"\n\nπŸ“„ Source: {', '.join(set(doc.metadata.get('source', 'Unknown') for doc in retrieved_docs))}"
213
 
214
  history.append((query, assistant_response))
215
  return history, ""
216
 
 
 
 
 
 
 
217
 
218
  def demo():
 
 
 
 
 
 
 
 
 
219
  with gr.Blocks() as app:
220
+ gr.Markdown("### πŸ”© Fastener Selection Assistant")
221
+
222
  with gr.Row():
223
  with gr.Column(scale=1):
224
+ image_input = gr.Image(type="numpy", label="Upload Geometry Image")
225
+ geometry_label = gr.Textbox(label="Detected Geometry", interactive=False)
226
+
227
  with gr.Column(scale=2):
228
  chatbot = gr.Chatbot()
229
+ query_input = gr.Textbox(label="Ask a question")
230
  query_btn = gr.Button("Submit")
231
 
232
+ image_input.change(image_upload_handler, inputs=[image_input], outputs=[geometry_label, query_input])
233
+ query_btn.click(handle_query, inputs=[query_input, chatbot], outputs=[chatbot, query_input])
234
+ query_input.submit(handle_query, inputs=[query_input, chatbot], outputs=[chatbot, query_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  app.launch()
237
 
238
+
239
  if __name__ == "__main__":
240
  demo()