File size: 3,648 Bytes
d9762cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore, RetrievalMode
from qdrant_client import QdrantClient, models
import logging
import pickle
from pathlib import Path

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s — %(levelname)s — %(message)s',
)
logger = logging.getLogger(__name__)

def get_vectorstore() -> QdrantVectorStore:
    base_dir = Path(__file__).resolve().parent.parent
    doc_path = base_dir / 'data' / 'processed_data' / 'criminal_code_of_vietnam.pkl'

    with open(doc_path, 'rb') as f:
        doc_list = pickle.load(f)

    qdrant_api_key = os.getenv('QDRANT_API_KEY')
    qdrant_url = os.getenv('QDRANT_URL')
    hf_api_key = os.getenv('HUGGINGFACEHUB_API_TOKEN')

    collection_name = 'legal_db'
    client = QdrantClient(url=qdrant_url, api_key=qdrant_api_key)

    model_name = 'BAAI/bge-large-en'
    model_kwargs = {'device': 'cpu'}
    encode_kwargs = {'normalize_embeddings': False}

    embeddings = HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs
    )
    logger.info('Embedding created.')
    dummy_embedding = embeddings.embed_query('A dummy to test embedding dimension')
    vector_dim = len(dummy_embedding)

    vectors_config = models.VectorParams(size=vector_dim, distance=models.Distance.COSINE)

    if collection_name in [c.name for c in client.get_collections().collections]:
        logger.info('Collection exists. Connecting...')
        collection_info = client.get_collection(collection_name)
        existing_dim = None
        if hasattr(collection_info.config, 'vectors') and hasattr(collection_info.config.vectors, 'size'):
            existing_dim = collection_info.config.vectors.size
        elif hasattr(collection_info.config, 'params') and hasattr(collection_info.config.params, 'vectors') and hasattr(collection_info.config.params.vectors, 'size'):
            existing_dim = collection_info.config.params.vectors.size

        logger.info(f'Existing dimension: {existing_dim}')
        if existing_dim != vector_dim:
            raise ValueError(
                f'Dimension mismatch: existing collection has {existing_dim}, but embedding model gives {vector_dim}'
            )

        db = QdrantVectorStore.from_existing_collection(
            embedding=embeddings,
            collection_name=collection_name,
            prefer_grpc=False,
            url=qdrant_url,
            api_key = qdrant_api_key
        )
    else:
        logger.info(f'Collection "{collection_name}" does not exist. Creating new collection...')
        client.create_collection(
            collection_name=collection_name,
            vectors_config=vectors_config,

        )
        db = QdrantVectorStore.from_documents(
            documents=doc_list,
            embedding=embeddings,
            url=qdrant_url,
            prefer_grpc=False,
            collection_name=collection_name,
            retrieval_mode = RetrievalMode.DENSE,
            api_key = qdrant_api_key
        )
        logger.info('Qdrant Index created.')

        fields_to_index = {
            'metadata.article': "keyword",
            'metadata.chapter': "keyword",
            'metadata.id': "keyword",
            'metadata.source': "keyword",
            'metadata.title': "keyword",
        }

        for field, schema in fields_to_index.items():
            client.create_payload_index(
                collection_name = collection_name,
                field_name = field,
                field_schema = schema,
            )

    return db