Spaces:
Running
Running
| /** | |
| * @file Tokenizers are used to prepare textual inputs for a model. | |
| * | |
| * **Example:** Create an `AutoTokenizer` and use it to tokenize a sentence. | |
| * This will automatically detect the tokenizer type based on the tokenizer class defined in `tokenizer.json`. | |
| * ```javascript | |
| * import { AutoTokenizer } from '@huggingface/transformers'; | |
| * | |
| * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); | |
| * const { input_ids } = await tokenizer('I love transformers!'); | |
| * // Tensor { | |
| * // data: BigInt64Array(6) [101n, 1045n, 2293n, 19081n, 999n, 102n], | |
| * // dims: [1, 6], | |
| * // type: 'int64', | |
| * // size: 6, | |
| * // } | |
| * ``` | |
| * | |
| * @module tokenizers | |
| */ | |
| import { | |
| Callable, | |
| } from './utils/generic.js'; | |
| import { | |
| reverseDictionary, | |
| escapeRegExp, | |
| isIntegralNumber, | |
| mergeArrays, | |
| len, | |
| } from './utils/core.js'; | |
| import { | |
| getModelJSON, | |
| } from './utils/hub.js'; | |
| import { max, min, round } from './utils/maths.js'; | |
| import { Tensor } from './utils/tensor.js'; | |
| import { | |
| PriorityQueue, | |
| TokenLattice, | |
| CharTrie, | |
| DictionarySplitter, | |
| LRUCache, | |
| } from './utils/data-structures.js'; | |
| import { Template } from '@huggingface/jinja'; | |
| import { | |
| WHISPER_LANGUAGE_MAPPING | |
| } from './models/whisper/common_whisper.js'; | |
| /** | |
| * @typedef {Object} TokenizerProperties Additional tokenizer-specific properties. | |
| * @property {boolean} [legacy=false] Whether or not the `legacy` behavior of the tokenizer should be used. | |
| * @typedef {import('./utils/hub.js').PretrainedOptions & TokenizerProperties} PretrainedTokenizerOptions | |
| */ | |
| /** | |
| * Loads a tokenizer from the specified path. | |
| * @param {string} pretrained_model_name_or_path The path to the tokenizer directory. | |
| * @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer. | |
| * @returns {Promise<any[]>} A promise that resolves with information about the loaded tokenizer. | |
| */ | |
| async function loadTokenizer(pretrained_model_name_or_path, options) { | |
| const info = await Promise.all([ | |
| getModelJSON(pretrained_model_name_or_path, 'tokenizer.json', true, options), | |
| getModelJSON(pretrained_model_name_or_path, 'tokenizer_config.json', true, options), | |
| ]) | |
| // Override legacy option if `options.legacy` is not null | |
| if (options.legacy !== null) { | |
| info[1].legacy = options.legacy; | |
| } | |
| return info; | |
| } | |
| /** | |
| * Helper function to split a string on a regex, but keep the delimiters. | |
| * This is required, because the JavaScript `.split()` method does not keep the delimiters, | |
| * and wrapping in a capturing group causes issues with existing capturing groups (due to nesting). | |
| * @param {string} text The text to split. | |
| * @param {RegExp} regex The regex to split on. | |
| * @returns {string[]} The split string. | |
| */ | |
| function regexSplit(text, regex) { | |
| const result = []; | |
| let prev = 0; | |
| for (const match of text.matchAll(regex)) { | |
| const fullMatch = match[0]; | |
| if (prev < match.index) { | |
| result.push(text.slice(prev, match.index)); | |
| } | |
| if (fullMatch.length > 0) { | |
| result.push(fullMatch); | |
| } | |
| prev = match.index + fullMatch.length; | |
| } | |
| if (prev < text.length) { | |
| result.push(text.slice(prev)); | |
| } | |
| return result; | |
| } | |
| /** | |
| * Helper method to construct a pattern from a config object. | |
| * @param {Object} pattern The pattern object. | |
| * @param {boolean} invert Whether to invert the pattern. | |
| * @returns {RegExp|null} The compiled pattern. | |
| */ | |
| function createPattern(pattern, invert = true) { | |
| if (pattern.Regex !== undefined) { | |
| // In certain cases, the pattern may contain unnecessary escape sequences (e.g., \# or \& or \~). | |
| // i.e., valid in Python (where the patterns are exported from) but invalid in JavaScript (where the patterns are parsed). | |
| // This isn't an issue when creating the regex w/o the 'u' flag, but it is when the 'u' flag is used. | |
| // For this reason, it is necessary to remove these backslashes before creating the regex. | |
| // See https://stackoverflow.com/a/63007777/13989043 for more information | |
| let regex = pattern.Regex.replace(/\\([#&~])/g, '$1'); // TODO: add more characters to this list if necessary | |
| // We also handle special cases where the regex contains invalid (non-JS compatible) syntax. | |
| for (const [key, value] of PROBLEMATIC_REGEX_MAP) { | |
| regex = regex.replaceAll(key, value); | |
| } | |
| return new RegExp(regex, 'gu'); | |
| } else if (pattern.String !== undefined) { | |
| const escaped = escapeRegExp(pattern.String); | |
| // NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split() | |
| return new RegExp(invert ? escaped : `(${escaped})`, 'gu'); | |
| } else { | |
| console.warn('Unknown pattern type:', pattern) | |
| return null; | |
| } | |
| } | |
| /** | |
| * Helper function to convert an Object to a Map | |
| * @param {Object} obj The object to convert. | |
| * @returns {Map<string, any>} The map. | |
| */ | |
| function objectToMap(obj) { | |
| return new Map(Object.entries(obj)); | |
| } | |
| /** | |
| * Helper function to convert a tensor to a list before decoding. | |
| * @param {Tensor} tensor The tensor to convert. | |
| * @returns {number[]} The tensor as a list. | |
| */ | |
| function prepareTensorForDecode(tensor) { | |
| const dims = tensor.dims; | |
| switch (dims.length) { | |
| case 1: | |
| return tensor.tolist(); | |
| case 2: | |
| if (dims[0] !== 1) { | |
| throw new Error('Unable to decode tensor with `batch size !== 1`. Use `tokenizer.batch_decode(...)` for batched inputs.'); | |
| } | |
| return tensor.tolist()[0]; | |
| default: | |
| throw new Error(`Expected tensor to have 1-2 dimensions, got ${dims.length}.`) | |
| } | |
| } | |
| /** | |
| * Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms | |
| * @param {string} text The text to clean up. | |
| * @returns {string} The cleaned up text. | |
| */ | |
| function clean_up_tokenization(text) { | |
| // Clean up a list of simple English tokenization artifacts | |
| // like spaces before punctuations and abbreviated forms | |
| return text.replace(/ \./g, '.') | |
| .replace(/ \?/g, '?') | |
| .replace(/ \!/g, '!') | |
| .replace(/ ,/g, ',') | |
| .replace(/ \' /g, "'") | |
| .replace(/ n\'t/g, "n't") | |
| .replace(/ \'m/g, "'m") | |
| .replace(/ \'s/g, "'s") | |
| .replace(/ \'ve/g, "'ve") | |
| .replace(/ \'re/g, "'re"); | |
| } | |
| /** | |
| * Helper function to remove accents from a string. | |
| * @param {string} text The text to remove accents from. | |
| * @returns {string} The text with accents removed. | |
| */ | |
| function remove_accents(text) { | |
| return text.replace(/\p{M}/gu, ''); | |
| } | |
| /** | |
| * Helper function to lowercase a string and remove accents. | |
| * @param {string} text The text to lowercase and remove accents from. | |
| * @returns {string} The lowercased text with accents removed. | |
| */ | |
| function lowercase_and_remove_accent(text) { | |
| return remove_accents(text.toLowerCase()); | |
| } | |
| /** | |
| * Checks whether the given Unicode codepoint represents a CJK (Chinese, Japanese, or Korean) character. | |
| * | |
| * A "chinese character" is defined as anything in the CJK Unicode block: | |
| * https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |
| * | |
| * Note that the CJK Unicode block is NOT all Japanese and Korean characters, despite its name. | |
| * The modern Korean Hangul alphabet is a different block, as is Japanese Hiragana and Katakana. | |
| * Those alphabets are used to write space-separated words, so they are not treated specially | |
| * and are handled like all other languages. | |
| * | |
| * @param {number|bigint} cp The Unicode codepoint to check. | |
| * @returns {boolean} True if the codepoint represents a CJK character, false otherwise. | |
| */ | |
| export function is_chinese_char(cp) { | |
| return ( | |
| (cp >= 0x4E00 && cp <= 0x9FFF) | |
| || (cp >= 0x3400 && cp <= 0x4DBF) | |
| || (cp >= 0x20000 && cp <= 0x2A6DF) | |
| || (cp >= 0x2A700 && cp <= 0x2B73F) | |
| || (cp >= 0x2B740 && cp <= 0x2B81F) | |
| || (cp >= 0x2B820 && cp <= 0x2CEAF) | |
| || (cp >= 0xF900 && cp <= 0xFAFF) | |
| || (cp >= 0x2F800 && cp <= 0x2FA1F) | |
| ) | |
| } | |
| /** | |
| * Helper function to fuse consecutive unknown tokens. | |
| * @param {string[]} arr The list of input tokens | |
| * @param {Map<string, any>} tokens_to_ids The mapping from tokens to token ids. | |
| * @param {number} unk_token_id The value to fuse on. | |
| * @private | |
| */ | |
| function fuse_unk(arr, tokens_to_ids, unk_token_id) { | |
| const fused = []; | |
| let i = 0; | |
| while (i < arr.length) { | |
| fused.push(arr[i]) | |
| if ((tokens_to_ids.get(arr[i]) ?? unk_token_id) !== unk_token_id) { | |
| ++i; | |
| continue; | |
| } | |
| while (++i < arr.length && (tokens_to_ids.get(arr[i]) ?? unk_token_id) === unk_token_id) { | |
| if (tokens_to_ids.get(fused.at(-1)) !== unk_token_id) { | |
| fused[fused.length - 1] += arr[i]; | |
| } | |
| } | |
| } | |
| return fused; | |
| } | |
| /** | |
| * Split a string on whitespace. | |
| * @param {string} text The text to split. | |
| * @returns {string[]} The split string. | |
| */ | |
| function whitespace_split(text) { | |
| return text.match(/\S+/g) || []; | |
| } | |
| const PUNCTUATION_REGEX = '\\p{P}\\u0021-\\u002F\\u003A-\\u0040\\u005B-\\u0060\\u007B-\\u007E'; | |
| const PUNCTUATION_ONLY_REGEX = new RegExp(`^[${PUNCTUATION_REGEX}]+$`, 'gu'); | |
| const BLOOM_SPLIT_CHARS = '.,!?\u2026\u3002\uff0c\u3001\u0964\u06d4\u060c'; | |
| // A mapping of regex patterns to their equivalent (but possibly longer) JS-compatible versions. | |
| const PROBLEMATIC_REGEX_MAP = new Map([ | |
| // This uses the case insensitive group modifier, which is not supported in JavaScript. | |
| // When parsing the regex, an "Invalid group" error is thrown. | |
| ["(?i:'s|'t|'re|'ve|'m|'ll|'d)", "(?:'([sS]|[tT]|[rR][eE]|[vV][eE]|[mM]|[lL][lL]|[dD]))"], | |
| // Used to override the default (invalid) regex of the bloom pretokenizer. | |
| // For more information, see https://github.com/huggingface/transformers.js/issues/94 | |
| [` ?[^(\\s|[${BLOOM_SPLIT_CHARS}])]+`, ` ?[^\\s${BLOOM_SPLIT_CHARS}]+`], | |
| ]) | |
| /** | |
| * Represent a token added by the user on top of the existing Model vocabulary. | |
| * AddedToken can be configured to specify the behavior they should have in various situations like: | |
| * - Whether they should only match single words | |
| * - Whether to include any whitespace on its left or right | |
| */ | |
| class AddedToken { | |
| /** | |
| * Creates a new instance of AddedToken. | |
| * @param {Object} config Added token configuration object. | |
| * @param {string} config.content The content of the added token. | |
| * @param {number} config.id The id of the added token. | |
| * @param {boolean} [config.single_word=false] Whether this token must be a single word or can break words. | |
| * @param {boolean} [config.lstrip=false] Whether this token should strip whitespaces on its left. | |
| * @param {boolean} [config.rstrip=false] Whether this token should strip whitespaces on its right. | |
| * @param {boolean} [config.normalized=false] Whether this token should be normalized. | |
| * @param {boolean} [config.special=false] Whether this token is special. | |
| */ | |
| constructor(config) { | |
| this.content = config.content; | |
| this.id = config.id; | |
| this.single_word = config.single_word ?? false; | |
| this.lstrip = config.lstrip ?? false; | |
| this.rstrip = config.rstrip ?? false; | |
| this.special = config.special ?? false; | |
| this.normalized = config.normalized ?? null; | |
| } | |
| } | |
| /** | |
| * Abstract base class for tokenizer models. | |
| * | |
| * @extends Callable | |
| */ | |
| export class TokenizerModel extends Callable { | |
| /** | |
| * Creates a new instance of TokenizerModel. | |
| * @param {Object} config The configuration object for the TokenizerModel. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.config = config; | |
| /** @type {string[]} */ | |
| this.vocab = []; | |
| /** | |
| * A mapping of tokens to ids. | |
| * @type {Map<string, number>} | |
| */ | |
| this.tokens_to_ids = new Map(); | |
| this.unk_token_id = undefined; | |
| this.unk_token = undefined; | |
| this.end_of_word_suffix = undefined; | |
| /** @type {boolean} Whether to fuse unknown tokens when encoding. Defaults to false. */ | |
| this.fuse_unk = this.config.fuse_unk ?? false; | |
| } | |
| /** | |
| * Instantiates a new TokenizerModel instance based on the configuration object provided. | |
| * @param {Object} config The configuration object for the TokenizerModel. | |
| * @param {...*} args Optional arguments to pass to the specific TokenizerModel constructor. | |
| * @returns {TokenizerModel} A new instance of a TokenizerModel. | |
| * @throws Will throw an error if the TokenizerModel type in the config is not recognized. | |
| */ | |
| static fromConfig(config, ...args) { | |
| switch (config.type) { | |
| case 'WordPiece': | |
| return new WordPieceTokenizer(config); | |
| case 'Unigram': | |
| // @ts-ignore | |
| return new Unigram(config, ...args); | |
| case 'BPE': | |
| return new BPE(config); | |
| default: | |
| // Some older tokenizers, like `google-t5/t5-small`, `openai-community/gpt2`, and `distilbert/distilbert-base-uncased`, do not have a `type` field. | |
| // In this case, we can infer the tokenizer type based on the structure of the `vocab` field and other properties. | |
| if (config.vocab) { | |
| if (Array.isArray(config.vocab)) { | |
| // config.vocab is of type `[string, number][]` | |
| // @ts-ignore | |
| return new Unigram(config, ...args); | |
| } else if (Object.hasOwn(config, 'continuing_subword_prefix') && Object.hasOwn(config, 'unk_token')) { | |
| if (Object.hasOwn(config, 'merges')) { | |
| return new BPE(config); | |
| } else { | |
| return new WordPieceTokenizer(config); | |
| } | |
| } else { | |
| // @ts-ignore | |
| return new LegacyTokenizerModel(config, ...args); | |
| } | |
| } | |
| throw new Error(`Unknown TokenizerModel type: ${config.type}`); | |
| } | |
| } | |
| /** | |
| * Internal function to call the TokenizerModel instance. | |
| * @param {string[]} tokens The tokens to encode. | |
| * @returns {string[]} The encoded tokens. | |
| */ | |
| _call(tokens) { | |
| tokens = this.encode(tokens); | |
| if (this.fuse_unk) { | |
| // Fuse unknown tokens | |
| tokens = fuse_unk(tokens, this.tokens_to_ids, this.unk_token_id); | |
| } | |
| return tokens; | |
| } | |
| /** | |
| * Encodes a list of tokens into a list of token IDs. | |
| * @param {string[]} tokens The tokens to encode. | |
| * @returns {string[]} The encoded tokens. | |
| * @throws Will throw an error if not implemented in a subclass. | |
| */ | |
| encode(tokens) { | |
| throw Error("encode should be implemented in subclass.") | |
| } | |
| /** | |
| * Converts a list of tokens into a list of token IDs. | |
| * @param {string[]} tokens The tokens to convert. | |
| * @returns {number[]} The converted token IDs. | |
| */ | |
| convert_tokens_to_ids(tokens) { | |
| return tokens.map(t => this.tokens_to_ids.get(t) ?? this.unk_token_id); | |
| } | |
| /** | |
| * Converts a list of token IDs into a list of tokens. | |
| * @param {number[]|bigint[]} ids The token IDs to convert. | |
| * @returns {string[]} The converted tokens. | |
| */ | |
| convert_ids_to_tokens(ids) { | |
| return ids.map(i => this.vocab[i] ?? this.unk_token); | |
| } | |
| } | |
| /** | |
| * A subclass of TokenizerModel that uses WordPiece encoding to encode tokens. | |
| * @extends TokenizerModel | |
| */ | |
| class WordPieceTokenizer extends TokenizerModel { | |
| /** | |
| * @param {Object} config The configuration object. | |
| * @param {Object} config.vocab A mapping of tokens to ids. | |
| * @param {string} config.unk_token The unknown token string. | |
| * @param {string} config.continuing_subword_prefix The prefix to use for continuing subwords. | |
| * @param {number} [config.max_input_chars_per_word=100] The maximum number of characters per word. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| /** | |
| * A mapping of tokens to ids. | |
| * @type {Map<string, number>} | |
| */ | |
| this.tokens_to_ids = objectToMap(config.vocab); | |
| /** | |
| * The id of the unknown token. | |
| * @type {number} | |
| */ | |
| this.unk_token_id = this.tokens_to_ids.get(config.unk_token); | |
| /** | |
| * The unknown token string. | |
| * @type {string} | |
| */ | |
| this.unk_token = config.unk_token; | |
| /** | |
| * The maximum number of characters allowed per word. | |
| * @type {number} | |
| */ | |
| this.max_input_chars_per_word = config.max_input_chars_per_word ?? 100; | |
| /** | |
| * An array of tokens. | |
| * @type {string[]} | |
| */ | |
| this.vocab = new Array(this.tokens_to_ids.size); | |
| for (const [key, value] of this.tokens_to_ids) { | |
| this.vocab[value] = key; | |
| } | |
| } | |
| /** | |
| * Encodes an array of tokens using WordPiece encoding. | |
| * @param {string[]} tokens The tokens to encode. | |
| * @returns {string[]} An array of encoded tokens. | |
| */ | |
| encode(tokens) { | |
| const outputTokens = []; | |
| for (const token of tokens) { | |
| const chars = [...token]; | |
| if (chars.length > this.max_input_chars_per_word) { | |
| outputTokens.push(this.unk_token); | |
| continue; | |
| } | |
| let isUnknown = false; | |
| let start = 0; | |
| const subTokens = []; | |
| while (start < chars.length) { | |
| let end = chars.length; | |
| let currentSubstring = null; | |
| while (start < end) { | |
| let substr = chars.slice(start, end).join(''); | |
| if (start > 0) { | |
| substr = this.config.continuing_subword_prefix + substr; | |
| } | |
| if (this.tokens_to_ids.has(substr)) { | |
| currentSubstring = substr; | |
| break; | |
| } | |
| --end; | |
| } | |
| if (currentSubstring === null) { | |
| isUnknown = true; | |
| break; | |
| } | |
| subTokens.push(currentSubstring); | |
| start = end; | |
| } | |
| if (isUnknown) { | |
| outputTokens.push(this.unk_token); | |
| } else { | |
| outputTokens.push(...subTokens); | |
| } | |
| } | |
| return outputTokens; | |
| } | |
| } | |
| /** | |
| * Class representing a Unigram tokenizer model. | |
| * @extends TokenizerModel | |
| */ | |
| class Unigram extends TokenizerModel { | |
| /** | |
| * Create a new Unigram tokenizer model. | |
| * @param {Object} config The configuration object for the Unigram model. | |
| * @param {number} config.unk_id The ID of the unknown token | |
| * @param {[string, number][]} config.vocab A 2D array representing a mapping of tokens to scores. | |
| * @param {Object} moreConfig Additional configuration object for the Unigram model. | |
| */ | |
| constructor(config, moreConfig) { | |
| super(config); | |
| const vocabSize = config.vocab.length; | |
| this.vocab = new Array(vocabSize); | |
| /** @type {number[]} */ | |
| this.scores = new Array(vocabSize); | |
| for (let i = 0; i < vocabSize; ++i) { | |
| [this.vocab[i], this.scores[i]] = config.vocab[i]; | |
| } | |
| this.unk_token_id = config.unk_id; | |
| this.unk_token = this.vocab[config.unk_id]; | |
| this.tokens_to_ids = new Map(this.vocab.map((x, i) => [x, i])); | |
| this.bos_token = ' '; // beginning of a sentence token | |
| this.bos_token_id = this.tokens_to_ids.get(this.bos_token); // NOTE: may be undefined | |
| this.eos_token = moreConfig.eos_token; | |
| this.eos_token_id = this.tokens_to_ids.get(this.eos_token); | |
| this.unk_token = this.vocab[this.unk_token_id]; | |
| this.minScore = min(this.scores)[0]; | |
| this.unk_score = this.minScore - 10.0; | |
| this.scores[this.unk_token_id] = this.unk_score; | |
| this.trie = new CharTrie(); | |
| this.trie.extend(this.vocab); | |
| // NOTE: `fuse_unk` is hardcoded to true for Unigram models | |
| // See: https://github.com/huggingface/tokenizers/blob/b58227c7f1ccf8b73ee2268354336da56d91e492/tokenizers/src/models/unigram/model.rs#L119 | |
| this.fuse_unk = true; | |
| } | |
| /** | |
| * Populates lattice nodes. | |
| * @param {TokenLattice} lattice The token lattice to populate with nodes. | |
| */ | |
| populateNodes(lattice) { | |
| const chars = lattice.chars; | |
| const mblen = 1; | |
| let beginPos = 0; | |
| while (beginPos < chars.length) { | |
| let hasSingleNode = false; | |
| const tokens = []; | |
| const sliced = chars.slice(beginPos).join(''); | |
| const prefixedTokens = this.trie.commonPrefixSearch(sliced); | |
| for (const token of prefixedTokens) { | |
| tokens.push(token); | |
| const tokenId = this.tokens_to_ids.get(token); | |
| const tokenScore = this.scores[tokenId]; | |
| const n = len(token); | |
| lattice.insert(beginPos, n, tokenScore, tokenId); | |
| if (!hasSingleNode && n === mblen) { | |
| hasSingleNode = true; | |
| } | |
| } | |
| if (!hasSingleNode) { | |
| lattice.insert(beginPos, mblen, this.unk_score, this.unk_token_id); | |
| } | |
| beginPos += mblen; | |
| } | |
| } | |
| /** | |
| * Encodes an array of tokens into an array of subtokens using the unigram model. | |
| * | |
| * @param {string} normalized The normalized string. | |
| * @returns {string[]} An array of subtokens obtained by encoding the input tokens using the unigram model. | |
| */ | |
| tokenize(normalized) { | |
| const lattice = new TokenLattice(normalized, this.bos_token_id, this.eos_token_id); | |
| this.populateNodes(lattice); | |
| return lattice.tokens(); | |
| } | |
| /** | |
| * Encodes an array of tokens using Unigram encoding. | |
| * @param {string[]} tokens The tokens to encode. | |
| * @returns {string[]} An array of encoded tokens. | |
| */ | |
| encode(tokens) { | |
| const toReturn = []; | |
| for (const token of tokens) { | |
| const tokenized = this.tokenize(token); | |
| toReturn.push(...tokenized); | |
| } | |
| return toReturn; | |
| } | |
| } | |
| /** | |
| * Returns list of utf-8 byte and a mapping to unicode strings. | |
| * Specifically avoids mapping to whitespace/control characters the BPE code barfs on. | |
| * @returns {Object} Object with utf-8 byte keys and unicode string values. | |
| */ | |
| const BYTES_TO_UNICODE = (() => { | |
| // Returns list of utf-8 byte and a mapping to unicode strings. | |
| // We specifically avoids mapping to whitespace/control characters | |
| // the bpe code barfs on. | |
| const bs = [ | |
| ...Array.from({ length: "~".charCodeAt(0) - "!".charCodeAt(0) + 1 }, (_, i) => i + "!".charCodeAt(0)), | |
| ...Array.from({ length: "¬".charCodeAt(0) - "¡".charCodeAt(0) + 1 }, (_, i) => i + "¡".charCodeAt(0)), | |
| ...Array.from({ length: "ÿ".charCodeAt(0) - "®".charCodeAt(0) + 1 }, (_, i) => i + "®".charCodeAt(0)), | |
| ]; | |
| const cs = bs.slice(); | |
| let n = 0; | |
| for (let b = 0; b < 256; ++b) { | |
| if (!bs.includes(b)) { | |
| bs.push(b); | |
| cs.push(256 + n); | |
| n += 1; | |
| } | |
| } | |
| const ccs = cs.map(n => String.fromCharCode(n)); | |
| return Object.fromEntries(bs.map((b, i) => [b, ccs[i]])); | |
| })(); | |
| const UNICODE_TO_BYTES = reverseDictionary(BYTES_TO_UNICODE); | |
| /** | |
| * @typedef {Object} BPENode | |
| * @property {string} token The token associated with the node | |
| * @property {number} bias A positional bias for the node. | |
| * @property {number} [score] The score of the node. | |
| * @property {BPENode} [prev] The previous node in the linked list. | |
| * @property {BPENode} [next] The next node in the linked list. | |
| */ | |
| /** | |
| * BPE class for encoding text into Byte-Pair-Encoding (BPE) tokens. | |
| * @extends TokenizerModel | |
| */ | |
| class BPE extends TokenizerModel { | |
| /** | |
| * Create a BPE instance. | |
| * @param {Object} config The configuration object for BPE. | |
| * @param {Object} config.vocab A mapping of tokens to ids. | |
| * @param {string[]|[string, string][]} config.merges An array of BPE merges as strings. | |
| * @param {string} config.unk_token The unknown token used for out of vocabulary words. | |
| * @param {string} config.end_of_word_suffix The suffix to place at the end of each word. | |
| * @param {string} [config.continuing_subword_suffix] The suffix to insert between words. | |
| * @param {boolean} [config.byte_fallback=false] Whether to use spm byte-fallback trick (defaults to False) | |
| * @param {boolean} [config.ignore_merges=false] Whether or not to match tokens with the vocab before using merges. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| /** @type {Map<string, number>} */ | |
| this.tokens_to_ids = objectToMap(config.vocab); | |
| this.unk_token_id = this.tokens_to_ids.get(config.unk_token); | |
| this.unk_token = config.unk_token; | |
| this.vocab = new Array(this.tokens_to_ids.size); | |
| for (const [key, value] of this.tokens_to_ids) { | |
| this.vocab[value] = key; | |
| } | |
| // Tokenizers >= 0.20.0 serializes BPE merges as a [string, string][] instead of a string[], | |
| // which resolves the ambiguity for merges containing spaces. | |
| const use_new_merge_format = Array.isArray(config.merges[0]); | |
| /** @type {[string, string][]} */ | |
| this.merges = use_new_merge_format | |
| ? /** @type {[string, string][]} */(config.merges) | |
| : (/** @type {string[]} */(config.merges)).map(x => /** @type {[string, string]} */(x.split(' ', 2))); | |
| this.bpe_ranks = new Map(this.merges.map((x, i) => [JSON.stringify(x), i])); | |
| this.end_of_word_suffix = config.end_of_word_suffix; | |
| // NOTE: `continuing_subword_suffix` is custom (to support `BlenderbotSmallTokenizer`) | |
| this.continuing_subword_suffix = config.continuing_subword_suffix ?? null; | |
| this.byte_fallback = this.config.byte_fallback ?? false; | |
| if (this.byte_fallback) { | |
| this.text_encoder = new TextEncoder(); | |
| } | |
| this.ignore_merges = this.config.ignore_merges ?? false; | |
| /** | |
| * The maximum length we should cache in a model. | |
| * Strings that are too long have minimal chances to cache hit anyway | |
| */ | |
| this.max_length_to_cache = 256; | |
| /** | |
| * The default capacity for a `BPE`'s internal cache. | |
| */ | |
| this.cache_capacity = 10000; | |
| this.cache = new LRUCache(this.cache_capacity); | |
| } | |
| /** | |
| * Clears the cache. | |
| */ | |
| clear_cache() { | |
| this.cache.clear(); | |
| } | |
| /** | |
| * Apply Byte-Pair-Encoding (BPE) to a given token. Efficient heap-based priority | |
| * queue implementation adapted from https://github.com/belladoreai/llama-tokenizer-js. | |
| * @param {string} token The token to encode. | |
| * @returns {string[]} The BPE encoded tokens. | |
| */ | |
| bpe(token) { | |
| if (token.length === 0) { | |
| return []; | |
| } | |
| const cached = this.cache.get(token); | |
| if (cached !== undefined) { | |
| return cached; | |
| } | |
| const word = Array.from(token); | |
| if (this.end_of_word_suffix) { | |
| word[word.length - 1] += this.end_of_word_suffix; | |
| } | |
| let result = []; | |
| if (word.length > 1) { | |
| // Create a priority queue to store the nodes that will be merged. | |
| // The comparator function compares the scores of the nodes. | |
| const queue = new PriorityQueue((a, b) => a.score < b.score); | |
| // Construct a doubly-linked list of nodes that will be inserted into the priority queue, | |
| // starting with the individual characters. We also populate each node with a positional | |
| // bias to break ties in the priority queue. | |
| let startingNode = { | |
| token: word[0], | |
| bias: 0, | |
| prev: null, | |
| next: null, | |
| } | |
| let previousNode = startingNode | |
| for (let i = 1; i < word.length; ++i) { | |
| const currentNode = { | |
| bias: i / word.length, // Add fractional component to break ties | |
| token: word[i], | |
| prev: previousNode, | |
| next: null, | |
| } | |
| previousNode.next = currentNode | |
| this._add_node(queue, previousNode) | |
| previousNode = currentNode | |
| } | |
| while (!queue.isEmpty()) { | |
| // Get the next node with the highest priority | |
| const node = queue.pop(); | |
| // Check that this merge is still possible | |
| if (node.deleted || !node.next || node.next.deleted) continue; | |
| // Here, we mark the current node (left side of the merge) and the next node (right side of the merge) as deleted. | |
| // This is because they will both be replaced by a new node representing the merge result. | |
| node.deleted = true; | |
| node.next.deleted = true; | |
| // Next, we fix the node that comes before the current node (i.e., left side of the merge). | |
| if (node.prev) { | |
| // Make a shallow copy of the previous node | |
| const newPreviousNode = { ...node.prev }; | |
| // Mark the old previous node as deleted. This avoids erroneous merges later, | |
| // because there may still be references to this node in the priority queue. | |
| node.prev.deleted = true; | |
| node.prev = newPreviousNode; | |
| // Update the reference of the previous node, by pointing its previous node to this new previous node. | |
| if (newPreviousNode.prev) { | |
| newPreviousNode.prev.next = newPreviousNode; | |
| } else { | |
| // If the previous of the previous node does not exist, it means that | |
| // `newPreviousNode` must be the new `startingNode`. | |
| startingNode = newPreviousNode; | |
| } | |
| } | |
| // Create a new node which represents the result of the merge. | |
| const merged = { | |
| token: node.token + node.next.token, | |
| bias: node.bias, | |
| prev: node.prev, | |
| next: node.next.next, | |
| } | |
| // We now consider where we can add the new merged node to the priority queue: | |
| // 1. prev <-> merged | |
| if (merged.prev) { | |
| merged.prev.next = merged; | |
| this._add_node(queue, merged.prev); | |
| } else { | |
| // If `merged.prev` does not exist, then `merged` must be the new `startingNode`. | |
| startingNode = merged; | |
| } | |
| // 2. merged <-> next | |
| if (merged.next) { | |
| merged.next.prev = merged; | |
| this._add_node(queue, merged); | |
| } | |
| } | |
| // Traverse the linked list, starting from the `startingNode`, and collect the tokens. | |
| for (let currentNode = startingNode; currentNode !== null; currentNode = currentNode.next) { | |
| result.push(currentNode.token); | |
| } | |
| } else { | |
| result = word; | |
| } | |
| // Possibly append suffix | |
| if (this.continuing_subword_suffix) { | |
| // Do not append suffix to the last token | |
| for (let i = 0; i < result.length - 1; ++i) { | |
| result[i] += this.continuing_subword_suffix; | |
| } | |
| } | |
| if (token.length < this.max_length_to_cache) { | |
| // Save the result to the cache | |
| this.cache.put(token, result); | |
| } | |
| return result; | |
| } | |
| /** | |
| * Helper function to add a node to the priority queue. | |
| * @param {PriorityQueue} queue | |
| * @param {BPENode} node | |
| * @private | |
| */ | |
| _add_node(queue, node) { | |
| // `score` is a measure of the merge priority: lower means higher priority | |
| // We use the BPE rank as a measure of priority (i.e., the local of the merge in the merges list) | |
| // We also add a fractional component to the score to break ties (with the earlier character having higher priority) | |
| const rank = this.bpe_ranks.get(JSON.stringify([node.token, node.next.token])); | |
| if (rank !== undefined) { | |
| node.score = rank + node.bias; | |
| queue.push(node); | |
| } | |
| } | |
| /** | |
| * Encodes the input sequence of tokens using the BPE algorithm and returns the resulting subword tokens. | |
| * @param {string[]} tokens The input sequence of tokens to encode. | |
| * @returns {string[]} The resulting subword tokens after applying the BPE algorithm to the input sequence of tokens. | |
| */ | |
| encode(tokens) { | |
| const outputTokens = []; | |
| for (const token of tokens) { | |
| if (this.ignore_merges && this.tokens_to_ids.has(token)) { | |
| outputTokens.push(token); | |
| continue; | |
| } | |
| const bpe_token_list = this.bpe(token); | |
| for (const t of bpe_token_list) { | |
| if (this.tokens_to_ids.has(t)) { | |
| outputTokens.push(t); | |
| } else if (this.byte_fallback) { | |
| const byteTokens = Array.from(this.text_encoder.encode(t)) | |
| .map(x => `<0x${x.toString(16).toUpperCase().padStart(2, '0')}>`); | |
| if (byteTokens.every(x => this.tokens_to_ids.has(x))) { | |
| // Ensure the byte tokens are actually in the vocabulary, otherwise | |
| // we fall back to the unknown token. For more information, see | |
| // https://github.com/huggingface/transformers/issues/28096. | |
| outputTokens.push(...byteTokens); | |
| } else { | |
| outputTokens.push(this.unk_token); | |
| } | |
| } else { | |
| outputTokens.push(this.unk_token); | |
| } | |
| } | |
| } | |
| return outputTokens; | |
| } | |
| } | |
| /** | |
| * Legacy tokenizer class for tokenizers with only a vocabulary. | |
| */ | |
| class LegacyTokenizerModel extends TokenizerModel { | |
| /** | |
| * Create a LegacyTokenizerModel instance. | |
| * @param {Object} config The configuration object for LegacyTokenizerModel. | |
| * @param {Object} config.vocab A (possibly nested) mapping of tokens to ids. | |
| * @param {Object} moreConfig Additional configuration object for the LegacyTokenizerModel model. | |
| */ | |
| constructor(config, moreConfig) { | |
| super(config); | |
| /**@type {Map<string, number>} */ | |
| this.tokens_to_ids = objectToMap( | |
| moreConfig.target_lang | |
| ? config.vocab[moreConfig.target_lang] | |
| : config.vocab | |
| ); | |
| this.bos_token = moreConfig.bos_token; | |
| this.bos_token_id = this.tokens_to_ids.get(this.bos_token); | |
| this.eos_token = moreConfig.eos_token; | |
| this.eos_token_id = this.tokens_to_ids.get(this.eos_token); | |
| this.pad_token = moreConfig.pad_token; | |
| this.pad_token_id = this.tokens_to_ids.get(this.pad_token); | |
| this.unk_token = moreConfig.unk_token; | |
| this.unk_token_id = this.tokens_to_ids.get(this.unk_token); | |
| this.vocab = new Array(this.tokens_to_ids.size); | |
| for (const [key, value] of this.tokens_to_ids) { | |
| this.vocab[value] = key; | |
| } | |
| } | |
| encode(tokens) { | |
| return tokens; | |
| } | |
| } | |
| /** | |
| * A base class for text normalization. | |
| * @abstract | |
| */ | |
| class Normalizer extends Callable { | |
| /** | |
| * @param {Object} config The configuration object for the normalizer. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.config = config; | |
| } | |
| /** | |
| * Factory method for creating normalizers from config objects. | |
| * @static | |
| * @param {Object} config The configuration object for the normalizer. | |
| * @returns {Normalizer} A Normalizer object. | |
| * @throws {Error} If an unknown Normalizer type is specified in the config. | |
| */ | |
| static fromConfig(config) { | |
| if (config === null) return null; | |
| switch (config.type) { | |
| case 'BertNormalizer': | |
| return new BertNormalizer(config); | |
| case 'Precompiled': | |
| return new Precompiled(config); | |
| case 'Sequence': | |
| return new NormalizerSequence(config); | |
| case 'Replace': | |
| return new Replace(config); | |
| case 'NFC': | |
| return new NFC(config); | |
| case 'NFD': | |
| return new NFD(config); | |
| case 'NFKC': | |
| return new NFKC(config); | |
| case 'NFKD': | |
| return new NFKD(config); | |
| case 'Strip': | |
| return new StripNormalizer(config); | |
| case 'StripAccents': | |
| return new StripAccents(config); | |
| case 'Lowercase': | |
| return new Lowercase(config); | |
| case 'Prepend': | |
| return new Prepend(config); | |
| default: | |
| throw new Error(`Unknown Normalizer type: ${config.type}`); | |
| } | |
| } | |
| /** | |
| * Normalize the input text. | |
| * @abstract | |
| * @param {string} text The text to normalize. | |
| * @returns {string} The normalized text. | |
| * @throws {Error} If this method is not implemented in a subclass. | |
| */ | |
| normalize(text) { | |
| throw Error("normalize should be implemented in subclass.") | |
| } | |
| /** | |
| * Alias for {@link Normalizer#normalize}. | |
| * @param {string} text The text to normalize. | |
| * @returns {string} The normalized text. | |
| */ | |
| _call(text) { | |
| return this.normalize(text); | |
| } | |
| } | |
| /** | |
| * Replace normalizer that replaces occurrences of a pattern with a given string or regular expression. | |
| * @extends Normalizer | |
| */ | |
| class Replace extends Normalizer { | |
| /** | |
| * Normalize the input text by replacing the pattern with the content. | |
| * @param {string} text The input text to be normalized. | |
| * @returns {string} The normalized text after replacing the pattern with the content. | |
| */ | |
| normalize(text) { | |
| const pattern = createPattern(this.config.pattern); | |
| return pattern === null | |
| ? text | |
| : text.replaceAll(pattern, this.config.content); | |
| } | |
| } | |
| /** | |
| * A normalizer that applies Unicode normalization to the input text. | |
| * @extends Normalizer | |
| * @abstract | |
| */ | |
| class UnicodeNormalizer extends Normalizer { | |
| /** | |
| * @type {string} The Unicode normalization form to apply. | |
| * Should be one of: 'NFC', 'NFD', 'NFKC', or 'NFKD'. | |
| */ | |
| form = undefined; | |
| /** | |
| * Normalize the input text by applying Unicode normalization. | |
| * @param {string} text The input text to be normalized. | |
| * @returns {string} The normalized text. | |
| */ | |
| normalize(text) { | |
| text = text.normalize(this.form) | |
| return text; | |
| } | |
| } | |
| /** | |
| * A normalizer that applies Unicode normalization form C (NFC) to the input text. | |
| * Canonical Decomposition, followed by Canonical Composition. | |
| * @extends UnicodeNormalizer | |
| */ | |
| class NFC extends UnicodeNormalizer { | |
| form = 'NFC'; | |
| } | |
| /** | |
| * A normalizer that applies Unicode normalization form D (NFD) to the input text. | |
| * Canonical Decomposition. | |
| * @extends UnicodeNormalizer | |
| */ | |
| class NFD extends UnicodeNormalizer { | |
| form = 'NFD'; | |
| } | |
| /** | |
| * A normalizer that applies Unicode normalization form KC (NFKC) to the input text. | |
| * Compatibility Decomposition, followed by Canonical Composition. | |
| * @extends UnicodeNormalizer | |
| */ | |
| class NFKC extends UnicodeNormalizer { | |
| form = 'NFKC'; | |
| } | |
| /** | |
| * A normalizer that applies Unicode normalization form KD (NFKD) to the input text. | |
| * Compatibility Decomposition. | |
| * @extends UnicodeNormalizer | |
| */ | |
| class NFKD extends UnicodeNormalizer { | |
| form = 'NFKD'; | |
| } | |
| /** | |
| * A normalizer that strips leading and/or trailing whitespace from the input text. | |
| */ | |
| class StripNormalizer extends Normalizer { | |
| /** | |
| * Strip leading and/or trailing whitespace from the input text. | |
| * @param {string} text The input text. | |
| * @returns {string} The normalized text. | |
| */ | |
| normalize(text) { | |
| if (this.config.strip_left && this.config.strip_right) { | |
| // Fast path to avoid an extra trim call | |
| text = text.trim(); | |
| } else { | |
| if (this.config.strip_left) { | |
| text = text.trimStart(); | |
| } | |
| if (this.config.strip_right) { | |
| text = text.trimEnd(); | |
| } | |
| } | |
| return text; | |
| } | |
| } | |
| /** | |
| * StripAccents normalizer removes all accents from the text. | |
| * @extends Normalizer | |
| */ | |
| class StripAccents extends Normalizer { | |
| /** | |
| * Remove all accents from the text. | |
| * @param {string} text The input text. | |
| * @returns {string} The normalized text without accents. | |
| */ | |
| normalize(text) { | |
| text = remove_accents(text); | |
| return text; | |
| } | |
| } | |
| /** | |
| * A Normalizer that lowercases the input string. | |
| * @extends Normalizer | |
| */ | |
| class Lowercase extends Normalizer { | |
| /** | |
| * Lowercases the input string. | |
| * @param {string} text The text to normalize. | |
| * @returns {string} The normalized text. | |
| */ | |
| normalize(text) { | |
| text = text.toLowerCase(); | |
| return text; | |
| } | |
| } | |
| /** | |
| * A Normalizer that prepends a string to the input string. | |
| * @extends Normalizer | |
| */ | |
| class Prepend extends Normalizer { | |
| /** | |
| * Prepends the input string. | |
| * @param {string} text The text to normalize. | |
| * @returns {string} The normalized text. | |
| */ | |
| normalize(text) { | |
| text = this.config.prepend + text; | |
| return text; | |
| } | |
| } | |
| /** | |
| * A Normalizer that applies a sequence of Normalizers. | |
| * @extends Normalizer | |
| */ | |
| class NormalizerSequence extends Normalizer { | |
| /** | |
| * Create a new instance of NormalizerSequence. | |
| * @param {Object} config The configuration object. | |
| * @param {Object[]} config.normalizers An array of Normalizer configuration objects. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| this.normalizers = config.normalizers.map(x => Normalizer.fromConfig(x)); | |
| } | |
| /** | |
| * Apply a sequence of Normalizers to the input text. | |
| * @param {string} text The text to normalize. | |
| * @returns {string} The normalized text. | |
| */ | |
| normalize(text) { | |
| return this.normalizers.reduce((t, normalizer) => { | |
| return normalizer.normalize(t); | |
| }, text); | |
| } | |
| } | |
| /** | |
| * A class representing a normalizer used in BERT tokenization. | |
| * @extends Normalizer | |
| */ | |
| class BertNormalizer extends Normalizer { | |
| /** | |
| * Adds whitespace around any CJK (Chinese, Japanese, or Korean) character in the input text. | |
| * | |
| * @param {string} text The input text to tokenize. | |
| * @returns {string} The tokenized text with whitespace added around CJK characters. | |
| */ | |
| _tokenize_chinese_chars(text) { | |
| /* Adds whitespace around any CJK character. */ | |
| const output = []; | |
| for (let i = 0; i < text.length; ++i) { | |
| const char = text[i]; | |
| const cp = char.charCodeAt(0); | |
| if (is_chinese_char(cp)) { | |
| output.push(" "); | |
| output.push(char); | |
| output.push(" "); | |
| } else { | |
| output.push(char); | |
| } | |
| } | |
| return output.join(""); | |
| } | |
| /** | |
| * Strips accents from the given text. | |
| * @param {string} text The text to strip accents from. | |
| * @returns {string} The text with accents removed. | |
| */ | |
| stripAccents(text) { | |
| // "Mark, Nonspacing" (Mn) | |
| return text.normalize('NFD').replace(/\p{Mn}/gu, ''); | |
| } | |
| /** | |
| * Checks whether `char` is a control character. | |
| * @param {string} char The character to check. | |
| * @returns {boolean} Whether `char` is a control character. | |
| * @private | |
| */ | |
| _is_control(char) { | |
| switch (char) { | |
| case '\t': | |
| case '\n': | |
| case '\r': | |
| // These are technically control characters but we count them as whitespace characters. | |
| return false; | |
| default: | |
| // Check if unicode category starts with C: | |
| // Cc - Control | |
| // Cf - Format | |
| // Co - Private Use | |
| // Cs - Surrogate | |
| return /^\p{Cc}|\p{Cf}|\p{Co}|\p{Cs}$/u.test(char); | |
| } | |
| } | |
| /** | |
| * Performs invalid character removal and whitespace cleanup on text. | |
| * @param {string} text The text to clean. | |
| * @returns {string} The cleaned text. | |
| * @private | |
| */ | |
| _clean_text(text) { | |
| const output = []; | |
| for (const char of text) { | |
| const cp = char.charCodeAt(0); | |
| if (cp === 0 || cp === 0xFFFD || this._is_control(char)) { | |
| continue; | |
| } | |
| if (/^\s$/.test(char)) { // is whitespace | |
| output.push(" "); | |
| } else { | |
| output.push(char); | |
| } | |
| } | |
| return output.join(""); | |
| } | |
| /** | |
| * Normalizes the given text based on the configuration. | |
| * @param {string} text The text to normalize. | |
| * @returns {string} The normalized text. | |
| */ | |
| normalize(text) { | |
| if (this.config.clean_text) { | |
| text = this._clean_text(text); | |
| } | |
| if (this.config.handle_chinese_chars) { | |
| text = this._tokenize_chinese_chars(text); | |
| } | |
| if (this.config.lowercase) { | |
| text = text.toLowerCase(); | |
| if (this.config.strip_accents !== false) { | |
| text = this.stripAccents(text); | |
| } | |
| } else if (this.config.strip_accents) { | |
| text = this.stripAccents(text); | |
| } | |
| return text; | |
| } | |
| } | |
| /** | |
| * A callable class representing a pre-tokenizer used in tokenization. Subclasses | |
| * should implement the `pre_tokenize_text` method to define the specific pre-tokenization logic. | |
| * @extends Callable | |
| */ | |
| class PreTokenizer extends Callable { | |
| /** | |
| * Factory method that returns an instance of a subclass of `PreTokenizer` based on the provided configuration. | |
| * | |
| * @static | |
| * @param {Object} config A configuration object for the pre-tokenizer. | |
| * @returns {PreTokenizer} An instance of a subclass of `PreTokenizer`. | |
| * @throws {Error} If the provided configuration object does not correspond to any known pre-tokenizer. | |
| */ | |
| static fromConfig(config) { | |
| if (config === null) return null; | |
| switch (config.type) { | |
| case 'BertPreTokenizer': | |
| return new BertPreTokenizer(config); | |
| case 'Sequence': | |
| return new PreTokenizerSequence(config); | |
| case 'Whitespace': | |
| return new WhitespacePreTokenizer(config); | |
| case 'WhitespaceSplit': | |
| return new WhitespaceSplit(config); | |
| case 'Metaspace': | |
| return new MetaspacePreTokenizer(config); | |
| case 'ByteLevel': | |
| return new ByteLevelPreTokenizer(config); | |
| case 'Split': | |
| return new SplitPreTokenizer(config); | |
| case 'Punctuation': | |
| return new PunctuationPreTokenizer(config); | |
| case 'Digits': | |
| return new DigitsPreTokenizer(config); | |
| case 'Replace': | |
| return new ReplacePreTokenizer(config); | |
| default: | |
| throw new Error(`Unknown PreTokenizer type: ${config.type}`); | |
| } | |
| } | |
| /** | |
| * Method that should be implemented by subclasses to define the specific pre-tokenization logic. | |
| * | |
| * @abstract | |
| * @param {string} text The text to pre-tokenize. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} The pre-tokenized text. | |
| * @throws {Error} If the method is not implemented in the subclass. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| throw Error("pre_tokenize_text should be implemented in subclass.") | |
| } | |
| /** | |
| * Tokenizes the given text into pre-tokens. | |
| * @param {string|string[]} text The text or array of texts to pre-tokenize. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of pre-tokens. | |
| */ | |
| pre_tokenize(text, options) { | |
| return (Array.isArray(text) | |
| ? text.map(x => this.pre_tokenize_text(x, options)) | |
| : this.pre_tokenize_text(text, options) | |
| ).flat(); | |
| } | |
| /** | |
| * Alias for {@link PreTokenizer#pre_tokenize}. | |
| * @param {string|string[]} text The text or array of texts to pre-tokenize. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of pre-tokens. | |
| */ | |
| _call(text, options) { | |
| return this.pre_tokenize(text, options); | |
| } | |
| } | |
| /** | |
| * @extends PreTokenizer | |
| */ | |
| class BertPreTokenizer extends PreTokenizer { | |
| /** | |
| * A PreTokenizer that splits text into wordpieces using a basic tokenization scheme | |
| * similar to that used in the original implementation of BERT. | |
| * | |
| * @param {Object} config The configuration object. | |
| */ | |
| constructor(config) { | |
| super(); | |
| // Construct a pattern which matches the rust implementation: | |
| // https://github.com/huggingface/tokenizers/blob/b4fcc9ce6e4ad5806e82826f816acfdfdc4fcc67/tokenizers/src/pre_tokenizers/bert.rs#L11 | |
| // Equivalent to removing whitespace and splitting on punctuation (both \p{P} and other ascii characters) | |
| this.pattern = new RegExp(`[^\\s${PUNCTUATION_REGEX}]+|[${PUNCTUATION_REGEX}]`, 'gu'); | |
| } | |
| /** | |
| * Tokenizes a single text using the BERT pre-tokenization scheme. | |
| * | |
| * @param {string} text The text to tokenize. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of tokens. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| return text.trim().match(this.pattern) || []; | |
| } | |
| } | |
| /** | |
| * A pre-tokenizer that splits text into Byte-Pair-Encoding (BPE) subwords. | |
| * @extends PreTokenizer | |
| */ | |
| class ByteLevelPreTokenizer extends PreTokenizer { | |
| /** | |
| * Creates a new instance of the `ByteLevelPreTokenizer` class. | |
| * @param {Object} config The configuration object. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.config = config; | |
| /** | |
| * @type {boolean} Whether to add a leading space to the first word. | |
| * This allows to treat the leading word just as any other word. | |
| */ | |
| this.add_prefix_space = this.config.add_prefix_space; | |
| /** | |
| * @type {boolean} Whether the post processing step should trim offsets | |
| * to avoid including whitespaces. | |
| * @todo Use this in the pretokenization step. | |
| */ | |
| this.trim_offsets = this.config.trim_offsets; | |
| /** | |
| * @type {boolean} Whether to use the standard GPT2 regex for whitespace splitting. | |
| * Set it to False if you want to use your own splitting. Defaults to true. | |
| */ | |
| this.use_regex = this.config.use_regex ?? true; | |
| this.pattern = /'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/gu; | |
| this.byte_encoder = BYTES_TO_UNICODE; | |
| this.text_encoder = new TextEncoder(); | |
| } | |
| /** | |
| * Tokenizes a single piece of text using byte-level tokenization. | |
| * @param {string} text The text to tokenize. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of tokens. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| // Add a leading space if the option is enabled | |
| if (this.add_prefix_space && !text.startsWith(' ')) { | |
| text = ' ' + text; | |
| } | |
| // Split on whitespace and punctuation | |
| const tokens = this.use_regex ? (text.match(this.pattern) || []) : [text]; | |
| // Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) | |
| return tokens.map( | |
| token => Array.from(this.text_encoder.encode(token), byte => this.byte_encoder[byte]).join('') | |
| ); | |
| } | |
| } | |
| /** | |
| * @typedef {'removed'|'isolated'|'mergedWithPrevious'|'mergedWithNext'|'contiguous'} SplitDelimiterBehavior | |
| */ | |
| /** | |
| * Splits text using a given pattern. | |
| * @extends PreTokenizer | |
| */ | |
| class SplitPreTokenizer extends PreTokenizer { | |
| /** | |
| * @param {Object} config The configuration options for the pre-tokenizer. | |
| * @param {Object} config.pattern The pattern used to split the text. Can be a string or a regex object. | |
| * @param {string|undefined} config.pattern.String The string to use for splitting. Only defined if the pattern is a string. | |
| * @param {string|undefined} config.pattern.Regex The regex to use for splitting. Only defined if the pattern is a regex. | |
| * @param {SplitDelimiterBehavior} config.behavior The behavior to use when splitting. | |
| * @param {boolean} config.invert Whether to split (invert=false) or match (invert=true) the pattern. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.config = config; | |
| // TODO support all behaviours (config.behavior) | |
| this.pattern = createPattern(this.config.pattern, this.config.invert); | |
| } | |
| /** | |
| * Tokenizes text by splitting it using the given pattern. | |
| * @param {string} text The text to tokenize. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of tokens. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| if (this.pattern === null) { | |
| return []; | |
| } | |
| if (this.config.invert) { | |
| return text.match(this.pattern) || []; | |
| } else if (this.config.behavior?.toLowerCase() === 'removed') { | |
| return text.split(this.pattern).filter(x => x); | |
| } else { | |
| return regexSplit(text, this.pattern); | |
| } | |
| } | |
| } | |
| /** | |
| * Splits text based on punctuation. | |
| * @extends PreTokenizer | |
| */ | |
| class PunctuationPreTokenizer extends PreTokenizer { | |
| /** | |
| * @param {Object} config The configuration options for the pre-tokenizer. | |
| * @param {SplitDelimiterBehavior} config.behavior The behavior to use when splitting. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.config = config; | |
| this.pattern = new RegExp(`[^${PUNCTUATION_REGEX}]+|[${PUNCTUATION_REGEX}]+`, 'gu'); | |
| } | |
| /** | |
| * Tokenizes text by splitting it using the given pattern. | |
| * @param {string} text The text to tokenize. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of tokens. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| return text.match(this.pattern) || []; | |
| } | |
| } | |
| /** | |
| * Splits text based on digits. | |
| * @extends PreTokenizer | |
| */ | |
| class DigitsPreTokenizer extends PreTokenizer { | |
| /** | |
| * @param {Object} config The configuration options for the pre-tokenizer. | |
| * @param {boolean} config.individual_digits Whether to split on individual digits. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.config = config; | |
| // Construct a pattern which matches the rust implementation: | |
| const digit_pattern = `[^\\d]+|\\d${this.config.individual_digits ? '' : '+'}`; | |
| this.pattern = new RegExp(digit_pattern, 'gu'); | |
| } | |
| /** | |
| * Tokenizes text by splitting it using the given pattern. | |
| * @param {string} text The text to tokenize. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of tokens. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| return text.match(this.pattern) || []; | |
| } | |
| } | |
| /** | |
| * @typedef {Object} PostProcessedOutput | |
| * @property {string[]} tokens List of token produced by the post-processor. | |
| * @property {number[]} [token_type_ids] List of token type ids produced by the post-processor. | |
| */ | |
| /** | |
| * @typedef {Object} EncodingSingle | |
| * @property {number[]} input_ids List of token ids to be fed to a model. | |
| * @property {number[]} attention_mask List of token type ids to be fed to a model | |
| * @property {number[]} [token_type_ids] List of indices specifying which tokens should be attended to by the model | |
| */ | |
| /** | |
| * @extends Callable | |
| */ | |
| class PostProcessor extends Callable { | |
| /** | |
| * @param {Object} config The configuration for the post-processor. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.config = config; | |
| } | |
| /** | |
| * Factory method to create a PostProcessor object from a configuration object. | |
| * | |
| * @param {Object} config Configuration object representing a PostProcessor. | |
| * @returns {PostProcessor} A PostProcessor object created from the given configuration. | |
| * @throws {Error} If an unknown PostProcessor type is encountered. | |
| */ | |
| static fromConfig(config) { | |
| if (config === null) return null; | |
| switch (config.type) { | |
| case 'TemplateProcessing': | |
| return new TemplateProcessing(config); | |
| case 'ByteLevel': | |
| return new ByteLevelPostProcessor(config); | |
| case 'RobertaProcessing': | |
| return new RobertaProcessing(config); | |
| case 'BertProcessing': | |
| return new BertProcessing(config); | |
| case 'Sequence': | |
| return new PostProcessorSequence(config); | |
| default: | |
| throw new Error(`Unknown PostProcessor type: ${config.type}`); | |
| } | |
| } | |
| /** | |
| * Method to be implemented in subclass to apply post-processing on the given tokens. | |
| * | |
| * @param {Array} tokens The input tokens to be post-processed. | |
| * @param {...*} args Additional arguments required by the post-processing logic. | |
| * @returns {PostProcessedOutput} The post-processed tokens. | |
| * @throws {Error} If the method is not implemented in subclass. | |
| */ | |
| post_process(tokens, ...args) { | |
| throw Error("post_process should be implemented in subclass.") | |
| } | |
| /** | |
| * Alias for {@link PostProcessor#post_process}. | |
| * @param {Array} tokens The text or array of texts to post-process. | |
| * @param {...*} args Additional arguments required by the post-processing logic. | |
| * @returns {PostProcessedOutput} The post-processed tokens. | |
| */ | |
| _call(tokens, ...args) { | |
| return this.post_process(tokens, ...args); | |
| } | |
| } | |
| /** | |
| * A post-processor that adds special tokens to the beginning and end of the input. | |
| */ | |
| class BertProcessing extends PostProcessor { | |
| /** | |
| * @param {Object} config The configuration for the post-processor. | |
| * @param {string[]} config.cls The special tokens to add to the beginning of the input. | |
| * @param {string[]} config.sep The special tokens to add to the end of the input. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| // TODO use all of config: add_prefix_space, trim_offsets | |
| this.cls = config.cls[0]; | |
| this.sep = config.sep[0]; | |
| } | |
| /** | |
| * Adds the special tokens to the beginning and end of the input. | |
| * @param {string[]} tokens The input tokens. | |
| * @param {string[]} [tokens_pair=null] An optional second set of input tokens. | |
| * @returns {PostProcessedOutput} The post-processed tokens with the special tokens added to the beginning and end. | |
| */ | |
| post_process(tokens, tokens_pair = null, { | |
| add_special_tokens = true, | |
| } = {}) { | |
| if (add_special_tokens) { | |
| tokens = mergeArrays([this.cls], tokens, [this.sep]); | |
| } | |
| let token_type_ids = new Array(tokens.length).fill(0); | |
| if (tokens_pair !== null) { | |
| // NOTE: It is intended to add 2 EOS tokens after the first set of tokens | |
| // https://github.com/huggingface/tokenizers/issues/983 | |
| const middle = (add_special_tokens && this instanceof RobertaProcessing) | |
| ? [this.sep] | |
| : []; | |
| const after = add_special_tokens ? [this.sep] : []; | |
| tokens = mergeArrays(tokens, middle, tokens_pair, after); | |
| token_type_ids = mergeArrays(token_type_ids, new Array(tokens_pair.length + middle.length + after.length).fill(1)); | |
| } | |
| return { tokens, token_type_ids }; | |
| } | |
| } | |
| class RobertaProcessing extends BertProcessing { } // NOTE: extends BertProcessing | |
| /** | |
| * Post processor that replaces special tokens in a template with actual tokens. | |
| * @extends PostProcessor | |
| */ | |
| class TemplateProcessing extends PostProcessor { | |
| /** | |
| * Creates a new instance of `TemplateProcessing`. | |
| * @param {Object} config The configuration options for the post processor. | |
| * @param {Array} config.single The template for a single sequence of tokens. | |
| * @param {Array} config.pair The template for a pair of sequences of tokens. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| this.single = config.single; | |
| this.pair = config.pair; | |
| } | |
| /** | |
| * Replaces special tokens in the template with actual tokens. | |
| * @param {string[]} tokens The list of tokens for the first sequence. | |
| * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). | |
| * @returns {PostProcessedOutput} An object containing the list of tokens with the special tokens replaced with actual tokens. | |
| */ | |
| post_process(tokens, tokens_pair = null, { | |
| add_special_tokens = true, | |
| } = {}) { | |
| const type = tokens_pair === null ? this.single : this.pair | |
| let processedTokens = []; | |
| let types = []; | |
| for (const item of type) { | |
| if ('SpecialToken' in item) { | |
| if (add_special_tokens) { | |
| processedTokens.push(item.SpecialToken.id); | |
| types.push(item.SpecialToken.type_id); | |
| } | |
| } else if ('Sequence' in item) { | |
| if (item.Sequence.id === 'A') { | |
| processedTokens = mergeArrays(processedTokens, tokens); | |
| types = mergeArrays(types, new Array(tokens.length).fill(item.Sequence.type_id)); | |
| } else if (item.Sequence.id === 'B') { | |
| processedTokens = mergeArrays(processedTokens, tokens_pair); | |
| types = mergeArrays(types, new Array(tokens_pair.length).fill(item.Sequence.type_id)); | |
| } | |
| } | |
| } | |
| return { tokens: processedTokens, token_type_ids: types }; | |
| } | |
| } | |
| /** | |
| * A PostProcessor that returns the given tokens as is. | |
| * @extends PostProcessor | |
| */ | |
| class ByteLevelPostProcessor extends PostProcessor { | |
| /** | |
| * Post process the given tokens. | |
| * @param {string[]} tokens The list of tokens for the first sequence. | |
| * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). | |
| * @returns {PostProcessedOutput} An object containing the post-processed tokens. | |
| */ | |
| post_process(tokens, tokens_pair = null) { | |
| if (tokens_pair) { | |
| tokens = mergeArrays(tokens, tokens_pair); | |
| } | |
| return { tokens }; | |
| } | |
| } | |
| /** | |
| * A post-processor that applies multiple post-processors in sequence. | |
| */ | |
| class PostProcessorSequence extends PostProcessor { | |
| /** | |
| * Creates a new instance of PostProcessorSequence. | |
| * @param {Object} config The configuration object. | |
| * @param {Object[]} config.processors The list of post-processors to apply. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| this.processors = config.processors.map(x => PostProcessor.fromConfig(x)); | |
| } | |
| /** | |
| * Post process the given tokens. | |
| * @param {string[]} tokens The list of tokens for the first sequence. | |
| * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). | |
| * @returns {PostProcessedOutput} An object containing the post-processed tokens. | |
| */ | |
| post_process(tokens, tokens_pair = null, options = {}) { | |
| let token_type_ids; | |
| for (const processor of this.processors) { | |
| if (processor instanceof ByteLevelPostProcessor) { | |
| // Special case where we need to pass the tokens_pair to the post-processor | |
| const output = processor.post_process(tokens); | |
| tokens = output.tokens; | |
| if (tokens_pair) { | |
| const pair_output = processor.post_process(tokens_pair); | |
| tokens_pair = pair_output.tokens; | |
| } | |
| } else { | |
| const output = processor.post_process(tokens, tokens_pair, options); | |
| tokens = output.tokens; | |
| token_type_ids = output.token_type_ids; | |
| } | |
| } | |
| return { tokens, token_type_ids }; | |
| } | |
| } | |
| /** | |
| * The base class for token decoders. | |
| * @extends Callable | |
| */ | |
| class Decoder extends Callable { | |
| /** | |
| * Creates an instance of `Decoder`. | |
| * | |
| * @param {Object} config The configuration object. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.config = config; | |
| /** @type {AddedToken[]} */ | |
| this.added_tokens = []; | |
| this.end_of_word_suffix = null; | |
| this.trim_offsets = config.trim_offsets; | |
| } | |
| /** | |
| * Creates a decoder instance based on the provided configuration. | |
| * | |
| * @param {Object} config The configuration object. | |
| * @returns {Decoder} A decoder instance. | |
| * @throws {Error} If an unknown decoder type is provided. | |
| */ | |
| static fromConfig(config) { | |
| if (config === null) return null; | |
| switch (config.type) { | |
| case 'WordPiece': | |
| return new WordPieceDecoder(config); | |
| case 'Metaspace': | |
| return new MetaspaceDecoder(config); | |
| case 'ByteLevel': | |
| return new ByteLevelDecoder(config); | |
| case 'Replace': | |
| return new ReplaceDecoder(config); | |
| case 'ByteFallback': | |
| return new ByteFallback(config); | |
| case 'Fuse': | |
| return new FuseDecoder(config); | |
| case 'Strip': | |
| return new StripDecoder(config); | |
| case 'Sequence': | |
| return new DecoderSequence(config); | |
| case 'CTC': | |
| return new CTCDecoder(config); | |
| case 'BPEDecoder': | |
| return new BPEDecoder(config); | |
| default: | |
| throw new Error(`Unknown Decoder type: ${config.type}`); | |
| } | |
| } | |
| /** | |
| * Calls the `decode` method. | |
| * | |
| * @param {string[]} tokens The list of tokens. | |
| * @returns {string} The decoded string. | |
| */ | |
| _call(tokens) { | |
| return this.decode(tokens); | |
| } | |
| /** | |
| * Decodes a list of tokens. | |
| * @param {string[]} tokens The list of tokens. | |
| * @returns {string} The decoded string. | |
| */ | |
| decode(tokens) { | |
| return this.decode_chain(tokens).join(''); | |
| } | |
| /** | |
| * Apply the decoder to a list of tokens. | |
| * | |
| * @param {string[]} tokens The list of tokens. | |
| * @returns {string[]} The decoded list of tokens. | |
| * @throws {Error} If the `decode_chain` method is not implemented in the subclass. | |
| */ | |
| decode_chain(tokens) { | |
| throw Error("`decode_chain` should be implemented in subclass.") | |
| } | |
| } | |
| class ReplaceDecoder extends Decoder { | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| const pattern = createPattern(this.config.pattern); | |
| return pattern === null | |
| ? tokens | |
| : tokens.map(token => token.replaceAll(pattern, this.config.content)) | |
| } | |
| } | |
| class ByteFallback extends Decoder { | |
| constructor(config) { | |
| super(config); | |
| this.text_decoder = new TextDecoder(); | |
| } | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| const new_tokens = []; | |
| let previous_byte_tokens = []; | |
| for (const token of tokens) { | |
| let bytes = null; | |
| if (token.length === 6 && token.startsWith('<0x') && token.endsWith('>')) { | |
| const byte = parseInt(token.slice(3, 5), 16); | |
| if (!isNaN(byte)) { | |
| bytes = byte; | |
| } | |
| } | |
| if (bytes !== null) { | |
| previous_byte_tokens.push(bytes); | |
| } else { | |
| if (previous_byte_tokens.length > 0) { | |
| const string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); | |
| new_tokens.push(string); | |
| previous_byte_tokens = []; | |
| } | |
| new_tokens.push(token); | |
| } | |
| } | |
| if (previous_byte_tokens.length > 0) { | |
| const string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); | |
| new_tokens.push(string); | |
| previous_byte_tokens = []; | |
| } | |
| return new_tokens; | |
| } | |
| } | |
| /** | |
| * Fuse simply fuses all tokens into one big string. | |
| * It's usually the last decoding step anyway, but this decoder | |
| * exists incase some decoders need to happen after that step | |
| */ | |
| class FuseDecoder extends Decoder { | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| return [tokens.join('')]; | |
| } | |
| } | |
| class StripDecoder extends Decoder { | |
| constructor(config) { | |
| super(config); | |
| this.content = this.config.content; | |
| this.start = this.config.start; | |
| this.stop = this.config.stop; | |
| } | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| return tokens.map(token => { | |
| let start_cut = 0; | |
| for (let i = 0; i < this.start; ++i) { | |
| if (token[i] === this.content) { | |
| start_cut = i + 1; | |
| continue; | |
| } else { | |
| break; | |
| } | |
| } | |
| let stop_cut = token.length; | |
| for (let i = 0; i < this.stop; ++i) { | |
| const index = token.length - i - 1; | |
| if (token[index] === this.content) { | |
| stop_cut = index; | |
| continue; | |
| } else { | |
| break; | |
| } | |
| } | |
| return token.slice(start_cut, stop_cut) | |
| }); | |
| } | |
| } | |
| /** | |
| * A decoder that decodes a list of WordPiece tokens into a single string. | |
| * @extends Decoder | |
| */ | |
| class WordPieceDecoder extends Decoder { | |
| /** | |
| * Creates a new instance of WordPieceDecoder. | |
| * @param {Object} config The configuration object. | |
| * @param {string} config.prefix The prefix used for WordPiece encoding. | |
| * @param {boolean} config.cleanup Whether to cleanup the decoded string. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| this.cleanup = config.cleanup; | |
| } | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| return tokens.map((token, i) => { | |
| if (i !== 0) { | |
| if (token.startsWith(this.config.prefix)) { | |
| // NOTE: .replace() is intended; only replace first occurrence | |
| token = token.replace(this.config.prefix, ''); | |
| } else { | |
| token = ' ' + token; | |
| } | |
| } | |
| if (this.cleanup) { | |
| token = clean_up_tokenization(token) | |
| } | |
| return token; | |
| }); | |
| } | |
| } | |
| /** | |
| * Byte-level decoder for tokenization output. Inherits from the `Decoder` class. | |
| * @extends Decoder | |
| */ | |
| class ByteLevelDecoder extends Decoder { | |
| /** | |
| * Create a `ByteLevelDecoder` object. | |
| * @param {Object} config Configuration object. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| this.byte_decoder = UNICODE_TO_BYTES; | |
| this.text_decoder = new TextDecoder("utf-8", { | |
| fatal: false, | |
| ignoreBOM: true, | |
| }); | |
| this.end_of_word_suffix = null; | |
| } | |
| /** | |
| * Convert an array of tokens to string by decoding each byte. | |
| * @param {string[]} tokens Array of tokens to be decoded. | |
| * @returns {string} The decoded string. | |
| */ | |
| convert_tokens_to_string(tokens) { | |
| const text = tokens.join(''); | |
| const byteArray = new Uint8Array([...text].map(c => this.byte_decoder[c])); | |
| const decoded_text = this.text_decoder.decode(byteArray); | |
| return decoded_text; | |
| } | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| // TODO move to base class (like HF) | |
| // tokens === filtered_tokens | |
| // To avoid mixing byte-level and unicode for byte-level BPT | |
| // we need to build string separately for added tokens and byte-level tokens | |
| // cf. https://github.com/huggingface/transformers/issues/1133 | |
| const sub_texts = []; | |
| let current_sub_text = []; | |
| for (const token of tokens) { | |
| // tokens sent here are already filtered, so we don't need to do this | |
| // if (skip_special_tokens && this.all_special_ids.includes(token)) { | |
| // continue; | |
| // } | |
| if (this.added_tokens.find(x => x.content === token) !== undefined) { | |
| if (current_sub_text.length > 0) { | |
| sub_texts.push(this.convert_tokens_to_string(current_sub_text)); | |
| current_sub_text = []; | |
| } | |
| sub_texts.push(token); | |
| } else { | |
| current_sub_text.push(token); | |
| } | |
| } | |
| if (current_sub_text.length > 0) { | |
| sub_texts.push(this.convert_tokens_to_string(current_sub_text)); | |
| } | |
| // TODO add spaces_between_special_tokens and clean_up_tokenization_spaces options | |
| return sub_texts; | |
| } | |
| } | |
| /** | |
| * The CTC (Connectionist Temporal Classification) decoder. | |
| * See https://github.com/huggingface/tokenizers/blob/bb38f390a61883fc2f29d659af696f428d1cda6b/tokenizers/src/decoders/ctc.rs | |
| */ | |
| class CTCDecoder extends Decoder { | |
| constructor(config) { | |
| super(config); | |
| this.pad_token = this.config.pad_token; | |
| this.word_delimiter_token = this.config.word_delimiter_token; | |
| this.cleanup = this.config.cleanup; | |
| } | |
| /** | |
| * Converts a connectionist-temporal-classification (CTC) output tokens into a single string. | |
| * @param {string[]} tokens Array of tokens to be decoded. | |
| * @returns {string} The decoded string. | |
| */ | |
| convert_tokens_to_string(tokens) { | |
| if (tokens.length === 0) return ''; | |
| // group same tokens into non-repeating tokens in CTC style decoding | |
| const grouped_tokens = [tokens[0]]; | |
| for (let i = 1; i < tokens.length; ++i) { | |
| if (tokens[i] !== grouped_tokens.at(-1)) { | |
| grouped_tokens.push(tokens[i]); | |
| } | |
| } | |
| // filter self.pad_token which is used as CTC-blank token | |
| const filtered_tokens = grouped_tokens.filter(token => token !== this.pad_token); | |
| let text = filtered_tokens.join(''); | |
| if (this.cleanup) { | |
| // cleanup and replace delimiter token | |
| text = clean_up_tokenization(text) | |
| .replaceAll(this.word_delimiter_token, ' ') | |
| .trim(); | |
| } | |
| return text; | |
| } | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| return [this.convert_tokens_to_string(tokens)]; | |
| } | |
| } | |
| /** | |
| * Apply a sequence of decoders. | |
| * @extends Decoder | |
| */ | |
| class DecoderSequence extends Decoder { | |
| /** | |
| * Creates a new instance of DecoderSequence. | |
| * @param {Object} config The configuration object. | |
| * @param {Object[]} config.decoders The list of decoders to apply. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| this.decoders = config.decoders.map(x => Decoder.fromConfig(x)); | |
| } | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| // Use reduce to apply each decoder to the tokens | |
| return this.decoders.reduce((toks, decoder) => { | |
| return decoder.decode_chain(toks); | |
| }, tokens); | |
| } | |
| } | |
| class BPEDecoder extends Decoder { | |
| constructor(config) { | |
| super(config); | |
| this.suffix = this.config.suffix; | |
| } | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| return tokens.map((token, i) => { | |
| return token.replaceAll(this.suffix, (i === tokens.length - 1) ? '' : ' ') | |
| }); | |
| } | |
| } | |
| // Custom decoder for VITS | |
| class VitsDecoder extends Decoder { | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| let decoded = ''; | |
| for (let i = 1; i < tokens.length; i += 2) { | |
| decoded += tokens[i]; | |
| } | |
| return [decoded]; | |
| } | |
| } | |
| /** | |
| * This PreTokenizer replaces spaces with the given replacement character, adds a prefix space if requested, | |
| * and returns a list of tokens. | |
| * @extends PreTokenizer | |
| */ | |
| class MetaspacePreTokenizer extends PreTokenizer { | |
| /** | |
| * @param {Object} config The configuration object for the MetaspacePreTokenizer. | |
| * @param {boolean} config.add_prefix_space Whether to add a prefix space to the first token. | |
| * @param {string} config.replacement The character to replace spaces with. | |
| * @param {string} [config.str_rep=config.replacement] An optional string representation of the replacement character. | |
| * @param {'first'|'never'|'always'} [config.prepend_scheme='always'] The metaspace prepending scheme. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.addPrefixSpace = config.add_prefix_space; | |
| this.replacement = config.replacement; | |
| this.strRep = config.str_rep || this.replacement; | |
| this.prepend_scheme = config.prepend_scheme ?? 'always'; | |
| } | |
| /** | |
| * This method takes a string, replaces spaces with the replacement character, | |
| * adds a prefix space if requested, and returns a new list of tokens. | |
| * @param {string} text The text to pre-tokenize. | |
| * @param {Object} [options] The options for the pre-tokenization. | |
| * @param {number} [options.section_index] The index of the section to pre-tokenize. | |
| * @returns {string[]} A new list of pre-tokenized tokens. | |
| */ | |
| pre_tokenize_text(text, { | |
| section_index = undefined, | |
| } = {}) { | |
| let normalized = text.replaceAll(' ', this.strRep); | |
| if ( | |
| // We add a prefix space if: | |
| // (1) The addPrefixSpace option is enabled and the normalized | |
| // token does not already start with the replacement character. | |
| (this.addPrefixSpace && !normalized.startsWith(this.replacement)) | |
| // and (2) either: | |
| // (a) prepend_scheme is 'always' | |
| // (b) prepend_scheme is 'first' and this is the first section | |
| && ( | |
| this.prepend_scheme === 'always' || | |
| (this.prepend_scheme === 'first' && section_index === 0) | |
| ) | |
| ) { | |
| normalized = this.strRep + normalized; | |
| } | |
| return [normalized]; | |
| } | |
| } | |
| /** | |
| * MetaspaceDecoder class extends the Decoder class and decodes Metaspace tokenization. | |
| * @extends Decoder | |
| */ | |
| class MetaspaceDecoder extends Decoder { | |
| /** | |
| * Constructs a new MetaspaceDecoder object. | |
| * @param {Object} config The configuration object for the MetaspaceDecoder. | |
| * @param {boolean} config.add_prefix_space Whether to add a prefix space to the decoded string. | |
| * @param {string} config.replacement The string to replace spaces with. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| this.addPrefixSpace = config.add_prefix_space; | |
| this.replacement = config.replacement; | |
| } | |
| /** @type {Decoder['decode_chain']} */ | |
| decode_chain(tokens) { | |
| const result = []; | |
| for (let i = 0; i < tokens.length; ++i) { | |
| let normalized = tokens[i].replaceAll(this.replacement, ' '); | |
| if (this.addPrefixSpace && i == 0 && normalized.startsWith(' ')) { | |
| normalized = normalized.substring(1); | |
| } | |
| result.push(normalized); | |
| } | |
| return result; | |
| } | |
| } | |
| /** | |
| * A normalizer that applies a precompiled charsmap. | |
| * This is useful for applying complex normalizations in C++ and exposing them to JavaScript. | |
| * @extends Normalizer | |
| * @param {Object} config The configuration object for the Precompiled normalizer. | |
| * @param {Object} config.precompiled_charsmap The precompiled charsmap object. | |
| */ | |
| class Precompiled extends Normalizer { | |
| /** | |
| * Create a new instance of Precompiled normalizer. | |
| * @param {Object} config The configuration object. | |
| * @param {any} config.precompiled_charsmap Precompiled chars mapping. | |
| */ | |
| constructor(config) { | |
| super(config); | |
| this.charsmap = config.precompiled_charsmap; | |
| } | |
| /** | |
| * Normalizes the given text by applying the precompiled charsmap. | |
| * @param {string} text The text to normalize. | |
| * @returns {string} The normalized text. | |
| */ | |
| normalize(text) { | |
| // As stated in the sentencepiece normalization docs (https://github.com/google/sentencepiece/blob/master/doc/normalization.md#use-pre-defined-normalization-rule), | |
| // there are 5 pre-defined normalization rules: | |
| // 1. nmt_nfkc: NFKC normalization with some additional normalization around spaces. (default) | |
| // 2. nfkc: original NFKC normalization. | |
| // 3. nmt_nfkc_cf: nmt_nfkc + Unicode case folding (mostly lower casing) | |
| // 4. nfkc_cf: nfkc + Unicode case folding. | |
| // 5. identity: no normalization | |
| // | |
| // For now, we only implement the default (nmt_nfkc). | |
| // See https://raw.githubusercontent.com/google/sentencepiece/master/data/nmt_nfkc.tsv for the full list of rules. | |
| // TODO: detect when a different `this.charsmap` is used. | |
| text = text.replace(/[\u0001-\u0008\u000B\u000E-\u001F\u007F\u008F\u009F]/gm, ''); // Remove control characters | |
| text = text.replace(/[\u0009\u000A\u000C\u000D\u00A0\u1680\u2000-\u200F\u2028\u2029\u202F\u205F\u2581\u3000\uFEFF\uFFFD]/gm, '\u0020'); // Replace certain characters with a space | |
| if (text.includes('\uFF5E')) { | |
| // To match the sentencepiece implementation 100%, we must handle a very strange edge-case. | |
| // For some reason, the "Fullwidth Tilde" character (\uFF5E) should not be converted to the standard Tilde character (\u007E). | |
| // However, NFKC normalization does do this conversion. As a result, we split the string on the Fullwidth Tilde character, | |
| // perform NFKC normalization on each substring, and then join them back together with the Fullwidth Tilde character. | |
| const parts = text.split('\uFF5E'); | |
| text = parts.map(part => part.normalize('NFKC')).join('\uFF5E'); | |
| } else { | |
| text = text.normalize('NFKC'); | |
| } | |
| return text; | |
| } | |
| } | |
| /** | |
| * A pre-tokenizer that applies a sequence of pre-tokenizers to the input text. | |
| * @extends PreTokenizer | |
| */ | |
| class PreTokenizerSequence extends PreTokenizer { | |
| /** | |
| * Creates an instance of PreTokenizerSequence. | |
| * @param {Object} config The configuration object for the pre-tokenizer sequence. | |
| * @param {Object[]} config.pretokenizers An array of pre-tokenizer configurations. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.tokenizers = config.pretokenizers.map(x => PreTokenizer.fromConfig(x)); | |
| } | |
| /** | |
| * Applies each pre-tokenizer in the sequence to the input text in turn. | |
| * @param {string} text The text to pre-tokenize. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} The pre-tokenized text. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| // Use reduce to apply each tokenizer to the text | |
| return this.tokenizers.reduce((preTokenizedText, tokenizer) => { | |
| return tokenizer.pre_tokenize(preTokenizedText, options); | |
| }, [text]); | |
| } | |
| } | |
| /** | |
| * Splits on word boundaries (using the following regular expression: `\w+|[^\w\s]+`). | |
| */ | |
| class WhitespacePreTokenizer extends PreTokenizer { | |
| /** | |
| * Creates an instance of WhitespacePreTokenizer. | |
| * @param {Object} config The configuration object for the pre-tokenizer. | |
| */ | |
| constructor(config) { | |
| super(); | |
| } | |
| /** | |
| * Pre-tokenizes the input text by splitting it on word boundaries. | |
| * @param {string} text The text to be pre-tokenized. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of tokens produced by splitting the input text on whitespace. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| return text.match(/\w+|[^\w\s]+/g) || []; | |
| } | |
| } | |
| /** | |
| * Splits a string of text by whitespace characters into individual tokens. | |
| * @extends PreTokenizer | |
| */ | |
| class WhitespaceSplit extends PreTokenizer { | |
| /** | |
| * Creates an instance of WhitespaceSplit. | |
| * @param {Object} config The configuration object for the pre-tokenizer. | |
| */ | |
| constructor(config) { | |
| super(); | |
| } | |
| /** | |
| * Pre-tokenizes the input text by splitting it on whitespace characters. | |
| * @param {string} text The text to be pre-tokenized. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of tokens produced by splitting the input text on whitespace. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| return whitespace_split(text); | |
| } | |
| } | |
| // NOTE: `ReplacePreTokenizer` is custom (to support `BlenderbotSmallTokenizer`) | |
| class ReplacePreTokenizer extends PreTokenizer { | |
| /** | |
| * @param {Object} config The configuration options for the pre-tokenizer. | |
| * @param {Object} config.pattern The pattern used to split the text. Can be a string or a regex object. | |
| * @param {string} config.content What to replace the pattern with. | |
| */ | |
| constructor(config) { | |
| super(); | |
| this.config = config; | |
| this.pattern = createPattern(this.config.pattern); | |
| this.content = this.config.content; | |
| } | |
| /** | |
| * Pre-tokenizes the input text by replacing certain characters. | |
| * @param {string} text The text to be pre-tokenized. | |
| * @param {Object} [options] Additional options for the pre-tokenization logic. | |
| * @returns {string[]} An array of tokens produced by replacing certain characters. | |
| */ | |
| pre_tokenize_text(text, options) { | |
| if (this.pattern === null) { | |
| return [text]; | |
| } | |
| return [text.replaceAll(this.pattern, this.config.content)]; | |
| } | |
| } | |
| const SPECIAL_TOKEN_ATTRIBUTES = [ | |
| 'bos_token', | |
| 'eos_token', | |
| 'unk_token', | |
| 'sep_token', | |
| 'pad_token', | |
| 'cls_token', | |
| 'mask_token', | |
| // additional_special_tokens (TODO) | |
| ] | |
| /** | |
| * | |
| * Helper function for padding values of an object, which are each arrays. | |
| * NOTE: No additional checks are made here for validity of arguments. | |
| * @param {Record<string, any[]>} item The input object. | |
| * @param {number} length The length to pad to. | |
| * @param {(key: string) => any} value_fn Determine the value to fill the array, based on its key. | |
| * @param {string} side Which side to pad the array. | |
| * @private | |
| */ | |
| function padHelper(item, length, value_fn, side) { | |
| for (const key of Object.keys(item)) { | |
| const diff = length - item[key].length; | |
| const value = value_fn(key); | |
| const padData = new Array(diff).fill(value); | |
| item[key] = side === 'right' | |
| ? mergeArrays(item[key], padData) | |
| : mergeArrays(padData, item[key]); | |
| } | |
| } | |
| /** | |
| * Helper function for truncating values of an object, which are each arrays. | |
| * NOTE: No additional checks are made here for validity of arguments. | |
| * @param {Record<string, any[]>} item The input object. | |
| * @param {number} length The length to truncate to. | |
| * @private | |
| */ | |
| function truncateHelper(item, length) { | |
| // Setting .length to a lower value truncates the array in-place: | |
| // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/length | |
| for (const key of Object.keys(item)) { | |
| item[key].length = length; | |
| } | |
| } | |
| /** | |
| * @typedef {Object} Message | |
| * @property {string} role The role of the message (e.g., "user" or "assistant" or "system"). | |
| * @property {string} content The content of the message. | |
| */ | |
| export class PreTrainedTokenizer extends Callable { | |
| return_token_type_ids = false; | |
| padding_side = 'right'; | |
| /** | |
| * Create a new PreTrainedTokenizer instance. | |
| * @param {Object} tokenizerJSON The JSON of the tokenizer. | |
| * @param {Object} tokenizerConfig The config of the tokenizer. | |
| */ | |
| constructor(tokenizerJSON, tokenizerConfig) { | |
| super(); | |
| this.config = tokenizerConfig; | |
| // Construct parts of the tokenizer from the JSON | |
| this.normalizer = Normalizer.fromConfig(tokenizerJSON.normalizer); | |
| this.pre_tokenizer = PreTokenizer.fromConfig(tokenizerJSON.pre_tokenizer); | |
| this.model = TokenizerModel.fromConfig(tokenizerJSON.model, tokenizerConfig); | |
| this.post_processor = PostProcessor.fromConfig(tokenizerJSON.post_processor); | |
| this.decoder = Decoder.fromConfig(tokenizerJSON.decoder); | |
| // Add added_tokens to model | |
| this.special_tokens = []; | |
| this.all_special_ids = []; | |
| /** @type {AddedToken[]} */ | |
| this.added_tokens = []; | |
| for (const addedToken of tokenizerJSON.added_tokens) { | |
| const token = new AddedToken(addedToken); | |
| this.added_tokens.push(token); | |
| this.model.tokens_to_ids.set(token.content, token.id); | |
| this.model.vocab[token.id] = token.content; | |
| if (token.special) { | |
| this.special_tokens.push(token.content); | |
| this.all_special_ids.push(token.id); | |
| } | |
| } | |
| // Update additional_special_tokens | |
| this.additional_special_tokens = tokenizerConfig.additional_special_tokens ?? []; | |
| this.special_tokens.push(...this.additional_special_tokens); | |
| this.special_tokens = [...new Set(this.special_tokens)]; // Remove duplicates | |
| if (this.decoder) { | |
| // Slight hack, but it prevents code duplication: | |
| this.decoder.added_tokens = this.added_tokens; | |
| // Another slight hack to add `end_of_word_suffix` (if present) to the decoder | |
| // This is needed for cases where BPE model and ByteLevel decoder are used | |
| // For more information, see https://github.com/huggingface/transformers.js/issues/74 | |
| // TODO: save this to the decoder when exporting? | |
| this.decoder.end_of_word_suffix = this.model.end_of_word_suffix; | |
| } | |
| this.added_tokens_splitter = new DictionarySplitter( | |
| this.added_tokens.map(x => x.content), | |
| ); | |
| /** @type {Map<string, AddedToken>} */ | |
| this.added_tokens_map = new Map(this.added_tokens.map(x => [x.content, x])) | |
| // Set mask token if present (otherwise will be undefined, which is fine) | |
| this.mask_token = this.getToken('mask_token'); | |
| this.mask_token_id = this.model.tokens_to_ids.get(this.mask_token); | |
| this.pad_token = this.getToken('pad_token', 'eos_token'); | |
| this.pad_token_id = this.model.tokens_to_ids.get(this.pad_token); | |
| this.sep_token = this.getToken('sep_token'); | |
| this.sep_token_id = this.model.tokens_to_ids.get(this.sep_token); | |
| this.unk_token = this.getToken('unk_token'); | |
| this.unk_token_id = this.model.tokens_to_ids.get(this.unk_token); | |
| this.bos_token = this.getToken('bos_token'); | |
| this.bos_token_id = this.model.tokens_to_ids.get(this.bos_token); | |
| this.eos_token = this.getToken('eos_token'); | |
| this.eos_token_id = this.model.tokens_to_ids.get(this.eos_token); | |
| this.model_max_length = tokenizerConfig.model_max_length; | |
| /** @type {boolean} Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). */ | |
| this.remove_space = tokenizerConfig.remove_space; | |
| this.clean_up_tokenization_spaces = tokenizerConfig.clean_up_tokenization_spaces ?? true; | |
| this.do_lowercase_and_remove_accent = tokenizerConfig.do_lowercase_and_remove_accent ?? false; | |
| if (tokenizerConfig.padding_side) { | |
| this.padding_side = tokenizerConfig.padding_side; | |
| } | |
| this.add_bos_token = tokenizerConfig.add_bos_token; | |
| this.add_eos_token = tokenizerConfig.add_eos_token; | |
| this.legacy = false; | |
| this.chat_template = tokenizerConfig.chat_template ?? null; | |
| if (Array.isArray(this.chat_template)) { | |
| // Chat templates are stored as lists of dicts with fixed key names, | |
| // we reconstruct that into a single dict while loading them. | |
| const chat_template = Object.create(null); | |
| for (const { name, template } of this.chat_template) { | |
| if (typeof name !== 'string' || typeof template !== 'string') { | |
| throw new Error('Chat template must be a list of objects with "name" and "template" properties'); | |
| } | |
| chat_template[name] = template; | |
| } | |
| this.chat_template = chat_template; | |
| } | |
| this._compiled_template_cache = new Map(); | |
| } | |
| /** | |
| * Returns the value of the first matching key in the tokenizer config object. | |
| * @param {...string} keys One or more keys to search for in the tokenizer config object. | |
| * @returns {string|null} The value associated with the first matching key, or null if no match is found. | |
| * @throws {Error} If an object is found for a matching key and its __type property is not "AddedToken". | |
| * @private | |
| */ | |
| getToken(...keys) { | |
| for (const key of keys) { | |
| const item = this.config[key]; | |
| if (!item) continue; | |
| if (typeof item === 'object') { | |
| if (item.__type === 'AddedToken') { | |
| return item.content; | |
| } else { | |
| throw Error(`Unknown token: ${item}`); | |
| } | |
| } else { | |
| return item; | |
| } | |
| } | |
| return null; | |
| } | |
| /** | |
| * Loads a pre-trained tokenizer from the given `pretrained_model_name_or_path`. | |
| * | |
| * @param {string} pretrained_model_name_or_path The path to the pre-trained tokenizer. | |
| * @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer. | |
| * | |
| * @throws {Error} Throws an error if the tokenizer.json or tokenizer_config.json files are not found in the `pretrained_model_name_or_path`. | |
| * @returns {Promise<PreTrainedTokenizer>} A new instance of the `PreTrainedTokenizer` class. | |
| */ | |
| static async from_pretrained(pretrained_model_name_or_path, { | |
| progress_callback = null, | |
| config = null, | |
| cache_dir = null, | |
| local_files_only = false, | |
| revision = 'main', | |
| legacy = null, | |
| } = {}) { | |
| const info = await loadTokenizer(pretrained_model_name_or_path, { | |
| progress_callback, | |
| config, | |
| cache_dir, | |
| local_files_only, | |
| revision, | |
| legacy, | |
| }) | |
| // @ts-ignore | |
| return new this(...info); | |
| } | |
| /** | |
| * @typedef {number[]|number[][]|Tensor} BatchEncodingItem | |
| * | |
| * @typedef {Object} BatchEncoding Holds the output of the tokenizer's call function. | |
| * @property {BatchEncodingItem} input_ids List of token ids to be fed to a model. | |
| * @property {BatchEncodingItem} attention_mask List of indices specifying which tokens should be attended to by the model. | |
| * @property {BatchEncodingItem} [token_type_ids] List of token type ids to be fed to a model. | |
| */ | |
| /** | |
| * Encode/tokenize the given text(s). | |
| * @param {string|string[]} text The text to tokenize. | |
| * @param {Object} options An optional object containing the following properties: | |
| * @param {string|string[]} [options.text_pair=null] Optional second sequence to be encoded. If set, must be the same type as text. | |
| * @param {boolean|'max_length'} [options.padding=false] Whether to pad the input sequences. | |
| * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. | |
| * @param {boolean} [options.truncation=null] Whether to truncate the input sequences. | |
| * @param {number} [options.max_length=null] Maximum length of the returned list and optionally padding length. | |
| * @param {boolean} [options.return_tensor=true] Whether to return the results as Tensors or arrays. | |
| * @param {boolean} [options.return_token_type_ids=null] Whether to return the token type ids. | |
| * @returns {BatchEncoding} Object to be passed to the model. | |
| */ | |
| _call( | |
| // Required positional arguments | |
| text, | |
| // Optional keyword arguments | |
| { | |
| text_pair = null, | |
| add_special_tokens = true, | |
| padding = false, | |
| truncation = null, | |
| max_length = null, | |
| return_tensor = true, // Different to HF | |
| return_token_type_ids = null, | |
| } = {}, | |
| ) { | |
| const isBatched = Array.isArray(text); | |
| /** @type {EncodingSingle[]} */ | |
| let encodedTokens; | |
| if (isBatched) { | |
| if (text.length === 0) { | |
| throw Error('text array must be non-empty') | |
| } | |
| if (text_pair !== null) { | |
| if (!Array.isArray(text_pair)) { | |
| throw Error('text_pair must also be an array') | |
| } else if (text.length !== text_pair.length) { | |
| throw Error('text and text_pair must have the same length') | |
| } | |
| encodedTokens = text.map( | |
| (t, i) => this._encode_plus(t, { text_pair: text_pair[i], add_special_tokens, return_token_type_ids }) | |
| ) | |
| } else { | |
| encodedTokens = text.map(x => this._encode_plus(x, { add_special_tokens, return_token_type_ids })); | |
| } | |
| } else { | |
| if (text === null || text === undefined) { | |
| throw Error('text may not be null or undefined') | |
| } | |
| if (Array.isArray(text_pair)) { | |
| throw Error('When specifying `text_pair`, since `text` is a string, `text_pair` must also be a string (i.e., not an array).') | |
| } | |
| // For single input, we just wrap in an array, and then unwrap later. | |
| encodedTokens = [this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids })]; | |
| } | |
| // At this point, `encodedTokens` is batched, of shape [batch_size, tokens]. | |
| // However, array may be jagged. So, we may need pad to max_length. | |
| if (max_length === null) { | |
| max_length = this.model_max_length; | |
| } else if (truncation === null) { | |
| if (padding === true) { | |
| console.warn( | |
| "`max_length` is ignored when `padding: true` and there is no truncation strategy. " + | |
| "To pad to max length, use `padding: 'max_length'`." | |
| ) | |
| max_length = this.model_max_length; | |
| } else if (padding === false) { | |
| console.warn("Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation: true` to explicitly truncate examples to max length."); | |
| truncation = true; | |
| } | |
| } | |
| // padding: 'max_length' doesn't require any additional calculation | |
| // but padding: true has to calculate max_length from the sequences | |
| if (padding === true) { | |
| max_length = Math.min(max(encodedTokens.map(x => x.input_ids.length))[0], max_length ?? Infinity); | |
| } | |
| // Ensure it is less than model max length | |
| max_length = Math.min(max_length, this.model_max_length ?? Infinity); | |
| if (padding || truncation) { | |
| // Perform padding and/or truncation | |
| for (let i = 0; i < encodedTokens.length; ++i) { | |
| if (encodedTokens[i].input_ids.length === max_length) { | |
| continue; | |
| } else if (encodedTokens[i].input_ids.length > max_length) { | |
| // possibly truncate | |
| if (truncation) { | |
| truncateHelper(encodedTokens[i], max_length); | |
| } | |
| } else { // t.length < max_length | |
| // possibly pad | |
| if (padding) { | |
| padHelper( | |
| encodedTokens[i], | |
| max_length, | |
| key => key === 'input_ids' ? this.pad_token_id : 0, | |
| this.padding_side | |
| ); | |
| } | |
| } | |
| } | |
| } | |
| const result = {}; | |
| if (return_tensor) { | |
| if (!(padding && truncation)) { | |
| // Not, guaranteed that all items have same length, so | |
| // we perform additional check | |
| if ( | |
| encodedTokens.some(x => { | |
| for (const key of Object.keys(x)) { | |
| if (x[key].length !== encodedTokens[0][key]?.length) { | |
| return true; | |
| } | |
| } | |
| return false; | |
| }) | |
| ) { | |
| throw Error( | |
| "Unable to create tensor, you should probably activate truncation and/or padding " + | |
| "with 'padding=true' and 'truncation=true' to have batched tensors with the same length." | |
| ) | |
| } | |
| } | |
| // Now we actually convert to tensor | |
| // NOTE: In the same way as the python library, we return a batched tensor, regardless of | |
| // whether we have a single input or multiple inputs. | |
| const dims = [encodedTokens.length, encodedTokens[0].input_ids.length]; | |
| for (const key of Object.keys(encodedTokens[0])) { | |
| result[key] = new Tensor('int64', | |
| BigInt64Array.from(encodedTokens.flatMap(x => x[key]).map(BigInt)), | |
| dims | |
| ); | |
| } | |
| } else { | |
| for (const key of Object.keys(encodedTokens[0])) { | |
| result[key] = encodedTokens.map(x => x[key]); | |
| } | |
| // If not returning a tensor, we match the input type | |
| if (!isBatched) { | |
| // Input was not batched, so we unwrap | |
| for (const key of Object.keys(result)) { | |
| result[key] = result[key][0]; | |
| } | |
| } | |
| } | |
| return /** @type {BatchEncoding} */(result); | |
| } | |
| /** | |
| * Encodes a single text using the preprocessor pipeline of the tokenizer. | |
| * | |
| * @param {string|null} text The text to encode. | |
| * @returns {string[]|null} The encoded tokens. | |
| */ | |
| _encode_text(text) { | |
| if (text === null) return null; | |
| // Actual function which does encoding, for a single text | |
| // First, we take care of special tokens. Needed to avoid issues arising from | |
| // normalization and/or pretokenization (which may not preserve special tokens) | |
| const sections = this.added_tokens_splitter.split(text); | |
| // Process left/right stripping of added tokens | |
| for (let i = 0; i < sections.length; ++i) { | |
| const addedToken = this.added_tokens_map.get(sections[i]); | |
| if (addedToken) { | |
| if (addedToken.lstrip && i > 0) { | |
| sections[i - 1] = sections[i - 1].trimEnd(); | |
| } | |
| if (addedToken.rstrip && i < sections.length - 1) { | |
| sections[i + 1] = sections[i + 1].trimStart(); | |
| } | |
| } | |
| } | |
| const tokens = sections.flatMap((x, section_index) => { | |
| if (x.length === 0) return []; | |
| if (this.added_tokens_map.has(x)) return [x]; // Return added tokens unchanged | |
| if (this.remove_space === true) { | |
| x = x.trim().split(/\s+/).join(' '); | |
| } | |
| if (this.do_lowercase_and_remove_accent) { | |
| x = lowercase_and_remove_accent(x); | |
| } | |
| if (this.normalizer !== null) { | |
| x = this.normalizer(x); | |
| } | |
| // If, after normalization, this section is empty (e.g., trimming whitespace), | |
| // we return an empty array | |
| if (x.length === 0) { | |
| return []; | |
| } | |
| const sectionTokens = (this.pre_tokenizer !== null) ? this.pre_tokenizer(x, { | |
| section_index, | |
| }) : [x]; | |
| const tokens = this.model(sectionTokens); | |
| return tokens; | |
| }); | |
| return tokens; | |
| } | |
| /** | |
| * Encodes a single text or a pair of texts using the model's tokenizer. | |
| * | |
| * @param {string} text The text to encode. | |
| * @param {Object} options An optional object containing the following properties: | |
| * @param {string} [options.text_pair=null] The optional second text to encode. | |
| * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. | |
| * @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids. | |
| * @returns {EncodingSingle} An object containing the encoded text. | |
| * @private | |
| */ | |
| _encode_plus(text, { | |
| text_pair = null, | |
| add_special_tokens = true, | |
| return_token_type_ids = null, | |
| } = {}) { | |
| const { tokens, token_type_ids } = this._tokenize_helper(text, { pair: text_pair, add_special_tokens }); | |
| const input_ids = this.model.convert_tokens_to_ids(tokens); | |
| const result = { | |
| input_ids, | |
| attention_mask: new Array(input_ids.length).fill(1), | |
| } | |
| if ((return_token_type_ids ?? this.return_token_type_ids) && token_type_ids) { | |
| result.token_type_ids = token_type_ids; | |
| } | |
| return result; | |
| } | |
| /** | |
| * Internal helper function to tokenize a text, and optionally a pair of texts. | |
| * @param {string} text The text to tokenize. | |
| * @param {Object} options An optional object containing the following properties: | |
| * @param {string} [options.pair=null] The optional second text to tokenize. | |
| * @param {boolean} [options.add_special_tokens=false] Whether or not to add the special tokens associated with the corresponding model. | |
| * @returns {{tokens: string[], token_type_ids?: number[]}} An object containing the tokens and optionally the token type IDs. | |
| */ | |
| _tokenize_helper(text, { | |
| pair = null, | |
| add_special_tokens = false, | |
| } = {}) { | |
| const tokens = this._encode_text(text); | |
| const tokens2 = this._encode_text(pair); | |
| return this.post_processor | |
| ? this.post_processor(tokens, tokens2, { add_special_tokens }) | |
| : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) }; | |
| } | |
| /** | |
| * Converts a string into a sequence of tokens. | |
| * @param {string} text The sequence to be encoded. | |
| * @param {Object} options An optional object containing the following properties: | |
| * @param {string} [options.pair] A second sequence to be encoded with the first. | |
| * @param {boolean} [options.add_special_tokens=false] Whether or not to add the special tokens associated with the corresponding model. | |
| * @returns {string[]} The list of tokens. | |
| */ | |
| tokenize(text, { | |
| pair = null, | |
| add_special_tokens = false, | |
| } = {}) { | |
| return this._tokenize_helper(text, { pair, add_special_tokens }).tokens; | |
| } | |
| /** | |
| * Encodes a single text or a pair of texts using the model's tokenizer. | |
| * | |
| * @param {string} text The text to encode. | |
| * @param {Object} options An optional object containing the following properties: | |
| * @param {string} [options.text_pair=null] The optional second text to encode. | |
| * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. | |
| * @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids. | |
| * @returns {number[]} An array of token IDs representing the encoded text(s). | |
| */ | |
| encode(text, { | |
| text_pair = null, | |
| add_special_tokens = true, | |
| return_token_type_ids = null, | |
| } = {}) { | |
| return this._encode_plus(text, { | |
| text_pair, | |
| add_special_tokens, | |
| return_token_type_ids, | |
| }).input_ids; | |
| } | |
| /** | |
| * Decode a batch of tokenized sequences. | |
| * @param {number[][]|Tensor} batch List/Tensor of tokenized input sequences. | |
| * @param {Object} decode_args (Optional) Object with decoding arguments. | |
| * @returns {string[]} List of decoded sequences. | |
| */ | |
| batch_decode(batch, decode_args = {}) { | |
| if (batch instanceof Tensor) { | |
| batch = batch.tolist(); | |
| } | |
| return batch.map(x => this.decode(x, decode_args)); | |
| } | |
| /** | |
| * Decodes a sequence of token IDs back to a string. | |
| * | |
| * @param {number[]|bigint[]|Tensor} token_ids List/Tensor of token IDs to decode. | |
| * @param {Object} [decode_args={}] | |
| * @param {boolean} [decode_args.skip_special_tokens=false] If true, special tokens are removed from the output string. | |
| * @param {boolean} [decode_args.clean_up_tokenization_spaces=true] If true, spaces before punctuations and abbreviated forms are removed. | |
| * | |
| * @returns {string} The decoded string. | |
| * @throws {Error} If `token_ids` is not a non-empty array of integers. | |
| */ | |
| decode( | |
| token_ids, | |
| decode_args = {}, | |
| ) { | |
| if (token_ids instanceof Tensor) { | |
| token_ids = prepareTensorForDecode(token_ids); | |
| } | |
| if (!Array.isArray(token_ids) || token_ids.length === 0 || !isIntegralNumber(token_ids[0])) { | |
| throw Error("token_ids must be a non-empty array of integers."); | |
| } | |
| return this.decode_single(token_ids, decode_args) | |
| } | |
| /** | |
| * Decode a single list of token ids to a string. | |
| * @param {number[]|bigint[]} token_ids List of token ids to decode | |
| * @param {Object} decode_args Optional arguments for decoding | |
| * @param {boolean} [decode_args.skip_special_tokens=false] Whether to skip special tokens during decoding | |
| * @param {boolean} [decode_args.clean_up_tokenization_spaces=null] Whether to clean up tokenization spaces during decoding. | |
| * If null, the value is set to `this.decoder.cleanup` if it exists, falling back to `this.clean_up_tokenization_spaces` if it exists, falling back to `true`. | |
| * @returns {string} The decoded string | |
| */ | |
| decode_single( | |
| token_ids, | |
| { | |
| skip_special_tokens = false, | |
| clean_up_tokenization_spaces = null, | |
| } | |
| ) { | |
| let tokens = this.model.convert_ids_to_tokens(token_ids); | |
| if (skip_special_tokens) { | |
| tokens = tokens.filter(x => !this.special_tokens.includes(x)); | |
| } | |
| // If `this.decoder` is null, we just join tokens with a space: | |
| // https://github.com/huggingface/tokenizers/blob/8edec536a737cb04494b454805be16c020abb14f/tokenizers/src/tokenizer/mod.rs#L835 | |
| /** @type {string} */ | |
| let decoded = this.decoder ? this.decoder(tokens) : tokens.join(' '); | |
| // Slight hack, but prevents having to pass `skip_special_tokens` to | |
| // each call to `decode`, which would lead to code duplication. | |
| if (this.decoder && this.decoder.end_of_word_suffix) { | |
| decoded = decoded.replaceAll(this.decoder.end_of_word_suffix, ' '); | |
| if (skip_special_tokens) { | |
| decoded = decoded.trim(); | |
| } | |
| } | |
| if (clean_up_tokenization_spaces ?? this.clean_up_tokenization_spaces) { | |
| decoded = clean_up_tokenization(decoded); | |
| } | |
| return decoded; | |
| } | |
| /** | |
| * Retrieve the chat template string used for tokenizing chat messages. This template is used | |
| * internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat | |
| * template for better generation tracking. | |
| * | |
| * @param {Object} options An optional object containing the following properties: | |
| * @param {string} [options.chat_template=null] | |
| * A Jinja template or the name of a template to use for this conversion. | |
| * It is usually not necessary to pass anything to this argument, | |
| * as the model's template will be used by default. | |
| * @param {Object[]} [options.tools=null] | |
| * A list of tools (callable functions) that will be accessible to the model. If the template does not | |
| * support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, | |
| * giving the name, description and argument types for the tool. See our | |
| * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) | |
| * for more information. | |
| * @returns {string} The chat template string. | |
| */ | |
| get_chat_template({ | |
| chat_template = null, | |
| tools = null, | |
| } = {}) { | |
| // First, handle the cases when the model has a dict of multiple templates | |
| if (this.chat_template && typeof this.chat_template === 'object') { | |
| const template_dict = this.chat_template; | |
| if (chat_template !== null && Object.hasOwn(template_dict, chat_template)) { | |
| // The user can pass the name of a template to the chat template argument instead of an entire template | |
| chat_template = template_dict[chat_template]; | |
| } else if (chat_template === null) { | |
| if (tools !== null && 'tool_use' in template_dict) { | |
| chat_template = template_dict['tool_use']; | |
| } else if ('default' in template_dict) { | |
| chat_template = template_dict['default']; | |
| } else { | |
| throw Error( | |
| `This model has multiple chat templates with no default specified! Please either pass a chat ` + | |
| `template or the name of the template you wish to use to the 'chat_template' argument. Available ` + | |
| `template names are ${Object.keys(template_dict).sort()}.` | |
| ) | |
| } | |
| } | |
| } else if (chat_template === null) { | |
| // These are the cases when the model has a single template | |
| // priority: `chat_template` argument > `tokenizer.chat_template` | |
| if (this.chat_template) { | |
| chat_template = this.chat_template; | |
| } else { | |
| throw Error( | |
| "Cannot use apply_chat_template() because tokenizer.chat_template is not set and no template " + | |
| "argument was passed! For information about writing templates and setting the " + | |
| "tokenizer.chat_template attribute, please see the documentation at " + | |
| "https://huggingface.co/docs/transformers/main/en/chat_templating" | |
| ) | |
| } | |
| } | |
| return chat_template; | |
| } | |
| /** | |
| * Converts a list of message objects with `"role"` and `"content"` keys to a list of token | |
| * ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to | |
| * determine the format and control tokens to use when converting. | |
| * | |
| * See [here](https://huggingface.co/docs/transformers/chat_templating) for more information. | |
| * | |
| * **Example:** Applying a chat template to a conversation. | |
| * | |
| * ```javascript | |
| * import { AutoTokenizer } from "@huggingface/transformers"; | |
| * | |
| * const tokenizer = await AutoTokenizer.from_pretrained("Xenova/mistral-tokenizer-v1"); | |
| * | |
| * const chat = [ | |
| * { "role": "user", "content": "Hello, how are you?" }, | |
| * { "role": "assistant", "content": "I'm doing great. How can I help you today?" }, | |
| * { "role": "user", "content": "I'd like to show off how chat templating works!" }, | |
| * ] | |
| * | |
| * const text = tokenizer.apply_chat_template(chat, { tokenize: false }); | |
| * // "<s>[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?</s> [INST] I'd like to show off how chat templating works! [/INST]" | |
| * | |
| * const input_ids = tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false }); | |
| * // [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793] | |
| * ``` | |
| * | |
| * @param {Message[]} conversation A list of message objects with `"role"` and `"content"` keys, | |
| * representing the chat history so far. | |
| * @param {Object} options An optional object containing the following properties: | |
| * @param {string} [options.chat_template=null] A Jinja template to use for this conversion. If | |
| * this is not passed, the model's chat template will be used instead. | |
| * @param {Object[]} [options.tools=null] | |
| * A list of tools (callable functions) that will be accessible to the model. If the template does not | |
| * support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, | |
| * giving the name, description and argument types for the tool. See our | |
| * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) | |
| * for more information. | |
| * @param {Record<string, string>[]} [options.documents=null] | |
| * A list of dicts representing documents that will be accessible to the model if it is performing RAG | |
| * (retrieval-augmented generation). If the template does not support RAG, this argument will have no | |
| * effect. We recommend that each document should be a dict containing "title" and "text" keys. Please | |
| * see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) | |
| * for examples of passing documents with chat templates. | |
| * @param {boolean} [options.add_generation_prompt=false] Whether to end the prompt with the token(s) that indicate | |
| * the start of an assistant message. This is useful when you want to generate a response from the model. | |
| * Note that this argument will be passed to the chat template, and so it must be supported in the | |
| * template for this argument to have any effect. | |
| * @param {boolean} [options.tokenize=true] Whether to tokenize the output. If false, the output will be a string. | |
| * @param {boolean} [options.padding=false] Whether to pad sequences to the maximum length. Has no effect if tokenize is false. | |
| * @param {boolean} [options.truncation=false] Whether to truncate sequences to the maximum length. Has no effect if tokenize is false. | |
| * @param {number} [options.max_length=null] Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is false. | |
| * If not specified, the tokenizer's `max_length` attribute will be used as a default. | |
| * @param {boolean} [options.return_tensor=true] Whether to return the output as a Tensor or an Array. Has no effect if tokenize is false. | |
| * @param {boolean} [options.return_dict=true] Whether to return a dictionary with named outputs. Has no effect if tokenize is false. | |
| * @param {Object} [options.tokenizer_kwargs={}] Additional options to pass to the tokenizer. | |
| * @returns {string | Tensor | number[]| number[][]|BatchEncoding} The tokenized output. | |
| */ | |
| apply_chat_template(conversation, { | |
| tools = null, | |
| documents = null, | |
| chat_template = null, | |
| add_generation_prompt = false, | |
| tokenize = true, | |
| padding = false, | |
| truncation = false, | |
| max_length = null, | |
| return_tensor = true, | |
| return_dict = false, | |
| tokenizer_kwargs = {}, | |
| ...kwargs | |
| } = {}) { | |
| chat_template = this.get_chat_template({ chat_template, tools }); | |
| if (typeof chat_template !== 'string') { | |
| throw Error(`chat_template must be a string, but got ${typeof chat_template}`); | |
| } | |
| // Compilation function uses a cache to avoid recompiling the same template | |
| let compiledTemplate = this._compiled_template_cache.get(chat_template); | |
| if (compiledTemplate === undefined) { | |
| compiledTemplate = new Template(chat_template); | |
| this._compiled_template_cache.set(chat_template, compiledTemplate); | |
| } | |
| const special_tokens_map = Object.create(null); | |
| for (const key of SPECIAL_TOKEN_ATTRIBUTES) { | |
| const value = this.getToken(key); | |
| if (value) { | |
| special_tokens_map[key] = value; | |
| } | |
| } | |
| const rendered = compiledTemplate.render({ | |
| messages: conversation, | |
| add_generation_prompt, | |
| tools, | |
| documents, | |
| ...special_tokens_map, | |
| ...kwargs, | |
| }); | |
| if (tokenize) { | |
| const out = this._call(rendered, { | |
| add_special_tokens: false, | |
| padding, | |
| truncation, | |
| max_length, | |
| return_tensor, | |
| ...tokenizer_kwargs, | |
| }); | |
| return return_dict ? out : out.input_ids; | |
| } | |
| return rendered; | |
| } | |
| } | |
| /** | |
| * BertTokenizer is a class used to tokenize text for BERT models. | |
| * @extends PreTrainedTokenizer | |
| */ | |
| export class BertTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| /** | |
| * Albert tokenizer | |
| * @extends PreTrainedTokenizer | |
| */ | |
| export class AlbertTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| export class MobileBertTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| export class SqueezeBertTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| export class DebertaTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| export class DebertaV2Tokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| export class HerbertTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| export class ConvBertTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| export class RoFormerTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| export class DistilBertTokenizer extends PreTrainedTokenizer { } | |
| export class CamembertTokenizer extends PreTrainedTokenizer { } | |
| export class XLMTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| constructor(tokenizerJSON, tokenizerConfig) { | |
| super(tokenizerJSON, tokenizerConfig); | |
| console.warn('WARNING: `XLMTokenizer` is not yet supported by Hugging Face\'s "fast" tokenizers library. Therefore, you may experience slightly inaccurate results.') | |
| } | |
| } | |
| export class ElectraTokenizer extends PreTrainedTokenizer { | |
| return_token_type_ids = true; | |
| } | |
| export class T5Tokenizer extends PreTrainedTokenizer { } | |
| export class GPT2Tokenizer extends PreTrainedTokenizer { } | |
| export class BartTokenizer extends PreTrainedTokenizer { } | |
| export class MBartTokenizer extends PreTrainedTokenizer { | |
| constructor(tokenizerJSON, tokenizerConfig) { | |
| super(tokenizerJSON, tokenizerConfig); | |
| this.languageRegex = /^[a-z]{2}_[A-Z]{2}$/; | |
| this.language_codes = this.special_tokens.filter(x => this.languageRegex.test(x)); | |
| this.lang_to_token = x => x; // Identity function | |
| } | |
| /** | |
| * Helper function to build translation inputs for an `MBartTokenizer`. | |
| * @param {string|string[]} raw_inputs The text to tokenize. | |
| * @param {Object} tokenizer_options Options to be sent to the tokenizer | |
| * @param {Object} generate_kwargs Generation options. | |
| * @returns {Object} Object to be passed to the model. | |
| */ | |
| _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) { | |
| return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs); | |
| } | |
| } | |
| export class MBart50Tokenizer extends MBartTokenizer { } // NOTE: extends MBartTokenizer | |
| export class RobertaTokenizer extends PreTrainedTokenizer { } | |
| export class BloomTokenizer extends PreTrainedTokenizer { } | |
| const SPIECE_UNDERLINE = "▁"; | |
| export class LlamaTokenizer extends PreTrainedTokenizer { | |
| padding_side = 'left'; | |
| constructor(tokenizerJSON, tokenizerConfig) { | |
| super(tokenizerJSON, tokenizerConfig); | |
| this.legacy = tokenizerConfig.legacy ?? true; | |
| if (!this.legacy) { | |
| // See https://github.com/huggingface/transformers/pull/24565 for more information | |
| this.normalizer = null; | |
| this.pre_tokenizer = new MetaspacePreTokenizer({ | |
| replacement: SPIECE_UNDERLINE, | |
| add_prefix_space: true, | |
| prepend_scheme: "first", | |
| }); | |
| } | |
| } | |
| /** | |
| * Helper function to handle legacy encoding of SPM tokenizers. | |
| * Adapted from https://github.com/huggingface/transformers/blob/e6dcf8abd6f65bb4b6dfc1831b20d9ba49ce00e2/src/transformers/models/t5/tokenization_t5.py#L374-L387 | |
| * @param {string} text The text to encode. | |
| * @returns {string[]} The encoded tokens. | |
| */ | |
| _encode_text(text) { | |
| if (text === null) return null; | |
| if (this.legacy || text.length === 0) { | |
| return super._encode_text(text); | |
| } | |
| let tokens = super._encode_text(SPIECE_UNDERLINE + text.replaceAll(SPIECE_UNDERLINE, " ")); | |
| if (tokens.length > 1 && tokens[0] === SPIECE_UNDERLINE && this.special_tokens.includes(tokens[1])) { | |
| tokens = tokens.slice(1); | |
| } | |
| return tokens; | |
| } | |
| } | |
| export class CodeLlamaTokenizer extends PreTrainedTokenizer { } | |
| export class XLMRobertaTokenizer extends PreTrainedTokenizer { } | |
| export class MPNetTokenizer extends PreTrainedTokenizer { } | |
| export class FalconTokenizer extends PreTrainedTokenizer { } | |
| export class GPTNeoXTokenizer extends PreTrainedTokenizer { } | |
| export class EsmTokenizer extends PreTrainedTokenizer { } | |
| export class Qwen2Tokenizer extends PreTrainedTokenizer { } | |
| export class GemmaTokenizer extends PreTrainedTokenizer { } | |
| export class Grok1Tokenizer extends PreTrainedTokenizer { } | |
| /** | |
| * Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`. | |
| * @param {PreTrainedTokenizer} self The tokenizer instance. | |
| * @param {string|string[]} raw_inputs The text to tokenize. | |
| * @param {Object} tokenizer_options Options to be sent to the tokenizer | |
| * @param {Object} generate_kwargs Generation options. | |
| * @returns {Object} Object to be passed to the model. | |
| * @private | |
| */ | |
| function _build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs) { | |
| if (!('language_codes' in self) || !Array.isArray(self.language_codes)) { | |
| throw new Error('Tokenizer must have `language_codes` attribute set and it should be an array of language ids.') | |
| } | |
| if (!('languageRegex' in self) || !(self.languageRegex instanceof RegExp)) { | |
| throw new Error('Tokenizer must have `languageRegex` attribute set and it should be a regular expression.') | |
| } | |
| if (!('lang_to_token' in self) || typeof self.lang_to_token !== 'function') { | |
| throw new Error('Tokenizer must have `lang_to_token` attribute set and it should be a function.') | |
| } | |
| const src_lang_token = generate_kwargs.src_lang; | |
| const tgt_lang_token = generate_kwargs.tgt_lang; | |
| // Check that the target language is valid: | |
| if (!self.language_codes.includes(tgt_lang_token)) { | |
| throw new Error(`Target language code "${tgt_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`); | |
| } | |
| // Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default. | |
| if (src_lang_token !== undefined) { | |
| // Check that the source language is valid: | |
| if (!self.language_codes.includes(src_lang_token)) { | |
| throw new Error(`Source language code "${src_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`); | |
| } | |
| // In the same way as the Python library, we override the post-processor | |
| // to force the source language to be first: | |
| for (const item of self.post_processor.config.single) { | |
| if ('SpecialToken' in item && self.languageRegex.test(item.SpecialToken.id)) { | |
| item.SpecialToken.id = self.lang_to_token(src_lang_token); | |
| break; | |
| } | |
| } | |
| // TODO: Do the same for pair? | |
| } | |
| // Override the `forced_bos_token_id` to force the correct language | |
| generate_kwargs.forced_bos_token_id = self.model.convert_tokens_to_ids([self.lang_to_token(tgt_lang_token)])[0]; | |
| return self._call(raw_inputs, tokenizer_options); | |
| } | |
| /** | |
| * The NllbTokenizer class is used to tokenize text for NLLB ("No Language Left Behind") models. | |
| * | |
| * No Language Left Behind (NLLB) is a first-of-its-kind, AI breakthrough project | |
| * that open-sources models capable of delivering high-quality translations directly | |
| * between any pair of 200+ languages — including low-resource languages like Asturian, | |
| * Luganda, Urdu and more. It aims to help people communicate with anyone, anywhere, | |
| * regardless of their language preferences. For more information, check out their | |
| * [paper](https://huggingface.co/papers/2207.04672). | |
| * | |
| * For a list of supported languages (along with their language codes), | |
| * @see {@link https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200} | |
| */ | |
| export class NllbTokenizer extends PreTrainedTokenizer { | |
| constructor(tokenizerJSON, tokenizerConfig) { | |
| super(tokenizerJSON, tokenizerConfig); | |
| this.languageRegex = /^[a-z]{3}_[A-Z][a-z]{3}$/; | |
| this.language_codes = this.special_tokens.filter(x => this.languageRegex.test(x)); | |
| this.lang_to_token = x => x; // Identity function | |
| } | |
| /** | |
| * Helper function to build translation inputs for an `NllbTokenizer`. | |
| * @param {string|string[]} raw_inputs The text to tokenize. | |
| * @param {Object} tokenizer_options Options to be sent to the tokenizer | |
| * @param {Object} generate_kwargs Generation options. | |
| * @returns {Object} Object to be passed to the model. | |
| */ | |
| _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) { | |
| return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs); | |
| } | |
| } | |
| /** | |
| * The M2M100Tokenizer class is used to tokenize text for M2M100 ("Many-to-Many") models. | |
| * | |
| * M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many | |
| * multilingual translation. It was introduced in this [paper](https://huggingface.co/papers/2010.11125) | |
| * and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository. | |
| * | |
| * For a list of supported languages (along with their language codes), | |
| * @see {@link https://huggingface.co/facebook/m2m100_418M#languages-covered} | |
| */ | |
| export class M2M100Tokenizer extends PreTrainedTokenizer { | |
| constructor(tokenizerJSON, tokenizerConfig) { | |
| super(tokenizerJSON, tokenizerConfig); | |
| this.languageRegex = /^__[a-z]{2,3}__$/; | |
| this.language_codes = this.special_tokens | |
| .filter(x => this.languageRegex.test(x)) | |
| .map(x => x.slice(2, -2)); | |
| this.lang_to_token = x => `__${x}__`; | |
| } | |
| /** | |
| * Helper function to build translation inputs for an `M2M100Tokenizer`. | |
| * @param {string|string[]} raw_inputs The text to tokenize. | |
| * @param {Object} tokenizer_options Options to be sent to the tokenizer | |
| * @param {Object} generate_kwargs Generation options. | |
| * @returns {Object} Object to be passed to the model. | |
| */ | |
| _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) { | |
| return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs); | |
| } | |
| } | |
| /** | |
| * WhisperTokenizer tokenizer | |
| * @extends PreTrainedTokenizer | |
| */ | |
| export class WhisperTokenizer extends PreTrainedTokenizer { | |
| get timestamp_begin() { | |
| return this.model.convert_tokens_to_ids(["<|notimestamps|>"])[0] + 1; | |
| } | |
| /** | |
| * Decodes automatic speech recognition (ASR) sequences. | |
| * @param {Array<{tokens: bigint[], token_timestamps?: number[], stride: number[]}>} sequences The sequences to decode. | |
| * @param {Object} options The options to use for decoding. | |
| * @returns {Array<string|{chunks?: undefined|Array<{language: string|null, timestamp: Array<number|null>, text: string}>}>} The decoded sequences. | |
| */ | |
| _decode_asr(sequences, { | |
| return_timestamps = false, | |
| return_language = false, | |
| time_precision = null, | |
| force_full_sequences = true | |
| } = {}) { | |
| // Set force_full_sequences=false if you want streaming | |
| // TODO add support for `return_language` | |
| // Internal method meant to only be used by asr pipeline. | |
| // Handles all the little quirks specific to whisper to handle | |
| // the various options not allowed in other seq2seq models | |
| // =========== Overview ============ | |
| // - iterate over all outputs | |
| // - all tokens within output | |
| // - Each token can be | |
| // - language token | |
| // - special token | |
| // - timestamp token | |
| // - text token | |
| // - We accumulate the text tokens. | |
| // - We split on end timestamps | |
| // - Lots of complexity comes from stride and timestamps | |
| if (time_precision === null) { | |
| throw Error("Must specify time_precision") | |
| } | |
| let last_language = null; | |
| const returnWordTimestamps = return_timestamps === "word"; | |
| function new_chunk() { | |
| return { "language": last_language, "timestamp": [null, null], "text": "" }; | |
| } | |
| // Welcome to the state machine! | |
| const chunks = []; | |
| let chunk = new_chunk(); | |
| let time_offset = 0.0; | |
| const timestamp_begin = this.timestamp_begin; | |
| // Whisper timestamp tokens start from 0.00 and go to timestamp 30.00 in 0.02 increments. | |
| // We can calculate the last time stamp token as timestamp_begin plus the number of tokens | |
| // tokens from 0.00 to 30.00 which is 1500. | |
| const total_timestamp_tokens = 1500; // (30.00 - 0.00) / 0.02 | |
| const timestamp_end = timestamp_begin + total_timestamp_tokens; | |
| let previous_tokens = []; | |
| let previous_token_timestamps = []; | |
| let skip = false; | |
| let right_stride_start = null; | |
| const all_special_ids = new Set(this.all_special_ids); | |
| for (const output of sequences) { | |
| // NOTE: python version has batches, so it uses [0] | |
| const token_ids = output.tokens; | |
| const token_timestamps = returnWordTimestamps ? output.token_timestamps : null; | |
| // These keep track of timestamps within strides, which need | |
| // to be skipped and resolve all tokens in a single chunk. | |
| let last_timestamp = null; | |
| let first_timestamp = timestamp_begin; | |
| if ("stride" in output) { | |
| const [chunk_len, stride_left, stride_right] = output.stride; | |
| // Offset the timings to account for the other `model_outputs`. | |
| time_offset -= stride_left; | |
| right_stride_start = chunk_len - stride_right; | |
| // Keeping track of timestamps within strides | |
| // We're going to NOT split on those, and delay until we're | |
| // out of BOTH stride. Otherwise lots of issues occur and | |
| // corner cases | |
| if (stride_left) { | |
| first_timestamp = stride_left / time_precision + timestamp_begin; | |
| } | |
| if (stride_right) { | |
| for (let i = token_ids.length - 1; i >= 0; --i) { | |
| const token = Number(token_ids[i]); | |
| if (token >= timestamp_begin) { | |
| // There can be several token in the right stride | |
| // But the last one is ALWAYS going to be skipped | |
| if (last_timestamp !== null && (token - timestamp_begin) * time_precision < right_stride_start) { | |
| break; | |
| } | |
| last_timestamp = token; | |
| } | |
| } | |
| } | |
| } | |
| let current_tokens = []; | |
| let current_token_timestamps = []; | |
| // - all tokens within output | |
| for (let i = 0; i < token_ids.length; ++i) { | |
| const token = Number(token_ids[i]); | |
| // 4 possible states for each token | |
| // - 1/ Language code | |
| // - 2/ all other special tokens (which we ignore) | |
| // - 3/ Timestamp | |
| // - 4/ Regular text | |
| if (all_special_ids.has(token)) { | |
| const text = this.decode([token]); | |
| const language = WHISPER_LANGUAGE_MAPPING.get(text.slice(2, -2)); | |
| if (language !== undefined) { | |
| // 1/ Indeed some language | |
| // TODO Handle when language is different from the previous | |
| // one, and we cannot use timestamped tokens to create chunks | |
| if (last_language !== null && language !== last_language && !return_timestamps) { | |
| previous_tokens.push(current_tokens); | |
| const resolved_tokens = this.findLongestCommonSequence(previous_tokens)[0]; | |
| const resolved_text = this.decode(resolved_tokens); | |
| chunk.text = resolved_text; | |
| chunks.push(chunk); | |
| // Flush all our temporary context | |
| previous_tokens = []; | |
| current_tokens = []; | |
| chunk = new_chunk(); | |
| } | |
| last_language = chunk.language = language; | |
| } else { | |
| // 2/ This is a regular special token, ignoring it | |
| } | |
| } else if (token >= timestamp_begin && token <= timestamp_end) { | |
| // 3/ Timestamp token | |
| const time = (token - timestamp_begin) * time_precision + time_offset; | |
| const rounded_time = round(time, 2); | |
| if (last_timestamp !== null && token >= last_timestamp) { | |
| // Whisper outputted a timestamp token, but it falls within | |
| // our stride, so we're going to skip it for the time being | |
| // and resolve this later | |
| // Skip is necessary because timestamp tokens always come | |
| // by pair, so we need to skip the next one too (which would mark the start of another chunk). | |
| skip = true; | |
| } else if (skip || (previous_tokens.length > 0 && token < first_timestamp)) { | |
| skip = false; | |
| } else if (chunk.timestamp[0] === null) { | |
| chunk.timestamp[0] = rounded_time; | |
| } else { | |
| // This is the end of the timestamp chunk | |
| if (rounded_time === chunk.timestamp[0]) { | |
| // This is a bug in timestamp token output | |
| // where we're taking the duplicate token | |
| // as a stop where it should be a start. | |
| // This is an issue in the underlying model output | |
| // Let's just skip it so it becomes de-factor a start agin | |
| } else { | |
| chunk.timestamp[1] = rounded_time; | |
| // Handling merges | |
| previous_tokens.push(current_tokens) | |
| if (returnWordTimestamps) { | |
| previous_token_timestamps.push(current_token_timestamps); | |
| } | |
| const [resolved_tokens, resolved_token_timestamps] = this.findLongestCommonSequence( | |
| previous_tokens, previous_token_timestamps | |
| ) | |
| const resolved_text = this.decode(resolved_tokens) | |
| chunk.text = resolved_text | |
| if (returnWordTimestamps) { | |
| chunk.words = this.collateWordTimestamps( | |
| resolved_tokens, resolved_token_timestamps, last_language, | |
| ) | |
| } | |
| chunks.push(chunk) | |
| // Flush all our temporary context | |
| previous_tokens = [] | |
| current_tokens = [] | |
| previous_token_timestamps = [] | |
| current_token_timestamps = [] | |
| chunk = new_chunk() | |
| } | |
| } | |
| } else { | |
| // 4/ Regular token | |
| // We just append to the list of all tokens so we can handle | |
| // merges later and decode into text. | |
| current_tokens.push(token) | |
| if (returnWordTimestamps) { | |
| let start_time = round(token_timestamps[i] + time_offset, 2); | |
| let end_time; | |
| if (i + 1 < token_timestamps.length) { | |
| end_time = round(token_timestamps[i + 1] + time_offset, 2); | |
| // Do not allow punctuation-only tokens to have a duration. | |
| // This prevents long pauses from messing up the timestamps. | |
| const decoded_text = this.decode([token]); | |
| if (PUNCTUATION_ONLY_REGEX.test(decoded_text)) { | |
| // Add `time_precision` to avoid overlapping timestamps | |
| end_time = round(Math.min(start_time + time_precision, end_time), 2); | |
| } | |
| } else { | |
| // should never happen | |
| end_time = null; | |
| } | |
| current_token_timestamps.push([start_time, end_time]); | |
| } | |
| } | |
| } | |
| if ('stride' in output) { | |
| const [chunk_len, stride_left, stride_right] = output.stride; | |
| time_offset += chunk_len - stride_right | |
| } | |
| // Leftover tokens | |
| if (current_tokens.length > 0) { | |
| previous_tokens.push(current_tokens) | |
| if (returnWordTimestamps) { | |
| previous_token_timestamps.push(current_token_timestamps); | |
| } | |
| } else if (previous_tokens.every(p => p.length === 0)) { | |
| // Flushing previous tokens (END)" | |
| chunk = new_chunk() | |
| previous_tokens = [] | |
| current_tokens = [] | |
| previous_token_timestamps = []; | |
| current_token_timestamps = []; | |
| } | |
| } | |
| if (previous_tokens.length > 0) { | |
| if (force_full_sequences && return_timestamps) { | |
| // Last token should always be timestamps, so there shouldn't be | |
| // leftover | |
| throw new Error( | |
| "Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. " + | |
| "Also make sure WhisperTimeStampLogitsProcessor was used during generation." | |
| ); | |
| } | |
| // Happens when we don't use timestamps | |
| const [resolved_tokens, resolved_token_timestamps] = this.findLongestCommonSequence(previous_tokens, previous_token_timestamps); | |
| // Flushing previous tokens (FINAL) | |
| const resolved_text = this.decode(resolved_tokens); | |
| chunk.text = resolved_text; | |
| if (returnWordTimestamps) { | |
| chunk.words = this.collateWordTimestamps( | |
| resolved_tokens, resolved_token_timestamps, last_language, | |
| ) | |
| } | |
| chunks.push(chunk); | |
| } | |
| let optional = Object.create(null); | |
| // Preparing and cleaning up the pipeline output | |
| const full_text = chunks.map(chunk => chunk.text).join(''); | |
| if (return_timestamps || return_language) { | |
| for (let i = 0; i < chunks.length; ++i) { | |
| const chunk = chunks[i]; | |
| if (!return_timestamps) { | |
| delete chunk["timestamp"]; | |
| } | |
| if (!return_language) { | |
| delete chunk["language"]; | |
| } | |
| } | |
| if (returnWordTimestamps) { | |
| const new_chunks = []; | |
| for (const chunk of chunks) { | |
| for (const word of chunk.words) { | |
| new_chunks.push(word); | |
| } | |
| } | |
| optional = { "chunks": new_chunks }; | |
| } else { | |
| optional = { "chunks": chunks }; | |
| } | |
| } | |
| return [full_text, optional]; | |
| } | |
| /** | |
| * Finds the longest common sequence among the provided sequences. | |
| * @param {number[][]} sequences An array of sequences of token ids to compare. | |
| * @returns {number[][]} The longest common sequence found. | |
| * @throws {Error} If there is a bug within the function. | |
| * @private | |
| */ | |
| findLongestCommonSequence(sequences, token_timestamp_sequences = null) { | |
| // It would be much harder to do O(n) because of fault tolerance. | |
| // We actually have a really good property which is that the total sequence | |
| // MUST be those subsequences in order. | |
| // If token_timestamp_sequences is provided, will split those sequences in | |
| // exactly the same way. | |
| let leftSequence = sequences[0]; | |
| let leftLength = leftSequence.length; | |
| let totalSequence = []; | |
| const use_token_timestamp_sequences = Array.isArray(token_timestamp_sequences) && token_timestamp_sequences.length > 0; | |
| let total_token_timestamp_sequence = use_token_timestamp_sequences ? [] : null; | |
| let left_token_timestamp_sequence = use_token_timestamp_sequences ? token_timestamp_sequences[0] : null; | |
| for (let i = 1; i < sequences.length; ++i) { | |
| const rightSequence = sequences[i]; | |
| let max = 0.0; | |
| let maxIndices = [leftLength, leftLength, 0, 0]; | |
| // Here we're sliding matches | |
| // [a, b, c, d] | |
| // [c, d, f] | |
| // = [c] == [d] | |
| // [a, b, c, d] | |
| // [c, d, f] | |
| // = [c, d] == [c, d] | |
| // [a, b, c, d] | |
| // [c, d, f] | |
| // = [b, c, d] == [c, d, f] | |
| // [a, b, c, d] | |
| // [c, d, f] | |
| // [a, b, c] == [c, d, f] | |
| // [a, b, c, d] | |
| // [d, f] | |
| // [a, b] == [d, f] | |
| // [a, b, c, d] | |
| // [f] | |
| // [a] == [f] | |
| const rightLength = rightSequence.length; | |
| for (let j = 1; j < leftLength + rightLength; ++j) { | |
| // Slightly convoluted because we don't want out of bound indices | |
| // This will be necessary for a small conflict resolution optimization | |
| // later | |
| const leftStart = Math.max(0, leftLength - j); | |
| const leftStop = Math.min(leftLength, leftLength + rightLength - j); | |
| const left = leftSequence.slice(leftStart, leftStop); | |
| const rightStart = Math.max(0, j - leftLength); | |
| const rightStop = Math.min(rightLength, j); | |
| const right = rightSequence.slice(rightStart, rightStop); | |
| if (left.length !== right.length) { | |
| throw new Error("There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."); | |
| } | |
| let matches; | |
| if (use_token_timestamp_sequences) { | |
| // Get length of longest subsequence of tokens that match | |
| // and have timestamps that are in order | |
| matches = left.filter((elem, idx) => ( | |
| elem === right[idx] | |
| && left_token_timestamp_sequence[leftStart + idx] <= token_timestamp_sequences[i][rightStart + idx] | |
| )).length; | |
| } else { | |
| matches = left.filter((elem, idx) => elem === right[idx]).length; | |
| } | |
| // epsilon to favor long perfect matches | |
| const eps = j / 10000.0; | |
| const matching = matches / j + eps; | |
| if (matches > 1 && matching > max) { | |
| max = matching; | |
| maxIndices = [leftStart, leftStop, rightStart, rightStop]; | |
| } | |
| } | |
| const [leftStart, leftStop, rightStart, rightStop] = maxIndices; | |
| const leftMid = Math.floor((leftStop + leftStart) / 2); | |
| const rightMid = Math.floor((rightStop + rightStart) / 2); | |
| totalSequence.push(...leftSequence.slice(0, leftMid)); | |
| leftSequence = rightSequence.slice(rightMid); | |
| leftLength = leftSequence.length; | |
| if (use_token_timestamp_sequences) { | |
| total_token_timestamp_sequence.push(...left_token_timestamp_sequence.slice(0, leftMid)); | |
| left_token_timestamp_sequence = token_timestamp_sequences[i].slice(rightMid); | |
| } | |
| } | |
| totalSequence.push(...leftSequence); | |
| if (use_token_timestamp_sequences) { | |
| total_token_timestamp_sequence.push(...left_token_timestamp_sequence); | |
| return [totalSequence, total_token_timestamp_sequence]; | |
| } else { | |
| return [totalSequence, []]; | |
| } | |
| } | |
| /** @private */ | |
| collateWordTimestamps(tokens, token_timestamps, language) { | |
| const [words, _, token_indices] = this.combineTokensIntoWords(tokens, language); | |
| const timings = []; | |
| for (let i = 0; i < words.length; ++i) { | |
| const indices = token_indices[i]; | |
| timings.push({ | |
| text: words[i], | |
| timestamp: [ | |
| token_timestamps[indices.at(0)][0], | |
| token_timestamps[indices.at(-1)][1], | |
| ], | |
| }); | |
| } | |
| return timings; | |
| } | |
| /** | |
| * Groups tokens by word. Returns a tuple containing a list of strings with the words, | |
| * and a list of `token_id` sequences with the tokens making up each word. | |
| * @param {number[]} tokens | |
| * @param {string} [language] | |
| * @param {string} prepend_punctionations | |
| * @param {string} append_punctuations | |
| * | |
| * @private | |
| */ | |
| combineTokensIntoWords(tokens, language, prepend_punctionations = "\"'“¡¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、") { | |
| language = language ?? 'english'; | |
| let words, word_tokens, token_indices; | |
| if (["chinese", "japanese", "thai", "lao", "myanmar"].includes(language)) { | |
| // These languages don't typically use spaces. | |
| [words, word_tokens, token_indices] = this.splitTokensOnUnicode(tokens) | |
| } else { | |
| [words, word_tokens, token_indices] = this.splitTokensOnSpaces(tokens) | |
| } | |
| return this.mergePunctuations(words, word_tokens, token_indices, prepend_punctionations, append_punctuations); | |
| } | |
| /** @type {PreTrainedTokenizer['decode']} */ | |
| decode( | |
| token_ids, | |
| decode_args, | |
| ) { | |
| let text; | |
| // @ts-ignore | |
| if (decode_args?.decode_with_timestamps) { | |
| if (token_ids instanceof Tensor) { | |
| token_ids = prepareTensorForDecode(token_ids); | |
| } | |
| text = this.decodeWithTimestamps(token_ids, decode_args); | |
| } else { | |
| text = super.decode(token_ids, decode_args); | |
| } | |
| // TODO: implement offsets | |
| // if (decode_args.output_offsets) { | |
| // let offsets = this.computeOffsets | |
| // } | |
| return text; | |
| } | |
| /** | |
| * @param {number[]|bigint[]} token_ids List of token IDs to decode. | |
| * @param {Object} decode_args Optional arguments for decoding | |
| * @private | |
| */ | |
| decodeWithTimestamps(token_ids, decode_args) { | |
| const time_precision = decode_args?.time_precision ?? 0.02; | |
| const timestamp_begin = Array.from(this.all_special_ids).at(-1) + 1; | |
| /**@type {Array} */ | |
| let outputs = [[]]; | |
| for (let token of token_ids) { | |
| token = Number(token); | |
| if (token >= timestamp_begin) { | |
| const timestamp = ((token - timestamp_begin) * time_precision).toFixed(2); | |
| outputs.push(`<|${timestamp}|>`); | |
| outputs.push([]); | |
| } else { | |
| outputs[outputs.length - 1].push(token); | |
| } | |
| } | |
| outputs = outputs.map( | |
| s => typeof s === 'string' ? s : super.decode(s, decode_args) | |
| ) | |
| return outputs.join(''); | |
| } | |
| /** | |
| * Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points. | |
| * @param {number[]} tokens | |
| * @returns {*} | |
| * @private | |
| */ | |
| splitTokensOnUnicode(tokens) { | |
| const decoded_full = this.decode(tokens, { | |
| // @ts-ignore | |
| decode_with_timestamps: true, | |
| }); | |
| const replacement_char = '\uFFFD'; | |
| const words = [] | |
| const word_tokens = [] | |
| const token_indices = [] | |
| let current_tokens = [] | |
| let current_indices = [] | |
| let unicode_offset = 0 | |
| for (let token_idx = 0; token_idx < tokens.length; ++token_idx) { | |
| const token = tokens[token_idx]; | |
| current_tokens.push(token); | |
| current_indices.push(token_idx); | |
| const decoded = this.decode(current_tokens, { | |
| // @ts-ignore | |
| decode_with_timestamps: true, | |
| }); | |
| if (!decoded.includes(replacement_char) || decoded_full[unicode_offset + decoded.indexOf(replacement_char)] === replacement_char) { | |
| words.push(decoded) | |
| word_tokens.push(current_tokens) | |
| token_indices.push(current_indices) | |
| current_tokens = [] | |
| current_indices = [] | |
| unicode_offset += decoded.length; | |
| } | |
| } | |
| return [words, word_tokens, token_indices] | |
| } | |
| /** | |
| * Combine tokens into words by splitting at whitespace and punctuation tokens. | |
| * @param {number[]} tokens | |
| * @private | |
| */ | |
| splitTokensOnSpaces(tokens) { | |
| const [subwords, subword_tokens_list, subword_indices_list] = this.splitTokensOnUnicode(tokens); | |
| const words = [] | |
| const word_tokens = [] | |
| const token_indices = [] | |
| const punctuationRegex = new RegExp(`^[${PUNCTUATION_REGEX}]$`, 'gu'); | |
| for (let i = 0; i < subwords.length; ++i) { | |
| const subword = subwords[i]; | |
| const subword_tokens = subword_tokens_list[i]; | |
| const subword_indices = subword_indices_list[i]; | |
| // @ts-ignore | |
| const special = subword_tokens[0] >= this.model.tokens_to_ids.get('<|endoftext|>'); | |
| const with_space = subword.startsWith(' '); | |
| const trimmed = subword.trim(); | |
| const punctuation = punctuationRegex.test(trimmed); | |
| if (special || with_space || punctuation || words.length === 0) { | |
| words.push(subword); | |
| word_tokens.push(subword_tokens); | |
| token_indices.push(subword_indices); | |
| } else { | |
| const ix = words.length - 1; | |
| words[ix] += subword; | |
| word_tokens[ix].push(...subword_tokens); | |
| token_indices[ix].push(...subword_indices); | |
| } | |
| } | |
| return [words, word_tokens, token_indices]; | |
| } | |
| /** | |
| * Merges punctuation tokens with neighboring words. | |
| * @param {string[]} words | |
| * @param {number[][]} tokens | |
| * @param {number[][]} indices | |
| * @param {string} prepended | |
| * @param {string} appended | |
| * @private | |
| */ | |
| mergePunctuations(words, tokens, indices, prepended, appended) { | |
| const newWords = structuredClone(words); | |
| const newTokens = structuredClone(tokens); | |
| const newIndices = structuredClone(indices); | |
| // prepend punctuations | |
| let i = newWords.length - 2; | |
| let j = newWords.length - 1; | |
| while (i >= 0) { | |
| if (newWords[i].startsWith(' ') && prepended.includes(newWords[i].trim())) { | |
| newWords[j] = newWords[i] + newWords[j]; | |
| newTokens[j] = mergeArrays(newTokens[i], newTokens[j]); | |
| newIndices[j] = mergeArrays(newIndices[i], newIndices[j]); | |
| newWords[i] = ''; | |
| newTokens[i] = []; | |
| newIndices[i] = []; | |
| } else { | |
| j = i; | |
| } | |
| --i; | |
| } | |
| // append punctuations | |
| i = 0; | |
| j = 1; | |
| while (j < newWords.length) { | |
| if (!newWords[i].endsWith(' ') && appended.includes(newWords[j])) { | |
| newWords[i] += newWords[j]; | |
| newTokens[i] = mergeArrays(newTokens[i], newTokens[j]); | |
| newIndices[i] = mergeArrays(newIndices[i], newIndices[j]); | |
| newWords[j] = ''; | |
| newTokens[j] = []; | |
| newIndices[j] = []; | |
| } else { | |
| i = j; | |
| } | |
| ++j; | |
| } | |
| return [ | |
| newWords.filter(x => x), | |
| newTokens.filter(x => x.length > 0), | |
| newIndices.filter(x => x.length > 0), | |
| ] | |
| } | |
| } | |
| export class CodeGenTokenizer extends PreTrainedTokenizer { } | |
| export class CLIPTokenizer extends PreTrainedTokenizer { } | |
| export class SiglipTokenizer extends PreTrainedTokenizer { } | |
| /** | |
| * @todo This model is not yet supported by Hugging Face's "fast" tokenizers library (https://github.com/huggingface/tokenizers). | |
| * Therefore, this implementation (which is based on fast tokenizers) may produce slightly inaccurate results. | |
| */ | |
| export class MarianTokenizer extends PreTrainedTokenizer { | |
| /** | |
| * Create a new MarianTokenizer instance. | |
| * @param {Object} tokenizerJSON The JSON of the tokenizer. | |
| * @param {Object} tokenizerConfig The config of the tokenizer. | |
| */ | |
| constructor(tokenizerJSON, tokenizerConfig) { | |
| super(tokenizerJSON, tokenizerConfig); | |
| this.languageRegex = /^(>>\w+<<)\s*/g; | |
| this.supported_language_codes = this.model.vocab.filter( | |
| x => this.languageRegex.test(x) | |
| ); | |
| console.warn('WARNING: `MarianTokenizer` is not yet supported by Hugging Face\'s "fast" tokenizers library. Therefore, you may experience slightly inaccurate results.') | |
| } | |
| /** | |
| * Encodes a single text. Overriding this method is necessary since the language codes | |
| * must be removed before encoding with sentencepiece model. | |
| * @see https://github.com/huggingface/transformers/blob/12d51db243a00726a548a43cc333390ebae731e3/src/transformers/models/marian/tokenization_marian.py#L204-L213 | |
| * | |
| * @param {string|null} text The text to encode. | |
| * @returns {Array} The encoded tokens. | |
| */ | |
| _encode_text(text) { | |
| if (text === null) return null; | |
| // Check if text starts with language code: | |
| const [matchInfo, ...remainder] = text.trim().split(this.languageRegex); | |
| if (remainder.length === 0) { | |
| // No language code, encode normally | |
| return super._encode_text(matchInfo); | |
| } else if (remainder.length === 2) { | |
| // Text starts with language code, so we do not encode it with sentencepiece. | |
| const [language, text] = remainder; | |
| if (!this.supported_language_codes.includes(language)) { | |
| console.warn(`Unsupported language code "${language}" detected, which may lead to unexpected behavior. Should be one of: ${JSON.stringify(this.supported_language_codes)}`) | |
| } | |
| return mergeArrays([language], super._encode_text(text)); | |
| } | |
| } | |
| } | |
| export class Wav2Vec2CTCTokenizer extends PreTrainedTokenizer { } | |
| export class BlenderbotTokenizer extends PreTrainedTokenizer { } | |
| export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { } | |
| export class SpeechT5Tokenizer extends PreTrainedTokenizer { } | |
| export class NougatTokenizer extends PreTrainedTokenizer { } | |
| export class VitsTokenizer extends PreTrainedTokenizer { | |
| constructor(tokenizerJSON, tokenizerConfig) { | |
| super(tokenizerJSON, tokenizerConfig); | |
| // Custom decoder function | |
| this.decoder = new VitsDecoder({}); | |
| } | |
| } | |
| export class CohereTokenizer extends PreTrainedTokenizer { } | |
| export class MgpstrTokenizer extends PreTrainedTokenizer { } | |
| export class Ernie4_5_Tokenizer extends PreTrainedTokenizer { } | |
| /** | |
| * Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function. | |
| * The chosen tokenizer class is determined by the type specified in the tokenizer config. | |
| * | |
| * @example | |
| * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); | |
| */ | |
| export class AutoTokenizer { | |
| static TOKENIZER_CLASS_MAPPING = { | |
| T5Tokenizer, | |
| DistilBertTokenizer, | |
| CamembertTokenizer, | |
| DebertaTokenizer, | |
| DebertaV2Tokenizer, | |
| BertTokenizer, | |
| HerbertTokenizer, | |
| ConvBertTokenizer, | |
| RoFormerTokenizer, | |
| XLMTokenizer, | |
| ElectraTokenizer, | |
| MobileBertTokenizer, | |
| SqueezeBertTokenizer, | |
| AlbertTokenizer, | |
| GPT2Tokenizer, | |
| BartTokenizer, | |
| MBartTokenizer, | |
| MBart50Tokenizer, | |
| RobertaTokenizer, | |
| WhisperTokenizer, | |
| CodeGenTokenizer, | |
| CLIPTokenizer, | |
| SiglipTokenizer, | |
| MarianTokenizer, | |
| BloomTokenizer, | |
| NllbTokenizer, | |
| M2M100Tokenizer, | |
| LlamaTokenizer, | |
| CodeLlamaTokenizer, | |
| XLMRobertaTokenizer, | |
| MPNetTokenizer, | |
| FalconTokenizer, | |
| GPTNeoXTokenizer, | |
| EsmTokenizer, | |
| Wav2Vec2CTCTokenizer, | |
| BlenderbotTokenizer, | |
| BlenderbotSmallTokenizer, | |
| SpeechT5Tokenizer, | |
| NougatTokenizer, | |
| VitsTokenizer, | |
| Qwen2Tokenizer, | |
| GemmaTokenizer, | |
| Grok1Tokenizer, | |
| CohereTokenizer, | |
| MgpstrTokenizer, | |
| Ernie4_5_Tokenizer, | |
| // Base case: | |
| PreTrainedTokenizer, | |
| } | |
| /** | |
| * Instantiate one of the tokenizer classes of the library from a pretrained model. | |
| * | |
| * The tokenizer class to instantiate is selected based on the `tokenizer_class` property of the config object | |
| * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) | |
| * | |
| * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: | |
| * - A string, the *model id* of a pretrained tokenizer hosted inside a model repo on huggingface.co. | |
| * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a | |
| * user or organization name, like `dbmdz/bert-base-german-cased`. | |
| * - A path to a *directory* containing tokenizer files, e.g., `./my_model_directory/`. | |
| * @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer. | |
| * | |
| * @returns {Promise<PreTrainedTokenizer>} A new instance of the PreTrainedTokenizer class. | |
| */ | |
| static async from_pretrained(pretrained_model_name_or_path, { | |
| progress_callback = null, | |
| config = null, | |
| cache_dir = null, | |
| local_files_only = false, | |
| revision = 'main', | |
| legacy = null, | |
| } = {}) { | |
| const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { | |
| progress_callback, | |
| config, | |
| cache_dir, | |
| local_files_only, | |
| revision, | |
| legacy, | |
| }) | |
| // Some tokenizers are saved with the "Fast" suffix, so we remove that if present. | |
| const tokenizerName = tokenizerConfig.tokenizer_class?.replace(/Fast$/, '') ?? 'PreTrainedTokenizer'; | |
| let cls = this.TOKENIZER_CLASS_MAPPING[tokenizerName]; | |
| if (!cls) { | |
| console.warn(`Unknown tokenizer class "${tokenizerName}", attempting to construct from base class.`); | |
| cls = PreTrainedTokenizer; | |
| } | |
| return new cls(tokenizerJSON, tokenizerConfig); | |
| } | |
| } | |