/** * @file Tokenizers are used to prepare textual inputs for a model. * * **Example:** Create an `AutoTokenizer` and use it to tokenize a sentence. * This will automatically detect the tokenizer type based on the tokenizer class defined in `tokenizer.json`. * ```javascript * import { AutoTokenizer } from '@huggingface/transformers'; * * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); * const { input_ids } = await tokenizer('I love transformers!'); * // Tensor { * // data: BigInt64Array(6) [101n, 1045n, 2293n, 19081n, 999n, 102n], * // dims: [1, 6], * // type: 'int64', * // size: 6, * // } * ``` * * @module tokenizers */ import { Callable, } from './utils/generic.js'; import { reverseDictionary, escapeRegExp, isIntegralNumber, mergeArrays, len, } from './utils/core.js'; import { getModelJSON, } from './utils/hub.js'; import { max, min, round } from './utils/maths.js'; import { Tensor } from './utils/tensor.js'; import { PriorityQueue, TokenLattice, CharTrie, DictionarySplitter, LRUCache, } from './utils/data-structures.js'; import { Template } from '@huggingface/jinja'; import { WHISPER_LANGUAGE_MAPPING } from './models/whisper/common_whisper.js'; /** * @typedef {Object} TokenizerProperties Additional tokenizer-specific properties. * @property {boolean} [legacy=false] Whether or not the `legacy` behavior of the tokenizer should be used. * @typedef {import('./utils/hub.js').PretrainedOptions & TokenizerProperties} PretrainedTokenizerOptions */ /** * Loads a tokenizer from the specified path. * @param {string} pretrained_model_name_or_path The path to the tokenizer directory. * @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer. * @returns {Promise} A promise that resolves with information about the loaded tokenizer. */ async function loadTokenizer(pretrained_model_name_or_path, options) { const info = await Promise.all([ getModelJSON(pretrained_model_name_or_path, 'tokenizer.json', true, options), getModelJSON(pretrained_model_name_or_path, 'tokenizer_config.json', true, options), ]) // Override legacy option if `options.legacy` is not null if (options.legacy !== null) { info[1].legacy = options.legacy; } return info; } /** * Helper function to split a string on a regex, but keep the delimiters. * This is required, because the JavaScript `.split()` method does not keep the delimiters, * and wrapping in a capturing group causes issues with existing capturing groups (due to nesting). * @param {string} text The text to split. * @param {RegExp} regex The regex to split on. * @returns {string[]} The split string. */ function regexSplit(text, regex) { const result = []; let prev = 0; for (const match of text.matchAll(regex)) { const fullMatch = match[0]; if (prev < match.index) { result.push(text.slice(prev, match.index)); } if (fullMatch.length > 0) { result.push(fullMatch); } prev = match.index + fullMatch.length; } if (prev < text.length) { result.push(text.slice(prev)); } return result; } /** * Helper method to construct a pattern from a config object. * @param {Object} pattern The pattern object. * @param {boolean} invert Whether to invert the pattern. * @returns {RegExp|null} The compiled pattern. */ function createPattern(pattern, invert = true) { if (pattern.Regex !== undefined) { // In certain cases, the pattern may contain unnecessary escape sequences (e.g., \# or \& or \~). // i.e., valid in Python (where the patterns are exported from) but invalid in JavaScript (where the patterns are parsed). // This isn't an issue when creating the regex w/o the 'u' flag, but it is when the 'u' flag is used. // For this reason, it is necessary to remove these backslashes before creating the regex. // See https://stackoverflow.com/a/63007777/13989043 for more information let regex = pattern.Regex.replace(/\\([#&~])/g, '$1'); // TODO: add more characters to this list if necessary // We also handle special cases where the regex contains invalid (non-JS compatible) syntax. for (const [key, value] of PROBLEMATIC_REGEX_MAP) { regex = regex.replaceAll(key, value); } return new RegExp(regex, 'gu'); } else if (pattern.String !== undefined) { const escaped = escapeRegExp(pattern.String); // NOTE: if invert is true, we wrap the pattern in a group so that it is kept when performing .split() return new RegExp(invert ? escaped : `(${escaped})`, 'gu'); } else { console.warn('Unknown pattern type:', pattern) return null; } } /** * Helper function to convert an Object to a Map * @param {Object} obj The object to convert. * @returns {Map} The map. */ function objectToMap(obj) { return new Map(Object.entries(obj)); } /** * Helper function to convert a tensor to a list before decoding. * @param {Tensor} tensor The tensor to convert. * @returns {number[]} The tensor as a list. */ function prepareTensorForDecode(tensor) { const dims = tensor.dims; switch (dims.length) { case 1: return tensor.tolist(); case 2: if (dims[0] !== 1) { throw new Error('Unable to decode tensor with `batch size !== 1`. Use `tokenizer.batch_decode(...)` for batched inputs.'); } return tensor.tolist()[0]; default: throw new Error(`Expected tensor to have 1-2 dimensions, got ${dims.length}.`) } } /** * Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms * @param {string} text The text to clean up. * @returns {string} The cleaned up text. */ function clean_up_tokenization(text) { // Clean up a list of simple English tokenization artifacts // like spaces before punctuations and abbreviated forms return text.replace(/ \./g, '.') .replace(/ \?/g, '?') .replace(/ \!/g, '!') .replace(/ ,/g, ',') .replace(/ \' /g, "'") .replace(/ n\'t/g, "n't") .replace(/ \'m/g, "'m") .replace(/ \'s/g, "'s") .replace(/ \'ve/g, "'ve") .replace(/ \'re/g, "'re"); } /** * Helper function to remove accents from a string. * @param {string} text The text to remove accents from. * @returns {string} The text with accents removed. */ function remove_accents(text) { return text.replace(/\p{M}/gu, ''); } /** * Helper function to lowercase a string and remove accents. * @param {string} text The text to lowercase and remove accents from. * @returns {string} The lowercased text with accents removed. */ function lowercase_and_remove_accent(text) { return remove_accents(text.toLowerCase()); } /** * Checks whether the given Unicode codepoint represents a CJK (Chinese, Japanese, or Korean) character. * * A "chinese character" is defined as anything in the CJK Unicode block: * https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) * * Note that the CJK Unicode block is NOT all Japanese and Korean characters, despite its name. * The modern Korean Hangul alphabet is a different block, as is Japanese Hiragana and Katakana. * Those alphabets are used to write space-separated words, so they are not treated specially * and are handled like all other languages. * * @param {number|bigint} cp The Unicode codepoint to check. * @returns {boolean} True if the codepoint represents a CJK character, false otherwise. */ export function is_chinese_char(cp) { return ( (cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) || (cp >= 0x20000 && cp <= 0x2A6DF) || (cp >= 0x2A700 && cp <= 0x2B73F) || (cp >= 0x2B740 && cp <= 0x2B81F) || (cp >= 0x2B820 && cp <= 0x2CEAF) || (cp >= 0xF900 && cp <= 0xFAFF) || (cp >= 0x2F800 && cp <= 0x2FA1F) ) } /** * Helper function to fuse consecutive unknown tokens. * @param {string[]} arr The list of input tokens * @param {Map} tokens_to_ids The mapping from tokens to token ids. * @param {number} unk_token_id The value to fuse on. * @private */ function fuse_unk(arr, tokens_to_ids, unk_token_id) { const fused = []; let i = 0; while (i < arr.length) { fused.push(arr[i]) if ((tokens_to_ids.get(arr[i]) ?? unk_token_id) !== unk_token_id) { ++i; continue; } while (++i < arr.length && (tokens_to_ids.get(arr[i]) ?? unk_token_id) === unk_token_id) { if (tokens_to_ids.get(fused.at(-1)) !== unk_token_id) { fused[fused.length - 1] += arr[i]; } } } return fused; } /** * Split a string on whitespace. * @param {string} text The text to split. * @returns {string[]} The split string. */ function whitespace_split(text) { return text.match(/\S+/g) || []; } const PUNCTUATION_REGEX = '\\p{P}\\u0021-\\u002F\\u003A-\\u0040\\u005B-\\u0060\\u007B-\\u007E'; const PUNCTUATION_ONLY_REGEX = new RegExp(`^[${PUNCTUATION_REGEX}]+$`, 'gu'); const BLOOM_SPLIT_CHARS = '.,!?\u2026\u3002\uff0c\u3001\u0964\u06d4\u060c'; // A mapping of regex patterns to their equivalent (but possibly longer) JS-compatible versions. const PROBLEMATIC_REGEX_MAP = new Map([ // This uses the case insensitive group modifier, which is not supported in JavaScript. // When parsing the regex, an "Invalid group" error is thrown. ["(?i:'s|'t|'re|'ve|'m|'ll|'d)", "(?:'([sS]|[tT]|[rR][eE]|[vV][eE]|[mM]|[lL][lL]|[dD]))"], // Used to override the default (invalid) regex of the bloom pretokenizer. // For more information, see https://github.com/huggingface/transformers.js/issues/94 [` ?[^(\\s|[${BLOOM_SPLIT_CHARS}])]+`, ` ?[^\\s${BLOOM_SPLIT_CHARS}]+`], ]) /** * Represent a token added by the user on top of the existing Model vocabulary. * AddedToken can be configured to specify the behavior they should have in various situations like: * - Whether they should only match single words * - Whether to include any whitespace on its left or right */ class AddedToken { /** * Creates a new instance of AddedToken. * @param {Object} config Added token configuration object. * @param {string} config.content The content of the added token. * @param {number} config.id The id of the added token. * @param {boolean} [config.single_word=false] Whether this token must be a single word or can break words. * @param {boolean} [config.lstrip=false] Whether this token should strip whitespaces on its left. * @param {boolean} [config.rstrip=false] Whether this token should strip whitespaces on its right. * @param {boolean} [config.normalized=false] Whether this token should be normalized. * @param {boolean} [config.special=false] Whether this token is special. */ constructor(config) { this.content = config.content; this.id = config.id; this.single_word = config.single_word ?? false; this.lstrip = config.lstrip ?? false; this.rstrip = config.rstrip ?? false; this.special = config.special ?? false; this.normalized = config.normalized ?? null; } } /** * Abstract base class for tokenizer models. * * @extends Callable */ export class TokenizerModel extends Callable { /** * Creates a new instance of TokenizerModel. * @param {Object} config The configuration object for the TokenizerModel. */ constructor(config) { super(); this.config = config; /** @type {string[]} */ this.vocab = []; /** * A mapping of tokens to ids. * @type {Map} */ this.tokens_to_ids = new Map(); this.unk_token_id = undefined; this.unk_token = undefined; this.end_of_word_suffix = undefined; /** @type {boolean} Whether to fuse unknown tokens when encoding. Defaults to false. */ this.fuse_unk = this.config.fuse_unk ?? false; } /** * Instantiates a new TokenizerModel instance based on the configuration object provided. * @param {Object} config The configuration object for the TokenizerModel. * @param {...*} args Optional arguments to pass to the specific TokenizerModel constructor. * @returns {TokenizerModel} A new instance of a TokenizerModel. * @throws Will throw an error if the TokenizerModel type in the config is not recognized. */ static fromConfig(config, ...args) { switch (config.type) { case 'WordPiece': return new WordPieceTokenizer(config); case 'Unigram': // @ts-ignore return new Unigram(config, ...args); case 'BPE': return new BPE(config); default: // Some older tokenizers, like `google-t5/t5-small`, `openai-community/gpt2`, and `distilbert/distilbert-base-uncased`, do not have a `type` field. // In this case, we can infer the tokenizer type based on the structure of the `vocab` field and other properties. if (config.vocab) { if (Array.isArray(config.vocab)) { // config.vocab is of type `[string, number][]` // @ts-ignore return new Unigram(config, ...args); } else if (Object.hasOwn(config, 'continuing_subword_prefix') && Object.hasOwn(config, 'unk_token')) { if (Object.hasOwn(config, 'merges')) { return new BPE(config); } else { return new WordPieceTokenizer(config); } } else { // @ts-ignore return new LegacyTokenizerModel(config, ...args); } } throw new Error(`Unknown TokenizerModel type: ${config.type}`); } } /** * Internal function to call the TokenizerModel instance. * @param {string[]} tokens The tokens to encode. * @returns {string[]} The encoded tokens. */ _call(tokens) { tokens = this.encode(tokens); if (this.fuse_unk) { // Fuse unknown tokens tokens = fuse_unk(tokens, this.tokens_to_ids, this.unk_token_id); } return tokens; } /** * Encodes a list of tokens into a list of token IDs. * @param {string[]} tokens The tokens to encode. * @returns {string[]} The encoded tokens. * @throws Will throw an error if not implemented in a subclass. */ encode(tokens) { throw Error("encode should be implemented in subclass.") } /** * Converts a list of tokens into a list of token IDs. * @param {string[]} tokens The tokens to convert. * @returns {number[]} The converted token IDs. */ convert_tokens_to_ids(tokens) { return tokens.map(t => this.tokens_to_ids.get(t) ?? this.unk_token_id); } /** * Converts a list of token IDs into a list of tokens. * @param {number[]|bigint[]} ids The token IDs to convert. * @returns {string[]} The converted tokens. */ convert_ids_to_tokens(ids) { return ids.map(i => this.vocab[i] ?? this.unk_token); } } /** * A subclass of TokenizerModel that uses WordPiece encoding to encode tokens. * @extends TokenizerModel */ class WordPieceTokenizer extends TokenizerModel { /** * @param {Object} config The configuration object. * @param {Object} config.vocab A mapping of tokens to ids. * @param {string} config.unk_token The unknown token string. * @param {string} config.continuing_subword_prefix The prefix to use for continuing subwords. * @param {number} [config.max_input_chars_per_word=100] The maximum number of characters per word. */ constructor(config) { super(config); /** * A mapping of tokens to ids. * @type {Map} */ this.tokens_to_ids = objectToMap(config.vocab); /** * The id of the unknown token. * @type {number} */ this.unk_token_id = this.tokens_to_ids.get(config.unk_token); /** * The unknown token string. * @type {string} */ this.unk_token = config.unk_token; /** * The maximum number of characters allowed per word. * @type {number} */ this.max_input_chars_per_word = config.max_input_chars_per_word ?? 100; /** * An array of tokens. * @type {string[]} */ this.vocab = new Array(this.tokens_to_ids.size); for (const [key, value] of this.tokens_to_ids) { this.vocab[value] = key; } } /** * Encodes an array of tokens using WordPiece encoding. * @param {string[]} tokens The tokens to encode. * @returns {string[]} An array of encoded tokens. */ encode(tokens) { const outputTokens = []; for (const token of tokens) { const chars = [...token]; if (chars.length > this.max_input_chars_per_word) { outputTokens.push(this.unk_token); continue; } let isUnknown = false; let start = 0; const subTokens = []; while (start < chars.length) { let end = chars.length; let currentSubstring = null; while (start < end) { let substr = chars.slice(start, end).join(''); if (start > 0) { substr = this.config.continuing_subword_prefix + substr; } if (this.tokens_to_ids.has(substr)) { currentSubstring = substr; break; } --end; } if (currentSubstring === null) { isUnknown = true; break; } subTokens.push(currentSubstring); start = end; } if (isUnknown) { outputTokens.push(this.unk_token); } else { outputTokens.push(...subTokens); } } return outputTokens; } } /** * Class representing a Unigram tokenizer model. * @extends TokenizerModel */ class Unigram extends TokenizerModel { /** * Create a new Unigram tokenizer model. * @param {Object} config The configuration object for the Unigram model. * @param {number} config.unk_id The ID of the unknown token * @param {[string, number][]} config.vocab A 2D array representing a mapping of tokens to scores. * @param {Object} moreConfig Additional configuration object for the Unigram model. */ constructor(config, moreConfig) { super(config); const vocabSize = config.vocab.length; this.vocab = new Array(vocabSize); /** @type {number[]} */ this.scores = new Array(vocabSize); for (let i = 0; i < vocabSize; ++i) { [this.vocab[i], this.scores[i]] = config.vocab[i]; } this.unk_token_id = config.unk_id; this.unk_token = this.vocab[config.unk_id]; this.tokens_to_ids = new Map(this.vocab.map((x, i) => [x, i])); this.bos_token = ' '; // beginning of a sentence token this.bos_token_id = this.tokens_to_ids.get(this.bos_token); // NOTE: may be undefined this.eos_token = moreConfig.eos_token; this.eos_token_id = this.tokens_to_ids.get(this.eos_token); this.unk_token = this.vocab[this.unk_token_id]; this.minScore = min(this.scores)[0]; this.unk_score = this.minScore - 10.0; this.scores[this.unk_token_id] = this.unk_score; this.trie = new CharTrie(); this.trie.extend(this.vocab); // NOTE: `fuse_unk` is hardcoded to true for Unigram models // See: https://github.com/huggingface/tokenizers/blob/b58227c7f1ccf8b73ee2268354336da56d91e492/tokenizers/src/models/unigram/model.rs#L119 this.fuse_unk = true; } /** * Populates lattice nodes. * @param {TokenLattice} lattice The token lattice to populate with nodes. */ populateNodes(lattice) { const chars = lattice.chars; const mblen = 1; let beginPos = 0; while (beginPos < chars.length) { let hasSingleNode = false; const tokens = []; const sliced = chars.slice(beginPos).join(''); const prefixedTokens = this.trie.commonPrefixSearch(sliced); for (const token of prefixedTokens) { tokens.push(token); const tokenId = this.tokens_to_ids.get(token); const tokenScore = this.scores[tokenId]; const n = len(token); lattice.insert(beginPos, n, tokenScore, tokenId); if (!hasSingleNode && n === mblen) { hasSingleNode = true; } } if (!hasSingleNode) { lattice.insert(beginPos, mblen, this.unk_score, this.unk_token_id); } beginPos += mblen; } } /** * Encodes an array of tokens into an array of subtokens using the unigram model. * * @param {string} normalized The normalized string. * @returns {string[]} An array of subtokens obtained by encoding the input tokens using the unigram model. */ tokenize(normalized) { const lattice = new TokenLattice(normalized, this.bos_token_id, this.eos_token_id); this.populateNodes(lattice); return lattice.tokens(); } /** * Encodes an array of tokens using Unigram encoding. * @param {string[]} tokens The tokens to encode. * @returns {string[]} An array of encoded tokens. */ encode(tokens) { const toReturn = []; for (const token of tokens) { const tokenized = this.tokenize(token); toReturn.push(...tokenized); } return toReturn; } } /** * Returns list of utf-8 byte and a mapping to unicode strings. * Specifically avoids mapping to whitespace/control characters the BPE code barfs on. * @returns {Object} Object with utf-8 byte keys and unicode string values. */ const BYTES_TO_UNICODE = (() => { // Returns list of utf-8 byte and a mapping to unicode strings. // We specifically avoids mapping to whitespace/control characters // the bpe code barfs on. const bs = [ ...Array.from({ length: "~".charCodeAt(0) - "!".charCodeAt(0) + 1 }, (_, i) => i + "!".charCodeAt(0)), ...Array.from({ length: "¬".charCodeAt(0) - "¡".charCodeAt(0) + 1 }, (_, i) => i + "¡".charCodeAt(0)), ...Array.from({ length: "ÿ".charCodeAt(0) - "®".charCodeAt(0) + 1 }, (_, i) => i + "®".charCodeAt(0)), ]; const cs = bs.slice(); let n = 0; for (let b = 0; b < 256; ++b) { if (!bs.includes(b)) { bs.push(b); cs.push(256 + n); n += 1; } } const ccs = cs.map(n => String.fromCharCode(n)); return Object.fromEntries(bs.map((b, i) => [b, ccs[i]])); })(); const UNICODE_TO_BYTES = reverseDictionary(BYTES_TO_UNICODE); /** * @typedef {Object} BPENode * @property {string} token The token associated with the node * @property {number} bias A positional bias for the node. * @property {number} [score] The score of the node. * @property {BPENode} [prev] The previous node in the linked list. * @property {BPENode} [next] The next node in the linked list. */ /** * BPE class for encoding text into Byte-Pair-Encoding (BPE) tokens. * @extends TokenizerModel */ class BPE extends TokenizerModel { /** * Create a BPE instance. * @param {Object} config The configuration object for BPE. * @param {Object} config.vocab A mapping of tokens to ids. * @param {string[]|[string, string][]} config.merges An array of BPE merges as strings. * @param {string} config.unk_token The unknown token used for out of vocabulary words. * @param {string} config.end_of_word_suffix The suffix to place at the end of each word. * @param {string} [config.continuing_subword_suffix] The suffix to insert between words. * @param {boolean} [config.byte_fallback=false] Whether to use spm byte-fallback trick (defaults to False) * @param {boolean} [config.ignore_merges=false] Whether or not to match tokens with the vocab before using merges. */ constructor(config) { super(config); /** @type {Map} */ this.tokens_to_ids = objectToMap(config.vocab); this.unk_token_id = this.tokens_to_ids.get(config.unk_token); this.unk_token = config.unk_token; this.vocab = new Array(this.tokens_to_ids.size); for (const [key, value] of this.tokens_to_ids) { this.vocab[value] = key; } // Tokenizers >= 0.20.0 serializes BPE merges as a [string, string][] instead of a string[], // which resolves the ambiguity for merges containing spaces. const use_new_merge_format = Array.isArray(config.merges[0]); /** @type {[string, string][]} */ this.merges = use_new_merge_format ? /** @type {[string, string][]} */(config.merges) : (/** @type {string[]} */(config.merges)).map(x => /** @type {[string, string]} */(x.split(' ', 2))); this.bpe_ranks = new Map(this.merges.map((x, i) => [JSON.stringify(x), i])); this.end_of_word_suffix = config.end_of_word_suffix; // NOTE: `continuing_subword_suffix` is custom (to support `BlenderbotSmallTokenizer`) this.continuing_subword_suffix = config.continuing_subword_suffix ?? null; this.byte_fallback = this.config.byte_fallback ?? false; if (this.byte_fallback) { this.text_encoder = new TextEncoder(); } this.ignore_merges = this.config.ignore_merges ?? false; /** * The maximum length we should cache in a model. * Strings that are too long have minimal chances to cache hit anyway */ this.max_length_to_cache = 256; /** * The default capacity for a `BPE`'s internal cache. */ this.cache_capacity = 10000; this.cache = new LRUCache(this.cache_capacity); } /** * Clears the cache. */ clear_cache() { this.cache.clear(); } /** * Apply Byte-Pair-Encoding (BPE) to a given token. Efficient heap-based priority * queue implementation adapted from https://github.com/belladoreai/llama-tokenizer-js. * @param {string} token The token to encode. * @returns {string[]} The BPE encoded tokens. */ bpe(token) { if (token.length === 0) { return []; } const cached = this.cache.get(token); if (cached !== undefined) { return cached; } const word = Array.from(token); if (this.end_of_word_suffix) { word[word.length - 1] += this.end_of_word_suffix; } let result = []; if (word.length > 1) { // Create a priority queue to store the nodes that will be merged. // The comparator function compares the scores of the nodes. const queue = new PriorityQueue((a, b) => a.score < b.score); // Construct a doubly-linked list of nodes that will be inserted into the priority queue, // starting with the individual characters. We also populate each node with a positional // bias to break ties in the priority queue. let startingNode = { token: word[0], bias: 0, prev: null, next: null, } let previousNode = startingNode for (let i = 1; i < word.length; ++i) { const currentNode = { bias: i / word.length, // Add fractional component to break ties token: word[i], prev: previousNode, next: null, } previousNode.next = currentNode this._add_node(queue, previousNode) previousNode = currentNode } while (!queue.isEmpty()) { // Get the next node with the highest priority const node = queue.pop(); // Check that this merge is still possible if (node.deleted || !node.next || node.next.deleted) continue; // Here, we mark the current node (left side of the merge) and the next node (right side of the merge) as deleted. // This is because they will both be replaced by a new node representing the merge result. node.deleted = true; node.next.deleted = true; // Next, we fix the node that comes before the current node (i.e., left side of the merge). if (node.prev) { // Make a shallow copy of the previous node const newPreviousNode = { ...node.prev }; // Mark the old previous node as deleted. This avoids erroneous merges later, // because there may still be references to this node in the priority queue. node.prev.deleted = true; node.prev = newPreviousNode; // Update the reference of the previous node, by pointing its previous node to this new previous node. if (newPreviousNode.prev) { newPreviousNode.prev.next = newPreviousNode; } else { // If the previous of the previous node does not exist, it means that // `newPreviousNode` must be the new `startingNode`. startingNode = newPreviousNode; } } // Create a new node which represents the result of the merge. const merged = { token: node.token + node.next.token, bias: node.bias, prev: node.prev, next: node.next.next, } // We now consider where we can add the new merged node to the priority queue: // 1. prev <-> merged if (merged.prev) { merged.prev.next = merged; this._add_node(queue, merged.prev); } else { // If `merged.prev` does not exist, then `merged` must be the new `startingNode`. startingNode = merged; } // 2. merged <-> next if (merged.next) { merged.next.prev = merged; this._add_node(queue, merged); } } // Traverse the linked list, starting from the `startingNode`, and collect the tokens. for (let currentNode = startingNode; currentNode !== null; currentNode = currentNode.next) { result.push(currentNode.token); } } else { result = word; } // Possibly append suffix if (this.continuing_subword_suffix) { // Do not append suffix to the last token for (let i = 0; i < result.length - 1; ++i) { result[i] += this.continuing_subword_suffix; } } if (token.length < this.max_length_to_cache) { // Save the result to the cache this.cache.put(token, result); } return result; } /** * Helper function to add a node to the priority queue. * @param {PriorityQueue} queue * @param {BPENode} node * @private */ _add_node(queue, node) { // `score` is a measure of the merge priority: lower means higher priority // We use the BPE rank as a measure of priority (i.e., the local of the merge in the merges list) // We also add a fractional component to the score to break ties (with the earlier character having higher priority) const rank = this.bpe_ranks.get(JSON.stringify([node.token, node.next.token])); if (rank !== undefined) { node.score = rank + node.bias; queue.push(node); } } /** * Encodes the input sequence of tokens using the BPE algorithm and returns the resulting subword tokens. * @param {string[]} tokens The input sequence of tokens to encode. * @returns {string[]} The resulting subword tokens after applying the BPE algorithm to the input sequence of tokens. */ encode(tokens) { const outputTokens = []; for (const token of tokens) { if (this.ignore_merges && this.tokens_to_ids.has(token)) { outputTokens.push(token); continue; } const bpe_token_list = this.bpe(token); for (const t of bpe_token_list) { if (this.tokens_to_ids.has(t)) { outputTokens.push(t); } else if (this.byte_fallback) { const byteTokens = Array.from(this.text_encoder.encode(t)) .map(x => `<0x${x.toString(16).toUpperCase().padStart(2, '0')}>`); if (byteTokens.every(x => this.tokens_to_ids.has(x))) { // Ensure the byte tokens are actually in the vocabulary, otherwise // we fall back to the unknown token. For more information, see // https://github.com/huggingface/transformers/issues/28096. outputTokens.push(...byteTokens); } else { outputTokens.push(this.unk_token); } } else { outputTokens.push(this.unk_token); } } } return outputTokens; } } /** * Legacy tokenizer class for tokenizers with only a vocabulary. */ class LegacyTokenizerModel extends TokenizerModel { /** * Create a LegacyTokenizerModel instance. * @param {Object} config The configuration object for LegacyTokenizerModel. * @param {Object} config.vocab A (possibly nested) mapping of tokens to ids. * @param {Object} moreConfig Additional configuration object for the LegacyTokenizerModel model. */ constructor(config, moreConfig) { super(config); /**@type {Map} */ this.tokens_to_ids = objectToMap( moreConfig.target_lang ? config.vocab[moreConfig.target_lang] : config.vocab ); this.bos_token = moreConfig.bos_token; this.bos_token_id = this.tokens_to_ids.get(this.bos_token); this.eos_token = moreConfig.eos_token; this.eos_token_id = this.tokens_to_ids.get(this.eos_token); this.pad_token = moreConfig.pad_token; this.pad_token_id = this.tokens_to_ids.get(this.pad_token); this.unk_token = moreConfig.unk_token; this.unk_token_id = this.tokens_to_ids.get(this.unk_token); this.vocab = new Array(this.tokens_to_ids.size); for (const [key, value] of this.tokens_to_ids) { this.vocab[value] = key; } } encode(tokens) { return tokens; } } /** * A base class for text normalization. * @abstract */ class Normalizer extends Callable { /** * @param {Object} config The configuration object for the normalizer. */ constructor(config) { super(); this.config = config; } /** * Factory method for creating normalizers from config objects. * @static * @param {Object} config The configuration object for the normalizer. * @returns {Normalizer} A Normalizer object. * @throws {Error} If an unknown Normalizer type is specified in the config. */ static fromConfig(config) { if (config === null) return null; switch (config.type) { case 'BertNormalizer': return new BertNormalizer(config); case 'Precompiled': return new Precompiled(config); case 'Sequence': return new NormalizerSequence(config); case 'Replace': return new Replace(config); case 'NFC': return new NFC(config); case 'NFD': return new NFD(config); case 'NFKC': return new NFKC(config); case 'NFKD': return new NFKD(config); case 'Strip': return new StripNormalizer(config); case 'StripAccents': return new StripAccents(config); case 'Lowercase': return new Lowercase(config); case 'Prepend': return new Prepend(config); default: throw new Error(`Unknown Normalizer type: ${config.type}`); } } /** * Normalize the input text. * @abstract * @param {string} text The text to normalize. * @returns {string} The normalized text. * @throws {Error} If this method is not implemented in a subclass. */ normalize(text) { throw Error("normalize should be implemented in subclass.") } /** * Alias for {@link Normalizer#normalize}. * @param {string} text The text to normalize. * @returns {string} The normalized text. */ _call(text) { return this.normalize(text); } } /** * Replace normalizer that replaces occurrences of a pattern with a given string or regular expression. * @extends Normalizer */ class Replace extends Normalizer { /** * Normalize the input text by replacing the pattern with the content. * @param {string} text The input text to be normalized. * @returns {string} The normalized text after replacing the pattern with the content. */ normalize(text) { const pattern = createPattern(this.config.pattern); return pattern === null ? text : text.replaceAll(pattern, this.config.content); } } /** * A normalizer that applies Unicode normalization to the input text. * @extends Normalizer * @abstract */ class UnicodeNormalizer extends Normalizer { /** * @type {string} The Unicode normalization form to apply. * Should be one of: 'NFC', 'NFD', 'NFKC', or 'NFKD'. */ form = undefined; /** * Normalize the input text by applying Unicode normalization. * @param {string} text The input text to be normalized. * @returns {string} The normalized text. */ normalize(text) { text = text.normalize(this.form) return text; } } /** * A normalizer that applies Unicode normalization form C (NFC) to the input text. * Canonical Decomposition, followed by Canonical Composition. * @extends UnicodeNormalizer */ class NFC extends UnicodeNormalizer { form = 'NFC'; } /** * A normalizer that applies Unicode normalization form D (NFD) to the input text. * Canonical Decomposition. * @extends UnicodeNormalizer */ class NFD extends UnicodeNormalizer { form = 'NFD'; } /** * A normalizer that applies Unicode normalization form KC (NFKC) to the input text. * Compatibility Decomposition, followed by Canonical Composition. * @extends UnicodeNormalizer */ class NFKC extends UnicodeNormalizer { form = 'NFKC'; } /** * A normalizer that applies Unicode normalization form KD (NFKD) to the input text. * Compatibility Decomposition. * @extends UnicodeNormalizer */ class NFKD extends UnicodeNormalizer { form = 'NFKD'; } /** * A normalizer that strips leading and/or trailing whitespace from the input text. */ class StripNormalizer extends Normalizer { /** * Strip leading and/or trailing whitespace from the input text. * @param {string} text The input text. * @returns {string} The normalized text. */ normalize(text) { if (this.config.strip_left && this.config.strip_right) { // Fast path to avoid an extra trim call text = text.trim(); } else { if (this.config.strip_left) { text = text.trimStart(); } if (this.config.strip_right) { text = text.trimEnd(); } } return text; } } /** * StripAccents normalizer removes all accents from the text. * @extends Normalizer */ class StripAccents extends Normalizer { /** * Remove all accents from the text. * @param {string} text The input text. * @returns {string} The normalized text without accents. */ normalize(text) { text = remove_accents(text); return text; } } /** * A Normalizer that lowercases the input string. * @extends Normalizer */ class Lowercase extends Normalizer { /** * Lowercases the input string. * @param {string} text The text to normalize. * @returns {string} The normalized text. */ normalize(text) { text = text.toLowerCase(); return text; } } /** * A Normalizer that prepends a string to the input string. * @extends Normalizer */ class Prepend extends Normalizer { /** * Prepends the input string. * @param {string} text The text to normalize. * @returns {string} The normalized text. */ normalize(text) { text = this.config.prepend + text; return text; } } /** * A Normalizer that applies a sequence of Normalizers. * @extends Normalizer */ class NormalizerSequence extends Normalizer { /** * Create a new instance of NormalizerSequence. * @param {Object} config The configuration object. * @param {Object[]} config.normalizers An array of Normalizer configuration objects. */ constructor(config) { super(config); this.normalizers = config.normalizers.map(x => Normalizer.fromConfig(x)); } /** * Apply a sequence of Normalizers to the input text. * @param {string} text The text to normalize. * @returns {string} The normalized text. */ normalize(text) { return this.normalizers.reduce((t, normalizer) => { return normalizer.normalize(t); }, text); } } /** * A class representing a normalizer used in BERT tokenization. * @extends Normalizer */ class BertNormalizer extends Normalizer { /** * Adds whitespace around any CJK (Chinese, Japanese, or Korean) character in the input text. * * @param {string} text The input text to tokenize. * @returns {string} The tokenized text with whitespace added around CJK characters. */ _tokenize_chinese_chars(text) { /* Adds whitespace around any CJK character. */ const output = []; for (let i = 0; i < text.length; ++i) { const char = text[i]; const cp = char.charCodeAt(0); if (is_chinese_char(cp)) { output.push(" "); output.push(char); output.push(" "); } else { output.push(char); } } return output.join(""); } /** * Strips accents from the given text. * @param {string} text The text to strip accents from. * @returns {string} The text with accents removed. */ stripAccents(text) { // "Mark, Nonspacing" (Mn) return text.normalize('NFD').replace(/\p{Mn}/gu, ''); } /** * Checks whether `char` is a control character. * @param {string} char The character to check. * @returns {boolean} Whether `char` is a control character. * @private */ _is_control(char) { switch (char) { case '\t': case '\n': case '\r': // These are technically control characters but we count them as whitespace characters. return false; default: // Check if unicode category starts with C: // Cc - Control // Cf - Format // Co - Private Use // Cs - Surrogate return /^\p{Cc}|\p{Cf}|\p{Co}|\p{Cs}$/u.test(char); } } /** * Performs invalid character removal and whitespace cleanup on text. * @param {string} text The text to clean. * @returns {string} The cleaned text. * @private */ _clean_text(text) { const output = []; for (const char of text) { const cp = char.charCodeAt(0); if (cp === 0 || cp === 0xFFFD || this._is_control(char)) { continue; } if (/^\s$/.test(char)) { // is whitespace output.push(" "); } else { output.push(char); } } return output.join(""); } /** * Normalizes the given text based on the configuration. * @param {string} text The text to normalize. * @returns {string} The normalized text. */ normalize(text) { if (this.config.clean_text) { text = this._clean_text(text); } if (this.config.handle_chinese_chars) { text = this._tokenize_chinese_chars(text); } if (this.config.lowercase) { text = text.toLowerCase(); if (this.config.strip_accents !== false) { text = this.stripAccents(text); } } else if (this.config.strip_accents) { text = this.stripAccents(text); } return text; } } /** * A callable class representing a pre-tokenizer used in tokenization. Subclasses * should implement the `pre_tokenize_text` method to define the specific pre-tokenization logic. * @extends Callable */ class PreTokenizer extends Callable { /** * Factory method that returns an instance of a subclass of `PreTokenizer` based on the provided configuration. * * @static * @param {Object} config A configuration object for the pre-tokenizer. * @returns {PreTokenizer} An instance of a subclass of `PreTokenizer`. * @throws {Error} If the provided configuration object does not correspond to any known pre-tokenizer. */ static fromConfig(config) { if (config === null) return null; switch (config.type) { case 'BertPreTokenizer': return new BertPreTokenizer(config); case 'Sequence': return new PreTokenizerSequence(config); case 'Whitespace': return new WhitespacePreTokenizer(config); case 'WhitespaceSplit': return new WhitespaceSplit(config); case 'Metaspace': return new MetaspacePreTokenizer(config); case 'ByteLevel': return new ByteLevelPreTokenizer(config); case 'Split': return new SplitPreTokenizer(config); case 'Punctuation': return new PunctuationPreTokenizer(config); case 'Digits': return new DigitsPreTokenizer(config); case 'Replace': return new ReplacePreTokenizer(config); default: throw new Error(`Unknown PreTokenizer type: ${config.type}`); } } /** * Method that should be implemented by subclasses to define the specific pre-tokenization logic. * * @abstract * @param {string} text The text to pre-tokenize. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} The pre-tokenized text. * @throws {Error} If the method is not implemented in the subclass. */ pre_tokenize_text(text, options) { throw Error("pre_tokenize_text should be implemented in subclass.") } /** * Tokenizes the given text into pre-tokens. * @param {string|string[]} text The text or array of texts to pre-tokenize. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of pre-tokens. */ pre_tokenize(text, options) { return (Array.isArray(text) ? text.map(x => this.pre_tokenize_text(x, options)) : this.pre_tokenize_text(text, options) ).flat(); } /** * Alias for {@link PreTokenizer#pre_tokenize}. * @param {string|string[]} text The text or array of texts to pre-tokenize. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of pre-tokens. */ _call(text, options) { return this.pre_tokenize(text, options); } } /** * @extends PreTokenizer */ class BertPreTokenizer extends PreTokenizer { /** * A PreTokenizer that splits text into wordpieces using a basic tokenization scheme * similar to that used in the original implementation of BERT. * * @param {Object} config The configuration object. */ constructor(config) { super(); // Construct a pattern which matches the rust implementation: // https://github.com/huggingface/tokenizers/blob/b4fcc9ce6e4ad5806e82826f816acfdfdc4fcc67/tokenizers/src/pre_tokenizers/bert.rs#L11 // Equivalent to removing whitespace and splitting on punctuation (both \p{P} and other ascii characters) this.pattern = new RegExp(`[^\\s${PUNCTUATION_REGEX}]+|[${PUNCTUATION_REGEX}]`, 'gu'); } /** * Tokenizes a single text using the BERT pre-tokenization scheme. * * @param {string} text The text to tokenize. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ pre_tokenize_text(text, options) { return text.trim().match(this.pattern) || []; } } /** * A pre-tokenizer that splits text into Byte-Pair-Encoding (BPE) subwords. * @extends PreTokenizer */ class ByteLevelPreTokenizer extends PreTokenizer { /** * Creates a new instance of the `ByteLevelPreTokenizer` class. * @param {Object} config The configuration object. */ constructor(config) { super(); this.config = config; /** * @type {boolean} Whether to add a leading space to the first word. * This allows to treat the leading word just as any other word. */ this.add_prefix_space = this.config.add_prefix_space; /** * @type {boolean} Whether the post processing step should trim offsets * to avoid including whitespaces. * @todo Use this in the pretokenization step. */ this.trim_offsets = this.config.trim_offsets; /** * @type {boolean} Whether to use the standard GPT2 regex for whitespace splitting. * Set it to False if you want to use your own splitting. Defaults to true. */ this.use_regex = this.config.use_regex ?? true; this.pattern = /'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+/gu; this.byte_encoder = BYTES_TO_UNICODE; this.text_encoder = new TextEncoder(); } /** * Tokenizes a single piece of text using byte-level tokenization. * @param {string} text The text to tokenize. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ pre_tokenize_text(text, options) { // Add a leading space if the option is enabled if (this.add_prefix_space && !text.startsWith(' ')) { text = ' ' + text; } // Split on whitespace and punctuation const tokens = this.use_regex ? (text.match(this.pattern) || []) : [text]; // Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) return tokens.map( token => Array.from(this.text_encoder.encode(token), byte => this.byte_encoder[byte]).join('') ); } } /** * @typedef {'removed'|'isolated'|'mergedWithPrevious'|'mergedWithNext'|'contiguous'} SplitDelimiterBehavior */ /** * Splits text using a given pattern. * @extends PreTokenizer */ class SplitPreTokenizer extends PreTokenizer { /** * @param {Object} config The configuration options for the pre-tokenizer. * @param {Object} config.pattern The pattern used to split the text. Can be a string or a regex object. * @param {string|undefined} config.pattern.String The string to use for splitting. Only defined if the pattern is a string. * @param {string|undefined} config.pattern.Regex The regex to use for splitting. Only defined if the pattern is a regex. * @param {SplitDelimiterBehavior} config.behavior The behavior to use when splitting. * @param {boolean} config.invert Whether to split (invert=false) or match (invert=true) the pattern. */ constructor(config) { super(); this.config = config; // TODO support all behaviours (config.behavior) this.pattern = createPattern(this.config.pattern, this.config.invert); } /** * Tokenizes text by splitting it using the given pattern. * @param {string} text The text to tokenize. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ pre_tokenize_text(text, options) { if (this.pattern === null) { return []; } if (this.config.invert) { return text.match(this.pattern) || []; } else if (this.config.behavior?.toLowerCase() === 'removed') { return text.split(this.pattern).filter(x => x); } else { return regexSplit(text, this.pattern); } } } /** * Splits text based on punctuation. * @extends PreTokenizer */ class PunctuationPreTokenizer extends PreTokenizer { /** * @param {Object} config The configuration options for the pre-tokenizer. * @param {SplitDelimiterBehavior} config.behavior The behavior to use when splitting. */ constructor(config) { super(); this.config = config; this.pattern = new RegExp(`[^${PUNCTUATION_REGEX}]+|[${PUNCTUATION_REGEX}]+`, 'gu'); } /** * Tokenizes text by splitting it using the given pattern. * @param {string} text The text to tokenize. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ pre_tokenize_text(text, options) { return text.match(this.pattern) || []; } } /** * Splits text based on digits. * @extends PreTokenizer */ class DigitsPreTokenizer extends PreTokenizer { /** * @param {Object} config The configuration options for the pre-tokenizer. * @param {boolean} config.individual_digits Whether to split on individual digits. */ constructor(config) { super(); this.config = config; // Construct a pattern which matches the rust implementation: const digit_pattern = `[^\\d]+|\\d${this.config.individual_digits ? '' : '+'}`; this.pattern = new RegExp(digit_pattern, 'gu'); } /** * Tokenizes text by splitting it using the given pattern. * @param {string} text The text to tokenize. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens. */ pre_tokenize_text(text, options) { return text.match(this.pattern) || []; } } /** * @typedef {Object} PostProcessedOutput * @property {string[]} tokens List of token produced by the post-processor. * @property {number[]} [token_type_ids] List of token type ids produced by the post-processor. */ /** * @typedef {Object} EncodingSingle * @property {number[]} input_ids List of token ids to be fed to a model. * @property {number[]} attention_mask List of token type ids to be fed to a model * @property {number[]} [token_type_ids] List of indices specifying which tokens should be attended to by the model */ /** * @extends Callable */ class PostProcessor extends Callable { /** * @param {Object} config The configuration for the post-processor. */ constructor(config) { super(); this.config = config; } /** * Factory method to create a PostProcessor object from a configuration object. * * @param {Object} config Configuration object representing a PostProcessor. * @returns {PostProcessor} A PostProcessor object created from the given configuration. * @throws {Error} If an unknown PostProcessor type is encountered. */ static fromConfig(config) { if (config === null) return null; switch (config.type) { case 'TemplateProcessing': return new TemplateProcessing(config); case 'ByteLevel': return new ByteLevelPostProcessor(config); case 'RobertaProcessing': return new RobertaProcessing(config); case 'BertProcessing': return new BertProcessing(config); case 'Sequence': return new PostProcessorSequence(config); default: throw new Error(`Unknown PostProcessor type: ${config.type}`); } } /** * Method to be implemented in subclass to apply post-processing on the given tokens. * * @param {Array} tokens The input tokens to be post-processed. * @param {...*} args Additional arguments required by the post-processing logic. * @returns {PostProcessedOutput} The post-processed tokens. * @throws {Error} If the method is not implemented in subclass. */ post_process(tokens, ...args) { throw Error("post_process should be implemented in subclass.") } /** * Alias for {@link PostProcessor#post_process}. * @param {Array} tokens The text or array of texts to post-process. * @param {...*} args Additional arguments required by the post-processing logic. * @returns {PostProcessedOutput} The post-processed tokens. */ _call(tokens, ...args) { return this.post_process(tokens, ...args); } } /** * A post-processor that adds special tokens to the beginning and end of the input. */ class BertProcessing extends PostProcessor { /** * @param {Object} config The configuration for the post-processor. * @param {string[]} config.cls The special tokens to add to the beginning of the input. * @param {string[]} config.sep The special tokens to add to the end of the input. */ constructor(config) { super(config); // TODO use all of config: add_prefix_space, trim_offsets this.cls = config.cls[0]; this.sep = config.sep[0]; } /** * Adds the special tokens to the beginning and end of the input. * @param {string[]} tokens The input tokens. * @param {string[]} [tokens_pair=null] An optional second set of input tokens. * @returns {PostProcessedOutput} The post-processed tokens with the special tokens added to the beginning and end. */ post_process(tokens, tokens_pair = null, { add_special_tokens = true, } = {}) { if (add_special_tokens) { tokens = mergeArrays([this.cls], tokens, [this.sep]); } let token_type_ids = new Array(tokens.length).fill(0); if (tokens_pair !== null) { // NOTE: It is intended to add 2 EOS tokens after the first set of tokens // https://github.com/huggingface/tokenizers/issues/983 const middle = (add_special_tokens && this instanceof RobertaProcessing) ? [this.sep] : []; const after = add_special_tokens ? [this.sep] : []; tokens = mergeArrays(tokens, middle, tokens_pair, after); token_type_ids = mergeArrays(token_type_ids, new Array(tokens_pair.length + middle.length + after.length).fill(1)); } return { tokens, token_type_ids }; } } class RobertaProcessing extends BertProcessing { } // NOTE: extends BertProcessing /** * Post processor that replaces special tokens in a template with actual tokens. * @extends PostProcessor */ class TemplateProcessing extends PostProcessor { /** * Creates a new instance of `TemplateProcessing`. * @param {Object} config The configuration options for the post processor. * @param {Array} config.single The template for a single sequence of tokens. * @param {Array} config.pair The template for a pair of sequences of tokens. */ constructor(config) { super(config); this.single = config.single; this.pair = config.pair; } /** * Replaces special tokens in the template with actual tokens. * @param {string[]} tokens The list of tokens for the first sequence. * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). * @returns {PostProcessedOutput} An object containing the list of tokens with the special tokens replaced with actual tokens. */ post_process(tokens, tokens_pair = null, { add_special_tokens = true, } = {}) { const type = tokens_pair === null ? this.single : this.pair let processedTokens = []; let types = []; for (const item of type) { if ('SpecialToken' in item) { if (add_special_tokens) { processedTokens.push(item.SpecialToken.id); types.push(item.SpecialToken.type_id); } } else if ('Sequence' in item) { if (item.Sequence.id === 'A') { processedTokens = mergeArrays(processedTokens, tokens); types = mergeArrays(types, new Array(tokens.length).fill(item.Sequence.type_id)); } else if (item.Sequence.id === 'B') { processedTokens = mergeArrays(processedTokens, tokens_pair); types = mergeArrays(types, new Array(tokens_pair.length).fill(item.Sequence.type_id)); } } } return { tokens: processedTokens, token_type_ids: types }; } } /** * A PostProcessor that returns the given tokens as is. * @extends PostProcessor */ class ByteLevelPostProcessor extends PostProcessor { /** * Post process the given tokens. * @param {string[]} tokens The list of tokens for the first sequence. * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). * @returns {PostProcessedOutput} An object containing the post-processed tokens. */ post_process(tokens, tokens_pair = null) { if (tokens_pair) { tokens = mergeArrays(tokens, tokens_pair); } return { tokens }; } } /** * A post-processor that applies multiple post-processors in sequence. */ class PostProcessorSequence extends PostProcessor { /** * Creates a new instance of PostProcessorSequence. * @param {Object} config The configuration object. * @param {Object[]} config.processors The list of post-processors to apply. */ constructor(config) { super(config); this.processors = config.processors.map(x => PostProcessor.fromConfig(x)); } /** * Post process the given tokens. * @param {string[]} tokens The list of tokens for the first sequence. * @param {string[]} [tokens_pair=null] The list of tokens for the second sequence (optional). * @returns {PostProcessedOutput} An object containing the post-processed tokens. */ post_process(tokens, tokens_pair = null, options = {}) { let token_type_ids; for (const processor of this.processors) { if (processor instanceof ByteLevelPostProcessor) { // Special case where we need to pass the tokens_pair to the post-processor const output = processor.post_process(tokens); tokens = output.tokens; if (tokens_pair) { const pair_output = processor.post_process(tokens_pair); tokens_pair = pair_output.tokens; } } else { const output = processor.post_process(tokens, tokens_pair, options); tokens = output.tokens; token_type_ids = output.token_type_ids; } } return { tokens, token_type_ids }; } } /** * The base class for token decoders. * @extends Callable */ class Decoder extends Callable { /** * Creates an instance of `Decoder`. * * @param {Object} config The configuration object. */ constructor(config) { super(); this.config = config; /** @type {AddedToken[]} */ this.added_tokens = []; this.end_of_word_suffix = null; this.trim_offsets = config.trim_offsets; } /** * Creates a decoder instance based on the provided configuration. * * @param {Object} config The configuration object. * @returns {Decoder} A decoder instance. * @throws {Error} If an unknown decoder type is provided. */ static fromConfig(config) { if (config === null) return null; switch (config.type) { case 'WordPiece': return new WordPieceDecoder(config); case 'Metaspace': return new MetaspaceDecoder(config); case 'ByteLevel': return new ByteLevelDecoder(config); case 'Replace': return new ReplaceDecoder(config); case 'ByteFallback': return new ByteFallback(config); case 'Fuse': return new FuseDecoder(config); case 'Strip': return new StripDecoder(config); case 'Sequence': return new DecoderSequence(config); case 'CTC': return new CTCDecoder(config); case 'BPEDecoder': return new BPEDecoder(config); default: throw new Error(`Unknown Decoder type: ${config.type}`); } } /** * Calls the `decode` method. * * @param {string[]} tokens The list of tokens. * @returns {string} The decoded string. */ _call(tokens) { return this.decode(tokens); } /** * Decodes a list of tokens. * @param {string[]} tokens The list of tokens. * @returns {string} The decoded string. */ decode(tokens) { return this.decode_chain(tokens).join(''); } /** * Apply the decoder to a list of tokens. * * @param {string[]} tokens The list of tokens. * @returns {string[]} The decoded list of tokens. * @throws {Error} If the `decode_chain` method is not implemented in the subclass. */ decode_chain(tokens) { throw Error("`decode_chain` should be implemented in subclass.") } } class ReplaceDecoder extends Decoder { /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { const pattern = createPattern(this.config.pattern); return pattern === null ? tokens : tokens.map(token => token.replaceAll(pattern, this.config.content)) } } class ByteFallback extends Decoder { constructor(config) { super(config); this.text_decoder = new TextDecoder(); } /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { const new_tokens = []; let previous_byte_tokens = []; for (const token of tokens) { let bytes = null; if (token.length === 6 && token.startsWith('<0x') && token.endsWith('>')) { const byte = parseInt(token.slice(3, 5), 16); if (!isNaN(byte)) { bytes = byte; } } if (bytes !== null) { previous_byte_tokens.push(bytes); } else { if (previous_byte_tokens.length > 0) { const string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); new_tokens.push(string); previous_byte_tokens = []; } new_tokens.push(token); } } if (previous_byte_tokens.length > 0) { const string = this.text_decoder.decode(Uint8Array.from(previous_byte_tokens)); new_tokens.push(string); previous_byte_tokens = []; } return new_tokens; } } /** * Fuse simply fuses all tokens into one big string. * It's usually the last decoding step anyway, but this decoder * exists incase some decoders need to happen after that step */ class FuseDecoder extends Decoder { /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { return [tokens.join('')]; } } class StripDecoder extends Decoder { constructor(config) { super(config); this.content = this.config.content; this.start = this.config.start; this.stop = this.config.stop; } /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { return tokens.map(token => { let start_cut = 0; for (let i = 0; i < this.start; ++i) { if (token[i] === this.content) { start_cut = i + 1; continue; } else { break; } } let stop_cut = token.length; for (let i = 0; i < this.stop; ++i) { const index = token.length - i - 1; if (token[index] === this.content) { stop_cut = index; continue; } else { break; } } return token.slice(start_cut, stop_cut) }); } } /** * A decoder that decodes a list of WordPiece tokens into a single string. * @extends Decoder */ class WordPieceDecoder extends Decoder { /** * Creates a new instance of WordPieceDecoder. * @param {Object} config The configuration object. * @param {string} config.prefix The prefix used for WordPiece encoding. * @param {boolean} config.cleanup Whether to cleanup the decoded string. */ constructor(config) { super(config); this.cleanup = config.cleanup; } /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { return tokens.map((token, i) => { if (i !== 0) { if (token.startsWith(this.config.prefix)) { // NOTE: .replace() is intended; only replace first occurrence token = token.replace(this.config.prefix, ''); } else { token = ' ' + token; } } if (this.cleanup) { token = clean_up_tokenization(token) } return token; }); } } /** * Byte-level decoder for tokenization output. Inherits from the `Decoder` class. * @extends Decoder */ class ByteLevelDecoder extends Decoder { /** * Create a `ByteLevelDecoder` object. * @param {Object} config Configuration object. */ constructor(config) { super(config); this.byte_decoder = UNICODE_TO_BYTES; this.text_decoder = new TextDecoder("utf-8", { fatal: false, ignoreBOM: true, }); this.end_of_word_suffix = null; } /** * Convert an array of tokens to string by decoding each byte. * @param {string[]} tokens Array of tokens to be decoded. * @returns {string} The decoded string. */ convert_tokens_to_string(tokens) { const text = tokens.join(''); const byteArray = new Uint8Array([...text].map(c => this.byte_decoder[c])); const decoded_text = this.text_decoder.decode(byteArray); return decoded_text; } /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { // TODO move to base class (like HF) // tokens === filtered_tokens // To avoid mixing byte-level and unicode for byte-level BPT // we need to build string separately for added tokens and byte-level tokens // cf. https://github.com/huggingface/transformers/issues/1133 const sub_texts = []; let current_sub_text = []; for (const token of tokens) { // tokens sent here are already filtered, so we don't need to do this // if (skip_special_tokens && this.all_special_ids.includes(token)) { // continue; // } if (this.added_tokens.find(x => x.content === token) !== undefined) { if (current_sub_text.length > 0) { sub_texts.push(this.convert_tokens_to_string(current_sub_text)); current_sub_text = []; } sub_texts.push(token); } else { current_sub_text.push(token); } } if (current_sub_text.length > 0) { sub_texts.push(this.convert_tokens_to_string(current_sub_text)); } // TODO add spaces_between_special_tokens and clean_up_tokenization_spaces options return sub_texts; } } /** * The CTC (Connectionist Temporal Classification) decoder. * See https://github.com/huggingface/tokenizers/blob/bb38f390a61883fc2f29d659af696f428d1cda6b/tokenizers/src/decoders/ctc.rs */ class CTCDecoder extends Decoder { constructor(config) { super(config); this.pad_token = this.config.pad_token; this.word_delimiter_token = this.config.word_delimiter_token; this.cleanup = this.config.cleanup; } /** * Converts a connectionist-temporal-classification (CTC) output tokens into a single string. * @param {string[]} tokens Array of tokens to be decoded. * @returns {string} The decoded string. */ convert_tokens_to_string(tokens) { if (tokens.length === 0) return ''; // group same tokens into non-repeating tokens in CTC style decoding const grouped_tokens = [tokens[0]]; for (let i = 1; i < tokens.length; ++i) { if (tokens[i] !== grouped_tokens.at(-1)) { grouped_tokens.push(tokens[i]); } } // filter self.pad_token which is used as CTC-blank token const filtered_tokens = grouped_tokens.filter(token => token !== this.pad_token); let text = filtered_tokens.join(''); if (this.cleanup) { // cleanup and replace delimiter token text = clean_up_tokenization(text) .replaceAll(this.word_delimiter_token, ' ') .trim(); } return text; } /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { return [this.convert_tokens_to_string(tokens)]; } } /** * Apply a sequence of decoders. * @extends Decoder */ class DecoderSequence extends Decoder { /** * Creates a new instance of DecoderSequence. * @param {Object} config The configuration object. * @param {Object[]} config.decoders The list of decoders to apply. */ constructor(config) { super(config); this.decoders = config.decoders.map(x => Decoder.fromConfig(x)); } /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { // Use reduce to apply each decoder to the tokens return this.decoders.reduce((toks, decoder) => { return decoder.decode_chain(toks); }, tokens); } } class BPEDecoder extends Decoder { constructor(config) { super(config); this.suffix = this.config.suffix; } /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { return tokens.map((token, i) => { return token.replaceAll(this.suffix, (i === tokens.length - 1) ? '' : ' ') }); } } // Custom decoder for VITS class VitsDecoder extends Decoder { /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { let decoded = ''; for (let i = 1; i < tokens.length; i += 2) { decoded += tokens[i]; } return [decoded]; } } /** * This PreTokenizer replaces spaces with the given replacement character, adds a prefix space if requested, * and returns a list of tokens. * @extends PreTokenizer */ class MetaspacePreTokenizer extends PreTokenizer { /** * @param {Object} config The configuration object for the MetaspacePreTokenizer. * @param {boolean} config.add_prefix_space Whether to add a prefix space to the first token. * @param {string} config.replacement The character to replace spaces with. * @param {string} [config.str_rep=config.replacement] An optional string representation of the replacement character. * @param {'first'|'never'|'always'} [config.prepend_scheme='always'] The metaspace prepending scheme. */ constructor(config) { super(); this.addPrefixSpace = config.add_prefix_space; this.replacement = config.replacement; this.strRep = config.str_rep || this.replacement; this.prepend_scheme = config.prepend_scheme ?? 'always'; } /** * This method takes a string, replaces spaces with the replacement character, * adds a prefix space if requested, and returns a new list of tokens. * @param {string} text The text to pre-tokenize. * @param {Object} [options] The options for the pre-tokenization. * @param {number} [options.section_index] The index of the section to pre-tokenize. * @returns {string[]} A new list of pre-tokenized tokens. */ pre_tokenize_text(text, { section_index = undefined, } = {}) { let normalized = text.replaceAll(' ', this.strRep); if ( // We add a prefix space if: // (1) The addPrefixSpace option is enabled and the normalized // token does not already start with the replacement character. (this.addPrefixSpace && !normalized.startsWith(this.replacement)) // and (2) either: // (a) prepend_scheme is 'always' // (b) prepend_scheme is 'first' and this is the first section && ( this.prepend_scheme === 'always' || (this.prepend_scheme === 'first' && section_index === 0) ) ) { normalized = this.strRep + normalized; } return [normalized]; } } /** * MetaspaceDecoder class extends the Decoder class and decodes Metaspace tokenization. * @extends Decoder */ class MetaspaceDecoder extends Decoder { /** * Constructs a new MetaspaceDecoder object. * @param {Object} config The configuration object for the MetaspaceDecoder. * @param {boolean} config.add_prefix_space Whether to add a prefix space to the decoded string. * @param {string} config.replacement The string to replace spaces with. */ constructor(config) { super(config); this.addPrefixSpace = config.add_prefix_space; this.replacement = config.replacement; } /** @type {Decoder['decode_chain']} */ decode_chain(tokens) { const result = []; for (let i = 0; i < tokens.length; ++i) { let normalized = tokens[i].replaceAll(this.replacement, ' '); if (this.addPrefixSpace && i == 0 && normalized.startsWith(' ')) { normalized = normalized.substring(1); } result.push(normalized); } return result; } } /** * A normalizer that applies a precompiled charsmap. * This is useful for applying complex normalizations in C++ and exposing them to JavaScript. * @extends Normalizer * @param {Object} config The configuration object for the Precompiled normalizer. * @param {Object} config.precompiled_charsmap The precompiled charsmap object. */ class Precompiled extends Normalizer { /** * Create a new instance of Precompiled normalizer. * @param {Object} config The configuration object. * @param {any} config.precompiled_charsmap Precompiled chars mapping. */ constructor(config) { super(config); this.charsmap = config.precompiled_charsmap; } /** * Normalizes the given text by applying the precompiled charsmap. * @param {string} text The text to normalize. * @returns {string} The normalized text. */ normalize(text) { // As stated in the sentencepiece normalization docs (https://github.com/google/sentencepiece/blob/master/doc/normalization.md#use-pre-defined-normalization-rule), // there are 5 pre-defined normalization rules: // 1. nmt_nfkc: NFKC normalization with some additional normalization around spaces. (default) // 2. nfkc: original NFKC normalization. // 3. nmt_nfkc_cf: nmt_nfkc + Unicode case folding (mostly lower casing) // 4. nfkc_cf: nfkc + Unicode case folding. // 5. identity: no normalization // // For now, we only implement the default (nmt_nfkc). // See https://raw.githubusercontent.com/google/sentencepiece/master/data/nmt_nfkc.tsv for the full list of rules. // TODO: detect when a different `this.charsmap` is used. text = text.replace(/[\u0001-\u0008\u000B\u000E-\u001F\u007F\u008F\u009F]/gm, ''); // Remove control characters text = text.replace(/[\u0009\u000A\u000C\u000D\u00A0\u1680\u2000-\u200F\u2028\u2029\u202F\u205F\u2581\u3000\uFEFF\uFFFD]/gm, '\u0020'); // Replace certain characters with a space if (text.includes('\uFF5E')) { // To match the sentencepiece implementation 100%, we must handle a very strange edge-case. // For some reason, the "Fullwidth Tilde" character (\uFF5E) should not be converted to the standard Tilde character (\u007E). // However, NFKC normalization does do this conversion. As a result, we split the string on the Fullwidth Tilde character, // perform NFKC normalization on each substring, and then join them back together with the Fullwidth Tilde character. const parts = text.split('\uFF5E'); text = parts.map(part => part.normalize('NFKC')).join('\uFF5E'); } else { text = text.normalize('NFKC'); } return text; } } /** * A pre-tokenizer that applies a sequence of pre-tokenizers to the input text. * @extends PreTokenizer */ class PreTokenizerSequence extends PreTokenizer { /** * Creates an instance of PreTokenizerSequence. * @param {Object} config The configuration object for the pre-tokenizer sequence. * @param {Object[]} config.pretokenizers An array of pre-tokenizer configurations. */ constructor(config) { super(); this.tokenizers = config.pretokenizers.map(x => PreTokenizer.fromConfig(x)); } /** * Applies each pre-tokenizer in the sequence to the input text in turn. * @param {string} text The text to pre-tokenize. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} The pre-tokenized text. */ pre_tokenize_text(text, options) { // Use reduce to apply each tokenizer to the text return this.tokenizers.reduce((preTokenizedText, tokenizer) => { return tokenizer.pre_tokenize(preTokenizedText, options); }, [text]); } } /** * Splits on word boundaries (using the following regular expression: `\w+|[^\w\s]+`). */ class WhitespacePreTokenizer extends PreTokenizer { /** * Creates an instance of WhitespacePreTokenizer. * @param {Object} config The configuration object for the pre-tokenizer. */ constructor(config) { super(); } /** * Pre-tokenizes the input text by splitting it on word boundaries. * @param {string} text The text to be pre-tokenized. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens produced by splitting the input text on whitespace. */ pre_tokenize_text(text, options) { return text.match(/\w+|[^\w\s]+/g) || []; } } /** * Splits a string of text by whitespace characters into individual tokens. * @extends PreTokenizer */ class WhitespaceSplit extends PreTokenizer { /** * Creates an instance of WhitespaceSplit. * @param {Object} config The configuration object for the pre-tokenizer. */ constructor(config) { super(); } /** * Pre-tokenizes the input text by splitting it on whitespace characters. * @param {string} text The text to be pre-tokenized. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens produced by splitting the input text on whitespace. */ pre_tokenize_text(text, options) { return whitespace_split(text); } } // NOTE: `ReplacePreTokenizer` is custom (to support `BlenderbotSmallTokenizer`) class ReplacePreTokenizer extends PreTokenizer { /** * @param {Object} config The configuration options for the pre-tokenizer. * @param {Object} config.pattern The pattern used to split the text. Can be a string or a regex object. * @param {string} config.content What to replace the pattern with. */ constructor(config) { super(); this.config = config; this.pattern = createPattern(this.config.pattern); this.content = this.config.content; } /** * Pre-tokenizes the input text by replacing certain characters. * @param {string} text The text to be pre-tokenized. * @param {Object} [options] Additional options for the pre-tokenization logic. * @returns {string[]} An array of tokens produced by replacing certain characters. */ pre_tokenize_text(text, options) { if (this.pattern === null) { return [text]; } return [text.replaceAll(this.pattern, this.config.content)]; } } const SPECIAL_TOKEN_ATTRIBUTES = [ 'bos_token', 'eos_token', 'unk_token', 'sep_token', 'pad_token', 'cls_token', 'mask_token', // additional_special_tokens (TODO) ] /** * * Helper function for padding values of an object, which are each arrays. * NOTE: No additional checks are made here for validity of arguments. * @param {Record} item The input object. * @param {number} length The length to pad to. * @param {(key: string) => any} value_fn Determine the value to fill the array, based on its key. * @param {string} side Which side to pad the array. * @private */ function padHelper(item, length, value_fn, side) { for (const key of Object.keys(item)) { const diff = length - item[key].length; const value = value_fn(key); const padData = new Array(diff).fill(value); item[key] = side === 'right' ? mergeArrays(item[key], padData) : mergeArrays(padData, item[key]); } } /** * Helper function for truncating values of an object, which are each arrays. * NOTE: No additional checks are made here for validity of arguments. * @param {Record} item The input object. * @param {number} length The length to truncate to. * @private */ function truncateHelper(item, length) { // Setting .length to a lower value truncates the array in-place: // https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/length for (const key of Object.keys(item)) { item[key].length = length; } } /** * @typedef {Object} Message * @property {string} role The role of the message (e.g., "user" or "assistant" or "system"). * @property {string} content The content of the message. */ export class PreTrainedTokenizer extends Callable { return_token_type_ids = false; padding_side = 'right'; /** * Create a new PreTrainedTokenizer instance. * @param {Object} tokenizerJSON The JSON of the tokenizer. * @param {Object} tokenizerConfig The config of the tokenizer. */ constructor(tokenizerJSON, tokenizerConfig) { super(); this.config = tokenizerConfig; // Construct parts of the tokenizer from the JSON this.normalizer = Normalizer.fromConfig(tokenizerJSON.normalizer); this.pre_tokenizer = PreTokenizer.fromConfig(tokenizerJSON.pre_tokenizer); this.model = TokenizerModel.fromConfig(tokenizerJSON.model, tokenizerConfig); this.post_processor = PostProcessor.fromConfig(tokenizerJSON.post_processor); this.decoder = Decoder.fromConfig(tokenizerJSON.decoder); // Add added_tokens to model this.special_tokens = []; this.all_special_ids = []; /** @type {AddedToken[]} */ this.added_tokens = []; for (const addedToken of tokenizerJSON.added_tokens) { const token = new AddedToken(addedToken); this.added_tokens.push(token); this.model.tokens_to_ids.set(token.content, token.id); this.model.vocab[token.id] = token.content; if (token.special) { this.special_tokens.push(token.content); this.all_special_ids.push(token.id); } } // Update additional_special_tokens this.additional_special_tokens = tokenizerConfig.additional_special_tokens ?? []; this.special_tokens.push(...this.additional_special_tokens); this.special_tokens = [...new Set(this.special_tokens)]; // Remove duplicates if (this.decoder) { // Slight hack, but it prevents code duplication: this.decoder.added_tokens = this.added_tokens; // Another slight hack to add `end_of_word_suffix` (if present) to the decoder // This is needed for cases where BPE model and ByteLevel decoder are used // For more information, see https://github.com/huggingface/transformers.js/issues/74 // TODO: save this to the decoder when exporting? this.decoder.end_of_word_suffix = this.model.end_of_word_suffix; } this.added_tokens_splitter = new DictionarySplitter( this.added_tokens.map(x => x.content), ); /** @type {Map} */ this.added_tokens_map = new Map(this.added_tokens.map(x => [x.content, x])) // Set mask token if present (otherwise will be undefined, which is fine) this.mask_token = this.getToken('mask_token'); this.mask_token_id = this.model.tokens_to_ids.get(this.mask_token); this.pad_token = this.getToken('pad_token', 'eos_token'); this.pad_token_id = this.model.tokens_to_ids.get(this.pad_token); this.sep_token = this.getToken('sep_token'); this.sep_token_id = this.model.tokens_to_ids.get(this.sep_token); this.unk_token = this.getToken('unk_token'); this.unk_token_id = this.model.tokens_to_ids.get(this.unk_token); this.bos_token = this.getToken('bos_token'); this.bos_token_id = this.model.tokens_to_ids.get(this.bos_token); this.eos_token = this.getToken('eos_token'); this.eos_token_id = this.model.tokens_to_ids.get(this.eos_token); this.model_max_length = tokenizerConfig.model_max_length; /** @type {boolean} Whether or not to strip the text when tokenizing (removing excess spaces before and after the string). */ this.remove_space = tokenizerConfig.remove_space; this.clean_up_tokenization_spaces = tokenizerConfig.clean_up_tokenization_spaces ?? true; this.do_lowercase_and_remove_accent = tokenizerConfig.do_lowercase_and_remove_accent ?? false; if (tokenizerConfig.padding_side) { this.padding_side = tokenizerConfig.padding_side; } this.add_bos_token = tokenizerConfig.add_bos_token; this.add_eos_token = tokenizerConfig.add_eos_token; this.legacy = false; this.chat_template = tokenizerConfig.chat_template ?? null; if (Array.isArray(this.chat_template)) { // Chat templates are stored as lists of dicts with fixed key names, // we reconstruct that into a single dict while loading them. const chat_template = Object.create(null); for (const { name, template } of this.chat_template) { if (typeof name !== 'string' || typeof template !== 'string') { throw new Error('Chat template must be a list of objects with "name" and "template" properties'); } chat_template[name] = template; } this.chat_template = chat_template; } this._compiled_template_cache = new Map(); } /** * Returns the value of the first matching key in the tokenizer config object. * @param {...string} keys One or more keys to search for in the tokenizer config object. * @returns {string|null} The value associated with the first matching key, or null if no match is found. * @throws {Error} If an object is found for a matching key and its __type property is not "AddedToken". * @private */ getToken(...keys) { for (const key of keys) { const item = this.config[key]; if (!item) continue; if (typeof item === 'object') { if (item.__type === 'AddedToken') { return item.content; } else { throw Error(`Unknown token: ${item}`); } } else { return item; } } return null; } /** * Loads a pre-trained tokenizer from the given `pretrained_model_name_or_path`. * * @param {string} pretrained_model_name_or_path The path to the pre-trained tokenizer. * @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer. * * @throws {Error} Throws an error if the tokenizer.json or tokenizer_config.json files are not found in the `pretrained_model_name_or_path`. * @returns {Promise} A new instance of the `PreTrainedTokenizer` class. */ static async from_pretrained(pretrained_model_name_or_path, { progress_callback = null, config = null, cache_dir = null, local_files_only = false, revision = 'main', legacy = null, } = {}) { const info = await loadTokenizer(pretrained_model_name_or_path, { progress_callback, config, cache_dir, local_files_only, revision, legacy, }) // @ts-ignore return new this(...info); } /** * @typedef {number[]|number[][]|Tensor} BatchEncodingItem * * @typedef {Object} BatchEncoding Holds the output of the tokenizer's call function. * @property {BatchEncodingItem} input_ids List of token ids to be fed to a model. * @property {BatchEncodingItem} attention_mask List of indices specifying which tokens should be attended to by the model. * @property {BatchEncodingItem} [token_type_ids] List of token type ids to be fed to a model. */ /** * Encode/tokenize the given text(s). * @param {string|string[]} text The text to tokenize. * @param {Object} options An optional object containing the following properties: * @param {string|string[]} [options.text_pair=null] Optional second sequence to be encoded. If set, must be the same type as text. * @param {boolean|'max_length'} [options.padding=false] Whether to pad the input sequences. * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. * @param {boolean} [options.truncation=null] Whether to truncate the input sequences. * @param {number} [options.max_length=null] Maximum length of the returned list and optionally padding length. * @param {boolean} [options.return_tensor=true] Whether to return the results as Tensors or arrays. * @param {boolean} [options.return_token_type_ids=null] Whether to return the token type ids. * @returns {BatchEncoding} Object to be passed to the model. */ _call( // Required positional arguments text, // Optional keyword arguments { text_pair = null, add_special_tokens = true, padding = false, truncation = null, max_length = null, return_tensor = true, // Different to HF return_token_type_ids = null, } = {}, ) { const isBatched = Array.isArray(text); /** @type {EncodingSingle[]} */ let encodedTokens; if (isBatched) { if (text.length === 0) { throw Error('text array must be non-empty') } if (text_pair !== null) { if (!Array.isArray(text_pair)) { throw Error('text_pair must also be an array') } else if (text.length !== text_pair.length) { throw Error('text and text_pair must have the same length') } encodedTokens = text.map( (t, i) => this._encode_plus(t, { text_pair: text_pair[i], add_special_tokens, return_token_type_ids }) ) } else { encodedTokens = text.map(x => this._encode_plus(x, { add_special_tokens, return_token_type_ids })); } } else { if (text === null || text === undefined) { throw Error('text may not be null or undefined') } if (Array.isArray(text_pair)) { throw Error('When specifying `text_pair`, since `text` is a string, `text_pair` must also be a string (i.e., not an array).') } // For single input, we just wrap in an array, and then unwrap later. encodedTokens = [this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids })]; } // At this point, `encodedTokens` is batched, of shape [batch_size, tokens]. // However, array may be jagged. So, we may need pad to max_length. if (max_length === null) { max_length = this.model_max_length; } else if (truncation === null) { if (padding === true) { console.warn( "`max_length` is ignored when `padding: true` and there is no truncation strategy. " + "To pad to max length, use `padding: 'max_length'`." ) max_length = this.model_max_length; } else if (padding === false) { console.warn("Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation: true` to explicitly truncate examples to max length."); truncation = true; } } // padding: 'max_length' doesn't require any additional calculation // but padding: true has to calculate max_length from the sequences if (padding === true) { max_length = Math.min(max(encodedTokens.map(x => x.input_ids.length))[0], max_length ?? Infinity); } // Ensure it is less than model max length max_length = Math.min(max_length, this.model_max_length ?? Infinity); if (padding || truncation) { // Perform padding and/or truncation for (let i = 0; i < encodedTokens.length; ++i) { if (encodedTokens[i].input_ids.length === max_length) { continue; } else if (encodedTokens[i].input_ids.length > max_length) { // possibly truncate if (truncation) { truncateHelper(encodedTokens[i], max_length); } } else { // t.length < max_length // possibly pad if (padding) { padHelper( encodedTokens[i], max_length, key => key === 'input_ids' ? this.pad_token_id : 0, this.padding_side ); } } } } const result = {}; if (return_tensor) { if (!(padding && truncation)) { // Not, guaranteed that all items have same length, so // we perform additional check if ( encodedTokens.some(x => { for (const key of Object.keys(x)) { if (x[key].length !== encodedTokens[0][key]?.length) { return true; } } return false; }) ) { throw Error( "Unable to create tensor, you should probably activate truncation and/or padding " + "with 'padding=true' and 'truncation=true' to have batched tensors with the same length." ) } } // Now we actually convert to tensor // NOTE: In the same way as the python library, we return a batched tensor, regardless of // whether we have a single input or multiple inputs. const dims = [encodedTokens.length, encodedTokens[0].input_ids.length]; for (const key of Object.keys(encodedTokens[0])) { result[key] = new Tensor('int64', BigInt64Array.from(encodedTokens.flatMap(x => x[key]).map(BigInt)), dims ); } } else { for (const key of Object.keys(encodedTokens[0])) { result[key] = encodedTokens.map(x => x[key]); } // If not returning a tensor, we match the input type if (!isBatched) { // Input was not batched, so we unwrap for (const key of Object.keys(result)) { result[key] = result[key][0]; } } } return /** @type {BatchEncoding} */(result); } /** * Encodes a single text using the preprocessor pipeline of the tokenizer. * * @param {string|null} text The text to encode. * @returns {string[]|null} The encoded tokens. */ _encode_text(text) { if (text === null) return null; // Actual function which does encoding, for a single text // First, we take care of special tokens. Needed to avoid issues arising from // normalization and/or pretokenization (which may not preserve special tokens) const sections = this.added_tokens_splitter.split(text); // Process left/right stripping of added tokens for (let i = 0; i < sections.length; ++i) { const addedToken = this.added_tokens_map.get(sections[i]); if (addedToken) { if (addedToken.lstrip && i > 0) { sections[i - 1] = sections[i - 1].trimEnd(); } if (addedToken.rstrip && i < sections.length - 1) { sections[i + 1] = sections[i + 1].trimStart(); } } } const tokens = sections.flatMap((x, section_index) => { if (x.length === 0) return []; if (this.added_tokens_map.has(x)) return [x]; // Return added tokens unchanged if (this.remove_space === true) { x = x.trim().split(/\s+/).join(' '); } if (this.do_lowercase_and_remove_accent) { x = lowercase_and_remove_accent(x); } if (this.normalizer !== null) { x = this.normalizer(x); } // If, after normalization, this section is empty (e.g., trimming whitespace), // we return an empty array if (x.length === 0) { return []; } const sectionTokens = (this.pre_tokenizer !== null) ? this.pre_tokenizer(x, { section_index, }) : [x]; const tokens = this.model(sectionTokens); return tokens; }); return tokens; } /** * Encodes a single text or a pair of texts using the model's tokenizer. * * @param {string} text The text to encode. * @param {Object} options An optional object containing the following properties: * @param {string} [options.text_pair=null] The optional second text to encode. * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. * @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids. * @returns {EncodingSingle} An object containing the encoded text. * @private */ _encode_plus(text, { text_pair = null, add_special_tokens = true, return_token_type_ids = null, } = {}) { const { tokens, token_type_ids } = this._tokenize_helper(text, { pair: text_pair, add_special_tokens }); const input_ids = this.model.convert_tokens_to_ids(tokens); const result = { input_ids, attention_mask: new Array(input_ids.length).fill(1), } if ((return_token_type_ids ?? this.return_token_type_ids) && token_type_ids) { result.token_type_ids = token_type_ids; } return result; } /** * Internal helper function to tokenize a text, and optionally a pair of texts. * @param {string} text The text to tokenize. * @param {Object} options An optional object containing the following properties: * @param {string} [options.pair=null] The optional second text to tokenize. * @param {boolean} [options.add_special_tokens=false] Whether or not to add the special tokens associated with the corresponding model. * @returns {{tokens: string[], token_type_ids?: number[]}} An object containing the tokens and optionally the token type IDs. */ _tokenize_helper(text, { pair = null, add_special_tokens = false, } = {}) { const tokens = this._encode_text(text); const tokens2 = this._encode_text(pair); return this.post_processor ? this.post_processor(tokens, tokens2, { add_special_tokens }) : { tokens: mergeArrays(tokens ?? [], tokens2 ?? []) }; } /** * Converts a string into a sequence of tokens. * @param {string} text The sequence to be encoded. * @param {Object} options An optional object containing the following properties: * @param {string} [options.pair] A second sequence to be encoded with the first. * @param {boolean} [options.add_special_tokens=false] Whether or not to add the special tokens associated with the corresponding model. * @returns {string[]} The list of tokens. */ tokenize(text, { pair = null, add_special_tokens = false, } = {}) { return this._tokenize_helper(text, { pair, add_special_tokens }).tokens; } /** * Encodes a single text or a pair of texts using the model's tokenizer. * * @param {string} text The text to encode. * @param {Object} options An optional object containing the following properties: * @param {string} [options.text_pair=null] The optional second text to encode. * @param {boolean} [options.add_special_tokens=true] Whether or not to add the special tokens associated with the corresponding model. * @param {boolean} [options.return_token_type_ids=null] Whether to return token_type_ids. * @returns {number[]} An array of token IDs representing the encoded text(s). */ encode(text, { text_pair = null, add_special_tokens = true, return_token_type_ids = null, } = {}) { return this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids, }).input_ids; } /** * Decode a batch of tokenized sequences. * @param {number[][]|Tensor} batch List/Tensor of tokenized input sequences. * @param {Object} decode_args (Optional) Object with decoding arguments. * @returns {string[]} List of decoded sequences. */ batch_decode(batch, decode_args = {}) { if (batch instanceof Tensor) { batch = batch.tolist(); } return batch.map(x => this.decode(x, decode_args)); } /** * Decodes a sequence of token IDs back to a string. * * @param {number[]|bigint[]|Tensor} token_ids List/Tensor of token IDs to decode. * @param {Object} [decode_args={}] * @param {boolean} [decode_args.skip_special_tokens=false] If true, special tokens are removed from the output string. * @param {boolean} [decode_args.clean_up_tokenization_spaces=true] If true, spaces before punctuations and abbreviated forms are removed. * * @returns {string} The decoded string. * @throws {Error} If `token_ids` is not a non-empty array of integers. */ decode( token_ids, decode_args = {}, ) { if (token_ids instanceof Tensor) { token_ids = prepareTensorForDecode(token_ids); } if (!Array.isArray(token_ids) || token_ids.length === 0 || !isIntegralNumber(token_ids[0])) { throw Error("token_ids must be a non-empty array of integers."); } return this.decode_single(token_ids, decode_args) } /** * Decode a single list of token ids to a string. * @param {number[]|bigint[]} token_ids List of token ids to decode * @param {Object} decode_args Optional arguments for decoding * @param {boolean} [decode_args.skip_special_tokens=false] Whether to skip special tokens during decoding * @param {boolean} [decode_args.clean_up_tokenization_spaces=null] Whether to clean up tokenization spaces during decoding. * If null, the value is set to `this.decoder.cleanup` if it exists, falling back to `this.clean_up_tokenization_spaces` if it exists, falling back to `true`. * @returns {string} The decoded string */ decode_single( token_ids, { skip_special_tokens = false, clean_up_tokenization_spaces = null, } ) { let tokens = this.model.convert_ids_to_tokens(token_ids); if (skip_special_tokens) { tokens = tokens.filter(x => !this.special_tokens.includes(x)); } // If `this.decoder` is null, we just join tokens with a space: // https://github.com/huggingface/tokenizers/blob/8edec536a737cb04494b454805be16c020abb14f/tokenizers/src/tokenizer/mod.rs#L835 /** @type {string} */ let decoded = this.decoder ? this.decoder(tokens) : tokens.join(' '); // Slight hack, but prevents having to pass `skip_special_tokens` to // each call to `decode`, which would lead to code duplication. if (this.decoder && this.decoder.end_of_word_suffix) { decoded = decoded.replaceAll(this.decoder.end_of_word_suffix, ' '); if (skip_special_tokens) { decoded = decoded.trim(); } } if (clean_up_tokenization_spaces ?? this.clean_up_tokenization_spaces) { decoded = clean_up_tokenization(decoded); } return decoded; } /** * Retrieve the chat template string used for tokenizing chat messages. This template is used * internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat * template for better generation tracking. * * @param {Object} options An optional object containing the following properties: * @param {string} [options.chat_template=null] * A Jinja template or the name of a template to use for this conversion. * It is usually not necessary to pass anything to this argument, * as the model's template will be used by default. * @param {Object[]} [options.tools=null] * A list of tools (callable functions) that will be accessible to the model. If the template does not * support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, * giving the name, description and argument types for the tool. See our * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) * for more information. * @returns {string} The chat template string. */ get_chat_template({ chat_template = null, tools = null, } = {}) { // First, handle the cases when the model has a dict of multiple templates if (this.chat_template && typeof this.chat_template === 'object') { const template_dict = this.chat_template; if (chat_template !== null && Object.hasOwn(template_dict, chat_template)) { // The user can pass the name of a template to the chat template argument instead of an entire template chat_template = template_dict[chat_template]; } else if (chat_template === null) { if (tools !== null && 'tool_use' in template_dict) { chat_template = template_dict['tool_use']; } else if ('default' in template_dict) { chat_template = template_dict['default']; } else { throw Error( `This model has multiple chat templates with no default specified! Please either pass a chat ` + `template or the name of the template you wish to use to the 'chat_template' argument. Available ` + `template names are ${Object.keys(template_dict).sort()}.` ) } } } else if (chat_template === null) { // These are the cases when the model has a single template // priority: `chat_template` argument > `tokenizer.chat_template` if (this.chat_template) { chat_template = this.chat_template; } else { throw Error( "Cannot use apply_chat_template() because tokenizer.chat_template is not set and no template " + "argument was passed! For information about writing templates and setting the " + "tokenizer.chat_template attribute, please see the documentation at " + "https://huggingface.co/docs/transformers/main/en/chat_templating" ) } } return chat_template; } /** * Converts a list of message objects with `"role"` and `"content"` keys to a list of token * ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to * determine the format and control tokens to use when converting. * * See [here](https://huggingface.co/docs/transformers/chat_templating) for more information. * * **Example:** Applying a chat template to a conversation. * * ```javascript * import { AutoTokenizer } from "@huggingface/transformers"; * * const tokenizer = await AutoTokenizer.from_pretrained("Xenova/mistral-tokenizer-v1"); * * const chat = [ * { "role": "user", "content": "Hello, how are you?" }, * { "role": "assistant", "content": "I'm doing great. How can I help you today?" }, * { "role": "user", "content": "I'd like to show off how chat templating works!" }, * ] * * const text = tokenizer.apply_chat_template(chat, { tokenize: false }); * // "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" * * const input_ids = tokenizer.apply_chat_template(chat, { tokenize: true, return_tensor: false }); * // [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793] * ``` * * @param {Message[]} conversation A list of message objects with `"role"` and `"content"` keys, * representing the chat history so far. * @param {Object} options An optional object containing the following properties: * @param {string} [options.chat_template=null] A Jinja template to use for this conversion. If * this is not passed, the model's chat template will be used instead. * @param {Object[]} [options.tools=null] * A list of tools (callable functions) that will be accessible to the model. If the template does not * support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema, * giving the name, description and argument types for the tool. See our * [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use) * for more information. * @param {Record[]} [options.documents=null] * A list of dicts representing documents that will be accessible to the model if it is performing RAG * (retrieval-augmented generation). If the template does not support RAG, this argument will have no * effect. We recommend that each document should be a dict containing "title" and "text" keys. Please * see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG) * for examples of passing documents with chat templates. * @param {boolean} [options.add_generation_prompt=false] Whether to end the prompt with the token(s) that indicate * the start of an assistant message. This is useful when you want to generate a response from the model. * Note that this argument will be passed to the chat template, and so it must be supported in the * template for this argument to have any effect. * @param {boolean} [options.tokenize=true] Whether to tokenize the output. If false, the output will be a string. * @param {boolean} [options.padding=false] Whether to pad sequences to the maximum length. Has no effect if tokenize is false. * @param {boolean} [options.truncation=false] Whether to truncate sequences to the maximum length. Has no effect if tokenize is false. * @param {number} [options.max_length=null] Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is false. * If not specified, the tokenizer's `max_length` attribute will be used as a default. * @param {boolean} [options.return_tensor=true] Whether to return the output as a Tensor or an Array. Has no effect if tokenize is false. * @param {boolean} [options.return_dict=true] Whether to return a dictionary with named outputs. Has no effect if tokenize is false. * @param {Object} [options.tokenizer_kwargs={}] Additional options to pass to the tokenizer. * @returns {string | Tensor | number[]| number[][]|BatchEncoding} The tokenized output. */ apply_chat_template(conversation, { tools = null, documents = null, chat_template = null, add_generation_prompt = false, tokenize = true, padding = false, truncation = false, max_length = null, return_tensor = true, return_dict = false, tokenizer_kwargs = {}, ...kwargs } = {}) { chat_template = this.get_chat_template({ chat_template, tools }); if (typeof chat_template !== 'string') { throw Error(`chat_template must be a string, but got ${typeof chat_template}`); } // Compilation function uses a cache to avoid recompiling the same template let compiledTemplate = this._compiled_template_cache.get(chat_template); if (compiledTemplate === undefined) { compiledTemplate = new Template(chat_template); this._compiled_template_cache.set(chat_template, compiledTemplate); } const special_tokens_map = Object.create(null); for (const key of SPECIAL_TOKEN_ATTRIBUTES) { const value = this.getToken(key); if (value) { special_tokens_map[key] = value; } } const rendered = compiledTemplate.render({ messages: conversation, add_generation_prompt, tools, documents, ...special_tokens_map, ...kwargs, }); if (tokenize) { const out = this._call(rendered, { add_special_tokens: false, padding, truncation, max_length, return_tensor, ...tokenizer_kwargs, }); return return_dict ? out : out.input_ids; } return rendered; } } /** * BertTokenizer is a class used to tokenize text for BERT models. * @extends PreTrainedTokenizer */ export class BertTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } /** * Albert tokenizer * @extends PreTrainedTokenizer */ export class AlbertTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } export class MobileBertTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } export class SqueezeBertTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } export class DebertaTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } export class DebertaV2Tokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } export class HerbertTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } export class ConvBertTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } export class RoFormerTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } export class DistilBertTokenizer extends PreTrainedTokenizer { } export class CamembertTokenizer extends PreTrainedTokenizer { } export class XLMTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); console.warn('WARNING: `XLMTokenizer` is not yet supported by Hugging Face\'s "fast" tokenizers library. Therefore, you may experience slightly inaccurate results.') } } export class ElectraTokenizer extends PreTrainedTokenizer { return_token_type_ids = true; } export class T5Tokenizer extends PreTrainedTokenizer { } export class GPT2Tokenizer extends PreTrainedTokenizer { } export class BartTokenizer extends PreTrainedTokenizer { } export class MBartTokenizer extends PreTrainedTokenizer { constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); this.languageRegex = /^[a-z]{2}_[A-Z]{2}$/; this.language_codes = this.special_tokens.filter(x => this.languageRegex.test(x)); this.lang_to_token = x => x; // Identity function } /** * Helper function to build translation inputs for an `MBartTokenizer`. * @param {string|string[]} raw_inputs The text to tokenize. * @param {Object} tokenizer_options Options to be sent to the tokenizer * @param {Object} generate_kwargs Generation options. * @returns {Object} Object to be passed to the model. */ _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) { return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs); } } export class MBart50Tokenizer extends MBartTokenizer { } // NOTE: extends MBartTokenizer export class RobertaTokenizer extends PreTrainedTokenizer { } export class BloomTokenizer extends PreTrainedTokenizer { } const SPIECE_UNDERLINE = "▁"; export class LlamaTokenizer extends PreTrainedTokenizer { padding_side = 'left'; constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); this.legacy = tokenizerConfig.legacy ?? true; if (!this.legacy) { // See https://github.com/huggingface/transformers/pull/24565 for more information this.normalizer = null; this.pre_tokenizer = new MetaspacePreTokenizer({ replacement: SPIECE_UNDERLINE, add_prefix_space: true, prepend_scheme: "first", }); } } /** * Helper function to handle legacy encoding of SPM tokenizers. * Adapted from https://github.com/huggingface/transformers/blob/e6dcf8abd6f65bb4b6dfc1831b20d9ba49ce00e2/src/transformers/models/t5/tokenization_t5.py#L374-L387 * @param {string} text The text to encode. * @returns {string[]} The encoded tokens. */ _encode_text(text) { if (text === null) return null; if (this.legacy || text.length === 0) { return super._encode_text(text); } let tokens = super._encode_text(SPIECE_UNDERLINE + text.replaceAll(SPIECE_UNDERLINE, " ")); if (tokens.length > 1 && tokens[0] === SPIECE_UNDERLINE && this.special_tokens.includes(tokens[1])) { tokens = tokens.slice(1); } return tokens; } } export class CodeLlamaTokenizer extends PreTrainedTokenizer { } export class XLMRobertaTokenizer extends PreTrainedTokenizer { } export class MPNetTokenizer extends PreTrainedTokenizer { } export class FalconTokenizer extends PreTrainedTokenizer { } export class GPTNeoXTokenizer extends PreTrainedTokenizer { } export class EsmTokenizer extends PreTrainedTokenizer { } export class Qwen2Tokenizer extends PreTrainedTokenizer { } export class GemmaTokenizer extends PreTrainedTokenizer { } export class Grok1Tokenizer extends PreTrainedTokenizer { } /** * Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`. * @param {PreTrainedTokenizer} self The tokenizer instance. * @param {string|string[]} raw_inputs The text to tokenize. * @param {Object} tokenizer_options Options to be sent to the tokenizer * @param {Object} generate_kwargs Generation options. * @returns {Object} Object to be passed to the model. * @private */ function _build_translation_inputs(self, raw_inputs, tokenizer_options, generate_kwargs) { if (!('language_codes' in self) || !Array.isArray(self.language_codes)) { throw new Error('Tokenizer must have `language_codes` attribute set and it should be an array of language ids.') } if (!('languageRegex' in self) || !(self.languageRegex instanceof RegExp)) { throw new Error('Tokenizer must have `languageRegex` attribute set and it should be a regular expression.') } if (!('lang_to_token' in self) || typeof self.lang_to_token !== 'function') { throw new Error('Tokenizer must have `lang_to_token` attribute set and it should be a function.') } const src_lang_token = generate_kwargs.src_lang; const tgt_lang_token = generate_kwargs.tgt_lang; // Check that the target language is valid: if (!self.language_codes.includes(tgt_lang_token)) { throw new Error(`Target language code "${tgt_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`); } // Allow `src_lang` to be optional. If not set, we'll use the tokenizer's default. if (src_lang_token !== undefined) { // Check that the source language is valid: if (!self.language_codes.includes(src_lang_token)) { throw new Error(`Source language code "${src_lang_token}" is not valid. Must be one of: {${self.language_codes.join(', ')}}`); } // In the same way as the Python library, we override the post-processor // to force the source language to be first: for (const item of self.post_processor.config.single) { if ('SpecialToken' in item && self.languageRegex.test(item.SpecialToken.id)) { item.SpecialToken.id = self.lang_to_token(src_lang_token); break; } } // TODO: Do the same for pair? } // Override the `forced_bos_token_id` to force the correct language generate_kwargs.forced_bos_token_id = self.model.convert_tokens_to_ids([self.lang_to_token(tgt_lang_token)])[0]; return self._call(raw_inputs, tokenizer_options); } /** * The NllbTokenizer class is used to tokenize text for NLLB ("No Language Left Behind") models. * * No Language Left Behind (NLLB) is a first-of-its-kind, AI breakthrough project * that open-sources models capable of delivering high-quality translations directly * between any pair of 200+ languages — including low-resource languages like Asturian, * Luganda, Urdu and more. It aims to help people communicate with anyone, anywhere, * regardless of their language preferences. For more information, check out their * [paper](https://huggingface.co/papers/2207.04672). * * For a list of supported languages (along with their language codes), * @see {@link https://github.com/facebookresearch/flores/blob/main/flores200/README.md#languages-in-flores-200} */ export class NllbTokenizer extends PreTrainedTokenizer { constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); this.languageRegex = /^[a-z]{3}_[A-Z][a-z]{3}$/; this.language_codes = this.special_tokens.filter(x => this.languageRegex.test(x)); this.lang_to_token = x => x; // Identity function } /** * Helper function to build translation inputs for an `NllbTokenizer`. * @param {string|string[]} raw_inputs The text to tokenize. * @param {Object} tokenizer_options Options to be sent to the tokenizer * @param {Object} generate_kwargs Generation options. * @returns {Object} Object to be passed to the model. */ _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) { return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs); } } /** * The M2M100Tokenizer class is used to tokenize text for M2M100 ("Many-to-Many") models. * * M2M100 is a multilingual encoder-decoder (seq-to-seq) model trained for Many-to-Many * multilingual translation. It was introduced in this [paper](https://huggingface.co/papers/2010.11125) * and first released in [this](https://github.com/pytorch/fairseq/tree/master/examples/m2m_100) repository. * * For a list of supported languages (along with their language codes), * @see {@link https://huggingface.co/facebook/m2m100_418M#languages-covered} */ export class M2M100Tokenizer extends PreTrainedTokenizer { constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); this.languageRegex = /^__[a-z]{2,3}__$/; this.language_codes = this.special_tokens .filter(x => this.languageRegex.test(x)) .map(x => x.slice(2, -2)); this.lang_to_token = x => `__${x}__`; } /** * Helper function to build translation inputs for an `M2M100Tokenizer`. * @param {string|string[]} raw_inputs The text to tokenize. * @param {Object} tokenizer_options Options to be sent to the tokenizer * @param {Object} generate_kwargs Generation options. * @returns {Object} Object to be passed to the model. */ _build_translation_inputs(raw_inputs, tokenizer_options, generate_kwargs) { return _build_translation_inputs(this, raw_inputs, tokenizer_options, generate_kwargs); } } /** * WhisperTokenizer tokenizer * @extends PreTrainedTokenizer */ export class WhisperTokenizer extends PreTrainedTokenizer { get timestamp_begin() { return this.model.convert_tokens_to_ids(["<|notimestamps|>"])[0] + 1; } /** * Decodes automatic speech recognition (ASR) sequences. * @param {Array<{tokens: bigint[], token_timestamps?: number[], stride: number[]}>} sequences The sequences to decode. * @param {Object} options The options to use for decoding. * @returns {Array, text: string}>}>} The decoded sequences. */ _decode_asr(sequences, { return_timestamps = false, return_language = false, time_precision = null, force_full_sequences = true } = {}) { // Set force_full_sequences=false if you want streaming // TODO add support for `return_language` // Internal method meant to only be used by asr pipeline. // Handles all the little quirks specific to whisper to handle // the various options not allowed in other seq2seq models // =========== Overview ============ // - iterate over all outputs // - all tokens within output // - Each token can be // - language token // - special token // - timestamp token // - text token // - We accumulate the text tokens. // - We split on end timestamps // - Lots of complexity comes from stride and timestamps if (time_precision === null) { throw Error("Must specify time_precision") } let last_language = null; const returnWordTimestamps = return_timestamps === "word"; function new_chunk() { return { "language": last_language, "timestamp": [null, null], "text": "" }; } // Welcome to the state machine! const chunks = []; let chunk = new_chunk(); let time_offset = 0.0; const timestamp_begin = this.timestamp_begin; // Whisper timestamp tokens start from 0.00 and go to timestamp 30.00 in 0.02 increments. // We can calculate the last time stamp token as timestamp_begin plus the number of tokens // tokens from 0.00 to 30.00 which is 1500. const total_timestamp_tokens = 1500; // (30.00 - 0.00) / 0.02 const timestamp_end = timestamp_begin + total_timestamp_tokens; let previous_tokens = []; let previous_token_timestamps = []; let skip = false; let right_stride_start = null; const all_special_ids = new Set(this.all_special_ids); for (const output of sequences) { // NOTE: python version has batches, so it uses [0] const token_ids = output.tokens; const token_timestamps = returnWordTimestamps ? output.token_timestamps : null; // These keep track of timestamps within strides, which need // to be skipped and resolve all tokens in a single chunk. let last_timestamp = null; let first_timestamp = timestamp_begin; if ("stride" in output) { const [chunk_len, stride_left, stride_right] = output.stride; // Offset the timings to account for the other `model_outputs`. time_offset -= stride_left; right_stride_start = chunk_len - stride_right; // Keeping track of timestamps within strides // We're going to NOT split on those, and delay until we're // out of BOTH stride. Otherwise lots of issues occur and // corner cases if (stride_left) { first_timestamp = stride_left / time_precision + timestamp_begin; } if (stride_right) { for (let i = token_ids.length - 1; i >= 0; --i) { const token = Number(token_ids[i]); if (token >= timestamp_begin) { // There can be several token in the right stride // But the last one is ALWAYS going to be skipped if (last_timestamp !== null && (token - timestamp_begin) * time_precision < right_stride_start) { break; } last_timestamp = token; } } } } let current_tokens = []; let current_token_timestamps = []; // - all tokens within output for (let i = 0; i < token_ids.length; ++i) { const token = Number(token_ids[i]); // 4 possible states for each token // - 1/ Language code // - 2/ all other special tokens (which we ignore) // - 3/ Timestamp // - 4/ Regular text if (all_special_ids.has(token)) { const text = this.decode([token]); const language = WHISPER_LANGUAGE_MAPPING.get(text.slice(2, -2)); if (language !== undefined) { // 1/ Indeed some language // TODO Handle when language is different from the previous // one, and we cannot use timestamped tokens to create chunks if (last_language !== null && language !== last_language && !return_timestamps) { previous_tokens.push(current_tokens); const resolved_tokens = this.findLongestCommonSequence(previous_tokens)[0]; const resolved_text = this.decode(resolved_tokens); chunk.text = resolved_text; chunks.push(chunk); // Flush all our temporary context previous_tokens = []; current_tokens = []; chunk = new_chunk(); } last_language = chunk.language = language; } else { // 2/ This is a regular special token, ignoring it } } else if (token >= timestamp_begin && token <= timestamp_end) { // 3/ Timestamp token const time = (token - timestamp_begin) * time_precision + time_offset; const rounded_time = round(time, 2); if (last_timestamp !== null && token >= last_timestamp) { // Whisper outputted a timestamp token, but it falls within // our stride, so we're going to skip it for the time being // and resolve this later // Skip is necessary because timestamp tokens always come // by pair, so we need to skip the next one too (which would mark the start of another chunk). skip = true; } else if (skip || (previous_tokens.length > 0 && token < first_timestamp)) { skip = false; } else if (chunk.timestamp[0] === null) { chunk.timestamp[0] = rounded_time; } else { // This is the end of the timestamp chunk if (rounded_time === chunk.timestamp[0]) { // This is a bug in timestamp token output // where we're taking the duplicate token // as a stop where it should be a start. // This is an issue in the underlying model output // Let's just skip it so it becomes de-factor a start agin } else { chunk.timestamp[1] = rounded_time; // Handling merges previous_tokens.push(current_tokens) if (returnWordTimestamps) { previous_token_timestamps.push(current_token_timestamps); } const [resolved_tokens, resolved_token_timestamps] = this.findLongestCommonSequence( previous_tokens, previous_token_timestamps ) const resolved_text = this.decode(resolved_tokens) chunk.text = resolved_text if (returnWordTimestamps) { chunk.words = this.collateWordTimestamps( resolved_tokens, resolved_token_timestamps, last_language, ) } chunks.push(chunk) // Flush all our temporary context previous_tokens = [] current_tokens = [] previous_token_timestamps = [] current_token_timestamps = [] chunk = new_chunk() } } } else { // 4/ Regular token // We just append to the list of all tokens so we can handle // merges later and decode into text. current_tokens.push(token) if (returnWordTimestamps) { let start_time = round(token_timestamps[i] + time_offset, 2); let end_time; if (i + 1 < token_timestamps.length) { end_time = round(token_timestamps[i + 1] + time_offset, 2); // Do not allow punctuation-only tokens to have a duration. // This prevents long pauses from messing up the timestamps. const decoded_text = this.decode([token]); if (PUNCTUATION_ONLY_REGEX.test(decoded_text)) { // Add `time_precision` to avoid overlapping timestamps end_time = round(Math.min(start_time + time_precision, end_time), 2); } } else { // should never happen end_time = null; } current_token_timestamps.push([start_time, end_time]); } } } if ('stride' in output) { const [chunk_len, stride_left, stride_right] = output.stride; time_offset += chunk_len - stride_right } // Leftover tokens if (current_tokens.length > 0) { previous_tokens.push(current_tokens) if (returnWordTimestamps) { previous_token_timestamps.push(current_token_timestamps); } } else if (previous_tokens.every(p => p.length === 0)) { // Flushing previous tokens (END)" chunk = new_chunk() previous_tokens = [] current_tokens = [] previous_token_timestamps = []; current_token_timestamps = []; } } if (previous_tokens.length > 0) { if (force_full_sequences && return_timestamps) { // Last token should always be timestamps, so there shouldn't be // leftover throw new Error( "Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. " + "Also make sure WhisperTimeStampLogitsProcessor was used during generation." ); } // Happens when we don't use timestamps const [resolved_tokens, resolved_token_timestamps] = this.findLongestCommonSequence(previous_tokens, previous_token_timestamps); // Flushing previous tokens (FINAL) const resolved_text = this.decode(resolved_tokens); chunk.text = resolved_text; if (returnWordTimestamps) { chunk.words = this.collateWordTimestamps( resolved_tokens, resolved_token_timestamps, last_language, ) } chunks.push(chunk); } let optional = Object.create(null); // Preparing and cleaning up the pipeline output const full_text = chunks.map(chunk => chunk.text).join(''); if (return_timestamps || return_language) { for (let i = 0; i < chunks.length; ++i) { const chunk = chunks[i]; if (!return_timestamps) { delete chunk["timestamp"]; } if (!return_language) { delete chunk["language"]; } } if (returnWordTimestamps) { const new_chunks = []; for (const chunk of chunks) { for (const word of chunk.words) { new_chunks.push(word); } } optional = { "chunks": new_chunks }; } else { optional = { "chunks": chunks }; } } return [full_text, optional]; } /** * Finds the longest common sequence among the provided sequences. * @param {number[][]} sequences An array of sequences of token ids to compare. * @returns {number[][]} The longest common sequence found. * @throws {Error} If there is a bug within the function. * @private */ findLongestCommonSequence(sequences, token_timestamp_sequences = null) { // It would be much harder to do O(n) because of fault tolerance. // We actually have a really good property which is that the total sequence // MUST be those subsequences in order. // If token_timestamp_sequences is provided, will split those sequences in // exactly the same way. let leftSequence = sequences[0]; let leftLength = leftSequence.length; let totalSequence = []; const use_token_timestamp_sequences = Array.isArray(token_timestamp_sequences) && token_timestamp_sequences.length > 0; let total_token_timestamp_sequence = use_token_timestamp_sequences ? [] : null; let left_token_timestamp_sequence = use_token_timestamp_sequences ? token_timestamp_sequences[0] : null; for (let i = 1; i < sequences.length; ++i) { const rightSequence = sequences[i]; let max = 0.0; let maxIndices = [leftLength, leftLength, 0, 0]; // Here we're sliding matches // [a, b, c, d] // [c, d, f] // = [c] == [d] // [a, b, c, d] // [c, d, f] // = [c, d] == [c, d] // [a, b, c, d] // [c, d, f] // = [b, c, d] == [c, d, f] // [a, b, c, d] // [c, d, f] // [a, b, c] == [c, d, f] // [a, b, c, d] // [d, f] // [a, b] == [d, f] // [a, b, c, d] // [f] // [a] == [f] const rightLength = rightSequence.length; for (let j = 1; j < leftLength + rightLength; ++j) { // Slightly convoluted because we don't want out of bound indices // This will be necessary for a small conflict resolution optimization // later const leftStart = Math.max(0, leftLength - j); const leftStop = Math.min(leftLength, leftLength + rightLength - j); const left = leftSequence.slice(leftStart, leftStop); const rightStart = Math.max(0, j - leftLength); const rightStop = Math.min(rightLength, j); const right = rightSequence.slice(rightStart, rightStop); if (left.length !== right.length) { throw new Error("There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."); } let matches; if (use_token_timestamp_sequences) { // Get length of longest subsequence of tokens that match // and have timestamps that are in order matches = left.filter((elem, idx) => ( elem === right[idx] && left_token_timestamp_sequence[leftStart + idx] <= token_timestamp_sequences[i][rightStart + idx] )).length; } else { matches = left.filter((elem, idx) => elem === right[idx]).length; } // epsilon to favor long perfect matches const eps = j / 10000.0; const matching = matches / j + eps; if (matches > 1 && matching > max) { max = matching; maxIndices = [leftStart, leftStop, rightStart, rightStop]; } } const [leftStart, leftStop, rightStart, rightStop] = maxIndices; const leftMid = Math.floor((leftStop + leftStart) / 2); const rightMid = Math.floor((rightStop + rightStart) / 2); totalSequence.push(...leftSequence.slice(0, leftMid)); leftSequence = rightSequence.slice(rightMid); leftLength = leftSequence.length; if (use_token_timestamp_sequences) { total_token_timestamp_sequence.push(...left_token_timestamp_sequence.slice(0, leftMid)); left_token_timestamp_sequence = token_timestamp_sequences[i].slice(rightMid); } } totalSequence.push(...leftSequence); if (use_token_timestamp_sequences) { total_token_timestamp_sequence.push(...left_token_timestamp_sequence); return [totalSequence, total_token_timestamp_sequence]; } else { return [totalSequence, []]; } } /** @private */ collateWordTimestamps(tokens, token_timestamps, language) { const [words, _, token_indices] = this.combineTokensIntoWords(tokens, language); const timings = []; for (let i = 0; i < words.length; ++i) { const indices = token_indices[i]; timings.push({ text: words[i], timestamp: [ token_timestamps[indices.at(0)][0], token_timestamps[indices.at(-1)][1], ], }); } return timings; } /** * Groups tokens by word. Returns a tuple containing a list of strings with the words, * and a list of `token_id` sequences with the tokens making up each word. * @param {number[]} tokens * @param {string} [language] * @param {string} prepend_punctionations * @param {string} append_punctuations * * @private */ combineTokensIntoWords(tokens, language, prepend_punctionations = "\"'“¡¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、") { language = language ?? 'english'; let words, word_tokens, token_indices; if (["chinese", "japanese", "thai", "lao", "myanmar"].includes(language)) { // These languages don't typically use spaces. [words, word_tokens, token_indices] = this.splitTokensOnUnicode(tokens) } else { [words, word_tokens, token_indices] = this.splitTokensOnSpaces(tokens) } return this.mergePunctuations(words, word_tokens, token_indices, prepend_punctionations, append_punctuations); } /** @type {PreTrainedTokenizer['decode']} */ decode( token_ids, decode_args, ) { let text; // @ts-ignore if (decode_args?.decode_with_timestamps) { if (token_ids instanceof Tensor) { token_ids = prepareTensorForDecode(token_ids); } text = this.decodeWithTimestamps(token_ids, decode_args); } else { text = super.decode(token_ids, decode_args); } // TODO: implement offsets // if (decode_args.output_offsets) { // let offsets = this.computeOffsets // } return text; } /** * @param {number[]|bigint[]} token_ids List of token IDs to decode. * @param {Object} decode_args Optional arguments for decoding * @private */ decodeWithTimestamps(token_ids, decode_args) { const time_precision = decode_args?.time_precision ?? 0.02; const timestamp_begin = Array.from(this.all_special_ids).at(-1) + 1; /**@type {Array} */ let outputs = [[]]; for (let token of token_ids) { token = Number(token); if (token >= timestamp_begin) { const timestamp = ((token - timestamp_begin) * time_precision).toFixed(2); outputs.push(`<|${timestamp}|>`); outputs.push([]); } else { outputs[outputs.length - 1].push(token); } } outputs = outputs.map( s => typeof s === 'string' ? s : super.decode(s, decode_args) ) return outputs.join(''); } /** * Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points. * @param {number[]} tokens * @returns {*} * @private */ splitTokensOnUnicode(tokens) { const decoded_full = this.decode(tokens, { // @ts-ignore decode_with_timestamps: true, }); const replacement_char = '\uFFFD'; const words = [] const word_tokens = [] const token_indices = [] let current_tokens = [] let current_indices = [] let unicode_offset = 0 for (let token_idx = 0; token_idx < tokens.length; ++token_idx) { const token = tokens[token_idx]; current_tokens.push(token); current_indices.push(token_idx); const decoded = this.decode(current_tokens, { // @ts-ignore decode_with_timestamps: true, }); if (!decoded.includes(replacement_char) || decoded_full[unicode_offset + decoded.indexOf(replacement_char)] === replacement_char) { words.push(decoded) word_tokens.push(current_tokens) token_indices.push(current_indices) current_tokens = [] current_indices = [] unicode_offset += decoded.length; } } return [words, word_tokens, token_indices] } /** * Combine tokens into words by splitting at whitespace and punctuation tokens. * @param {number[]} tokens * @private */ splitTokensOnSpaces(tokens) { const [subwords, subword_tokens_list, subword_indices_list] = this.splitTokensOnUnicode(tokens); const words = [] const word_tokens = [] const token_indices = [] const punctuationRegex = new RegExp(`^[${PUNCTUATION_REGEX}]$`, 'gu'); for (let i = 0; i < subwords.length; ++i) { const subword = subwords[i]; const subword_tokens = subword_tokens_list[i]; const subword_indices = subword_indices_list[i]; // @ts-ignore const special = subword_tokens[0] >= this.model.tokens_to_ids.get('<|endoftext|>'); const with_space = subword.startsWith(' '); const trimmed = subword.trim(); const punctuation = punctuationRegex.test(trimmed); if (special || with_space || punctuation || words.length === 0) { words.push(subword); word_tokens.push(subword_tokens); token_indices.push(subword_indices); } else { const ix = words.length - 1; words[ix] += subword; word_tokens[ix].push(...subword_tokens); token_indices[ix].push(...subword_indices); } } return [words, word_tokens, token_indices]; } /** * Merges punctuation tokens with neighboring words. * @param {string[]} words * @param {number[][]} tokens * @param {number[][]} indices * @param {string} prepended * @param {string} appended * @private */ mergePunctuations(words, tokens, indices, prepended, appended) { const newWords = structuredClone(words); const newTokens = structuredClone(tokens); const newIndices = structuredClone(indices); // prepend punctuations let i = newWords.length - 2; let j = newWords.length - 1; while (i >= 0) { if (newWords[i].startsWith(' ') && prepended.includes(newWords[i].trim())) { newWords[j] = newWords[i] + newWords[j]; newTokens[j] = mergeArrays(newTokens[i], newTokens[j]); newIndices[j] = mergeArrays(newIndices[i], newIndices[j]); newWords[i] = ''; newTokens[i] = []; newIndices[i] = []; } else { j = i; } --i; } // append punctuations i = 0; j = 1; while (j < newWords.length) { if (!newWords[i].endsWith(' ') && appended.includes(newWords[j])) { newWords[i] += newWords[j]; newTokens[i] = mergeArrays(newTokens[i], newTokens[j]); newIndices[i] = mergeArrays(newIndices[i], newIndices[j]); newWords[j] = ''; newTokens[j] = []; newIndices[j] = []; } else { i = j; } ++j; } return [ newWords.filter(x => x), newTokens.filter(x => x.length > 0), newIndices.filter(x => x.length > 0), ] } } export class CodeGenTokenizer extends PreTrainedTokenizer { } export class CLIPTokenizer extends PreTrainedTokenizer { } export class SiglipTokenizer extends PreTrainedTokenizer { } /** * @todo This model is not yet supported by Hugging Face's "fast" tokenizers library (https://github.com/huggingface/tokenizers). * Therefore, this implementation (which is based on fast tokenizers) may produce slightly inaccurate results. */ export class MarianTokenizer extends PreTrainedTokenizer { /** * Create a new MarianTokenizer instance. * @param {Object} tokenizerJSON The JSON of the tokenizer. * @param {Object} tokenizerConfig The config of the tokenizer. */ constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); this.languageRegex = /^(>>\w+<<)\s*/g; this.supported_language_codes = this.model.vocab.filter( x => this.languageRegex.test(x) ); console.warn('WARNING: `MarianTokenizer` is not yet supported by Hugging Face\'s "fast" tokenizers library. Therefore, you may experience slightly inaccurate results.') } /** * Encodes a single text. Overriding this method is necessary since the language codes * must be removed before encoding with sentencepiece model. * @see https://github.com/huggingface/transformers/blob/12d51db243a00726a548a43cc333390ebae731e3/src/transformers/models/marian/tokenization_marian.py#L204-L213 * * @param {string|null} text The text to encode. * @returns {Array} The encoded tokens. */ _encode_text(text) { if (text === null) return null; // Check if text starts with language code: const [matchInfo, ...remainder] = text.trim().split(this.languageRegex); if (remainder.length === 0) { // No language code, encode normally return super._encode_text(matchInfo); } else if (remainder.length === 2) { // Text starts with language code, so we do not encode it with sentencepiece. const [language, text] = remainder; if (!this.supported_language_codes.includes(language)) { console.warn(`Unsupported language code "${language}" detected, which may lead to unexpected behavior. Should be one of: ${JSON.stringify(this.supported_language_codes)}`) } return mergeArrays([language], super._encode_text(text)); } } } export class Wav2Vec2CTCTokenizer extends PreTrainedTokenizer { } export class BlenderbotTokenizer extends PreTrainedTokenizer { } export class BlenderbotSmallTokenizer extends PreTrainedTokenizer { } export class SpeechT5Tokenizer extends PreTrainedTokenizer { } export class NougatTokenizer extends PreTrainedTokenizer { } export class VitsTokenizer extends PreTrainedTokenizer { constructor(tokenizerJSON, tokenizerConfig) { super(tokenizerJSON, tokenizerConfig); // Custom decoder function this.decoder = new VitsDecoder({}); } } export class CohereTokenizer extends PreTrainedTokenizer { } export class MgpstrTokenizer extends PreTrainedTokenizer { } export class Ernie4_5_Tokenizer extends PreTrainedTokenizer { } /** * Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function. * The chosen tokenizer class is determined by the type specified in the tokenizer config. * * @example * const tokenizer = await AutoTokenizer.from_pretrained('Xenova/bert-base-uncased'); */ export class AutoTokenizer { static TOKENIZER_CLASS_MAPPING = { T5Tokenizer, DistilBertTokenizer, CamembertTokenizer, DebertaTokenizer, DebertaV2Tokenizer, BertTokenizer, HerbertTokenizer, ConvBertTokenizer, RoFormerTokenizer, XLMTokenizer, ElectraTokenizer, MobileBertTokenizer, SqueezeBertTokenizer, AlbertTokenizer, GPT2Tokenizer, BartTokenizer, MBartTokenizer, MBart50Tokenizer, RobertaTokenizer, WhisperTokenizer, CodeGenTokenizer, CLIPTokenizer, SiglipTokenizer, MarianTokenizer, BloomTokenizer, NllbTokenizer, M2M100Tokenizer, LlamaTokenizer, CodeLlamaTokenizer, XLMRobertaTokenizer, MPNetTokenizer, FalconTokenizer, GPTNeoXTokenizer, EsmTokenizer, Wav2Vec2CTCTokenizer, BlenderbotTokenizer, BlenderbotSmallTokenizer, SpeechT5Tokenizer, NougatTokenizer, VitsTokenizer, Qwen2Tokenizer, GemmaTokenizer, Grok1Tokenizer, CohereTokenizer, MgpstrTokenizer, Ernie4_5_Tokenizer, // Base case: PreTrainedTokenizer, } /** * Instantiate one of the tokenizer classes of the library from a pretrained model. * * The tokenizer class to instantiate is selected based on the `tokenizer_class` property of the config object * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) * * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: * - A string, the *model id* of a pretrained tokenizer hosted inside a model repo on huggingface.co. * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a * user or organization name, like `dbmdz/bert-base-german-cased`. * - A path to a *directory* containing tokenizer files, e.g., `./my_model_directory/`. * @param {PretrainedTokenizerOptions} options Additional options for loading the tokenizer. * * @returns {Promise} A new instance of the PreTrainedTokenizer class. */ static async from_pretrained(pretrained_model_name_or_path, { progress_callback = null, config = null, cache_dir = null, local_files_only = false, revision = 'main', legacy = null, } = {}) { const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { progress_callback, config, cache_dir, local_files_only, revision, legacy, }) // Some tokenizers are saved with the "Fast" suffix, so we remove that if present. const tokenizerName = tokenizerConfig.tokenizer_class?.replace(/Fast$/, '') ?? 'PreTrainedTokenizer'; let cls = this.TOKENIZER_CLASS_MAPPING[tokenizerName]; if (!cls) { console.warn(`Unknown tokenizer class "${tokenizerName}", attempting to construct from base class.`); cls = PreTrainedTokenizer; } return new cls(tokenizerJSON, tokenizerConfig); } }