File size: 1,898 Bytes
beb2111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from pymilvus import (
    connections,
    Collection,
    CollectionSchema,
    FieldSchema,
    DataType,
    utility,
)
from dotenv import load_dotenv
import os
from config import DOCS_INDEX_NAME
import logging

logger = logging.getLogger("backend")

load_dotenv()


def connect_db():
    """连接到 Zilliz Cloud"""
    uri = os.getenv("ZILLIZ_CLOUD_URI")
    token = os.getenv("ZILLIZ_CLOUD_TOKEN")

    logger.info(f"Connecting to DB: {uri}")
    connections.connect(alias="default", uri=uri, token=token)
    logger.info("Success!")


def create_schema_if_not_exists():
    try:
        connect_db()

        # 定义 collection schema
        fields = [
            FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
            FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=3000),
            FieldSchema(name="publish_time", dtype=DataType.VARCHAR, max_length=50),
            FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),
        ]

        schema = CollectionSchema(
            fields=fields, description="News documents collection"
        )

        # 检查 collection 是否存在
        if not utility.has_collection(DOCS_INDEX_NAME):
            collection = Collection(name=DOCS_INDEX_NAME, schema=schema)

            # 创建索引
            index_params = {
                "metric_type": "COSINE",
                "index_type": "IVF_FLAT",
                "params": {"nlist": 1024},
            }
            collection.create_index(field_name="embedding", index_params=index_params)

            logger.info(f"已创建collection和索引: {DOCS_INDEX_NAME}")
        else:
            logger.warning(f"Collection {DOCS_INDEX_NAME} 已存在")

    except Exception as e:
        logger.error(f"创建collection时出错: {str(e)}")
        raise
    finally:
        connections.disconnect("default")