gerglitzen commited on
Commit
0ab9543
·
1 Parent(s): 47c7f9e
Files changed (1) hide show
  1. call_openai.py +231 -0
call_openai.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import json
3
+ from typing import List, Dict
4
+ from .callback_handler import BaseCallbackHandler
5
+ import tiktoken
6
+
7
+ def call_openai(
8
+ messages: List[Dict[str, str]],
9
+ functions: List[str] = None,
10
+ stream: str = "no",
11
+ model: str = "gpt-3.5-turbo",
12
+ temperature: float = 0,
13
+ callback: BaseCallbackHandler = None
14
+ ) -> str:
15
+ """
16
+ Call openai with list of messages and optional list of functions. See description at openai website.
17
+
18
+ Args:
19
+ messages: messages passed to openai. list of dictionaries with keys: role=[system, user, assitant, function] + content= message
20
+ functions: function list passed to openai
21
+ stream: ["no", "sentence", "token"]
22
+ model: name of openai model
23
+ temperature: of openai model
24
+ callback: callback handler class. If streaming, it is mandatory
25
+
26
+ Returns:
27
+ final message
28
+ """
29
+
30
+ current_state = None
31
+ prompt_tokens = token_count(
32
+ messages=messages,
33
+ functions=functions
34
+ )
35
+
36
+ if functions == None:
37
+ completion_tokens = -2
38
+ response = openai.ChatCompletion.create(
39
+ model = model,
40
+ temperature=temperature,
41
+ stream=True,
42
+ messages=messages,
43
+ )
44
+ else:
45
+ completion_tokens = -1
46
+ response = openai.ChatCompletion.create(
47
+ model = model,
48
+ temperature=temperature,
49
+ stream=True,
50
+ messages=messages,
51
+ functions=functions
52
+ )
53
+
54
+ for chunk in response:
55
+ completion_tokens += 1
56
+ data = json.loads(str(chunk["choices"][0]))
57
+ delta = data["delta"]
58
+ finish_reason = data["finish_reason"]
59
+
60
+ if finish_reason is not None:
61
+ if finish_reason == "function_call":
62
+ completion_tokens += 6
63
+ final_response = {
64
+ "usage": {
65
+ "completion_tokens": completion_tokens,
66
+ "prompt_tokens": prompt_tokens,
67
+ },
68
+ "choices": []
69
+ }
70
+
71
+ if current_state == "function":
72
+ d = {
73
+ "finish_reason": "function_call",
74
+ "message": {
75
+ "content": None,
76
+ "function_call": {
77
+ "arguments": function_arg,
78
+ "name": function_name
79
+ },
80
+ "role": "assistant"
81
+ }
82
+ }
83
+ final_response["choices"].append(d)
84
+
85
+ if current_state == "user":
86
+ d = {
87
+ "finish_reason": "stop",
88
+ "message": {
89
+ "content": message_all,
90
+ "role": "assistant"
91
+ }
92
+ }
93
+ final_response["choices"].append(d)
94
+
95
+ if callback:
96
+ callback.on_llm_end(response=final_response)
97
+ return final_response
98
+
99
+ else:
100
+ if current_state == None:
101
+ if 'function_call' in delta:
102
+ current_state = "function"
103
+ function_name = delta["function_call"]["name"]
104
+ function_arg = ""
105
+ # if stream != "no":
106
+ # s = f" - {function_name}"
107
+ # callback.on_llm_new_token(token=s)
108
+ else:
109
+ current_state = "user"
110
+ message_stream = ""
111
+ message_all = ""
112
+
113
+ elif current_state == "function":
114
+ function_arg += delta['function_call']['arguments']
115
+
116
+ elif current_state == "user":
117
+ token = delta["content"]
118
+ message_all += token
119
+
120
+ if stream == "token":
121
+ callback.on_llm_new_token(token=token)
122
+ if stream == "sentence":
123
+ message_stream += token
124
+ if "." in token or "!" in token or "?" in token or "\n" in token:
125
+ if message_stream[-1] == "\n":
126
+ callback.on_llm_new_token(token=message_stream[:-1])
127
+ else:
128
+ callback.on_llm_new_token(token=message_stream)
129
+ message_stream = ""
130
+
131
+
132
+ def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613"):
133
+ """Return the number of tokens used by a list of messages."""
134
+ try:
135
+ encoding = tiktoken.encoding_for_model(model)
136
+ except KeyError:
137
+ # print("Warning: model not found. Using cl100k_base encoding.")
138
+ encoding = tiktoken.get_encoding("cl100k_base")
139
+ if model in {
140
+ "gpt-3.5-turbo-0613",
141
+ "gpt-3.5-turbo-16k-0613",
142
+ "gpt-4-0314",
143
+ "gpt-4-32k-0314",
144
+ "gpt-4-0613",
145
+ "gpt-4-32k-0613",
146
+ }:
147
+ tokens_per_message = 3
148
+ tokens_per_name = 1
149
+ elif model == "gpt-3.5-turbo-0301":
150
+ tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
151
+ tokens_per_name = -1 # if there's a name, the role is omitted
152
+ elif "gpt-3.5-turbo" in model:
153
+ # print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
154
+ return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
155
+ elif "gpt-4" in model:
156
+ # print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
157
+ return num_tokens_from_messages(messages, model="gpt-4-0613")
158
+ else:
159
+ raise NotImplementedError(
160
+ f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
161
+ )
162
+ num_tokens = 0
163
+ # print(messages)
164
+ for message in messages:
165
+ num_tokens += tokens_per_message
166
+ for key, value in message.items():
167
+ if key == "function_call":
168
+ num_tokens += tokens_per_name
169
+ for k, v in value.items():
170
+ # print(k,v)
171
+ num_tokens += len(encoding.encode(v))
172
+ if value != None and key != "function_call":
173
+ num_tokens += len(encoding.encode(value))
174
+ if key == "name":
175
+ num_tokens += tokens_per_name
176
+ num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
177
+ return num_tokens
178
+
179
+ def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613"):
180
+ """Return the number of tokens used by a list of functions."""
181
+ try:
182
+ encoding = tiktoken.encoding_for_model(model)
183
+ except KeyError:
184
+ # print("Warning: model not found. Using cl100k_base encoding.")
185
+ encoding = tiktoken.get_encoding("cl100k_base")
186
+
187
+ num_tokens = 0
188
+ for function in functions:
189
+ function_tokens = len(encoding.encode(function['name']))
190
+ function_tokens += len(encoding.encode(function['description']))
191
+
192
+ if 'parameters' in function:
193
+ parameters = function['parameters']
194
+ if 'properties' in parameters:
195
+ for propertiesKey in parameters['properties']:
196
+ function_tokens += len(encoding.encode(propertiesKey))
197
+ v = parameters['properties'][propertiesKey]
198
+ for field in v:
199
+ if field == 'type':
200
+ function_tokens += 2
201
+ function_tokens += len(encoding.encode(v['type']))
202
+ elif field == 'description':
203
+ function_tokens += 2
204
+ function_tokens += len(encoding.encode(v['description']))
205
+ elif field == 'enum':
206
+ function_tokens -= 3
207
+ for o in v['enum']:
208
+ function_tokens += 3
209
+ function_tokens += len(encoding.encode(o))
210
+ else:
211
+ dummy = 0
212
+ # print(f"Warning: not supported field: {field}")
213
+ function_tokens += 16
214
+
215
+ num_tokens += function_tokens
216
+
217
+ num_tokens += 16
218
+ return num_tokens
219
+
220
+ def token_count(
221
+ messages: List[Dict[str, str]],
222
+ functions: List[str] = None,
223
+ model = "gpt-3.5-turbo-0613"
224
+ ) -> int:
225
+
226
+ msgs_tokens = num_tokens_from_messages(messages=messages, model=model)
227
+ tokens_used = msgs_tokens
228
+ if functions is not None:
229
+ function_tokens = num_tokens_from_functions(functions=functions, model=model)
230
+ tokens_used += function_tokens
231
+ return tokens_used