namberino commited on
Commit
4dcccc9
·
1 Parent(s): 9e7752e

Add difficulty distribution

Browse files
Files changed (4) hide show
  1. app.py +190 -8
  2. generator.py +421 -1
  3. requirements.txt +1 -0
  4. utils.py +155 -8
app.py CHANGED
@@ -8,7 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
 
10
  # Import the user's RAGMCQ implementation
11
- from generator import RAGMCQ
12
  from utils import log_pipeline
13
 
14
  app = FastAPI(title="RAG MCQ Generator API")
@@ -24,6 +24,7 @@ app.add_middleware(
24
 
25
  # global rag instance
26
  rag: Optional[RAGMCQ] = None
 
27
 
28
  class GenerateResponse(BaseModel):
29
  mcqs: dict
@@ -34,15 +35,17 @@ class ListResponse(BaseModel):
34
 
35
  @app.on_event("startup")
36
  def startup_event():
 
37
  global rag
38
 
39
  # instantiate the heavy object once
40
  rag = RAGMCQ()
 
41
  print("RAGMCQ instance created on startup.")
42
 
43
  @app.get("/health")
44
  def health():
45
- return {"status": "ok", "ready": rag is not None}
46
 
47
  def _save_upload_to_temp(upload: UploadFile) -> str:
48
  suffix = ".pdf"
@@ -57,11 +60,11 @@ def _save_upload_to_temp(upload: UploadFile) -> str:
57
  async def list_collection_files_endpoint(
58
  collection_name: str = "programming"
59
  ):
60
- global rag
61
- if rag is None:
62
  raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
63
 
64
- files = rag.list_files_in_collection(collection_name)
65
 
66
  return {"files": files}
67
 
@@ -81,8 +84,8 @@ async def upload_multiple_files(
81
  - overwrite: if true, existing points for each filename will be removed
82
  - qdrant_filename_prefix: optional prefix; if provided each file will be saved under "<prefix>_<original_filename>"
83
  """
84
- global rag
85
- if rag is None:
86
  raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
87
 
88
  saved_files = []
@@ -112,7 +115,7 @@ async def upload_multiple_files(
112
  )
113
 
114
  try:
115
- rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=overwrite)
116
  saved_files.append(qdrant_filename)
117
  except Exception as e:
118
  # collect failure info rather than aborting all uploads
@@ -122,6 +125,185 @@ async def upload_multiple_files(
122
 
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  @app.post("/generate_saved", response_model=GenerateResponse)
126
  async def generate_saved_endpoint(
127
  n_questions: int = Form(10),
 
8
  from pydantic import BaseModel
9
 
10
  # Import the user's RAGMCQ implementation
11
+ from generator import RAGMCQWithDifficulty, RAGMCQ
12
  from utils import log_pipeline
13
 
14
  app = FastAPI(title="RAG MCQ Generator API")
 
24
 
25
  # global rag instance
26
  rag: Optional[RAGMCQ] = None
27
+ rag_difficulty: Optional[RAGMCQWithDifficulty] = None
28
 
29
  class GenerateResponse(BaseModel):
30
  mcqs: dict
 
35
 
36
  @app.on_event("startup")
37
  def startup_event():
38
+ global rag_difficulty
39
  global rag
40
 
41
  # instantiate the heavy object once
42
  rag = RAGMCQ()
43
+ rag_difficulty = RAGMCQWithDifficulty()
44
  print("RAGMCQ instance created on startup.")
45
 
46
  @app.get("/health")
47
  def health():
48
+ return {"status": "ok", "ready": rag_difficulty is not None and rag is not None}
49
 
50
  def _save_upload_to_temp(upload: UploadFile) -> str:
51
  suffix = ".pdf"
 
60
  async def list_collection_files_endpoint(
61
  collection_name: str = "programming"
62
  ):
63
+ global rag_difficulty
64
+ if rag_difficulty is None:
65
  raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
66
 
67
+ files = rag_difficulty.list_files_in_collection(collection_name)
68
 
69
  return {"files": files}
70
 
 
84
  - overwrite: if true, existing points for each filename will be removed
85
  - qdrant_filename_prefix: optional prefix; if provided each file will be saved under "<prefix>_<original_filename>"
86
  """
87
+ global rag_difficulty
88
+ if rag_difficulty is None:
89
  raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
90
 
91
  saved_files = []
 
115
  )
116
 
117
  try:
118
+ rag_difficulty.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=overwrite)
119
  saved_files.append(qdrant_filename)
120
  except Exception as e:
121
  # collect failure info rather than aborting all uploads
 
125
 
126
 
127
 
