| from typing import Any, List, Mapping, Optional |
| from langchain.callbacks.manager import CallbackManagerForLLMRun |
| from langchain.llms.base import LLM |
| import chatglm_cpp |
|
|
| from langchain import PromptTemplate, LLMChain |
| from langchain.callbacks.manager import CallbackManager |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler |
|
|
| DEFAULT_MODEL_PATH = "chatglm2-6b-ggml.q8_0.bin" |
|
|
| callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) |
| pipeline = chatglm_cpp.Pipeline(DEFAULT_MODEL_PATH) |
|
|
| class ChatGLM(LLM): |
| temperature: float = 0.7 |
| base_model: str = DEFAULT_MODEL_PATH |
| max_length: int = 2048 |
| verbose: bool = False |
| streaming: bool = False |
| top_p: float = 0.9 |
| top_k: int = 0 |
| max_context_length: int = 512 |
| threads: int = 0 |
|
|
| @property |
| def _llm_type(self) -> str: |
| return "chatglm" |
|
|
| def _call(self, prompt: str, stop: Optional[List[str]] = None, |
| run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: |
| if stop is not None: |
| raise ValueError("stop kwargs are not permitted.") |
| print("Prompt: ", prompt) |
| history = [prompt] |
| response = "" |
| if self.streaming: |
| for piece in pipeline.stream_chat( |
| history, |
| max_length=self.max_length, |
| max_context_length=self.max_context_length, |
| do_sample=self.temperature > 0, |
| top_k=self.top_k, |
| top_p=self.top_p, |
| temperature=self.temperature, |
| num_threads=self.threads, |
| ): |
| response += piece |
| return response |
| |
| |
| |
| |
| else: |
| response = pipeline.chat( |
| history, |
| max_length=self.max_length, |
| max_context_length=self.max_context_length, |
| do_sample=self.temperature > 0, |
| top_k=self.top_k, |
| top_p=self.top_p, |
| temperature=self.temperature, |
| num_threads=self.threads, |
| ) |
| return response |
|
|
| @property |
| def _identifying_params(self) -> Mapping[str, Any]: |
| """Get the identifying parameters.""" |
| return {"temperature": self.temperature, |
| "base_model": self.base_model, |
| "max_length": self.max_length, |
| "verbose": self.verbose, |
| "streaming": self.streaming, |
| "top_p": self.top_p, |
| "top_k": self.top_k, |
| "max_context_length": self.max_context_length, |
| "threads": self.threads} |
|
|
|
|
| template = "小明的妈妈有两个孩子,一个叫大明 {question}" |
| prompt = PromptTemplate(template=template, input_variables=["question"]) |
| question = "另外一个叫什么?" |
| llm = ChatGLM(streaming=False, callback_manager=callback_manager, show_progress=True) |
| llm_chain = LLMChain(prompt=prompt, llm=llm) |
| print(llm_chain.run(question)) |