Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_qdrant import QdrantVectorStore, RetrievalMode | |
| from qdrant_client import QdrantClient, models | |
| import logging | |
| import pickle | |
| from pathlib import Path | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s β %(levelname)s β %(message)s', | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def get_vectorstore() -> QdrantVectorStore: | |
| base_dir = Path(__file__).resolve().parent.parent | |
| doc_path = base_dir / 'data' / 'processed_data' / 'criminal_code_of_vietnam.pkl' | |
| with open(doc_path, 'rb') as f: | |
| doc_list = pickle.load(f) | |
| qdrant_api_key = os.getenv('QDRANT_API_KEY') | |
| qdrant_url = os.getenv('QDRANT_URL') | |
| hf_api_key = os.getenv('HUGGINGFACEHUB_API_TOKEN') | |
| collection_name = 'legal_db' | |
| client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key) | |
| model_name = 'BAAI/bge-large-en' | |
| model_kwargs = {'device': 'cpu'} | |
| encode_kwargs = {'normalize_embeddings': False} | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs | |
| ) | |
| logger.info('Embedding created.') | |
| dummy_embedding = embeddings.embed_query('A dummy to test embedding dimension') | |
| vector_dim = len(dummy_embedding) | |
| vectors_config = models.VectorParams(size=vector_dim, distance=models.Distance.COSINE) | |
| if collection_name in [c.name for c in client.get_collections().collections]: | |
| logger.info('Collection exists. Connecting...') | |
| collection_info = client.get_collection(collection_name) | |
| existing_dim = None | |
| if hasattr(collection_info.config, 'vectors') and hasattr(collection_info.config.vectors, 'size'): | |
| existing_dim = collection_info.config.vectors.size | |
| elif hasattr(collection_info.config, 'params') and hasattr(collection_info.config.params, 'vectors') and hasattr(collection_info.config.params.vectors, 'size'): | |
| existing_dim = collection_info.config.params.vectors.size | |
| logger.info(f'Existing dimension: {existing_dim}') | |
| if existing_dim != vector_dim: | |
| raise ValueError( | |
| f'Dimension mismatch: existing collection has {existing_dim}, but embedding model gives {vector_dim}' | |
| ) | |
| db = QdrantVectorStore.from_existing_collection( | |
| embedding=embeddings, | |
| collection_name=collection_name, | |
| prefer_grpc=False, | |
| url=qdrant_url, | |
| api_key = qdrant_api_key | |
| ) | |
| else: | |
| logger.info(f'Collection "{collection_name}" does not exist. Creating new collection...') | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=vectors_config, | |
| ) | |
| db = QdrantVectorStore.from_documents( | |
| documents=doc_list, | |
| embedding=embeddings, | |
| url=qdrant_url, | |
| prefer_grpc=False, | |
| collection_name=collection_name, | |
| retrieval_mode = RetrievalMode.DENSE, | |
| api_key = qdrant_api_key | |
| ) | |
| logger.info('Qdrant Index created.') | |
| fields_to_index = { | |
| 'metadata.article': "keyword", | |
| 'metadata.chapter': "keyword", | |
| 'metadata.id': "keyword", | |
| 'metadata.source': "keyword", | |
| 'metadata.title': "keyword", | |
| } | |
| for field, schema in fields_to_index.items(): | |
| client.create_payload_index( | |
| collection_name = collection_name, | |
| field_name = field, | |
| field_schema = schema, | |
| ) | |
| return db | |