File size: 4,942 Bytes
ca97aa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import { Processor } from "../../base/processing_utils.js";
import { AutoImageProcessor } from "../auto/image_processing_auto.js";
import { AutoTokenizer } from "../../tokenizers.js";

export class Florence2Processor extends Processor {
    static tokenizer_class = AutoTokenizer
    static image_processor_class = AutoImageProcessor

    constructor(config, components, chat_template) {
        super(config, components, chat_template);

        const {
            // @ts-expect-error TS2339
            tasks_answer_post_processing_type,
            // @ts-expect-error TS2339
            task_prompts_without_inputs,
            // @ts-expect-error TS2339
            task_prompts_with_input,
        } = this.image_processor.config;

        /** @type {Map<string, string>} */
        this.tasks_answer_post_processing_type = new Map(Object.entries(tasks_answer_post_processing_type ?? {}));

        /** @type {Map<string, string>} */
        this.task_prompts_without_inputs = new Map(Object.entries(task_prompts_without_inputs ?? {}));

        /** @type {Map<string, string>} */
        this.task_prompts_with_input = new Map(Object.entries(task_prompts_with_input ?? {}));

        this.regexes = {
            quad_boxes: /(.+?)<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>/gm,
            bboxes: /([^<]+)?<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>/gm,
        }
        this.size_per_bin = 1000;
    }

    /**
     * Helper function to construct prompts from input texts
     * @param {string|string[]} text
     * @returns {string[]}
     */
    construct_prompts(text) {
        if (typeof text === 'string') {
            text = [text];
        }

        const prompts = [];
        for (const t of text) {
            // 1. fixed task prompts without additional inputs
            if (this.task_prompts_without_inputs.has(t)) {
                prompts.push(this.task_prompts_without_inputs.get(t));
            }
            // 2. task prompts with additional inputs 
            else {
                for (const [task, prompt] of this.task_prompts_with_input) {
                    if (t.includes(task)) {
                        prompts.push(prompt.replaceAll('{input}', t).replaceAll(task, ''));
                        break;
                    }
                }

                // 3. default prompt
                if (prompts.length !== text.length) {
                    prompts.push(t);
                }
            }
        }
        return prompts;
    }

    /**
     * Post-process the output of the model to each of the task outputs.
     * @param {string} text The text to post-process.
     * @param {string} task The task to post-process the text for.
     * @param {[number, number]} image_size The size of the image. height x width.
     */
    post_process_generation(text, task, image_size) {
        const task_answer_post_processing_type = this.tasks_answer_post_processing_type.get(task) ?? 'pure_text';

        // remove the special tokens
        text = text.replaceAll('<s>', '').replaceAll('</s>', '');

        let final_answer;
        switch (task_answer_post_processing_type) {
            case 'pure_text':
                final_answer = text;
                break;

            case 'description_with_bboxes':
            case 'bboxes':
            case 'phrase_grounding':
            case 'ocr':
                const key = task_answer_post_processing_type === 'ocr' ? 'quad_boxes' : 'bboxes';
                const matches = text.matchAll(this.regexes[key]);
                const labels = [];
                const items = [];
                for (const [_, label, ...locations] of matches) {
                    // Push new label, or duplicate the last label
                    labels.push(label ? label.trim() : labels.at(-1) ?? '');
                    items.push(locations.map((x, i) =>
                        // NOTE: Add 0.5 to use the center position of the bin as the coordinate.
                        (Number(x) + 0.5) / this.size_per_bin * image_size[i % 2])
                    );
                }
                final_answer = { labels, [key]: items };
                break;

            default:
                throw new Error(`Task "${task}" (of type "${task_answer_post_processing_type}") not yet implemented.`);
        }

        return { [task]: final_answer }
    }

    // NOTE: images and text are switched from the python version
    // `images` is required, `text` is optional
    async _call(images, text=null, kwargs = {}) {

        if (!images && !text){
            throw new Error('Either text or images must be provided');
        }

        const image_inputs = await this.image_processor(images, kwargs);
        const text_inputs = text ? this.tokenizer(this.construct_prompts(text), kwargs) : {};

        return {
            ...image_inputs,
            ...text_inputs,
        }
    }
}