fix: use AutoModelForCausalLM for reranker (yes/no token scoring)
Browse filesMatches the working pattern from shreyask/Qwen3-Reranker-0.6B-Web.
The reranker is a causal LM that scores query-document pairs by
comparing logits for "yes" vs "no" tokens, not a text-classification
pipeline.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- src/pipeline/models.ts +40 -14
src/pipeline/models.ts
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
import {
|
| 2 |
pipeline,
|
|
|
|
|
|
|
| 3 |
type FeatureExtractionPipeline,
|
| 4 |
type TextGenerationPipeline,
|
| 5 |
-
type
|
|
|
|
| 6 |
type ProgressInfo,
|
| 7 |
} from "@huggingface/transformers";
|
| 8 |
import { MODEL_EMBEDDING, MODEL_RERANKER, MODEL_EXPANSION } from "../constants";
|
|
@@ -12,9 +15,14 @@ type ProgressCallback = (state: ModelState) => void;
|
|
| 12 |
|
| 13 |
// Singleton model instances
|
| 14 |
let embeddingPipeline: FeatureExtractionPipeline | null = null;
|
| 15 |
-
let rerankerPipeline: TextClassificationPipeline | null = null;
|
| 16 |
let generationPipeline: TextGenerationPipeline | null = null;
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
/** Check whether WebGPU is available in this browser. */
|
| 19 |
export async function checkWebGPU(): Promise<boolean> {
|
| 20 |
if (!navigator.gpu) return false;
|
|
@@ -101,18 +109,28 @@ export async function loadEmbeddingModel(
|
|
| 101 |
export async function loadRerankerModel(
|
| 102 |
onProgress?: ProgressCallback,
|
| 103 |
): Promise<void> {
|
| 104 |
-
if (
|
| 105 |
const name = "reranker";
|
| 106 |
onProgress?.({ name, status: "pending", progress: 0 });
|
| 107 |
try {
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
onProgress?.({ name, status: "ready", progress: 1 });
|
| 117 |
} catch (err) {
|
| 118 |
onProgress?.({
|
|
@@ -184,8 +202,16 @@ export function getEmbeddingPipeline(): FeatureExtractionPipeline | null {
|
|
| 184 |
return embeddingPipeline;
|
| 185 |
}
|
| 186 |
|
| 187 |
-
export function
|
| 188 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
}
|
| 190 |
|
| 191 |
export function getExpansionPipeline(): TextGenerationPipeline | null {
|
|
@@ -195,7 +221,7 @@ export function getExpansionPipeline(): TextGenerationPipeline | null {
|
|
| 195 |
export function isAllModelsReady(): boolean {
|
| 196 |
return (
|
| 197 |
embeddingPipeline !== null &&
|
| 198 |
-
|
| 199 |
generationPipeline !== null
|
| 200 |
);
|
| 201 |
}
|
|
|
|
| 1 |
import {
|
| 2 |
pipeline,
|
| 3 |
+
AutoTokenizer,
|
| 4 |
+
AutoModelForCausalLM,
|
| 5 |
type FeatureExtractionPipeline,
|
| 6 |
type TextGenerationPipeline,
|
| 7 |
+
type PreTrainedTokenizer,
|
| 8 |
+
type PreTrainedModel,
|
| 9 |
type ProgressInfo,
|
| 10 |
} from "@huggingface/transformers";
|
| 11 |
import { MODEL_EMBEDDING, MODEL_RERANKER, MODEL_EXPANSION } from "../constants";
|
|
|
|
| 15 |
|
| 16 |
// Singleton model instances
|
| 17 |
let embeddingPipeline: FeatureExtractionPipeline | null = null;
|
|
|
|
| 18 |
let generationPipeline: TextGenerationPipeline | null = null;
|
| 19 |
|
| 20 |
+
// Reranker uses AutoModel + AutoTokenizer (not a pipeline)
|
| 21 |
+
let rerankerModel: PreTrainedModel | null = null;
|
| 22 |
+
let rerankerTokenizer: PreTrainedTokenizer | null = null;
|
| 23 |
+
let rerankerTokenYes = -1;
|
| 24 |
+
let rerankerTokenNo = -1;
|
| 25 |
+
|
| 26 |
/** Check whether WebGPU is available in this browser. */
|
| 27 |
export async function checkWebGPU(): Promise<boolean> {
|
| 28 |
if (!navigator.gpu) return false;
|
|
|
|
| 109 |
export async function loadRerankerModel(
|
| 110 |
onProgress?: ProgressCallback,
|
| 111 |
): Promise<void> {
|
| 112 |
+
if (rerankerModel) return;
|
| 113 |
const name = "reranker";
|
| 114 |
onProgress?.({ name, status: "pending", progress: 0 });
|
| 115 |
try {
|
| 116 |
+
const progressHandler = makeProgressHandler(name, onProgress);
|
| 117 |
+
|
| 118 |
+
// Load tokenizer and model separately (cross-encoder pattern)
|
| 119 |
+
rerankerTokenizer = await AutoTokenizer.from_pretrained(MODEL_RERANKER, {
|
| 120 |
+
progress_callback: progressHandler,
|
| 121 |
+
});
|
| 122 |
+
|
| 123 |
+
// Pre-compute "yes" and "no" token IDs for scoring
|
| 124 |
+
const yesIds = rerankerTokenizer("yes", { add_special_tokens: false }).input_ids.data;
|
| 125 |
+
const noIds = rerankerTokenizer("no", { add_special_tokens: false }).input_ids.data;
|
| 126 |
+
rerankerTokenYes = Number(yesIds[yesIds.length - 1]);
|
| 127 |
+
rerankerTokenNo = Number(noIds[noIds.length - 1]);
|
| 128 |
+
|
| 129 |
+
rerankerModel = await AutoModelForCausalLM.from_pretrained(MODEL_RERANKER, {
|
| 130 |
+
device: "webgpu",
|
| 131 |
+
progress_callback: progressHandler,
|
| 132 |
+
});
|
| 133 |
+
|
| 134 |
onProgress?.({ name, status: "ready", progress: 1 });
|
| 135 |
} catch (err) {
|
| 136 |
onProgress?.({
|
|
|
|
| 202 |
return embeddingPipeline;
|
| 203 |
}
|
| 204 |
|
| 205 |
+
export function getRerankerModel(): PreTrainedModel | null {
|
| 206 |
+
return rerankerModel;
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
export function getRerankerTokenizer(): PreTrainedTokenizer | null {
|
| 210 |
+
return rerankerTokenizer;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
export function getRerankerTokenIds(): { yes: number; no: number } {
|
| 214 |
+
return { yes: rerankerTokenYes, no: rerankerTokenNo };
|
| 215 |
}
|
| 216 |
|
| 217 |
export function getExpansionPipeline(): TextGenerationPipeline | null {
|
|
|
|
| 221 |
export function isAllModelsReady(): boolean {
|
| 222 |
return (
|
| 223 |
embeddingPipeline !== null &&
|
| 224 |
+
rerankerModel !== null &&
|
| 225 |
generationPipeline !== null
|
| 226 |
);
|
| 227 |
}
|