Eric Gardner commited on
Commit
be647a4
·
1 Parent(s): 7ecba03

Use Flan-T5-base to pre-generate questions

Browse files
Files changed (3) hide show
  1. index.js +8 -4
  2. routes/article.js +42 -2
  3. services/questionGenerator.js +121 -0
index.js CHANGED
@@ -3,6 +3,7 @@ import cors from 'cors';
3
  import articleRoutes from './routes/article.js';
4
  import searchRoutes from './routes/search.js';
5
  import { initEmbedder } from './services/embedder.js';
 
6
 
7
  const app = express();
8
  const PORT = process.env.PORT || 3000;
@@ -22,13 +23,16 @@ app.get( '/api/health', ( _, res ) => {
22
  res.json( { status: 'ok' } );
23
  } );
24
 
25
- // Pre-warm the embedding model on startup
26
- console.log( 'Starting server and loading embedding model...' );
27
- initEmbedder().then( () => {
 
 
 
28
  app.listen( PORT, () => {
29
  console.log( `Server running on http://localhost:${ PORT }` );
30
  } );
31
  } ).catch( ( err ) => {
32
- console.error( 'Failed to initialize embedder:', err );
33
  process.exit( 1 );
34
  } );
 
3
  import articleRoutes from './routes/article.js';
4
  import searchRoutes from './routes/search.js';
5
  import { initEmbedder } from './services/embedder.js';
6
+ import { initQuestionGenerator } from './services/questionGenerator.js';
7
 
8
  const app = express();
9
  const PORT = process.env.PORT || 3000;
 
23
  res.json( { status: 'ok' } );
24
  } );
25
 
26
+ // Pre-warm the models on startup
27
+ console.log( 'Starting server and loading models...' );
28
+ Promise.all( [
29
+ initEmbedder(),
30
+ initQuestionGenerator()
31
+ ] ).then( () => {
32
  app.listen( PORT, () => {
33
  console.log( `Server running on http://localhost:${ PORT }` );
34
  } );
35
  } ).catch( ( err ) => {
36
+ console.error( 'Failed to initialize models:', err );
37
  process.exit( 1 );
38
  } );
routes/article.js CHANGED
@@ -5,6 +5,7 @@ import { embedTexts, embedSingle } from '../services/embedder.js';
5
  import { search } from '../services/vectorSearch.js';
6
  import { getCached, setCache, isCacheValid } from '../services/cache.js';
7
  import { getProcessingState, setProcessing } from '../services/processingState.js';
 
8
 
9
  const router = Router();
10
 
@@ -32,7 +33,8 @@ router.get( '/:title', async ( req, res ) => {
32
  revisionId: cached.revisionId,
33
  html: cached.html,
34
  status: 'ready',
35
- chunkCount: cached.chunks.length
 
36
  } );
37
  }
38
 
@@ -202,6 +204,43 @@ async function processArticle( title, revisionId ) {
202
  chunk.embedding = embeddings[ i ];
203
  } );
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  // Save to cache
206
  await setCache( title, {
207
  title: articleData.title,
@@ -210,7 +249,8 @@ async function processArticle( title, revisionId ) {
210
  fetchedAt: new Date().toISOString(),
211
  html,
212
  chunkCount: chunks.length,
213
- chunks
 
214
  } );
215
 
216
  setProcessing( title, 'ready' );
 
5
  import { search } from '../services/vectorSearch.js';
6
  import { getCached, setCache, isCacheValid } from '../services/cache.js';
7
  import { getProcessingState, setProcessing } from '../services/processingState.js';
8
+ import { generateQuestions, getLeadSectionText } from '../services/questionGenerator.js';
9
 
10
  const router = Router();
11
 
 
33
  revisionId: cached.revisionId,
34
  html: cached.html,
35
  status: 'ready',
36
+ chunkCount: cached.chunks.length,
37
+ suggestedQuestions: cached.suggestedQuestions || []
38
  } );
39
  }
40
 
 
204
  chunk.embedding = embeddings[ i ];
205
  } );
206
 
207
+ // Generate suggested questions from the lead section
208
+ let suggestedQuestions = [];
209
+ try {
210
+ const leadText = getLeadSectionText( chunks );
211
+ console.log( `Lead text length: ${ leadText.length } chars` );
212
+ if ( leadText.length > 100 ) {
213
+ console.log( 'Generating suggested questions...' );
214
+ const rawQuestions = await generateQuestions( leadText, 5 );
215
+ console.log( `Raw questions from model:`, rawQuestions );
216
+
217
+ // Validate questions by checking if they match article content
218
+ const validatedQuestions = [];
219
+ for ( const question of rawQuestions ) {
220
+ const questionEmbedding = await embedSingle( question );
221
+ const { results } = search( questionEmbedding, chunks, 1 );
222
+
223
+ if ( results.length === 0 ) {
224
+ console.log( `Question: "${ question }" -> no results` );
225
+ continue;
226
+ }
227
+
228
+ const score = results[ 0 ].score;
229
+ console.log( `Question: "${ question }" -> score: ${ score.toFixed( 3 ) }` );
230
+
231
+ // Keep questions that have a good match (score > 0.3)
232
+ if ( score > 0.3 ) {
233
+ validatedQuestions.push( question );
234
+ }
235
+ }
236
+
237
+ suggestedQuestions = validatedQuestions.slice( 0, 3 );
238
+ console.log( `Generated ${ suggestedQuestions.length } validated questions` );
239
+ }
240
+ } catch ( err ) {
241
+ console.warn( 'Question generation failed, continuing without suggestions:', err.message );
242
+ }
243
+
244
  // Save to cache
