korupolujayanth2004 commited on
Commit
a9bec9f
·
1 Parent(s): 8372873

Update embed_utils.py

Browse files
Files changed (1) hide show
  1. backend/embed_utils.py +33 -27
backend/embed_utils.py CHANGED
@@ -1,13 +1,18 @@
1
  # backend/embed_utils.py
2
-
3
  import os
4
- import uuid # Added for generating point IDs
5
- import time # For upload_timestamp
 
 
 
 
 
 
6
  from sentence_transformers import SentenceTransformer
7
  # Qdrant client models
8
  from qdrant_client.http.models import PointStruct, Filter, FieldCondition, MatchValue, Distance, VectorParams, NamedVector, ScrollResult
9
- from backend.qdrant_client import qdrant_client, KB_COLLECTION # Import Qdrant client and collection name
10
- from backend.document_loader import Document # Import the Document class definition
11
  from typing import List
12
 
13
  # === Embedding Model Initialization ===
@@ -30,26 +35,26 @@ def embed_and_store_chunks(documents: List[Document], session_id: str):
30
  Each chunk is associated with a session_id.
31
  """
32
  points = []
33
- current_timestamp = str(int(time.time())) # Use a Unix timestamp for when the document was uploaded
34
-
35
  for doc in documents:
36
  # Generate embedding for the chunk's text content
37
  embedding = get_embedding(doc.text)
38
-
39
  # Create payload for Qdrant, including chunk details and session ID
40
  payload = {
41
  "chunk_id": doc.chunk_id,
42
  "text": doc.text,
43
- "metadata": doc.metadata, # Preserve original metadata from document_loader
44
- "session_id": session_id, # CRUCIAL: Associate each chunk with the current session
45
  "upload_timestamp": current_timestamp,
46
- "file_type": doc.metadata.get("file_type", "unknown"), # Get file_type from metadata
47
- "source": doc.metadata.get("source", "unknown") # Get source from metadata
48
  }
49
-
50
  points.append(
51
  PointStruct(
52
- id=str(uuid.uuid4()), # Assign a unique ID for each Qdrant point
53
  vector=embedding,
54
  payload=payload
55
  )
@@ -59,7 +64,7 @@ def embed_and_store_chunks(documents: List[Document], session_id: str):
59
  if points:
60
  qdrant_client.upsert(
61
  collection_name=KB_COLLECTION,
62
- wait=True, # Wait for the operation to complete
63
  points=points
64
  )
65
  print(f"Stored {len(points)} chunks for session '{session_id}' into '{KB_COLLECTION}'.")
@@ -74,10 +79,10 @@ def search_knowledge_base(query_text: str, session_id: str, top_k: int = 5) -> s
74
  Returns a concatenated string of the most relevant text chunks.
75
  """
76
  if not query_text.strip():
77
- return "" # Return empty string if query is empty
78
-
79
  query_embedding = get_embedding(query_text)
80
-
81
  # Construct a filter to ensure we only search within the current session's data
82
  session_filter = Filter(
83
  must=[
@@ -87,30 +92,31 @@ def search_knowledge_base(query_text: str, session_id: str, top_k: int = 5) -> s
87
  )
88
  ]
89
  )
90
-
91
  try:
92
  # Perform the search in Qdrant with the query vector and session filter
93
  search_result: List[ScrollResult] = qdrant_client.search(
94
  collection_name=KB_COLLECTION,
95
  query_vector=query_embedding,
96
- query_filter=session_filter, # Apply the session-specific filter
97
- limit=top_k, # Number of top results to retrieve
98
- with_payload=True # Ensure payload (text and metadata) is returned
99
  )
100
-
101
  context_chunks = []
102
  for hit in search_result:
103
  # Extract the text content from the payload of each relevant hit
104
  if hit.payload and 'text' in hit.payload:
105
  context_chunks.append(hit.payload['text'])
106
- # print(f" Hit: {hit.payload.get('text', '')[:50]}... (Score: {hit.score})") # Debugging line
107
-
108
  if context_chunks:
109
  # Join relevant chunks into a single string to provide to the LLM
110
  return "\n\n".join(context_chunks)
111
  else:
