Spaces:
Sleeping
Sleeping
Update vectordb_utils.py
Browse files- vectordb_utils.py +24 -2
vectordb_utils.py
CHANGED
|
@@ -3,20 +3,27 @@
|
|
| 3 |
from qdrant_client import QdrantClient
|
| 4 |
from qdrant_client.models import VectorParams, Distance, PointStruct
|
| 5 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 6 |
import uuid
|
| 7 |
import os
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
os.makedirs(cache_dir, exist_ok=True)
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
qdrant = QdrantClient(":memory:")
|
| 12 |
collection_name = "customer_support_docsv1"
|
| 13 |
|
|
|
|
| 14 |
def init_qdrant_collection():
|
| 15 |
qdrant.recreate_collection(
|
| 16 |
collection_name=collection_name,
|
| 17 |
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
|
| 18 |
)
|
| 19 |
|
|
|
|
| 20 |
def add_to_vectordb(query, response):
|
| 21 |
vector = encoder.encode(query).tolist()
|
| 22 |
qdrant.upload_points(
|
|
@@ -28,6 +35,21 @@ def add_to_vectordb(query, response):
|
|
| 28 |
)]
|
| 29 |
)
|
| 30 |
|
|
|
|
| 31 |
def search_vectordb(query, limit=3):
|
| 32 |
vector = encoder.encode(query).tolist()
|
| 33 |
return qdrant.search(collection_name=collection_name, query_vector=vector, limit=limit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from qdrant_client import QdrantClient
|
| 4 |
from qdrant_client.models import VectorParams, Distance, PointStruct
|
| 5 |
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
import uuid
|
| 8 |
import os
|
| 9 |
+
|
| 10 |
+
# Setup cache dir
|
| 11 |
+
cache_dir = os.environ.get("MODEL_CACHE_DIR", "/app/cache") # Fallback
|
| 12 |
os.makedirs(cache_dir, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
# Encoder and Qdrant config
|
| 15 |
+
encoder = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=cache_dir)
|
| 16 |
qdrant = QdrantClient(":memory:")
|
| 17 |
collection_name = "customer_support_docsv1"
|
| 18 |
|
| 19 |
+
# Initialize collection
|
| 20 |
def init_qdrant_collection():
|
| 21 |
qdrant.recreate_collection(
|
| 22 |
collection_name=collection_name,
|
| 23 |
vectors_config=VectorParams(size=384, distance=Distance.COSINE)
|
| 24 |
)
|
| 25 |
|
| 26 |
+
# Add a query/response to DB
|
| 27 |
def add_to_vectordb(query, response):
|
| 28 |
vector = encoder.encode(query).tolist()
|
| 29 |
qdrant.upload_points(
|
|
|
|
| 35 |
)]
|
| 36 |
)
|
| 37 |
|
| 38 |
+
# Search DB
|
| 39 |
def search_vectordb(query, limit=3):
|
| 40 |
vector = encoder.encode(query).tolist()
|
| 41 |
return qdrant.search(collection_name=collection_name, query_vector=vector, limit=limit)
|
| 42 |
+
|
| 43 |
+
# 🆕 Load and populate from Hugging Face dataset
|
| 44 |
+
def populate_vectordb_from_hf():
|
| 45 |
+
print("Loading dataset from Hugging Face...")
|
| 46 |
+
dataset = load_dataset("Talhat/Customer_IT_Support", split="train")
|
| 47 |
+
|
| 48 |
+
print("Populating vector DB...")
|
| 49 |
+
for item in dataset:
|
| 50 |
+
query = item.get("input", "").strip()
|
| 51 |
+
response = item.get("output", "").strip()
|
| 52 |
+
if query and response:
|
| 53 |
+
add_to_vectordb(query, response)
|
| 54 |
+
|
| 55 |
+
print("Vector DB population complete.")
|