File size: 3,905 Bytes
f745cab
 
 
 
 
b5d3f34
f745cab
 
 
 
 
 
 
 
b5d3f34
 
f745cab
 
 
b5d3f34
f745cab
b5d3f34
f745cab
b5d3f34
 
 
 
 
 
 
 
 
 
f745cab
b5d3f34
f745cab
b5d3f34
 
 
 
 
 
f745cab
 
 
 
b5d3f34
 
f745cab
 
 
 
b5d3f34
 
f745cab
 
 
 
 
 
 
b5d3f34
 
 
f745cab
 
 
 
 
 
b5d3f34
 
f745cab
 
 
 
b5d3f34
 
f745cab
 
 
 
 
 
b5d3f34
 
8181fcd
 
b5d3f34
 
8181fcd
 
b5d3f34
 
f745cab
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from abc import abstractmethod, ABC

from llama_index import ServiceContext, LLMPredictor, LangchainEmbedding, Document
from llama_index import StorageContext

from core.lifecycle import Lifecycle
from langchain_manager.manager import BaseLangChainManager


class ServiceContextManager(Lifecycle, ABC):

    @abstractmethod
    def get_service_context(self) -> ServiceContext:
        pass


class AzureServiceContextManager(ServiceContextManager):
    lc_manager: BaseLangChainManager
    service_context: ServiceContext

    def __init__(self, lc_manager: BaseLangChainManager):
        super().__init__()
        self.lc_manager = lc_manager

    def get_service_context(self) -> ServiceContext:
        if self.service_context is None:
            raise ValueError(
                "service context is not ready, check for lifecycle statement"
            )
        return self.service_context

    def do_init(self) -> None:
        # define embedding
        embedding = LangchainEmbedding(self.lc_manager.get_embedding())
        # define LLM
        llm_predictor = LLMPredictor(llm=self.lc_manager.get_llm())
        # configure service context
        self.service_context = ServiceContext.from_defaults(
            llm_predictor=llm_predictor, embed_model=embedding
        )

    def do_start(self) -> None:
        self.logger.info("[do_start][embedding] last used usage: %d",
                         self.service_context.embed_model.total_tokens_used)
        self.logger.info("[do_start][predict] last used usage: %d",
                         self.service_context.llm_predictor.total_tokens_used)

    def do_stop(self) -> None:
        self.logger.info("[do_stop][embedding] last used usage: %d",
                         self.service_context.embed_model.total_tokens_used)
        self.logger.info("[do_stop][predict] last used usage: %d",
                         self.service_context.llm_predictor.total_tokens_used)

    def do_dispose(self) -> None:
        self.logger.info("[do_dispose] total used token: %d", self.service_context.llm_predictor.total_tokens_used)


class StorageContextManager(Lifecycle, ABC):

    @abstractmethod
    def get_storage_context(self) -> StorageContext:
        pass


class LocalStorageContextManager(StorageContextManager):
    storage_context: StorageContext

    def __init__(self,
                 dataset_path: str = "./dataset",
                 service_context_manager: ServiceContextManager = None) -> None:
        super().__init__()
        self.dataset_path = dataset_path
        self.service_context_manager = service_context_manager

    def get_storage_context(self) -> StorageContext:
        return self.storage_context

    def do_init(self) -> None:
        from llama.utils import is_local_storage_files_ready
        if is_local_storage_files_ready(self.dataset_path):
            self.storage_context = StorageContext.from_defaults(persist_dir=self.dataset_path)
        else:
            docs = self._download()
            self._indexing(docs)

    def do_start(self) -> None:
        # self.logger.info("[do_start]%", **self.storage_context.to_dict())
        pass

    def do_stop(self) -> None:
        # self.logger.info("[do_stop]%", **self.storage_context.to_dict())
        pass

    def do_dispose(self) -> None:
        self.storage_context.persist(self.dataset_path)

    def _download(self) -> [Document]:
        from llama.data_loader import GithubLoader
        loader = GithubLoader()
        return loader.load()

    def _indexing(self, docs: [Document]):
        from llama_index import GPTVectorStoreIndex
        index = GPTVectorStoreIndex.from_documents(docs,
                                                   service_context=self.service_context_manager.get_service_context())
        index.storage_context.persist(persist_dir=self.dataset_path)
        self.storage_context = index.storage_context