128
+ @app.post("/generate_saved_with_difficulty", response_model=GenerateResponse)
129
+ async def generate_saved_with_difficulty_endpoint(
130
+ n_easy_questions: int = Form(3),
131
+ n_medium_questions: int = Form(5),
132
+ n_hard_questions: int = Form(2),
133
+ qdrant_filename: str = Form("default_filename"),
134
+ collection_name: str = Form("programming"),
135
+ mode: str = Form("rag"),
136
+ questions_per_chunk: int = Form(3),
137
+ top_k: int = Form(3),
138
+ temperature: float = Form(0.2),
139
+ validate_mcqs: bool = Form(False),
140
+ enable_fiddler: bool = Form(False),
141
+ ):
142
+ global rag_difficulty
143
+ if rag_difficulty is None:
144
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
145
+
146
+ difficulty_counts = {
147
+ "easy": n_easy_questions,
148
+ "medium": n_medium_questions,
149
+ "hard": n_hard_questions
150
+ }
151
+
152
+ all_mcqs = {}
153
+ counter = 1
154
+
155
+ for difficulty, n_questions in difficulty_counts.items():
156
+ try:
157
+ mcqs = rag_difficulty.generate_from_qdrant(
158
+ filename=qdrant_filename,
159
+ collection=collection_name,
160
+ n_questions=n_questions,
161
+ mode=mode,
162
+ questions_per_chunk=questions_per_chunk,
163
+ top_k=top_k,
164
+ temperature=temperature,
165
+ enable_fiddler=enable_fiddler,
166
+ target_difficulty=difficulty,
167
+ )
168
+ questions_list = []
169
+ if isinstance(mcqs, dict):
170
+ for v in mcqs.values():
171
+ if isinstance(v, list):
172
+ questions_list.extend(v)
173
+ else:
174
+ questions_list.append(v)
175
+ elif isinstance(mcqs, list):
176
+ questions_list = mcqs
177
+ else:
178
+ continue
179
+
180
+ for qobj in questions_list:
181
+ if isinstance(qobj, dict):
182
+ qobj["_difficulty"] = difficulty
183
+ all_mcqs[str(counter)] = qobj
184
+ counter += 1
185
+
186
+ except Exception as e:
187
+ raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
188
+
189
+ validation_report = None
190
+
191
+ if validate_mcqs:
192
+ try:
193
+ # validate_mcqs expects keys as strings and the normalized content
194
+ validation_report = rag_difficulty.validate_mcqs(all_mcqs, top_k=top_k)
195
+ except Exception as e:
196
+ # don't fail the whole request for a validation error — return generator output and note the error
197
+ validation_report = {"error": f"Validation failed: {e}"}
198
+
199
+ # log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
200
+
201
+ return {"mcqs": all_mcqs, "validation": validation_report}
202
+
203
+
204
+
205
+
206
+ @app.post("/generate_with_difficulty", response_model=GenerateResponse)
207
+ async def generate_with_difficulty_endpoint(
208
+ background_tasks: BackgroundTasks,
209
+ file: UploadFile = File(...),
210
+ n_easy_questions: int = Form(3),
211
+ n_medium_questions: int = Form(5),
212
+ n_hard_questions: int = Form(2),
213
+ qdrant_filename: str = Form("default_filename"),
214
+ collection_name: str = Form("programming"),
215
+ mode: str = Form("rag"),
216
+ questions_per_page: int = Form(3),
217
+ top_k: int = Form(3),
218
+ temperature: float = Form(0.2),
219
+ validate_mcqs: bool = Form(False),
220
+ enable_fiddler: bool = Form(False)
221
+ ):
222
+ global rag_difficulty
223
+ if rag_difficulty is None:
224
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
225
+
226
+ # basic file validation
227
+ if not file.filename.lower().endswith(".pdf"):
228
+ raise HTTPException(status_code=400, detail="Only PDF files are supported.")
229
+
230
+ # save uploaded file to a temp location
231
+ tmp_path = _save_upload_to_temp(file)
232
+
233
+ # ensure file removed afterward
234
+ def _cleanup(path: str):
235
+ try:
236
+ os.remove(path)
237
+ except Exception:
238
+ pass
239
+
240
+ background_tasks.add_task(_cleanup, tmp_path)
241
+
242
+ # save pdf
243
+ try:
244
+ rag_difficulty.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=True)
245
+ except Exception as e:
246
+ raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
247
+
248
+ difficulty_counts = {
249
+ "easy": n_easy_questions,
250
+ "medium": n_medium_questions,
251
+ "hard": n_hard_questions
252
+ }
253
+
254
+ all_mcqs = {}
255
+ counter = 1
256
+
257
+ for difficulty, n_questions in difficulty_counts.items():
258
+ try:
259
+ mcqs = rag_difficulty.generate_from_pdf(
260
+ pdf_path=tmp_path,
261
+ n_questions=n_questions,
262
+ mode=mode,
263
+ questions_per_page=questions_per_page,
264
+ top_k=top_k,
265
+ temperature=temperature,
266
+ enable_fiddler=enable_fiddler,
267
+ target_difficulty=difficulty,
268
+ )
269
+ questions_list = []
270
+ if isinstance(mcqs, dict):
271
+ for v in mcqs.values():
272
+ if isinstance(v, list):
273
+ questions_list.extend(v)
274
+ else:
275
+ questions_list.append(v)
276
+ elif isinstance(mcqs, list):
277
+ questions_list = mcqs
278
+ else:
279
+ continue
280
+
281
+ for qobj in questions_list:
282
+ if isinstance(qobj, dict):
283
+ qobj["_difficulty"] = difficulty
284
+ all_mcqs[str(counter)] = qobj
285
+ counter += 1
286
+
287
+ except Exception as e:
288
+ raise HTTPException(status_code=500, detail=f"Generation from file failed: {e}")
289
+
290
+ validation_report = None
291
+
292
+ if validate_mcqs:
293
+ try:
294
+ # rag.build_index_from_pdf(tmp_path)
295
+ # validate_mcqs expects keys as strings and the normalized content
296
+ validation_report = rag_difficulty.validate_mcqs(all_mcqs, top_k=top_k)
297
+ except Exception as e:
298
+ # don't fail the whole request for a validation error — return generator output and note the error
299
+ validation_report = {"error": f"Validation failed: {e}"}
300
+
301
+
302
+ # log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
303
+
304
+ return {"mcqs": all_mcqs, "validation": validation_report}
305
+
306
+
307
  @app.post("/generate_saved", response_model=GenerateResponse)
