Minh commited on
Commit
6912ad8
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+ .env
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: HF Vector Search
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.19.2
8
+ app_file: app.py
9
+ pinned: true
10
+ ---
11
+
12
+ # HF Vector Search
13
+ Dự án tìm kiếm Vector sử dụng Qdrant (Deploy via Gradio SDK).
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import json
3
+ # import gradio as gr
4
+
5
+ # from uuid import uuid4
6
+ # from pprint import pprint
7
+ # from dotenv import load_dotenv
8
+ # from qdrant_client import QdrantClient
9
+ # from fastembed import TextEmbedding
10
+ # from langchain_core.documents import Document
11
+ # from src.utils.qdrant_vector_store import QdrantVectorStore, RetrievalMode
12
+ # from src.utils.fastembed_manager import add_custom_embedding_model
13
+ # from src.utils.fastembed_sparse import FastEmbedSparse
14
+
15
+
16
+
17
+ # from qdrant_client import QdrantClient
18
+ # from qdrant_client.http import models
19
+ # load_dotenv()
20
+
21
+
22
+ # COLLECTION_NAME = "test_collection"
23
+ # qdrant_api_key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.T97XMDCPTieAz5kVDkKtF0_HU_9BkFA71tH2j4WovkU"
24
+ # qdrant_endpoint = "https://9ea9b30f-4284-455b-bbae-65e4e458ed35.europe-west3-0.gcp.cloud.qdrant.io"
25
+ # qdrant_client = QdrantClient(
26
+ # url=qdrant_endpoint,
27
+ # api_key=qdrant_api_key,
28
+ # prefer_grpc=True,
29
+ # )
30
+ # sparse_embeddings = FastEmbedSparse(model_name="Qdrant/BM25")
31
+ # embedding = add_custom_embedding_model(
32
+ # model_name="models/Vietnamese_Embedding_OnnX_Quantized",
33
+ # source_model="Mint1456/Vietnamese_Embedding_OnnX_Quantized",
34
+ # dim=1024,
35
+ # source_file="model.onnx"
36
+ # )
37
+ # client = QdrantVectorStore(
38
+ # client=qdrant_client,
39
+ # collection_name=COLLECTION_NAME,
40
+ # embedding=embedding,
41
+ # sparse_embedding=sparse_embeddings,
42
+ # retrieval_mode=RetrievalMode.HYBRID,
43
+ # )
44
+
45
+ # def search_document(query, top_k, search_type, slider_lambda):
46
+ # if not query.strip():
47
+ # return "⚠️ Enter query to look up!"
48
+ # try:
49
+ # if search_type == "Default":
50
+ # hits = client.similarity_search_with_score(query=query,k=top_k)
51
+ # else:
52
+
53
+ # hits = client.max_marginal_relevance_search_with_score(query=query, k=top_k, lambda_mult=slider_lambda)
54
+ # except Exception as e:
55
+ # print("error", e)
56
+
57
+ # total_found = len(hits)
58
+ # if total_found == 0:
59
+ # return json.dumps([], indent=2)
60
+
61
+ # # Nếu tìm được 10 mà đòi 15 -> chỉ lấy 10. Nếu tìm được 100 mà đòi 15 -> lấy 15
62
+ # safe_k = min(top_k, total_found)
63
+ # results = []
64
+ # for i in range(safe_k):
65
+ # hit = hits[i]
66
+ # if hit[0].metadata.get('parent_chunking', None) is not None:
67
+ # content = hit[0].metadata['parent_chunking']
68
+ # elif hit[0].metadata.get('type', None) == "intro":
69
+ # content = hit[0].page_content
70
+ # else:
71
+ # content = None
72
+ # results.append({
73
+ # "Score": round(hit[1], 4),
74
+ # "Content": content,
75
+ # # "Metadata:": {k: v for k, v in hit[0].metadata.items() if k != "page_content"}
76
+ # })
77
+
78
+ # return json.dumps(results, indent=2, ensure_ascii=False)
79
+
80
+ # # --- GIAO DIỆN GRADIO ---
81
+
82
+ # with gr.Blocks(title="Qdrant Vector DB Demo") as demo:
83
+ # gr.Markdown("# 🚀 Demo Qdrant Vector Search")
84
+ # gr.Markdown("Tool test nhanh khả năng thêm dữ liệu và tìm kiếm ngữ nghĩa (Semantic Search).")
85
+
86
+
87
+ # with gr.Tab("2. Tìm Kiếm (Search)"):
88
+ # with gr.Row():
89
+ # with gr.Column(scale=1):
90
+ # txt_query = gr.Textbox(label="Câu truy vấn", placeholder="Ví dụ: Tìm về một số thông tin trên website Bệnh Viện Tâm Anh", lines=2)
91
+ # gr.Examples(
92
+ # examples=[
93
+ # "Rủi ro khi khâu cổ tử cung",
94
+ # "Biến chứng của tràn dịch phổi",
95
+ # "Triệu chứng của viêm phế quản",
96
+ # "Phòng ngừa đau tim"
97
+ # ],
98
+ # inputs=txt_query,
99
+ # label="Ví dụ mẫu (Click để chọn)"
100
+ # )
101
+ # # Component mới: Chọn thuật toán
102
+ # radio_type = gr.Radio(
103
+ # choices=["Default", "MMR"],
104
+ # value="Default",
105
+ # label="Search Type",
106
+ # info="Default: Giống nhất | MMR: Đa dạng kết quả"
107
+ # )
108
+
109
+ # # Component mới: Slider cho MMR
110
+ # # visible=False mặc định, sẽ hiện khi chọn MMR (nếu bạn muốn làm xịn, ở đây để luôn True cho dễ)
111
+ # slider_lambda = gr.Slider(
112
+ # minimum=0.0, maximum=1.0, value=0.5, step=0.1,
113
+ # label="Độ đa dạng (Lambda)",
114
+ # info="1.0 = Chính xác nhất (như Default), 0.0 = Đa dạng nhất"
115
+ # )
116
+
117
+ # slider_k = gr.Slider(minimum=1, maximum=20, value=3, step=1, label="Số lượng kết quả (Top K)")
118
+
119
+ # btn_search = gr.Button("🔍 Tìm kiếm ngay", variant="primary")
120
+
121
+ # with gr.Column(scale=2):
122
+ # out_search = gr.Code(label="Kết quả trả về (JSON)", language="json")
123
+
124
+ # # Cập nhật inputs truyền vào hàm search
125
+ # btn_search.click(
126
+ # search_document,
127
+ # inputs=[txt_query, slider_k, radio_type, slider_lambda],
128
+ # outputs=out_search
129
+ # )
130
+ import gradio as gr
131
+
132
+ with gr.Blocks(title="Qdrant Vector DB Demo") as demo:
133
+ gr.Markdown("# 🚀 Demo Qdrant Vector Search")
134
+ gr.Markdown("Tool test nhanh khả năng thêm dữ liệu và tìm kiếm ngữ nghĩa (Semantic Search).")
135
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ fast-json-repair>=0.2.0
2
+ fastembed>=0.3.0
3
+ spaces
4
+ gradio
src/utils/embed_manager.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import uuid
3
+ import base64
4
+ import json
5
+
6
+ from bs4 import BeautifulSoup
7
+ from langchain_core.documents import Document
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+
10
+ def uuid64():
11
+ u = uuid.uuid4()
12
+ b64 = base64.urlsafe_b64encode(u.bytes).rstrip(b'=')
13
+ return b64.decode('ascii')
14
+
15
+ async def clean_text(text: str) -> str:
16
+ if not text:
17
+ return ""
18
+
19
+ text = re.sub(r'\[caption[^\]]*\].*?\[/caption\]', '', text, flags=re.IGNORECASE | re.DOTALL)
20
+ text = re.sub(r'\[/?caption[^\]]*\]', '', text, flags=re.IGNORECASE)
21
+ text = re.sub(r'\.(?=[A-ZĂÂÁÀẢÃẠ...])', '. ', text)
22
+ text = re.sub(r'\.([A-ZÀ-Ỹ])', r'. \1', text)
23
+ text = re.sub(r'\s+', ' ', text).strip()
24
+
25
+ return text
26
+
27
+ async def load_json_data(file_path):
28
+ """Load JSON data from file."""
29
+ print(f"Loading data from {file_path}...")
30
+ with open(file_path, 'r', encoding='utf-8') as f:
31
+ data = json.load(f)
32
+ print(f"[OK] Loaded {len(data)} entries")
33
+ return data
34
+
35
+ async def create_qdrant_collection(client, collection_name: str, vector_size: int):
36
+ from qdrant_client.http.models import VectorParams, Distance
37
+
38
+ if not client.collection_exists(collection_name):
39
+ try:
40
+ print(f"Collection '{collection_name}' does not exist. Creating...")
41
+ client.create_collection(
42
+ collection_name=collection_name,
43
+ vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE)
44
+ )
45
+ except Exception as e:
46
+ print(f"Error creating collection '{collection_name}': {e}")
47
+ raise e
48
+ else:
49
+ client.create_collection(
50
+ collection_name=collection_name,
51
+ vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE,)
52
+
53
+ )
54
+
55
+
56
+ async def init_qdrant_client(endpoint: str, api_key: str):
57
+ from qdrant_client import QdrantClient
58
+
59
+ try:
60
+ client = QdrantClient(
61
+ url=endpoint,
62
+ api_key=api_key,
63
+ )
64
+ print("Qdrant client initialized successfully.")
65
+ return client
66
+ except Exception as e:
67
+ print(f"Error initializing Qdrant client: {e}")
68
+ raise e
69
+
70
+
71
+ async def parse_html_to_sections(html: str, data_json):
72
+ soup = BeautifulSoup(html, "html.parser")
73
+
74
+ documents = []
75
+
76
+ # --- 1. Lấy <p> đầu tiên ---
77
+ first_p = soup.find("p")
78
+ if first_p:
79
+ cleaned_text = await clean_text(first_p.get_text(separator=" ", strip=True))
80
+ documents.append(
81
+ Document(
82
+ page_content= cleaned_text,
83
+ metadata={
84
+ "site": data_json["site"],
85
+ "url": data_json["url"],
86
+ "date_created": data_json["event_time"]["$date"],
87
+ "document_id": uuid64(),
88
+ "type": "intro"
89
+ }
90
+ )
91
+ )
92
+ first_p.decompose() # remove để không bị lặp
93
+
94
+ # --- 2. Tách theo h2 ---
95
+ h2_tags = soup.find_all("h2")
96
+
97
+ for i, h2 in enumerate(h2_tags):
98
+ header = await clean_text(h2.get_text(separator=" ", strip=True))
99
+ cleaned_text = await clean_text(first_p.get_text(separator=" ", strip=True))
100
+ contents = []
101
+ for sib in h2.next_siblings:
102
+ if getattr(sib, "name", None) == "h2":
103
+ break
104
+ if hasattr(sib, "get_text"):
105
+ text = await clean_text(sib.get_text(separator=" ", strip=True))
106
+ if text:
107
+ contents.append(text)
108
+
109
+ parent_text = header + "\n" + "\n".join(contents)
110
+
111
+ documents.append(
112
+ Document(
113
+ page_content=parent_text,
114
+ metadata={
115
+ "site": data_json["site"],
116
+ "url": data_json["url"],
117
+ "date_created": data_json["event_time"]["$date"],
118
+ "header": header,
119
+ "parent_id": uuid64(),
120
+ "parent_chunking": parent_text,
121
+ }
122
+ )
123
+ )
124
+
125
+ return documents
126
+
127
+
128
+ async def chunk_documents(docs, chunk_size=500, chunk_overlap =50):
129
+ splitter = RecursiveCharacterTextSplitter(
130
+ chunk_size=chunk_size,
131
+ chunk_overlap=chunk_overlap,
132
+ separators=["\n\n", "\n", " ", ""]
133
+ )
134
+
135
+ chunked_docs = []
136
+
137
+ for doc in docs:
138
+ # chỉ chunk các section có header (bỏ intro nếu muốn)
139
+ if doc.metadata.get("type") == "intro":
140
+ chunked_docs.append(doc)
141
+ continue
142
+
143
+ chunks = splitter.split_text(doc.page_content)
144
+ print("chunk=", len(chunks))
145
+ header = doc.metadata.get("header")
146
+ # print(header)
147
+
148
+ for idx, chunk in enumerate(chunks):
149
+ page_content = header + "\n " + chunk
150
+ # print(page_content)
151
+ chunked_docs.append(
152
+ Document(
153
+ page_content= page_content,
154
+ metadata={
155
+ **doc.metadata,
156
+ "document_id": uuid64()
157
+ }
158
+ )
159
+ )
160
+
161
+ return chunked_docs
src/utils/embeddings.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import asyncio
4
+ import re
5
+ import uuid
6
+ import base64
7
+ import json
8
+
9
+ from bs4 import BeautifulSoup
10
+ from typing import List, Dict, Tuple, Optional, Any, Protocol, Literal
11
+ from langchain_core.documents import Document
12
+ from fastembed_manager import add_custom_embedding_model
13
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
14
+ from tqdm.asyncio import tqdm_asyncio
15
+ from asyncio import Semaphore
16
+ from fastembed_manager import add_custom_embedding_model
17
+ sem = Semaphore(10)
18
+
19
+ def resolve_user_path(path: str) -> str:
20
+ return os.path.expanduser(path)
21
+
22
+ def load_json_data(file_path: str) -> List[Dict[str, Any]]:
23
+ import json
24
+ with open(file_path, 'r', encoding='utf-8') as f:
25
+ data = json.load(f)
26
+ return data
27
+
28
+
29
+ def uuid64():
30
+ u = uuid.uuid4()
31
+ b64 = base64.urlsafe_b64encode(u.bytes).rstrip(b'=')
32
+ return b64.decode('ascii')
33
+
34
+ def clean_text(text: str) -> str:
35
+ if not text:
36
+ return ""
37
+
38
+ # 1. Xóa TOÀN BỘ khối caption (cả thẻ lẫn nội dung bên trong)
39
+ # Dùng flag re.DOTALL để dấu chấm (.) khớp được cả xuống dòng (\n)
40
+ # Pattern: Tìm [caption ... ] ... [/caption] và xóa sạch
41
+ text = re.sub(r'\[caption[^\]]*\].*?\[/caption\]', '', text, flags=re.IGNORECASE | re.DOTALL)
42
+
43
+ # 2. (Dự phòng) Xóa các thẻ shortcode lẻ tẻ còn sót lại (ví dụ chỉ có mở mà không có đóng)
44
+ text = re.sub(r'\[/?caption[^\]]*\]', '', text, flags=re.IGNORECASE)
45
+
46
+ # 3. Xử lý lỗi dính chữ sau dấu chấm (Ví dụ: "tiêu biến.Ống" -> "tiêu biến. Ống")
47
+ # Tìm dấu chấm, theo sau là chữ cái viết hoa, mà không có khoảng trắng
48
+ text = re.sub(r'\.(?=[A-ZĂÂÁÀẢÃẠ...])', '. ', text)
49
+ # (Lưu ý: Regex trên đơn giản, nếu muốn bắt chính xác tiếng Việt thì cần list dài hơn hoặc dùng \w)
50
+ # Cách đơn giản hơn cho tiếng Việt:
51
+ text = re.sub(r'\.([A-ZÀ-Ỹ])', r'. \1', text)
52
+
53
+ # 4. Xóa khoảng trắng thừa
54
+ text = re.sub(r'\s+', ' ', text).strip()
55
+
56
+ return text
57
+
58
+ def parse_html_to_sections(html: str, data_json):
59
+ soup = BeautifulSoup(html, "html.parser")
60
+
61
+ documents = []
62
+
63
+ first_p = soup.find("p")
64
+ if first_p:
65
+ cleaned_text = clean_text(first_p.get_text(separator=" ", strip=True))
66
+ documents.append(
67
+ Document(
68
+ page_content=cleaned_text,
69
+ metadata={
70
+ "site": data_json["site"],
71
+ "url": data_json["url"],
72
+ "date_created": data_json["event_time"]["$date"],
73
+ "document_id": uuid64(),
74
+ "type": "intro"
75
+ }
76
+ )
77
+ )
78
+ first_p.decompose()
79
+
80
+ h2_tags = soup.find_all("h2")
81
+
82
+ for i, h2 in enumerate(h2_tags):
83
+ header = clean_text(h2.get_text(separator=" ", strip=True))
84
+ contents = []
85
+ for sib in h2.next_siblings:
86
+ if getattr(sib, "name", None) == "h2":
87
+ break
88
+ if hasattr(sib, "get_text"):
89
+ text = clean_text(sib.get_text(separator=" ", strip=True))
90
+ if text:
91
+ contents.append(text)
92
+
93
+ parent_text = header + "\n" + "\n".join(contents)
94
+
95
+ documents.append(
96
+ Document(
97
+ page_content=parent_text,
98
+ metadata={
99
+ "site": data_json["site"],
100
+ "url": data_json["url"],
101
+ "date_created": data_json["event_time"]["$date"],
102
+ "header": header,
103
+ "parent_id": uuid64(),
104
+ "parent_chunking": parent_text,
105
+ }
106
+ )
107
+ )
108
+
109
+ return documents
110
+
111
+
112
+ def chunk_documents(docs, chunk_size=500, chunk_overlap =50):
113
+ splitter = RecursiveCharacterTextSplitter(
114
+ chunk_size=chunk_size,
115
+ chunk_overlap=chunk_overlap,
116
+ separators=["\n\n", "\n", " ", ""]
117
+ )
118
+
119
+ chunked_docs = []
120
+
121
+ for doc in docs:
122
+ # chỉ chunk các section có header (bỏ intro nếu muốn)
123
+ if doc.metadata.get("type") == "intro":
124
+ chunked_docs.append(doc)
125
+ continue
126
+
127
+ chunks = splitter.split_text(doc.page_content)
128
+ # print("chunk=", len(chunks))
129
+ header = doc.metadata.get("header")
130
+ # print(header)
131
+
132
+ for idx, chunk in enumerate(chunks):
133
+ page_content = header + "\n " + chunk
134
+ # print(page_content)
135
+ chunked_docs.append(
136
+ Document(
137
+ page_content= page_content,
138
+ metadata={
139
+ **doc.metadata,
140
+ "document_id": uuid64()
141
+ }
142
+ )
143
+ )
144
+
145
+ return chunked_docs
146
+
147
+ async def process_single_data(data_json) -> Document:
148
+ async with sem:
149
+ html_text = data_json.get("body", "")
150
+ if not html_text:
151
+ raise ValueError("No 'body' field in JSON data")
152
+ section = await asyncio.to_thread(parse_html_to_sections, html_text, data_json)
153
+ chunked_section = await asyncio.to_thread(chunk_documents, section)
154
+ return chunked_section
155
+
156
+ async def processing_json_file(file_path: str) -> List[Document]:
157
+ print("Loading JSON data from:", file_path)
158
+ data_list = load_json_data(file_path)
159
+ all_documents = []
160
+
161
+ tasks = [process_single_data(data) for data in data_list]
162
+ results = await tqdm_asyncio.gather(*tasks)
163
+ all_documents = [doc for sublist in results for doc in sublist]
164
+
165
+ return all_documents
166
+
167
+ def embedding_documents(documents: List[Document]):
168
+ from fastembed_sparse import FastEmbedSparse
169
+ from qdrant_vector_store import QdrantVectorStore, RetrievalMode
170
+ from dotenv import load_dotenv
171
+ load_dotenv()
172
+ sparse_embeddings = FastEmbedSparse(model_name="Qdrant/BM25")
173
+ embed = add_custom_embedding_model(
174
+ model_name="models/Vietnamese_Embedding_OnnX_Quantized",
175
+ source_model="Mint1456/Vietnamese_Embedding_OnnX_Quantized",
176
+ dim=1024,
177
+ source_file="model.onnx"
178
+ )
179
+ qdrant_api_key = os.getenv("QDRANT_API_KEY")
180
+ qdrant_endpoint = os.getenv("QDRANT_ENDPOINT")
181
+
182
+ store = QdrantVectorStore.from_documents(
183
+ documents=documents,
184
+ embedding=embed,
185
+ sparse_embedding=sparse_embeddings,
186
+ api_key=qdrant_api_key,
187
+ url=qdrant_endpoint,
188
+ collection_name="test_collection",
189
+ retrieval_mode=RetrievalMode.HYBRID,
190
+ force_recreate=False,
191
+ )
192
+
193
+ if __name__ == "__main__":
194
+ data_path = r"D:\Project\Data\flask_chatai.web_data 1.json"
195
+ data = asyncio.run(processing_json_file(data_path))
196
+ # with open("processed_documents.txt", "w", encoding="utf-8") as f:
197
+ # json.dump([doc.page_content for doc in data], f, ensure_ascii=False, indent=2)
198
+ embedding_documents(data)
src/utils/fastembed_manager.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastembed import TextEmbedding
2
+ from fastembed.common.model_description import PoolingType, ModelSource
3
+ from huggingface_hub import snapshot_download
4
+ import time
5
+
6
+ # def download_model_from_hf(model_name: str, save_path: str):
7
+ # try:
8
+ # snapshot_download(
9
+ # repo_id=model_name,
10
+ # local_dir=save_path,
11
+ # allow_patterns=["onnx/*"],
12
+ # local_dir_use_symlinks=False,
13
+ # )
14
+ # except Exception as e:
15
+ # print(f"Error downloading model from Hugging Face: {e}")
16
+ # raise e
17
+
18
+
19
+ def add_custom_embedding_model(
20
+ model_name: str, source_model: str, source_file: str, dim: int, from_hf: bool = True
21
+ ):
22
+ """Add a custom embedding model to FastEmbed and return an instance of TextEmbedding."""
23
+ if from_hf:
24
+ try:
25
+ TextEmbedding.add_custom_model(
26
+ model=model_name,
27
+ pooling=PoolingType.MEAN,
28
+ normalization=True,
29
+ sources=ModelSource(hf=source_model), # can be used with an `url` to load files from a private storage
30
+ dim=dim,
31
+ model_file=source_file, # can be used to load an already supported model with another optimization or quantization, e.g. onnx/model_O4.onnx
32
+ )
33
+ print(f"Successfully added model '{model_name}' from Hugging Face.")
34
+ return TextEmbedding(model_name=model_name)
35
+ except Exception as e:
36
+ print(f"Error adding model from Hugging Face: {e}")
37
+ raise e
38
+ else:
39
+ try:
40
+ TextEmbedding.add_custom_model(
41
+ model=model_name,
42
+ pooling=PoolingType.MEAN,
43
+ normalization=True,
44
+ sources=ModelSource(url=source_model),
45
+ dim=dim,
46
+ model_file=source_file,
47
+ )
48
+ print(f"Successfully added model '{model_name}' from local file.")
49
+ return TextEmbedding(model_name=model_name)
50
+ except Exception as e:
51
+ print(f"Error adding model from local file: {e}")
52
+ raise e
53
+
54
+ if __name__ == "__main__":
55
+ # Example usage: adding a custom model from Hugging Face
56
+ # add_custom_embedding_model(
57
+ # model_name="models/Vietnamese_Embedding",
58
+ # source_model="AITeamVN/Vietnamese_Embedding",
59
+ # source_file="onnx/model.onnx_data",
60
+ # dim=1024,
61
+ # from_hf=True
62
+ # )
63
+
64
+ # model = TextEmbedding(model_name="AITeamVN/Vietnamese_Embedding")
65
+ # embeddings = list(model.embed("text to embed"))
66
+ # # Ex
67
+ # download_model_from_hf("AITeamVN/Vietnamese_Embedding", "./models/Vietnamese_Embedding")
68
+ # from fastembed import TextEmbedding
69
+ # from fastembed.common.model_description import PoolingType, ModelSource
70
+
71
+ TextEmbedding.add_custom_model(
72
+ model="Mint1456/Vietnamese_Embedding_OnnX_Quantized",
73
+ pooling=PoolingType.MEAN,
74
+ normalization=True,
75
+ sources=ModelSource(hf="Mint1456/Vietnamese_Embedding_OnnX_Quantized"), # can be used with an `url` to load files from a private storage
76
+ dim=1024,
77
+ model_file="model.onnx", # can be used to load an already supported model with another optimization or quantization, e.g. onnx/model_O4.onnx
78
+ )
79
+ model = TextEmbedding(model_name="Mint1456/Vietnamese_Embedding_OnnX_Quantized")
80
+ start = time.perf_counter()
81
+ embeddings = list(model.embed("define artificial intelligence"))
82
+ print(f"len embeding {len(embeddings[0])}, time taken: {time.perf_counter() - start} seconds")
src/utils/fastembed_sparse.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from abc import ABC, abstractmethod
6
+ from langchain_core.runnables.config import run_in_executor
7
+ from pydantic import BaseModel, Field
8
+
9
+ if TYPE_CHECKING:
10
+ from collections.abc import Sequence
11
+
12
+ class SparseVector(BaseModel, extra="forbid"):
13
+ """Sparse vector structure."""
14
+
15
+ indices: list[int] = Field(..., description="indices must be unique")
16
+ values: list[float] = Field(
17
+ ..., description="values and indices must be the same length"
18
+ )
19
+
20
+
21
+ class SparseEmbeddings(ABC):
22
+ """An interface for sparse embedding models to use with Qdrant."""
23
+
24
+ @abstractmethod
25
+ def embed_documents(self, texts: list[str]) -> list[SparseVector]:
26
+ """Embed search docs."""
27
+
28
+ @abstractmethod
29
+ def embed_query(self, text: str) -> SparseVector:
30
+ """Embed query text."""
31
+
32
+ async def aembed_documents(self, texts: list[str]) -> list[SparseVector]:
33
+ """Asynchronous Embed search docs."""
34
+ return await run_in_executor(None, self.embed_documents, texts)
35
+
36
+ async def aembed_query(self, text: str) -> SparseVector:
37
+ """Asynchronous Embed query text."""
38
+ return await run_in_executor(None, self.embed_query, text)
39
+
40
+ class FastEmbedSparse(SparseEmbeddings):
41
+ """An interface for sparse embedding models to use with Qdrant."""
42
+
43
+ def __init__(
44
+ self,
45
+ model_name: str = "Qdrant/bm25",
46
+ batch_size: int = 256,
47
+ cache_dir: str | None = None,
48
+ threads: int | None = None,
49
+ providers: Sequence[Any] | None = None,
50
+ parallel: int | None = None,
51
+ **kwargs: Any,
52
+ ) -> None:
53
+ """Sparse encoder implementation using FastEmbed.
54
+
55
+ Uses [FastEmbed](https://qdrant.github.io/fastembed/) for sparse text
56
+ embeddings.
57
+ For a list of available models, see [the Qdrant docs](https://qdrant.github.io/fastembed/examples/Supported_Models/).
58
+
59
+ Args:
60
+ model_name (str): The name of the model to use.
61
+ batch_size (int): Batch size for encoding.
62
+ cache_dir (str, optional): The path to the model cache directory.\
63
+ Can also be set using the\
64
+ `FASTEMBED_CACHE_PATH` env variable.
65
+ threads (int, optional): The number of threads onnxruntime session can use.
66
+ providers (Sequence[Any], optional): List of ONNX execution providers.\
67
+ parallel (int, optional): If `>1`, data-parallel encoding will be used, r\
68
+ Recommended for encoding of large datasets.\
69
+ If `0`, use all available cores.\
70
+ If `None`, don't use data-parallel processing,\
71
+ use default onnxruntime threading instead.\
72
+
73
+ kwargs: Additional options to pass to `fastembed.SparseTextEmbedding`
74
+
75
+ Raises:
76
+ ValueError: If the `model_name` is not supported in `SparseTextEmbedding`.
77
+ """
78
+ try:
79
+ from fastembed import ( # type: ignore[import-not-found] # noqa: PLC0415
80
+ SparseTextEmbedding,
81
+ )
82
+ except ImportError as err:
83
+ msg = (
84
+ "The 'fastembed' package is not installed. "
85
+ "Please install it with "
86
+ "`pip install fastembed` or `pip install fastembed-gpu`."
87
+ )
88
+ raise ValueError(msg) from err
89
+ self._batch_size = batch_size
90
+ self._parallel = parallel
91
+ self._model = SparseTextEmbedding(
92
+ model_name=model_name,
93
+ cache_dir=cache_dir,
94
+ threads=threads,
95
+ providers=providers,
96
+ **kwargs,
97
+ )
98
+
99
+ def embed_documents(self, texts: list[str]) -> list[SparseVector]:
100
+ results = self._model.embed(
101
+ texts, batch_size=self._batch_size, parallel=self._parallel
102
+ )
103
+ return [
104
+ SparseVector(indices=result.indices.tolist(), values=result.values.tolist())
105
+ for result in results
106
+ ]
107
+
108
+ def embed_query(self, text: str) -> SparseVector:
109
+ result = next(self._model.embed(text))
110
+
111
+ return SparseVector(
112
+ indices=result.indices.tolist(), values=result.values.tolist()
113
+ )
src/utils/qdrant_vector_store.py ADDED
@@ -0,0 +1,1112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import uuid
4
+ import numpy as np
5
+ from collections.abc import Callable
6
+ from enum import Enum
7
+ from itertools import islice
8
+ from operator import itemgetter
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ )
13
+
14
+ from langchain_core.documents import Document
15
+ from fastembed import TextEmbedding
16
+ from langchain_core.vectorstores import VectorStore
17
+ from qdrant_client import QdrantClient, models
18
+
19
+ if TYPE_CHECKING:
20
+ from collections.abc import Generator, Iterable, Sequence
21
+
22
+ from qdrant_sparse_embeddings import SparseEmbeddings
23
+
24
+
25
+ class QdrantVectorStoreError(Exception):
26
+ """`QdrantVectorStore` related exceptions."""
27
+
28
+
29
+ class RetrievalMode(str, Enum):
30
+ """Modes for retrieving vectors from Qdrant."""
31
+
32
+ DENSE = "dense"
33
+ SPARSE = "sparse"
34
+ HYBRID = "hybrid"
35
+
36
+
37
+ class QdrantVectorStore(VectorStore):
38
+ CONTENT_KEY: str = "page_content"
39
+ METADATA_KEY: str = "metadata"
40
+ VECTOR_NAME: str = ""
41
+ SPARSE_VECTOR_NAME: str = "test_collection"
42
+
43
+ def __init__(
44
+ self,
45
+ client: QdrantClient,
46
+ collection_name: str,
47
+ embedding: TextEmbedding | None = None,
48
+ retrieval_mode: RetrievalMode = RetrievalMode.DENSE,
49
+ vector_name: str = VECTOR_NAME,
50
+ content_payload_key: str = CONTENT_KEY,
51
+ metadata_payload_key: str = METADATA_KEY,
52
+ distance: models.Distance = models.Distance.COSINE,
53
+ sparse_embedding: SparseEmbeddings | None = None,
54
+ sparse_vector_name: str = SPARSE_VECTOR_NAME,
55
+ validate_embeddings: bool = True,
56
+ validate_collection_config: bool = True,
57
+ ) -> None:
58
+ """Initialize a new instance of `QdrantVectorStore`.
59
+
60
+ ```python
61
+ qdrant = QdrantVectorStore(
62
+ client=client,
63
+ collection_name="my-collection",
64
+ embedding=OpenAIEmbeddings(),
65
+ retrieval_mode=RetrievalMode.HYBRID,
66
+ sparse_embedding=FastEmbedSparse(),
67
+ )
68
+ ```
69
+ """
70
+ if validate_embeddings:
71
+ self._validate_embeddings(retrieval_mode, embedding, sparse_embedding)
72
+
73
+ if validate_collection_config:
74
+ self._validate_collection_config(
75
+ client,
76
+ collection_name,
77
+ retrieval_mode,
78
+ vector_name,
79
+ sparse_vector_name,
80
+ distance,
81
+ embedding,
82
+ )
83
+
84
+ self._client = client
85
+ self.collection_name = collection_name
86
+ self._embeddings = embedding
87
+ self.retrieval_mode = retrieval_mode
88
+ self.vector_name = vector_name
89
+ self.content_payload_key = content_payload_key
90
+ self.metadata_payload_key = metadata_payload_key
91
+ self.distance = distance
92
+ self._sparse_embeddings = sparse_embedding
93
+ self.sparse_vector_name = sparse_vector_name
94
+
95
+ @property
96
+ def client(self) -> QdrantClient:
97
+ """Get the Qdrant client instance that is being used.
98
+
99
+ Returns:
100
+ QdrantClient: An instance of `QdrantClient`.
101
+
102
+ """
103
+ return self._client
104
+
105
+ @property
106
+ def embeddings(self) -> TextEmbedding | None:
107
+ """Get the dense embeddings instance that is being used.
108
+
109
+ Returns:
110
+ Embeddings: An instance of `TextEmbedding`, or None for SPARSE mode.
111
+
112
+ """
113
+ return self._embeddings
114
+
115
+ def _get_retriever_tags(self) -> list[str]:
116
+ """Get tags for retriever.
117
+
118
+ Override the base class method to handle SPARSE mode where embeddings can be
119
+ None. In SPARSE mode, embeddings is None, so we don't include embeddings class
120
+ name in tags. In DENSE/HYBRID modes, embeddings is not None, so we include
121
+ embeddings class name.
122
+ """
123
+ tags = [self.__class__.__name__]
124
+
125
+ # Handle different retrieval modes
126
+ if self.retrieval_mode == RetrievalMode.SPARSE:
127
+ # SPARSE mode: no dense embeddings, so no embeddings class name in tags
128
+ pass
129
+ # DENSE/HYBRID modes: include embeddings class name if available
130
+ elif self.embeddings is not None:
131
+ tags.append(self.embeddings.__class__.__name__)
132
+
133
+ return tags
134
+
135
+ def _require_embeddings(self, operation: str) -> TextEmbedding:
136
+ """Require embeddings for operations that need them.
137
+
138
+ Args:
139
+ operation: Description of the operation requiring embeddings.
140
+
141
+ Returns:
142
+ The embeddings instance.
143
+
144
+ Raises:
145
+ ValueError: If embeddings are None and required for the operation.
146
+ """
147
+ if self.embeddings is None:
148
+ msg = f"Embeddings are required for {operation}"
149
+ raise ValueError(msg)
150
+ return self.embeddings
151
+
152
+ @property
153
+ def sparse_embeddings(self) -> SparseEmbeddings:
154
+ """Get the sparse embeddings instance that is being used.
155
+
156
+ Raises:
157
+ ValueError: If sparse embeddings are `None`.
158
+
159
+ Returns:
160
+ SparseEmbeddings: An instance of `SparseEmbeddings`.
161
+
162
+ """
163
+ if self._sparse_embeddings is None:
164
+ msg = (
165
+ "Sparse embeddings are `None`. "
166
+ "Please set using the `sparse_embedding` parameter."
167
+ )
168
+ raise ValueError(msg)
169
+ return self._sparse_embeddings
170
+
171
+ @classmethod
172
+ def from_texts(
173
+ cls: type[QdrantVectorStore],
174
+ texts: list[str],
175
+ embedding: TextEmbedding | None = None,
176
+ metadatas: list[dict] | None = None,
177
+ ids: Sequence[str | int] | None = None,
178
+ collection_name: str | None = None,
179
+ location: str | None = None,
180
+ url: str | None = None,
181
+ port: int | None = 6333,
182
+ grpc_port: int = 6334,
183
+ prefer_grpc: bool = False,
184
+ https: bool | None = None,
185
+ api_key: str | None = None,
186
+ prefix: str | None = None,
187
+ timeout: int | None = None,
188
+ host: str | None = None,
189
+ path: str | None = None,
190
+ distance: models.Distance = models.Distance.COSINE,
191
+ content_payload_key: str = CONTENT_KEY,
192
+ metadata_payload_key: str = METADATA_KEY,
193
+ vector_name: str = VECTOR_NAME,
194
+ retrieval_mode: RetrievalMode = RetrievalMode.DENSE,
195
+ sparse_embedding: SparseEmbeddings | None = None,
196
+ sparse_vector_name: str = SPARSE_VECTOR_NAME,
197
+ collection_create_options: dict[str, Any] | None = None,
198
+ vector_params: dict[str, Any] | None = None,
199
+ sparse_vector_params: dict[str, Any] | None = None,
200
+ batch_size: int = 64,
201
+ force_recreate: bool = False,
202
+ validate_embeddings: bool = True,
203
+ validate_collection_config: bool = True,
204
+ **kwargs: Any,
205
+ ) -> QdrantVectorStore:
206
+ """
207
+ Construct an instance of `QdrantVectorStore` from a list of texts.
208
+ """
209
+ if sparse_vector_params is None:
210
+ sparse_vector_params = {}
211
+ if vector_params is None:
212
+ vector_params = {}
213
+ if collection_create_options is None:
214
+ collection_create_options = {}
215
+ client_options = {
216
+ "location": location,
217
+ "url": url,
218
+ "port": port,
219
+ "grpc_port": grpc_port,
220
+ "prefer_grpc": prefer_grpc,
221
+ "https": https,
222
+ "api_key": api_key,
223
+ "prefix": prefix,
224
+ "timeout": timeout,
225
+ "host": host,
226
+ "path": path,
227
+ **kwargs,
228
+ }
229
+
230
+ qdrant = cls.construct_instance(
231
+ embedding,
232
+ retrieval_mode,
233
+ sparse_embedding,
234
+ client_options,
235
+ collection_name,
236
+ distance,
237
+ content_payload_key,
238
+ metadata_payload_key,
239
+ vector_name,
240
+ sparse_vector_name,
241
+ force_recreate,
242
+ collection_create_options,
243
+ vector_params,
244
+ sparse_vector_params,
245
+ validate_embeddings,
246
+ validate_collection_config,
247
+ )
248
+ qdrant.add_texts(texts, metadatas, ids, batch_size)
249
+ return qdrant
250
+
251
+ def add_documents(
252
+ self,
253
+ documents: Sequence[Document],
254
+ ids: Sequence[str | int] | None = None,
255
+ batch_size: int = 64,
256
+ **kwargs: Any,
257
+ ) -> list[str | int]:
258
+
259
+ texts = [doc.page_content for doc in documents]
260
+
261
+ metadatas = [doc.metadata if doc.metadata is not None else {} for doc in documents]
262
+
263
+ return self.add_texts(
264
+ texts=texts,
265
+ metadatas=metadatas,
266
+ ids=ids,
267
+ batch_size=batch_size,
268
+ **kwargs,
269
+ )
270
+
271
+ @classmethod
272
+ def from_documents(
273
+ cls,
274
+ documents: list[Document],
275
+ embedding: TextEmbedding,
276
+ **kwargs: Any,
277
+ ):
278
+ """Return `VectorStore` initialized from documents and embeddings.
279
+
280
+ Args:
281
+ documents: List of `Document` objects to add to the `VectorStore`.
282
+ embedding: Embedding function to use.
283
+ **kwargs: Additional keyword arguments.
284
+
285
+ Returns:
286
+ `VectorStore` initialized from documents and embeddings.
287
+ """
288
+ texts = [d.page_content for d in documents]
289
+ metadatas = [d.metadata for d in documents]
290
+
291
+ if "ids" not in kwargs:
292
+ ids = [doc.metadata.get("chunk_id") for doc in documents]
293
+
294
+ # If there's at least one valid ID, we'll assume that IDs
295
+ # should be used.
296
+ if any(ids):
297
+ kwargs["ids"] = ids
298
+
299
+ return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
300
+
301
+ @classmethod
302
+ def from_existing_collection(
303
+ cls: type[QdrantVectorStore],
304
+ collection_name: str,
305
+ embedding: TextEmbedding | None = None,
306
+ retrieval_mode: RetrievalMode = RetrievalMode.DENSE,
307
+ location: str | None = None,
308
+ url: str | None = None,
309
+ port: int | None = 6333,
310
+ grpc_port: int = 6334,
311
+ prefer_grpc: bool = False,
312
+ https: bool | None = None,
313
+ api_key: str | None = None,
314
+ prefix: str | None = None,
315
+ timeout: int | None = None,
316
+ host: str | None = None,
317
+ path: str | None = None,
318
+ distance: models.Distance = models.Distance.COSINE,
319
+ content_payload_key: str = CONTENT_KEY,
320
+ metadata_payload_key: str = METADATA_KEY,
321
+ vector_name: str = VECTOR_NAME,
322
+ sparse_vector_name: str = SPARSE_VECTOR_NAME,
323
+ sparse_embedding: SparseEmbeddings | None = None,
324
+ validate_embeddings: bool = True,
325
+ validate_collection_config: bool = True,
326
+ **kwargs: Any,
327
+ ) -> QdrantVectorStore:
328
+ """Construct `QdrantVectorStore` from existing collection without adding data.
329
+
330
+ Returns:
331
+ QdrantVectorStore: A new instance of `QdrantVectorStore`.
332
+ """
333
+ client = QdrantClient(
334
+ location=location,
335
+ url=url,
336
+ port=port,
337
+ grpc_port=grpc_port,
338
+ prefer_grpc=prefer_grpc,
339
+ https=https,
340
+ api_key=api_key,
341
+ prefix=prefix,
342
+ timeout=timeout,
343
+ host=host,
344
+ path=path,
345
+ **kwargs,
346
+ )
347
+
348
+ return cls(
349
+ client=client,
350
+ collection_name=collection_name,
351
+ embedding=embedding,
352
+ retrieval_mode=retrieval_mode,
353
+ content_payload_key=content_payload_key,
354
+ metadata_payload_key=metadata_payload_key,
355
+ distance=distance,
356
+ vector_name=vector_name,
357
+ sparse_embedding=sparse_embedding,
358
+ sparse_vector_name=sparse_vector_name,
359
+ validate_embeddings=validate_embeddings,
360
+ validate_collection_config=validate_collection_config,
361
+ )
362
+
363
+ def add_texts( # type: ignore[override]
364
+ self,
365
+ texts: Iterable[str],
366
+ metadatas: list[dict] | None = None,
367
+ ids: Sequence[str | int] | None = None,
368
+ batch_size: int = 64,
369
+ **kwargs: Any,
370
+ ) -> list[str | int]:
371
+ """Add texts with embeddings to the `VectorStore`.
372
+
373
+ Returns:
374
+ List of ids from adding the texts into the `VectorStore`.
375
+
376
+ """
377
+ added_ids = []
378
+ for batch_ids, points in self._generate_batches(
379
+ texts, metadatas, ids, batch_size
380
+ ):
381
+ self.client.upsert(
382
+ collection_name=self.collection_name, points=points, **kwargs
383
+ )
384
+ added_ids.extend(batch_ids)
385
+
386
+ return added_ids
387
+
388
+ def similarity_search(
389
+ self,
390
+ query: str,
391
+ k: int = 4,
392
+ filter: models.Filter | None = None,
393
+ search_params: models.SearchParams | None = None,
394
+ offset: int = 0,
395
+ score_threshold: float | None = None,
396
+ consistency: models.ReadConsistency | None = None,
397
+ hybrid_fusion: models.FusionQuery | None = None,
398
+ **kwargs: Any,
399
+ ) -> list[Document]:
400
+ """Return docs most similar to query.
401
+
402
+ Returns:
403
+ List of `Document` objects most similar to the query.
404
+
405
+ """
406
+ results = self.similarity_search_with_score(
407
+ query,
408
+ k,
409
+ filter=filter,
410
+ search_params=search_params,
411
+ offset=offset,
412
+ score_threshold=score_threshold,
413
+ consistency=consistency,
414
+ hybrid_fusion=hybrid_fusion,
415
+ **kwargs,
416
+ )
417
+ return list(map(itemgetter(0), results))
418
+
419
+ def similarity_search_with_score(
420
+ self,
421
+ query: str,
422
+ k: int = 4,
423
+ filter: models.Filter | None = None,
424
+ search_params: models.SearchParams | None = None,
425
+ offset: int = 0,
426
+ score_threshold: float | None = None,
427
+ consistency: models.ReadConsistency | None = None,
428
+ hybrid_fusion: models.FusionQuery | None = None,
429
+ **kwargs: Any,
430
+ ) -> list[tuple[Document, float]]:
431
+ """Return docs most similar to query.
432
+
433
+ Returns:
434
+ List of documents most similar to the query text and distance for each.
435
+
436
+ """
437
+ query_options = {
438
+ "collection_name": self.collection_name,
439
+ "query_filter": filter,
440
+ "search_params": search_params,
441
+ "limit": k,
442
+ "offset": offset,
443
+ "with_payload": True,
444
+ "with_vectors": False,
445
+ "score_threshold": score_threshold,
446
+ "consistency": consistency,
447
+ **kwargs,
448
+ }
449
+ if self.retrieval_mode == RetrievalMode.DENSE:
450
+ embeddings = self._require_embeddings("DENSE mode")
451
+ query_dense_embedding = list(embeddings.embed(query))[0]
452
+ results = self.client.query_points(
453
+ query=query_dense_embedding,
454
+ using=self.vector_name,
455
+ **query_options,
456
+ ).points
457
+
458
+ elif self.retrieval_mode == RetrievalMode.SPARSE:
459
+ query_sparse_embedding = self.sparse_embeddings.embed_query(query)
460
+ results = self.client.query_points(
461
+ query=models.SparseVector(
462
+ indices=query_sparse_embedding.indices,
463
+ values=query_sparse_embedding.values,
464
+ ),
465
+ using=self.sparse_vector_name,
466
+ **query_options,
467
+ ).points
468
+
469
+ elif self.retrieval_mode == RetrievalMode.HYBRID:
470
+ embeddings = self._require_embeddings("HYBRID mode")
471
+ query_dense_embedding = list(embeddings.embed(query))[0]
472
+ query_sparse_embedding = self.sparse_embeddings.embed_query(query)
473
+ results = self.client.query_points(
474
+ prefetch=[
475
+ models.Prefetch(
476
+ using=self.vector_name,
477
+ query=query_dense_embedding,
478
+ filter=filter,
479
+ limit=k,
480
+ params=search_params,
481
+ ),
482
+ models.Prefetch(
483
+ using=self.sparse_vector_name,
484
+ query=models.SparseVector(
485
+ indices=query_sparse_embedding.indices,
486
+ values=query_sparse_embedding.values,
487
+ ),
488
+ filter=filter,
489
+ limit=k,
490
+ params=search_params,
491
+ ),
492
+ ],
493
+ query=hybrid_fusion or models.FusionQuery(fusion=models.Fusion.RRF),
494
+ **query_options,
495
+ ).points
496
+
497
+ else:
498
+ msg = f"Invalid retrieval mode. {self.retrieval_mode}."
499
+ raise ValueError(msg)
500
+ return [
501
+ (
502
+ self._document_from_point(
503
+ result,
504
+ self.collection_name,
505
+ self.content_payload_key,
506
+ self.metadata_payload_key,
507
+ ),
508
+ result.score,
509
+ )
510
+ for result in results
511
+ ]
512
+
513
+ def max_marginal_relevance_search(
514
+ self,
515
+ query: str,
516
+ k: int = 4,
517
+ fetch_k: int = 20,
518
+ lambda_mult: float = 0.5,
519
+ filter: models.Filter | None = None,
520
+ search_params: models.SearchParams | None = None,
521
+ score_threshold: float | None = None,
522
+ consistency: models.ReadConsistency | None = None,
523
+ **kwargs: Any,
524
+ ) -> list[Document]:
525
+ """Return docs selected using the maximal marginal relevance with dense vectors.
526
+
527
+ Maximal marginal relevance optimizes for similarity to query AND diversity
528
+ among selected documents.
529
+
530
+ Returns:
531
+ List of `Document` objects selected by maximal marginal relevance.
532
+
533
+ """
534
+ self._validate_collection_for_dense(
535
+ self.client,
536
+ self.collection_name,
537
+ self.vector_name,
538
+ self.distance,
539
+ self.embeddings,
540
+ )
541
+
542
+ embeddings = self._require_embeddings("max_marginal_relevance_search")
543
+ query_embedding = list(embeddings.embed(query))[0]
544
+ return self.max_marginal_relevance_search_by_vector(
545
+ query_embedding,
546
+ k=k,
547
+ fetch_k=fetch_k,
548
+ lambda_mult=lambda_mult,
549
+ filter=filter,
550
+ search_params=search_params,
551
+ score_threshold=score_threshold,
552
+ consistency=consistency,
553
+ **kwargs,
554
+ )
555
+
556
+ def max_marginal_relevance_search_by_vector(
557
+ self,
558
+ embedding: list[float],
559
+ k: int = 4,
560
+ fetch_k: int = 20,
561
+ lambda_mult: float = 0.5,
562
+ filter: models.Filter | None = None, # noqa: A002
563
+ search_params: models.SearchParams | None = None,
564
+ score_threshold: float | None = None,
565
+ consistency: models.ReadConsistency | None = None,
566
+ **kwargs: Any,
567
+ ) -> list[Document]:
568
+ """Return docs selected using the maximal marginal relevance with dense vectors.
569
+
570
+ Maximal marginal relevance optimizes for similarity to query AND diversity
571
+ among selected documents.
572
+
573
+ Returns:
574
+ List of `Document` objects selected by maximal marginal relevance.
575
+
576
+ """
577
+ results = self.max_marginal_relevance_search_with_score_by_vector(
578
+ embedding,
579
+ k=k,
580
+ fetch_k=fetch_k,
581
+ lambda_mult=lambda_mult,
582
+ filter=filter,
583
+ search_params=search_params,
584
+ score_threshold=score_threshold,
585
+ consistency=consistency,
586
+ **kwargs,
587
+ )
588
+ return list(map(itemgetter(0), results))
589
+
590
+ def max_marginal_relevance_search_with_score_by_vector(
591
+ self,
592
+ embedding: list[float],
593
+ k: int = 4,
594
+ fetch_k: int = 20,
595
+ lambda_mult: float = 0.5,
596
+ filter: models.Filter | None = None, # noqa: A002
597
+ search_params: models.SearchParams | None = None,
598
+ score_threshold: float | None = None,
599
+ consistency: models.ReadConsistency | None = None,
600
+ **kwargs: Any,
601
+ ) -> list[tuple[Document, float]]:
602
+ """Return docs selected using the maximal marginal relevance.
603
+
604
+ Maximal marginal relevance optimizes for similarity to query AND diversity
605
+ among selected documents.
606
+
607
+ Returns:
608
+ List of `Document` objects selected by maximal marginal relevance and
609
+ distance for each.
610
+ """
611
+ results = self.client.query_points(
612
+ collection_name=self.collection_name,
613
+ query=models.NearestQuery(
614
+ nearest=embedding,
615
+ mmr=models.Mmr(diversity=lambda_mult, candidates_limit=fetch_k),
616
+ ),
617
+ query_filter=filter,
618
+ search_params=search_params,
619
+ limit=k,
620
+ with_payload=True,
621
+ with_vectors=True,
622
+ score_threshold=score_threshold,
623
+ consistency=consistency,
624
+ using=self.vector_name,
625
+ **kwargs,
626
+ ).points
627
+
628
+ return [
629
+ (
630
+ self._document_from_point(
631
+ result,
632
+ self.collection_name,
633
+ self.content_payload_key,
634
+ self.metadata_payload_key,
635
+ ),
636
+ result.score,
637
+ )
638
+ for result in results
639
+ ]
640
+
641
+ def max_marginal_relevance_search_with_score(
642
+ self,
643
+ query: str,
644
+ k: int = 4,
645
+ fetch_k: int = 20,
646
+ lambda_mult: float = 0.5,
647
+ filter: models.Filter | None = None,
648
+ search_params: models.SearchParams | None = None,
649
+ score_threshold: float | None = None,
650
+ consistency: models.ReadConsistency | None = None,
651
+ **kwargs: Any,
652
+ ) -> list[tuple[Document, float]]:
653
+ """Return docs selected using the maximal marginal relevance with dense vectors.
654
+
655
+ Maximal marginal relevance optimizes for similarity to query AND diversity
656
+ among selected documents.
657
+
658
+ Returns:
659
+ List of `Document` objects selected by maximal marginal relevance.
660
+
661
+ """
662
+ self._validate_collection_for_dense(
663
+ self.client,
664
+ self.collection_name,
665
+ self.vector_name,
666
+ self.distance,
667
+ self.embeddings,
668
+ )
669
+
670
+ embeddings = self._require_embeddings("max_marginal_relevance_search")
671
+ query_embedding = list(embeddings.embed(query))[0]
672
+ return self.max_marginal_relevance_search_with_score_by_vector(
673
+ query_embedding,
674
+ k=k,
675
+ fetch_k=fetch_k,
676
+ lambda_mult=lambda_mult,
677
+ filter=filter,
678
+ search_params=search_params,
679
+ score_threshold=score_threshold,
680
+ consistency=consistency,
681
+ **kwargs,
682
+ )
683
+ # TO-DO
684
+ # def delete(
685
+ # self,
686
+ # ids: list[str | int] | None = None,
687
+ # **kwargs: Any,
688
+ # ) -> bool | None:
689
+ # """Delete documents by their ids.
690
+
691
+ # Args:
692
+ # ids: List of ids to delete.
693
+ # **kwargs: Other keyword arguments that subclasses might use.
694
+
695
+ # Returns:
696
+ # True if deletion is successful, `False` otherwise.
697
+
698
+ # """
699
+ # result = self.client.delete(
700
+ # collection_name=self.collection_name,
701
+ # points_selector=ids,
702
+ # )
703
+ # return result.status == models.UpdateStatus.COMPLETED
704
+
705
+
706
+ @classmethod
707
+ def construct_instance(
708
+ cls: type[QdrantVectorStore],
709
+ embedding: TextEmbedding | None = None,
710
+ retrieval_mode: RetrievalMode = RetrievalMode.DENSE,
711
+ sparse_embedding: SparseEmbeddings | None = None,
712
+ client_options: dict[str, Any] | None = None,
713
+ collection_name: str | None = None,
714
+ distance: models.Distance = models.Distance.COSINE,
715
+ content_payload_key: str = CONTENT_KEY,
716
+ metadata_payload_key: str = METADATA_KEY,
717
+ vector_name: str = VECTOR_NAME,
718
+ sparse_vector_name: str = SPARSE_VECTOR_NAME,
719
+ force_recreate: bool = False,
720
+ collection_create_options: dict[str, Any] | None = None,
721
+ vector_params: dict[str, Any] | None = None,
722
+ sparse_vector_params: dict[str, Any] | None = None,
723
+ validate_embeddings: bool = True,
724
+ validate_collection_config: bool = True,
725
+ ) -> QdrantVectorStore:
726
+ if sparse_vector_params is None:
727
+ sparse_vector_params = {}
728
+ if vector_params is None:
729
+ vector_params = {}
730
+ if collection_create_options is None:
731
+ collection_create_options = {}
732
+ if client_options is None:
733
+ client_options = {}
734
+ if validate_embeddings:
735
+ cls._validate_embeddings(retrieval_mode, embedding, sparse_embedding)
736
+ collection_name = collection_name or uuid.uuid4().hex
737
+ client = QdrantClient(**client_options)
738
+
739
+ collection_exists = client.collection_exists(collection_name)
740
+
741
+ if collection_exists and force_recreate:
742
+ client.delete_collection(collection_name)
743
+ collection_exists = False
744
+ if collection_exists:
745
+ if validate_collection_config:
746
+ cls._validate_collection_config(
747
+ client,
748
+ collection_name,
749
+ retrieval_mode,
750
+ vector_name,
751
+ sparse_vector_name,
752
+ distance,
753
+ embedding,
754
+ )
755
+ else:
756
+ vectors_config, sparse_vectors_config = {}, {}
757
+ if retrieval_mode == RetrievalMode.DENSE:
758
+ partial_embeddings = list(embedding.embed(["dummy_text"]))
759
+
760
+ vector_params["size"] = len(partial_embeddings[0])
761
+ vector_params["distance"] = distance
762
+
763
+ vectors_config = {
764
+ vector_name: models.VectorParams(
765
+ **vector_params,
766
+ )
767
+ }
768
+
769
+ elif retrieval_mode == RetrievalMode.SPARSE:
770
+ sparse_vectors_config = {
771
+ sparse_vector_name: models.SparseVectorParams(
772
+ **sparse_vector_params
773
+ )
774
+ }
775
+
776
+ elif retrieval_mode == RetrievalMode.HYBRID:
777
+ partial_embeddings = list(embedding.embed(["dummy_text"]))
778
+
779
+ vector_params["size"] = len(partial_embeddings[0])
780
+ vector_params["distance"] = distance
781
+
782
+ vectors_config = {
783
+ vector_name: models.VectorParams(
784
+ **vector_params,
785
+ )
786
+ }
787
+
788
+ sparse_vectors_config = {
789
+ sparse_vector_name: models.SparseVectorParams(
790
+ **sparse_vector_params
791
+ )
792
+ }
793
+
794
+ collection_create_options["collection_name"] = collection_name
795
+ collection_create_options["vectors_config"] = vectors_config
796
+ collection_create_options["sparse_vectors_config"] = sparse_vectors_config
797
+
798
+ client.create_collection(**collection_create_options)
799
+
800
+ return cls(
801
+ client=client,
802
+ collection_name=collection_name,
803
+ embedding=embedding,
804
+ retrieval_mode=retrieval_mode,
805
+ content_payload_key=content_payload_key,
806
+ metadata_payload_key=metadata_payload_key,
807
+ distance=distance,
808
+ vector_name=vector_name,
809
+ sparse_embedding=sparse_embedding,
810
+ sparse_vector_name=sparse_vector_name,
811
+ validate_embeddings=False,
812
+ validate_collection_config=False,
813
+ )
814
+
815
+ @staticmethod
816
+ def _cosine_relevance_score_fn(distance: float) -> float:
817
+ """Normalize the distance to a score on a scale `[0, 1]`."""
818
+ return (distance + 1.0) / 2.0
819
+
820
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
821
+ """Your "correct" relevance function may differ depending on a few things.
822
+
823
+ Including:
824
+ - The distance / similarity metric used by the VectorStore
825
+ - The scale of your embeddings (OpenAI's are unit normed. Many others are not!)
826
+ - Embedding dimensionality
827
+ - etc.
828
+ """
829
+ if self.distance == models.Distance.COSINE:
830
+ return self._cosine_relevance_score_fn
831
+ if self.distance == models.Distance.DOT:
832
+ return self._max_inner_product_relevance_score_fn
833
+ if self.distance == models.Distance.EUCLID:
834
+ return self._euclidean_relevance_score_fn
835
+ msg = "Unknown distance strategy, must be COSINE, DOT, or EUCLID."
836
+ raise ValueError(msg)
837
+
838
+ @classmethod
839
+ def _document_from_point(
840
+ cls,
841
+ scored_point: Any,
842
+ collection_name: str,
843
+ content_payload_key: str,
844
+ metadata_payload_key: str,
845
+ ) -> Document:
846
+ metadata = scored_point.payload.get(metadata_payload_key) or {}
847
+ metadata["_id"] = scored_point.id
848
+ metadata["_collection_name"] = collection_name
849
+ return Document(
850
+ page_content=scored_point.payload.get(content_payload_key, ""),
851
+ metadata=metadata,
852
+ )
853
+
854
+ def _generate_batches(
855
+ self,
856
+ texts: Iterable[str],
857
+ metadatas: list[dict] | None = None,
858
+ ids: Sequence[str | int] | None = None,
859
+ batch_size: int = 64,
860
+ ) -> Generator[tuple[list[str | int], list[models.PointStruct]], Any, None]:
861
+ texts_iterator = iter(texts)
862
+ metadatas_iterator = iter(metadatas or [])
863
+ ids_iterator = iter(ids or [uuid.uuid4().hex for _ in iter(texts)])
864
+
865
+ while batch_texts := list(islice(texts_iterator, batch_size)):
866
+ batch_metadatas = list(islice(metadatas_iterator, batch_size)) or None
867
+ batch_ids = list(islice(ids_iterator, batch_size))
868
+ points = [
869
+ models.PointStruct(
870
+ id=point_id,
871
+ vector=vector,
872
+ payload=payload,
873
+ )
874
+ for point_id, vector, payload in zip(
875
+ batch_ids,
876
+ self._build_vectors(batch_texts),
877
+ self._build_payloads(
878
+ batch_texts,
879
+ batch_metadatas,
880
+ self.content_payload_key,
881
+ self.metadata_payload_key,
882
+ ),
883
+ strict=False,
884
+ )
885
+ ]
886
+
887
+ yield batch_ids, points
888
+
889
+ @staticmethod
890
+ def _build_payloads(
891
+ texts: Iterable[str],
892
+ metadatas: list[dict] | None,
893
+ content_payload_key: str,
894
+ metadata_payload_key: str,
895
+ ) -> list[dict]:
896
+ payloads = []
897
+ for i, text in enumerate(texts):
898
+ if text is None:
899
+ msg = (
900
+ "At least one of the texts is None. Please remove it before "
901
+ "calling .from_texts or .add_texts."
902
+ )
903
+ raise ValueError(msg)
904
+ metadata = metadatas[i] if metadatas is not None else None
905
+ payloads.append(
906
+ {
907
+ content_payload_key: text,
908
+ metadata_payload_key: metadata,
909
+ }
910
+ )
911
+
912
+ return payloads
913
+
914
+ def _build_vectors(
915
+ self,
916
+ texts: Iterable[str],
917
+ ) -> list[models.VectorStruct]:
918
+ if self.retrieval_mode == RetrievalMode.DENSE:
919
+ embeddings = self._require_embeddings("DENSE mode")
920
+ batch_embeddings = list(embeddings.embed(list(texts)))
921
+ return [
922
+ {
923
+ self.vector_name: vector,
924
+ }
925
+ for vector in batch_embeddings
926
+ ]
927
+
928
+ if self.retrieval_mode == RetrievalMode.SPARSE:
929
+ batch_sparse_embeddings = self.sparse_embeddings.embed_documents(
930
+ list(texts)
931
+ )
932
+ return [
933
+ {
934
+ self.sparse_vector_name: models.SparseVector(
935
+ values=vector.values, indices=vector.indices
936
+ )
937
+ }
938
+ for vector in batch_sparse_embeddings
939
+ ]
940
+
941
+ if self.retrieval_mode == RetrievalMode.HYBRID:
942
+ embeddings = self._require_embeddings("HYBRID mode")
943
+ dense_embeddings = list(embeddings.embed(list(texts)))
944
+ sparse_embeddings = self.sparse_embeddings.embed_documents(list(texts))
945
+
946
+ if len(dense_embeddings) != len(sparse_embeddings):
947
+ msg = "Mismatched length between dense and sparse embeddings."
948
+ raise ValueError(msg)
949
+
950
+ return [
951
+ {
952
+ self.vector_name: dense_vector,
953
+ self.sparse_vector_name: models.SparseVector(
954
+ values=sparse_vector.values, indices=sparse_vector.indices
955
+ ),
956
+ }
957
+ for dense_vector, sparse_vector in zip(
958
+ dense_embeddings, sparse_embeddings, strict=False
959
+ )
960
+ ]
961
+
962
+ msg = f"Unknown retrieval mode. {self.retrieval_mode} to build vectors."
963
+ raise ValueError(msg)
964
+
965
+ @classmethod
966
+ def _validate_collection_config(
967
+ cls: type[QdrantVectorStore],
968
+ client: QdrantClient,
969
+ collection_name: str,
970
+ retrieval_mode: RetrievalMode,
971
+ vector_name: str,
972
+ sparse_vector_name: str,
973
+ distance: models.Distance,
974
+ embedding: TextEmbedding | None,
975
+ ) -> None:
976
+ if retrieval_mode == RetrievalMode.DENSE:
977
+ cls._validate_collection_for_dense(
978
+ client, collection_name, vector_name, distance, embedding
979
+ )
980
+
981
+ elif retrieval_mode == RetrievalMode.SPARSE:
982
+ cls._validate_collection_for_sparse(
983
+ client, collection_name, sparse_vector_name
984
+ )
985
+
986
+ elif retrieval_mode == RetrievalMode.HYBRID:
987
+ cls._validate_collection_for_dense(
988
+ client, collection_name, vector_name, distance, embedding
989
+ )
990
+ cls._validate_collection_for_sparse(
991
+ client, collection_name, sparse_vector_name
992
+ )
993
+
994
+ @classmethod
995
+ def _validate_collection_for_dense(
996
+ cls: type[QdrantVectorStore],
997
+ client: QdrantClient,
998
+ collection_name: str,
999
+ vector_name: str,
1000
+ distance: models.Distance,
1001
+ dense_embeddings: TextEmbedding | list[float] | None,
1002
+ ) -> None:
1003
+ collection_info = client.get_collection(collection_name=collection_name)
1004
+ vector_config = collection_info.config.params.vectors
1005
+
1006
+ if isinstance(vector_config, dict):
1007
+ # vector_config is a Dict[str, VectorParams]
1008
+ if vector_name not in vector_config:
1009
+ msg = (
1010
+ f"Existing Qdrant collection {collection_name} does not "
1011
+ f"contain dense vector named {vector_name}. "
1012
+ "Did you mean one of the "
1013
+ f"existing vectors: {', '.join(vector_config.keys())}? "
1014
+ f"If you want to recreate the collection, set `force_recreate` "
1015
+ f"parameter to `True`."
1016
+ )
1017
+ raise QdrantVectorStoreError(msg)
1018
+
1019
+ # Get the VectorParams object for the specified vector_name
1020
+ vector_config = vector_config[vector_name] # type: ignore[assignment, index]
1021
+
1022
+ # vector_config is an instance of VectorParams
1023
+ # Case of a collection with single/unnamed vector.
1024
+ elif vector_name != "":
1025
+ msg = (
1026
+ f"Existing Qdrant collection {collection_name} is built "
1027
+ "with unnamed dense vector. "
1028
+ f"If you want to reuse it, set `vector_name` to ''(empty string)."
1029
+ f"If you want to recreate the collection, "
1030
+ "set `force_recreate` to `True`."
1031
+ )
1032
+ raise QdrantVectorStoreError(msg)
1033
+
1034
+ if vector_config is None:
1035
+ msg = "VectorParams is None"
1036
+ raise ValueError(msg)
1037
+
1038
+ if isinstance(dense_embeddings, TextEmbedding):
1039
+ embeddings = list(dense_embeddings.embed(["dummy_text"]))[0]
1040
+ vector_size = len(embeddings)
1041
+ elif isinstance(dense_embeddings, list):
1042
+ vector_size = len(dense_embeddings)
1043
+ else:
1044
+ msg = "Invalid `embeddings` type."
1045
+ raise TypeError(msg)
1046
+
1047
+ if vector_config.size != vector_size:
1048
+ msg = (
1049
+ f"Existing Qdrant collection is configured for dense vectors with "
1050
+ f"{vector_config.size} dimensions. "
1051
+ f"Selected embeddings are {vector_size}-dimensional. "
1052
+ f"If you want to recreate the collection, set `force_recreate` "
1053
+ f"parameter to `True`."
1054
+ )
1055
+ raise QdrantVectorStoreError(msg)
1056
+
1057
+ if vector_config.distance != distance:
1058
+ msg = (
1059
+ f"Existing Qdrant collection is configured for "
1060
+ f"{vector_config.distance.name} similarity, but requested "
1061
+ f"{distance.upper()}. Please set `distance` parameter to "
1062
+ f"`{vector_config.distance.name}` if you want to reuse it. "
1063
+ f"If you want to recreate the collection, set `force_recreate` "
1064
+ f"parameter to `True`."
1065
+ )
1066
+ raise QdrantVectorStoreError(msg)
1067
+
1068
+ @classmethod
1069
+ def _validate_collection_for_sparse(
1070
+ cls: type[QdrantVectorStore],
1071
+ client: QdrantClient,
1072
+ collection_name: str,
1073
+ sparse_vector_name: str,
1074
+ ) -> None:
1075
+ collection_info = client.get_collection(collection_name=collection_name)
1076
+ sparse_vector_config = collection_info.config.params.sparse_vectors
1077
+
1078
+ if (
1079
+ sparse_vector_config is None
1080
+ or sparse_vector_name not in sparse_vector_config
1081
+ ):
1082
+ msg = (
1083
+ f"Existing Qdrant collection {collection_name} does not "
1084
+ f"contain sparse vectors named {sparse_vector_name}. "
1085
+ f"If you want to recreate the collection, set `force_recreate` "
1086
+ f"parameter to `True`."
1087
+ )
1088
+ raise QdrantVectorStoreError(msg)
1089
+
1090
+ @classmethod
1091
+ def _validate_embeddings(
1092
+ cls: type[QdrantVectorStore],
1093
+ retrieval_mode: RetrievalMode,
1094
+ embedding: TextEmbedding | None,
1095
+ sparse_embedding: SparseEmbeddings | None,
1096
+ ) -> None:
1097
+ if retrieval_mode == RetrievalMode.DENSE and embedding is None:
1098
+ msg = "'embedding' cannot be None when retrieval mode is 'dense'"
1099
+ raise ValueError(msg)
1100
+
1101
+ if retrieval_mode == RetrievalMode.SPARSE and sparse_embedding is None:
1102
+ msg = "'sparse_embedding' cannot be None when retrieval mode is 'sparse'"
1103
+ raise ValueError(msg)
1104
+
1105
+ if retrieval_mode == RetrievalMode.HYBRID and any(
1106
+ [embedding is None, sparse_embedding is None]
1107
+ ):
1108
+ msg = (
1109
+ "Both 'embedding' and 'sparse_embedding' cannot be None "
1110
+ "when retrieval mode is 'hybrid'"
1111
+ )
1112
+ raise ValueError(msg)