nazib61 commited on
Commit
a472bce
·
verified ·
1 Parent(s): 7401237

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -37
app.py CHANGED
@@ -3,70 +3,84 @@ from datasets import load_dataset
3
  from qdrant_client import QdrantClient, models
4
  from sentence_transformers import SentenceTransformer
5
  import torch # Ensure torch is imported
 
 
 
 
 
6
 
7
  # --- Configuration ---
8
- # Use ":memory:" for a temporary, in-memory database.
9
- # Or use a path like "./qdrant_db" to save the data to disk.
10
- # Using a path is better for Spaces as data will be rebuilt only when the code changes.
11
  QDRANT_PATH = "./qdrant_db"
12
  COLLECTION_NAME = "my_text_collection"
13
- MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'
14
 
15
  # --- Load Model ---
16
- # Specify that the model should run on the CPU, which is standard for HF Spaces
17
  device = "cpu"
18
  model = SentenceTransformer(MODEL_NAME, device=device)
19
 
20
  # --- Qdrant Client and Collection Setup ---
21
- # Initialize Qdrant client to use a local, on-disk storage
22
- # This avoids the need to run a separate Qdrant server
23
  qdrant_client = QdrantClient(path=QDRANT_PATH)
24
 
25
  # Check if the collection already exists
 
26
  try:
27
  collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME)
28
  print("Collection already exists.")
 
29
  except Exception as e:
30
- print("Collection not found, creating a new one...")
31
- # --- Load Dataset ---
32
- # We only load the dataset and create embeddings if the collection doesn't exist
 
 
 
33
  dataset = load_dataset("ag_news", split="test")
34
- # Limiting the dataset for a quicker demo setup
35
- data = [item['text'] for item in dataset][:1000]
 
36
 
37
- # Create the collection
 
 
38
  qdrant_client.create_collection(
39
  collection_name=COLLECTION_NAME,
40
- vectors_config=models.VectorParams(size=model.get_sentence_embedding_dimension(), distance=models.Distance.COSINE),
41
  )
42
 
43
- # --- Generate and Index Embeddings ---
44
  print("Generating and indexing embeddings...")
45
- # This can take a moment on the first run
46
- qdrant_client.add(
 
 
 
 
 
 
 
 
 
 
 
 
47
  collection_name=COLLECTION_NAME,
48
- documents=data,
49
- ids=list(range(len(data))), # Simple sequential IDs
50
- embedding_model=model
51
  )
52
  print("Embeddings indexed successfully.")
53
 
54
 
55
  # --- Search Function ---
56
  def search_in_qdrant(query):
57
- """
58
- Takes a user query, generates its embedding, and searches in Qdrant.
59
- """
60
  if not query:
61
  return "Please enter a search query."
62
 
63
- # The client's search function can now take the model directly
 
 
64
  hits = qdrant_client.search(
65
  collection_name=COLLECTION_NAME,
66
- query_text=query,
67
- query_filter=None, # No filters for now
68
- limit=5, # Return the top 5 most similar results
69
- embedding_model=model
70
  )
71
 
72
  results_text = ""
@@ -74,23 +88,131 @@ def search_in_qdrant(query):
74
  return "No results found."
75
 
76
  for hit in hits:
77
- results_text += f"**Score:** {hit.score:.4f}\n"
78
- results_text += f"**Text:** {hit.payload['document']}\n\n" # Payload key is 'document' when using .add()
 
 
 
 
 
79
 
80
  return results_text
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # --- Gradio Interface ---
83
  with gr.Blocks() as demo:
84
  gr.Markdown("# Semantic Search with Qdrant and Gradio")
85
  gr.Markdown("Enter a query to search for similar news articles from the AG News dataset.")
86
 
87
- with gr.Row():
88
- search_input = gr.Textbox(label="Search Query", placeholder="e.g., 'Latest news on space exploration'")
89
-
90
- search_button = gr.Button("Search")
91
- search_output = gr.Markdown()
92
-
93
- search_button.click(search_in_qdrant, inputs=search_input, outputs=search_output)
 
 
 
 
 
 
 
 
94
 
95
  if __name__ == "__main__":
96
  demo.launch()
 
3
  from qdrant_client import QdrantClient, models
4
  from sentence_transformers import SentenceTransformer
