Image_generation / qdrant_setup.py
manasdhir's picture
minor changes
789f374
from qdrant_client import QdrantClient
from qdrant_client.http import models
from config import qdrant_api_key
from typing import List, Dict, Any
from embedding_gen import client as azure_client
from qdrant_client.http.exceptions import UnexpectedResponse
# Your Qdrant Cloud info
url = "https://f9f00eb9-b14e-49f6-b1f1-8f77e9e4a3ab.europe-west3-0.gcp.cloud.qdrant.io"
api_key = qdrant_api_key
collection_name = "test"
client = QdrantClient(
url=url,
api_key=api_key
)
vector_size=768
def ensure_collection():
try:
if client.collection_exists(collection_name):
print(f"Collection '{collection_name}' already exists")
else:
client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE),
)
print(f"Collection '{collection_name}' created")
except UnexpectedResponse as e:
print(f"Error ensuring collection: {e}")
raise
ensure_collection()
def batch_upsert_docling_chunks_to_qdrant(
chunks_data: List[Dict[str, Any]],
batch_size: int = 20
) -> None:
"""
Upsert Docling chunks with embeddings and metadata to Qdrant in batches.
Args:
chunks_data: List of dicts with 'text', 'embedding', and 'metadata' keys.
batch_size: Number of points to upsert per batch (default is 20).
"""
points = []
for i, chunk_data in enumerate(chunks_data):
# Create payload with metadata
payload = {
"text": chunk_data["text"],
"headings": chunk_data["metadata"]["headings"]
}
# Add other metadata fields if they exist
if "page_number" in chunk_data["metadata"]:
payload["page_number"] = chunk_data["metadata"]["page_number"]
points.append(
models.PointStruct(
id=i,
vector=chunk_data["embedding"],
payload=payload
)
)
total_points = len(points)
for start_idx in range(0, total_points, batch_size):
end_idx = min(start_idx + batch_size, total_points)
batch_points = points[start_idx:end_idx]
client.upsert(collection_name=collection_name, points=batch_points, wait=True)
print(f"Upserted points {start_idx} to {end_idx - 1}")
print(f"Total upserted points: {total_points}")
# Alternative: Backward compatible version that accepts separate lists
def batch_upsert_to_qdrant_with_metadata(
content_list: List[str],
embedding_list: List[List[float]],
metadata_list: List[Dict[str, Any]],
batch_size: int = 20
) -> None:
"""
Upsert embedding-content pairs with metadata to Qdrant in batches.
Args:
content_list: List of text chunks.
embedding_list: List of vectors corresponding to content.
metadata_list: List of metadata dicts corresponding to content.
batch_size: Number of points to upsert per batch (default is 20).
"""
if not (len(content_list) == len(embedding_list) == len(metadata_list)):
raise ValueError("Content, embedding, and metadata list lengths do not match!")
points = []
for i in range(len(content_list)):
payload = {
"text": content_list[i],
"headings": metadata_list[i].get("headings", [])
}
# Add other metadata fields
for key, value in metadata_list[i].items():
if key not in ["text", "headings"]:
payload[key] = value
points.append(
models.PointStruct(
id=i,
vector=embedding_list[i],
payload=payload
)
)
total_points = len(points)
for start_idx in range(0, total_points, batch_size):
end_idx = min(start_idx + batch_size, total_points)
batch_points = points[start_idx:end_idx]
client.upsert(collection_name=collection_name, points=batch_points, wait=True)
print(f"Upserted points {start_idx} to {end_idx - 1}")
print(f"Total upserted points: {total_points}")
def clear_collection_payloads():
"""
Delete all points from an existing Qdrant collection, without deleting the collection itself.
This empties the collection.
Args:
client: Initialized QdrantClient.
collection_name: Name of the existing Qdrant collection.
"""
# Use empty filter to delete all points
# Note: If your Qdrant server/client version doesn't support empty filter deletion,
# alternatively you can delete by querying all IDs first and then deleting.
client.delete_collection(collection_name)
client.create_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE),
)
print("delete done")
def search_in_qdrant(
query: str,
qdrant_client: QdrantClient = client,
initial_limit: int = 5,
length_threshold: int = 4500
) -> List[Dict[str, Any]]:
"""
Search qdrant with up to `initial_limit` results.
If combined content length of top 3 is less than `length_threshold`,
return only top 3 results, else return top `initial_limit`.
Args:
query (str): Query string.
qdrant_client (QdrantClient): Initialized Qdrant client.
initial_limit (int): Max number of results to retrieve (e.g., 5).
length_threshold (int): Threshold in characters on combined content length.
Returns:
List[Dict]: List of scored points with payloads.
"""
# Get embedding for the query
response = azure_client.embeddings.create(
input=query,
model="text-embedding-3-large",
dimensions=768
)
query_embedding = response.data[0].embedding
# Retrieve up to 'initial_limit' results
search_results = qdrant_client.query_points(
collection_name=collection_name,
query=query_embedding,
limit=initial_limit,
with_payload=True,
)
if not search_results.points:
return []
# Extract top 3 results content length
top_3 = search_results.points[:3]
combined_length = sum(
len(p.payload.get('text', '')) for p in top_3
)
# Decide final number of results to return
final_limit = 3 if combined_length < length_threshold else initial_limit
selected_points = search_results.points[:final_limit]
results = []
for scored_point in selected_points:
results.append({
"id": scored_point.id,
"score": scored_point.score,
"payload": scored_point.payload,
})
return results
# def get_context_for_questions(questions: List[str], top_k: int = 3) -> List[str]:
# contexts = []
# for q in questions:
# results = search_in_qdrant(q, top_k=top_k)
# if results:
# # Use text from the top hit (or join top_k if you wish)
# context = results[0]['payload'].get('text', '')
# else:
# context = ''
# contexts.append(context)
# return contexts
import concurrent.futures
from typing import List
def get_context_for_questions(
questions: List[str],
max_workers: int = 4
) -> List[str]:
def get_context_single(question: str) -> str:
results = search_in_qdrant(question)
if results:
return results[0]['payload'].get('text', '')
return ''
contexts = [None] * len(questions)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all tasks, get future -> index mapping
futures = {executor.submit(get_context_single, q): idx for idx, q in enumerate(questions)}
for future in concurrent.futures.as_completed(futures):
idx = futures[future]
try:
contexts[idx] = future.result()
except Exception as e:
print(f"Error fetching context for question {idx}: {e}")
contexts[idx] = "" # fallback empty string
return contexts
if __name__=="__main__":
# import time
# t1=time.time()
# clear_collection_payloads()
# t2=time.time()
# print("time taken", t2-t1)
# questions= [
# "What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?",
# "What is the waiting period for pre-existing diseases (PED) to be covered?",
# "Does this policy cover maternity expenses, and what are the conditions?",
# "What is the waiting period for cataract surgery?",
# "Are the medical expenses for an organ donor covered under this policy?",
# "What is the No Claim Discount (NCD) offered in this policy?",
# "Is there a benefit for preventive health check-ups?",
# "How does the policy define a 'Hospital'?",
# "What is the extent of coverage for AYUSH treatments?",
# "Are there any sub-limits on room rent and ICU charges for Plan A?"
# ]
# import time
# t1=time.time()
# print(get_context_for_questions(questions)==get_context_for_questions_parallel(questions))
# t2=time.time()
# print("time taken",t2-t1)
print(search_in_qdrant("What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?"))
#print_all_points()