bharatcoder commited on
Commit
da13ac2
·
verified ·
1 Parent(s): 68a5585

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
 
2
  try:
 
3
  import gradio as gr
4
  import torch
5
  from sentence_transformers import SentenceTransformer
@@ -241,6 +242,111 @@ class EmbeddingGemmaPrompts:
241
 
242
 
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  def slice_list(lst: list, start: int, end: int) -> list:
245
  """
246
  A tool that slices a list given a start and end index.
 
1
 
2
  try:
3
+ import os
4
  import gradio as gr
5
  import torch
6
  from sentence_transformers import SentenceTransformer
 
242
 
243
 
244
 
245
+ def search_knowledge_base(
246
+ query: str,
247
+ num_results: int = 5,
248
+ source_filter: Optional[str] = None,
249
+ task_type: str = "search"
250
+ ) -> Dict[str, Any]:
251
+ """
252
+ Search the RS Studies knowledge base using semantic similarity
253
+
254
+ Args:
255
+ query: The search query
256
+ num_results: Number of results to return
257
+ source_filter: Optional source folder filter
258
+ task_type: Type of task for query formatting
259
+
260
+ Returns:
261
+ Dictionary with search results and metadata
262
+ """
263
+ if not ensure_initialized():
264
+ return {"error": "Server not properly initialized", "results": []}
265
+
266
+ try:
267
+ # Create query embedding with task-specific formatting using EmbeddingGemmaPrompts
268
+ query_formatted = EmbeddingGemmaPrompts.encode_query(query, task_type)
269
+ query_embedding = model.encode([query_formatted], device=device)
270
+
271
+ # Prepare search parameters
272
+ search_params = {
273
+ "query_embeddings": query_embedding.tolist(),
274
+ "n_results": min(num_results, config.MAX_NUM_RESULTS),
275
+ "include": ["documents", "metadatas", "distances"]
276
+ }
277
+
278
+ # Add source filter if specified
279
+ if source_filter and source_filter in config.VALID_SOURCES:
280
+ search_params["where"] = {"source_folder": {"$eq": source_filter}}
281
+
282
+ # Perform search
283
+ results = collection.query(**search_params)
284
+
285
+ # Format results
286
+ formatted_results = []
287
+ if results["documents"] and len(results["documents"]) > 0:
288
+ for i in range(len(results["documents"][0])):
289
+ result = {
290
+ "rank": i + 1,
291
+ "content": results["documents"][0][i],
292
+ "source_folder": results["metadatas"][0][i].get("source_folder", "unknown"),
293
+ "chunk_file": results["metadatas"][0][i].get("chunk_file", "unknown"),
294
+ "chunk_number": results["metadatas"][0][i].get("chunk_number", "unknown"),
295
+ "similarity_score": float(1 - results["distances"][0][i]),
296
+ "distance": float(results["distances"][0][i]),
297
+ "chunk_length": results["metadatas"][0][i].get("chunk_length", 0),
298
+ "metadata": results["metadatas"][0][i]
299
+ }
300
+ formatted_results.append(result)
301
+
302
+ return {
303
+ "query": query,
304
+ "task_type": task_type,
305
+ "num_results": len(formatted_results),
306
+ "source_filter": source_filter,
307
+ "results": formatted_results,
308
+ "success": True
309
+ }
310
+
311
+ except Exception as e:
312
+ return {"error": f"Search failed: {str(e)}", "results": [], "success": False}
313
+
314
+ def get_available_sources() -> Dict[str, Any]:
315
+ """Get list of available source folders in the knowledge base"""
316
+ if not ensure_initialized():
317
+ return {"error": "Server not properly initialized", "sources": []}
318
+
319
+ try:
320
+ # Get all metadata to find unique source folders
321
+ all_results = collection.get(include=["metadatas"])
322
+ sources = set()
323
+
324
+ for metadata in all_results["metadatas"]:
325
+ source = metadata.get("source_folder")
326
+ if source:
327
+ sources.add(source)
328
+
329
+ # Get statistics for each source
330
+ source_stats = {}
331
+ for source in sources:
332
+ source_results = collection.get(
333
+ where={"source_folder": {"$eq": source}},
334
+ include=["metadatas"]
335
+ )
336
+ source_stats[source] = len(source_results["metadatas"])
337
+
338
+ return {
339
+ "sources": sorted(list(sources)),
340
+ "source_stats": source_stats,
341
+ "total_sources": len(sources),
342
+ "total_chunks": collection.count(),
343
+ "success": True
344
+ }
345
+
346
+ except Exception as e:
347
+ return {"error": f"Failed to get sources: {str(e)}", "sources": [], "success": False}
348
+
349
+
350
  def slice_list(lst: list, start: int, end: int) -> list:
351
  """
352
  A tool that slices a list given a start and end index.