File size: 3,550 Bytes
b3de77b
0827021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3de77b
0827021
 
 
 
b3de77b
 
 
 
0827021
 
 
 
 
 
 
 
 
2f81d82
 
 
0827021
 
 
 
 
 
 
 
 
 
 
 
 
 
2f81d82
0827021
 
 
 
 
 
2f81d82
0827021
 
 
 
2f81d82
 
0827021
2f81d82
0827021
 
 
 
 
 
2f81d82
 
0827021
 
 
 
b3de77b
0827021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3de77b
 
 
0827021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3de77b
0827021
 
 
 
 
b3de77b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from pymilvus import DataType, MilvusClient


def get_milvus_client(db_path: str) -> MilvusClient:
    """
    Get a Milvus client.

    Args:
        db_path: The path to the Milvus database

    Returns:
        A Milvus client
    """
    try:
        client = MilvusClient(db_path)
        return client

    except Exception as e:
        print(f"Error getting Milvus client: {e}")
        return None


def create_collection_if_not_exists(
    client: MilvusClient, collection_name: str, dim: int
) -> None:
    """
    Create a collection in Milvus if it does not exist.

    Args:
        client: The Milvus client
        collection_name: The name of the collection to create
        dim: The dimension of the binary vector
    """
    try:
        # Create collection only if it does not exist
        if not client.has_collection(collection_name):
            print(f"Collection {collection_name} not found. Creating it...")

            schema = client.create_schema(
                auto_id=True,
                enable_dynamic_fields=True,
            )
            schema.add_field(
                field_name="id",
                datatype=DataType.INT64,
                is_primary=True,
                auto_id=True,
            )
            schema.add_field(
                field_name="context",
                datatype=DataType.VARCHAR,
                max_length=65535,
            )
            schema.add_field(
                field_name="binary_vector",
                datatype=DataType.BINARY_VECTOR,
                dim=dim,
            )

            index_params = client.prepare_index_params()
            index_params.add_index(
                field_name="binary_vector",
                index_name="binary_vector_index",
                index_type="BIN_FLAT",
                metric_type="HAMMING",
            )

            client.create_collection(
                collection_name=collection_name,
                schema=schema,
                index_params=index_params,
            )
            print(f"Collection {collection_name} created successfully.")
        else:
            print(f"Collection {collection_name} already exists. Skipping creation.")
    except Exception as e:
        print(f"Error creating collection: {e}")
        return None


def insert_data(client: MilvusClient, collection_name: str, data: list[dict]):
    """
    Insert data into a collection in Milvus.

    Args:
        client: The Milvus client
        collection_name: The name of the collection to insert data into
        data: The data to insert
    """
    try:
        client.insert(
            collection_name=collection_name,
            data=data,
        )
    except Exception as e:
        print(f"Error inserting data: {e}")


def search(
    client: MilvusClient, collection_name: str, binary_query: bytes, limit: int = 5
):
    """
    Search for data in a collection in Milvus.
    """
    try:
        # Search for data
        results = client.search(
            collection_name=collection_name,
            data=[binary_query],
            anns_field="binary_vector",
            search_params={
                "metric_type": "HAMMING",
            },
            output_fields=["context"],
            limit=limit,
        )

        if not results:
            print("No search results found")
            return []

        contexts = [res.entity.context for res in results[0]]
        return contexts

    except Exception as e:
        print(f"Error searching for data: {e}")
        return []