Spaces:
Runtime error
Runtime error
| import { pipeline, env } from "@huggingface/transformers"; | |
| // Skip local model check | |
| env.allowLocalModels = false; | |
| async function supportsWebGPU() { | |
| try { | |
| if (!navigator.gpu) return false; | |
| await navigator.gpu.requestAdapter(); | |
| return true; | |
| } catch (e) { | |
| return false; | |
| } | |
| } | |
| const device = (await supportsWebGPU()) ? "webgpu" : "wasm"; | |
| class PipelineManager { | |
| static defaultConfigs = { | |
| "text-classification": { | |
| model: "onnx-community/rubert-tiny-sentiment-balanced-ONNX", | |
| }, | |
| "image-classification": { | |
| model: "onnx-community/mobilenet_v2_1.0_224", | |
| }, | |
| }; | |
| static instances = {}; // key: `${task}:${modelName}` -> pipeline instance | |
| static currentTask = "text-classification"; | |
| static currentModel = PipelineManager.defaultConfigs["text-classification"].model; | |
| static queue = []; | |
| static isProcessing = false; | |
| static async getInstance(task, modelName, progress_callback = null) { | |
| const key = `${task}:${modelName}`; | |
| if (!this.instances[key]) { | |
| self.postMessage({ status: "initiate", file: modelName, task }); | |
| this.instances[key] = await pipeline(task, modelName, { progress_callback, device: device}); | |
| self.postMessage({ status: "ready", file: modelName, task }); | |
| } | |
| return this.instances[key]; | |
| } | |
| static async processQueue() { | |
| if (this.isProcessing || this.queue.length === 0) return; | |
| this.isProcessing = true; | |
| const { input, task, modelName } = this.queue[this.queue.length - 1]; | |
| this.queue = []; | |
| try { | |
| const classifier = await this.getInstance(task, modelName, (x) => { | |
| self.postMessage({ | |
| ...x, | |
| status: x.status || "progress", | |
| file: x.file || modelName, | |
| name: modelName, | |
| task, | |
| loaded: x.loaded, | |
| total: x.total, | |
| progress: x.loaded && x.total ? (x.loaded / x.total) * 100 : 0, | |
| }); | |
| }); | |
| let output; | |
| if (task === "image-classification") { | |
| // input is a data URL or Blob | |
| output = await classifier(input, { top_k: 5 }); | |
| } else if (task === "automatic-speech-recognition") { | |
| output = await classifier(input); | |
| } else { | |
| output = await classifier(input, { top_k: 5 }); | |
| } | |
| self.postMessage({ | |
| status: "complete", | |
| output, | |
| file: modelName, | |
| task, | |
| }); | |
| } catch (error) { | |
| self.postMessage({ | |
| status: "error", | |
| error: error.message, | |
| file: modelName, | |
| task, | |
| }); | |
| } | |
| this.isProcessing = false; | |
| if (this.queue.length > 0) { | |
| this.processQueue(); | |
| } | |
| } | |
| } | |
| // Listen for messages from the main thread | |
| self.addEventListener("message", async (event) => { | |
| const { input, modelName, task, action } = event.data; | |
| // console.log("Worker received message:", event.data); // Add this line to log the received message t | |
| if (action === "load-model") { | |
| PipelineManager.currentTask = task || "text-classification"; | |
| PipelineManager.currentModel = | |
| modelName || | |
| PipelineManager.defaultConfigs[PipelineManager.currentTask].model; | |
| await PipelineManager.getInstance( | |
| PipelineManager.currentTask, | |
| PipelineManager.currentModel, | |
| (x) => { | |
| self.postMessage({ | |
| ...x, | |
| file: PipelineManager.currentModel, | |
| status: x.status || "progress", | |
| loaded: x.loaded, | |
| total: x.total, | |
| task: PipelineManager.currentTask, | |
| }); | |
| } | |
| ); | |
| return; | |
| } | |
| PipelineManager.queue.push({ | |
| input, | |
| task: task || PipelineManager.currentTask, | |
| modelName: modelName || PipelineManager.currentModel, | |
| }); | |
| PipelineManager.processQueue(); | |
| }); |