308
  async def generate_saved_endpoint(
309
  n_questions: int = Form(10),
generator.py CHANGED
@@ -9,6 +9,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder
9
  from transformers import pipeline
10
  from uuid import uuid4
11
  import pymupdf4llm
 
12
 
13
  try:
14
  from qdrant_client import QdrantClient
@@ -31,7 +32,7 @@ try:
31
  except Exception:
32
  _HAS_FAISS = False
33
 
34
- from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json, save_to_local
35
 
36
  from huggingface_hub import login
37
  login(token=os.environ['HF_MODEL_TOKEN'])
@@ -1074,3 +1075,422 @@ class RAGMCQ:
1074
  label = "khó"
1075
 
1076
  return score, label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from transformers import pipeline
10
  from uuid import uuid4
11
  import pymupdf4llm
12
+ from typing_extensions import override
13
 
14
  try:
15
  from qdrant_client import QdrantClient
 
32
  except Exception:
33
  _HAS_FAISS = False
34
 
35
+ from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json, save_to_local, structure_context_for_llm, new_generate_mcqs_from_text
36
 
37
  from huggingface_hub import login
38
  login(token=os.environ['HF_MODEL_TOKEN'])
 
1075
  label = "khó"
1076
 
1077
  return score, label
1078
+
1079
+ class RAGMCQWithDifficulty(RAGMCQ):
1080
+ def __init__(
1081
+ self,
1082
+ embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
1083
+ generation_model: str = "openai/gpt-oss-120b",
1084
+ qdrant_url: str = os.environ.get('QDRANT_URL') or "",
1085
+ qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "",
1086
+ qdrant_prefer_grpc: bool = False,
1087
+ ):
1088
+ super().__init__(embedder_model, generation_model, qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
1089
+
1090
+ @override
1091
+ def generate_from_pdf(
1092
+ self,
1093
+ pdf_path: str,
1094
+ n_questions: int = 10,
1095
+ mode: str = "rag", # per_page or rag
1096
+ questions_per_page: int = 3, # for per_page mode
1097
+ top_k: int = 3, # chunks to retrieve for each question in rag mode
1098
+ temperature: float = 0.2,
1099
+ enable_fiddler: bool = False,
1100
+ target_difficulty: str = 'easy' # easy, mid, difficult
1101
+ ) -> Dict[str, Any]:
1102
+ # build index
1103
+ self.build_index_from_pdf(pdf_path)
1104
+
1105
+ output: Dict[str, Any] = {}
1106
+ qcount = 0
1107
+
1108
+ if mode == "per_page":
1109
+ # iterate pages -> chunks
1110
+ for idx, meta in enumerate(self.metadata):
1111
+ chunk_text = self.texts[idx]
1112
+
1113
+ if not chunk_text.strip():
1114
+ continue
1115
+ to_gen = questions_per_page
1116
+
1117
+ # ask generator
1118
+ try:
1119
+ structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
1120
+ mcq_block = generate_mcqs_from_text(
1121
+ source_text=chunk_text, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
1122
+ )
1123
+ except Exception as e:
1124
+ # skip this chunk if generator fails
1125
+ print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
1126
+ continue
1127
+
1128
+ if "error" in list(mcq_block.keys()):
1129
+ return output
1130
+
1131
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
1132
+ qcount += 1
1133
+ output[str(qcount)] = mcq_block[item]
1134
+ if qcount >= n_questions:
1135
+ return output
1136
+
1137
+ return output
1138
+
1139
+ # pdf gene
1140
+ elif mode == "rag":
1141
+ # strategy: create a few natural short queries by sampling sentences or using chunk summaries.
1142
+ # create queries by sampling chunk text sentences.
1143
+ # stop when n_questions reached or max_attempts exceeded.
1144
+ attempts = 0
1145
+ max_attempts = n_questions * 4
1146
+
1147
+ while qcount < n_questions and attempts < max_attempts:
1148
+ attempts += 1
1149
+ # create a seed query: pick a random chunk, pick a sentence from it
1150
+ seed_idx = random.randrange(len(self.texts))
1151
+ chunk = self.texts[seed_idx]
1152
+
1153
+ #? investigate better Chunking Strategy
1154
+ #with open("chunks.txt", "a", encoding="utf-8") as f:
1155
+ #f.write(chunk + "\n")
1156
+
1157
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
1158
+ seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
1159
+ query = f"Create questions about: {seed_sent}"
1160
+
1161
+ # retrieve top_k chunks
1162
+ retrieved = self._retrieve(query, top_k=top_k)
1163
+ context_parts = []
1164
+ for ridx, score in retrieved:
1165
+ md = self.metadata[ridx]
1166
+ context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
1167
+ context = "\n\n".join(context_parts)
1168
+
1169
+ # save_to_local('test/context.md', content=context)
1170
+
1171
+ # call generator for 1 question (or small batch) with the retrieved context
1172
+ try:
1173
+ # request 1 question at a time to keep diversity
1174
+ structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
1175
+ mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
1176
+ except Exception as e:
1177
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
1178
+ continue
1179
+
1180
+ if "error" in list(mcq_block.keys()):
1181
+ return output
1182
+
1183
+ # append result(s)
1184
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
1185
+ payload = mcq_block[item]
1186
+ q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
1187
+ options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
1188
+ if isinstance(options, list):
1189
+ options = {str(i+1): o for i, o in enumerate(options)}
1190
+ correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
1191
+ concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
1192
+ correct_text = ""
1193
+ if isinstance(correct_key, str) and correct_key.strip() in options:
1194
+ correct_text = options[correct_key.strip()]
1195
+ else:
1196
+ correct_text = payload.get("correct_text") or correct_key or ""
1197
+
1198
+ diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
1199
+ q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
1200
+ )
1201
+
1202
+ payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
1203
+
1204
+ qcount += 1
1205
+ output[str(qcount)] = mcq_block[item]
1206
+ if qcount >= n_questions:
1207
+ return output
1208
+
1209
+ return output
1210
+ else:
1211
+ raise ValueError("mode must be 'per_page' or 'rag'.")
1212
+
1213
+ @override
1214
+ def generate_from_qdrant(
1215
+ self,
1216
+ filename: str,
1217
+ collection: str,
1218
+ n_questions: int = 10,
1219
+ mode: str = "rag", # 'per_chunk' or 'rag'
1220
+ questions_per_chunk: int = 3, # used for 'per_chunk'
1221
+ top_k: int = 3, # retrieval size used in RAG
1222
+ temperature: float = 0.2,
1223
+ enable_fiddler: bool = False,
1224
+ target_difficulty: str = 'easy',
1225
+
1226
+ ) -> Dict[str, Any]:
1227
+ if self.qdrant is None:
1228
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
1229
+
1230
+ # get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
1231
+ file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
1232
+ if not file_points:
1233
+ raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
1234
+
1235
+ # create a local list of texts & metadata for sampling
1236
+ texts = []
1237
+ metas = []
1238
+ for p in file_points:
1239
+ payload = p.get("payload", {})
1240
+ text = payload.get("text", "")
1241
+ texts.append(text)
1242
+ metas.append(payload)
1243
+
1244
+ self.texts = texts
1245
+ self.metadata = metas
1246
+ embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
1247
+ if embeddings is None or len(embeddings) == 0:
1248
+ self.embeddings = None
1249
+ self.index = None
1250
+ else:
1251
+ self.embeddings = embeddings.astype("float32")
1252
+
1253
+ # update dim in case embedder changed unexpectedly
1254
+ self.dim = int(self.embeddings.shape[1])
1255
+
1256
+ # build index
1257
+ self._build_faiss_index()
1258
+
1259
+ output = {}
1260
+ qcount = 0
1261
+
1262
+ if mode == "per_chunk":
1263
+ # iterate all chunks (in payload order) and request questions_per_chunk from each
1264
+ for i, txt in enumerate(texts):
1265
+ if not txt.strip():
1266
+ continue
1267
+ to_gen = questions_per_chunk
1268
+ try:
1269
+ mcq_block = new_generate_mcqs_from_text(txt, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=False)
1270
+ except Exception as e:
1271
+ print(f"Generator failed on chunk (index {i}): {e}")
1272
+ continue
1273
+
1274
+ if "error" in list(mcq_block.keys()):
1275
+ return output
1276
+
1277
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
1278
+ qcount += 1
1279
+ output[str(qcount)] = mcq_block[item]
1280
+ if qcount >= n_questions:
1281
+ return output
1282
+ return output
1283
+
1284
+ elif mode == "rag":
1285
+ attempts = 0
1286
+ max_attempts = n_questions * 4
1287
+ while qcount < n_questions and attempts < max_attempts:
1288
+ attempts += 1
1289
+ # create a seed query: pick a random chunk, pick a sentence from it
1290
+ seed_idx = random.randrange(len(self.texts))
1291
+ chunk = self.texts[seed_idx]
1292
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
1293
+ candidate = [s for s in sents if len(s.strip()) > 20]
1294
+ if candidate:
1295
+ seed_sent = random.choice(candidate)
1296
+ else:
1297
+ stripped = chunk.strip()
1298
+ seed_sent = (stripped[:200] if stripped else "[no text available]")
1299
+ query = f"Create questions about: {seed_sent}"
1300
+
1301
+
1302
+ # retrieve top_k chunks from the same file (restricted by filename filter)
1303
+ retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
1304
+ print('retrieved qdrant', retrieved)
1305
+ context_parts = []
1306
+ for payload, score in retrieved:
1307
+ # payload should contain page & chunk_id and text
1308
+ page = payload.get("page", "?")
1309
+ ctxt = payload.get("text", "")
1310
+ context_parts.append(f"[page {page}] {ctxt}")
1311
+ context = "\n\n".join(context_parts)
1312
+
1313
+
1314
+ # q generation
1315
+ try:
1316
+ # Difficulty pipeline: easy, mid, difficult
1317
+ structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
1318
+ mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
1319
+ except Exception as e:
1320
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
1321
+ continue
1322
+
1323
+ if "error" in list(mcq_block.keys()):
1324
+ return output
1325
+
1326
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
1327
+ payload = mcq_block[item]
1328
+ q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
1329
+ options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
1330
+
1331
+ if isinstance(options, list):
1332
+ options = {str(i+1): o for i, o in enumerate(options)}
1333
+
1334
+ correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
1335
+ concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
1336
+
1337
+ correct_text = ""
1338
+ if isinstance(correct_key, str) and correct_key.strip() in options:
1339
+ correct_text = options[correct_key.strip()]
1340
+ else:
1341
+ correct_text = payload.get("correct_text") or correct_key or ""
1342
+
1343
+ #? change estimate
1344
+ diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
1345
+ q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
1346
+ )
1347
+
1348
+ payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
1349
+
1350
+ # CHECK n generation: if number of request mcqs < default generation number e.g. 5 - 3 = 2 < 3 then only genearate 2 mcqs
1351
+ if n_questions - qcount < questions_per_chunk:
1352
+ questions_per_chunk = n_questions - qcount
1353
+
1354
+ qcount += 1 # count number of question
1355
+ print('qcount:', qcount)
1356
+ print('questions_per_chunk:', questions_per_chunk)
1357
+
1358
+ output[str(qcount)] = mcq_block[item]
1359
+ if qcount >= n_questions:
1360
+ return output
1361
+
1362
+ if output is not None:
1363
+ print("output available")
1364
+ return output
1365
+ else:
1366
+ raise ValueError("mode must be 'per_chunk' or 'rag'.")
1367
+
1368
+ @override
1369
+ def _estimate_difficulty_for_generation(
1370
+ self,
1371
+ q_text: str,
1372
+ options: Dict[str, str],
1373
+ correct_text: str,
1374
+ context_text: str = "",
1375
+ concepts_used: Dict = {}
1376
+ ) -> Tuple[float, str]:
1377
+ def safe_map_sim(s):
1378
+ # map potentially [-1,1] cosine-like to [0,1], clamp
1379
+ try:
1380
+ s = float(s)
1381
+ except Exception:
1382
+ return 0.0
1383
+ mapped = (s + 1.0) / 2.0
1384
+ return max(0.0, min(1.0, mapped))
1385
+
1386
+ # embedding support
1387
+ emb_support = 0.0
1388
+ try:
1389
+ stmt = (q_text or "").strip()
1390
+ if correct_text:
1391
+ stmt = f"{stmt} Answer: {correct_text}"
1392
+
1393
+ # use internal retrieve but map returned score
1394
+ res = []
1395
+ try:
1396
+ res = self._retrieve(stmt, top_k=1)
1397
+ except Exception:
1398
+ res = []
1399
+
1400
+ if res:
1401
+ raw_score = float(res[0][1])
1402
+ emb_support = safe_map_sim(raw_score)
1403
+ else:
1404
+ emb_support = 0.0
1405
+ except Exception:
1406
+ emb_support = 0.0
1407
+
1408
+ # distractor sims
1409
+ mean_sim = 0.0
1410
+ distractor_penalty = 0.0
1411
+ amb_flag = 0.0
1412
+ try:
1413
+ keys = list(options.keys())
1414
+ texts = [options[k] for k in keys]
1415
+ if correct_text is None:
1416
+ correct_text = ""
1417
+
1418
+ all_texts = [correct_text] + texts
1419
+ embs = self.embedder.encode(all_texts, convert_to_numpy=True)
1420
+ embs = np.asarray(embs, dtype=float)
1421
+ norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12
1422
+ embs = embs / norms
1423
+ corr = embs[0]
1424
+ opts = embs[1:]
1425
+
1426
+ if opts.size == 0:
1427
+ mean_sim = 0.0
1428
+ distractor_penalty = 0.0
1429
+ gap = 0.0
1430
+ else:
1431
+ sims = (opts @ corr).tolist() # [-1,1]
1432
+ sims_mapped = [safe_map_sim(s) for s in sims] # [0,1]
1433
+ mean_sim = float(sum(sims_mapped) / len(sims_mapped))
1434
+ # gap between best distractor and second best (higher gap -> easier)
1435
+ sorted_s = sorted(sims_mapped, reverse=True)
1436
+ top = sorted_s[0]
1437
+ second = sorted_s[1] if len(sorted_s) > 1 else 0.0
1438
+ gap = top - second
1439
+ # penalties: if distractors are extremely close to correct -> higher penalty
1440
+ too_close_count = sum(1 for s in sims_mapped if s >= 0.85)
1441
+ too_far_count = sum(1 for s in sims_mapped if s <= 0.15)
1442
+ distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped))))
1443
+ amb_flag = 1.0 if top >= 0.8 else 0.0
1444
+ except Exception:
1445
+ mean_sim = 0.0
1446
+ distractor_penalty = 0.0
1447
+ amb_flag = 0.0
1448
+ gap = 0.0
1449
+
1450
+ # question length normalized
1451
+ question_len = len((q_text or "").strip())
1452
+ question_len_norm = min(1.0, question_len / 300.0)
1453
+
1454
+ # count number of concept from string
1455
+ concepts_num = len(concepts_used.keys())
1456
+ if concepts_num < 2:
1457
+ concepts_penalty = 0
1458
+ else:
1459
+ concepts_penalty = concepts_num
1460
+
1461
+ # combine signals using safer semantics:
1462
+ # higher emb_support -> easier (so we subtract a term)
1463
+ # higher distractor_penalty -> harder (add)
1464
+ # better gap -> easier (subtract)
1465
+ # compute score (higher -> harder)
1466
+
1467
+ score = 0
1468
+ score += 0.35 * float(distractor_penalty)
1469
+ score += 0.20 * float(mean_sim)
1470
+ score += 0.22 * float(amb_flag)
1471
+ score += 0.05 * float(question_len_norm)
1472
+ score -= 0.20 * float(gap)
1473
+
1474
+ # clamp
1475
+ score = max(0.0, min(1.0, float(score)))
1476
+ components = {
1477
+ "base": 0.3,
1478
+ "distractor_penalty": 0.35 * float(distractor_penalty),
1479
+ "mean_sim": 0.15 * float(mean_sim),
1480
+ "amb_flag": 0.05 * float(amb_flag),
1481
+ "concepts_num": 0.1 * float(concepts_num),
1482
+ "gap": -0.12 * float(gap),
1483
+ "question_len_norm": 0.05 * float(question_len_norm),
1484
+ "emb_support": -0.45 * float(emb_support),
1485
+ "total_score": score,
1486
+ }
1487
+
1488
+ # label
1489
+ if score <= 0.35:
1490
+ label = "dễ"
1491
+ elif score <= 0.65 and score > 0.35:
1492
+ label = "trung bình"
1493
+ else:
1494
+ label = "khó"
1495
+
1496
+ return score, label, components
requirements.txt CHANGED
@@ -8,3 +8,4 @@ qdrant-client
8
  pymupdf4llm
