Spaces:
Sleeping
Sleeping
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models | |
| from config import qdrant_api_key | |
| from typing import List, Dict, Any | |
| from embedding_gen import client as azure_client | |
| from qdrant_client.http.exceptions import UnexpectedResponse | |
| # Your Qdrant Cloud info | |
| url = "https://f9f00eb9-b14e-49f6-b1f1-8f77e9e4a3ab.europe-west3-0.gcp.cloud.qdrant.io" | |
| api_key = qdrant_api_key | |
| collection_name = "test" | |
| client = QdrantClient( | |
| url=url, | |
| api_key=api_key | |
| ) | |
| vector_size=768 | |
| def ensure_collection(): | |
| try: | |
| if client.collection_exists(collection_name): | |
| print(f"Collection '{collection_name}' already exists") | |
| else: | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE), | |
| ) | |
| print(f"Collection '{collection_name}' created") | |
| except UnexpectedResponse as e: | |
| print(f"Error ensuring collection: {e}") | |
| raise | |
| ensure_collection() | |
| def batch_upsert_docling_chunks_to_qdrant( | |
| chunks_data: List[Dict[str, Any]], | |
| batch_size: int = 20 | |
| ) -> None: | |
| """ | |
| Upsert Docling chunks with embeddings and metadata to Qdrant in batches. | |
| Args: | |
| chunks_data: List of dicts with 'text', 'embedding', and 'metadata' keys. | |
| batch_size: Number of points to upsert per batch (default is 20). | |
| """ | |
| points = [] | |
| for i, chunk_data in enumerate(chunks_data): | |
| # Create payload with metadata | |
| payload = { | |
| "text": chunk_data["text"], | |
| "headings": chunk_data["metadata"]["headings"] | |
| } | |
| # Add other metadata fields if they exist | |
| if "page_number" in chunk_data["metadata"]: | |
| payload["page_number"] = chunk_data["metadata"]["page_number"] | |
| points.append( | |
| models.PointStruct( | |
| id=i, | |
| vector=chunk_data["embedding"], | |
| payload=payload | |
| ) | |
| ) | |
| total_points = len(points) | |
| for start_idx in range(0, total_points, batch_size): | |
| end_idx = min(start_idx + batch_size, total_points) | |
| batch_points = points[start_idx:end_idx] | |
| client.upsert(collection_name=collection_name, points=batch_points, wait=True) | |
| print(f"Upserted points {start_idx} to {end_idx - 1}") | |
| print(f"Total upserted points: {total_points}") | |
| # Alternative: Backward compatible version that accepts separate lists | |
| def batch_upsert_to_qdrant_with_metadata( | |
| content_list: List[str], | |
| embedding_list: List[List[float]], | |
| metadata_list: List[Dict[str, Any]], | |
| batch_size: int = 20 | |
| ) -> None: | |
| """ | |
| Upsert embedding-content pairs with metadata to Qdrant in batches. | |
| Args: | |
| content_list: List of text chunks. | |
| embedding_list: List of vectors corresponding to content. | |
| metadata_list: List of metadata dicts corresponding to content. | |
| batch_size: Number of points to upsert per batch (default is 20). | |
| """ | |
| if not (len(content_list) == len(embedding_list) == len(metadata_list)): | |
| raise ValueError("Content, embedding, and metadata list lengths do not match!") | |
| points = [] | |
| for i in range(len(content_list)): | |
| payload = { | |
| "text": content_list[i], | |
| "headings": metadata_list[i].get("headings", []) | |
| } | |
| # Add other metadata fields | |
| for key, value in metadata_list[i].items(): | |
| if key not in ["text", "headings"]: | |
| payload[key] = value | |
| points.append( | |
| models.PointStruct( | |
| id=i, | |
| vector=embedding_list[i], | |
| payload=payload | |
| ) | |
| ) | |
| total_points = len(points) | |
| for start_idx in range(0, total_points, batch_size): | |
| end_idx = min(start_idx + batch_size, total_points) | |
| batch_points = points[start_idx:end_idx] | |
| client.upsert(collection_name=collection_name, points=batch_points, wait=True) | |
| print(f"Upserted points {start_idx} to {end_idx - 1}") | |
| print(f"Total upserted points: {total_points}") | |
| def clear_collection_payloads(): | |
| """ | |
| Delete all points from an existing Qdrant collection, without deleting the collection itself. | |
| This empties the collection. | |
| Args: | |
| client: Initialized QdrantClient. | |
| collection_name: Name of the existing Qdrant collection. | |
| """ | |
| # Use empty filter to delete all points | |
| # Note: If your Qdrant server/client version doesn't support empty filter deletion, | |
| # alternatively you can delete by querying all IDs first and then deleting. | |
| client.delete_collection(collection_name) | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE), | |
| ) | |
| print("delete done") | |
| def search_in_qdrant( | |
| query: str, | |
| qdrant_client: QdrantClient = client, | |
| initial_limit: int = 5, | |
| length_threshold: int = 4500 | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Search qdrant with up to `initial_limit` results. | |
| If combined content length of top 3 is less than `length_threshold`, | |
| return only top 3 results, else return top `initial_limit`. | |
| Args: | |
| query (str): Query string. | |
| qdrant_client (QdrantClient): Initialized Qdrant client. | |
| initial_limit (int): Max number of results to retrieve (e.g., 5). | |
| length_threshold (int): Threshold in characters on combined content length. | |
| Returns: | |
| List[Dict]: List of scored points with payloads. | |
| """ | |
| # Get embedding for the query | |
| response = azure_client.embeddings.create( | |
| input=query, | |
| model="text-embedding-3-large", | |
| dimensions=768 | |
| ) | |
| query_embedding = response.data[0].embedding | |
| # Retrieve up to 'initial_limit' results | |
| search_results = qdrant_client.query_points( | |
| collection_name=collection_name, | |
| query=query_embedding, | |
| limit=initial_limit, | |
| with_payload=True, | |
| ) | |
| if not search_results.points: | |
| return [] | |
| # Extract top 3 results content length | |
| top_3 = search_results.points[:3] | |
| combined_length = sum( | |
| len(p.payload.get('text', '')) for p in top_3 | |
| ) | |
| # Decide final number of results to return | |
| final_limit = 3 if combined_length < length_threshold else initial_limit | |
| selected_points = search_results.points[:final_limit] | |
| results = [] | |
| for scored_point in selected_points: | |
| results.append({ | |
| "id": scored_point.id, | |
| "score": scored_point.score, | |
| "payload": scored_point.payload, | |
| }) | |
| return results | |
| # def get_context_for_questions(questions: List[str], top_k: int = 3) -> List[str]: | |
| # contexts = [] | |
| # for q in questions: | |
| # results = search_in_qdrant(q, top_k=top_k) | |
| # if results: | |
| # # Use text from the top hit (or join top_k if you wish) | |
| # context = results[0]['payload'].get('text', '') | |
| # else: | |
| # context = '' | |
| # contexts.append(context) | |
| # return contexts | |
| import concurrent.futures | |
| from typing import List | |
| def get_context_for_questions( | |
| questions: List[str], | |
| max_workers: int = 4 | |
| ) -> List[str]: | |
| def get_context_single(question: str) -> str: | |
| results = search_in_qdrant(question) | |
| if results: | |
| return results[0]['payload'].get('text', '') | |
| return '' | |
| contexts = [None] * len(questions) | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| # Submit all tasks, get future -> index mapping | |
| futures = {executor.submit(get_context_single, q): idx for idx, q in enumerate(questions)} | |
| for future in concurrent.futures.as_completed(futures): | |
| idx = futures[future] | |
| try: | |
| contexts[idx] = future.result() | |
| except Exception as e: | |
| print(f"Error fetching context for question {idx}: {e}") | |
| contexts[idx] = "" # fallback empty string | |
| return contexts | |
| if __name__=="__main__": | |
| # import time | |
| # t1=time.time() | |
| # clear_collection_payloads() | |
| # t2=time.time() | |
| # print("time taken", t2-t1) | |
| # questions= [ | |
| # "What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?", | |
| # "What is the waiting period for pre-existing diseases (PED) to be covered?", | |
| # "Does this policy cover maternity expenses, and what are the conditions?", | |
| # "What is the waiting period for cataract surgery?", | |
| # "Are the medical expenses for an organ donor covered under this policy?", | |
| # "What is the No Claim Discount (NCD) offered in this policy?", | |
| # "Is there a benefit for preventive health check-ups?", | |
| # "How does the policy define a 'Hospital'?", | |
| # "What is the extent of coverage for AYUSH treatments?", | |
| # "Are there any sub-limits on room rent and ICU charges for Plan A?" | |
| # ] | |
| # import time | |
| # t1=time.time() | |
| # print(get_context_for_questions(questions)==get_context_for_questions_parallel(questions)) | |
| # t2=time.time() | |
| # print("time taken",t2-t1) | |
| print(search_in_qdrant("What is the grace period for premium payment under the National Parivar Mediclaim Plus Policy?")) | |
| #print_all_points() |