htaf's picture
added new instruct pipeline for faster generation
2739b3a
// src/pipeline/step.mjs
import { loadProviderFor } from '../providers/provider.mjs';
import {
hybridSearch,
bm25Search,
vectorSearch,
hydeHybrid,
} from '../retrieval/retrieval.mjs';
import { runGenerator } from '../generator/generator_core.mjs';
import { runVerifier } from '../verifier/verifier_core.mjs';
import { runReward } from '../reward/reward_core.mjs';
import { preview } from './util.mjs';
/**
* Run a single pipeline step for one question.
*
* Flow:
* retrieval (or provided context) → generator → verifier → reward
*
* Design constraints:
* - Exactly one context chunk is used per question.
* - If `initialContext` is provided, we NEVER hit Elasticsearch.
* - If we call ES, we still only keep the FIRST returned chunk.
*
* Returns:
* {
* status: 'accepted'
* | 'invalid_question'
* | 'retrieval_failed'
* | 'generator_failed'
* | 'verifier_rejected'
* | 'verifier_error'
* | 'reward_rejected'
* | 'reward_error',
* question,
* context, // array with exactly one chunk (when successful)
* gen,
* ver,
* rew,
* error? // optional error message
* }
*/
export async function runPipelineStep({
question,
initialContext, // optional: [{ id?, content, ... }]
retrievalMode = process.env.RETRIEVAL_MODE || 'hybrid',
k = Number(process.env.RETRIEVAL_K || '6'),
generatorProvider,
verifierProvider,
rewardProvider,
cachedGen,
verbose = false,
logger = console,
} = {}) {
const log = logger?.log?.bind(logger) || console.log;
const errLog = logger?.error?.bind(logger) || console.error;
// ----------------------------------------
// Question sanity
// ----------------------------------------
if (!question || !question.trim()) {
if (verbose) log(' [pipeline] empty / invalid question, skipping');
return { status: 'invalid_question', question };
}
const genProv = generatorProvider || loadProviderFor('generator');
const verProv = verifierProvider || loadProviderFor('verifier');
const rewProv = rewardProvider || loadProviderFor('reward');
// ----------------------------------------
// Retrieval / context selection
// ----------------------------------------
let context = [];
if (initialContext && Array.isArray(initialContext) && initialContext.length > 0) {
// Use provided context, no ES call
context = initialContext.slice(0, 1); // enforce single-chunk invariant
if (verbose) {
log(
` [retrieval] using initialContext provided (len=${initialContext.length}), ` +
`keeping first chunk only`,
);
const first = context[0]?.content ?? '';
log(' [context] first chunk (provided):');
log(' ' + preview(first, 200).replace(/\n/g, '\n '));
}
} else {
// Go to ES exactly once
try {
if (verbose) log(` [retrieval] mode=${retrievalMode} k=${k}`);
const hits = await (async () => {
switch (retrievalMode) {
case 'bm25':
return bm25Search(question, k);
case 'vector':
return vectorSearch(question, k);
case 'hyde':
return hydeHybrid(question, k, genProv);
case 'hybrid':
default:
return hybridSearch(question, k);
}
})();
if (verbose) {
log(` [retrieval] got ${hits.length} chunks from ES`);
}
if (!hits || hits.length === 0) {
if (verbose) log(' [retrieval] no chunks found → retrieval_failed');
return {
status: 'retrieval_failed',
question,
error: 'no_chunks',
};
}
// Enforce single-chunk context
context = [hits[0]];
if (verbose) {
const first = context[0]?.content ?? '';
log(' [context] first chunk (from ES):');
log(' ' + preview(first, 200).replace(/\n/g, '\n '));
}
} catch (e) {
const msg = e?.message || String(e);
if (verbose) errLog(' [retrieval] ERROR:', msg);
return {
status: 'retrieval_failed',
question,
error: msg,
};
}
}
// Safety: if somehow context is still empty here, fail fast
if (!context || context.length === 0) {
if (verbose) log(' [retrieval] context empty after selection → retrieval_failed');
return {
status: 'retrieval_failed',
question,
error: 'empty_context',
};
}
// ----------------------------------------
// Generator
// ----------------------------------------
let gen;
if (cachedGen) {
gen = cachedGen;
if (verbose) log(' [generator] using cached generation');
} else {
try {
if (verbose) log(' [generator] calling model…');
gen = await runGenerator(question, context, genProv);
if (verbose) {
if (gen?.thought) {
const thoughtPreview =
typeof gen.thought === 'string'
? gen.thought
: JSON.stringify(gen.thought, null, 2);
log(' [generator] thought:');
log(' ' + preview(thoughtPreview, 500).replace(/\n/g, '\n '));
}
if (gen?.thinking) {
const thinkingPreview =
typeof gen.thinking === 'string'
? gen.thinking
: JSON.stringify(gen.thinking, null, 2);
log(' [generator] thinking (raw from provider):');
log(' ' + preview(thinkingPreview, 500).replace(/\n/g, '\n '));
}
log(' [generator] answer:');
log(' ' + preview(gen?.answer ?? '', 400).replace(/\n/g, '\n '));
if (gen?.confidence) {
log(' [generator] confidence: ' + gen.confidence);
}
if (gen?.evidence) {
log(
' [generator] evidence: ' +
preview(
Array.isArray(gen.evidence)
? gen.evidence.join(' | ')
: String(gen.evidence),
400,
).replace(/\n/g, '\n '),
);
}
if (gen?.limitations) {
log(' [generator] limitations: ' + preview(gen.limitations, 200));
}
if (gen?.raw) {
let rawDisplay = gen.raw;
try {
const parsed = JSON.parse(gen.raw);
rawDisplay = JSON.stringify(parsed, null, 2);
} catch {
// leave as string
}
log(' [generator] raw response (JSON if parsable):');
log(' ' + preview(rawDisplay, 2000).replace(/\n/g, '\n '));
}
if (gen?.rawJson?.response) {
log(' [generator] ollama response text (full):');
log(' ' + preview(gen.rawJson.response, 2000).replace(/\n/g, '\n '));
}
if (gen?.rawJson) {
const jsonDisplay = JSON.stringify(gen.rawJson, null, 2);
log(' [generator] ollama full JSON:');
log(' ' + jsonDisplay.replace(/\n/g, '\n '));
}
}
} catch (e) {
const msg = e?.message || String(e);
if (verbose) errLog(' [generator] ERROR:', msg);
return {
status: 'generator_failed',
question,
context,
error: msg,
};
}
}
// Empty answer means generator failed
if (!gen || !gen.answer || !gen.answer.trim()) {
if (verbose) log(' [generator] empty answer — generator_failed');
return {
status: 'generator_failed',
question,
context,
gen,
error: 'empty_answer',
};
}
// ----------------------------------------
// Verifier
// ----------------------------------------
let ver;
try {
if (verbose) log(' [verifier] calling model…');
ver = await runVerifier({ question, context, gen }, verProv);
if (verbose) {
log(' [verifier] ok=' + (ver?.ok === true));
const raw = ver?.raw ?? '';
log(' [verifier] raw transcript:');
log(' ' + raw.replace(/\n/g, '\n '));
}
} catch (e) {
const msg = e?.message || String(e);
if (verbose) errLog(' [verifier] ERROR:', msg);
return {
status: 'verifier_error',
question,
context,
gen,
error: msg,
};
}
if (!ver || ver.ok !== true) {
if (verbose) log(' [verifier] rejected sample');
return {
status: 'verifier_rejected',
question,
context,
gen,
ver,
};
}
// ----------------------------------------
// Reward
// ----------------------------------------
let rew;
try {
if (verbose) log(' [reward] calling model…');
rew = await runReward({ question, context, gen, ver }, rewProv);
if (verbose) {
log(` [reward] score=${rew?.score} ok=${rew?.ok}`);
log(' ' + preview(rew?.raw ?? '', 200).replace(/\n/g, '\n '));
}
} catch (e) {
const msg = e?.message || String(e);
if (verbose) errLog(' [reward] ERROR:', msg);
return {
status: 'reward_error',
question,
context,
gen,
ver,
error: msg,
};
}
if (!rew || rew.ok !== true) {
if (verbose) log(' [reward] rejected sample');
return {
status: 'reward_rejected',
question,
context,
gen,
ver,
rew,
};
}
// ----------------------------------------
// Accepted sample
// ----------------------------------------
if (verbose) log(' [pipeline] accepted ✅');
return {
status: 'accepted',
question,
context,
gen,
ver,
rew,
};
}