File size: 3,399 Bytes
d3c7f96
 
 
 
 
 
 
 
45f314a
d3c7f96
 
 
 
0c2fc21
 
 
 
 
 
 
 
 
 
 
 
d3c7f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45f314a
d3c7f96
 
 
 
 
 
 
 
 
 
0c2fc21
d3c7f96
 
 
45f314a
 
 
 
 
 
 
 
 
d3c7f96
 
 
 
 
 
 
 
 
0c2fc21
d3c7f96
 
0c2fc21
 
d3c7f96
 
 
 
 
 
 
 
 
 
 
0c2fc21
d3c7f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// web/src/worker.js

import {
  AutoProcessor,
  Gemma4ForConditionalGeneration,
  TextStreamer,
  InterruptableStoppingCriteria,
  load_image,
  RawImage,
} from "@huggingface/transformers";

const MODEL_ID = "onnx-community/gemma-4-E2B-it-ONNX";

const THINK_START = "‹‹THINK››";
const THINK_END = "‹‹/THINK››";

function cleanGemmaOutput(raw) {
  return raw
    .replace(/<\|?channel\|?>?\s*thought\s*/gi, THINK_START)
    .replace(/<\|?channell?\|?>/gi, THINK_END)
    .replace(/<\|?[a-z_]+\|?>/gi, "")
    .trim();
}


let processor = null;
let model = null;

const stoppingCriteria = new InterruptableStoppingCriteria();

async function checkWebGPU() {
  try {
    const adapter = await navigator.gpu?.requestAdapter();
    self.postMessage({
      type: "status",
      status: adapter ? "webgpu-available" : "webgpu-unavailable",
    });
  } catch {
    self.postMessage({ type: "status", status: "webgpu-unavailable" });
  }
}

async function loadModel() {
  try {
    self.postMessage({ type: "status", status: "loading" });

    const progress_callback = (p) => self.postMessage({ type: "progress", ...p });

    processor = await AutoProcessor.from_pretrained(MODEL_ID, { progress_callback });

    model = await Gemma4ForConditionalGeneration.from_pretrained(MODEL_ID, {
      dtype: "q4f16",
      device: "webgpu",
      progress_callback,
    });

    self.postMessage({ type: "status", status: "ready" });
  } catch (err) {
    self.postMessage({ type: "error", message: err.message });
  }
}

async function generate({ messages, imageUrl, videoData, audioData, enableThinking }) {
  if (!model || !processor) {
    self.postMessage({ type: "error", message: "Model not loaded" });
    return;
  }

  try {
    self.postMessage({ type: "status", status: "generating" });
    stoppingCriteria.reset();

    const prompt = processor.apply_chat_template(messages, {
      enable_thinking: enableThinking,
      add_generation_prompt: true,
    });

    // Gemma4ImageProcessor expects RawImage | RawImage[], not RawVideo
    let image = null;
    if (videoData) {
      image = videoData.frames.map((f) =>
        new RawImage(new Uint8ClampedArray(f.data), f.width, f.height, f.channels)
      );
    } else if (imageUrl) {
      image = await load_image(imageUrl);
    }
    const audio = audioData ?? null;

    const inputs = await processor(prompt, image, audio, {
      add_special_tokens: false,
    });

    let fullText = "";
    const streamer = new TextStreamer(processor.tokenizer, {
      skip_prompt: true,
      skip_special_tokens: false,
      callback_function: (text) => {
        fullText += text;
        const cleaned = cleanGemmaOutput(fullText);
        self.postMessage({ type: "update", text: cleaned });
      },
    });

    await model.generate({
      ...inputs,
      max_new_tokens: 512,
      do_sample: false,
      streamer,
      stopping_criteria: [stoppingCriteria],
    });

    self.postMessage({ type: "complete", text: cleanGemmaOutput(fullText) });
  } catch (err) {
    self.postMessage({ type: "error", message: err.message });
  }
}

self.onmessage = (e) => {
  switch (e.data.type) {
    case "check":
      checkWebGPU();
      break;
    case "load":
      loadModel();
      break;
    case "generate":
      generate(e.data);
      break;
    case "interrupt":
      stoppingCriteria.interrupt();
      break;
  }
};