mgetz commited on
Commit
0ab8474
·
verified ·
1 Parent(s): be6f1eb

I added the semantic search feature to this chatbot

Browse files
Files changed (1) hide show
  1. app.py +111 -1
app.py CHANGED
@@ -1,6 +1,108 @@
1
  import gradio as gr
2
  import random
3
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  client = InferenceClient("Qwen/Qwen2.5-7B-Instruct-1M")
6
 
@@ -23,7 +125,15 @@ def echo(message, history):
23
  def yes_no(message, history):
24
  responses = ["Yes", "No"]
25
  return random.choice(responses)
26
-
27
  chatbot = gr.ChatInterface(respond, type="messages")
28
 
 
 
 
 
 
 
 
 
29
  chatbot.launch()
 
1
  import gradio as gr
2
  import random
3
  from huggingface_hub import InferenceClient
4
+ #STEP 1: (Import Sentence Transformer Library and Torch)
5
+ from sentence_transformers import SentenceTransformer
6
+ import torch
7
+
8
+ # ===== LOAD & PROCESS YOUR NEW CONTENT =====
9
+ #STEP 2: (Load/process text file)
10
+ # Open the tooth_brushin_text.txt file in read mode with UTF-8 encoding
11
+ with open("tooth_brushin_text.txt", "r", encoding="utf-8") as file:
12
+ # Read the entire contents of the file and store it in a variable
13
+ tooth_brushin_text = file.read()
14
+
15
+ # Print the text below
16
+ print(tooth_brushin_text)
17
+
18
+ # ===== APPLY THE COMPLETE WORKFLOW =====
19
+ #STEP 3: (Split text file by chunk (BY SENTENCE) clean/strip chunks)
20
+ def preprocess_text(text):
21
+ # Strip extra whitespace from the beginning and the end of the text
22
+ cleaned_text = text.strip()
23
+
24
+ # Split the cleaned_text by every period
25
+ chunks = cleaned_text.split(".")
26
+
27
+ # Create an empty list to store cleaned chunks
28
+ cleaned_chunks = []
29
+
30
+ # Write your for-in loop below to clean each chunk and add it to the cleaned_chunks list
31
+ for chunk in chunks:
32
+ stripped_chunk = chunk.strip()
33
+ if len(stripped_chunk) > 0:
34
+ cleaned_chunks.append(chunk)
35
+
36
+ # Print cleaned_chunks
37
+ print(cleaned_chunks)
38
+
39
+ num_of_chunks = len(cleaned_chunks)
40
+
41
+ # Print the length of cleaned_chunks
42
+ print(f"There are {num_of_chunks} chunks.")
43
+
44
+ # Return the cleaned_chunks
45
+ return cleaned_chunks
46
+
47
+ # Call the preprocess_text function and store the result in a cleaned_chunks variable
48
+ cleaned_chunks = preprocess_text(tooth_brushin_text)
49
+
50
+ #STEP 4: (Convert Chunks into vectors)
51
+ # Load the pre-trained embedding model that converts text to vectors
52
+ model = SentenceTransformer('all-MiniLM-L6-v2')
53
+
54
+ def create_embeddings(text_chunks):
55
+ # Convert each text chunk into a vector embedding and store as a tensor
56
+ chunk_embeddings = model.encode(text_chunks, convert_to_tensor=True)
57
+
58
+ # Print the chunk embeddings
59
+ print(chunk_embeddings)
60
+
61
+ # Print the shape of chunk_embeddings
62
+ print(chunk_embeddings.shape)
63
+
64
+ # Return the chunk_embeddings
65
+ return chunk_embeddings
66
+
67
+ # Call the create_embeddings function and store the result in a new chunk_embeddings variable
68
+ chunk_embeddings = create_embeddings(cleaned_chunks)
69
+
70
+ #STEP 5: (Convert query into vectors, find most relevant 3 chunks as vectors, convert those 3 chunks back into text, output text)
71
+ # Define a function to find the most relevant text chunks for a given query, chunk_embeddings, and text_chunks
72
+ def get_top_chunks(query, chunk_embeddings, text_chunks):
73
+ # Convert the query text into a vector embedding
74
+ query_embedding = model.encode(query, convert_to_tensor=True) # Complete this line
75
+
76
+ # Normalize the query embedding to unit length for accurate similarity comparison
77
+ query_embedding_normalized = query_embedding / query_embedding.norm()
78
+
79
+ # Normalize all chunk embeddings to unit length for consistent comparison
80
+ chunk_embeddings_normalized = chunk_embeddings / chunk_embeddings.norm(dim=1, keepdim=True)
81
+
82
+ # Calculate cosine similarity between query and all chunks using matrix multiplication
83
+ similarities = torch.matmul(chunk_embeddings_normalized, query_embedding_normalized) # Complete this line
84
+
85
+ # Print the similarities
86
+ print(similarities)
87
+
88
+ # Find the indices of the 3 chunks with highest similarity scores
89
+ top_indices = torch.topk(similarities, k=3).indices
90
+
91
+ # Print the top indices
92
+ print(top_indices)
93
+
94
+ # Create an empty list to store the most relevant chunks
95
+ top_chunks = []
96
+
97
+ # Loop through the top indices and retrieve the corresponding text chunks
98
+ for index in top_indices:
99
+ relevant_text_chunk = text_chunks[index]
100
+ top_chunks.append(relevant_text_chunk)
101
+
102
+ # Return the list of most relevant chunks
103
+ return top_chunks
104
+ #STEP 6:
105
+
106
 
107
  client = InferenceClient("Qwen/Qwen2.5-7B-Instruct-1M")
108
 
 
125
  def yes_no(message, history):
126
  responses = ["Yes", "No"]
127
  return random.choice(responses)
128
+
129
  chatbot = gr.ChatInterface(respond, type="messages")
130
 
131
+ # Call the get_top_chunks function with the original query
132
+ top_results = get_top_chunks(message, chunk_embeddings, cleaned_chunks)
133
+
134
+ # Print the top results
135
+ print(top_results)
136
+
137
+
138
+
139
  chatbot.launch()