245
  await setCache( title, {
246
  title: articleData.title,
 
249
  fetchedAt: new Date().toISOString(),
250
  html,
251
  chunkCount: chunks.length,
252
+ chunks,
253
+ suggestedQuestions
254
  } );
255
 
256
  setProcessing( title, 'ready' );
services/questionGenerator.js ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { pipeline } from '@xenova/transformers';
2
+
3
+ let generator = null;
4
+
5
+ /**
6
+ * Initialize the question generation model (Flan-T5-base)
7
+ */
8
+ export async function initQuestionGenerator() {
9
+ if ( !generator ) {
10
+ console.log( 'Loading question generation model (flan-t5-base)...' );
11
+ generator = await pipeline( 'text2text-generation', 'Xenova/flan-t5-base' );
12
+ console.log( 'Question generation model loaded.' );
13
+ }
14
+ return generator;
15
+ }
16
+
17
+ /**
18
+ * Generate a single question from a text passage
19
+ *
20
+ * @param {string} text - The passage to generate a question about
21
+ * @returns {Promise<string|null>} - Generated question or null
22
+ */
23
+ async function generateSingleQuestion( text ) {
24
+ // More specific prompt to encourage factual questions
25
+ const prompt = `Ask a specific factual question that can be answered by the following passage: ${ text }`;
26
+
27
+ const result = await generator( prompt, {
28
+ max_new_tokens: 60,
29
+ num_beams: 2,
30
+ do_sample: false
31
+ } );
32
+
33
+ const output = result[ 0 ].generated_text.trim();
34
+
35
+ // Ensure it ends with a question mark
36
+ if ( output.length > 10 ) {
37
+ return output.endsWith( '?' ) ? output : output + '?';
38
+ }
39
+ return null;
40
+ }
41
+
42
+ /**
43
+ * Group sentences into chunks of N for more context
44
+ *
45
+ * @param {string[]} sentences - Array of sentences
46
+ * @param {number} groupSize - Number of sentences per group
47
+ * @returns {string[]} - Array of grouped sentence strings
48
+ */
49
+ function groupSentences( sentences, groupSize = 3 ) {
50
+ const groups = [];
51
+ for ( let i = 0; i < sentences.length; i += groupSize ) {
52
+ const group = sentences.slice( i, i + groupSize ).join( ' ' );
53
+ groups.push( group );
54
+ }
55
+ return groups;
56
+ }
57
+
58
+ /**
59
+ * Generate questions from a text passage
60
+ *
61
+ * @param {string} text - The passage to generate questions about
62
+ * @param {number} numQuestions - Number of questions to generate (default: 5)
63
+ * @returns {Promise<string[]>} - Array of generated questions
64
+ */
65
+ export async function generateQuestions( text, numQuestions = 5 ) {
66
+ if ( !generator ) {
67
+ await initQuestionGenerator();
68
+ }
69
+
70
+ // Split text into sentences
71
+ const sentences = text
72
+ .split( /(?<=[.!?])\s+/ )
73
+ .filter( ( s ) => s.length > 30 );
74
+
75
+ // Group sentences (2-3 at a time) for more context per question
76
+ const chunks = groupSentences( sentences, 2 );
77
+
78
+ // Take a sample of chunks to generate questions from
79
+ const sampleSize = Math.min( numQuestions * 2, chunks.length );
80
+ const sampled = chunks.slice( 0, sampleSize );
81
+
82
+ const questions = [];
83
+ const seen = new Set();
84
+
85
+ try {
86
+ for ( const chunk of sampled ) {
87
+ if ( questions.length >= numQuestions ) {
88
+ break;
89
+ }
90
+
91
+ const question = await generateSingleQuestion( chunk );
92
+ if ( question && !seen.has( question.toLowerCase() ) ) {
93
+ seen.add( question.toLowerCase() );
94
+ questions.push( question );
95
+ }
96
+ }
97
+
98
+ return questions;
99
+ } catch ( error ) {
100
+ console.error( 'Question generation failed:', error );
101
+ return [];
102
+ }
103
+ }
104
+
105
+ /**
106
+ * Extract lead section text from chunks
107
+ *
108
+ * @param {Array} chunks - Article chunks with sectionTitle
109
+ * @returns {string} - Combined text from the introduction/lead section
110
+ */
111
+ export function getLeadSectionText( chunks ) {
112
+ const leadChunks = chunks.filter(
113
+ ( chunk ) => chunk.sectionTitle === 'Introduction' || chunk.sectionId === null
114
+ );
115
+
116
+ // Take up to first 3 paragraphs from the lead
117
+ return leadChunks
118
+ .slice( 0, 3 )
119
+ .map( ( c ) => c.text )
120
+ .join( ' ' );
121
+ }