|
|
|
|
|
|
|
|
import { pipeline } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.7.0/dist/transformers.min.js'; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export class OnDeviceService { |
|
|
constructor({modelName = 'Xenova/distilgpt2'} = {}) { |
|
|
this.modelName = modelName; |
|
|
this._ready = false; |
|
|
this._model = null; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async load(progressCb) { |
|
|
console.log("Downloading model:", this.modelName); |
|
|
|
|
|
const defaultProgressCb = (progress) => { |
|
|
if (progress && typeof progress === 'object') { |
|
|
if (progress.status) { |
|
|
console.log(`[Model Loading] ${progress.status}`); |
|
|
} |
|
|
if (progress.loaded && progress.total) { |
|
|
const percent = ((progress.loaded / progress.total) * 100).toFixed(1); |
|
|
console.log(`[Model Loading] ${percent}% (${progress.loaded}/${progress.total} bytes)`); |
|
|
} |
|
|
} else { |
|
|
console.log(`[Model Loading] Progress:`, progress); |
|
|
} |
|
|
}; |
|
|
|
|
|
this._model = await pipeline('text-generation', this.modelName, { |
|
|
progress_callback: progressCb || defaultProgressCb |
|
|
}); |
|
|
console.log("Model loaded and ready."); |
|
|
this._ready = true; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
isReady() { |
|
|
return this._ready; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async infer(prompt, {maxNewTokens = 100} = {}) { |
|
|
if (!this._ready || !this._model) { |
|
|
console.log("model not ready:" , this._ready, this._model); |
|
|
throw new Error('Model not loaded. Call load() first.'); |
|
|
} |
|
|
console.log("running inference on-device:\n", prompt); |
|
|
|
|
|
const output = await this._model(prompt, { |
|
|
max_new_tokens: maxNewTokens, |
|
|
temperature: 1.5, |
|
|
repetition_penalty: 1.5, |
|
|
no_repeat_ngram_size: 2, |
|
|
num_beams: 1, |
|
|
num_return_sequences: 1, |
|
|
}); |
|
|
|
|
|
const text = output[0]?.generated_text?.trim() || ''; |
|
|
|
|
|
|
|
|
return {answer: text, stats: {input_tokens: undefined, output_tokens: undefined}}; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
updateConfig({modelName}) { |
|
|
if (modelName) this.modelName = modelName; |
|
|
} |
|
|
} |