Spaces:
Runtime error
Runtime error
namberino
commited on
Commit
·
4dcccc9
1
Parent(s):
9e7752e
Add difficulty distribution
Browse files- app.py +190 -8
- generator.py +421 -1
- requirements.txt +1 -0
- utils.py +155 -8
app.py
CHANGED
|
@@ -8,7 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
|
| 10 |
# Import the user's RAGMCQ implementation
|
| 11 |
-
from generator import RAGMCQ
|
| 12 |
from utils import log_pipeline
|
| 13 |
|
| 14 |
app = FastAPI(title="RAG MCQ Generator API")
|
|
@@ -24,6 +24,7 @@ app.add_middleware(
|
|
| 24 |
|
| 25 |
# global rag instance
|
| 26 |
rag: Optional[RAGMCQ] = None
|
|
|
|
| 27 |
|
| 28 |
class GenerateResponse(BaseModel):
|
| 29 |
mcqs: dict
|
|
@@ -34,15 +35,17 @@ class ListResponse(BaseModel):
|
|
| 34 |
|
| 35 |
@app.on_event("startup")
|
| 36 |
def startup_event():
|
|
|
|
| 37 |
global rag
|
| 38 |
|
| 39 |
# instantiate the heavy object once
|
| 40 |
rag = RAGMCQ()
|
|
|
|
| 41 |
print("RAGMCQ instance created on startup.")
|
| 42 |
|
| 43 |
@app.get("/health")
|
| 44 |
def health():
|
| 45 |
-
return {"status": "ok", "ready": rag is not None}
|
| 46 |
|
| 47 |
def _save_upload_to_temp(upload: UploadFile) -> str:
|
| 48 |
suffix = ".pdf"
|
|
@@ -57,11 +60,11 @@ def _save_upload_to_temp(upload: UploadFile) -> str:
|
|
| 57 |
async def list_collection_files_endpoint(
|
| 58 |
collection_name: str = "programming"
|
| 59 |
):
|
| 60 |
-
global
|
| 61 |
-
if
|
| 62 |
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 63 |
|
| 64 |
-
files =
|
| 65 |
|
| 66 |
return {"files": files}
|
| 67 |
|
|
@@ -81,8 +84,8 @@ async def upload_multiple_files(
|
|
| 81 |
- overwrite: if true, existing points for each filename will be removed
|
| 82 |
- qdrant_filename_prefix: optional prefix; if provided each file will be saved under "<prefix>_<original_filename>"
|
| 83 |
"""
|
| 84 |
-
global
|
| 85 |
-
if
|
| 86 |
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 87 |
|
| 88 |
saved_files = []
|
|
@@ -112,7 +115,7 @@ async def upload_multiple_files(
|
|
| 112 |
)
|
| 113 |
|
| 114 |
try:
|
| 115 |
-
|
| 116 |
saved_files.append(qdrant_filename)
|
| 117 |
except Exception as e:
|
| 118 |
# collect failure info rather than aborting all uploads
|
|
@@ -122,6 +125,185 @@ async def upload_multiple_files(
|
|
| 122 |
|
| 123 |
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
@app.post("/generate_saved", response_model=GenerateResponse)
|
| 126 |
async def generate_saved_endpoint(
|
| 127 |
n_questions: int = Form(10),
|
|
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
|
| 10 |
# Import the user's RAGMCQ implementation
|
| 11 |
+
from generator import RAGMCQWithDifficulty, RAGMCQ
|
| 12 |
from utils import log_pipeline
|
| 13 |
|
| 14 |
app = FastAPI(title="RAG MCQ Generator API")
|
|
|
|
| 24 |
|
| 25 |
# global rag instance
|
| 26 |
rag: Optional[RAGMCQ] = None
|
| 27 |
+
rag_difficulty: Optional[RAGMCQWithDifficulty] = None
|
| 28 |
|
| 29 |
class GenerateResponse(BaseModel):
|
| 30 |
mcqs: dict
|
|
|
|
| 35 |
|
| 36 |
@app.on_event("startup")
|
| 37 |
def startup_event():
|
| 38 |
+
global rag_difficulty
|
| 39 |
global rag
|
| 40 |
|
| 41 |
# instantiate the heavy object once
|
| 42 |
rag = RAGMCQ()
|
| 43 |
+
rag_difficulty = RAGMCQWithDifficulty()
|
| 44 |
print("RAGMCQ instance created on startup.")
|
| 45 |
|
| 46 |
@app.get("/health")
|
| 47 |
def health():
|
| 48 |
+
return {"status": "ok", "ready": rag_difficulty is not None and rag is not None}
|
| 49 |
|
| 50 |
def _save_upload_to_temp(upload: UploadFile) -> str:
|
| 51 |
suffix = ".pdf"
|
|
|
|
| 60 |
async def list_collection_files_endpoint(
|
| 61 |
collection_name: str = "programming"
|
| 62 |
):
|
| 63 |
+
global rag_difficulty
|
| 64 |
+
if rag_difficulty is None:
|
| 65 |
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 66 |
|
| 67 |
+
files = rag_difficulty.list_files_in_collection(collection_name)
|
| 68 |
|
| 69 |
return {"files": files}
|
| 70 |
|
|
|
|
| 84 |
- overwrite: if true, existing points for each filename will be removed
|
| 85 |
- qdrant_filename_prefix: optional prefix; if provided each file will be saved under "<prefix>_<original_filename>"
|
| 86 |
"""
|
| 87 |
+
global rag_difficulty
|
| 88 |
+
if rag_difficulty is None:
|
| 89 |
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 90 |
|
| 91 |
saved_files = []
|
|
|
|
| 115 |
)
|
| 116 |
|
| 117 |
try:
|
| 118 |
+
rag_difficulty.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=overwrite)
|
| 119 |
saved_files.append(qdrant_filename)
|
| 120 |
except Exception as e:
|
| 121 |
# collect failure info rather than aborting all uploads
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
|
| 128 |
+
@app.post("/generate_saved_with_difficulty", response_model=GenerateResponse)
|
| 129 |
+
async def generate_saved_with_difficulty_endpoint(
|
| 130 |
+
n_easy_questions: int = Form(3),
|
| 131 |
+
n_medium_questions: int = Form(5),
|
| 132 |
+
n_hard_questions: int = Form(2),
|
| 133 |
+
qdrant_filename: str = Form("default_filename"),
|
| 134 |
+
collection_name: str = Form("programming"),
|
| 135 |
+
mode: str = Form("rag"),
|
| 136 |
+
questions_per_chunk: int = Form(3),
|
| 137 |
+
top_k: int = Form(3),
|
| 138 |
+
temperature: float = Form(0.2),
|
| 139 |
+
validate_mcqs: bool = Form(False),
|
| 140 |
+
enable_fiddler: bool = Form(False),
|
| 141 |
+
):
|
| 142 |
+
global rag_difficulty
|
| 143 |
+
if rag_difficulty is None:
|
| 144 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 145 |
+
|
| 146 |
+
difficulty_counts = {
|
| 147 |
+
"easy": n_easy_questions,
|
| 148 |
+
"medium": n_medium_questions,
|
| 149 |
+
"hard": n_hard_questions
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
all_mcqs = {}
|
| 153 |
+
counter = 1
|
| 154 |
+
|
| 155 |
+
for difficulty, n_questions in difficulty_counts.items():
|
| 156 |
+
try:
|
| 157 |
+
mcqs = rag_difficulty.generate_from_qdrant(
|
| 158 |
+
filename=qdrant_filename,
|
| 159 |
+
collection=collection_name,
|
| 160 |
+
n_questions=n_questions,
|
| 161 |
+
mode=mode,
|
| 162 |
+
questions_per_chunk=questions_per_chunk,
|
| 163 |
+
top_k=top_k,
|
| 164 |
+
temperature=temperature,
|
| 165 |
+
enable_fiddler=enable_fiddler,
|
| 166 |
+
target_difficulty=difficulty,
|
| 167 |
+
)
|
| 168 |
+
questions_list = []
|
| 169 |
+
if isinstance(mcqs, dict):
|
| 170 |
+
for v in mcqs.values():
|
| 171 |
+
if isinstance(v, list):
|
| 172 |
+
questions_list.extend(v)
|
| 173 |
+
else:
|
| 174 |
+
questions_list.append(v)
|
| 175 |
+
elif isinstance(mcqs, list):
|
| 176 |
+
questions_list = mcqs
|
| 177 |
+
else:
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
for qobj in questions_list:
|
| 181 |
+
if isinstance(qobj, dict):
|
| 182 |
+
qobj["_difficulty"] = difficulty
|
| 183 |
+
all_mcqs[str(counter)] = qobj
|
| 184 |
+
counter += 1
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
|
| 188 |
+
|
| 189 |
+
validation_report = None
|
| 190 |
+
|
| 191 |
+
if validate_mcqs:
|
| 192 |
+
try:
|
| 193 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 194 |
+
validation_report = rag_difficulty.validate_mcqs(all_mcqs, top_k=top_k)
|
| 195 |
+
except Exception as e:
|
| 196 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 197 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 198 |
+
|
| 199 |
+
# log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
|
| 200 |
+
|
| 201 |
+
return {"mcqs": all_mcqs, "validation": validation_report}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@app.post("/generate_with_difficulty", response_model=GenerateResponse)
|
| 207 |
+
async def generate_with_difficulty_endpoint(
|
| 208 |
+
background_tasks: BackgroundTasks,
|
| 209 |
+
file: UploadFile = File(...),
|
| 210 |
+
n_easy_questions: int = Form(3),
|
| 211 |
+
n_medium_questions: int = Form(5),
|
| 212 |
+
n_hard_questions: int = Form(2),
|
| 213 |
+
qdrant_filename: str = Form("default_filename"),
|
| 214 |
+
collection_name: str = Form("programming"),
|
| 215 |
+
mode: str = Form("rag"),
|
| 216 |
+
questions_per_page: int = Form(3),
|
| 217 |
+
top_k: int = Form(3),
|
| 218 |
+
temperature: float = Form(0.2),
|
| 219 |
+
validate_mcqs: bool = Form(False),
|
| 220 |
+
enable_fiddler: bool = Form(False)
|
| 221 |
+
):
|
| 222 |
+
global rag_difficulty
|
| 223 |
+
if rag_difficulty is None:
|
| 224 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 225 |
+
|
| 226 |
+
# basic file validation
|
| 227 |
+
if not file.filename.lower().endswith(".pdf"):
|
| 228 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
|
| 229 |
+
|
| 230 |
+
# save uploaded file to a temp location
|
| 231 |
+
tmp_path = _save_upload_to_temp(file)
|
| 232 |
+
|
| 233 |
+
# ensure file removed afterward
|
| 234 |
+
def _cleanup(path: str):
|
| 235 |
+
try:
|
| 236 |
+
os.remove(path)
|
| 237 |
+
except Exception:
|
| 238 |
+
pass
|
| 239 |
+
|
| 240 |
+
background_tasks.add_task(_cleanup, tmp_path)
|
| 241 |
+
|
| 242 |
+
# save pdf
|
| 243 |
+
try:
|
| 244 |
+
rag_difficulty.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=True)
|
| 245 |
+
except Exception as e:
|
| 246 |
+
raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
|
| 247 |
+
|
| 248 |
+
difficulty_counts = {
|
| 249 |
+
"easy": n_easy_questions,
|
| 250 |
+
"medium": n_medium_questions,
|
| 251 |
+
"hard": n_hard_questions
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
all_mcqs = {}
|
| 255 |
+
counter = 1
|
| 256 |
+
|
| 257 |
+
for difficulty, n_questions in difficulty_counts.items():
|
| 258 |
+
try:
|
| 259 |
+
mcqs = rag_difficulty.generate_from_pdf(
|
| 260 |
+
pdf_path=tmp_path,
|
| 261 |
+
n_questions=n_questions,
|
| 262 |
+
mode=mode,
|
| 263 |
+
questions_per_page=questions_per_page,
|
| 264 |
+
top_k=top_k,
|
| 265 |
+
temperature=temperature,
|
| 266 |
+
enable_fiddler=enable_fiddler,
|
| 267 |
+
target_difficulty=difficulty,
|
| 268 |
+
)
|
| 269 |
+
questions_list = []
|
| 270 |
+
if isinstance(mcqs, dict):
|
| 271 |
+
for v in mcqs.values():
|
| 272 |
+
if isinstance(v, list):
|
| 273 |
+
questions_list.extend(v)
|
| 274 |
+
else:
|
| 275 |
+
questions_list.append(v)
|
| 276 |
+
elif isinstance(mcqs, list):
|
| 277 |
+
questions_list = mcqs
|
| 278 |
+
else:
|
| 279 |
+
continue
|
| 280 |
+
|
| 281 |
+
for qobj in questions_list:
|
| 282 |
+
if isinstance(qobj, dict):
|
| 283 |
+
qobj["_difficulty"] = difficulty
|
| 284 |
+
all_mcqs[str(counter)] = qobj
|
| 285 |
+
counter += 1
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
raise HTTPException(status_code=500, detail=f"Generation from file failed: {e}")
|
| 289 |
+
|
| 290 |
+
validation_report = None
|
| 291 |
+
|
| 292 |
+
if validate_mcqs:
|
| 293 |
+
try:
|
| 294 |
+
# rag.build_index_from_pdf(tmp_path)
|
| 295 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 296 |
+
validation_report = rag_difficulty.validate_mcqs(all_mcqs, top_k=top_k)
|
| 297 |
+
except Exception as e:
|
| 298 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 299 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
|
| 303 |
+
|
| 304 |
+
return {"mcqs": all_mcqs, "validation": validation_report}
|
| 305 |
+
|
| 306 |
+
|
| 307 |
@app.post("/generate_saved", response_model=GenerateResponse)
|
| 308 |
async def generate_saved_endpoint(
|
| 309 |
n_questions: int = Form(10),
|
generator.py
CHANGED
|
@@ -9,6 +9,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder
|
|
| 9 |
from transformers import pipeline
|
| 10 |
from uuid import uuid4
|
| 11 |
import pymupdf4llm
|
|
|
|
| 12 |
|
| 13 |
try:
|
| 14 |
from qdrant_client import QdrantClient
|
|
@@ -31,7 +32,7 @@ try:
|
|
| 31 |
except Exception:
|
| 32 |
_HAS_FAISS = False
|
| 33 |
|
| 34 |
-
from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json, save_to_local
|
| 35 |
|
| 36 |
from huggingface_hub import login
|
| 37 |
login(token=os.environ['HF_MODEL_TOKEN'])
|
|
@@ -1074,3 +1075,422 @@ class RAGMCQ:
|
|
| 1074 |
label = "khó"
|
| 1075 |
|
| 1076 |
return score, label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from transformers import pipeline
|
| 10 |
from uuid import uuid4
|
| 11 |
import pymupdf4llm
|
| 12 |
+
from typing_extensions import override
|
| 13 |
|
| 14 |
try:
|
| 15 |
from qdrant_client import QdrantClient
|
|
|
|
| 32 |
except Exception:
|
| 33 |
_HAS_FAISS = False
|
| 34 |
|
| 35 |
+
from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json, save_to_local, structure_context_for_llm, new_generate_mcqs_from_text
|
| 36 |
|
| 37 |
from huggingface_hub import login
|
| 38 |
login(token=os.environ['HF_MODEL_TOKEN'])
|
|
|
|
| 1075 |
label = "khó"
|
| 1076 |
|
| 1077 |
return score, label
|
| 1078 |
+
|
| 1079 |
+
class RAGMCQWithDifficulty(RAGMCQ):
|
| 1080 |
+
def __init__(
|
| 1081 |
+
self,
|
| 1082 |
+
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
| 1083 |
+
generation_model: str = "openai/gpt-oss-120b",
|
| 1084 |
+
qdrant_url: str = os.environ.get('QDRANT_URL') or "",
|
| 1085 |
+
qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "",
|
| 1086 |
+
qdrant_prefer_grpc: bool = False,
|
| 1087 |
+
):
|
| 1088 |
+
super().__init__(embedder_model, generation_model, qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
|
| 1089 |
+
|
| 1090 |
+
@override
|
| 1091 |
+
def generate_from_pdf(
|
| 1092 |
+
self,
|
| 1093 |
+
pdf_path: str,
|
| 1094 |
+
n_questions: int = 10,
|
| 1095 |
+
mode: str = "rag", # per_page or rag
|
| 1096 |
+
questions_per_page: int = 3, # for per_page mode
|
| 1097 |
+
top_k: int = 3, # chunks to retrieve for each question in rag mode
|
| 1098 |
+
temperature: float = 0.2,
|
| 1099 |
+
enable_fiddler: bool = False,
|
| 1100 |
+
target_difficulty: str = 'easy' # easy, mid, difficult
|
| 1101 |
+
) -> Dict[str, Any]:
|
| 1102 |
+
# build index
|
| 1103 |
+
self.build_index_from_pdf(pdf_path)
|
| 1104 |
+
|
| 1105 |
+
output: Dict[str, Any] = {}
|
| 1106 |
+
qcount = 0
|
| 1107 |
+
|
| 1108 |
+
if mode == "per_page":
|
| 1109 |
+
# iterate pages -> chunks
|
| 1110 |
+
for idx, meta in enumerate(self.metadata):
|
| 1111 |
+
chunk_text = self.texts[idx]
|
| 1112 |
+
|
| 1113 |
+
if not chunk_text.strip():
|
| 1114 |
+
continue
|
| 1115 |
+
to_gen = questions_per_page
|
| 1116 |
+
|
| 1117 |
+
# ask generator
|
| 1118 |
+
try:
|
| 1119 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1120 |
+
mcq_block = generate_mcqs_from_text(
|
| 1121 |
+
source_text=chunk_text, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
|
| 1122 |
+
)
|
| 1123 |
+
except Exception as e:
|
| 1124 |
+
# skip this chunk if generator fails
|
| 1125 |
+
print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
|
| 1126 |
+
continue
|
| 1127 |
+
|
| 1128 |
+
if "error" in list(mcq_block.keys()):
|
| 1129 |
+
return output
|
| 1130 |
+
|
| 1131 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 1132 |
+
qcount += 1
|
| 1133 |
+
output[str(qcount)] = mcq_block[item]
|
| 1134 |
+
if qcount >= n_questions:
|
| 1135 |
+
return output
|
| 1136 |
+
|
| 1137 |
+
return output
|
| 1138 |
+
|
| 1139 |
+
# pdf gene
|
| 1140 |
+
elif mode == "rag":
|
| 1141 |
+
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
|
| 1142 |
+
# create queries by sampling chunk text sentences.
|
| 1143 |
+
# stop when n_questions reached or max_attempts exceeded.
|
| 1144 |
+
attempts = 0
|
| 1145 |
+
max_attempts = n_questions * 4
|
| 1146 |
+
|
| 1147 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 1148 |
+
attempts += 1
|
| 1149 |
+
# create a seed query: pick a random chunk, pick a sentence from it
|
| 1150 |
+
seed_idx = random.randrange(len(self.texts))
|
| 1151 |
+
chunk = self.texts[seed_idx]
|
| 1152 |
+
|
| 1153 |
+
#? investigate better Chunking Strategy
|
| 1154 |
+
#with open("chunks.txt", "a", encoding="utf-8") as f:
|
| 1155 |
+
#f.write(chunk + "\n")
|
| 1156 |
+
|
| 1157 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 1158 |
+
seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
|
| 1159 |
+
query = f"Create questions about: {seed_sent}"
|
| 1160 |
+
|
| 1161 |
+
# retrieve top_k chunks
|
| 1162 |
+
retrieved = self._retrieve(query, top_k=top_k)
|
| 1163 |
+
context_parts = []
|
| 1164 |
+
for ridx, score in retrieved:
|
| 1165 |
+
md = self.metadata[ridx]
|
| 1166 |
+
context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
|
| 1167 |
+
context = "\n\n".join(context_parts)
|
| 1168 |
+
|
| 1169 |
+
# save_to_local('test/context.md', content=context)
|
| 1170 |
+
|
| 1171 |
+
# call generator for 1 question (or small batch) with the retrieved context
|
| 1172 |
+
try:
|
| 1173 |
+
# request 1 question at a time to keep diversity
|
| 1174 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1175 |
+
mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
|
| 1176 |
+
except Exception as e:
|
| 1177 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 1178 |
+
continue
|
| 1179 |
+
|
| 1180 |
+
if "error" in list(mcq_block.keys()):
|
| 1181 |
+
return output
|
| 1182 |
+
|
| 1183 |
+
# append result(s)
|
| 1184 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 1185 |
+
payload = mcq_block[item]
|
| 1186 |
+
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
|
| 1187 |
+
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
|
| 1188 |
+
if isinstance(options, list):
|
| 1189 |
+
options = {str(i+1): o for i, o in enumerate(options)}
|
| 1190 |
+
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
|
| 1191 |
+
concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
|
| 1192 |
+
correct_text = ""
|
| 1193 |
+
if isinstance(correct_key, str) and correct_key.strip() in options:
|
| 1194 |
+
correct_text = options[correct_key.strip()]
|
| 1195 |
+
else:
|
| 1196 |
+
correct_text = payload.get("correct_text") or correct_key or ""
|
| 1197 |
+
|
| 1198 |
+
diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
|
| 1199 |
+
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
|
| 1203 |
+
|
| 1204 |
+
qcount += 1
|
| 1205 |
+
output[str(qcount)] = mcq_block[item]
|
| 1206 |
+
if qcount >= n_questions:
|
| 1207 |
+
return output
|
| 1208 |
+
|
| 1209 |
+
return output
|
| 1210 |
+
else:
|
| 1211 |
+
raise ValueError("mode must be 'per_page' or 'rag'.")
|
| 1212 |
+
|
| 1213 |
+
@override
|
| 1214 |
+
def generate_from_qdrant(
|
| 1215 |
+
self,
|
| 1216 |
+
filename: str,
|
| 1217 |
+
collection: str,
|
| 1218 |
+
n_questions: int = 10,
|
| 1219 |
+
mode: str = "rag", # 'per_chunk' or 'rag'
|
| 1220 |
+
questions_per_chunk: int = 3, # used for 'per_chunk'
|
| 1221 |
+
top_k: int = 3, # retrieval size used in RAG
|
| 1222 |
+
temperature: float = 0.2,
|
| 1223 |
+
enable_fiddler: bool = False,
|
| 1224 |
+
target_difficulty: str = 'easy',
|
| 1225 |
+
|
| 1226 |
+
) -> Dict[str, Any]:
|
| 1227 |
+
if self.qdrant is None:
|
| 1228 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 1229 |
+
|
| 1230 |
+
# get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
|
| 1231 |
+
file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
|
| 1232 |
+
if not file_points:
|
| 1233 |
+
raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
|
| 1234 |
+
|
| 1235 |
+
# create a local list of texts & metadata for sampling
|
| 1236 |
+
texts = []
|
| 1237 |
+
metas = []
|
| 1238 |
+
for p in file_points:
|
| 1239 |
+
payload = p.get("payload", {})
|
| 1240 |
+
text = payload.get("text", "")
|
| 1241 |
+
texts.append(text)
|
| 1242 |
+
metas.append(payload)
|
| 1243 |
+
|
| 1244 |
+
self.texts = texts
|
| 1245 |
+
self.metadata = metas
|
| 1246 |
+
embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
|
| 1247 |
+
if embeddings is None or len(embeddings) == 0:
|
| 1248 |
+
self.embeddings = None
|
| 1249 |
+
self.index = None
|
| 1250 |
+
else:
|
| 1251 |
+
self.embeddings = embeddings.astype("float32")
|
| 1252 |
+
|
| 1253 |
+
# update dim in case embedder changed unexpectedly
|
| 1254 |
+
self.dim = int(self.embeddings.shape[1])
|
| 1255 |
+
|
| 1256 |
+
# build index
|
| 1257 |
+
self._build_faiss_index()
|
| 1258 |
+
|
| 1259 |
+
output = {}
|
| 1260 |
+
qcount = 0
|
| 1261 |
+
|
| 1262 |
+
if mode == "per_chunk":
|
| 1263 |
+
# iterate all chunks (in payload order) and request questions_per_chunk from each
|
| 1264 |
+
for i, txt in enumerate(texts):
|
| 1265 |
+
if not txt.strip():
|
| 1266 |
+
continue
|
| 1267 |
+
to_gen = questions_per_chunk
|
| 1268 |
+
try:
|
| 1269 |
+
mcq_block = new_generate_mcqs_from_text(txt, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=False)
|
| 1270 |
+
except Exception as e:
|
| 1271 |
+
print(f"Generator failed on chunk (index {i}): {e}")
|
| 1272 |
+
continue
|
| 1273 |
+
|
| 1274 |
+
if "error" in list(mcq_block.keys()):
|
| 1275 |
+
return output
|
| 1276 |
+
|
| 1277 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 1278 |
+
qcount += 1
|
| 1279 |
+
output[str(qcount)] = mcq_block[item]
|
| 1280 |
+
if qcount >= n_questions:
|
| 1281 |
+
return output
|
| 1282 |
+
return output
|
| 1283 |
+
|
| 1284 |
+
elif mode == "rag":
|
| 1285 |
+
attempts = 0
|
| 1286 |
+
max_attempts = n_questions * 4
|
| 1287 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 1288 |
+
attempts += 1
|
| 1289 |
+
# create a seed query: pick a random chunk, pick a sentence from it
|
| 1290 |
+
seed_idx = random.randrange(len(self.texts))
|
| 1291 |
+
chunk = self.texts[seed_idx]
|
| 1292 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 1293 |
+
candidate = [s for s in sents if len(s.strip()) > 20]
|
| 1294 |
+
if candidate:
|
| 1295 |
+
seed_sent = random.choice(candidate)
|
| 1296 |
+
else:
|
| 1297 |
+
stripped = chunk.strip()
|
| 1298 |
+
seed_sent = (stripped[:200] if stripped else "[no text available]")
|
| 1299 |
+
query = f"Create questions about: {seed_sent}"
|
| 1300 |
+
|
| 1301 |
+
|
| 1302 |
+
# retrieve top_k chunks from the same file (restricted by filename filter)
|
| 1303 |
+
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
|
| 1304 |
+
print('retrieved qdrant', retrieved)
|
| 1305 |
+
context_parts = []
|
| 1306 |
+
for payload, score in retrieved:
|
| 1307 |
+
# payload should contain page & chunk_id and text
|
| 1308 |
+
page = payload.get("page", "?")
|
| 1309 |
+
ctxt = payload.get("text", "")
|
| 1310 |
+
context_parts.append(f"[page {page}] {ctxt}")
|
| 1311 |
+
context = "\n\n".join(context_parts)
|
| 1312 |
+
|
| 1313 |
+
|
| 1314 |
+
# q generation
|
| 1315 |
+
try:
|
| 1316 |
+
# Difficulty pipeline: easy, mid, difficult
|
| 1317 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1318 |
+
mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
|
| 1319 |
+
except Exception as e:
|
| 1320 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 1321 |
+
continue
|
| 1322 |
+
|
| 1323 |
+
if "error" in list(mcq_block.keys()):
|
| 1324 |
+
return output
|
| 1325 |
+
|
| 1326 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 1327 |
+
payload = mcq_block[item]
|
| 1328 |
+
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
|
| 1329 |
+
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
|
| 1330 |
+
|
| 1331 |
+
if isinstance(options, list):
|
| 1332 |
+
options = {str(i+1): o for i, o in enumerate(options)}
|
| 1333 |
+
|
| 1334 |
+
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
|
| 1335 |
+
concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
|
| 1336 |
+
|
| 1337 |
+
correct_text = ""
|
| 1338 |
+
if isinstance(correct_key, str) and correct_key.strip() in options:
|
| 1339 |
+
correct_text = options[correct_key.strip()]
|
| 1340 |
+
else:
|
| 1341 |
+
correct_text = payload.get("correct_text") or correct_key or ""
|
| 1342 |
+
|
| 1343 |
+
#? change estimate
|
| 1344 |
+
diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
|
| 1345 |
+
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
|
| 1346 |
+
)
|
| 1347 |
+
|
| 1348 |
+
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
|
| 1349 |
+
|
| 1350 |
+
# CHECK n generation: if number of request mcqs < default generation number e.g. 5 - 3 = 2 < 3 then only genearate 2 mcqs
|
| 1351 |
+
if n_questions - qcount < questions_per_chunk:
|
| 1352 |
+
questions_per_chunk = n_questions - qcount
|
| 1353 |
+
|
| 1354 |
+
qcount += 1 # count number of question
|
| 1355 |
+
print('qcount:', qcount)
|
| 1356 |
+
print('questions_per_chunk:', questions_per_chunk)
|
| 1357 |
+
|
| 1358 |
+
output[str(qcount)] = mcq_block[item]
|
| 1359 |
+
if qcount >= n_questions:
|
| 1360 |
+
return output
|
| 1361 |
+
|
| 1362 |
+
if output is not None:
|
| 1363 |
+
print("output available")
|
| 1364 |
+
return output
|
| 1365 |
+
else:
|
| 1366 |
+
raise ValueError("mode must be 'per_chunk' or 'rag'.")
|
| 1367 |
+
|
| 1368 |
+
@override
|
| 1369 |
+
def _estimate_difficulty_for_generation(
|
| 1370 |
+
self,
|
| 1371 |
+
q_text: str,
|
| 1372 |
+
options: Dict[str, str],
|
| 1373 |
+
correct_text: str,
|
| 1374 |
+
context_text: str = "",
|
| 1375 |
+
concepts_used: Dict = {}
|
| 1376 |
+
) -> Tuple[float, str]:
|
| 1377 |
+
def safe_map_sim(s):
|
| 1378 |
+
# map potentially [-1,1] cosine-like to [0,1], clamp
|
| 1379 |
+
try:
|
| 1380 |
+
s = float(s)
|
| 1381 |
+
except Exception:
|
| 1382 |
+
return 0.0
|
| 1383 |
+
mapped = (s + 1.0) / 2.0
|
| 1384 |
+
return max(0.0, min(1.0, mapped))
|
| 1385 |
+
|
| 1386 |
+
# embedding support
|
| 1387 |
+
emb_support = 0.0
|
| 1388 |
+
try:
|
| 1389 |
+
stmt = (q_text or "").strip()
|
| 1390 |
+
if correct_text:
|
| 1391 |
+
stmt = f"{stmt} Answer: {correct_text}"
|
| 1392 |
+
|
| 1393 |
+
# use internal retrieve but map returned score
|
| 1394 |
+
res = []
|
| 1395 |
+
try:
|
| 1396 |
+
res = self._retrieve(stmt, top_k=1)
|
| 1397 |
+
except Exception:
|
| 1398 |
+
res = []
|
| 1399 |
+
|
| 1400 |
+
if res:
|
| 1401 |
+
raw_score = float(res[0][1])
|
| 1402 |
+
emb_support = safe_map_sim(raw_score)
|
| 1403 |
+
else:
|
| 1404 |
+
emb_support = 0.0
|
| 1405 |
+
except Exception:
|
| 1406 |
+
emb_support = 0.0
|
| 1407 |
+
|
| 1408 |
+
# distractor sims
|
| 1409 |
+
mean_sim = 0.0
|
| 1410 |
+
distractor_penalty = 0.0
|
| 1411 |
+
amb_flag = 0.0
|
| 1412 |
+
try:
|
| 1413 |
+
keys = list(options.keys())
|
| 1414 |
+
texts = [options[k] for k in keys]
|
| 1415 |
+
if correct_text is None:
|
| 1416 |
+
correct_text = ""
|
| 1417 |
+
|
| 1418 |
+
all_texts = [correct_text] + texts
|
| 1419 |
+
embs = self.embedder.encode(all_texts, convert_to_numpy=True)
|
| 1420 |
+
embs = np.asarray(embs, dtype=float)
|
| 1421 |
+
norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12
|
| 1422 |
+
embs = embs / norms
|
| 1423 |
+
corr = embs[0]
|
| 1424 |
+
opts = embs[1:]
|
| 1425 |
+
|
| 1426 |
+
if opts.size == 0:
|
| 1427 |
+
mean_sim = 0.0
|
| 1428 |
+
distractor_penalty = 0.0
|
| 1429 |
+
gap = 0.0
|
| 1430 |
+
else:
|
| 1431 |
+
sims = (opts @ corr).tolist() # [-1,1]
|
| 1432 |
+
sims_mapped = [safe_map_sim(s) for s in sims] # [0,1]
|
| 1433 |
+
mean_sim = float(sum(sims_mapped) / len(sims_mapped))
|
| 1434 |
+
# gap between best distractor and second best (higher gap -> easier)
|
| 1435 |
+
sorted_s = sorted(sims_mapped, reverse=True)
|
| 1436 |
+
top = sorted_s[0]
|
| 1437 |
+
second = sorted_s[1] if len(sorted_s) > 1 else 0.0
|
| 1438 |
+
gap = top - second
|
| 1439 |
+
# penalties: if distractors are extremely close to correct -> higher penalty
|
| 1440 |
+
too_close_count = sum(1 for s in sims_mapped if s >= 0.85)
|
| 1441 |
+
too_far_count = sum(1 for s in sims_mapped if s <= 0.15)
|
| 1442 |
+
distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped))))
|
| 1443 |
+
amb_flag = 1.0 if top >= 0.8 else 0.0
|
| 1444 |
+
except Exception:
|
| 1445 |
+
mean_sim = 0.0
|
| 1446 |
+
distractor_penalty = 0.0
|
| 1447 |
+
amb_flag = 0.0
|
| 1448 |
+
gap = 0.0
|
| 1449 |
+
|
| 1450 |
+
# question length normalized
|
| 1451 |
+
question_len = len((q_text or "").strip())
|
| 1452 |
+
question_len_norm = min(1.0, question_len / 300.0)
|
| 1453 |
+
|
| 1454 |
+
# count number of concept from string
|
| 1455 |
+
concepts_num = len(concepts_used.keys())
|
| 1456 |
+
if concepts_num < 2:
|
| 1457 |
+
concepts_penalty = 0
|
| 1458 |
+
else:
|
| 1459 |
+
concepts_penalty = concepts_num
|
| 1460 |
+
|
| 1461 |
+
# combine signals using safer semantics:
|
| 1462 |
+
# higher emb_support -> easier (so we subtract a term)
|
| 1463 |
+
# higher distractor_penalty -> harder (add)
|
| 1464 |
+
# better gap -> easier (subtract)
|
| 1465 |
+
# compute score (higher -> harder)
|
| 1466 |
+
|
| 1467 |
+
score = 0
|
| 1468 |
+
score += 0.35 * float(distractor_penalty)
|
| 1469 |
+
score += 0.20 * float(mean_sim)
|
| 1470 |
+
score += 0.22 * float(amb_flag)
|
| 1471 |
+
score += 0.05 * float(question_len_norm)
|
| 1472 |
+
score -= 0.20 * float(gap)
|
| 1473 |
+
|
| 1474 |
+
# clamp
|
| 1475 |
+
score = max(0.0, min(1.0, float(score)))
|
| 1476 |
+
components = {
|
| 1477 |
+
"base": 0.3,
|
| 1478 |
+
"distractor_penalty": 0.35 * float(distractor_penalty),
|
| 1479 |
+
"mean_sim": 0.15 * float(mean_sim),
|
| 1480 |
+
"amb_flag": 0.05 * float(amb_flag),
|
| 1481 |
+
"concepts_num": 0.1 * float(concepts_num),
|
| 1482 |
+
"gap": -0.12 * float(gap),
|
| 1483 |
+
"question_len_norm": 0.05 * float(question_len_norm),
|
| 1484 |
+
"emb_support": -0.45 * float(emb_support),
|
| 1485 |
+
"total_score": score,
|
| 1486 |
+
}
|
| 1487 |
+
|
| 1488 |
+
# label
|
| 1489 |
+
if score <= 0.35:
|
| 1490 |
+
label = "dễ"
|
| 1491 |
+
elif score <= 0.65 and score > 0.35:
|
| 1492 |
+
label = "trung bình"
|
| 1493 |
+
else:
|
| 1494 |
+
label = "khó"
|
| 1495 |
+
|
| 1496 |
+
return score, label, components
|
requirements.txt
CHANGED
|
@@ -8,3 +8,4 @@ qdrant-client
|
|
| 8 |
pymupdf4llm
|
| 9 |
uuid
|
| 10 |
huggingface_hub
|
|
|
|
|
|
| 8 |
pymupdf4llm
|
| 9 |
uuid
|
| 10 |
huggingface_hub
|
| 11 |
+
typing_extensions
|
utils.py
CHANGED
|
@@ -53,12 +53,6 @@ def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: in
|
|
| 53 |
resp.raise_for_status()
|
| 54 |
data = resp.json()
|
| 55 |
|
| 56 |
-
#save_to_local('test/raw_resp.json', content=data)
|
| 57 |
-
|
| 58 |
-
#? Must update within _post_chat because it the original function for LLM generation
|
| 59 |
-
update_token_count(token_usage=data['usage']) # get data['usages']['prompt_tokens'] & data['usages']['completion_tokens']
|
| 60 |
-
# update_time_info(time_info=data['time_info'])
|
| 61 |
-
|
| 62 |
# handle various shapes
|
| 63 |
if "choices" in data and len(data["choices"]) > 0:
|
| 64 |
# prefer message.content
|
|
@@ -70,7 +64,6 @@ def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: in
|
|
| 70 |
if "text" in ch:
|
| 71 |
return ch["text"]
|
| 72 |
|
| 73 |
-
print(f'Generation Time: {data["time_info"]}')
|
| 74 |
# final fallback
|
| 75 |
raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
|
| 76 |
|
|
@@ -92,10 +85,164 @@ def _safe_extract_json(text: str) -> dict:
|
|
| 92 |
return json.loads(fixed)
|
| 93 |
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
def generate_mcqs_from_text(
|
| 96 |
source_text: str,
|
| 97 |
n: int = 3,
|
| 98 |
-
model: str = "gpt-oss-120b",
|
| 99 |
temperature: float = 0.2,
|
| 100 |
enable_fiddler: bool = False,
|
| 101 |
) -> Dict[str, Any]:
|
|
|
|
| 53 |
resp.raise_for_status()
|
| 54 |
data = resp.json()
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# handle various shapes
|
| 57 |
if "choices" in data and len(data["choices"]) > 0:
|
| 58 |
# prefer message.content
|
|
|
|
| 64 |
if "text" in ch:
|
| 65 |
return ch["text"]
|
| 66 |
|
|
|
|
| 67 |
# final fallback
|
| 68 |
raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
|
| 69 |
|
|
|
|
| 85 |
return json.loads(fixed)
|
| 86 |
|
| 87 |
|
| 88 |
+
def structure_context_for_llm(
|
| 89 |
+
source_text: str,
|
| 90 |
+
model: str = "openai/gpt-oss-120b",
|
| 91 |
+
temperature: float = 0.2,
|
| 92 |
+
enable_fiddler = False,
|
| 93 |
+
) -> Dict[str, Any]:
|
| 94 |
+
"""
|
| 95 |
+
Take a long source_text, split into N chunks, and restructure them
|
| 96 |
+
so each chunk is self-contained, structured, and semantically meaningful.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
system_message = {
|
| 100 |
+
"role": "system",
|
| 101 |
+
"content": (
|
| 102 |
+
"Bạn là một trợ lý hữu ích chuyên xử lý và cấu trúc văn bản để phục vụ mô hình ngôn ngữ (LLM). Trả lời bằng Tiếng Việt\n"
|
| 103 |
+
"Nhiệm vụ của bạn là:\n"
|
| 104 |
+
"- Nếu văn bản dài trên 500 từ chia văn bản thành 2 đoạn (chunk) có ý nghĩa rõ ràng.\n"
|
| 105 |
+
"- Mỗi chunk phải **tự chứa đủ thông tin** (self-contained) để LLM có thể hiểu độc lập.\n"
|
| 106 |
+
"- Xác định **chủ đề chính (topic)** của mỗi chunk và dùng nó làm KEY trong JSON.\n"
|
| 107 |
+
"- Trong mỗi topic, tổ chức thông tin thành cấu trúc rõ ràng gồm các trường:\n"
|
| 108 |
+
" - 'đoạn văn': nội dung gốc đã cấu trúc đầy đủ\n"
|
| 109 |
+
" - 'khái niệm chính': từ điểm chứa các khái niệm chính với khái niệm phụ hỗ trợ khái niệm chính đi kèm nếu có\n"
|
| 110 |
+
" - 'công thức': danh sách công thức (nếu có)\n"
|
| 111 |
+
" - 'ví dụ': ví dụ minh họa (nếu có)\n"
|
| 112 |
+
" - 'tóm tắt': tóm tắt nội dung, dễ hiểu\n"
|
| 113 |
+
"- Giữ ngữ nghĩa liền mạch.\n"
|
| 114 |
+
"- Chỉ TRẢ VỀ MỘT JSON hợp lệ theo schema, không kèm văn bản khác.\n\n"
|
| 115 |
+
|
| 116 |
+
"Chỉ TRẢ VỀ duy nhất MỘT đối tượng JSON theo schema sau và không có bất kỳ văn bản nào khác:\n\n"
|
| 117 |
+
"{\n"
|
| 118 |
+
' "Tên topic": {"đoạn văn": "nội dung đã cấu trúc của topic 1", "khái niệm chính": {"khái niệm chính 1":["khái niệm phụ", "..."],"khái niệm chính 2":["khái niệm phụ", "..."]}, "công thức": ["..."], "ví dụ": ["..."], "tóm tắt": "tóm tắt ngắn gọn"},\n'
|
| 119 |
+
"}\n"
|
| 120 |
+
)
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
user_message = {
|
| 124 |
+
"role": "user",
|
| 125 |
+
"content": (
|
| 126 |
+
"Hãy chia văn bản sau thành nhiều chunk theo hướng dẫn trên và xuất JSON hợp lệ.\n"
|
| 127 |
+
f"### Văn bản nguồn:\n{source_text}"
|
| 128 |
+
)
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
if enable_fiddler:
|
| 132 |
+
max_conf, max_cat = text_safety_check(user_message['content'])
|
| 133 |
+
if max_conf > 0.5:
|
| 134 |
+
print(f"Harmful content detected: ({max_cat} : {max_conf})")
|
| 135 |
+
return {}
|
| 136 |
+
|
| 137 |
+
raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
|
| 138 |
+
parsed = _safe_extract_json(raw)
|
| 139 |
+
if not isinstance(parsed, dict):
|
| 140 |
+
raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
|
| 141 |
+
return parsed
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def new_generate_mcqs_from_text(
|
| 145 |
+
source_text: str,
|
| 146 |
+
n: int = 3,
|
| 147 |
+
model: str = "openai/openai/gpt-oss-120b",
|
| 148 |
+
temperature: float = 0.2,
|
| 149 |
+
enable_fiddler = False,
|
| 150 |
+
target_difficulty: str = "easy",
|
| 151 |
+
|
| 152 |
+
) -> Dict[str, Any]:
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
expected_concepts = {
|
| 156 |
+
"easy": 1,
|
| 157 |
+
"medium": 2,
|
| 158 |
+
"hard": (3, 4)
|
| 159 |
+
}
|
| 160 |
+
if isinstance(expected_concepts[target_difficulty], tuple):
|
| 161 |
+
min_concepts, max_concepts = expected_concepts[target_difficulty]
|
| 162 |
+
concept_range = f"{min_concepts}-{max_concepts}"
|
| 163 |
+
else:
|
| 164 |
+
concept_range = expected_concepts[target_difficulty]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
difficulty_prompts = {
|
| 168 |
+
"easy": (
|
| 169 |
+
"- Câu hỏi DỄ: kiểm tra duy nhất 1 khái niệm chính cơ bản dễ hiểu, định nghĩa, hoặc công thức đơn giản."
|
| 170 |
+
"- Đáp án có thể tìm thấy trực tiếp trong văn bản."
|
| 171 |
+
"- Ngữ cảnh đủ để hiểu khái niệm chính."
|
| 172 |
+
"- Distractors khác biệt rõ ràng, dễ loại bỏ."
|
| 173 |
+
"- Độ dài câu hỏi ngắn gọn không quá 10-20 từ hoặc ít hơn 120 ký tự, tập trung vào một ý duy nhất.\n"
|
| 174 |
+
),
|
| 175 |
+
"medium": (
|
| 176 |
+
"- Câu hỏi TRUNG BÌNH kiểm tra khái niệm chính trong văn bản"
|
| 177 |
+
"- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
|
| 178 |
+
"- Các Distractors không quá giống nhau."
|
| 179 |
+
"- Độ dài câu hỏi vừa phải khoảng 23–30 từ hoặc khoảng 150 - 180 ký tự, có thêm chi tiết phụ để suy luận.\n"
|
| 180 |
+
),
|
| 181 |
+
"hard": (
|
| 182 |
+
"- Câu hỏi KHÓ kiểm tra thông tin được phân tích/tổng hợp"
|
| 183 |
+
"- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
|
| 184 |
+
"- Ít nhất 2 distractors gần giống đáp án đúng, độ tương đồng cao. "
|
| 185 |
+
f"- Đáp án yêu cầu học sinh suy luận hoặc áp dụng công thức vào ví dụ nếu có."
|
| 186 |
+
"- Độ dài câu hỏi dài hơn 35 từ hoặc hơn 200 ký tự.\n \n"
|
| 187 |
+
)
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
difficult_criteria = difficulty_prompts[target_difficulty] # "easy", "medium", "hard"
|
| 191 |
+
print(concept_range)
|
| 192 |
+
system_message = {
|
| 193 |
+
"role": "system",
|
| 194 |
+
"content": (
|
| 195 |
+
"Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm (MCQ). Luôn trả lời bằng tiếng việt"
|
| 196 |
+
f"Đảm bảo chỉ tạo sinh câu trắc nghiệm có độ khó sau {difficult_criteria}"
|
| 197 |
+
f"Quan trọng: Mỗi câu hỏi chỉ sử dụng chính xác {concept_range} khái niệm chính (mỗi khái niệm chính có 1 danh sách khái niệm phụ) từ văn bản nguồn. "
|
| 198 |
+
"Mỗi câu hỏi và đáp án phải dựa trên thông tin từ văn bản nguồn. Không được đưa kiến thức ngoài vào."
|
| 199 |
+
"Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không kèm giải thích hay trường thêm:\n\n"
|
| 200 |
+
"{\n"
|
| 201 |
+
' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"...", "khái niệm sử dụng": {"khái niệm chính":["khái niệm phụ", "..."], "..."]}},\n'
|
| 202 |
+
' "2": { ... }\n'
|
| 203 |
+
"}\n\n"
|
| 204 |
+
"Lưu ý:\n"
|
| 205 |
+
f"- Tạo đúng {n} mục, đánh số từ 1 tới {n}.\n"
|
| 206 |
+
"- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
|
| 207 |
+
"- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'options'.\n"
|
| 208 |
+
"- Toàn bộ thông tin cần thiết để trả lời phải nằm trong chính câu hỏi, không tham chiếu lại văn bản nguồn."
|
| 209 |
+
f"- Sử dụng chính xác {concept_range} khái niệm chính"
|
| 210 |
+
)
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
user_message = {
|
| 214 |
+
"role": "user",
|
| 215 |
+
"content": (
|
| 216 |
+
f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Chỉ sử dụng nội dung này làm nguồn duy nhất để xây dựng câu hỏi.\n\n"
|
| 217 |
+
|
| 218 |
+
"### Yêu cầu:\n"
|
| 219 |
+
"- Bám sát vào thông tin trong văn bản; không thêm kiến thức ngoài.\n"
|
| 220 |
+
"- Nếu văn bản thiếu chi tiết, hãy tạo phương án nhiễu (distractors) hợp lý, nhưng phải có thể biện minh từ nội dung hoặc ngữ cảnh.\n"
|
| 221 |
+
f"### Văn bản nguồn:\n{source_text}"
|
| 222 |
+
)
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
if enable_fiddler:
|
| 227 |
+
max_conf, max_cat = text_safety_check(user_message['content'])
|
| 228 |
+
if max_conf > 0.5:
|
| 229 |
+
print(f"Harmful content detected: ({max_cat} : {max_conf})")
|
| 230 |
+
return {}
|
| 231 |
+
|
| 232 |
+
raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
|
| 233 |
+
# print('\n\n',raw)
|
| 234 |
+
parsed = _safe_extract_json(raw)
|
| 235 |
+
# basic validation
|
| 236 |
+
if not isinstance(parsed, dict) or len(parsed) != n:
|
| 237 |
+
raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
|
| 238 |
+
return parsed
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
|
| 242 |
def generate_mcqs_from_text(
|
| 243 |
source_text: str,
|
| 244 |
n: int = 3,
|
| 245 |
+
model: str = "openai/gpt-oss-120b",
|
| 246 |
temperature: float = 0.2,
|
| 247 |
enable_fiddler: bool = False,
|
| 248 |
) -> Dict[str, Any]:
|