Spaces:
Sleeping
Sleeping
namberino commited on
Commit ·
a9fb115
1
Parent(s): ba3f5c1
Testing new rag difficulty generator
Browse files- app.py +3 -3
- generator.py +387 -17
app.py
CHANGED
|
@@ -8,7 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
|
| 10 |
# Import the user's RAGMCQ implementation
|
| 11 |
-
from generator import
|
| 12 |
from utils import log_pipeline
|
| 13 |
|
| 14 |
app = FastAPI(title="RAG MCQ Generator API")
|
|
@@ -23,7 +23,7 @@ app.add_middleware(
|
|
| 23 |
)
|
| 24 |
|
| 25 |
# global rag instance
|
| 26 |
-
rag: Optional[
|
| 27 |
|
| 28 |
class GenerateResponse(BaseModel):
|
| 29 |
mcqs: dict
|
|
@@ -37,7 +37,7 @@ def startup_event():
|
|
| 37 |
global rag
|
| 38 |
|
| 39 |
# instantiate the heavy object once
|
| 40 |
-
rag =
|
| 41 |
print("RAGMCQ instance created on startup.")
|
| 42 |
|
| 43 |
@app.get("/health")
|
|
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
|
| 10 |
# Import the user's RAGMCQ implementation
|
| 11 |
+
from generator import RAGMCQWithDifficulty
|
| 12 |
from utils import log_pipeline
|
| 13 |
|
| 14 |
app = FastAPI(title="RAG MCQ Generator API")
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
# global rag instance
|
| 26 |
+
rag: Optional[RAGMCQWithDifficulty] = None
|
| 27 |
|
| 28 |
class GenerateResponse(BaseModel):
|
| 29 |
mcqs: dict
|
|
|
|
| 37 |
global rag
|
| 38 |
|
| 39 |
# instantiate the heavy object once
|
| 40 |
+
rag = RAGMCQWithDifficulty()
|
| 41 |
print("RAGMCQ instance created on startup.")
|
| 42 |
|
| 43 |
@app.get("/health")
|
generator.py
CHANGED
|
@@ -9,6 +9,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder
|
|
| 9 |
from transformers import pipeline
|
| 10 |
from uuid import uuid4
|
| 11 |
import pymupdf4llm
|
|
|
|
| 12 |
|
| 13 |
try:
|
| 14 |
from qdrant_client import QdrantClient
|
|
@@ -31,7 +32,7 @@ try:
|
|
| 31 |
except Exception:
|
| 32 |
_HAS_FAISS = False
|
| 33 |
|
| 34 |
-
from utils import generate_mcqs_from_text,
|
| 35 |
|
| 36 |
from huggingface_hub import login
|
| 37 |
login(token=os.environ['HF_MODEL_TOKEN'])
|
|
@@ -128,7 +129,6 @@ class RAGMCQ:
|
|
| 128 |
|
| 129 |
return final
|
| 130 |
|
| 131 |
-
|
| 132 |
def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
|
| 133 |
pages = self.extract_pages(pdf_path)
|
| 134 |
|
|
@@ -188,7 +188,6 @@ class RAGMCQ:
|
|
| 188 |
top_k: int = 3, # chunks to retrieve for each question in rag mode
|
| 189 |
temperature: float = 0.2,
|
| 190 |
enable_fiddler: bool = False,
|
| 191 |
-
target_difficulty: str = 'easy' # easy, mid, difficult
|
| 192 |
) -> Dict[str, Any]:
|
| 193 |
# build index
|
| 194 |
self.build_index_from_pdf(pdf_path)
|
|
@@ -207,9 +206,8 @@ class RAGMCQ:
|
|
| 207 |
|
| 208 |
# ask generator
|
| 209 |
try:
|
| 210 |
-
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 211 |
mcq_block = generate_mcqs_from_text(
|
| 212 |
-
|
| 213 |
)
|
| 214 |
except Exception as e:
|
| 215 |
# skip this chunk if generator fails
|
|
@@ -227,7 +225,6 @@ class RAGMCQ:
|
|
| 227 |
|
| 228 |
return output
|
| 229 |
|
| 230 |
-
# pdf gene
|
| 231 |
elif mode == "rag":
|
| 232 |
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
|
| 233 |
# create queries by sampling chunk text sentences.
|
|
@@ -262,8 +259,9 @@ class RAGMCQ:
|
|
| 262 |
# call generator for 1 question (or small batch) with the retrieved context
|
| 263 |
try:
|
| 264 |
# request 1 question at a time to keep diversity
|
| 265 |
-
|
| 266 |
-
|
|
|
|
| 267 |
except Exception as e:
|
| 268 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 269 |
continue
|
|
@@ -279,18 +277,16 @@ class RAGMCQ:
|
|
| 279 |
if isinstance(options, list):
|
| 280 |
options = {str(i+1): o for i, o in enumerate(options)}
|
| 281 |
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
|
| 282 |
-
concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
|
| 283 |
correct_text = ""
|
| 284 |
if isinstance(correct_key, str) and correct_key.strip() in options:
|
| 285 |
correct_text = options[correct_key.strip()]
|
| 286 |
else:
|
| 287 |
correct_text = payload.get("correct_text") or correct_key or ""
|
| 288 |
-
|
| 289 |
-
diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
|
| 290 |
-
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
|
| 291 |
-
)
|
| 292 |
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
qcount += 1
|
| 296 |
output[str(qcount)] = mcq_block[item]
|
|
@@ -840,6 +836,381 @@ class RAGMCQ:
|
|
| 840 |
return out
|
| 841 |
|
| 842 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 843 |
def generate_from_qdrant(
|
| 844 |
self,
|
| 845 |
filename: str,
|
|
@@ -994,7 +1365,7 @@ class RAGMCQ:
|
|
| 994 |
else:
|
| 995 |
raise ValueError("mode must be 'per_chunk' or 'rag'.")
|
| 996 |
|
| 997 |
-
|
| 998 |
def _estimate_difficulty_for_generation(
|
| 999 |
self,
|
| 1000 |
q_text: str,
|
|
@@ -1092,7 +1463,6 @@ class RAGMCQ:
|
|
| 1092 |
# higher distractor_penalty -> harder (add)
|
| 1093 |
# better gap -> easier (subtract)
|
| 1094 |
# compute score (higher -> harder)
|
| 1095 |
-
# score += 0.12 * float(concepts_penalty)
|
| 1096 |
|
| 1097 |
score = 0
|
| 1098 |
score += 0.35 * float(distractor_penalty)
|
|
@@ -1123,4 +1493,4 @@ class RAGMCQ:
|
|
| 1123 |
else:
|
| 1124 |
label = "khó"
|
| 1125 |
|
| 1126 |
-
return score, label, components
|
|
|
|
| 9 |
from transformers import pipeline
|
| 10 |
from uuid import uuid4
|
| 11 |
import pymupdf4llm
|
| 12 |
+
from typing import override
|
| 13 |
|
| 14 |
try:
|
| 15 |
from qdrant_client import QdrantClient
|
|
|
|
| 32 |
except Exception:
|
| 33 |
_HAS_FAISS = False
|
| 34 |
|
| 35 |
+
from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json, save_to_local, structure_context_for_llm, new_generate_mcqs_from_text
|
| 36 |
|
| 37 |
from huggingface_hub import login
|
| 38 |
login(token=os.environ['HF_MODEL_TOKEN'])
|
|
|
|
| 129 |
|
| 130 |
return final
|
| 131 |
|
|
|
|
| 132 |
def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
|
| 133 |
pages = self.extract_pages(pdf_path)
|
| 134 |
|
|
|
|
| 188 |
top_k: int = 3, # chunks to retrieve for each question in rag mode
|
| 189 |
temperature: float = 0.2,
|
| 190 |
enable_fiddler: bool = False,
|
|
|
|
| 191 |
) -> Dict[str, Any]:
|
| 192 |
# build index
|
| 193 |
self.build_index_from_pdf(pdf_path)
|
|
|
|
| 206 |
|
| 207 |
# ask generator
|
| 208 |
try:
|
|
|
|
| 209 |
mcq_block = generate_mcqs_from_text(
|
| 210 |
+
chunk_text, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
|
| 211 |
)
|
| 212 |
except Exception as e:
|
| 213 |
# skip this chunk if generator fails
|
|
|
|
| 225 |
|
| 226 |
return output
|
| 227 |
|
|
|
|
| 228 |
elif mode == "rag":
|
| 229 |
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
|
| 230 |
# create queries by sampling chunk text sentences.
|
|
|
|
| 259 |
# call generator for 1 question (or small batch) with the retrieved context
|
| 260 |
try:
|
| 261 |
# request 1 question at a time to keep diversity
|
| 262 |
+
mcq_block = generate_mcqs_from_text(
|
| 263 |
+
context, n=1, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
|
| 264 |
+
)
|
| 265 |
except Exception as e:
|
| 266 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 267 |
continue
|
|
|
|
| 277 |
if isinstance(options, list):
|
| 278 |
options = {str(i+1): o for i, o in enumerate(options)}
|
| 279 |
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
|
|
|
|
| 280 |
correct_text = ""
|
| 281 |
if isinstance(correct_key, str) and correct_key.strip() in options:
|
| 282 |
correct_text = options[correct_key.strip()]
|
| 283 |
else:
|
| 284 |
correct_text = payload.get("correct_text") or correct_key or ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
|
| 286 |
+
diff_score, diff_label = self._estimate_difficulty_for_generation(
|
| 287 |
+
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context
|
| 288 |
+
)
|
| 289 |
+
payload["difficulty"] = {"score": diff_score, "label": diff_label}
|
| 290 |
|
| 291 |
qcount += 1
|
| 292 |
output[str(qcount)] = mcq_block[item]
|
|
|
|
| 836 |
return out
|
| 837 |
|
| 838 |
|
| 839 |
+
def generate_from_qdrant(
|
| 840 |
+
self,
|
| 841 |
+
filename: str,
|
| 842 |
+
collection: str,
|
| 843 |
+
n_questions: int = 10,
|
| 844 |
+
mode: str = "rag", # 'per_chunk' or 'rag'
|
| 845 |
+
questions_per_chunk: int = 3, # used for 'per_chunk'
|
| 846 |
+
top_k: int = 3, # retrieval size used in RAG
|
| 847 |
+
temperature: float = 0.2,
|
| 848 |
+
enable_fiddler: bool = False,
|
| 849 |
+
) -> Dict[str, Any]:
|
| 850 |
+
if self.qdrant is None:
|
| 851 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 852 |
+
|
| 853 |
+
# get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
|
| 854 |
+
file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
|
| 855 |
+
if not file_points:
|
| 856 |
+
raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
|
| 857 |
+
|
| 858 |
+
# create a local list of texts & metadata for sampling
|
| 859 |
+
texts = []
|
| 860 |
+
metas = []
|
| 861 |
+
for p in file_points:
|
| 862 |
+
payload = p.get("payload", {})
|
| 863 |
+
text = payload.get("text", "")
|
| 864 |
+
texts.append(text)
|
| 865 |
+
metas.append(payload)
|
| 866 |
+
|
| 867 |
+
self.texts = texts
|
| 868 |
+
self.metadata = metas
|
| 869 |
+
embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
|
| 870 |
+
if embeddings is None or len(embeddings) == 0:
|
| 871 |
+
self.embeddings = None
|
| 872 |
+
self.index = None
|
| 873 |
+
else:
|
| 874 |
+
self.embeddings = embeddings.astype("float32")
|
| 875 |
+
|
| 876 |
+
# update dim in case embedder changed unexpectedly
|
| 877 |
+
self.dim = int(self.embeddings.shape[1])
|
| 878 |
+
|
| 879 |
+
# build index
|
| 880 |
+
self._build_faiss_index()
|
| 881 |
+
|
| 882 |
+
output = {}
|
| 883 |
+
qcount = 0
|
| 884 |
+
|
| 885 |
+
if mode == "per_chunk":
|
| 886 |
+
# iterate all chunks (in payload order) and request questions_per_chunk from each
|
| 887 |
+
for i, txt in enumerate(texts):
|
| 888 |
+
if not txt.strip():
|
| 889 |
+
continue
|
| 890 |
+
to_gen = questions_per_chunk
|
| 891 |
+
try:
|
| 892 |
+
mcq_block = generate_mcqs_from_text(txt, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler)
|
| 893 |
+
except Exception as e:
|
| 894 |
+
print(f"Generator failed on chunk (index {i}): {e}")
|
| 895 |
+
continue
|
| 896 |
+
|
| 897 |
+
if "error" in list(mcq_block.keys()):
|
| 898 |
+
return output
|
| 899 |
+
|
| 900 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 901 |
+
qcount += 1
|
| 902 |
+
output[str(qcount)] = mcq_block[item]
|
| 903 |
+
if qcount >= n_questions:
|
| 904 |
+
return output
|
| 905 |
+
return output
|
| 906 |
+
|
| 907 |
+
elif mode == "rag":
|
| 908 |
+
attempts = 0
|
| 909 |
+
max_attempts = n_questions * 4
|
| 910 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 911 |
+
attempts += 1
|
| 912 |
+
# create a seed query: pick a random chunk, pick a sentence from it
|
| 913 |
+
seed_idx = random.randrange(len(self.texts))
|
| 914 |
+
chunk = self.texts[seed_idx]
|
| 915 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 916 |
+
candidate = [s for s in sents if len(s.strip()) > 20]
|
| 917 |
+
if candidate:
|
| 918 |
+
seed_sent = random.choice(candidate)
|
| 919 |
+
else:
|
| 920 |
+
stripped = chunk.strip()
|
| 921 |
+
seed_sent = (stripped[:200] if stripped else "[no text available]")
|
| 922 |
+
query = f"Create questions about: {seed_sent}"
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
# retrieve top_k chunks from the same file (restricted by filename filter)
|
| 926 |
+
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
|
| 927 |
+
context_parts = []
|
| 928 |
+
for payload, score in retrieved:
|
| 929 |
+
# payload should contain page & chunk_id and text
|
| 930 |
+
page = payload.get("page", "?")
|
| 931 |
+
ctxt = payload.get("text", "")
|
| 932 |
+
context_parts.append(f"[page {page}] {ctxt}")
|
| 933 |
+
context = "\n\n".join(context_parts)
|
| 934 |
+
|
| 935 |
+
try:
|
| 936 |
+
mcq_block = generate_mcqs_from_text(context, n=1, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler)
|
| 937 |
+
except Exception as e:
|
| 938 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 939 |
+
continue
|
| 940 |
+
|
| 941 |
+
if "error" in list(mcq_block.keys()):
|
| 942 |
+
return output
|
| 943 |
+
|
| 944 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 945 |
+
payload = mcq_block[item]
|
| 946 |
+
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
|
| 947 |
+
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
|
| 948 |
+
if isinstance(options, list):
|
| 949 |
+
options = {str(i+1): o for i, o in enumerate(options)}
|
| 950 |
+
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
|
| 951 |
+
correct_text = ""
|
| 952 |
+
if isinstance(correct_key, str) and correct_key.strip() in options:
|
| 953 |
+
correct_text = options[correct_key.strip()]
|
| 954 |
+
else:
|
| 955 |
+
correct_text = payload.get("correct_text") or correct_key or ""
|
| 956 |
+
|
| 957 |
+
diff_score, diff_label = self._estimate_difficulty_for_generation(
|
| 958 |
+
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context
|
| 959 |
+
)
|
| 960 |
+
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
|
| 961 |
+
|
| 962 |
+
qcount += 1
|
| 963 |
+
output[str(qcount)] = mcq_block[item]
|
| 964 |
+
if qcount >= n_questions:
|
| 965 |
+
return output
|
| 966 |
+
return output
|
| 967 |
+
else:
|
| 968 |
+
raise ValueError("mode must be 'per_chunk' or 'rag'.")
|
| 969 |
+
|
| 970 |
+
def _estimate_difficulty_for_generation(
|
| 971 |
+
self,
|
| 972 |
+
q_text: str,
|
| 973 |
+
options: Dict[str, str],
|
| 974 |
+
correct_text: str,
|
| 975 |
+
context_text: str = "",
|
| 976 |
+
) -> Tuple[float, str]:
|
| 977 |
+
def safe_map_sim(s):
|
| 978 |
+
# map potentially [-1,1] cosine-like to [0,1], clamp
|
| 979 |
+
try:
|
| 980 |
+
s = float(s)
|
| 981 |
+
except Exception:
|
| 982 |
+
return 0.0
|
| 983 |
+
mapped = (s + 1.0) / 2.0
|
| 984 |
+
return max(0.0, min(1.0, mapped))
|
| 985 |
+
|
| 986 |
+
# embedding support
|
| 987 |
+
emb_support = 0.0
|
| 988 |
+
try:
|
| 989 |
+
stmt = (q_text or "").strip()
|
| 990 |
+
if correct_text:
|
| 991 |
+
stmt = f"{stmt} Answer: {correct_text}"
|
| 992 |
+
|
| 993 |
+
# use internal retrieve but map returned score
|
| 994 |
+
res = []
|
| 995 |
+
try:
|
| 996 |
+
res = self._retrieve(stmt, top_k=1)
|
| 997 |
+
except Exception:
|
| 998 |
+
res = []
|
| 999 |
+
|
| 1000 |
+
if res:
|
| 1001 |
+
raw_score = float(res[0][1])
|
| 1002 |
+
emb_support = safe_map_sim(raw_score)
|
| 1003 |
+
else:
|
| 1004 |
+
emb_support = 0.0
|
| 1005 |
+
except Exception:
|
| 1006 |
+
emb_support = 0.0
|
| 1007 |
+
|
| 1008 |
+
# distractor sims
|
| 1009 |
+
mean_sim = 0.0
|
| 1010 |
+
distractor_penalty = 0.0
|
| 1011 |
+
amb_flag = 0.0
|
| 1012 |
+
try:
|
| 1013 |
+
keys = list(options.keys())
|
| 1014 |
+
texts = [options[k] for k in keys]
|
| 1015 |
+
if correct_text is None:
|
| 1016 |
+
correct_text = ""
|
| 1017 |
+
|
| 1018 |
+
all_texts = [correct_text] + texts
|
| 1019 |
+
embs = self.embedder.encode(all_texts, convert_to_numpy=True)
|
| 1020 |
+
embs = np.asarray(embs, dtype=float)
|
| 1021 |
+
norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12
|
| 1022 |
+
embs = embs / norms
|
| 1023 |
+
corr = embs[0]
|
| 1024 |
+
opts = embs[1:]
|
| 1025 |
+
|
| 1026 |
+
if opts.size == 0:
|
| 1027 |
+
mean_sim = 0.0
|
| 1028 |
+
distractor_penalty = 0.0
|
| 1029 |
+
gap = 0.0
|
| 1030 |
+
else:
|
| 1031 |
+
sims = (opts @ corr).tolist() # [-1,1]
|
| 1032 |
+
sims_mapped = [safe_map_sim(s) for s in sims] # [0,1]
|
| 1033 |
+
mean_sim = float(sum(sims_mapped) / len(sims_mapped))
|
| 1034 |
+
# gap between best distractor and second best (higher gap -> easier)
|
| 1035 |
+
sorted_s = sorted(sims_mapped, reverse=True)
|
| 1036 |
+
top = sorted_s[0]
|
| 1037 |
+
second = sorted_s[1] if len(sorted_s) > 1 else 0.0
|
| 1038 |
+
gap = top - second
|
| 1039 |
+
# penalties: if distractors are extremely close to correct -> higher penalty
|
| 1040 |
+
too_close_count = sum(1 for s in sims_mapped if s >= 0.85)
|
| 1041 |
+
too_far_count = sum(1 for s in sims_mapped if s <= 0.15)
|
| 1042 |
+
distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped))))
|
| 1043 |
+
amb_flag = 1.0 if top >= 0.9 else 0.0
|
| 1044 |
+
except Exception:
|
| 1045 |
+
mean_sim = 0.0
|
| 1046 |
+
distractor_penalty = 0.0
|
| 1047 |
+
amb_flag = 0.0
|
| 1048 |
+
gap = 0.0
|
| 1049 |
+
|
| 1050 |
+
# stem length normalized
|
| 1051 |
+
qlen = len((q_text or "").strip())
|
| 1052 |
+
qlen_norm = min(1.0, qlen / 300.0)
|
| 1053 |
+
|
| 1054 |
+
# combine signals using safer semantics:
|
| 1055 |
+
# higher emb_support -> easier (so we subtract a term)
|
| 1056 |
+
# higher distractor_penalty -> harder (add)
|
| 1057 |
+
# better gap -> easier (subtract)
|
| 1058 |
+
# compute score (higher -> harder)
|
| 1059 |
+
score = 0
|
| 1060 |
+
score += 0.35 * float(distractor_penalty)
|
| 1061 |
+
score += 0.20 * float(mean_sim)
|
| 1062 |
+
score += 0.22 * float(amb_flag)
|
| 1063 |
+
score += 0.05 * float(qlen_norm)
|
| 1064 |
+
score -= 0.20 * float(gap)
|
| 1065 |
+
|
| 1066 |
+
# clamp
|
| 1067 |
+
score = max(0.0, min(1.0, float(score)))
|
| 1068 |
+
|
| 1069 |
+
# label
|
| 1070 |
+
if score <= 0.33:
|
| 1071 |
+
label = "dễ"
|
| 1072 |
+
elif score <= 0.66 and score > 0.33:
|
| 1073 |
+
label = "trung bình"
|
| 1074 |
+
else:
|
| 1075 |
+
label = "khó"
|
| 1076 |
+
|
| 1077 |
+
return score, label
|
| 1078 |
+
|
| 1079 |
+
class RAGMCQWithDifficulty(RAGMCQ):
|
| 1080 |
+
def __init__(
|
| 1081 |
+
self,
|
| 1082 |
+
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
| 1083 |
+
generation_model: str = "openai/gpt-oss-120b",
|
| 1084 |
+
qdrant_url: str = os.environ.get('QDRANT_URL') or "",
|
| 1085 |
+
qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "",
|
| 1086 |
+
qdrant_prefer_grpc: bool = False,
|
| 1087 |
+
):
|
| 1088 |
+
super(embedder_model, generation_model, qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
|
| 1089 |
+
|
| 1090 |
+
@override
|
| 1091 |
+
def generate_from_pdf(
|
| 1092 |
+
self,
|
| 1093 |
+
pdf_path: str,
|
| 1094 |
+
n_questions: int = 10,
|
| 1095 |
+
mode: str = "rag", # per_page or rag
|
| 1096 |
+
questions_per_page: int = 3, # for per_page mode
|
| 1097 |
+
top_k: int = 3, # chunks to retrieve for each question in rag mode
|
| 1098 |
+
temperature: float = 0.2,
|
| 1099 |
+
enable_fiddler: bool = False,
|
| 1100 |
+
target_difficulty: str = 'easy' # easy, mid, difficult
|
| 1101 |
+
) -> Dict[str, Any]:
|
| 1102 |
+
# build index
|
| 1103 |
+
self.build_index_from_pdf(pdf_path)
|
| 1104 |
+
|
| 1105 |
+
output: Dict[str, Any] = {}
|
| 1106 |
+
qcount = 0
|
| 1107 |
+
|
| 1108 |
+
if mode == "per_page":
|
| 1109 |
+
# iterate pages -> chunks
|
| 1110 |
+
for idx, meta in enumerate(self.metadata):
|
| 1111 |
+
chunk_text = self.texts[idx]
|
| 1112 |
+
|
| 1113 |
+
if not chunk_text.strip():
|
| 1114 |
+
continue
|
| 1115 |
+
to_gen = questions_per_page
|
| 1116 |
+
|
| 1117 |
+
# ask generator
|
| 1118 |
+
try:
|
| 1119 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1120 |
+
mcq_block = generate_mcqs_from_text(
|
| 1121 |
+
source_text=chunk_text, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
|
| 1122 |
+
)
|
| 1123 |
+
except Exception as e:
|
| 1124 |
+
# skip this chunk if generator fails
|
| 1125 |
+
print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
|
| 1126 |
+
continue
|
| 1127 |
+
|
| 1128 |
+
if "error" in list(mcq_block.keys()):
|
| 1129 |
+
return output
|
| 1130 |
+
|
| 1131 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 1132 |
+
qcount += 1
|
| 1133 |
+
output[str(qcount)] = mcq_block[item]
|
| 1134 |
+
if qcount >= n_questions:
|
| 1135 |
+
return output
|
| 1136 |
+
|
| 1137 |
+
return output
|
| 1138 |
+
|
| 1139 |
+
# pdf gene
|
| 1140 |
+
elif mode == "rag":
|
| 1141 |
+
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
|
| 1142 |
+
# create queries by sampling chunk text sentences.
|
| 1143 |
+
# stop when n_questions reached or max_attempts exceeded.
|
| 1144 |
+
attempts = 0
|
| 1145 |
+
max_attempts = n_questions * 4
|
| 1146 |
+
|
| 1147 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 1148 |
+
attempts += 1
|
| 1149 |
+
# create a seed query: pick a random chunk, pick a sentence from it
|
| 1150 |
+
seed_idx = random.randrange(len(self.texts))
|
| 1151 |
+
chunk = self.texts[seed_idx]
|
| 1152 |
+
|
| 1153 |
+
#? investigate better Chunking Strategy
|
| 1154 |
+
#with open("chunks.txt", "a", encoding="utf-8") as f:
|
| 1155 |
+
#f.write(chunk + "\n")
|
| 1156 |
+
|
| 1157 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 1158 |
+
seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
|
| 1159 |
+
query = f"Create questions about: {seed_sent}"
|
| 1160 |
+
|
| 1161 |
+
# retrieve top_k chunks
|
| 1162 |
+
retrieved = self._retrieve(query, top_k=top_k)
|
| 1163 |
+
context_parts = []
|
| 1164 |
+
for ridx, score in retrieved:
|
| 1165 |
+
md = self.metadata[ridx]
|
| 1166 |
+
context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
|
| 1167 |
+
context = "\n\n".join(context_parts)
|
| 1168 |
+
|
| 1169 |
+
# save_to_local('test/context.md', content=context)
|
| 1170 |
+
|
| 1171 |
+
# call generator for 1 question (or small batch) with the retrieved context
|
| 1172 |
+
try:
|
| 1173 |
+
# request 1 question at a time to keep diversity
|
| 1174 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1175 |
+
mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
|
| 1176 |
+
except Exception as e:
|
| 1177 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 1178 |
+
continue
|
| 1179 |
+
|
| 1180 |
+
if "error" in list(mcq_block.keys()):
|
| 1181 |
+
return output
|
| 1182 |
+
|
| 1183 |
+
# append result(s)
|
| 1184 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 1185 |
+
payload = mcq_block[item]
|
| 1186 |
+
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
|
| 1187 |
+
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
|
| 1188 |
+
if isinstance(options, list):
|
| 1189 |
+
options = {str(i+1): o for i, o in enumerate(options)}
|
| 1190 |
+
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
|
| 1191 |
+
concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
|
| 1192 |
+
correct_text = ""
|
| 1193 |
+
if isinstance(correct_key, str) and correct_key.strip() in options:
|
| 1194 |
+
correct_text = options[correct_key.strip()]
|
| 1195 |
+
else:
|
| 1196 |
+
correct_text = payload.get("correct_text") or correct_key or ""
|
| 1197 |
+
|
| 1198 |
+
diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
|
| 1199 |
+
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
|
| 1203 |
+
|
| 1204 |
+
qcount += 1
|
| 1205 |
+
output[str(qcount)] = mcq_block[item]
|
| 1206 |
+
if qcount >= n_questions:
|
| 1207 |
+
return output
|
| 1208 |
+
|
| 1209 |
+
return output
|
| 1210 |
+
else:
|
| 1211 |
+
raise ValueError("mode must be 'per_page' or 'rag'.")
|
| 1212 |
+
|
| 1213 |
+
@override
|
| 1214 |
def generate_from_qdrant(
|
| 1215 |
self,
|
| 1216 |
filename: str,
|
|
|
|
| 1365 |
else:
|
| 1366 |
raise ValueError("mode must be 'per_chunk' or 'rag'.")
|
| 1367 |
|
| 1368 |
+
@override
|
| 1369 |
def _estimate_difficulty_for_generation(
|
| 1370 |
self,
|
| 1371 |
q_text: str,
|
|
|
|
| 1463 |
# higher distractor_penalty -> harder (add)
|
| 1464 |
# better gap -> easier (subtract)
|
| 1465 |
# compute score (higher -> harder)
|
|
|
|
| 1466 |
|
| 1467 |
score = 0
|
| 1468 |
score += 0.35 * float(distractor_penalty)
|
|
|
|
| 1493 |
else:
|
| 1494 |
label = "khó"
|
| 1495 |
|
| 1496 |
+
return score, label, components
|