File size: 4,046 Bytes
4ccf537
 
 
 
 
b5d3f34
4ccf537
 
 
 
 
 
 
 
b5d3f34
 
4ccf537
 
 
b5d3f34
4ccf537
b5d3f34
4ccf537
b5d3f34
 
 
 
 
 
 
 
 
 
 
 
 
 
4ccf537
b5d3f34
4ccf537
b5d3f34
 
 
 
 
 
4ccf537
 
 
 
b5d3f34
 
4ccf537
 
 
 
b5d3f34
 
4ccf537
 
 
 
 
 
 
b5d3f34
 
 
4ccf537
 
 
 
 
 
b5d3f34
 
4ccf537
 
 
 
b5d3f34
 
4ccf537
 
 
 
 
 
b5d3f34
 
4ccf537
b5d3f34
 
4ccf537
b5d3f34
 
4ccf537
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from abc import abstractmethod, ABC

from llama_index import ServiceContext, LLMPredictor, LangchainEmbedding, Document
from llama_index import StorageContext

from core.lifecycle import Lifecycle
from langchain_manager.manager import BaseLangChainManager


class ServiceContextManager(Lifecycle, ABC):

    @abstractmethod
    def get_service_context(self) -> ServiceContext:
        pass


class AzureServiceContextManager(ServiceContextManager):
    lc_manager: BaseLangChainManager
    service_context: ServiceContext

    def __init__(self, lc_manager: BaseLangChainManager):
        super().__init__()
        self.lc_manager = lc_manager

    def get_service_context(self) -> ServiceContext:
        if self.lifecycle_state.is_started():
            raise KeyError(
                "incorrect lifecycle state: {}".format(self.lifecycle_state.phase)
            )
        if self.service_context is None:
            raise ValueError(
                "service context is not ready, check for lifecycle statement"
            )
        return self.service_context

    def do_init(self) -> None:
        # define embedding
        embedding = LangchainEmbedding(self.lc_manager.get_embedding())
        # define LLM
        llm_predictor = LLMPredictor(llm=self.lc_manager.get_llm())
        # configure service context
        self.service_context = ServiceContext.from_defaults(
            llm_predictor=llm_predictor, embed_model=embedding
        )

    def do_start(self) -> None:
        self.logger.info("[do_start][embedding] last used usage: %d",
                         self.service_context.embed_model.total_tokens_used)
        self.logger.info("[do_start][predict] last used usage: %d",
                         self.service_context.llm_predictor.total_tokens_used)

    def do_stop(self) -> None:
        self.logger.info("[do_stop][embedding] last used usage: %d",
                         self.service_context.embed_model.total_tokens_used)
        self.logger.info("[do_stop][predict] last used usage: %d",
                         self.service_context.llm_predictor.total_tokens_used)

    def do_dispose(self) -> None:
        self.logger.info("[do_dispose] total used token: %d", self.service_context.llm_predictor.total_tokens_used)


class StorageContextManager(Lifecycle, ABC):

    @abstractmethod
    def get_storage_context(self) -> StorageContext:
        pass


class LocalStorageContextManager(StorageContextManager):
    storage_context: StorageContext

    def __init__(self,
                 dataset_path: str = "./dataset",
                 service_context_manager: ServiceContextManager = None) -> None:
        super().__init__()
        self.dataset_path = dataset_path
        self.service_context_manager = service_context_manager

    def get_storage_context(self) -> StorageContext:
        return self.storage_context

    def do_init(self) -> None:
        from llama.utils import is_local_storage_files_ready
        if is_local_storage_files_ready(self.dataset_path):
            self.storage_context = StorageContext.from_defaults(persist_dir=self.dataset_path)
        else:
            docs = self._download()
            self._indexing(docs)

    def do_start(self) -> None:
        self.logger.info("[do_start]%", **self.storage_context.to_dict())

    def do_stop(self) -> None:
        self.logger.info("[do_stop]%", **self.storage_context.to_dict())

    def do_dispose(self) -> None:
        self.storage_context.persist(self.dataset_path)

    def _download(self) -> [Document]:
        from llama.data_loader import GithubLoader
        loader = GithubLoader()
        return loader.load()

    def _indexing(self, docs: [Document]):
        from llama_index import GPTVectorStoreIndex
        index = GPTVectorStoreIndex.from_documents(docs,
                                                   service_context=self.service_context_manager.get_service_context())
        index.storage_context.persist(persist_dir=self.dataset_path)
        self.storage_context = index.storage_context