File size: 3,410 Bytes
cf450f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import logging
import os
from pathlib import Path

from kaig.db import DB
from kaig.definitions import VectorTableDefinition
from kaig.embeddings import Embedder
from kaig.llm import LLM

from .definitions import EdgeTypes, Tables

logger = logging.getLogger(__name__)


def init_db(init_llm: bool, db_name: str, init_indexes: bool = True) -> DB:
    tables = [Tables.document.value, Tables.concept.value, Tables.page.value]
    vector_tables = [
        VectorTableDefinition(Tables.chunk.value, "HNSW", "COSINE"),
        VectorTableDefinition(Tables.concept.value, "HNSW", "COSINE"),
    ]

    if init_llm:
        logger.info("Init LLM...")
        llm_model = os.getenv("KG_LLM_MODEL", "alias-fast")
        fallback_env = os.getenv("KG_LLM_FALLBACK_MODELS")
        if fallback_env:
            fallback_models = [
                x.strip() for x in fallback_env.split(",") if x.strip()
            ]
        elif llm_model != "alias-fast":
            fallback_models = ["alias-fast"]
        else:
            fallback_models = ["alias-large"]
        llm = LLM(
            provider="openai",
            model=llm_model,
            temperature=1,
            fallback_models=fallback_models,
        )
    else:
        logger.info("Init without LLM")
        llm = None
    embedder_provider = os.getenv(
        "KG_EMBEDDINGS_PROVIDER", "sentence-transformers"
    ).lower()
    embedder_model = os.getenv("KG_EMBEDDINGS_MODEL", "alias-embeddings")
    if embedder_provider in {"sentence-transformers", "local", "hf"}:
        embedder = Embedder(
            provider="sentence-transformers",
            model_name=os.getenv(
                "KG_LOCAL_EMBEDDINGS_MODEL",
                "sentence-transformers/all-MiniLM-L6-v2",
            ),
            vector_type="F32",
        )
    else:
        try:
            embedder = Embedder(
                provider="openai",
                model_name=embedder_model,
                vector_type="F32",
            )
        except Exception as exc:
            logger.warning(
                "Embeddings init failed (%s). Falling back to local embeddings.",
                exc,
            )
            embedder = Embedder(
                provider="sentence-transformers",
                model_name=os.getenv(
                    "KG_LOCAL_EMBEDDINGS_MODEL",
                    "sentence-transformers/all-MiniLM-L6-v2",
                ),
                vector_type="F32",
            )

    # -- DB connection
    url = os.getenv("KG_DB_URL", "ws://localhost:8000/rpc")
    db_user = "root"
    db_pass = "root"
    db_ns = "kaig"
    db_db = db_name
    db = DB(
        url,
        db_user,
        db_pass,
        db_ns,
        db_db,
        embedder,
        llm,
        tables=tables,
        original_docs_table="document",
        vector_tables=vector_tables,
        graph_relations=[EdgeTypes.MENTIONS_CONCEPT.value],
    )
    if llm:
        llm.set_analytics(db.insert_analytics_data)

    # Remove this if you don't want to clear all your tables on every run
    # db.clear()

    surqls: list[str] = []
    for filename in ["schema.surql"]:
        file_path = Path(__file__).parent.parent.parent / "surql" / filename
        with open(file_path, "r") as file:
            surqls.append(file.read())

    for surql in surqls:
        _ = db.sync_conn.query(surql)
    db.init_db(force=init_indexes)

    return db