shreyask Claude Opus 4.6 commited on
Commit
bdf3c61
·
verified ·
1 Parent(s): 4e354b6

add model loading layer for Transformers.js WebGPU pipelines

Browse files

Singleton loaders for embedding (embeddinggemma-300m), reranker
(Qwen3-Reranker-0.6B), and query expansion (1.7B Q4) models using
@huggingface/transformers v4 with WebGPU backend. Includes progress
reporting, WebGPU detection, and parallel loading via loadAllModels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. src/pipeline/models.ts +201 -0
src/pipeline/models.ts ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import {
2
+ pipeline,
3
+ type FeatureExtractionPipeline,
4
+ type TextGenerationPipeline,
5
+ type TextClassificationPipeline,
6
+ type ProgressInfo,
7
+ } from "@huggingface/transformers";
8
+ import { MODEL_EMBEDDING, MODEL_RERANKER, MODEL_EXPANSION } from "../constants";
9
+ import type { ModelState } from "../types";
10
+
11
+ type ProgressCallback = (state: ModelState) => void;
12
+
13
+ // Singleton model instances
14
+ let embeddingPipeline: FeatureExtractionPipeline | null = null;
15
+ let rerankerPipeline: TextClassificationPipeline | null = null;
16
+ let generationPipeline: TextGenerationPipeline | null = null;
17
+
18
+ /** Check whether WebGPU is available in this browser. */
19
+ export async function checkWebGPU(): Promise<boolean> {
20
+ if (!navigator.gpu) return false;
21
+ try {
22
+ const adapter = await navigator.gpu.requestAdapter();
23
+ return adapter !== null;
24
+ } catch {
25
+ return false;
26
+ }
27
+ }
28
+
29
+ // ---------------------------------------------------------------------------
30
+ // Internal: translate Transformers.js ProgressInfo → our ModelState
31
+ // ---------------------------------------------------------------------------
32
+
33
+ function makeProgressHandler(
34
+ modelName: string,
35
+ onProgress?: ProgressCallback,
36
+ ): ((info: ProgressInfo) => void) | undefined {
37
+ if (!onProgress) return undefined;
38
+
39
+ return (info: ProgressInfo) => {
40
+ switch (info.status) {
41
+ case "initiate":
42
+ case "download":
43
+ onProgress({
44
+ name: modelName,
45
+ status: "downloading",
46
+ progress: 0,
47
+ });
48
+ break;
49
+ case "progress":
50
+ onProgress({
51
+ name: modelName,
52
+ status: "downloading",
53
+ progress: (info as { progress: number }).progress / 100,
54
+ });
55
+ break;
56
+ case "done":
57
+ onProgress({
58
+ name: modelName,
59
+ status: "loading",
60
+ progress: 1,
61
+ });
62
+ break;
63
+ case "ready":
64
+ onProgress({
65
+ name: modelName,
66
+ status: "ready",
67
+ progress: 1,
68
+ });
69
+ break;
70
+ }
71
+ };
72
+ }
73
+
74
+ // ---------------------------------------------------------------------------
75
+ // Individual model loaders
76
+ // ---------------------------------------------------------------------------
77
+
78
+ export async function loadEmbeddingModel(
79
+ onProgress?: ProgressCallback,
80
+ ): Promise<void> {
81
+ if (embeddingPipeline) return;
82
+ const name = "embedding";
83
+ onProgress?.({ name, status: "pending", progress: 0 });
84
+ try {
85
+ embeddingPipeline = await pipeline("feature-extraction", MODEL_EMBEDDING, {
86
+ device: "webgpu",
87
+ progress_callback: makeProgressHandler(name, onProgress),
88
+ });
89
+ onProgress?.({ name, status: "ready", progress: 1 });
90
+ } catch (err) {
91
+ onProgress?.({
92
+ name,
93
+ status: "error",
94
+ progress: 0,
95
+ error: err instanceof Error ? err.message : String(err),
96
+ });
97
+ throw err;
98
+ }
99
+ }
100
+
101
+ export async function loadRerankerModel(
102
+ onProgress?: ProgressCallback,
103
+ ): Promise<void> {
104
+ if (rerankerPipeline) return;
105
+ const name = "reranker";
106
+ onProgress?.({ name, status: "pending", progress: 0 });
107
+ try {
108
+ rerankerPipeline = await pipeline(
109
+ "text-classification",
110
+ MODEL_RERANKER,
111
+ {
112
+ device: "webgpu",
113
+ progress_callback: makeProgressHandler(name, onProgress),
114
+ },
115
+ );
116
+ onProgress?.({ name, status: "ready", progress: 1 });
117
+ } catch (err) {
118
+ onProgress?.({
119
+ name,
120
+ status: "error",
121
+ progress: 0,
122
+ error: err instanceof Error ? err.message : String(err),
123
+ });
124
+ throw err;
125
+ }
126
+ }
127
+
128
+ export async function loadExpansionModel(
129
+ onProgress?: ProgressCallback,
130
+ ): Promise<void> {
131
+ if (generationPipeline) return;
132
+ const name = "expansion";
133
+ onProgress?.({ name, status: "pending", progress: 0 });
134
+ try {
135
+ generationPipeline = await pipeline(
136
+ "text-generation",
137
+ MODEL_EXPANSION,
138
+ {
139
+ dtype: "q4",
140
+ device: "webgpu",
141
+ progress_callback: makeProgressHandler(name, onProgress),
142
+ },
143
+ );
144
+ onProgress?.({ name, status: "ready", progress: 1 });
145
+ } catch (err) {
146
+ onProgress?.({
147
+ name,
148
+ status: "error",
149
+ progress: 0,
150
+ error: err instanceof Error ? err.message : String(err),
151
+ });
152
+ throw err;
153
+ }
154
+ }
155
+
156
+ // ---------------------------------------------------------------------------
157
+ // Load all models in parallel
158
+ // ---------------------------------------------------------------------------
159
+
160
+ export async function loadAllModels(
161
+ onProgress?: ProgressCallback,
162
+ ): Promise<void> {
163
+ const hasWebGPU = await checkWebGPU();
164
+ if (!hasWebGPU) {
165
+ const err = "WebGPU is not available in this browser";
166
+ for (const name of ["embedding", "reranker", "expansion"]) {
167
+ onProgress?.({ name, status: "error", progress: 0, error: err });
168
+ }
169
+ throw new Error(err);
170
+ }
171
+
172
+ await Promise.all([
173
+ loadEmbeddingModel(onProgress),
174
+ loadRerankerModel(onProgress),
175
+ loadExpansionModel(onProgress),
176
+ ]);
177
+ }
178
+
179
+ // ---------------------------------------------------------------------------
180
+ // Getters
181
+ // ---------------------------------------------------------------------------
182
+
183
+ export function getEmbeddingPipeline(): FeatureExtractionPipeline | null {
184
+ return embeddingPipeline;
185
+ }
186
+
187
+ export function getRerankerPipeline(): TextClassificationPipeline | null {
188
+ return rerankerPipeline;
189
+ }
190
+
191
+ export function getExpansionPipeline(): TextGenerationPipeline | null {
192
+ return generationPipeline;
193
+ }
194
+
195
+ export function isAllModelsReady(): boolean {
196
+ return (
197
+ embeddingPipeline !== null &&
198
+ rerankerPipeline !== null &&
199
+ generationPipeline !== null
200
+ );
201
+ }