namberino commited on
Commit
a9fb115
·
1 Parent(s): ba3f5c1

Testing new rag difficulty generator

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. generator.py +387 -17
app.py CHANGED
@@ -8,7 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
 
10
  # Import the user's RAGMCQ implementation
11
- from generator import RAGMCQ
12
  from utils import log_pipeline
13
 
14
  app = FastAPI(title="RAG MCQ Generator API")
@@ -23,7 +23,7 @@ app.add_middleware(
23
  )
24
 
25
  # global rag instance
26
- rag: Optional[RAGMCQ] = None
27
 
28
  class GenerateResponse(BaseModel):
29
  mcqs: dict
@@ -37,7 +37,7 @@ def startup_event():
37
  global rag
38
 
39
  # instantiate the heavy object once
40
- rag = RAGMCQ()
41
  print("RAGMCQ instance created on startup.")
42
 
43
  @app.get("/health")
 
8
  from pydantic import BaseModel
9
 
10
  # Import the user's RAGMCQ implementation
11
+ from generator import RAGMCQWithDifficulty
12
  from utils import log_pipeline
13
 
14
  app = FastAPI(title="RAG MCQ Generator API")
 
23
  )
24
 
25
  # global rag instance
26
+ rag: Optional[RAGMCQWithDifficulty] = None
27
 
28
  class GenerateResponse(BaseModel):
29
  mcqs: dict
 
37
  global rag
38
 
39
  # instantiate the heavy object once
40
+ rag = RAGMCQWithDifficulty()
41
  print("RAGMCQ instance created on startup.")
42
 
43
  @app.get("/health")
generator.py CHANGED
@@ -9,6 +9,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder
9
  from transformers import pipeline
10
  from uuid import uuid4
11
  import pymupdf4llm
 
12
 
13
  try:
14
  from qdrant_client import QdrantClient
@@ -31,7 +32,7 @@ try:
31
  except Exception:
32
  _HAS_FAISS = False
33
 
34
- from utils import generate_mcqs_from_text, new_generate_mcqs_from_text, structure_context_for_llm
35
 
36
  from huggingface_hub import login
37
  login(token=os.environ['HF_MODEL_TOKEN'])
@@ -128,7 +129,6 @@ class RAGMCQ:
128
 
129
  return final
130
 
131
-
132
  def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
133
  pages = self.extract_pages(pdf_path)
134
 
@@ -188,7 +188,6 @@ class RAGMCQ:
188
  top_k: int = 3, # chunks to retrieve for each question in rag mode
189
  temperature: float = 0.2,
190
  enable_fiddler: bool = False,
191
- target_difficulty: str = 'easy' # easy, mid, difficult
192
  ) -> Dict[str, Any]:
193
  # build index
194
  self.build_index_from_pdf(pdf_path)
@@ -207,9 +206,8 @@ class RAGMCQ:
207
 
208
  # ask generator
209
  try:
210
- structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
211
  mcq_block = generate_mcqs_from_text(
212
- source_text=chunk_text, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
213
  )
214
  except Exception as e:
215
  # skip this chunk if generator fails
@@ -227,7 +225,6 @@ class RAGMCQ:
227
 
228
  return output
229
 
230
- # pdf gene
231
  elif mode == "rag":
232
  # strategy: create a few natural short queries by sampling sentences or using chunk summaries.
233
  # create queries by sampling chunk text sentences.
@@ -262,8 +259,9 @@ class RAGMCQ:
262
  # call generator for 1 question (or small batch) with the retrieved context
263
  try:
264
  # request 1 question at a time to keep diversity
265
- structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
266
- mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
 
267
  except Exception as e:
268
  print(f"Generator failed during RAG attempt {attempts}: {e}")
269
  continue
@@ -279,18 +277,16 @@ class RAGMCQ:
279
  if isinstance(options, list):
280
  options = {str(i+1): o for i, o in enumerate(options)}
281
  correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
282
- concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
283
  correct_text = ""
284
  if isinstance(correct_key, str) and correct_key.strip() in options:
285
  correct_text = options[correct_key.strip()]
286
  else:
287
  correct_text = payload.get("correct_text") or correct_key or ""
288
-
289
- diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
290
- q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
291
- )
292
 
293
- payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
 
 
 
294
 
295
  qcount += 1
296
  output[str(qcount)] = mcq_block[item]
@@ -840,6 +836,381 @@ class RAGMCQ:
840
  return out
841
 
