File size: 4,671 Bytes
ca97aa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

/**
 * @module generation/stopping_criteria
 */

import { Callable } from "../utils/generic.js";

// NOTE:
// Stopping Criteria returns a list of `batch_size` booleans, indicating whether each sequence in the batch should be stopped.

/**
 * Abstract base class for all stopping criteria that can be applied during generation.
 */
export class StoppingCriteria extends Callable {
    /**
     * 
     * @param {number[][]} input_ids (`number[][]` of shape `(batch_size, sequence_length)`):
     * Indices of input sequence tokens in the vocabulary.
     * @param {number[][]} scores scores (`number[][]` of shape `(batch_size, config.vocab_size)`):
     * Prediction scores of a language modeling head. These can be scores for each vocabulary token before SoftMax
     * or scores for each vocabulary token after SoftMax.
     * @returns {boolean[]} A list of booleans indicating whether each sequence should be stopped.
     */
    _call(input_ids, scores) {
        throw Error("StoppingCriteria needs to be subclassed");
    }
}
/**
 */
export class StoppingCriteriaList extends Callable {
    /**
     * Constructs a new instance of `StoppingCriteriaList`.
     */
    constructor() {
        super();
        this.criteria = [];
    }

    /**
     * Adds a new stopping criterion to the list.
     *
     * @param {StoppingCriteria} item The stopping criterion to add.
     */
    push(item) {
        this.criteria.push(item);
    }

    /**
     * Adds multiple stopping criteria to the list.
     *
     * @param {StoppingCriteria|StoppingCriteriaList|StoppingCriteria[]} items The stopping criteria to add.
     */
    extend(items) {
        if (items instanceof StoppingCriteriaList) {
            items = items.criteria;
        } else if (items instanceof StoppingCriteria) {
            items = [items];
        }
        this.criteria.push(...items);
    }

    _call(input_ids, scores) {
        const is_done = new Array(input_ids.length).fill(false);
        for (const criterion of this.criteria) {
            const criterion_done = criterion(input_ids, scores);
            for (let i = 0; i < is_done.length; ++i) {
                is_done[i] ||= criterion_done[i];
            }
        }
        return is_done;
    }

    [Symbol.iterator]() {
        return this.criteria.values();
    }
}

/**
 * This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`.
 * Keep in mind for decoder-only type of transformers, this will include the initial prompted tokens.
 */
export class MaxLengthCriteria extends StoppingCriteria {

    /**
     * 
     * @param {number} max_length The maximum length that the output sequence can have in number of tokens.
     * @param {number} [max_position_embeddings=null] The maximum model length, as defined by the model's `config.max_position_embeddings` attribute.
     */
    constructor(max_length, max_position_embeddings = null) {
        super();
        this.max_length = max_length;
        this.max_position_embeddings = max_position_embeddings;
    }

    _call(input_ids) {
        return input_ids.map(ids => ids.length >= this.max_length);
    }
}

// TODO: add MaxTimeCriteria

/**
 * This class can be used to stop generation whenever the "end-of-sequence" token is generated.
 * By default, it uses the `model.generation_config.eos_token_id`.
 */
export class EosTokenCriteria extends StoppingCriteria {

    /**
     * 
     * @param {number|number[]} eos_token_id The id of the *end-of-sequence* token.
     * Optionally, use a list to set multiple *end-of-sequence* tokens.
     */
    constructor(eos_token_id) {
        super();
        if (!Array.isArray(eos_token_id)) {
            eos_token_id = [eos_token_id];
        }
        this.eos_token_id = eos_token_id;
    }

    /**
     * 
     * @param {number[][]} input_ids 
     * @param {number[][]} scores 
     * @returns {boolean[]}
     */
    _call(input_ids, scores) {
        return input_ids.map(ids => {
            const last = ids.at(-1);
            // NOTE: We use == instead of === to allow for number/bigint comparison
            return this.eos_token_id.some(eos_id => last == eos_id);
        });
    }
}

/**
 * This class can be used to stop generation whenever the user interrupts the process.
 */
export class InterruptableStoppingCriteria extends StoppingCriteria {
    constructor() {
        super();
        this.interrupted = false;
    }

    interrupt() {
        this.interrupted = true;
    }

    reset() {
        this.interrupted = false;
    }

    _call(input_ids, scores) {
        return new Array(input_ids.length).fill(this.interrupted);
    }
}