Files changed (5) hide show
  1. .gitignore +95 -0
  2. modeling_kimi_k25.py +0 -0
  3. requirements.txt +6 -0
  4. tokenization_kimi.py +368 -352
  5. tool_declaration_ts.py +500 -479
.gitignore ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # PyInstaller
28
+ *.manifest
29
+ *.spec
30
+
31
+ # Installer logs
32
+ pip-log.txt
33
+ pip-delete-this-directory.txt
34
+
35
+ # Unit test / coverage reports
36
+ htmlcov/
37
+ .tox/
38
+ .nox/
39
+ .coverage
40
+ .coverage.*
41
+ .cache
42
+ nosetests.xml
43
+ coverage.xml
44
+ *.cover
45
+ *.py,cover
46
+ .hypothesis/
47
+ .pytest_cache/
48
+
49
+ # Translations
50
+ *.mo
51
+ *.pot
52
+
53
+ # Jupyter Notebook
54
+ .ipynb_checkpoints
55
+
56
+ # IPython
57
+ profile_default/
58
+ ipython_config.py
59
+
60
+ # pyenv
61
+ .python-version
62
+
63
+ # Environment
64
+ .env
65
+ .venv
66
+ env/
67
+ venv/
68
+ ENV/
69
+ env.bak/
70
+ venv.bak/
71
+
72
+ # IDE
73
+ .idea/
74
+ .vscode/
75
+ *.swp
76
+ *.swo
77
+ *~
78
+
79
+ # Model weights and large files
80
+ *.bin
81
+ *.safetensors
82
+ *.gguf
83
+ *.pt
84
+ *.pth
85
+ *.ckpt
86
+ *.h5
87
+ model-*.json
88
+
89
+ # OS files
90
+ .DS_Store
91
+ Thumbs.db
92
+
93
+ # Logs
94
+ *.log
95
+ logs/
modeling_kimi_k25.py CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.57.1
3
+ tiktoken>=0.5.0
4
+ numpy>=1.24.0
5
+ Pillow>=9.0.0
6
+ pydantic>=2.0.0
tokenization_kimi.py CHANGED
@@ -1,352 +1,368 @@
1
- import os
2
- from collections import OrderedDict
3
- from logging import getLogger
4
- from pathlib import Path
5
- from shutil import copyfile
6
- from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
7
-
8
- import tiktoken
9
- from tiktoken.load import load_tiktoken_bpe
10
- from tokenizers import AddedToken
11
-
12
- from transformers.convert_slow_tokenizer import bytes_to_unicode
13
- from transformers.tokenization_utils import PreTrainedTokenizer
14
-
15
- from .tool_declaration_ts import encode_tools_to_typescript_style
16
-
17
- logger = getLogger(__name__)
18
- VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
19
-
20
-
21
- class TikTokenTokenizer(PreTrainedTokenizer):
22
- """
23
- Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
24
-
25
- This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
26
- this superclass for more information regarding those methods.
27
-
28
- Args:
29
- vocab_file (`str`):
30
- The path to the Tiktoken model file.
31
- bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
32
- The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
33
- eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
34
- The end of sequence token.
35
- unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
36
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
37
- token instead. The second to last item in special_tokens.
38
- pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
39
- The token used for padding, for example when batching sequences of different lengths.
40
- additional_special_tokens (list of `str`, *optional*):
41
- A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
42
- skipped when decoding if `skip_special_tokens` is set to `True`.
43
- """
44
-
45
- vocab_files_names = VOCAB_FILES_NAMES
46
-
47
- model_input_names = ["input_ids", "attention_mask"]
48
-
49
- special_tokens: Dict[str, int]
50
-
51
- num_reserved_special_tokens = 256
52
-
53
- pat_str = "|".join([
54
- r"""[\p{Han}]+""",
55
- r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
56
- r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
57
- r"""\p{N}{1,3}""",
58
- r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
59
- r"""\s*[\r\n]+""",
60
- r"""\s+(?!\S)""",
61
- r"""\s+""",
62
- ])
63
-
64
- def __init__(
65
- self,
66
- vocab_file,
67
- bos_token: Union[str, AddedToken] = "[BOS]",
68
- eos_token: Union[str, AddedToken] = "[EOS]",
69
- unk_token: Union[str, AddedToken, None] = None,
70
- pad_token: Union[str, AddedToken, None] = None,
71
- additional_special_tokens: List[str] = None,
72
- added_tokens_decoder: Optional[dict] = None,
73
- **kwargs,
74
- ):
75
- assert os.path.isfile(vocab_file), vocab_file
76
-
77
- if additional_special_tokens is None:
78
- additional_special_tokens = [
79
- "<|im_end|>",
80
- "<|im_user|>",
81
- "<|im_assistant|>",
82
- "<|start_header_id|>",
83
- "<|end_header_id|>",
84
- "[EOT]",
85
- "<|im_system|>",
86
- "<|im_middle|>",
87
- ]
88
-
89
- if added_tokens_decoder:
90
- special_tokens_mapping = {
91
- i: added_tokens_decoder[i].content
92
- for i in added_tokens_decoder
93
- }
94
- else:
95
- special_tokens_mapping = {}
96
-
97
- self.vocab_file = vocab_file
98
- mergeable_ranks = load_tiktoken_bpe(vocab_file)
99
- num_base_tokens = len(mergeable_ranks)
100
- self.special_tokens = {
101
- special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
102
- for i in range(num_base_tokens, num_base_tokens +
103
- self.num_reserved_special_tokens)
104
- }
105
-
106
- self.model = tiktoken.Encoding(
107
- name=Path(vocab_file).name,
108
- pat_str=self.pat_str,
109
- mergeable_ranks=mergeable_ranks,
110
- special_tokens=self.special_tokens,
111
- )
112
- logger.info(f"Reloaded tiktoken model from {vocab_file}")
113
-
114
- self.n_words: int = self.model.n_vocab
115
- # BOS / EOS token IDs
116
- self.bos_id: int = self.special_tokens[str(bos_token)]
117
- self.eos_id: int = self.special_tokens[str(eos_token)]
118
- logger.info(
119
- f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
120
- )
121
-
122
- self.pad_id: int = self.special_tokens[str(pad_token)]
123
- self.unk_id: int = self.special_tokens[str(unk_token)]
124
-
125
- self.byte_encoder = bytes_to_unicode()
126
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
127
-
128
- self.decoder = {}
129
- for i in range(self.n_words):
130
- # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
131
- decoding = ''.join([
132
- self.byte_encoder[ord(char)] for char in
133
- self.model.decode_single_token_bytes(i).decode('latin-1')
134
- ])
135
- self.decoder[i] = decoding
136
-
137
- self.encoder = {}
138
- for i in range(self.n_words):
139
- if i in self.decoder:
140
- self.encoder[self.decoder[i]] = i
141
-
142
- self._token_config_cache = OrderedDict()
143
- self._cache_max_size = 128
144
-
145
- super().__init__(
146
- bos_token=bos_token,
147
- eos_token=eos_token,
148
- unk_token=unk_token,
149
- pad_token=pad_token,
150
- additional_special_tokens=additional_special_tokens,
151
- added_tokens_decoder=added_tokens_decoder,
152
- **kwargs,
153
- )
154
- self.all_special_ids_set = set(self.all_special_ids)
155
-
156
- def encode(self,
157
- text: str,
158
- allow_special_tokens: bool = True,
159
- **kwargs) -> List[int]:
160
- """
161
- Encodes a string into a list of token IDs.
162
-
163
- Args:
164
- text (str): The input string to be encoded.
165
-
166
- Returns:
167
- list[int]: A list of token IDs.
168
- """
169
- # If there are other args, we should call super().encode because there are a lot of code
170
- # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
171
- # NOTE: our encode method is not compatible with the super().encode method,
172
- # e.g. split_special_tokens' default is True in our encode method.
173
- if len(kwargs) > 0:
174
- logger.warning(f"Calling super().encode with {kwargs}")
175
- return super().encode(text, **kwargs)
176
-
177
- assert type(text) is str
178
-
179
- # The tiktoken tokenizer can handle <=400k chars without
180
- # pyo3_runtime.PanicException.
181
- TIKTOKEN_MAX_ENCODE_CHARS = 400_000
182
-
183
- # https://github.com/openai/tiktoken/issues/195
184
- # Here we iterate over subsequences and split if we exceed the limit
185
- # of max consecutive non-whitespace or whitespace characters.
186
- MAX_NO_WHITESPACES_CHARS = 25_000
187
-
188
- texts = self.pre_tokenizer_process(text)
189
-
190
- all_substrs = []
191
- for text in texts:
192
- substrs = (
193
- substr for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
194
- for substr in self._split_whitespaces_or_nonwhitespaces(
195
- text[i:i +
196
- TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS))
197
- all_substrs.extend(substrs)
198
-
199
- t: List[int] = []
200
- for substr in all_substrs:
201
- if allow_special_tokens:
202
- t.extend(
203
- # we should consider special token as a common token
204
- self.model.encode(
205
- substr,
206
- allowed_special="all",
207
- ))
208
- else:
209
- t.extend(
210
- # we should consider special token as a common token
211
- self.model.encode(
212
- substr,
213
- disallowed_special=(),
214
- ))
215
-
216
- return t
217
-
218
- def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
219
- """
220
- Decodes a list of token IDs into a string.
221
-
222
- Args:
223
- token_ids (List[int]): The list of token IDs to be decoded.
224
-
225
- Returns:
226
- str: The decoded string.
227
- """
228
- # If there are other args, we should call super().decode because there are a lot of code
229
- # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
230
- if len(kwargs) > 0:
231
- return super().decode(token_ids, **kwargs)
232
-
233
- if type(token_ids) is int:
234
- token_ids = [token_ids]
235
-
236
- return self.model.decode(cast(List[int], token_ids))
237
-
238
- @staticmethod
239
- def _split_whitespaces_or_nonwhitespaces(
240
- s: str, max_consecutive_slice_len: int) -> Iterator[str]:
241
- """
242
- Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
243
- consecutive whitespaces or consecutive non-whitespaces.
244
- """
245
- current_slice_len = 0
246
- current_slice_is_space = s[0].isspace() if len(s) > 0 else False
247
- slice_start = 0
248
-
249
- for i in range(len(s)):
250
- is_now_space = s[i].isspace()
251
-
252
- if current_slice_is_space ^ is_now_space:
253
- current_slice_len = 1
254
- current_slice_is_space = is_now_space
255
- else:
256
- current_slice_len += 1
257
- if current_slice_len > max_consecutive_slice_len:
258
- yield s[slice_start:i]
259
- slice_start = i
260
- current_slice_len = 1
261
- yield s[slice_start:]
262
-
263
- def pre_tokenizer_process(self, text: str) -> List[str]:
264
- """
265
- pre-tokenizes the input text into a list of tokens.
266
- This method is used to split the input text into smaller chunks for internal processing.
267
- """
268
- return [text]
269
-
270
- """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
271
-
272
- @property
273
- def vocab_size(self) -> int:
274
- return self.n_words
275
-
276
- def get_vocab(self) -> Dict[str, int]:
277
- return self.encoder
278
-
279
- def _tokenize(self, text: str, **kwargs) -> List[str]:
280
- return [self.decoder[t] for t in self.encode(text)]
281
-
282
- def _convert_token_to_id(self, token: str) -> int:
283
- return self.encoder.get(token, self.unk_id)
284
-
285
- def _convert_id_to_token(self, index: int) -> str:
286
- return self.decoder.get(index)
287
-
288
- @staticmethod
289
- def clean_up_tokenization(out_string: str) -> str:
290
- return out_string
291
-
292
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
293
- text = ''.join(tokens)
294
- text = bytearray([self.byte_decoder[c]
295
- for c in text]).decode('utf-8', 'replace')
296
- return text
297
-
298
- def save_vocabulary(self,
299
- save_directory: str,
300
- filename_prefix: Optional[str] = None) -> Tuple[str]:
301
- if not os.path.isdir(save_directory):
302
- raise ValueError(
303
- f"vocabulary path ({save_directory}) should be a directory")
304
- out_vocab_file = os.path.join(
305
- save_directory,
306
- (filename_prefix + "-" if filename_prefix else "") +
307
- VOCAB_FILES_NAMES["vocab_file"])
308
-
309
- if os.path.abspath(self.vocab_file) != os.path.abspath(
310
- out_vocab_file) and os.path.isfile(self.vocab_file):
311
- copyfile(self.vocab_file, out_vocab_file)
312
-
313
- return (out_vocab_file, )
314
-
315
- def apply_chat_template(self,
316
- conversation,
317
- tools: Optional[list[dict]] = None,
318
- tokenize: bool = False,
319
- add_generation_prompt: bool = True,
320
- thinking: bool = True,
321
- **kwargs):
322
-
323
- tools = deep_sort_dict(tools)
324
-
325
- # Convert tools to TypeScript style string if tools are provided
326
- tools_ts_str = None
327
- if tools:
328
- try:
329
- tools_ts_str = encode_tools_to_typescript_style(tools)
330
-
331
- except Exception as e:
332
- print(f"Failed to convert tools to TypeScript style: {e}")
333
- tools_ts_str = None
334
-
335
- # Store the TypeScript string in kwargs so it can be accessed by the template
336
- if tools_ts_str is not None:
337
- kwargs['tools_ts_str'] = tools_ts_str
338
- return super().apply_chat_template(
339
- conversation,
340
- tools=tools,
341
- tokenize=tokenize,
342
- add_generation_prompt=add_generation_prompt,
343
- thinking=thinking,
344
- **kwargs)
345
-
346
-
347
- def deep_sort_dict(obj: Any) -> Any:
348
- if isinstance(obj, dict):
349
- return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
350
- if isinstance(obj, list):
351
- return [deep_sort_dict(item) for item in obj]
352
- return obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+ from logging import getLogger
4
+ from pathlib import Path
5
+ from shutil import copyfile
6
+ from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
7
+
8
+ import tiktoken
9
+ from tiktoken.load import load_tiktoken_bpe
10
+ from tokenizers import AddedToken
11
+
12
+ from transformers.convert_slow_tokenizer import bytes_to_unicode
13
+ from transformers.tokenization_utils import PreTrainedTokenizer
14
+
15
+ from .tool_declaration_ts import encode_tools_to_typescript_style
16
+
17
+ logger = getLogger(__name__)
18
+ VOCAB_FILES_NAMES = {"vocab_file": "tiktoken.model"}
19
+
20
+
21
+ class TikTokenTokenizer(PreTrainedTokenizer):
22
+ """
23
+ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. See megatron/tokenizer/tiktoken_tokenizer.py.
24
+
25
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
26
+ this superclass for more information regarding those methods.
27
+
28
+ Args:
29
+ vocab_file (`str`):
30
+ The path to the Tiktoken model file.
31
+ bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|begin_of_text|>",`):
32
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
33
+ eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|end_of_text|>"`):
34
+ The end of sequence token.
35
+ unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_249|>"`):
36
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
37
+ token instead. The second to last item in special_tokens.
38
+ pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<|reserved_special_token_250|>"`):
39
+ The token used for padding, for example when batching sequences of different lengths.
40
+ additional_special_tokens (list of `str`, *optional*):
41
+ A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
42
+ skipped when decoding if `skip_special_tokens` is set to `True`.
43
+ """
44
+
45
+ vocab_files_names = VOCAB_FILES_NAMES
46
+
47
+ model_input_names = ["input_ids", "attention_mask"]
48
+
49
+ special_tokens: Dict[str, int]
50
+
51
+ num_reserved_special_tokens = 256
52
+
53
+ pat_str = "|".join(
54
+ [
55
+ r"""[\p{Han}]+""",
56
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
57
+ r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
58
+ r"""\p{N}{1,3}""",
59
+ r""" ?[^\s\p{L}\p{N}]+[\r\n]*""",
60
+ r"""\s*[\r\n]+""",
61
+ r"""\s+(?!\S)""",
62
+ r"""\s+""",
63
+ ]
64
+ )
65
+
66
+ def __init__(
67
+ self,
68
+ vocab_file,
69
+ bos_token: Union[str, AddedToken] = "[BOS]",
70
+ eos_token: Union[str, AddedToken] = "[EOS]",
71
+ unk_token: Union[str, AddedToken, None] = None,
72
+ pad_token: Union[str, AddedToken, None] = None,
73
+ additional_special_tokens: Optional[List[str]] = None,
74
+ added_tokens_decoder: Optional[dict] = None,
75
+ **kwargs,
76
+ ):
77
+ assert os.path.isfile(vocab_file), vocab_file
78
+
79
+ if additional_special_tokens is None:
80
+ additional_special_tokens = [
81
+ "<|im_end|>",
82
+ "<|im_user|>",
83
+ "<|im_assistant|>",
84
+ "<|start_header_id|>",
85
+ "<|end_header_id|>",
86
+ "[EOT]",
87
+ "<|im_system|>",
88
+ "<|im_middle|>",
89
+ ]
90
+
91
+ if added_tokens_decoder:
92
+ special_tokens_mapping = {
93
+ i: added_tokens_decoder[i].content for i in added_tokens_decoder
94
+ }
95
+ else:
96
+ special_tokens_mapping = {}
97
+
98
+ self.vocab_file = vocab_file
99
+ mergeable_ranks = load_tiktoken_bpe(vocab_file)
100
+ num_base_tokens = len(mergeable_ranks)
101
+ self.special_tokens = {
102
+ special_tokens_mapping.get(i, f"<|reserved_token_{i}|>"): i
103
+ for i in range(
104
+ num_base_tokens, num_base_tokens + self.num_reserved_special_tokens
105
+ )
106
+ }
107
+
108
+ self.model = tiktoken.Encoding(
109
+ name=Path(vocab_file).name,
110
+ pat_str=self.pat_str,
111
+ mergeable_ranks=mergeable_ranks,
112
+ special_tokens=self.special_tokens,
113
+ )
114
+ logger.info(f"Reloaded tiktoken model from {vocab_file}")
115
+
116
+ self.n_words: int = self.model.n_vocab
117
+ # BOS / EOS token IDs
118
+ self.bos_id: int = self.special_tokens[str(bos_token)]
119
+ self.eos_id: int = self.special_tokens[str(eos_token)]
120
+ logger.info(
121
+ f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
122
+ )
123
+
124
+ self.pad_id: int = self.special_tokens[str(pad_token)]
125
+ self.unk_id: int = self.special_tokens[str(unk_token)]
126
+
127
+ self.byte_encoder = bytes_to_unicode()
128
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
129
+
130
+ self.decoder = {}
131
+ for i in range(self.n_words):
132
+ # Taken from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee
133
+ decoding = "".join(
134
+ [
135
+ self.byte_encoder[ord(char)]
136
+ for char in self.model.decode_single_token_bytes(i).decode(
137
+ "latin-1"
138
+ )
139
+ ]
140
+ )
141
+ self.decoder[i] = decoding
142
+
143
+ self.encoder = {}
144
+ for i in range(self.n_words):
145
+ if i in self.decoder:
146
+ self.encoder[self.decoder[i]] = i
147
+
148
+ self._token_config_cache = OrderedDict()
149
+ self._cache_max_size = 128
150
+
151
+ super().__init__(
152
+ bos_token=bos_token,
153
+ eos_token=eos_token,
154
+ unk_token=unk_token,
155
+ pad_token=pad_token,
156
+ additional_special_tokens=additional_special_tokens,
157
+ added_tokens_decoder=added_tokens_decoder,
158
+ **kwargs,
159
+ )
160
+ self.all_special_ids_set = set(self.all_special_ids)
161
+
162
+ def encode(
163
+ self, text: str, allow_special_tokens: bool = True, **kwargs
164
+ ) -> List[int]:
165
+ """
166
+ Encodes a string into a list of token IDs.
167
+
168
+ Args:
169
+ text (str): The input string to be encoded.
170
+
171
+ Returns:
172
+ list[int]: A list of token IDs.
173
+ """
174
+ # If there are other args, we should call super().encode because there are a lot of code
175
+ # to handle those args. supper().encode finally will call _tokenize and _convert_token_to_id.
176
+ # NOTE: our encode method is not compatible with the super().encode method,
177
+ # e.g. split_special_tokens' default is True in our encode method.
178
+ if len(kwargs) > 0:
179
+ logger.warning(f"Calling super().encode with {kwargs}")
180
+ return super().encode(text, **kwargs)
181
+
182
+ assert type(text) is str
183
+
184
+ # The tiktoken tokenizer can handle <=400k chars without
185
+ # pyo3_runtime.PanicException.
186
+ TIKTOKEN_MAX_ENCODE_CHARS = 400_000
187
+
188
+ # https://github.com/openai/tiktoken/issues/195
189
+ # Here we iterate over subsequences and split if we exceed the limit
190
+ # of max consecutive non-whitespace or whitespace characters.
191
+ MAX_NO_WHITESPACES_CHARS = 25_000
192
+
193
+ texts = self.pre_tokenizer_process(text)
194
+
195
+ all_substrs = []
196
+ for text in texts:
197
+ substrs = (
198
+ substr
199
+ for i in range(0, len(text), TIKTOKEN_MAX_ENCODE_CHARS)
200
+ for substr in self._split_whitespaces_or_nonwhitespaces(
201
+ text[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS
202
+ )
203
+ )
204
+ all_substrs.extend(substrs)
205
+
206
+ t: List[int] = []
207
+ for substr in all_substrs:
208
+ if allow_special_tokens:
209
+ t.extend(
210
+ # we should consider special token as a common token
211
+ self.model.encode(
212
+ substr,
213
+ allowed_special="all",
214
+ )
215
+ )
216
+ else:
217
+ t.extend(
218
+ # we should consider special token as a common token
219
+ self.model.encode(
220
+ substr,
221
+ disallowed_special=(),
222
+ )
223
+ )
224
+
225
+ return t
226
+
227
+ def decode(self, token_ids: Union[int, List[int]], **kwargs) -> str:
228
+ """
229
+ Decodes a list of token IDs into a string.
230
+
231
+ Args:
232
+ token_ids (List[int]): The list of token IDs to be decoded.
233
+
234
+ Returns:
235
+ str: The decoded string.
236
+ """
237
+ # If there are other args, we should call super().decode because there are a lot of code
238
+ # to handle those args. supper().encode finally will call convert_tokens_to_string and _convert_id_to_token.
239
+ if len(kwargs) > 0:
240
+ return super().decode(token_ids, **kwargs)
241
+
242
+ if type(token_ids) is int:
243
+ token_ids = [token_ids]
244
+
245
+ return self.model.decode(cast(List[int], token_ids))
246
+
247
+ @staticmethod
248
+ def _split_whitespaces_or_nonwhitespaces(
249
+ s: str, max_consecutive_slice_len: int
250
+ ) -> Iterator[str]:
251
+ """
252
+ Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len`
253
+ consecutive whitespaces or consecutive non-whitespaces.
254
+ """
255
+ current_slice_len = 0
256
+ current_slice_is_space = s[0].isspace() if len(s) > 0 else False
257
+ slice_start = 0
258
+
259
+ for i in range(len(s)):
260
+ is_now_space = s[i].isspace()
261
+
262
+ if current_slice_is_space ^ is_now_space:
263
+ current_slice_len = 1
264
+ current_slice_is_space = is_now_space
265
+ else:
266
+ current_slice_len += 1
267
+ if current_slice_len > max_consecutive_slice_len:
268
+ yield s[slice_start:i]
269
+ slice_start = i
270
+ current_slice_len = 1
271
+ yield s[slice_start:]
272
+
273
+ def pre_tokenizer_process(self, text: str) -> List[str]:
274
+ """
275
+ pre-tokenizes the input text into a list of tokens.
276
+ This method is used to split the input text into smaller chunks for internal processing.
277
+ """
278
+ return [text]
279
+
280
+ """ ----- Below are the abstract methods required by PreTrainedTokenizer ----- """
281
+
282
+ @property
283
+ def vocab_size(self) -> int:
284
+ return self.n_words
285
+
286
+ def get_vocab(self) -> Dict[str, int]:
287
+ return self.encoder
288
+
289
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
290
+ return [self.decoder[t] for t in self.encode(text)]
291
+
292
+ def _convert_token_to_id(self, token: str) -> int:
293
+ return self.encoder.get(token, self.unk_id)
294
+
295
+ def _convert_id_to_token(self, index: int) -> str:
296
+ return self.decoder.get(index)
297
+
298
+ @staticmethod
299
+ def clean_up_tokenization(out_string: str) -> str:
300
+ return out_string
301
+
302
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
303
+ text = "".join(tokens)
304
+ text = bytearray([self.byte_decoder[c] for c in text]).decode(
305
+ "utf-8", "replace"
306
+ )
307
+ return text
308
+
309
+ def save_vocabulary(
310
+ self, save_directory: str, filename_prefix: Optional[str] = None
311
+ ) -> Tuple[str]:
312
+ if not os.path.isdir(save_directory):
313
+ raise ValueError(
314
+ f"vocabulary path ({save_directory}) should be a directory"
315
+ )
316
+ out_vocab_file = os.path.join(
317
+ save_directory,
318
+ (filename_prefix + "-" if filename_prefix else "")
319
+ + VOCAB_FILES_NAMES["vocab_file"],
320
+ )
321
+
322
+ if os.path.abspath(self.vocab_file) != os.path.abspath(
323
+ out_vocab_file
324
+ ) and os.path.isfile(self.vocab_file):
325
+ copyfile(self.vocab_file, out_vocab_file)
326
+
327
+ return (out_vocab_file,)
328
+
329
+ def apply_chat_template(
330
+ self,
331
+ conversation,
332
+ tools: Optional[list[dict]] = None,
333
+ tokenize: bool = False,
334
+ add_generation_prompt: bool = True,
335
+ thinking: bool = True,
336
+ **kwargs,
337
+ ):
338
+ tools = deep_sort_dict(tools)
339
+
340
+ # Convert tools to TypeScript style string if tools are provided
341
+ tools_ts_str = None
342
+ if tools:
343
+ try:
344
+ tools_ts_str = encode_tools_to_typescript_style(tools)
345
+
346
+ except Exception as e:
347
+ print(f"Failed to convert tools to TypeScript style: {e}")
348
+ tools_ts_str = None
349
+
350
+ # Store the TypeScript string in kwargs so it can be accessed by the template
351
+ if tools_ts_str is not None:
352
+ kwargs["tools_ts_str"] = tools_ts_str
353
+ return super().apply_chat_template(
354
+ conversation,
355
+ tools=tools,
356
+ tokenize=tokenize,
357
+ add_generation_prompt=add_generation_prompt,
358
+ thinking=thinking,
359
+ **kwargs,
360
+ )
361
+
362
+
363
+ def deep_sort_dict(obj: Any) -> Any:
364
+ if isinstance(obj, dict):
365
+ return {k: deep_sort_dict(v) for k, v in sorted(obj.items())}
366
+ if isinstance(obj, list):
367
+ return [deep_sort_dict(item) for item in obj]
368
+ return obj
tool_declaration_ts.py CHANGED
@@ -1,479 +1,500 @@
1
- """
2
- Encode structured tool declaration to typescript style string.
3
- """
4
- import dataclasses
5
- import json
6
- import logging
7
- from collections.abc import Sequence
8
- from typing import Any
9
-
10
- logger = logging.getLogger(__name__)
11
-
12
- _TS_INDENT = " "
13
- _TS_FIELD_DELIMITER = ",\n"
14
-
15
-
16
- class _SchemaRegistry:
17
- """Registry for schema definitions to handle $ref resolution"""
18
-
19
- def __init__(self):
20
- self.definitions = {}
21
- self.has_self_ref = False
22
-
23
- def register_definitions(self, defs: dict[str, Any]):
24
- """Register schema definitions from $defs section"""
25
- if not defs:
26
- return
27
- for def_name, def_schema in defs.items():
28
- self.definitions[def_name] = def_schema
29
-
30
- def resolve_ref(self, ref: str) -> dict[str, Any]:
31
- """Resolve a reference to its schema definition"""
32
- if ref == "#":
33
- self.has_self_ref = True
34
- return {"$self_ref": True}
35
- elif ref.startswith("#/$defs/"):
36
- def_name = ref.split("/")[-1]
37
- if def_name not in self.definitions:
38
- raise ValueError(f"Reference not found: {ref}")
39
- return self.definitions[def_name]
40
- else:
41
- raise ValueError(f"Unsupported reference format: {ref}")
42
-
43
-
44
- def _format_description(description: str, indent: str = "") -> str:
45
- return "\n".join([
46
- f"{indent}// {line}" if line else ""
47
- for line in description.split("\n")
48
- ])
49
-
50
-
51
- class _BaseType:
52
- description: str
53
- constraints: dict[str, Any]
54
-
55
- def __init__(
56
- self,
57
- extra_props: dict[str, Any],
58
- *,
59
- allowed_constraint_keys: Sequence[str] = (),
60
- ):
61
- self.description = extra_props.get("description", "")
62
- self.constraints = {
63
- k: v
64
- for k, v in extra_props.items() if k in allowed_constraint_keys
65
- }
66
-
67
- def to_typescript_style(self, indent: str = "") -> str:
68
- raise NotImplementedError
69
-
70
- def format_docstring(self, indent: str) -> str:
71
- lines = []
72
- if self.description:
73
- lines.append(_format_description(self.description, indent))
74
- if self.constraints:
75
- constraints_str = ", ".join(f"{k}: {v}" for k, v in sorted(
76
- self.constraints.items(), key=lambda kv: kv[0]))
77
- lines.append(f"{indent}// {constraints_str}")
78
-
79
- return "".join(x + "\n" for x in lines)
80
-
81
-
82
- class _ParameterTypeScalar(_BaseType):
83
- type: str
84
-
85
- def __init__(self, type: str, extra_props: dict[str, Any] | None = None):
86
- self.type = type
87
-
88
- allowed_constraint_keys: list[str] = []
89
- if self.type == "string":
90
- allowed_constraint_keys = ["maxLength", "minLength", "pattern"]
91
- elif self.type in ("number", "integer"):
92
- allowed_constraint_keys = ["maximum", "minimum"]
93
-
94
- super().__init__(extra_props or {},
95
- allowed_constraint_keys=allowed_constraint_keys)
96
-
97
- def to_typescript_style(self, indent: str = "") -> str:
98
- # Map integer to number in TypeScript
99
- if self.type == "integer":
100
- return "number"
101
- return self.type
102
-
103
-
104
- class _ParameterTypeObject(_BaseType):
105
- properties: list["_Parameter"]
106
- additional_properties: Any | None = None
107
-
108
- def __init__(self,
109
- json_schema_object: dict[str, Any],
110
- registry: _SchemaRegistry | None = None):
111
- super().__init__(json_schema_object)
112
-
113
- self.properties = []
114
- self.additional_properties = None
115
-
116
- if not json_schema_object:
117
- return
118
-
119
- if "$defs" in json_schema_object and registry:
120
- registry.register_definitions(json_schema_object["$defs"])
121
-
122
- self.additional_properties = json_schema_object.get(
123
- "additionalProperties")
124
- if isinstance(self.additional_properties, dict):
125
- self.additional_properties = _parse_parameter_type(
126
- self.additional_properties, registry)
127
-
128
- if "properties" not in json_schema_object:
129
- return
130
-
131
- required_parameters = json_schema_object.get("required", [])
132
- optional_parameters = set(
133
- json_schema_object["properties"].keys()) - set(required_parameters)
134
-
135
- self.properties = [
136
- _Parameter(
137
- name=name,
138
- type=_parse_parameter_type(prop, registry),
139
- optional=name in optional_parameters,
140
- default=prop.get("default")
141
- if isinstance(prop, dict) else None,
142
- ) for name, prop in json_schema_object["properties"].items()
143
- ]
144
-
145
- def to_typescript_style(self, indent: str = "") -> str:
146
- # sort by optional, make the required parameters first
147
- parameters = [p for p in self.properties if not p.optional]
148
- opt_params = [p for p in self.properties if p.optional]
149
-
150
- parameters = sorted(parameters, key=lambda p: p.name)
151
- parameters.extend(sorted(opt_params, key=lambda p: p.name))
152
-
153
- param_strs = []
154
- for p in parameters:
155
- one = p.to_typescript_style(indent=indent + _TS_INDENT)
156
- param_strs.append(one)
157
-
158
- if self.additional_properties is not None:
159
- ap_type_str = "any"
160
- if self.additional_properties is True:
161
- ap_type_str = "any"
162
- elif self.additional_properties is False:
163
- ap_type_str = "never"
164
- elif isinstance(self.additional_properties, _ParameterType):
165
- ap_type_str = self.additional_properties.to_typescript_style(
166
- indent=indent + _TS_INDENT)
167
- else:
168
- raise ValueError(
169
- f"Unknown additionalProperties: {self.additional_properties}"
170
- )
171
- param_strs.append(
172
- f"{indent + _TS_INDENT}[k: string]: {ap_type_str}")
173
-
174
- if not param_strs:
175
- return "{}"
176
-
177
- params_str = _TS_FIELD_DELIMITER.join(param_strs)
178
- if params_str:
179
- # add new line before and after
180
- params_str = f"\n{params_str}\n"
181
- # always wrap with object
182
- return f"{{{params_str}{indent}}}"
183
-
184
-
185
- class _ParameterTypeArray(_BaseType):
186
- item: "_ParameterType"
187
-
188
- def __init__(self,
189
- json_schema_object: dict[str, Any],
190
- registry: _SchemaRegistry | None = None):
191
- super().__init__(json_schema_object,
192
- allowed_constraint_keys=("minItems", "maxItems"))
193
- if json_schema_object.get("items"):
194
- self.item = _parse_parameter_type(json_schema_object["items"],
195
- registry)
196
- else:
197
- self.item = _ParameterTypeScalar(type="any")
198
-
199
- def to_typescript_style(self, indent: str = "") -> str:
200
- item_docstring = self.item.format_docstring(indent + _TS_INDENT)
201
- if item_docstring:
202
- return ("Array<\n" + item_docstring + indent + _TS_INDENT +
203
- self.item.to_typescript_style(indent=indent + _TS_INDENT) +
204
- "\n" + indent + ">")
205
- else:
206
- return f"Array<{self.item.to_typescript_style(indent=indent)}>"
207
-
208
-
209
- class _ParameterTypeEnum(_BaseType):
210
- # support scalar types only
211
- enum: list[str | int | float | bool | None]
212
-
213
- def __init__(self, json_schema_object: dict[str, Any]):
214
- super().__init__(json_schema_object)
215
- self.enum = json_schema_object["enum"]
216
-
217
- # Validate enum values against declared type if present
218
- if "type" in json_schema_object:
219
- typ = json_schema_object["type"]
220
- if isinstance(typ, list):
221
- if len(typ) == 1:
222
- typ = typ[0]
223
- elif len(typ) == 2:
224
- if "null" not in typ:
225
- raise ValueError(f"Enum type {typ} is not supported")
226
- else:
227
- typ = typ[0] if typ[0] != "null" else typ[1]
228
- else:
229
- raise ValueError(f"Enum type {typ} is not supported")
230
- for val in self.enum:
231
- if val is None:
232
- continue
233
- if typ == "string" and not isinstance(val, str):
234
- raise ValueError(f"Enum value {val} is not a string")
235
- elif typ == "number" and not isinstance(val, (int, float)):
236
- raise ValueError(f"Enum value {val} is not a number")
237
- elif typ == "integer" and not isinstance(val, int):
238
- raise ValueError(f"Enum value {val} is not an integer")
239
- elif typ == "boolean" and not isinstance(val, bool):
240
- raise ValueError(f"Enum value {val} is not a boolean")
241
-
242
- def to_typescript_style(self, indent: str = "") -> str:
243
- return " | ".join(
244
- [f'"{e}"' if isinstance(e, str) else str(e) for e in self.enum])
245
-
246
-
247
- class _ParameterTypeAnyOf(_BaseType):
248
- types: list["_ParameterType"]
249
-
250
- def __init__(
251
- self,
252
- json_schema_object: dict[str, Any],
253
- registry: _SchemaRegistry | None = None,
254
- ):
255
- super().__init__(json_schema_object)
256
- self.types = [
257
- _parse_parameter_type(t, registry)
258
- for t in json_schema_object["anyOf"]
259
- ]
260
-
261
- def to_typescript_style(self, indent: str = "") -> str:
262
- return " | ".join(
263
- [t.to_typescript_style(indent=indent) for t in self.types])
264
-
265
-
266
- class _ParameterTypeUnion(_BaseType):
267
- types: list[str]
268
-
269
- def __init__(self, json_schema_object: dict[str, Any]):
270
- super().__init__(json_schema_object)
271
-
272
- mapping = {
273
- "string": "string",
274
- "number": "number",
275
- "integer": "number",
276
- "boolean": "boolean",
277
- "null": "null",
278
- "object": "{}",
279
- "array": "Array<any>",
280
- }
281
- self.types = [mapping[t] for t in json_schema_object["type"]]
282
-
283
- def to_typescript_style(self, indent: str = "") -> str:
284
- return " | ".join(self.types)
285
-
286
-
287
- class _ParameterTypeRef(_BaseType):
288
- ref_name: str
289
- is_self_ref: bool = False
290
-
291
- def __init__(self, json_schema_object: dict[str, Any],
292
- registry: _SchemaRegistry):
293
- super().__init__(json_schema_object)
294
-
295
- ref = json_schema_object["$ref"]
296
- resolved_schema = registry.resolve_ref(ref)
297
-
298
- if resolved_schema.get("$self_ref", False):
299
- self.ref_name = "parameters"
300
- self.is_self_ref = True
301
- else:
302
- self.ref_name = ref.split("/")[-1]
303
-
304
- def to_typescript_style(self, indent: str = "") -> str:
305
- return self.ref_name
306
-
307
-
308
- _ParameterType = (_ParameterTypeScalar
309
- | _ParameterTypeObject
310
- | _ParameterTypeArray
311
- | _ParameterTypeEnum
312
- | _ParameterTypeAnyOf
313
- | _ParameterTypeUnion
314
- | _ParameterTypeRef)
315
-
316
-
317
- @dataclasses.dataclass
318
- class _Parameter:
319
- """
320
- A parameter in a function, or a field in a object.
321
- It consists of the type as well as the name.
322
- """
323
-
324
- type: _ParameterType
325
- name: str = "_"
326
- optional: bool = True
327
- default: Any | None = None
328
-
329
- @classmethod
330
- def parse_extended(cls, attributes: dict[str, Any]) -> "_Parameter":
331
- if not attributes:
332
- raise ValueError("attributes is empty")
333
-
334
- return cls(
335
- name=attributes.get("name", "_"),
336
- type=_parse_parameter_type(attributes),
337
- optional=attributes.get("optional", False),
338
- default=attributes.get("default"),
339
- )
340
-
341
- def to_typescript_style(self, indent: str = "") -> str:
342
- comments = self.type.format_docstring(indent)
343
-
344
- if self.default is not None:
345
- default_repr = (json.dumps(self.default, ensure_ascii=False)
346
- if not isinstance(self.default, (int, float, bool))
347
- else repr(self.default))
348
- comments += f"{indent}// Default: {default_repr}\n"
349
-
350
- return (
351
- comments +
352
- f"{indent}{self.name}{'?' if self.optional else ''}: {self.type.to_typescript_style(indent=indent)}"
353
- )
354
-
355
-
356
- def _parse_parameter_type(
357
- json_schema_object: dict[str, Any] | bool,
358
- registry: _SchemaRegistry | None = None) -> _ParameterType:
359
- if isinstance(json_schema_object, bool):
360
- if json_schema_object:
361
- return _ParameterTypeScalar(type="any")
362
- else:
363
- logger.warning(
364
- f"Warning: Boolean value {json_schema_object} is not supported, use null instead."
365
- )
366
- return _ParameterTypeScalar(type="null")
367
-
368
- if "$ref" in json_schema_object and registry:
369
- return _ParameterTypeRef(json_schema_object, registry)
370
-
371
- if "anyOf" in json_schema_object:
372
- return _ParameterTypeAnyOf(json_schema_object, registry)
373
- elif "enum" in json_schema_object:
374
- return _ParameterTypeEnum(json_schema_object)
375
- elif "type" in json_schema_object:
376
- typ = json_schema_object["type"]
377
- if isinstance(typ, list):
378
- return _ParameterTypeUnion(json_schema_object)
379
- elif typ == "object":
380
- return _ParameterTypeObject(json_schema_object, registry)
381
- elif typ == "array":
382
- return _ParameterTypeArray(json_schema_object, registry)
383
- else:
384
- return _ParameterTypeScalar(typ, json_schema_object)
385
- elif json_schema_object == {}:
386
- return _ParameterTypeScalar(type="any")
387
- else:
388
- raise ValueError(f"Invalid JSON Schema object: {json_schema_object}")
389
-
390
-
391
- def _openai_function_to_typescript_style(function: dict[str, Any], ) -> str:
392
- """Convert OpenAI function definition (dict) to TypeScript style string."""
393
- registry = _SchemaRegistry()
394
- parameters = function.get("parameters") or {}
395
- parsed = _ParameterTypeObject(parameters, registry)
396
-
397
- interfaces = []
398
- root_interface_name = None
399
- if registry.has_self_ref:
400
- root_interface_name = "parameters"
401
- params_str = _TS_FIELD_DELIMITER.join([
402
- p.to_typescript_style(indent=_TS_INDENT) for p in parsed.properties
403
- ])
404
- params_str = f"\n{params_str}\n" if params_str else ""
405
- interface_def = f"interface {root_interface_name} {{{params_str}}}"
406
- interfaces.append(interface_def)
407
-
408
- definitions_copy = dict(registry.definitions)
409
- for def_name, def_schema in definitions_copy.items():
410
- obj_type = _parse_parameter_type(def_schema, registry)
411
- params_str = obj_type.to_typescript_style()
412
-
413
- description_part = ""
414
- if obj_description := def_schema.get("description", ""):
415
- description_part = _format_description(obj_description) + "\n"
416
-
417
- interface_def = f"{description_part}interface {def_name} {params_str}"
418
- interfaces.append(interface_def)
419
-
420
- interface_str = "\n".join(interfaces)
421
- function_name = function.get("name", "function")
422
- if root_interface_name:
423
- type_def = f"type {function_name} = (_: {root_interface_name}) => any;"
424
- else:
425
- params_str = parsed.to_typescript_style()
426
- type_def = f"type {function_name} = (_: {params_str}) => any;"
427
-
428
- description = function.get("description")
429
- return "\n".join(
430
- filter(
431
- bool,
432
- [
433
- interface_str,
434
- ((description and _format_description(description)) or ""),
435
- type_def,
436
- ],
437
- ))
438
-
439
-
440
- def encode_tools_to_typescript_style(tools: list[dict[str, Any]], ) -> str:
441
- """
442
- Convert tools (list of dict) to TypeScript style string.
443
-
444
- Supports OpenAI format: {"type": "function", "function": {...}}
445
-
446
- Args:
447
- tools: List of tool definitions in dict format
448
-
449
- Returns:
450
- TypeScript style string representation of the tools
451
- """
452
- if not tools:
453
- return ""
454
-
455
- functions = []
456
-
457
- for tool in tools:
458
- tool_type = tool.get("type")
459
- if tool_type == "function":
460
- func_def = tool.get("function", {})
461
- if func_def:
462
- functions.append(
463
- _openai_function_to_typescript_style(func_def))
464
- else:
465
- # Skip unsupported tool types (like "_plugin")
466
- continue
467
-
468
- if not functions:
469
- return ""
470
-
471
- functions_str = "\n".join(functions)
472
- result = "# Tools\n\n"
473
-
474
- if functions_str:
475
- result += "## functions\nnamespace functions {\n"
476
- result += functions_str + "\n"
477
- result += "}\n"
478
-
479
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Encode structured tool declaration to typescript style string.
3
+ """
4
+
5
+ import dataclasses
6
+ import json
7
+ import logging
8
+ from collections.abc import Sequence
9
+ from typing import Any
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ _TS_INDENT = " "
14
+ _TS_FIELD_DELIMITER = ",\n"
15
+
16
+
17
+ class _SchemaRegistry:
18
+ """Registry for schema definitions to handle $ref resolution"""
19
+
20
+ def __init__(self):
21
+ self.definitions = {}
22
+ self.has_self_ref = False
23
+
24
+ def register_definitions(self, defs: dict[str, Any]):
25
+ """Register schema definitions from $defs section"""
26
+ if not defs:
27
+ return
28
+ for def_name, def_schema in defs.items():
29
+ self.definitions[def_name] = def_schema
30
+
31
+ def resolve_ref(self, ref: str) -> dict[str, Any]:
32
+ """Resolve a reference to its schema definition"""
33
+ if ref == "#":
34
+ self.has_self_ref = True
35
+ return {"$self_ref": True}
36
+ elif ref.startswith("#/$defs/"):
37
+ def_name = ref.split("/")[-1]
38
+ if def_name not in self.definitions:
39
+ raise ValueError(f"Reference not found: {ref}")
40
+ return self.definitions[def_name]
41
+ else:
42
+ raise ValueError(f"Unsupported reference format: {ref}")
43
+
44
+
45
+ def _format_description(description: str, indent: str = "") -> str:
46
+ return "\n".join(
47
+ [f"{indent}// {line}" if line else "" for line in description.split("\n")]
48
+ )
49
+
50
+
51
+ class _BaseType:
52
+ description: str
53
+ constraints: dict[str, Any]
54
+
55
+ def __init__(
56
+ self,
57
+ extra_props: dict[str, Any],
58
+ *,
59
+ allowed_constraint_keys: Sequence[str] = (),
60
+ ):
61
+ self.description = extra_props.get("description", "")
62
+ self.constraints = {
63
+ k: v for k, v in extra_props.items() if k in allowed_constraint_keys
64
+ }
65
+
66
+ def to_typescript_style(self, indent: str = "") -> str:
67
+ raise NotImplementedError
68
+
69
+ def format_docstring(self, indent: str) -> str:
70
+ lines = []
71
+ if self.description:
72
+ lines.append(_format_description(self.description, indent))
73
+ if self.constraints:
74
+ constraints_str = ", ".join(
75
+ f"{k}: {v}"
76
+ for k, v in sorted(self.constraints.items(), key=lambda kv: kv[0])
77
+ )
78
+ lines.append(f"{indent}// {constraints_str}")
79
+
80
+ return "".join(x + "\n" for x in lines)
81
+
82
+
83
+ class _ParameterTypeScalar(_BaseType):
84
+ type: str
85
+
86
+ def __init__(self, type: str, extra_props: dict[str, Any] | None = None):
87
+ self.type = type
88
+
89
+ allowed_constraint_keys: list[str] = []
90
+ if self.type == "string":
91
+ allowed_constraint_keys = ["maxLength", "minLength", "pattern"]
92
+ elif self.type in ("number", "integer"):
93
+ allowed_constraint_keys = ["maximum", "minimum"]
94
+
95
+ super().__init__(
96
+ extra_props or {}, allowed_constraint_keys=allowed_constraint_keys
97
+ )
98
+
99
+ def to_typescript_style(self, indent: str = "") -> str:
100
+ # Map integer to number in TypeScript
101
+ if self.type == "integer":
102
+ return "number"
103
+ return self.type
104
+
105
+
106
+ class _ParameterTypeObject(_BaseType):
107
+ properties: list["_Parameter"]
108
+ additional_properties: Any | None = None
109
+
110
+ def __init__(
111
+ self,
112
+ json_schema_object: dict[str, Any],
113
+ registry: _SchemaRegistry | None = None,
114
+ ):
115
+ super().__init__(json_schema_object)
116
+
117
+ self.properties = []
118
+ self.additional_properties = None
119
+
120
+ if not json_schema_object:
121
+ return
122
+
123
+ if "$defs" in json_schema_object and registry:
124
+ registry.register_definitions(json_schema_object["$defs"])
125
+
126
+ self.additional_properties = json_schema_object.get("additionalProperties")
127
+ if isinstance(self.additional_properties, dict):
128
+ self.additional_properties = _parse_parameter_type(
129
+ self.additional_properties, registry
130
+ )
131
+
132
+ if "properties" not in json_schema_object:
133
+ return
134
+
135
+ required_parameters = json_schema_object.get("required", [])
136
+ optional_parameters = set(json_schema_object["properties"].keys()) - set(
137
+ required_parameters
138
+ )
139
+
140
+ self.properties = [
141
+ _Parameter(
142
+ name=name,
143
+ type=_parse_parameter_type(prop, registry),
144
+ optional=name in optional_parameters,
145
+ default=prop.get("default") if isinstance(prop, dict) else None,
146
+ )
147
+ for name, prop in json_schema_object["properties"].items()
148
+ ]
149
+
150
+ def to_typescript_style(self, indent: str = "") -> str:
151
+ # sort by optional, make the required parameters first
152
+ parameters = [p for p in self.properties if not p.optional]
153
+ opt_params = [p for p in self.properties if p.optional]
154
+
155
+ parameters = sorted(parameters, key=lambda p: p.name)
156
+ parameters.extend(sorted(opt_params, key=lambda p: p.name))
157
+
158
+ param_strs = []
159
+ for p in parameters:
160
+ one = p.to_typescript_style(indent=indent + _TS_INDENT)
161
+ param_strs.append(one)
162
+
163
+ if self.additional_properties is not None:
164
+ ap_type_str = "any"
165
+ if self.additional_properties is True:
166
+ ap_type_str = "any"
167
+ elif self.additional_properties is False:
168
+ ap_type_str = "never"
169
+ elif isinstance(self.additional_properties, _ParameterType):
170
+ ap_type_str = self.additional_properties.to_typescript_style(
171
+ indent=indent + _TS_INDENT
172
+ )
173
+ else:
174
+ raise ValueError(
175
+ f"Unknown additionalProperties: {self.additional_properties}"
176
+ )
177
+ param_strs.append(f"{indent + _TS_INDENT}[k: string]: {ap_type_str}")
178
+
179
+ if not param_strs:
180
+ return "{}"
181
+
182
+ params_str = _TS_FIELD_DELIMITER.join(param_strs)
183
+ if params_str:
184
+ # add new line before and after
185
+ params_str = f"\n{params_str}\n"
186
+ # always wrap with object
187
+ return f"{{{params_str}{indent}}}"
188
+
189
+
190
+ class _ParameterTypeArray(_BaseType):
191
+ item: "_ParameterType"
192
+
193
+ def __init__(
194
+ self,
195
+ json_schema_object: dict[str, Any],
196
+ registry: _SchemaRegistry | None = None,
197
+ ):
198
+ super().__init__(
199
+ json_schema_object, allowed_constraint_keys=("minItems", "maxItems")
200
+ )
201
+ if json_schema_object.get("items"):
202
+ self.item = _parse_parameter_type(json_schema_object["items"], registry)
203
+ else:
204
+ self.item = _ParameterTypeScalar(type="any")
205
+
206
+ def to_typescript_style(self, indent: str = "") -> str:
207
+ item_docstring = self.item.format_docstring(indent + _TS_INDENT)
208
+ if item_docstring:
209
+ return (
210
+ "Array<\n"
211
+ + item_docstring
212
+ + indent
213
+ + _TS_INDENT
214
+ + self.item.to_typescript_style(indent=indent + _TS_INDENT)
215
+ + "\n"
216
+ + indent
217
+ + ">"
218
+ )
219
+ else:
220
+ return f"Array<{self.item.to_typescript_style(indent=indent)}>"
221
+
222
+
223
+ class _ParameterTypeEnum(_BaseType):
224
+ # support scalar types only
225
+ enum: list[str | int | float | bool | None]
226
+
227
+ def __init__(self, json_schema_object: dict[str, Any]):
228
+ super().__init__(json_schema_object)
229
+ self.enum = json_schema_object["enum"]
230
+
231
+ # Validate enum values against declared type if present
232
+ if "type" in json_schema_object:
233
+ typ = json_schema_object["type"]
234
+ if isinstance(typ, list):
235
+ if len(typ) == 1:
236
+ typ = typ[0]
237
+ elif len(typ) == 2:
238
+ if "null" not in typ:
239
+ raise ValueError(f"Enum type {typ} is not supported")
240
+ else:
241
+ typ = typ[0] if typ[0] != "null" else typ[1]
242
+ else:
243
+ raise ValueError(f"Enum type {typ} is not supported")
244
+ for val in self.enum:
245
+ if val is None:
246
+ continue
247
+ if typ == "string" and not isinstance(val, str):
248
+ raise ValueError(f"Enum value {val} is not a string")
249
+ elif typ == "number" and not isinstance(val, (int, float)):
250
+ raise ValueError(f"Enum value {val} is not a number")
251
+ elif typ == "integer" and not isinstance(val, int):
252
+ raise ValueError(f"Enum value {val} is not an integer")
253
+ elif typ == "boolean" and not isinstance(val, bool):
254
+ raise ValueError(f"Enum value {val} is not a boolean")
255
+
256
+ def to_typescript_style(self, indent: str = "") -> str:
257
+ return " | ".join(
258
+ [f'"{e}"' if isinstance(e, str) else str(e) for e in self.enum]
259
+ )
260
+
261
+
262
+ class _ParameterTypeAnyOf(_BaseType):
263
+ types: list["_ParameterType"]
264
+
265
+ def __init__(
266
+ self,
267
+ json_schema_object: dict[str, Any],
268
+ registry: _SchemaRegistry | None = None,
269
+ ):
270
+ super().__init__(json_schema_object)
271
+ self.types = [
272
+ _parse_parameter_type(t, registry) for t in json_schema_object["anyOf"]
273
+ ]
274
+
275
+ def to_typescript_style(self, indent: str = "") -> str:
276
+ return " | ".join([t.to_typescript_style(indent=indent) for t in self.types])
277
+
278
+
279
+ class _ParameterTypeUnion(_BaseType):
280
+ types: list[str]
281
+
282
+ def __init__(self, json_schema_object: dict[str, Any]):
283
+ super().__init__(json_schema_object)
284
+
285
+ mapping = {
286
+ "string": "string",
287
+ "number": "number",
288
+ "integer": "number",
289
+ "boolean": "boolean",
290
+ "null": "null",
291
+ "object": "{}",
292
+ "array": "Array<any>",
293
+ }
294
+ self.types = [mapping[t] for t in json_schema_object["type"]]
295
+
296
+ def to_typescript_style(self, indent: str = "") -> str:
297
+ return " | ".join(self.types)
298
+
299
+
300
+ class _ParameterTypeRef(_BaseType):
301
+ ref_name: str
302
+ is_self_ref: bool = False
303
+
304
+ def __init__(self, json_schema_object: dict[str, Any], registry: _SchemaRegistry):
305
+ super().__init__(json_schema_object)
306
+
307
+ ref = json_schema_object["$ref"]
308
+ resolved_schema = registry.resolve_ref(ref)
309
+
310
+ if resolved_schema.get("$self_ref", False):
311
+ self.ref_name = "parameters"
312
+ self.is_self_ref = True
313
+ else:
314
+ self.ref_name = ref.split("/")[-1]
315
+
316
+ def to_typescript_style(self, indent: str = "") -> str:
317
+ return self.ref_name
318
+
319
+
320
+ _ParameterType = (
321
+ _ParameterTypeScalar
322
+ | _ParameterTypeObject
323
+ | _ParameterTypeArray
324
+ | _ParameterTypeEnum
325
+ | _ParameterTypeAnyOf
326
+ | _ParameterTypeUnion
327
+ | _ParameterTypeRef
328
+ )
329
+
330
+
331
+ @dataclasses.dataclass
332
+ class _Parameter:
333
+ """
334
+ A parameter in a function, or a field in a object.
335
+ It consists of the type as well as the name.
336
+ """
337
+
338
+ type: _ParameterType
339
+ name: str = "_"
340
+ optional: bool = True
341
+ default: Any | None = None
342
+
343
+ @classmethod
344
+ def parse_extended(cls, attributes: dict[str, Any]) -> "_Parameter":
345
+ if not attributes:
346
+ raise ValueError("attributes is empty")
347
+
348
+ return cls(
349
+ name=attributes.get("name", "_"),
350
+ type=_parse_parameter_type(attributes),
351
+ optional=attributes.get("optional", False),
352
+ default=attributes.get("default"),
353
+ )
354
+
355
+ def to_typescript_style(self, indent: str = "") -> str:
356
+ comments = self.type.format_docstring(indent)
357
+
358
+ if self.default is not None:
359
+ default_repr = (
360
+ json.dumps(self.default, ensure_ascii=False)
361
+ if not isinstance(self.default, (int, float, bool))
362
+ else repr(self.default)
363
+ )
364
+ comments += f"{indent}// Default: {default_repr}\n"
365
+
366
+ return (
367
+ comments
368
+ + f"{indent}{self.name}{'?' if self.optional else ''}: {self.type.to_typescript_style(indent=indent)}"
369
+ )
370
+
371
+
372
+ def _parse_parameter_type(
373
+ json_schema_object: dict[str, Any] | bool, registry: _SchemaRegistry | None = None
374
+ ) -> _ParameterType:
375
+ if isinstance(json_schema_object, bool):
376
+ if json_schema_object:
377
+ return _ParameterTypeScalar(type="any")
378
+ else:
379
+ logger.warning(
380
+ f"Warning: Boolean value {json_schema_object} is not supported, use null instead."
381
+ )
382
+ return _ParameterTypeScalar(type="null")
383
+
384
+ if "$ref" in json_schema_object and registry:
385
+ return _ParameterTypeRef(json_schema_object, registry)
386
+
387
+ if "anyOf" in json_schema_object:
388
+ return _ParameterTypeAnyOf(json_schema_object, registry)
389
+ elif "enum" in json_schema_object:
390
+ return _ParameterTypeEnum(json_schema_object)
391
+ elif "type" in json_schema_object:
392
+ typ = json_schema_object["type"]
393
+ if isinstance(typ, list):
394
+ return _ParameterTypeUnion(json_schema_object)
395
+ elif typ == "object":
396
+ return _ParameterTypeObject(json_schema_object, registry)
397
+ elif typ == "array":
398
+ return _ParameterTypeArray(json_schema_object, registry)
399
+ else:
400
+ return _ParameterTypeScalar(typ, json_schema_object)
401
+ elif json_schema_object == {}:
402
+ return _ParameterTypeScalar(type="any")
403
+ else:
404
+ raise ValueError(f"Invalid JSON Schema object: {json_schema_object}")
405
+
406
+
407
+ def _openai_function_to_typescript_style(
408
+ function: dict[str, Any],
409
+ ) -> str:
410
+ """Convert OpenAI function definition (dict) to TypeScript style string."""
411
+ registry = _SchemaRegistry()
412
+ parameters = function.get("parameters") or {}
413
+ parsed = _ParameterTypeObject(parameters, registry)
414
+
415
+ interfaces = []
416
+ root_interface_name = None
417
+ if registry.has_self_ref:
418
+ root_interface_name = "parameters"
419
+ params_str = _TS_FIELD_DELIMITER.join(
420
+ [p.to_typescript_style(indent=_TS_INDENT) for p in parsed.properties]
421
+ )
422
+ params_str = f"\n{params_str}\n" if params_str else ""
423
+ interface_def = f"interface {root_interface_name} {{{params_str}}}"
424
+ interfaces.append(interface_def)
425
+
426
+ definitions_copy = dict(registry.definitions)
427
+ for def_name, def_schema in definitions_copy.items():
428
+ obj_type = _parse_parameter_type(def_schema, registry)
429
+ params_str = obj_type.to_typescript_style()
430
+
431
+ description_part = ""
432
+ if obj_description := def_schema.get("description", ""):
433
+ description_part = _format_description(obj_description) + "\n"
434
+
435
+ interface_def = f"{description_part}interface {def_name} {params_str}"
436
+ interfaces.append(interface_def)
437
+
438
+ interface_str = "\n".join(interfaces)
439
+ raw_function_name = function.get("name", "function")
440
+ function_name = raw_function_name.replace("-", "_").replace(" ", "_")
441
+ if root_interface_name:
442
+ type_def = f"type {function_name} = (_: {root_interface_name}) => any;"
443
+ else:
444
+ params_str = parsed.to_typescript_style()
445
+ type_def = f"type {function_name} = (_: {params_str}) => any;"
446
+
447
+ description = function.get("description")
448
+ return "\n".join(
449
+ filter(
450
+ bool,
451
+ [
452
+ interface_str,
453
+ ((description and _format_description(description)) or ""),
454
+ type_def,
455
+ ],
456
+ )
457
+ )
458
+
459
+
460
+ def encode_tools_to_typescript_style(
461
+ tools: list[dict[str, Any]],
462
+ ) -> str:
463
+ """
464
+ Convert tools (list of dict) to TypeScript style string.
465
+
466
+ Supports OpenAI format: {"type": "function", "function": {...}}
467
+
468
+ Args:
469
+ tools: List of tool definitions in dict format
470
+
471
+ Returns:
472
+ TypeScript style string representation of the tools
473
+ """
474
+ if not tools:
475
+ return ""
476
+
477
+ functions = []
478
+
479
+ for tool in tools:
480
+ tool_type = tool.get("type")
481
+ if tool_type == "function":
482
+ func_def = tool.get("function", {})
483
+ if func_def:
484
+ functions.append(_openai_function_to_typescript_style(func_def))
485
+ else:
486
+ # Skip unsupported tool types (like "_plugin")
487
+ continue
488
+
489
+ if not functions:
490
+ return ""
491
+
492
+ functions_str = "\n".join(functions)
493
+ result = "# Tools\n\n"
494
+
495
+ if functions_str:
496
+ result += "## functions\nnamespace functions {\n"
497
+ result += functions_str + "\n"
498
+ result += "}\n"
499
+
500
+ return result