842
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
843
  def generate_from_qdrant(
844
  self,
845
  filename: str,
@@ -994,7 +1365,7 @@ class RAGMCQ:
994
  else:
995
  raise ValueError("mode must be 'per_chunk' or 'rag'.")
996
 
997
-
998
  def _estimate_difficulty_for_generation(
999
  self,
1000
  q_text: str,
@@ -1092,7 +1463,6 @@ class RAGMCQ:
1092
  # higher distractor_penalty -> harder (add)
1093
  # better gap -> easier (subtract)
1094
  # compute score (higher -> harder)
1095
- # score += 0.12 * float(concepts_penalty)
1096
 
1097
  score = 0
1098
  score += 0.35 * float(distractor_penalty)
@@ -1123,4 +1493,4 @@ class RAGMCQ:
1123
  else:
1124
  label = "khó"
1125
 
1126
- return score, label, components # type: ignore
 
9
  from transformers import pipeline
10
  from uuid import uuid4
11
  import pymupdf4llm
12
+ from typing import override
13
 
14
  try:
15
  from qdrant_client import QdrantClient
 
32
  except Exception:
33
  _HAS_FAISS = False
34
 
35
+ from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json, save_to_local, structure_context_for_llm, new_generate_mcqs_from_text
36
 
37
  from huggingface_hub import login
38
  login(token=os.environ['HF_MODEL_TOKEN'])
 
129
 
130
  return final
131
 
 
132
  def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
133
  pages = self.extract_pages(pdf_path)
134
 
 
188
  top_k: int = 3, # chunks to retrieve for each question in rag mode
189
  temperature: float = 0.2,
190
  enable_fiddler: bool = False,
 
191
  ) -> Dict[str, Any]:
192
  # build index
193
  self.build_index_from_pdf(pdf_path)
 
206
 
207
  # ask generator
208
  try:
 
209
  mcq_block = generate_mcqs_from_text(
210
+ chunk_text, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
211
  )
212
  except Exception as e:
213
  # skip this chunk if generator fails
 
225
 
226
  return output
227
 
 
228
  elif mode == "rag":
229
  # strategy: create a few natural short queries by sampling sentences or using chunk summaries.
230
  # create queries by sampling chunk text sentences.
 
259
  # call generator for 1 question (or small batch) with the retrieved context
260
  try:
261
  # request 1 question at a time to keep diversity
262
+ mcq_block = generate_mcqs_from_text(
263
+ context, n=1, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
264
+ )
265
  except Exception as e:
266
  print(f"Generator failed during RAG attempt {attempts}: {e}")
267
  continue
 
277
  if isinstance(options, list):
278
  options = {str(i+1): o for i, o in enumerate(options)}
279
  correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
 
280
  correct_text = ""
281
  if isinstance(correct_key, str) and correct_key.strip() in options:
282
  correct_text = options[correct_key.strip()]
283
  else:
284
  correct_text = payload.get("correct_text") or correct_key or ""
 
 
 
 
285
 
286
+ diff_score, diff_label = self._estimate_difficulty_for_generation(
287
+ q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context
288
+ )
289
+ payload["difficulty"] = {"score": diff_score, "label": diff_label}
290
 
291
  qcount += 1
292
  output[str(qcount)] = mcq_block[item]
 
836
  return out
837
 
838
 
839
+ def generate_from_qdrant(
840
+ self,
841
+ filename: str,
842
+ collection: str,
843
+ n_questions: int = 10,
844
+ mode: str = "rag", # 'per_chunk' or 'rag'
845
+ questions_per_chunk: int = 3, # used for 'per_chunk'
846
+ top_k: int = 3, # retrieval size used in RAG
847
+ temperature: float = 0.2,
848
+ enable_fiddler: bool = False,
849
+ ) -> Dict[str, Any]:
850
+ if self.qdrant is None:
851
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
852
+
853
+ # get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
854
+ file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
855
+ if not file_points:
856
+ raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
857
+
858
+ # create a local list of texts & metadata for sampling
859
+ texts = []
860
+ metas = []
861
+ for p in file_points:
862
+ payload = p.get("payload", {})
863
+ text = payload.get("text", "")
864
+ texts.append(text)
865
+ metas.append(payload)
866
+
867
+ self.texts = texts
868
+ self.metadata = metas
869
+ embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
870
+ if embeddings is None or len(embeddings) == 0:
871
+ self.embeddings = None
872
+ self.index = None
873
+ else:
874
+ self.embeddings = embeddings.astype("float32")
875
+
876
+ # update dim in case embedder changed unexpectedly
877
+ self.dim = int(self.embeddings.shape[1])
878
+
879
+ # build index
880
+ self._build_faiss_index()
881
+
882
+ output = {}
883
+ qcount = 0
884
+
885
+ if mode == "per_chunk":
886
+ # iterate all chunks (in payload order) and request questions_per_chunk from each
887
+ for i, txt in enumerate(texts):
888
+ if not txt.strip():
889
+ continue
890
+ to_gen = questions_per_chunk
891
+ try:
892
+ mcq_block = generate_mcqs_from_text(txt, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler)
893
+ except Exception as e:
894
+ print(f"Generator failed on chunk (index {i}): {e}")
895
+ continue
896
+
897
+ if "error" in list(mcq_block.keys()):
898
+ return output
899
+
900
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
901
+ qcount += 1
902
+ output[str(qcount)] = mcq_block[item]
903
+ if qcount >= n_questions:
904
+ return output
905
+ return output
906
+
907
+ elif mode == "rag":
908
+ attempts = 0
909
+ max_attempts = n_questions * 4
910
+ while qcount < n_questions and attempts < max_attempts:
911
+ attempts += 1
912
+ # create a seed query: pick a random chunk, pick a sentence from it
913
+ seed_idx = random.randrange(len(self.texts))
914
+ chunk = self.texts[seed_idx]
915
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
916
+ candidate = [s for s in sents if len(s.strip()) > 20]
917
+ if candidate:
918
+ seed_sent = random.choice(candidate)
919
+ else:
920
+ stripped = chunk.strip()
921
+ seed_sent = (stripped[:200] if stripped else "[no text available]")
922
+ query = f"Create questions about: {seed_sent}"
923
+
924
+
925
+ # retrieve top_k chunks from the same file (restricted by filename filter)
926
+ retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
927
+ context_parts = []
928
+ for payload, score in retrieved:
929
+ # payload should contain page & chunk_id and text
930
+ page = payload.get("page", "?")
931
+ ctxt = payload.get("text", "")
932
+ context_parts.append(f"[page {page}] {ctxt}")
933
+ context = "\n\n".join(context_parts)
934
+
935
+ try:
936
+ mcq_block = generate_mcqs_from_text(context, n=1, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler)
937
+ except Exception as e:
938
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
939
+ continue
940
+
941
+ if "error" in list(mcq_block.keys()):
942
+ return output
943
+
944
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
945
+ payload = mcq_block[item]
946
+ q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
947
+ options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
948
+ if isinstance(options, list):
949
+ options = {str(i+1): o for i, o in enumerate(options)}
950
+ correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
951
+ correct_text = ""
952
+ if isinstance(correct_key, str) and correct_key.strip() in options:
953
+ correct_text = options[correct_key.strip()]
954
+ else:
955
+ correct_text = payload.get("correct_text") or correct_key or ""
956
+
957
+ diff_score, diff_label = self._estimate_difficulty_for_generation(
958
+ q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context
959
+ )
960
+ payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
961
+
962
+ qcount += 1
963
+ output[str(qcount)] = mcq_block[item]
964
+ if qcount >= n_questions:
965
+ return output
966
+ return output
967
+ else:
968
+ raise ValueError("mode must be 'per_chunk' or 'rag'.")
969
+
970
+ def _estimate_difficulty_for_generation(
971
+ self,
972
+ q_text: str,
973
+ options: Dict[str, str],
974
+ correct_text: str,
975
+ context_text: str = "",
976
+ ) -> Tuple[float, str]:
977
+ def safe_map_sim(s):
978
+ # map potentially [-1,1] cosine-like to [0,1], clamp
979
+ try:
980
+ s = float(s)
981
+ except Exception:
982
+ return 0.0
983
+ mapped = (s + 1.0) / 2.0
984
+ return max(0.0, min(1.0, mapped))
985
+
986
+ # embedding support
987
+ emb_support = 0.0
988
+ try:
989
+ stmt = (q_text or "").strip()
990
+ if correct_text:
991
+ stmt = f"{stmt} Answer: {correct_text}"
992
+
993
+ # use internal retrieve but map returned score
994
+ res = []
995
+ try:
996
+ res = self._retrieve(stmt, top_k=1)
997
+ except Exception:
998
+ res = []
999
+
1000
+ if res:
1001
+ raw_score = float(res[0][1])
1002
+ emb_support = safe_map_sim(raw_score)
1003
+ else:
1004
+ emb_support = 0.0
1005
+ except Exception:
1006
+ emb_support = 0.0
1007
+
1008
+ # distractor sims
1009
+ mean_sim = 0.0
1010
+ distractor_penalty = 0.0
1011
+ amb_flag = 0.0
1012
+ try:
1013
+ keys = list(options.keys())
1014
+ texts = [options[k] for k in keys]
1015
+ if correct_text is None:
1016
+ correct_text = ""
1017
+
1018
+ all_texts = [correct_text] + texts
1019
+ embs = self.embedder.encode(all_texts, convert_to_numpy=True)
1020
+ embs = np.asarray(embs, dtype=float)
1021
+ norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12
1022
+ embs = embs / norms
1023
+ corr = embs[0]
1024
+ opts = embs[1:]
1025
+
1026
+ if opts.size == 0:
1027
+ mean_sim = 0.0
1028
+ distractor_penalty = 0.0
1029
+ gap = 0.0
1030
+ else:
1031
+ sims = (opts @ corr).tolist() # [-1,1]
1032
+ sims_mapped = [safe_map_sim(s) for s in sims] # [0,1]
1033
+ mean_sim = float(sum(sims_mapped) / len(sims_mapped))
1034
+ # gap between best distractor and second best (higher gap -> easier)
1035
+ sorted_s = sorted(sims_mapped, reverse=True)
1036
+ top = sorted_s[0]
1037
+ second = sorted_s[1] if len(sorted_s) > 1 else 0.0
1038
+ gap = top - second
1039
+ # penalties: if distractors are extremely close to correct -> higher penalty
1040
+ too_close_count = sum(1 for s in sims_mapped if s >= 0.85)
1041
+ too_far_count = sum(1 for s in sims_mapped if s <= 0.15)
1042
+ distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped))))
1043
+ amb_flag = 1.0 if top >= 0.9 else 0.0
1044
+ except Exception:
1045
+ mean_sim = 0.0
1046
+ distractor_penalty = 0.0
1047
+ amb_flag = 0.0
1048
+ gap = 0.0
1049
+
1050
+ # stem length normalized
1051
+ qlen = len((q_text or "").strip())
1052
+ qlen_norm = min(1.0, qlen / 300.0)
1053
+
1054
+ # combine signals using safer semantics:
1055
+ # higher emb_support -> easier (so we subtract a term)
1056
+ # higher distractor_penalty -> harder (add)
1057
+ # better gap -> easier (subtract)
1058
+ # compute score (higher -> harder)
1059
+ score = 0
1060
+ score += 0.35 * float(distractor_penalty)
1061
+ score += 0.20 * float(mean_sim)
1062
+ score += 0.22 * float(amb_flag)
1063
+ score += 0.05 * float(qlen_norm)
1064
+ score -= 0.20 * float(gap)
1065
+
1066
+ # clamp
1067
+ score = max(0.0, min(1.0, float(score)))
1068
+
1069
+ # label
1070
+ if score <= 0.33:
1071
+ label = "dễ"
1072
+ elif score <= 0.66 and score > 0.33:
1073
+ label = "trung bình"
1074
+ else:
1075
+ label = "khó"
1076
+
1077
+ return score, label
1078
+
1079
+ class RAGMCQWithDifficulty(RAGMCQ):
1080
+ def __init__(
1081
+ self,
1082
+ embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
1083
+ generation_model: str = "openai/gpt-oss-120b",
1084
+ qdrant_url: str = os.environ.get('QDRANT_URL') or "",
1085
+ qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "",
1086
+ qdrant_prefer_grpc: bool = False,
1087
+ ):
1088
+ super(embedder_model, generation_model, qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
1089
+
1090
+ @override
1091
+ def generate_from_pdf(
1092
+ self,
1093
+ pdf_path: str,
1094
+ n_questions: int = 10,
1095
+ mode: str = "rag", # per_page or rag
1096
+ questions_per_page: int = 3, # for per_page mode
1097
+ top_k: int = 3, # chunks to retrieve for each question in rag mode
1098
+ temperature: float = 0.2,
1099
+ enable_fiddler: bool = False,
1100
+ target_difficulty: str = 'easy' # easy, mid, difficult
1101
+ ) -> Dict[str, Any]:
1102
+ # build index
1103
+ self.build_index_from_pdf(pdf_path)
1104
+
1105
+ output: Dict[str, Any] = {}
1106
+ qcount = 0
1107
+
1108
+ if mode == "per_page":
1109
+ # iterate pages -> chunks
1110
+ for idx, meta in enumerate(self.metadata):
1111
+ chunk_text = self.texts[idx]
1112
+
1113
+ if not chunk_text.strip():
1114
+ continue
1115
+ to_gen = questions_per_page
1116
+
1117
+ # ask generator
1118
+ try:
1119
+ structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
1120
+ mcq_block = generate_mcqs_from_text(
1121
+ source_text=chunk_text, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
1122
+ )
1123
+ except Exception as e:
1124
+ # skip this chunk if generator fails
1125
+ print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
1126
+ continue
1127
+
1128
+ if "error" in list(mcq_block.keys()):
1129
+ return output
1130
+
1131
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
1132
+ qcount += 1
1133
+ output[str(qcount)] = mcq_block[item]
1134
+ if qcount >= n_questions:
1135
+ return output
1136
+
1137
+ return output
1138
+
1139
+ # pdf gene
1140
+ elif mode == "rag":
1141
+ # strategy: create a few natural short queries by sampling sentences or using chunk summaries.
1142
+ # create queries by sampling chunk text sentences.
1143
+ # stop when n_questions reached or max_attempts exceeded.
1144
+ attempts = 0
1145
+ max_attempts = n_questions * 4
1146
+
1147
+ while qcount < n_questions and attempts < max_attempts:
1148
+ attempts += 1
1149
+ # create a seed query: pick a random chunk, pick a sentence from it
1150
+ seed_idx = random.randrange(len(self.texts))
1151
+ chunk = self.texts[seed_idx]
1152
+
1153
+ #? investigate better Chunking Strategy
1154
+ #with open("chunks.txt", "a", encoding="utf-8") as f:
1155
+ #f.write(chunk + "\n")
1156
+
1157
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
1158
+ seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
1159
+ query = f"Create questions about: {seed_sent}"
1160
+
1161
+ # retrieve top_k chunks
1162
+ retrieved = self._retrieve(query, top_k=top_k)
1163
+ context_parts = []
1164
+ for ridx, score in retrieved:
1165
+ md = self.metadata[ridx]
1166
+ context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
1167
+ context = "\n\n".join(context_parts)
1168
+
1169
+ # save_to_local('test/context.md', content=context)
1170
+
1171
+ # call generator for 1 question (or small batch) with the retrieved context
1172
+ try:
1173
+ # request 1 question at a time to keep diversity
1174
+ structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
1175
+ mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
1176
+ except Exception as e:
1177
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
1178
+ continue
1179
+
1180
+ if "error" in list(mcq_block.keys()):
1181
+ return output
1182
+
1183
+ # append result(s)
1184
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
1185
+ payload = mcq_block[item]
1186
+ q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
1187
+ options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
1188
+ if isinstance(options, list):
1189
+ options = {str(i+1): o for i, o in enumerate(options)}
1190
+ correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
1191
+ concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
1192
+ correct_text = ""
1193
+ if isinstance(correct_key, str) and correct_key.strip() in options:
1194
+ correct_text = options[correct_key.strip()]
1195
+ else:
1196
+ correct_text = payload.get("correct_text") or correct_key or ""
1197
+
1198
+ diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
1199
+ q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
1200
+ )
1201
+
1202
+ payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
1203
+
1204
+ qcount += 1
1205
+ output[str(qcount)] = mcq_block[item]
1206
+ if qcount >= n_questions:
1207
+ return output
1208
+
1209
+ return output
1210
+ else:
1211
+ raise ValueError("mode must be 'per_page' or 'rag'.")
1212
+
1213
+ @override
1214
  def generate_from_qdrant(
1215
  self,
1216
  filename: str,
 
1365
  else:
1366
  raise ValueError("mode must be 'per_chunk' or 'rag'.")
1367
 
1368
+ @override
1369
  def _estimate_difficulty_for_generation(
1370
  self,
1371
  q_text: str,
 
1463
  # higher distractor_penalty -> harder (add)
1464
  # better gap -> easier (subtract)
1465
  # compute score (higher -> harder)
 
1466
 
1467
  score = 0
1468
  score += 0.35 * float(distractor_penalty)
 
1493
  else:
1494
  label = "khó"
1495
 
1496
+ return score, label, components