File size: 3,878 Bytes
e1a3c89
 
714c7da
3ea6625
 
 
 
 
 
 
e779f87
3ea6625
e779f87
3ea6625
 
 
 
 
 
 
 
 
 
 
 
 
296be2c
048bfac
 
 
 
 
 
 
 
 
 
 
 
 
 
241aff9
e1a3c89
296be2c
0726d27
3a27a81
e1a3c89
296be2c
e1a3c89
3ea6625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63f3bed
e1a3c89
714c7da
e1a3c89
 
296be2c
3b0c247
714c7da
 
 
 
 
3b0c247
714c7da
e1a3c89
048bfac
296be2c
 
714c7da
 
 
7b5fcf7
 
 
3ea6625
 
 
 
 
 
 
3a27a81
3ea6625
3a27a81
3ea6625
578c052
 
 
 
 
 
 
 
 
 
3ea6625
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
// OnDeviceService: uses Xenova's transformers.js to run a small causal LM in browser
// Uses ES module import for Xenova's transformers.js
import {pipeline} from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.8.0';


/**
 * On device llm inference service using transformers.js
 * TODO Implement this class!
 */
export class OnDeviceService {
    constructor({modelName = '', quantization = 'fp32'} = {}) {
        this.modelName = modelName;
        this.modelQuantization = quantization;
        this._ready = false;
        this._model = null;
    }


    /**
     * Load the model into memory to be ready for inference.
     * Download the model if not already cached. Cache the model for future use.
     *
     * @param progressCb
     * @returns {Promise<void>}
     */
    async load(progressCb) {
        console.log(`⬇️ Download Model '${this.modelName}'...`);
        // Provide a default progress callback if none is given
        const defaultProgressCb = (progress) => {
            if (progress && typeof progress === 'object') {
                if (progress.status) {
                    console.log(`[Model Loading] ${progress.status}`);
                }
                if (progress.loaded && progress.total) {
                    const percent = ((progress.loaded / progress.total) * 100).toFixed(1);
                    console.log(`[Model Loading] ${percent}% (${progress.loaded}/${progress.total} bytes)`);
                }
            } else {
                console.log(`[Model Loading] Progress:`, progress);
            }
        };

        this._model = await pipeline('text-generation', this.modelName, {
            progress_callback: progressCb || defaultProgressCb,
            device: 'webgpu', // run on WebGPU if available
            dtype: this.modelQuantization, // set model quantization
        });
        console.log(`✅ Model '${this.modelName}' loaded and ready.`);
        this._ready = true;
    }


    /**
     * Returns if the model is loaded and ready for inference
     * @returns {boolean}
     */
    isReady() {
        return this._ready;
    }


    /**
     * Perform inference on the on-device model
     * TODO Implement inference
     *
     * @param prompt - The input prompt string
     * @param maxNewTokens - Maximum number of new tokens to generate
     * @returns {Promise<string>}
     */
    async infer(prompt, {maxNewTokens = 50} = {}) {
        if (!this._ready || !this._model) {
            console.log("model not ready:", this._ready, this._model);
            throw new Error('Model not loaded. Call load() first.');
        }
        console.log("🔄 Running inference on-device for prompt:\n", prompt);

        const messages = [
            { role: "user", content: prompt },
        ];

        const output = await this._model(messages, {
            max_new_tokens: maxNewTokens,
            temperature: 0.2,
        });

        console.log("✅ Completed inference on-device for prompt:\n", prompt);

        // take last generated text which corresponds to the model's answer
        const generated_output = output[0]?.generated_text;
        const text = generated_output[generated_output.length - 1]?.content.trim() || '';

        // todo calculate input and output tokens
        return {answer: text, stats: {input_tokens: undefined, output_tokens: undefined}};
    }

    /**
     * Update configuration with new values
     *
     * @param modelName - The name of the model to use
     */
    updateConfig({modelName, quantization} = {}) {
        if (modelName) this.modelName = modelName;
        if (quantization) this.modelQuantization = quantization;
    }


    /**
     * Retrieve the name of the currently loaded model.
     *
     * @returns {string} - The name of the model as a string.
     */
    getModelName(){
        return this.modelName;
    }
}