File size: 4,725 Bytes
257bc0d
 
bd59653
257bc0d
175a385
257bc0d
 
 
bd59653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257bc0d
 
 
 
175a385
 
257bc0d
 
 
374292c
257bc0d
175a385
257bc0d
175a385
 
374292c
 
 
6c20719
175a385
 
374292c
175a385
257bc0d
175a385
257bc0d
175a385
6c20719
bd59653
 
 
6c20719
175a385
374292c
01d6691
 
 
 
 
 
 
 
175a385
374292c
01d6691
 
 
 
 
 
 
5ea412d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd59653
 
 
5ea412d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d6691
175a385
374292c
01d6691
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from abc import abstractmethod, ABC

from llama_index import ServiceContext, LLMPredictor, LangchainEmbedding

from core.lifecycle import Lifecycle
from langchain_manager.manager import BaseLangChainManager


# def get_callback_manager() -> CallbackManager:
#     from llama_index.callbacks import (
#         WandbCallbackHandler,
#         CallbackManager,
#         LlamaDebugHandler,
#     )
#     llama_debug = LlamaDebugHandler(print_trace_on_end=True)
#     # wandb.init args
#     run_args = dict(
#         project="llamaindex",
#     )
#     wandb_callback = WandbCallbackHandler(run_args=run_args)
#     return CallbackManager([llama_debug, wandb_callback])


class ServiceContextManager(Lifecycle, ABC):
    @abstractmethod
    def get_service_context(self) -> ServiceContext:
        pass


class AzureServiceContextManager(ServiceContextManager):
    lc_manager: BaseLangChainManager
    service_context: ServiceContext

    def __init__(self, lc_manager: BaseLangChainManager):
        super().__init__()
        self.lc_manager = lc_manager

    def get_service_context(self) -> ServiceContext:
        if self.service_context is None:
            raise ValueError(
                "service context is not ready, check for lifecycle statement"
            )
        return self.service_context

    def do_init(self) -> None:
        # define embedding
        embedding = LangchainEmbedding(self.lc_manager.get_embedding())
        # define LLM
        llm_predictor = LLMPredictor(llm=self.lc_manager.get_llm())
        # configure service context
        self.service_context = ServiceContext.from_defaults(
            llm_predictor=llm_predictor,
            embed_model=embedding,
            # callback_manager=get_callback_manager(),
        )

    def do_start(self) -> None:
        self.logger.info(
            "[do_start][embedding] last used usage: %d",
            self.service_context.embed_model.total_tokens_used,
        )
        self.logger.info(
            "[do_start][predict] last used usage: %d",
            self.service_context.llm_predictor.total_tokens_used,
        )

    def do_stop(self) -> None:
        self.logger.info(
            "[do_stop][embedding] last used usage: %d",
            self.service_context.embed_model.total_tokens_used,
        )
        self.logger.info(
            "[do_stop][predict] last used usage: %d",
            self.service_context.llm_predictor.total_tokens_used,
        )

    def do_dispose(self) -> None:
        self.logger.info(
            "[do_dispose] total used token: %d",
            self.service_context.llm_predictor.total_tokens_used,
        )


class HuggingFaceChineseOptServiceContextManager(ServiceContextManager):
    lc_manager: BaseLangChainManager
    service_context: ServiceContext

    def __init__(self, lc_manager: BaseLangChainManager):
        super().__init__()
        self.lc_manager = lc_manager

    def get_service_context(self) -> ServiceContext:
        if self.service_context is None:
            raise ValueError(
                "service context is not ready, check for lifecycle statement"
            )
        return self.service_context

    def do_init(self) -> None:
        # define embedding
        from langchain.embeddings import HuggingFaceEmbeddings

        model_name = "GanymedeNil/text2vec-large-chinese"
        hf_embedding = HuggingFaceEmbeddings(
            model_name=model_name, model_kwargs={"device": "cpu"}
        )

        embedding = LangchainEmbedding(hf_embedding)
        # define LLM
        llm_predictor = LLMPredictor(self.lc_manager.get_llm())
        # configure service context
        self.service_context = ServiceContext.from_defaults(
            llm_predictor=llm_predictor,
            embed_model=embedding,
            # callback_manager=get_callback_manager()
        )

    def do_start(self) -> None:
        self.logger.info(
            "[do_start][embedding] last used usage: %d",
            self.service_context.embed_model.total_tokens_used,
        )
        self.logger.info(
            "[do_start][predict] last used usage: %d",
            self.service_context.llm_predictor.total_tokens_used,
        )

    def do_stop(self) -> None:
        self.logger.info(
            "[do_stop][embedding] last used usage: %d",
            self.service_context.embed_model.total_tokens_used,
        )
        self.logger.info(
            "[do_stop][predict] last used usage: %d",
            self.service_context.llm_predictor.total_tokens_used,
        )

    def do_dispose(self) -> None:
        self.logger.info(
            "[do_dispose] total used token: %d",
            self.service_context.llm_predictor.total_tokens_used,
        )