Spaces:
Sleeping
Sleeping
File size: 3,648 Bytes
d9762cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore, RetrievalMode
from qdrant_client import QdrantClient, models
import logging
import pickle
from pathlib import Path
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s — %(levelname)s — %(message)s',
)
logger = logging.getLogger(__name__)
def get_vectorstore() -> QdrantVectorStore:
base_dir = Path(__file__).resolve().parent.parent
doc_path = base_dir / 'data' / 'processed_data' / 'criminal_code_of_vietnam.pkl'
with open(doc_path, 'rb') as f:
doc_list = pickle.load(f)
qdrant_api_key = os.getenv('QDRANT_API_KEY')
qdrant_url = os.getenv('QDRANT_URL')
hf_api_key = os.getenv('HUGGINGFACEHUB_API_TOKEN')
collection_name = 'legal_db'
client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)
model_name = 'BAAI/bge-large-en'
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
embeddings = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
logger.info('Embedding created.')
dummy_embedding = embeddings.embed_query('A dummy to test embedding dimension')
vector_dim = len(dummy_embedding)
vectors_config = models.VectorParams(size=vector_dim, distance=models.Distance.COSINE)
if collection_name in [c.name for c in client.get_collections().collections]:
logger.info('Collection exists. Connecting...')
collection_info = client.get_collection(collection_name)
existing_dim = None
if hasattr(collection_info.config, 'vectors') and hasattr(collection_info.config.vectors, 'size'):
existing_dim = collection_info.config.vectors.size
elif hasattr(collection_info.config, 'params') and hasattr(collection_info.config.params, 'vectors') and hasattr(collection_info.config.params.vectors, 'size'):
existing_dim = collection_info.config.params.vectors.size
logger.info(f'Existing dimension: {existing_dim}')
if existing_dim != vector_dim:
raise ValueError(
f'Dimension mismatch: existing collection has {existing_dim}, but embedding model gives {vector_dim}'
)
db = QdrantVectorStore.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
prefer_grpc=False,
url=qdrant_url,
api_key = qdrant_api_key
)
else:
logger.info(f'Collection "{collection_name}" does not exist. Creating new collection...')
client.create_collection(
collection_name=collection_name,
vectors_config=vectors_config,
)
db = QdrantVectorStore.from_documents(
documents=doc_list,
embedding=embeddings,
url=qdrant_url,
prefer_grpc=False,
collection_name=collection_name,
retrieval_mode = RetrievalMode.DENSE,
api_key = qdrant_api_key
)
logger.info('Qdrant Index created.')
fields_to_index = {
'metadata.article': "keyword",
'metadata.chapter': "keyword",
'metadata.id': "keyword",
'metadata.source': "keyword",
'metadata.title': "keyword",
}
for field, schema in fields_to_index.items():
client.create_payload_index(
collection_name = collection_name,
field_name = field,
field_schema = schema,
)
return db
|