wang-run commited on
Commit
ceb607c
·
verified ·
1 Parent(s): 83eb413

Update zhipuai_LLM.py

Browse files
Files changed (1) hide show
  1. zhipuai_LLM.py +164 -161
zhipuai_LLM.py CHANGED
@@ -1,161 +1,164 @@
1
- from typing import Any, Dict, Iterator, List, Optional, Union
2
- import os
3
- import time
4
- from zhipuai import ZhipuAI
5
- from langchain_core.callbacks import CallbackManagerForLLMRun
6
- from langchain_core.language_models import BaseChatModel
7
- from langchain_core.messages import (
8
- AIMessage,
9
- AIMessageChunk,
10
- BaseMessage,
11
- SystemMessage,
12
- ChatMessage,
13
- HumanMessage
14
- )
15
- from langchain_core.messages.ai import UsageMetadata
16
- from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
17
-
18
- def _convert_message_to_dict(message: Union[BaseMessage, dict, tuple]) -> dict:
19
- role = "user"
20
- content = ""
21
-
22
- if isinstance(message, tuple) and len(message) == 2:
23
- msg_type, content = message
24
- if msg_type == "system":
25
- role = "system"
26
- elif msg_type in ["ai", "assistant"]:
27
- role = "assistant"
28
- else:
29
- role = "user"
30
-
31
- elif isinstance(message, dict):
32
- msg_type = message.get("role", "user")
33
- content = message.get("content", "")
34
- if msg_type == "system":
35
- role = "system"
36
- elif msg_type in ["ai", "assistant"]:
37
- role = "assistant"
38
- else:
39
- role = "user"
40
-
41
- elif isinstance(message, BaseMessage):
42
- content = message.content
43
- if isinstance(message, ChatMessage):
44
- role = message.role
45
- elif isinstance(message, HumanMessage):
46
- role = "user"
47
- elif isinstance(message, AIMessage):
48
- role = "assistant"
49
- elif isinstance(message, SystemMessage):
50
- role = "system"
51
- else:
52
- role = "user"
53
- else:
54
- content = str(message)
55
-
56
- return {"role": role, "content": content}
57
-
58
- class ZhipuaiLLM(BaseChatModel):
59
- model_name: str = "glm-4-flash"
60
- temperature: Optional[float] = 0.1
61
- max_tokens: Optional[int] = None
62
- timeout: Optional[int] = None
63
- stop: Optional[List[str]] = None
64
- max_retries: int = 3
65
- api_key: str | None = None
66
-
67
- def _get_client(self) -> ZhipuAI:
68
- current_api_key = self.api_key or os.environ.get("ZHIPUAI_API_KEY")
69
- return ZhipuAI(api_key=current_api_key)
70
-
71
- def _generate(
72
- self,
73
- messages: List[Any],
74
- stop: Optional[List[str]] = None,
75
- run_manager: Optional[CallbackManagerForLLMRun] = None,
76
- **kwargs: Any,
77
- ) -> ChatResult:
78
- zhipu_messages = [_convert_message_to_dict(message) for message in messages]
79
- start_time = time.time()
80
-
81
- client = self._get_client()
82
- response = client.chat.completions.create(
83
- model=self.model_name,
84
- temperature=self.temperature,
85
- max_tokens=self.max_tokens,
86
- timeout=self.timeout,
87
- stop=stop,
88
- messages=zhipu_messages,
89
- **kwargs
90
- )
91
-
92
- time_in_seconds = time.time() - start_time
93
- message = AIMessage(
94
- content=response.choices[0].message.content,
95
- additional_kwargs={},
96
- response_metadata={"time_in_seconds": round(time_in_seconds, 3)},
97
- usage_metadata={
98
- "input_tokens": response.usage.prompt_tokens,
99
- "output_tokens": response.usage.completion_tokens,
100
- "total_tokens": response.usage.total_tokens,
101
- },
102
- )
103
- return ChatResult(generations=[ChatGeneration(message=message)])
104
-
105
- def _stream(
106
- self,
107
- messages: List[Any],
108
- stop: Optional[List[str]] = None,
109
- run_manager: Optional[CallbackManagerForLLMRun] = None,
110
- **kwargs: Any,
111
- ) -> Iterator[ChatGenerationChunk]:
112
- zhipu_messages = [_convert_message_to_dict(message) for message in messages]
113
- start_time = time.time()
114
-
115
- client = self._get_client()
116
- response = client.chat.completions.create(
117
- model=self.model_name,
118
- stream=True,
119
- temperature=self.temperature,
120
- max_tokens=self.max_tokens,
121
- timeout=self.timeout,
122
- stop=stop,
123
- messages=zhipu_messages,
124
- **kwargs
125
- )
126
-
127
- usage_metadata = None
128
- for res in response:
129
- if hasattr(res, 'usage') and res.usage:
130
- usage_metadata = UsageMetadata({
131
- "input_tokens": getattr(res.usage, 'prompt_tokens', 0),
132
- "output_tokens": getattr(res.usage, 'completion_tokens', 0),
133
- "total_tokens": getattr(res.usage, 'total_tokens', 0),
134
- })
135
-
136
- chunk_content = res.choices[0].delta.content if res.choices and res.choices[0].delta.content else ""
137
- chunk = ChatGenerationChunk(message=AIMessageChunk(content=chunk_content))
138
-
139
- if run_manager and chunk_content:
140
- run_manager.on_llm_new_token(chunk_content, chunk=chunk)
141
-
142
- yield chunk
143
-
144
- time_in_sec = time.time() - start_time
145
- final_chunk = ChatGenerationChunk(
146
- message=AIMessageChunk(
147
- content="",
148
- response_metadata={"time_in_sec": round(time_in_sec, 3)},
149
- usage_metadata=usage_metadata
150
- )
151
- )
152
- if run_manager:
153
- run_manager.on_llm_new_token("", chunk=final_chunk)
154
- yield final_chunk
155
-
156
- @property
157
- def _llm_type(self) -> str:
158
- return self.model_name
159
-
160
- @property
161
- def _identifying_params(self) -> Dict[str,
 
 
 
 
1
+ from typing import Any, Dict, Iterator, List, Optional, Union
2
+ import os
3
+ import time
4
+ from zhipuai import ZhipuAI
5
+ from langchain_core.callbacks import CallbackManagerForLLMRun
6
+ from langchain_core.language_models import BaseChatModel
7
+ from langchain_core.messages import (
8
+ AIMessage,
9
+ AIMessageChunk,
10
+ BaseMessage,
11
+ SystemMessage,
12
+ ChatMessage,
13
+ HumanMessage
14
+ )
15
+ from langchain_core.messages.ai import UsageMetadata
16
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
17
+
18
+ def _convert_message_to_dict(message: Union[BaseMessage, dict, tuple]) -> dict:
19
+ role = "user"
20
+ content = ""
21
+
22
+ if isinstance(message, tuple) and len(message) == 2:
23
+ msg_type, content = message
24
+ if msg_type == "system":
25
+ role = "system"
26
+ elif msg_type in ["ai", "assistant"]:
27
+ role = "assistant"
28
+ else:
29
+ role = "user"
30
+
31
+ elif isinstance(message, dict):
32
+ msg_type = message.get("role", "user")
33
+ content = message.get("content", "")
34
+ if msg_type == "system":
35
+ role = "system"
36
+ elif msg_type in ["ai", "assistant"]:
37
+ role = "assistant"
38
+ else:
39
+ role = "user"
40
+
41
+ elif isinstance(message, BaseMessage):
42
+ content = message.content
43
+ if isinstance(message, ChatMessage):
44
+ role = message.role
45
+ elif isinstance(message, HumanMessage):
46
+ role = "user"
47
+ elif isinstance(message, AIMessage):
48
+ role = "assistant"
49
+ elif isinstance(message, SystemMessage):
50
+ role = "system"
51
+ else:
52
+ role = "user"
53
+ else:
54
+ content = str(message)
55
+
56
+ return {"role": role, "content": content}
57
+
58
+ class ZhipuaiLLM(BaseChatModel):
59
+ model_name: str = "glm-4-flash"
60
+ temperature: Optional[float] = 0.1
61
+ max_tokens: Optional[int] = None
62
+ timeout: Optional[int] = None
63
+ stop: Optional[List[str]] = None
64
+ max_retries: int = 3
65
+ api_key: str | None = None
66
+
67
+ def _get_client(self) -> ZhipuAI:
68
+ current_api_key = self.api_key or os.environ.get("ZHIPUAI_API_KEY")
69
+ return ZhipuAI(api_key=current_api_key)
70
+
71
+ def _generate(
72
+ self,
73
+ messages: List[Any],
74
+ stop: Optional[List[str]] = None,
75
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
76
+ **kwargs: Any,
77
+ ) -> ChatResult:
78
+ zhipu_messages = [_convert_message_to_dict(message) for message in messages]
79
+ start_time = time.time()
80
+
81
+ client = self._get_client()
82
+ response = client.chat.completions.create(
83
+ model=self.model_name,
84
+ temperature=self.temperature,
85
+ max_tokens=self.max_tokens,
86
+ timeout=self.timeout,
87
+ stop=stop,
88
+ messages=zhipu_messages,
89
+ **kwargs
90
+ )
91
+
92
+ time_in_seconds = time.time() - start_time
93
+ message = AIMessage(
94
+ content=response.choices[0].message.content,
95
+ additional_kwargs={},
96
+ response_metadata={"time_in_seconds": round(time_in_seconds, 3)},
97
+ usage_metadata={
98
+ "input_tokens": response.usage.prompt_tokens,
99
+ "output_tokens": response.usage.completion_tokens,
100
+ "total_tokens": response.usage.total_tokens,
101
+ },
102
+ )
103
+ return ChatResult(generations=[ChatGeneration(message=message)])
104
+
105
+ def _stream(
106
+ self,
107
+ messages: List[Any],
108
+ stop: Optional[List[str]] = None,
109
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
110
+ **kwargs: Any,
111
+ ) -> Iterator[ChatGenerationChunk]:
112
+ zhipu_messages = [_convert_message_to_dict(message) for message in messages]
113
+ start_time = time.time()
114
+
115
+ client = self._get_client()
116
+ response = client.chat.completions.create(
117
+ model=self.model_name,
118
+ stream=True,
119
+ temperature=self.temperature,
120
+ max_tokens=self.max_tokens,
121
+ timeout=self.timeout,
122
+ stop=stop,
123
+ messages=zhipu_messages,
124
+ **kwargs
125
+ )
126
+
127
+ usage_metadata = None
128
+ for res in response:
129
+ if hasattr(res, 'usage') and res.usage:
130
+ usage_metadata = UsageMetadata({
131
+ "input_tokens": getattr(res.usage, 'prompt_tokens', 0),
132
+ "output_tokens": getattr(res.usage, 'completion_tokens', 0),
133
+ "total_tokens": getattr(res.usage, 'total_tokens', 0),
134
+ })
135
+
136
+ chunk_content = res.choices[0].delta.content if res.choices and res.choices[0].delta.content else ""
137
+ chunk = ChatGenerationChunk(message=AIMessageChunk(content=chunk_content))
138
+
139
+ if run_manager and chunk_content:
140
+ run_manager.on_llm_new_token(chunk_content, chunk=chunk)
141
+
142
+ yield chunk
143
+
144
+ time_in_sec = time.time() - start_time
145
+ final_chunk = ChatGenerationChunk(
146
+ message=AIMessageChunk(
147
+ content="",
148
+ response_metadata={"time_in_sec": round(time_in_sec, 3)},
149
+ usage_metadata=usage_metadata
150
+ )
151
+ )
152
+ if run_manager:
153
+ run_manager.on_llm_new_token("", chunk=final_chunk)
154
+ yield final_chunk
155
+
156
+ @property
157
+ def _llm_type(self) -> str:
158
+ return self.model_name
159
+
160
+ @property
161
+ def _identifying_params(self) -> Dict[str, Any]:
162
+ return {
163
+ "model_name": self.model_name,
164
+ }