namberino commited on
Commit
1950969
·
1 Parent(s): fee5e22

Update generator

Browse files
Files changed (1) hide show
  1. app.py +45 -18
app.py CHANGED
@@ -164,7 +164,6 @@ async def generate_saved_endpoint(
164
  )
165
  questions_list = []
166
  if isinstance(mcqs, dict):
167
- # if dict values are lists or single objects, collect them
168
  for v in mcqs.values():
169
  if isinstance(v, list):
170
  questions_list.extend(v)
@@ -173,13 +172,10 @@ async def generate_saved_endpoint(
173
  elif isinstance(mcqs, list):
174
  questions_list = mcqs
175
  else:
176
- # unexpected shape — skip or handle as needed
177
  continue
178
 
179
- # assign sequential ids (as strings). use `counter` for numeric keys if preferred.
180
  for qobj in questions_list:
181
  if isinstance(qobj, dict):
182
- # optional: keep which difficulty the question came from
183
  qobj["_difficulty"] = difficulty
184
  all_mcqs[str(counter)] = qobj
185
  counter += 1
@@ -208,7 +204,9 @@ async def generate_saved_endpoint(
208
  async def generate_endpoint(
209
  background_tasks: BackgroundTasks,
210
  file: UploadFile = File(...),
211
- n_questions: int = Form(10),
 
 
212
  qdrant_filename: str = Form("default_filename"),
213
  collection_name: str = Form("programming"),
214
  mode: str = Form("rag"),
@@ -244,19 +242,48 @@ async def generate_endpoint(
244
  except Exception as e:
245
  raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
246
 
247
- # generate
248
- try:
249
- mcqs = rag.generate_from_pdf(
250
- tmp_path,
251
- n_questions=n_questions,
252
- mode=mode,
253
- questions_per_page=questions_per_page,
254
- top_k=top_k,
255
- temperature=temperature,
256
- enable_fiddler=enable_fiddler
257
- )
258
- except Exception as e:
259
- raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
  validation_report = None
262
 
 
164
  )
165
  questions_list = []
166
  if isinstance(mcqs, dict):
 
167
  for v in mcqs.values():
168
  if isinstance(v, list):
169
  questions_list.extend(v)
 
172
  elif isinstance(mcqs, list):
173
  questions_list = mcqs
174
  else:
 
175
  continue
176
 
 
177
  for qobj in questions_list:
178
  if isinstance(qobj, dict):
 
179
  qobj["_difficulty"] = difficulty
180
  all_mcqs[str(counter)] = qobj
181
  counter += 1
 
204
  async def generate_endpoint(
205
  background_tasks: BackgroundTasks,
206
  file: UploadFile = File(...),
207
+ n_easy_questions: int = Form(3),
208
+ n_medium_questions: int = Form(5),
209
+ n_hard_questions: int = Form(2),
210
  qdrant_filename: str = Form("default_filename"),
211
  collection_name: str = Form("programming"),
212
  mode: str = Form("rag"),
 
242
  except Exception as e:
243
  raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
244
 
245
+ difficulty_counts = {
246
+ "easy": n_easy_questions,
247
+ "medium": n_medium_questions,
248
+ "hard": n_hard_questions
249
+ }
250
+
251
+ all_mcqs = {}
252
+ counter = 1
253
+
254
+ for difficulty, n_questions in difficulty_counts.items():
255
+ try:
256
+ mcqs = rag.generate_from_pdf(
257
+ tmp_path,
258
+ collection=collection_name,
259
+ n_questions=n_questions,
260
+ mode=mode,
261
+ questions_per_page=questions_per_page,
262
+ top_k=top_k,
263
+ temperature=temperature,
264
+ enable_fiddler=enable_fiddler,
265
+ target_difficulty=difficulty,
266
+ )
267
+ questions_list = []
268
+ if isinstance(mcqs, dict):
269
+ for v in mcqs.values():
270
+ if isinstance(v, list):
271
+ questions_list.extend(v)
272
+ else:
273
+ questions_list.append(v)
274
+ elif isinstance(mcqs, list):
275
+ questions_list = mcqs
276
+ else:
277
+ continue
278
+
279
+ for qobj in questions_list:
280
+ if isinstance(qobj, dict):
281
+ qobj["_difficulty"] = difficulty
282
+ all_mcqs[str(counter)] = qobj
283
+ counter += 1
284
+
285
+ except Exception as e:
286
+ raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
287
 
288
  validation_report = None
289