112
  print(f"No relevant context found in KB for session '{session_id}' and query: '{query_text}'")
113
- return "" # No relevant context found for this session
 
114
  except Exception as e:
115
  print(f"Error during knowledge base search for session '{session_id}': {e}")
116
- return "" # Return empty string on error
 
1
  # backend/embed_utils.py
 
2
  import os
3
+ import uuid # Added for generating point IDs
4
+ import time # For upload_timestamp
5
+
6
+ # Fix cache directory permissions for Hugging Face Spaces
7
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache_custom"
8
+ os.environ["HF_HOME"] = "/tmp/transformers_cache_custom"
9
+ os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
10
+
11
  from sentence_transformers import SentenceTransformer
12
  # Qdrant client models
13
  from qdrant_client.http.models import PointStruct, Filter, FieldCondition, MatchValue, Distance, VectorParams, NamedVector, ScrollResult
14
+ from backend.qdrant_client import qdrant_client, KB_COLLECTION # Import Qdrant client and collection name
15
+ from backend.document_loader import Document # Import the Document class definition
16
  from typing import List
17
 
18
  # === Embedding Model Initialization ===
 
35
  Each chunk is associated with a session_id.
36
  """
37
  points = []
38
+ current_timestamp = str(int(time.time())) # Use a Unix timestamp for when the document was uploaded
39
+
40
  for doc in documents:
41
  # Generate embedding for the chunk's text content
42
  embedding = get_embedding(doc.text)
43
+
44
  # Create payload for Qdrant, including chunk details and session ID
45
  payload = {
46
  "chunk_id": doc.chunk_id,
47
  "text": doc.text,
48
+ "metadata": doc.metadata, # Preserve original metadata from document_loader
49
+ "session_id": session_id, # CRUCIAL: Associate each chunk with the current session
50
  "upload_timestamp": current_timestamp,
51
+ "file_type": doc.metadata.get("file_type", "unknown"), # Get file_type from metadata
52
+ "source": doc.metadata.get("source", "unknown") # Get source from metadata
53
  }
54
+
55
  points.append(
56
  PointStruct(
57
+ id=str(uuid.uuid4()), # Assign a unique ID for each Qdrant point
58
  vector=embedding,
59
  payload=payload
60
  )
 
64
  if points:
65
  qdrant_client.upsert(
66
  collection_name=KB_COLLECTION,
67
+ wait=True, # Wait for the operation to complete
68
  points=points
69
  )
70
  print(f"Stored {len(points)} chunks for session '{session_id}' into '{KB_COLLECTION}'.")
 
79
  Returns a concatenated string of the most relevant text chunks.
80
  """
81
  if not query_text.strip():
82
+ return "" # Return empty string if query is empty
83
+
84
  query_embedding = get_embedding(query_text)
85
+
86
  # Construct a filter to ensure we only search within the current session's data
87
  session_filter = Filter(
88
  must=[
 
92
  )
93
  ]
94
  )
95
+
96
  try:
97
  # Perform the search in Qdrant with the query vector and session filter
98
  search_result: List[ScrollResult] = qdrant_client.search(
99
  collection_name=KB_COLLECTION,
100
  query_vector=query_embedding,
101
+ query_filter=session_filter, # Apply the session-specific filter
102
+ limit=top_k, # Number of top results to retrieve
103
+ with_payload=True # Ensure payload (text and metadata) is returned
104
  )
105
+
106
  context_chunks = []
107
  for hit in search_result:
108
  # Extract the text content from the payload of each relevant hit
109
  if hit.payload and 'text' in hit.payload:
110
  context_chunks.append(hit.payload['text'])
111
+ # print(f" Hit: {hit.payload.get('text', '')[:50]}... (Score: {hit.score})") # Debugging line
112
+
113
  if context_chunks:
114
  # Join relevant chunks into a single string to provide to the LLM
115
  return "\n\n".join(context_chunks)
116
  else:
117
  print(f"No relevant context found in KB for session '{session_id}' and query: '{query_text}'")
118
+ return "" # No relevant context found for this session
119
+
120
  except Exception as e:
121
  print(f"Error during knowledge base search for session '{session_id}': {e}")
122
+ return "" # Return empty string on error