File size: 3,593 Bytes
cbb6a01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import {
  pipeline,
  TextStreamer,
  DynamicCache,
  InterruptableStoppingCriteria,
} from "@huggingface/transformers";

const MODEL_IDS = {
  "1.7b": "onnx-community/Bonsai-1.7B-ONNX",
};

async function check() {
  try {
    const adapter = await navigator.gpu?.requestAdapter();
    if (!adapter) throw new Error("WebGPU is not supported (no adapter found)");
  } catch (e) {
    self.postMessage({ status: "error", data: e.toString() });
  }
}

class TextGenerationPipeline {
  static instances = new Map();

  static getInstance(modelKey, progress_callback = null) {
    const modelId = MODEL_IDS[modelKey];
    if (!modelId) throw new Error(`Unknown model: ${modelKey}`);
    if (!this.instances.has(modelKey)) {
      this.instances.set(
        modelKey,
        pipeline("text-generation", modelId, {
          device: "webgpu",
          dtype: "q1",
          progress_callback,
        }),
      );
    }
    return this.instances.get(modelKey);
  }
}

const stopping_criteria = new InterruptableStoppingCriteria();
let past_key_values_cache = null;
let current_model_key = null;

function disposePastKeyValues() {
  past_key_values_cache?.dispose?.();
  past_key_values_cache = null;
}

async function load(modelKey) {
  if (current_model_key && current_model_key !== modelKey) {
    disposePastKeyValues();
  }
  current_model_key = modelKey;

  self.postMessage({ status: "loading", data: "Loading model..." });

  const generator = await TextGenerationPipeline.getInstance(
    modelKey,
    (info) => {
      if (info.status === "progress_total") {
        self.postMessage({
          status: "progress_total",
          progress: Number(info.progress ?? 0),
          loaded: Number(info.loaded ?? 0),
          total: Number(info.total ?? 0),
        });
      }
    },
  );

  self.postMessage({
    status: "loading",
    data: "Optimizing model for 1-bit execution",
  });

  const inputs = generator.tokenizer("a");
  await generator.model.generate({ ...inputs, max_new_tokens: 1 });

  self.postMessage({ status: "ready" });
}

async function generate(messages) {
  const generator = await TextGenerationPipeline.getInstance(current_model_key);

  let startTime;
  let numTokens = 0;
  let tps;

  const streamer = new TextStreamer(generator.tokenizer, {
    skip_prompt: true,
    skip_special_tokens: true,
    callback_function: (output) => {
      self.postMessage({ status: "update", output, tps, numTokens });
    },
    token_callback_function: () => {
      startTime ??= performance.now();
      if (numTokens++ > 0) {
        tps = (numTokens / (performance.now() - startTime)) * 1000;
      }
    },
  });

  self.postMessage({ status: "start" });

  past_key_values_cache ??= new DynamicCache();

  try {
    const output = await generator(messages, {
      max_new_tokens: 1024,
      do_sample: false,
      streamer,
      stopping_criteria,
      past_key_values: past_key_values_cache,
    });

    self.postMessage({
      status: "complete",
      output: output[0].generated_text.at(-1).content,
    });
  } catch (e) {
    self.postMessage({ status: "error", data: e.toString() });
  }
}

self.addEventListener("message", async (e) => {
  const { type, data } = e.data;
  switch (type) {
    case "check":
      check();
      break;
    case "load":
      load(data);
      break;
    case "generate":
      stopping_criteria.reset();
      generate(data);
      break;
    case "interrupt":
      stopping_criteria.interrupt();
      break;
    case "reset":
      disposePastKeyValues();
      stopping_criteria.reset();
      break;
  }
});