dnzblgn commited on
Commit
8b39d97
·
verified ·
1 Parent(s): e987c24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -6
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  import os
3
  import docx
4
  import numpy as np
 
 
5
  from sentence_transformers import SentenceTransformer
6
  from sklearn.metrics.pairwise import cosine_similarity
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -11,6 +13,66 @@ from langchain.memory import ConversationBufferMemory
11
  from langchain_community.llms import HuggingFaceEndpoint
12
  from langchain_huggingface import HuggingFaceEmbeddings
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Initialize semantic model
15
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
16
 
@@ -138,20 +200,51 @@ def handle_query(query, history, retriever, qa_chain, embeddings):
138
  history.append((query, assistant_response))
139
  return history, ""
140
 
 
 
 
 
 
 
 
141
  def demo():
 
 
 
 
 
142
  documents = load_documents()
143
  vector_db, embeddings = create_db(documents)
144
  retriever, qa_chain = initialize_chatbot(vector_db, embeddings)
145
 
146
  with gr.Blocks() as app:
147
- gr.Markdown("### 🤖 Document Question Answering System")
148
-
149
- chatbot = gr.Chatbot()
150
- query_input = gr.Textbox(label="Ask a question about the documents")
151
- query_btn = gr.Button("Submit")
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def user_query_handler(query, history):
154
  return handle_query(query, history, retriever, qa_chain, embeddings)
 
 
 
 
 
 
155
 
156
  query_btn.click(
157
  user_query_handler,
@@ -164,7 +257,7 @@ def demo():
164
  inputs=[query_input, chatbot],
165
  outputs=[chatbot, query_input]
166
  )
167
-
168
  app.launch()
169
 
170
  if __name__ == "__main__":
 
2
  import os
3
  import docx
4
  import numpy as np
5
+ import cv2
6
+ from pathlib import Path
7
  from sentence_transformers import SentenceTransformer
8
  from sklearn.metrics.pairwise import cosine_similarity
9
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
13
  from langchain_community.llms import HuggingFaceEndpoint
14
  from langchain_huggingface import HuggingFaceEmbeddings
15
 
16
+ class SimpleGeometryClassifier:
17
+ def __init__(self):
18
+ self.reference_embeddings = {
19
+ "flat.png": {
20
+ "embedding": None,
21
+ "label": "Flat or Sheet-Based"
22
+ },
23
+ "cylindrical.png": {
24
+ "embedding": None,
25
+ "label": "Cylindrical"
26
+ },
27
+ "complex.png": {
28
+ "embedding": None,
29
+ "label": "Complex Multi Axis Geometry"
30
+ }
31
+ }
32
+
33
+ def compute_embedding(self, image_path):
34
+ img = cv2.imread(image_path)
35
+ img = cv2.resize(img, (224, 224))
36
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
37
+
38
+ win_size = (224, 224)
39
+ cell_size = (8, 8)
40
+ block_size = (16, 16)
41
+ block_stride = (8, 8)
42
+ num_bins = 9
43
+
44
+ hog = cv2.HOGDescriptor(win_size, block_size, block_stride, cell_size, num_bins)
45
+ embedding = hog.compute(img)
46
+
47
+ return embedding.flatten()
48
+
49
+ def initialize_reference_embeddings(self, reference_folder):
50
+ for image_name in self.reference_embeddings.keys():
51
+ image_path = str(Path(reference_folder) / image_name)
52
+ if Path(image_path).exists():
53
+ self.reference_embeddings[image_name]["embedding"] = self.compute_embedding(image_path)
54
+ else:
55
+ print(f"Warning: Reference image {image_path} not found")
56
+
57
+ def find_closest_geometry(self, query_embedding):
58
+ best_similarity = -1
59
+ best_label = None
60
+
61
+ for ref_data in self.reference_embeddings.values():
62
+ if ref_data["embedding"] is not None:
63
+ similarity = np.dot(query_embedding, ref_data["embedding"]) / (
64
+ np.linalg.norm(query_embedding) * np.linalg.norm(ref_data["embedding"])
65
+ )
66
+ if similarity > best_similarity:
67
+ best_similarity = similarity
68
+ best_label = ref_data["label"]
69
+
70
+ return best_label or "Unknown Geometry"
71
+
72
+ def process_image(self, image_path):
73
+ query_embedding = self.compute_embedding(image_path)
74
+ return self.find_closest_geometry(query_embedding)
75
+
76
  # Initialize semantic model
77
  semantic_model = SentenceTransformer("all-MiniLM-L6-v2")
78
 
 
200
  history.append((query, assistant_response))
201
  return history, ""
202
 
203
+ def process_image_and_generate_query(image):
204
+ classifier = SimpleGeometryClassifier()
205
+ classifier.initialize_reference_embeddings("images")
206
+ geometry_type = classifier.process_image(image)
207
+ query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
208
+ return geometry_type, query
209
+
210
  def demo():
211
+ # Initialize classifier
212
+ classifier = SimpleGeometryClassifier()
213
+ classifier.initialize_reference_embeddings("images")
214
+
215
+ # Initialize chatbot components
216
  documents = load_documents()
217
  vector_db, embeddings = create_db(documents)
218
  retriever, qa_chain = initialize_chatbot(vector_db, embeddings)
219
 
220
  with gr.Blocks() as app:
221
+ gr.Markdown("### 🤖 Fastener Agent with Image Recognition 📚")
 
 
 
 
222
 
223
+ with gr.Row():
224
+ with gr.Column(scale=1):
225
+ image_input = gr.Image(type="filepath", label="Upload Geometry Image")
226
+ geometry_label = gr.Textbox(label="Detected Geometry Type", interactive=False)
227
+
228
+ with gr.Column(scale=2):
229
+ chatbot = gr.Chatbot()
230
+ query_input = gr.Textbox(label="Ask a question about the documents")
231
+ query_btn = gr.Button("Submit")
232
+
233
+ def image_upload_handler(image):
234
+ if image is None:
235
+ return "", ""
236
+ geometry_type = classifier.process_image(image)
237
+ suggested_query = f"I have a {geometry_type} geometry, which screw should I use and what is the best machine to use for {geometry_type} geometry?"
238
+ return geometry_type, suggested_query
239
+
240
  def user_query_handler(query, history):
241
  return handle_query(query, history, retriever, qa_chain, embeddings)
242
+
243
+ image_input.change(
244
+ image_upload_handler,
245
+ inputs=[image_input],
246
+ outputs=[geometry_label, query_input]
247
+ )
248
 
249
  query_btn.click(
250
  user_query_handler,
 
257
  inputs=[query_input, chatbot],
258
  outputs=[chatbot, query_input]
259
  )
260
+
261
  app.launch()
262
 
263
  if __name__ == "__main__":