from qdrant_client import QdrantClient from qdrant_client.http import models from config import qdrant_api_key from typing import List, Dict, Any from embedding_gen import client as azure_client from qdrant_client.http.exceptions import UnexpectedResponse # Your Qdrant Cloud info url = "https://f9f00eb9-b14e-49f6-b1f1-8f77e9e4a3ab.europe-west3-0.gcp.cloud.qdrant.io" api_key = qdrant_api_key collection_name = "test" client = QdrantClient( url=url, api_key=api_key ) vector_size=768 def ensure_collection(): try: if client.collection_exists(collection_name): print(f"Collection '{collection_name}' already exists") else: client.create_collection( collection_name=collection_name, vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE), ) print(f"Collection '{collection_name}' created") except UnexpectedResponse as e: print(f"Error ensuring collection: {e}") raise ensure_collection() def batch_upsert_docling_chunks_to_qdrant( chunks_data: List[Dict[str, Any]], batch_size: int = 20 ) -> None: """ Upsert Docling chunks with embeddings and metadata to Qdrant in batches. Args: chunks_data: List of dicts with 'text', 'embedding', and 'metadata' keys. batch_size: Number of points to upsert per batch (default is 20). """ points = [] for i, chunk_data in enumerate(chunks_data): # Create payload with metadata payload = { "text": chunk_data["text"], "headings": chunk_data["metadata"]["headings"] } # Add other metadata fields if they exist if "page_number" in chunk_data["metadata"]: payload["page_number"] = chunk_data["metadata"]["page_number"] points.append( models.PointStruct( id=i, vector=chunk_data["embedding"], payload=payload ) ) total_points = len(points) for start_idx in range(0, total_points, batch_size): end_idx = min(start_idx + batch_size, total_points) batch_points = points[start_idx:end_idx] client.upsert(collection_name=collection_name, points=batch_points, wait=True) print(f"Upserted points {start_idx} to {end_idx - 1}") print(f"Total upserted points: {total_points}") # Alternative: Backward compatible version that accepts separate lists def batch_upsert_to_qdrant_with_metadata( content_list: List[str], embedding_list: List[List[float]], metadata_list: List[Dict[str, Any]], batch_size: int = 20 ) -> None: """ Upsert embedding-content pairs with metadata to Qdrant in batches. Args: content_list: List of text chunks. embedding_list: List of vectors corresponding to content. metadata_list: List of metadata dicts corresponding to content. batch_size: Number of points to upsert per batch (default is 20). """ if not (len(content_list) == len(embedding_list) == len(metadata_list)): raise ValueError("Content, embedding, and metadata list lengths do not match!") points = [] for i in range(len(content_list)): payload = { "text": content_list[i], "headings": metadata_list[i].get("headings", []) } # Add other metadata fields for key, value in metadata_list[i].items(): if key not in ["text", "headings"]: payload[key] = value points.append( models.PointStruct( id=i, vector=embedding_list[i], payload=payload ) ) total_points = len(points) for start_idx in range(0, total_points, batch_size): end_idx = min(start_idx + batch_size, total_points) batch_points = points[start_idx:end_idx] client.upsert(collection_name=collection_name, points=batch_points, wait=True) print(f"Upserted points {start_idx} to {end_idx - 1}") print(f"Total upserted points: {total_points}") def clear_collection_payloads(): """ Delete all points from an existing Qdrant collection, without deleting the collection itself. This empties the collection. Args: client: Initialized QdrantClient. collection_name: Name of the existing Qdrant collection. """ # Use empty filter to delete all points # Note: If your Qdrant server/client version doesn't support empty filter deletion, # alternatively you can delete by querying all IDs first and then deleting. client.delete_collection(collection_name) client.create_collection( collection_name=collection_name, vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE), ) print("delete done") def search_in_qdrant( query: str, qdrant_client: QdrantClient = client, initial_limit: int = 5, length_threshold: int = 4500 ) -> List[Dict[str, Any]]: """ Search qdrant with up to `initial_limit` results. If combined content length of top 3 is less than `length_threshold`, return only top 3 results, else return top `initial_limit`. Args: query (str): Query string. qdrant_client (QdrantClient): Initialized Qdrant client. initial_limit (int): Max number of results to retrieve (e.g., 5). length_threshold (int): Threshold in characters on combined content length. Returns: List[Dict]: List of scored points with payloads. """ # Get embedding for the query response = azure_client.embeddings.create( input=query, model="text-embedding-3-large", dimensions=768 ) query_embedding = response.data[0].embedding # Retrieve up to 'initial_limit' results search_results = qdrant_client.query_points( collection_name=collection_name, query=query_embedding, limit=initial_limit, with_payload=True, ) if not search_results.points: return [] # Extract top 3 results content length top_3 = search_results.points[:3] combined_length = sum( len(p.payload.get('text', '')) for p in top_3 ) # Decide final number of results to return final_limit = 3 if combined_length < length_threshold else initial_limit selected_points = search_results.points[:final_limit] results = [] for scored_point in selected_points: results.append({ "id": scored_point.id, "score": scored_point.score, "payload": scored_point.payload, }) return results # def get_context_for_questions(questions: List[str], top_k: int = 3) -> List[str]: # contexts = [] # for q in questions: # results = search_in_qdrant(q, top_k=top_k) # if results: # # Use text from the top hit (or join top_k if you wish) # context = results[0]['payload'].get('text', '') # else: # context = '' # contexts.append(context) # return contexts import concurrent.futures from typing import List def get_context_for_questions( questions: List[str], max_workers: int = 4 ) -> List[str]: def get_context_single(question: str) -> str: results = search_in_qdrant(question) if results: return results[0]['payload'].get('text', '') return '' contexts = [None] * len(questions) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all tasks, get future -> index mapping futures = {executor.submit(get_context_single, q): idx for idx, q in enumerate(questions)} for future in concurrent.futures.as_completed(futures): idx = futures[future] try: contexts[idx] = future.result() except Exception as e: print(f"Error fetching context for question {idx}: {e}") contexts[idx] = "" # fallback empty string return contexts if __name__=="__main__": # import time # t1=time.time() # clear_collection_payloads() # t2=time.time() # print("time taken", t2-t1) # questions= [ # "What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?", # "What is the waiting period for pre-existing diseases (PED) to be covered?", # "Does this policy cover maternity expenses, and what are the conditions?", # "What is the waiting period for cataract surgery?", # "Are the medical expenses for an organ donor covered under this policy?", # "What is the No Claim Discount (NCD) offered in this policy?", # "Is there a benefit for preventive health check-ups?", # "How does the policy define a 'Hospital'?", # "What is the extent of coverage for AYUSH treatments?", # "Are there any sub-limits on room rent and ICU charges for Plan A?" # ] # import time # t1=time.time() # print(get_context_for_questions(questions)==get_context_for_questions_parallel(questions)) # t2=time.time() # print("time taken",t2-t1) print(search_in_qdrant("What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?")) #print_all_points()