5
  import torch # Ensure torch is imported
6
+ import os
7
+ import shutil
8
+ import PyPDF2
9
+ from docx import Document
10
+ import pandas as pd
11
 
12
  # --- Configuration ---
 
 
 
13
  QDRANT_PATH = "./qdrant_db"
14
  COLLECTION_NAME = "my_text_collection"
15
+ MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2' # Better model for semantic similarity
16
 
17
  # --- Load Model ---
 
18
  device = "cpu"
19
  model = SentenceTransformer(MODEL_NAME, device=device)
20
 
21
  # --- Qdrant Client and Collection Setup ---
 
 
22
  qdrant_client = QdrantClient(path=QDRANT_PATH)
23
 
24
  # Check if the collection already exists
25
+ collection_exists = False
26
  try:
27
  collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME)
28
  print("Collection already exists.")
29
+ collection_exists = True
30
  except Exception as e:
31
+ print(f"Collection not found: {e}, creating a new one...")
32
+ collection_exists = False
33
+
34
+ # If collection doesn't exist, create it and populate with data
35
+ if not collection_exists:
36
+ # Load dataset and convert to a simple list format
37
  dataset = load_dataset("ag_news", split="test")
38
+ # Convert dataset to pandas dataframe to properly access the text column
39
+ df = dataset.to_pandas()
40
+ data = df['text'].tolist()[:1000] # Get first 1000 text entries
41
 
42
+ # Create the collection with proper vector configuration
43
+ # Use the correct vector size for the selected model
44
+ vector_size = model.get_sentence_embedding_dimension() or 768 # Get the actual embedding size of the model, default to 768 for mpnet
45
  qdrant_client.create_collection(
46
  collection_name=COLLECTION_NAME,
47
+ vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE),
48
  )
49
 
50
+ # Generate embeddings manually to ensure compatibility
51
  print("Generating and indexing embeddings...")
52
+ embeddings = model.encode(data)
53
+
54
+ # Prepare points for insertion
55
+ points = []
56
+ for i, (text, embedding) in enumerate(zip(data, embeddings)):
57
+ point = models.PointStruct(
58
+ id=i,
59
+ vector=embedding.tolist(),
60
+ payload={"document": text}
61
+ )
62
+ points.append(point)
63
+
64
+ # Upload points to the collection
65
+ qdrant_client.upsert(
66
  collection_name=COLLECTION_NAME,
67
+ points=points
 
 
68
  )
69
  print("Embeddings indexed successfully.")
70
 
71
 
72
  # --- Search Function ---
73
  def search_in_qdrant(query):
 
 
 
74
  if not query:
75
  return "Please enter a search query."
76
 
77
+ # Generate embedding for the query
78
+ query_embedding = model.encode([query])[0].tolist()
79
+
80
  hits = qdrant_client.search(
81
  collection_name=COLLECTION_NAME,
82
+ query_vector=query_embedding,
83
+ limit=5,
 
 
84
  )
85
 
86
  results_text = ""
 
88
  return "No results found."
89
 
90
  for hit in hits:
91
+ # Check if payload exists and has the document key
92
+ if hit.payload and 'document' in hit.payload:
93
+ results_text += f"**Score:** {hit.score:.4f}\n"
94
+ results_text += f"**Text:** {hit.payload['document']}\n\n"
95
+ else:
96
+ results_text += f"**Score:** {hit.score:.4f}\n"
97
+ results_text += f"**Text:** [No document content available]\n\n"
98
 
99
  return results_text
100
 
