| | import os |
| | from string import Template |
| | from typing import Dict, List, Union |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from zhipuai import ZhipuAI |
| |
|
| |
|
| | class GLM: |
| | def __init__(self, model_name="silk-road/Haruhi-Zero-GLM3-6B-0_4"): |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, trust_remote_code=True |
| | ) |
| | client = AutoModelForCausalLM.from_pretrained( |
| | model_name, trust_remote_code=True, device_map="auto" |
| | ) |
| |
|
| | self.client = client.eval() |
| |
|
| | def message2query(self, messages) -> str: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | template = Template("<|$role|>\n$content\n") |
| |
|
| | return "".join([template.substitute(message) for message in messages]) |
| |
|
| | def get_response( |
| | self, |
| | message: Union[str, list[dict[str, str]]], |
| | history: List[Dict[str, str]] = None, |
| | ): |
| | if isinstance(message, str): |
| | response, history = self.client.chat(self.tokenizer, message) |
| | elif isinstance(message, list): |
| | response, history = self.client.chat( |
| | self.tokenizer, message[-1]["content"],history=message[:-1] |
| | ) |
| | |
| | print(response) |
| | return response |
| |
|
| |
|
| | class GLM_api: |
| | def __init__(self, model_name="glm-4"): |
| | API_KEY = os.environ.get("ZHIPU_API_KEY") |
| |
|
| | self.client = ZhipuAI(api_key=API_KEY) |
| | self.model = model_name |
| |
|
| | def chat(self, message): |
| | try: |
| | response = self.client.chat.completions.create( |
| | model=self.model, messages=message |
| | ) |
| | except Exception as e: |
| | print(e) |
| | return "樑εθΏζ₯ε€±θ΄₯" |
| |
|
| | return response.choices[0].message.content |
| |
|