Spaces:
Sleeping
Sleeping
File size: 4,725 Bytes
257bc0d bd59653 257bc0d 175a385 257bc0d bd59653 257bc0d 175a385 257bc0d 374292c 257bc0d 175a385 257bc0d 175a385 374292c 6c20719 175a385 374292c 175a385 257bc0d 175a385 257bc0d 175a385 6c20719 bd59653 6c20719 175a385 374292c 01d6691 175a385 374292c 01d6691 5ea412d bd59653 5ea412d 01d6691 175a385 374292c 01d6691 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | from abc import abstractmethod, ABC
from llama_index import ServiceContext, LLMPredictor, LangchainEmbedding
from core.lifecycle import Lifecycle
from langchain_manager.manager import BaseLangChainManager
# def get_callback_manager() -> CallbackManager:
# from llama_index.callbacks import (
# WandbCallbackHandler,
# CallbackManager,
# LlamaDebugHandler,
# )
# llama_debug = LlamaDebugHandler(print_trace_on_end=True)
# # wandb.init args
# run_args = dict(
# project="llamaindex",
# )
# wandb_callback = WandbCallbackHandler(run_args=run_args)
# return CallbackManager([llama_debug, wandb_callback])
class ServiceContextManager(Lifecycle, ABC):
@abstractmethod
def get_service_context(self) -> ServiceContext:
pass
class AzureServiceContextManager(ServiceContextManager):
lc_manager: BaseLangChainManager
service_context: ServiceContext
def __init__(self, lc_manager: BaseLangChainManager):
super().__init__()
self.lc_manager = lc_manager
def get_service_context(self) -> ServiceContext:
if self.service_context is None:
raise ValueError(
"service context is not ready, check for lifecycle statement"
)
return self.service_context
def do_init(self) -> None:
# define embedding
embedding = LangchainEmbedding(self.lc_manager.get_embedding())
# define LLM
llm_predictor = LLMPredictor(llm=self.lc_manager.get_llm())
# configure service context
self.service_context = ServiceContext.from_defaults(
llm_predictor=llm_predictor,
embed_model=embedding,
# callback_manager=get_callback_manager(),
)
def do_start(self) -> None:
self.logger.info(
"[do_start][embedding] last used usage: %d",
self.service_context.embed_model.total_tokens_used,
)
self.logger.info(
"[do_start][predict] last used usage: %d",
self.service_context.llm_predictor.total_tokens_used,
)
def do_stop(self) -> None:
self.logger.info(
"[do_stop][embedding] last used usage: %d",
self.service_context.embed_model.total_tokens_used,
)
self.logger.info(
"[do_stop][predict] last used usage: %d",
self.service_context.llm_predictor.total_tokens_used,
)
def do_dispose(self) -> None:
self.logger.info(
"[do_dispose] total used token: %d",
self.service_context.llm_predictor.total_tokens_used,
)
class HuggingFaceChineseOptServiceContextManager(ServiceContextManager):
lc_manager: BaseLangChainManager
service_context: ServiceContext
def __init__(self, lc_manager: BaseLangChainManager):
super().__init__()
self.lc_manager = lc_manager
def get_service_context(self) -> ServiceContext:
if self.service_context is None:
raise ValueError(
"service context is not ready, check for lifecycle statement"
)
return self.service_context
def do_init(self) -> None:
# define embedding
from langchain.embeddings import HuggingFaceEmbeddings
model_name = "GanymedeNil/text2vec-large-chinese"
hf_embedding = HuggingFaceEmbeddings(
model_name=model_name, model_kwargs={"device": "cpu"}
)
embedding = LangchainEmbedding(hf_embedding)
# define LLM
llm_predictor = LLMPredictor(self.lc_manager.get_llm())
# configure service context
self.service_context = ServiceContext.from_defaults(
llm_predictor=llm_predictor,
embed_model=embedding,
# callback_manager=get_callback_manager()
)
def do_start(self) -> None:
self.logger.info(
"[do_start][embedding] last used usage: %d",
self.service_context.embed_model.total_tokens_used,
)
self.logger.info(
"[do_start][predict] last used usage: %d",
self.service_context.llm_predictor.total_tokens_used,
)
def do_stop(self) -> None:
self.logger.info(
"[do_stop][embedding] last used usage: %d",
self.service_context.embed_model.total_tokens_used,
)
self.logger.info(
"[do_stop][predict] last used usage: %d",
self.service_context.llm_predictor.total_tokens_used,
)
def do_dispose(self) -> None:
self.logger.info(
"[do_dispose] total used token: %d",
self.service_context.llm_predictor.total_tokens_used,
)
|