101
+ # --- Upload Function ---
102
+ def extract_text_from_file(file_path):
103
+ """Extract text from various file types"""
104
+ file_extension = file_path.lower().split('.')[-1]
105
+
106
+ if file_extension == 'txt':
107
+ with open(file_path, 'r', encoding='utf-8') as f:
108
+ return f.read()
109
+ elif file_extension == 'pdf':
110
+ text = ""
111
+ with open(file_path, 'rb') as f:
112
+ pdf_reader = PyPDF2.PdfReader(f)
113
+ for page in pdf_reader.pages:
114
+ text += page.extract_text() + "\n"
115
+ return text
116
+ elif file_extension in ['docx', 'doc']:
117
+ doc = Document(file_path)
118
+ text = ""
119
+ for paragraph in doc.paragraphs:
120
+ text += paragraph.text + "\n"
121
+ return text
122
+ elif file_extension in ['csv', 'xlsx', 'xls']:
123
+ if file_extension == 'csv':
124
+ df = pd.read_csv(file_path)
125
+ else:
126
+ df = pd.read_excel(file_path)
127
+ # Convert the entire dataframe to text
128
+ return df.to_string()
129
+ else:
130
+ # Try to read as plain text
131
+ try:
132
+ with open(file_path, 'r', encoding='utf-8') as f:
133
+ return f.read()
134
+ except UnicodeDecodeError:
135
+ # If UTF-8 fails, try with different encoding
136
+ try:
137
+ with open(file_path, 'r', encoding='latin-1') as f:
138
+ return f.read()
139
+ except:
140
+ return "Could not read file: unsupported format or encoding issue"
141
+
142
+ def upload_to_qdrant(text_content, file_upload=None):
143
+ if not text_content and not file_upload:
144
+ return "Please provide text content or upload a file."
145
+
146
+ documents_to_add = []
147
+
148
+ # Add text content if provided
149
+ if text_content:
150
+ documents_to_add.append(text_content)
151
+
152
+ # Process uploaded file if provided
153
+ if file_upload:
154
+ try:
155
+ content = extract_text_from_file(file_upload.name)
156
+ documents_to_add.append(content)
157
+ except Exception as e:
158
+ return f"Error reading file: {str(e)}"
159
+
160
+ if not documents_to_add:
161
+ return "No content to upload."
162
+
163
+ # Get the next available ID by checking the current max ID in the collection
164
+ # For simplicity, we'll just get the count of existing records and start from there
165
+ max_id = 0 # Default to 0 if we can't get the count
166
+ try:
167
+ collection_info = qdrant_client.get_collection(collection_name=COLLECTION_NAME)
168
+ if hasattr(collection_info, 'points_count') and collection_info.points_count is not None:
169
+ current_count = collection_info.points_count
170
+ max_id = current_count # Start from the current count
171
+ except:
172
+ max_id = 0 # If there's an error, start with 0
173
+
174
+ # Generate embeddings for the new documents
175
+ embeddings = model.encode(documents_to_add)
176
+
177
+ # Prepare points for insertion
178
+ points = []
179
+ for i, (doc, embedding) in enumerate(zip(documents_to_add, embeddings)):
180
+ point_id = max_id + i + 1 # IDs will be automatically converted as needed by Qdrant
181
+ point = models.PointStruct(
182
+ id=point_id,
183
+ vector=embedding.tolist(),
184
+ payload={"document": doc}
185
+ )
186
+ points.append(point)
187
+
188
+ # Upload points to the collection
189
+ qdrant_client.upsert(
190
+ collection_name=COLLECTION_NAME,
191
+ points=points
192
+ )
193
+
194
+ return f"Successfully added {len(documents_to_add)} document(s) to the collection."
195
+
196
  # --- Gradio Interface ---
197
  with gr.Blocks() as demo:
198
  gr.Markdown("# Semantic Search with Qdrant and Gradio")
199
  gr.Markdown("Enter a query to search for similar news articles from the AG News dataset.")
200
 
201
+ with gr.Tab("Search"):
202
+ with gr.Row():
203
+ search_input = gr.Textbox(label="Search Query", placeholder="e.g., 'Latest news on space exploration'")
204
+ search_button = gr.Button("Search")
205
+ search_output = gr.Markdown()
206
+ search_button.click(search_in_qdrant, inputs=search_input, outputs=search_output)
207
+
208
+ with gr.Tab("Upload"):
209
+ with gr.Row():
210
+ text_input = gr.Textbox(label="Text Content", placeholder="Enter text to add to the collection", lines=5)
211
+ with gr.Row():
212
+ file_input = gr.File(label="Or Upload a File", file_types=['.txt', '.pdf', '.docx', '.csv', '.xlsx', '.xls', '.md'])
213
+ upload_button = gr.Button("Upload to Collection")
214
+ upload_output = gr.Textbox(label="Upload Status", interactive=False)
215
+ upload_button.click(upload_to_qdrant, inputs=[text_input, file_input], outputs=upload_output)
216
 
217
  if __name__ == "__main__":
218
  demo.launch()