9
  uuid
10
  huggingface_hub
 
 
8
  pymupdf4llm
9
  uuid
10
  huggingface_hub
11
+ typing_extensions
utils.py CHANGED
@@ -53,12 +53,6 @@ def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: in
53
  resp.raise_for_status()
54
  data = resp.json()
55
 
56
- #save_to_local('test/raw_resp.json', content=data)
57
-
58
- #? Must update within _post_chat because it the original function for LLM generation
59
- update_token_count(token_usage=data['usage']) # get data['usages']['prompt_tokens'] & data['usages']['completion_tokens']
60
- # update_time_info(time_info=data['time_info'])
61
-
62
  # handle various shapes
63
  if "choices" in data and len(data["choices"]) > 0:
64
  # prefer message.content
@@ -70,7 +64,6 @@ def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: in
70
  if "text" in ch:
71
  return ch["text"]
72
 
73
- print(f'Generation Time: {data["time_info"]}')
74
  # final fallback
75
  raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
76
 
@@ -92,10 +85,164 @@ def _safe_extract_json(text: str) -> dict:
92
  return json.loads(fixed)
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def generate_mcqs_from_text(
96
  source_text: str,
97
  n: int = 3,
98
- model: str = "gpt-oss-120b",
99
  temperature: float = 0.2,
100
  enable_fiddler: bool = False,
101
  ) -> Dict[str, Any]:
 
