Spaces:
Sleeping
Sleeping
korupolujayanth2004
commited on
Commit
·
a9bec9f
1
Parent(s):
8372873
Update embed_utils.py
Browse files- backend/embed_utils.py +33 -27
backend/embed_utils.py
CHANGED
|
@@ -1,13 +1,18 @@
|
|
| 1 |
# backend/embed_utils.py
|
| 2 |
-
|
| 3 |
import os
|
| 4 |
-
import uuid
|
| 5 |
-
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
# Qdrant client models
|
| 8 |
from qdrant_client.http.models import PointStruct, Filter, FieldCondition, MatchValue, Distance, VectorParams, NamedVector, ScrollResult
|
| 9 |
-
from backend.qdrant_client import qdrant_client, KB_COLLECTION
|
| 10 |
-
from backend.document_loader import Document
|
| 11 |
from typing import List
|
| 12 |
|
| 13 |
# === Embedding Model Initialization ===
|
|
@@ -30,26 +35,26 @@ def embed_and_store_chunks(documents: List[Document], session_id: str):
|
|
| 30 |
Each chunk is associated with a session_id.
|
| 31 |
"""
|
| 32 |
points = []
|
| 33 |
-
current_timestamp = str(int(time.time()))
|
| 34 |
-
|
| 35 |
for doc in documents:
|
| 36 |
# Generate embedding for the chunk's text content
|
| 37 |
embedding = get_embedding(doc.text)
|
| 38 |
-
|
| 39 |
# Create payload for Qdrant, including chunk details and session ID
|
| 40 |
payload = {
|
| 41 |
"chunk_id": doc.chunk_id,
|
| 42 |
"text": doc.text,
|
| 43 |
-
"metadata": doc.metadata,
|
| 44 |
-
"session_id": session_id,
|
| 45 |
"upload_timestamp": current_timestamp,
|
| 46 |
-
"file_type": doc.metadata.get("file_type", "unknown"),
|
| 47 |
-
"source": doc.metadata.get("source", "unknown")
|
| 48 |
}
|
| 49 |
-
|
| 50 |
points.append(
|
| 51 |
PointStruct(
|
| 52 |
-
id=str(uuid.uuid4()),
|
| 53 |
vector=embedding,
|
| 54 |
payload=payload
|
| 55 |
)
|
|
@@ -59,7 +64,7 @@ def embed_and_store_chunks(documents: List[Document], session_id: str):
|
|
| 59 |
if points:
|
| 60 |
qdrant_client.upsert(
|
| 61 |
collection_name=KB_COLLECTION,
|
| 62 |
-
wait=True,
|
| 63 |
points=points
|
| 64 |
)
|
| 65 |
print(f"Stored {len(points)} chunks for session '{session_id}' into '{KB_COLLECTION}'.")
|
|
@@ -74,10 +79,10 @@ def search_knowledge_base(query_text: str, session_id: str, top_k: int = 5) -> s
|
|
| 74 |
Returns a concatenated string of the most relevant text chunks.
|
| 75 |
"""
|
| 76 |
if not query_text.strip():
|
| 77 |
-
return ""
|
| 78 |
-
|
| 79 |
query_embedding = get_embedding(query_text)
|
| 80 |
-
|
| 81 |
# Construct a filter to ensure we only search within the current session's data
|
| 82 |
session_filter = Filter(
|
| 83 |
must=[
|
|
@@ -87,30 +92,31 @@ def search_knowledge_base(query_text: str, session_id: str, top_k: int = 5) -> s
|
|
| 87 |
)
|
| 88 |
]
|
| 89 |
)
|
| 90 |
-
|
| 91 |
try:
|
| 92 |
# Perform the search in Qdrant with the query vector and session filter
|
| 93 |
search_result: List[ScrollResult] = qdrant_client.search(
|
| 94 |
collection_name=KB_COLLECTION,
|
| 95 |
query_vector=query_embedding,
|
| 96 |
-
query_filter=session_filter,
|
| 97 |
-
limit=top_k,
|
| 98 |
-
with_payload=True
|
| 99 |
)
|
| 100 |
-
|
| 101 |
context_chunks = []
|
| 102 |
for hit in search_result:
|
| 103 |
# Extract the text content from the payload of each relevant hit
|
| 104 |
if hit.payload and 'text' in hit.payload:
|
| 105 |
context_chunks.append(hit.payload['text'])
|
| 106 |
-
# print(f" Hit: {hit.payload.get('text', '')[:50]}... (Score: {hit.score})")
|
| 107 |
-
|
| 108 |
if context_chunks:
|
| 109 |
# Join relevant chunks into a single string to provide to the LLM
|
| 110 |
return "\n\n".join(context_chunks)
|
| 111 |
else:
|
| 112 |
print(f"No relevant context found in KB for session '{session_id}' and query: '{query_text}'")
|
| 113 |
-
return ""
|
|
|
|
| 114 |
except Exception as e:
|
| 115 |
print(f"Error during knowledge base search for session '{session_id}': {e}")
|
| 116 |
-
return ""
|
|
|
|
| 1 |
# backend/embed_utils.py
|
|
|
|
| 2 |
import os
|
| 3 |
+
import uuid # Added for generating point IDs
|
| 4 |
+
import time # For upload_timestamp
|
| 5 |
+
|
| 6 |
+
# Fix cache directory permissions for Hugging Face Spaces
|
| 7 |
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache_custom"
|
| 8 |
+
os.environ["HF_HOME"] = "/tmp/transformers_cache_custom"
|
| 9 |
+
os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
|
| 10 |
+
|
| 11 |
from sentence_transformers import SentenceTransformer
|
| 12 |
# Qdrant client models
|
| 13 |
from qdrant_client.http.models import PointStruct, Filter, FieldCondition, MatchValue, Distance, VectorParams, NamedVector, ScrollResult
|
| 14 |
+
from backend.qdrant_client import qdrant_client, KB_COLLECTION # Import Qdrant client and collection name
|
| 15 |
+
from backend.document_loader import Document # Import the Document class definition
|
| 16 |
from typing import List
|
| 17 |
|
| 18 |
# === Embedding Model Initialization ===
|
|
|
|
| 35 |
Each chunk is associated with a session_id.
|
| 36 |
"""
|
| 37 |
points = []
|
| 38 |
+
current_timestamp = str(int(time.time())) # Use a Unix timestamp for when the document was uploaded
|
| 39 |
+
|
| 40 |
for doc in documents:
|
| 41 |
# Generate embedding for the chunk's text content
|
| 42 |
embedding = get_embedding(doc.text)
|
| 43 |
+
|
| 44 |
# Create payload for Qdrant, including chunk details and session ID
|
| 45 |
payload = {
|
| 46 |
"chunk_id": doc.chunk_id,
|
| 47 |
"text": doc.text,
|
| 48 |
+
"metadata": doc.metadata, # Preserve original metadata from document_loader
|
| 49 |
+
"session_id": session_id, # CRUCIAL: Associate each chunk with the current session
|
| 50 |
"upload_timestamp": current_timestamp,
|
| 51 |
+
"file_type": doc.metadata.get("file_type", "unknown"), # Get file_type from metadata
|
| 52 |
+
"source": doc.metadata.get("source", "unknown") # Get source from metadata
|
| 53 |
}
|
| 54 |
+
|
| 55 |
points.append(
|
| 56 |
PointStruct(
|
| 57 |
+
id=str(uuid.uuid4()), # Assign a unique ID for each Qdrant point
|
| 58 |
vector=embedding,
|
| 59 |
payload=payload
|
| 60 |
)
|
|
|
|
| 64 |
if points:
|
| 65 |
qdrant_client.upsert(
|
| 66 |
collection_name=KB_COLLECTION,
|
| 67 |
+
wait=True, # Wait for the operation to complete
|
| 68 |
points=points
|
| 69 |
)
|
| 70 |
print(f"Stored {len(points)} chunks for session '{session_id}' into '{KB_COLLECTION}'.")
|
|
|
|
| 79 |
Returns a concatenated string of the most relevant text chunks.
|
| 80 |
"""
|
| 81 |
if not query_text.strip():
|
| 82 |
+
return "" # Return empty string if query is empty
|
| 83 |
+
|
| 84 |
query_embedding = get_embedding(query_text)
|
| 85 |
+
|
| 86 |
# Construct a filter to ensure we only search within the current session's data
|
| 87 |
session_filter = Filter(
|
| 88 |
must=[
|
|
|
|
| 92 |
)
|
| 93 |
]
|
| 94 |
)
|
| 95 |
+
|
| 96 |
try:
|
| 97 |
# Perform the search in Qdrant with the query vector and session filter
|
| 98 |
search_result: List[ScrollResult] = qdrant_client.search(
|
| 99 |
collection_name=KB_COLLECTION,
|
| 100 |
query_vector=query_embedding,
|
| 101 |
+
query_filter=session_filter, # Apply the session-specific filter
|
| 102 |
+
limit=top_k, # Number of top results to retrieve
|
| 103 |
+
with_payload=True # Ensure payload (text and metadata) is returned
|
| 104 |
)
|
| 105 |
+
|
| 106 |
context_chunks = []
|
| 107 |
for hit in search_result:
|
| 108 |
# Extract the text content from the payload of each relevant hit
|
| 109 |
if hit.payload and 'text' in hit.payload:
|
| 110 |
context_chunks.append(hit.payload['text'])
|
| 111 |
+
# print(f" Hit: {hit.payload.get('text', '')[:50]}... (Score: {hit.score})") # Debugging line
|
| 112 |
+
|
| 113 |
if context_chunks:
|
| 114 |
# Join relevant chunks into a single string to provide to the LLM
|
| 115 |
return "\n\n".join(context_chunks)
|
| 116 |
else:
|
| 117 |
print(f"No relevant context found in KB for session '{session_id}' and query: '{query_text}'")
|
| 118 |
+
return "" # No relevant context found for this session
|
| 119 |
+
|
| 120 |
except Exception as e:
|
| 121 |
print(f"Error during knowledge base search for session '{session_id}': {e}")
|
| 122 |
+
return "" # Return empty string on error
|