shreyask Claude Opus 4.6 commited on
Commit
a6ac99b
·
verified ·
1 Parent(s): bdf3c61

fix: use AutoModelForCausalLM for reranker (yes/no token scoring)

Browse files

Matches the working pattern from shreyask/Qwen3-Reranker-0.6B-Web.
The reranker is a causal LM that scores query-document pairs by
comparing logits for "yes" vs "no" tokens, not a text-classification
pipeline.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. src/pipeline/models.ts +40 -14
src/pipeline/models.ts CHANGED
@@ -1,8 +1,11 @@
1
  import {
2
  pipeline,
 
 
3
  type FeatureExtractionPipeline,
4
  type TextGenerationPipeline,
5
- type TextClassificationPipeline,
 
6
  type ProgressInfo,
7
  } from "@huggingface/transformers";
8
  import { MODEL_EMBEDDING, MODEL_RERANKER, MODEL_EXPANSION } from "../constants";
@@ -12,9 +15,14 @@ type ProgressCallback = (state: ModelState) => void;
12
 
13
  // Singleton model instances
14
  let embeddingPipeline: FeatureExtractionPipeline | null = null;
15
- let rerankerPipeline: TextClassificationPipeline | null = null;
16
  let generationPipeline: TextGenerationPipeline | null = null;
17
 
 
 
 
 
 
 
18
  /** Check whether WebGPU is available in this browser. */
19
  export async function checkWebGPU(): Promise<boolean> {
20
  if (!navigator.gpu) return false;
@@ -101,18 +109,28 @@ export async function loadEmbeddingModel(
101
  export async function loadRerankerModel(
102
  onProgress?: ProgressCallback,
103
  ): Promise<void> {
104
- if (rerankerPipeline) return;
105
  const name = "reranker";
106
  onProgress?.({ name, status: "pending", progress: 0 });
107
  try {
108
- rerankerPipeline = await pipeline(
109
- "text-classification",
110
- MODEL_RERANKER,
111
- {
112
- device: "webgpu",
113
- progress_callback: makeProgressHandler(name, onProgress),
114
- },
115
- );
 
 
 
 
 
 
 
 
 
 
116
  onProgress?.({ name, status: "ready", progress: 1 });
117
  } catch (err) {
118
  onProgress?.({
@@ -184,8 +202,16 @@ export function getEmbeddingPipeline(): FeatureExtractionPipeline | null {
184
  return embeddingPipeline;
185
  }
186
 
187
- export function getRerankerPipeline(): TextClassificationPipeline | null {
188
- return rerankerPipeline;
 
 
 
 
 
 
 
 
189
  }
190
 
191
  export function getExpansionPipeline(): TextGenerationPipeline | null {
@@ -195,7 +221,7 @@ export function getExpansionPipeline(): TextGenerationPipeline | null {
195
  export function isAllModelsReady(): boolean {
196
  return (
197
  embeddingPipeline !== null &&
198
- rerankerPipeline !== null &&
199
  generationPipeline !== null
200
  );
201
  }
 
1
  import {
2
  pipeline,
3
+ AutoTokenizer,
4
+ AutoModelForCausalLM,
5
  type FeatureExtractionPipeline,
6
  type TextGenerationPipeline,
7
+ type PreTrainedTokenizer,
8
+ type PreTrainedModel,
9
  type ProgressInfo,
10
  } from "@huggingface/transformers";
11
  import { MODEL_EMBEDDING, MODEL_RERANKER, MODEL_EXPANSION } from "../constants";
 
15
 
16
  // Singleton model instances
17
  let embeddingPipeline: FeatureExtractionPipeline | null = null;
 
18
  let generationPipeline: TextGenerationPipeline | null = null;
19
 
20
+ // Reranker uses AutoModel + AutoTokenizer (not a pipeline)
21
+ let rerankerModel: PreTrainedModel | null = null;
22
+ let rerankerTokenizer: PreTrainedTokenizer | null = null;
23
+ let rerankerTokenYes = -1;
24
+ let rerankerTokenNo = -1;
25
+
26
  /** Check whether WebGPU is available in this browser. */
27
  export async function checkWebGPU(): Promise<boolean> {
28
  if (!navigator.gpu) return false;
 
109
  export async function loadRerankerModel(
110
  onProgress?: ProgressCallback,
111
  ): Promise<void> {
112
+ if (rerankerModel) return;
113
  const name = "reranker";
114
  onProgress?.({ name, status: "pending", progress: 0 });
115
  try {
116
+ const progressHandler = makeProgressHandler(name, onProgress);
117
+
118
+ // Load tokenizer and model separately (cross-encoder pattern)
119
+ rerankerTokenizer = await AutoTokenizer.from_pretrained(MODEL_RERANKER, {
120
+ progress_callback: progressHandler,
121
+ });
122
+
123
+ // Pre-compute "yes" and "no" token IDs for scoring
124
+ const yesIds = rerankerTokenizer("yes", { add_special_tokens: false }).input_ids.data;
125
+ const noIds = rerankerTokenizer("no", { add_special_tokens: false }).input_ids.data;
126
+ rerankerTokenYes = Number(yesIds[yesIds.length - 1]);
127
+ rerankerTokenNo = Number(noIds[noIds.length - 1]);
128
+
129
+ rerankerModel = await AutoModelForCausalLM.from_pretrained(MODEL_RERANKER, {
130
+ device: "webgpu",
131
+ progress_callback: progressHandler,
132
+ });
133
+
134
  onProgress?.({ name, status: "ready", progress: 1 });
135
  } catch (err) {
136
  onProgress?.({
 
202
  return embeddingPipeline;
203
  }
204
 
205
+ export function getRerankerModel(): PreTrainedModel | null {
206
+ return rerankerModel;
207
+ }
208
+
209
+ export function getRerankerTokenizer(): PreTrainedTokenizer | null {
210
+ return rerankerTokenizer;
211
+ }
212
+
213
+ export function getRerankerTokenIds(): { yes: number; no: number } {
214
+ return { yes: rerankerTokenYes, no: rerankerTokenNo };
215
  }
216
 
217
  export function getExpansionPipeline(): TextGenerationPipeline | null {
 
221
  export function isAllModelsReady(): boolean {
222
  return (
223
  embeddingPipeline !== null &&
224
+ rerankerModel !== null &&
225
  generationPipeline !== null
226
  );
227
  }