53
  resp.raise_for_status()
54
  data = resp.json()
55
 
 
 
 
 
 
 
56
  # handle various shapes
57
  if "choices" in data and len(data["choices"]) > 0:
58
  # prefer message.content
 
64
  if "text" in ch:
65
  return ch["text"]
66
 
 
67
  # final fallback
68
  raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
69
 
 
85
  return json.loads(fixed)
86
 
87
 
88
+ def structure_context_for_llm(
89
+ source_text: str,
90
+ model: str = "openai/gpt-oss-120b",
91
+ temperature: float = 0.2,
92
+ enable_fiddler = False,
93
+ ) -> Dict[str, Any]:
94
+ """
95
+ Take a long source_text, split into N chunks, and restructure them
96
+ so each chunk is self-contained, structured, and semantically meaningful.
97
+ """
98
+
99
+ system_message = {
100
+ "role": "system",
101
+ "content": (
102
+ "Bạn là một trợ lý hữu ích chuyên xử lý và cấu trúc văn bản để phục vụ mô hình ngôn ngữ (LLM). Trả lời bằng Tiếng Việt\n"
103
+ "Nhiệm vụ của bạn là:\n"
104
+ "- Nếu văn bản dài trên 500 từ chia văn bản thành 2 đoạn (chunk) có ý nghĩa rõ ràng.\n"
105
+ "- Mỗi chunk phải **tự chứa đủ thông tin** (self-contained) để LLM có thể hiểu độc lập.\n"
106
+ "- Xác định **chủ đề chính (topic)** của mỗi chunk và dùng nó làm KEY trong JSON.\n"
107
+ "- Trong mỗi topic, tổ chức thông tin thành cấu trúc rõ ràng gồm các trường:\n"
108
+ " - 'đoạn văn': nội dung gốc đã cấu trúc đầy đủ\n"
109
+ " - 'khái niệm chính': từ điểm chứa các khái niệm chính với khái niệm phụ hỗ trợ khái niệm chính đi kèm nếu có\n"
110
+ " - 'công thức': danh sách công thức (nếu có)\n"
111
+ " - 'ví dụ': ví dụ minh họa (nếu có)\n"
112
+ " - 'tóm tắt': tóm tắt nội dung, dễ hiểu\n"
113
+ "- Giữ ngữ nghĩa liền mạch.\n"
114
+ "- Chỉ TRẢ VỀ MỘT JSON hợp lệ theo schema, không kèm văn bản khác.\n\n"
115
+
116
+ "Chỉ TRẢ VỀ duy nhất MỘT đối tượng JSON theo schema sau và không có bất kỳ văn bản nào khác:\n\n"
117
+ "{\n"
118
+ ' "Tên topic": {"đoạn văn": "nội dung đã cấu trúc của topic 1", "khái niệm chính": {"khái niệm chính 1":["khái niệm phụ", "..."],"khái niệm chính 2":["khái niệm phụ", "..."]}, "công thức": ["..."], "ví dụ": ["..."], "tóm tắt": "tóm tắt ngắn gọn"},\n'
119
+ "}\n"
120
+ )
121
+ }
122
+
123
+ user_message = {
124
+ "role": "user",
125
+ "content": (
126
+ "Hãy chia văn bản sau thành nhiều chunk theo hướng dẫn trên và xuất JSON hợp lệ.\n"
127
+ f"### Văn bản nguồn:\n{source_text}"
128
+ )
129
+ }
130
+
131
+ if enable_fiddler:
132
+ max_conf, max_cat = text_safety_check(user_message['content'])
133
+ if max_conf > 0.5:
134
+ print(f"Harmful content detected: ({max_cat} : {max_conf})")
135
+ return {}
136
+
137
+ raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
138
+ parsed = _safe_extract_json(raw)
139
+ if not isinstance(parsed, dict):
140
+ raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
141
+ return parsed
142
+
143
+
144
+ def new_generate_mcqs_from_text(
145
+ source_text: str,
146
+ n: int = 3,
147
+ model: str = "openai/openai/gpt-oss-120b",
148
+ temperature: float = 0.2,
149
+ enable_fiddler = False,
150
+ target_difficulty: str = "easy",
151
+
152
+ ) -> Dict[str, Any]:
153
+
154
+
155
+ expected_concepts = {
156
+ "easy": 1,
157
+ "medium": 2,
158
+ "hard": (3, 4)
159
+ }
160
+ if isinstance(expected_concepts[target_difficulty], tuple):
161
+ min_concepts, max_concepts = expected_concepts[target_difficulty]
162
+ concept_range = f"{min_concepts}-{max_concepts}"
163
+ else:
164
+ concept_range = expected_concepts[target_difficulty]
165
+
166
+
167
+ difficulty_prompts = {
168
+ "easy": (
169
+ "- Câu hỏi DỄ: kiểm tra duy nhất 1 khái niệm chính cơ bản dễ hiểu, định nghĩa, hoặc công thức đơn giản."
170
+ "- Đáp án có thể tìm thấy trực tiếp trong văn bản."
171
+ "- Ngữ cảnh đủ để hiểu khái niệm chính."
172
+ "- Distractors khác biệt rõ ràng, dễ loại bỏ."
173
+ "- Độ dài câu hỏi ngắn gọn không quá 10-20 từ hoặc ít hơn 120 ký tự, tập trung vào một ý duy nhất.\n"
174
+ ),
175
+ "medium": (
176
+ "- Câu hỏi TRUNG BÌNH kiểm tra khái niệm chính trong văn bản"
177
+ "- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
178
+ "- Các Distractors không quá giống nhau."
179
+ "- Độ dài câu hỏi vừa phải khoảng 23–30 từ hoặc khoảng 150 - 180 ký tự, có thêm chi tiết phụ để suy luận.\n"
180
+ ),
181
+ "hard": (
182
+ "- Câu hỏi KHÓ kiểm tra thông tin được phân tích/tổng hợp"
183
+ "- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
184
+ "- Ít nhất 2 distractors gần giống đáp án đúng, độ tương đồng cao. "
185
+ f"- Đáp án yêu cầu học sinh suy luận hoặc áp dụng công thức vào ví dụ nếu có."
186
+ "- Độ dài câu hỏi dài hơn 35 từ hoặc hơn 200 ký tự.\n \n"
187
+ )
188
+ }
189
+
190
+ difficult_criteria = difficulty_prompts[target_difficulty] # "easy", "medium", "hard"
191
+ print(concept_range)
192
+ system_message = {
193
+ "role": "system",
194
+ "content": (
195
+ "Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm (MCQ). Luôn trả lời bằng tiếng việt"
196
+ f"Đảm bảo chỉ tạo sinh câu trắc nghiệm có độ khó sau {difficult_criteria}"
197
+ f"Quan trọng: Mỗi câu hỏi chỉ sử dụng chính xác {concept_range} khái niệm chính (mỗi khái niệm chính có 1 danh sách khái niệm phụ) từ văn bản nguồn. "
198
+ "Mỗi câu hỏi và đáp án phải dựa trên thông tin từ văn bản nguồn. Không được đưa kiến thức ngoài vào."
199
+ "Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không kèm giải thích hay trường thêm:\n\n"
200
+ "{\n"
201
+ ' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"...", "khái niệm sử dụng": {"khái niệm chính":["khái niệm phụ", "..."], "..."]}},\n'
202
+ ' "2": { ... }\n'
203
+ "}\n\n"
204
+ "Lưu ý:\n"
205
+ f"- Tạo đúng {n} mục, đánh số từ 1 tới {n}.\n"
206
+ "- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
207
+ "- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'options'.\n"
208
+ "- Toàn bộ thông tin cần thiết để trả lời phải nằm trong chính câu hỏi, không tham chiếu lại văn bản nguồn."
209
+ f"- Sử dụng chính xác {concept_range} khái niệm chính"
210
+ )
211
+ }
212
+
213
+ user_message = {
214
+ "role": "user",
215
+ "content": (
216
+ f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Chỉ sử dụng nội dung này làm nguồn duy nhất để xây dựng câu hỏi.\n\n"
217
+
218
+ "### Yêu cầu:\n"
219
+ "- Bám sát vào thông tin trong văn bản; không thêm kiến thức ngoài.\n"
220
+ "- Nếu văn bản thiếu chi tiết, hãy tạo phương án nhiễu (distractors) hợp lý, nhưng phải có thể biện minh từ nội dung hoặc ngữ cảnh.\n"
221
+ f"### Văn bản nguồn:\n{source_text}"
222
+ )
223
+ }
224
+
225
+
226
+ if enable_fiddler:
227
+ max_conf, max_cat = text_safety_check(user_message['content'])
228
+ if max_conf > 0.5:
229
+ print(f"Harmful content detected: ({max_cat} : {max_conf})")
230
+ return {}
231
+
232
+ raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
233
+ # print('\n\n',raw)
234
+ parsed = _safe_extract_json(raw)
235
+ # basic validation
236
+ if not isinstance(parsed, dict) or len(parsed) != n:
237
+ raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
238
+ return parsed
239
+
240
+
241
+
242
  def generate_mcqs_from_text(
243
  source_text: str,
244
  n: int = 3,
245
+ model: str = "openai/gpt-oss-120b",
246
  temperature: float = 0.2,
247
  enable_fiddler: bool = False,
248
  ) -> Dict[str, Any]: