distill-pipeline / scripts /regenerate_gold_from_cache.mjs
htaf's picture
updateda bunch of stuff
fad3187
#!/usr/bin/env node
// scripts/regenerate_gold_from_cache.mjs
// Regenerate gold/pipeline_gold.jsonl from cache JSONL files.
import fs from 'fs/promises';
import path from 'path';
import { fileURLToPath } from 'url';
import { loadRagChunks } from '../src/retrieval/jsonl_chunks.mjs';
import {
questionId,
chunkIdFromContent,
} from '../src/pipeline/cache.mjs';
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
const PROJECT_ROOT = path.join(__dirname, '..');
const CACHE_DIR = (() => {
const custom = process.env.PIPELINE_CACHE_DIR;
if (custom) {
return path.isAbsolute(custom)
? custom
: path.join(PROJECT_ROOT, custom);
}
return path.join(PROJECT_ROOT, 'data', 'cache');
})();
const GOLD_PATH =
process.env.GOLD_PATH ||
path.join(PROJECT_ROOT, 'gold', 'pipeline_gold.jsonl');
const CACHE_FILES = {
questions: 'questions.jsonl',
generations: 'generations.jsonl',
verifications: 'verifications.jsonl',
rewards: 'rewards.jsonl',
};
async function readJsonl(fileName) {
const filePath = path.join(CACHE_DIR, fileName);
try {
const txt = await fs.readFile(filePath, 'utf8');
return txt
.split('\n')
.map((l) => l.trim())
.filter(Boolean)
.map((line) => {
try {
return JSON.parse(line);
} catch {
return null;
}
})
.filter(Boolean);
} catch (e) {
if (e.code === 'ENOENT') return [];
throw e;
}
}
function compositeKey(...parts) {
return parts.filter(Boolean).join('|');
}
async function loadChunksMap() {
const chunks = await loadRagChunks();
const map = new Map();
for (const c of chunks) {
const cid = c.id || chunkIdFromContent(c.content, c.sourceId || c.source?.id);
map.set(cid, c);
}
return map;
}
function latestByTs(records, keyFn) {
const map = new Map();
for (const r of records) {
const key = keyFn(r);
if (!key) continue;
const existing = map.get(key);
if (!existing || (r.ts && (!existing.ts || r.ts > existing.ts))) {
map.set(key, r);
}
}
return map;
}
function rewardOk(r) {
if (!r) return false;
if (r.ok === true) return true;
if (typeof r.score === 'number') return r.score >= 0.5;
if (typeof r.score === 'string') {
const t = r.score.trim().toLowerCase();
if (t === 'pass') return true;
const num = Number(r.score);
if (Number.isFinite(num)) return num >= 0.5;
}
return false;
}
async function main() {
const [questions, generations, verifications, rewards] = await Promise.all([
readJsonl(CACHE_FILES.questions),
readJsonl(CACHE_FILES.generations),
readJsonl(CACHE_FILES.verifications),
readJsonl(CACHE_FILES.rewards),
]);
const chunkMap = await loadChunksMap();
// Build questionId -> question text map
const qMap = new Map();
for (const rec of questions) {
const chunkId = rec.chunk_id;
if (!chunkId) continue;
const qs = Array.isArray(rec.questions)
? rec.questions
: rec.question
? [rec.question]
: [];
const qIds = Array.isArray(rec.question_ids) ? rec.question_ids : [];
for (let i = 0; i < qs.length; i++) {
const q = qs[i];
const providedId = qIds[i];
const hashedId = questionId(chunkId, q);
if (providedId) {
qMap.set(compositeKey(chunkId, providedId), q);
}
qMap.set(compositeKey(chunkId, hashedId), q);
}
}
// Latest generation per chunk+question (by ts)
const genMap = latestByTs(generations, (g) =>
compositeKey(g.chunk_id, g.question_id),
);
// Latest verification per chunk+question+gen
const verMap = latestByTs(verifications, (v) =>
compositeKey(v.chunk_id, v.question_id, v.gen_id),
);
// Latest reward per chunk+question+gen
const rewMap = latestByTs(rewards, (r) =>
compositeKey(r.chunk_id, r.question_id, r.gen_id),
);
const out = [];
let accepted = 0;
for (const [key, gen] of genMap.entries()) {
const [chunkId, qId] = key.split('|');
const question = qMap.get(compositeKey(chunkId, qId)) || '[unknown question]';
const chunk = chunkMap.get(chunkId) || {};
const context = [{ id: chunkId, content: chunk.content ?? chunk.text ?? '' }];
const ver = verMap.get(compositeKey(chunkId, qId, gen.gen_id));
const rew = rewMap.get(compositeKey(chunkId, qId, gen.gen_id));
const rewardIsOk = rewardOk(rew);
const verifierIsOk = ver?.ok === true;
if (!rewardIsOk && !verifierIsOk) continue;
accepted += 1;
out.push({
question,
sourceChunkId: chunkId,
sourceChunk: chunk.content ?? chunk.text,
sourceDoc: chunk.source,
context,
sample: gen,
verifier: ver,
reward: rew,
});
}
const lines = out.map((r) => JSON.stringify(r));
await fs.mkdir(path.dirname(GOLD_PATH), { recursive: true });
await fs.writeFile(GOLD_PATH, lines.join('\n') + '\n', 'utf8');
console.log(`Regenerated gold at ${GOLD_PATH}`);
console.log(`Accepted records: ${accepted}`);
console.log(`Total written: ${out.length}`);
}
if (import.meta.url === `file://${__filename}`) {
main().catch((err) => {
console.error('Regenerate gold error:', err);
process.exit(1);
});
}
export { main };