Spaces:
Sleeping
Sleeping
update 2 difficulty endpoint and 2 original endpoint
Browse files- app.py +134 -27
- generator.py +37 -28
- utils.py +45 -45
app.py
CHANGED
|
@@ -8,7 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
|
| 10 |
# Import the user's RAGMCQ implementation
|
| 11 |
-
from generator import RAGMCQWithDifficulty
|
| 12 |
from utils import log_pipeline
|
| 13 |
|
| 14 |
app = FastAPI(title="RAG MCQ Generator API")
|
|
@@ -23,7 +23,8 @@ app.add_middleware(
|
|
| 23 |
)
|
| 24 |
|
| 25 |
# global rag instance
|
| 26 |
-
rag: Optional[
|
|
|
|
| 27 |
|
| 28 |
class GenerateResponse(BaseModel):
|
| 29 |
mcqs: dict
|
|
@@ -34,10 +35,11 @@ class ListResponse(BaseModel):
|
|
| 34 |
|
| 35 |
@app.on_event("startup")
|
| 36 |
def startup_event():
|
| 37 |
-
global rag
|
| 38 |
|
| 39 |
# instantiate the heavy object once
|
| 40 |
-
rag =
|
|
|
|
| 41 |
print("RAGMCQ instance created on startup.")
|
| 42 |
|
| 43 |
@app.get("/health")
|
|
@@ -121,22 +123,23 @@ async def upload_multiple_files(
|
|
| 121 |
return {"files": saved_files}
|
| 122 |
|
| 123 |
|
| 124 |
-
|
| 125 |
-
@app.post("/
|
| 126 |
-
async def
|
| 127 |
n_easy_questions: int = Form(3),
|
| 128 |
n_medium_questions: int = Form(5),
|
| 129 |
n_hard_questions: int = Form(2),
|
| 130 |
qdrant_filename: str = Form("default_filename"),
|
| 131 |
collection_name: str = Form("programming"),
|
| 132 |
mode: str = Form("rag"),
|
| 133 |
-
questions_per_chunk: int = Form(3),
|
| 134 |
top_k: int = Form(3),
|
| 135 |
temperature: float = Form(0.2),
|
| 136 |
validate_mcqs: bool = Form(False),
|
| 137 |
enable_fiddler: bool = Form(False),
|
| 138 |
):
|
| 139 |
-
global
|
|
|
|
|
|
|
| 140 |
if rag is None:
|
| 141 |
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 142 |
|
|
@@ -156,8 +159,8 @@ async def generate_saved_endpoint(
|
|
| 156 |
collection=collection_name,
|
| 157 |
n_questions=n_questions,
|
| 158 |
mode=mode,
|
| 159 |
-
questions_per_chunk=
|
| 160 |
-
top_k=
|
| 161 |
temperature=temperature,
|
| 162 |
enable_fiddler=enable_fiddler,
|
| 163 |
target_difficulty=difficulty,
|
|
@@ -193,30 +196,27 @@ async def generate_saved_endpoint(
|
|
| 193 |
# don't fail the whole request for a validation error — return generator output and note the error
|
| 194 |
validation_report = {"error": f"Validation failed: {e}"}
|
| 195 |
|
| 196 |
-
# log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
|
| 197 |
-
|
| 198 |
return {"mcqs": all_mcqs, "validation": validation_report}
|
| 199 |
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
@app.post("/generate", response_model=GenerateResponse)
|
| 204 |
-
async def generate_endpoint(
|
| 205 |
background_tasks: BackgroundTasks,
|
| 206 |
file: UploadFile = File(...),
|
|
|
|
|
|
|
| 207 |
n_easy_questions: int = Form(3),
|
| 208 |
n_medium_questions: int = Form(5),
|
| 209 |
n_hard_questions: int = Form(2),
|
| 210 |
-
qdrant_filename: str = Form("default_filename"),
|
| 211 |
-
collection_name: str = Form("programming"),
|
| 212 |
mode: str = Form("rag"),
|
| 213 |
-
questions_per_page: int = Form(3),
|
| 214 |
top_k: int = Form(3),
|
| 215 |
temperature: float = Form(0.2),
|
| 216 |
validate_mcqs: bool = Form(False),
|
| 217 |
enable_fiddler: bool = Form(False)
|
| 218 |
):
|
| 219 |
-
global
|
|
|
|
|
|
|
| 220 |
if rag is None:
|
| 221 |
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 222 |
|
|
@@ -254,10 +254,9 @@ async def generate_endpoint(
|
|
| 254 |
for difficulty, n_questions in difficulty_counts.items():
|
| 255 |
try:
|
| 256 |
mcqs = rag.generate_from_pdf(
|
| 257 |
-
|
| 258 |
n_questions=n_questions,
|
| 259 |
mode=mode,
|
| 260 |
-
questions_per_page=questions_per_page,
|
| 261 |
top_k=top_k,
|
| 262 |
temperature=temperature,
|
| 263 |
enable_fiddler=enable_fiddler,
|
|
@@ -282,23 +281,131 @@ async def generate_endpoint(
|
|
| 282 |
counter += 1
|
| 283 |
|
| 284 |
except Exception as e:
|
| 285 |
-
raise HTTPException(status_code=500, detail=f"Generation from file failed: {e}")
|
| 286 |
|
| 287 |
validation_report = None
|
| 288 |
|
| 289 |
if validate_mcqs:
|
| 290 |
try:
|
| 291 |
-
# rag.build_index_from_pdf(tmp_path)
|
| 292 |
# validate_mcqs expects keys as strings and the normalized content
|
| 293 |
validation_report = rag.validate_mcqs(all_mcqs, top_k=top_k)
|
| 294 |
except Exception as e:
|
| 295 |
# don't fail the whole request for a validation error — return generator output and note the error
|
| 296 |
validation_report = {"error": f"Validation failed: {e}"}
|
| 297 |
|
|
|
|
| 298 |
|
| 299 |
-
# log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
|
| 300 |
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
|
| 304 |
if __name__ == "__main__":
|
|
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
|
| 10 |
# Import the user's RAGMCQ implementation
|
| 11 |
+
from generator import RAGMCQWithDifficulty, RAGMCQ
|
| 12 |
from utils import log_pipeline
|
| 13 |
|
| 14 |
app = FastAPI(title="RAG MCQ Generator API")
|
|
|
|
| 23 |
)
|
| 24 |
|
| 25 |
# global rag instance
|
| 26 |
+
rag: Optional[RAGMCQ] = None
|
| 27 |
+
rag_difficulty: Optional[RAGMCQWithDifficulty] = None
|
| 28 |
|
| 29 |
class GenerateResponse(BaseModel):
|
| 30 |
mcqs: dict
|
|
|
|
| 35 |
|
| 36 |
@app.on_event("startup")
|
| 37 |
def startup_event():
|
| 38 |
+
global rag, rag_difficulty
|
| 39 |
|
| 40 |
# instantiate the heavy object once
|
| 41 |
+
rag = RAGMCQ()
|
| 42 |
+
rag_difficulty = RAGMCQWithDifficulty()
|
| 43 |
print("RAGMCQ instance created on startup.")
|
| 44 |
|
| 45 |
@app.get("/health")
|
|
|
|
| 123 |
return {"files": saved_files}
|
| 124 |
|
| 125 |
|
| 126 |
+
#? GENERATION with DIFFICULTY
|
| 127 |
+
@app.post("/generate_saved_difficulty", response_model=GenerateResponse)
|
| 128 |
+
async def generate_saved_difficulty_endpoint(
|
| 129 |
n_easy_questions: int = Form(3),
|
| 130 |
n_medium_questions: int = Form(5),
|
| 131 |
n_hard_questions: int = Form(2),
|
| 132 |
qdrant_filename: str = Form("default_filename"),
|
| 133 |
collection_name: str = Form("programming"),
|
| 134 |
mode: str = Form("rag"),
|
|
|
|
| 135 |
top_k: int = Form(3),
|
| 136 |
temperature: float = Form(0.2),
|
| 137 |
validate_mcqs: bool = Form(False),
|
| 138 |
enable_fiddler: bool = Form(False),
|
| 139 |
):
|
| 140 |
+
global rag_difficulty
|
| 141 |
+
rag = rag_difficulty
|
| 142 |
+
|
| 143 |
if rag is None:
|
| 144 |
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 145 |
|
|
|
|
| 159 |
collection=collection_name,
|
| 160 |
n_questions=n_questions,
|
| 161 |
mode=mode,
|
| 162 |
+
questions_per_chunk=3,
|
| 163 |
+
top_k=3,
|
| 164 |
temperature=temperature,
|
| 165 |
enable_fiddler=enable_fiddler,
|
| 166 |
target_difficulty=difficulty,
|
|
|
|
| 196 |
# don't fail the whole request for a validation error — return generator output and note the error
|
| 197 |
validation_report = {"error": f"Validation failed: {e}"}
|
| 198 |
|
|
|
|
|
|
|
| 199 |
return {"mcqs": all_mcqs, "validation": validation_report}
|
| 200 |
|
| 201 |
|
| 202 |
+
@app.post("/generate_difficulty", response_model=GenerateResponse)
|
| 203 |
+
async def generate_difficulty_endpoint(
|
|
|
|
|
|
|
| 204 |
background_tasks: BackgroundTasks,
|
| 205 |
file: UploadFile = File(...),
|
| 206 |
+
qdrant_filename: str = Form("default_filename"),
|
| 207 |
+
collection_name: str = Form("programming"),
|
| 208 |
n_easy_questions: int = Form(3),
|
| 209 |
n_medium_questions: int = Form(5),
|
| 210 |
n_hard_questions: int = Form(2),
|
|
|
|
|
|
|
| 211 |
mode: str = Form("rag"),
|
|
|
|
| 212 |
top_k: int = Form(3),
|
| 213 |
temperature: float = Form(0.2),
|
| 214 |
validate_mcqs: bool = Form(False),
|
| 215 |
enable_fiddler: bool = Form(False)
|
| 216 |
):
|
| 217 |
+
global rag_difficulty
|
| 218 |
+
rag = rag_difficulty
|
| 219 |
+
|
| 220 |
if rag is None:
|
| 221 |
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 222 |
|
|
|
|
| 254 |
for difficulty, n_questions in difficulty_counts.items():
|
| 255 |
try:
|
| 256 |
mcqs = rag.generate_from_pdf(
|
| 257 |
+
tmp_path,
|
| 258 |
n_questions=n_questions,
|
| 259 |
mode=mode,
|
|
|
|
| 260 |
top_k=top_k,
|
| 261 |
temperature=temperature,
|
| 262 |
enable_fiddler=enable_fiddler,
|
|
|
|
| 281 |
counter += 1
|
| 282 |
|
| 283 |
except Exception as e:
|
| 284 |
+
raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
|
| 285 |
|
| 286 |
validation_report = None
|
| 287 |
|
| 288 |
if validate_mcqs:
|
| 289 |
try:
|
|
|
|
| 290 |
# validate_mcqs expects keys as strings and the normalized content
|
| 291 |
validation_report = rag.validate_mcqs(all_mcqs, top_k=top_k)
|
| 292 |
except Exception as e:
|
| 293 |
# don't fail the whole request for a validation error — return generator output and note the error
|
| 294 |
validation_report = {"error": f"Validation failed: {e}"}
|
| 295 |
|
| 296 |
+
return {"mcqs": all_mcqs, "validation": validation_report}
|
| 297 |
|
|
|
|
| 298 |
|
| 299 |
+
|
| 300 |
+
#? ORIGINAL Generation
|
| 301 |
+
@app.post("/generate_saved", response_model=GenerateResponse)
|
| 302 |
+
async def generate_saved_endpoint(
|
| 303 |
+
qdrant_filename: str = Form("default_filename"),
|
| 304 |
+
collection_name: str = Form("programming"),
|
| 305 |
+
mode: str = Form("rag"),
|
| 306 |
+
n_questions: int = 10,
|
| 307 |
+
questions_per_chunk: int = Form(3),
|
| 308 |
+
top_k: int = Form(3),
|
| 309 |
+
temperature: float = Form(0.2),
|
| 310 |
+
validate_mcqs: bool = Form(False),
|
| 311 |
+
enable_fiddler: bool = Form(False),
|
| 312 |
+
):
|
| 313 |
+
global rag
|
| 314 |
+
if rag is None:
|
| 315 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 316 |
+
|
| 317 |
+
try:
|
| 318 |
+
mcqs = rag.generate_from_qdrant(
|
| 319 |
+
filename=qdrant_filename,
|
| 320 |
+
collection=collection_name,
|
| 321 |
+
mode=mode,
|
| 322 |
+
n_questions=n_questions,
|
| 323 |
+
questions_per_chunk=questions_per_chunk,
|
| 324 |
+
top_k=top_k,
|
| 325 |
+
temperature=temperature,
|
| 326 |
+
enable_fiddler=enable_fiddler,
|
| 327 |
+
)
|
| 328 |
+
except Exception as e:
|
| 329 |
+
raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
|
| 330 |
+
|
| 331 |
+
validation_report = None
|
| 332 |
+
|
| 333 |
+
if validate_mcqs:
|
| 334 |
+
try:
|
| 335 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 336 |
+
validation_report = rag.validate_mcqs(mcqs, top_k=top_k)
|
| 337 |
+
except Exception as e:
|
| 338 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 339 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 340 |
+
|
| 341 |
+
return {"mcqs": mcqs, "validation": validation_report}
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
@app.post("/generate", response_model=GenerateResponse)
|
| 345 |
+
async def generate_endpoint(
|
| 346 |
+
background_tasks: BackgroundTasks,
|
| 347 |
+
file: UploadFile = File(...),
|
| 348 |
+
n_questions: int = Form(10),
|
| 349 |
+
qdrant_filename: str = Form("default_filename"),
|
| 350 |
+
collection_name: str = Form("programming"),
|
| 351 |
+
mode: str = Form("rag"),
|
| 352 |
+
questions_per_page: int = Form(3),
|
| 353 |
+
top_k: int = Form(3),
|
| 354 |
+
temperature: float = Form(0.2),
|
| 355 |
+
validate_mcqs: bool = Form(False),
|
| 356 |
+
enable_fiddler: bool = Form(False)
|
| 357 |
+
):
|
| 358 |
+
global rag
|
| 359 |
+
if rag is None:
|
| 360 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 361 |
+
|
| 362 |
+
# basic file validation
|
| 363 |
+
if not file.filename.lower().endswith(".pdf"):
|
| 364 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
|
| 365 |
+
|
| 366 |
+
# save uploaded file to a temp location
|
| 367 |
+
tmp_path = _save_upload_to_temp(file)
|
| 368 |
+
|
| 369 |
+
# ensure file removed afterward
|
| 370 |
+
def _cleanup(path: str):
|
| 371 |
+
try:
|
| 372 |
+
os.remove(path)
|
| 373 |
+
except Exception:
|
| 374 |
+
pass
|
| 375 |
+
|
| 376 |
+
background_tasks.add_task(_cleanup, tmp_path)
|
| 377 |
+
|
| 378 |
+
# save pdf
|
| 379 |
+
try:
|
| 380 |
+
rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=True)
|
| 381 |
+
except Exception as e:
|
| 382 |
+
raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
|
| 383 |
+
|
| 384 |
+
# generate
|
| 385 |
+
try:
|
| 386 |
+
mcqs = rag.generate_from_pdf(
|
| 387 |
+
tmp_path,
|
| 388 |
+
n_questions=n_questions,
|
| 389 |
+
mode=mode,
|
| 390 |
+
questions_per_page=questions_per_page,
|
| 391 |
+
top_k=top_k,
|
| 392 |
+
temperature=temperature,
|
| 393 |
+
enable_fiddler=enable_fiddler
|
| 394 |
+
)
|
| 395 |
+
except Exception as e:
|
| 396 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
|
| 397 |
+
|
| 398 |
+
validation_report = None
|
| 399 |
+
|
| 400 |
+
if validate_mcqs:
|
| 401 |
+
try:
|
| 402 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 403 |
+
validation_report = rag.validate_mcqs(mcqs, top_k=top_k)
|
| 404 |
+
except Exception as e:
|
| 405 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 406 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 407 |
+
|
| 408 |
+
return {"mcqs": mcqs, "validation": validation_report}
|
| 409 |
|
| 410 |
|
| 411 |
if __name__ == "__main__":
|
generator.py
CHANGED
|
@@ -32,7 +32,7 @@ try:
|
|
| 32 |
except Exception:
|
| 33 |
_HAS_FAISS = False
|
| 34 |
|
| 35 |
-
from utils import generate_mcqs_from_text,
|
| 36 |
|
| 37 |
from huggingface_hub import login
|
| 38 |
login(token=os.environ['HF_MODEL_TOKEN'])
|
|
@@ -202,12 +202,12 @@ class RAGMCQ:
|
|
| 202 |
|
| 203 |
if not chunk_text.strip():
|
| 204 |
continue
|
| 205 |
-
to_gen = questions_per_page
|
| 206 |
|
| 207 |
# ask generator
|
| 208 |
try:
|
|
|
|
| 209 |
mcq_block = generate_mcqs_from_text(
|
| 210 |
-
|
| 211 |
)
|
| 212 |
except Exception as e:
|
| 213 |
# skip this chunk if generator fails
|
|
@@ -259,8 +259,9 @@ class RAGMCQ:
|
|
| 259 |
# call generator for 1 question (or small batch) with the retrieved context
|
| 260 |
try:
|
| 261 |
# request 1 question at a time to keep diversity
|
|
|
|
| 262 |
mcq_block = generate_mcqs_from_text(
|
| 263 |
-
|
| 264 |
)
|
| 265 |
except Exception as e:
|
| 266 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
|
@@ -683,7 +684,7 @@ class RAGMCQ:
|
|
| 683 |
"text": txt,
|
| 684 |
"source_id": source_id,
|
| 685 |
}
|
| 686 |
-
points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload))
|
| 687 |
|
| 688 |
# upsert in batches
|
| 689 |
if len(points) >= batch_size:
|
|
@@ -887,9 +888,10 @@ class RAGMCQ:
|
|
| 887 |
for i, txt in enumerate(texts):
|
| 888 |
if not txt.strip():
|
| 889 |
continue
|
| 890 |
-
|
| 891 |
try:
|
| 892 |
-
|
|
|
|
| 893 |
except Exception as e:
|
| 894 |
print(f"Generator failed on chunk (index {i}): {e}")
|
| 895 |
continue
|
|
@@ -933,7 +935,8 @@ class RAGMCQ:
|
|
| 933 |
context = "\n\n".join(context_parts)
|
| 934 |
|
| 935 |
try:
|
| 936 |
-
|
|
|
|
| 937 |
except Exception as e:
|
| 938 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 939 |
continue
|
|
@@ -967,12 +970,14 @@ class RAGMCQ:
|
|
| 967 |
else:
|
| 968 |
raise ValueError("mode must be 'per_chunk' or 'rag'.")
|
| 969 |
|
|
|
|
|
|
|
| 970 |
def _estimate_difficulty_for_generation(
|
| 971 |
self,
|
| 972 |
q_text: str,
|
| 973 |
options: Dict[str, str],
|
| 974 |
correct_text: str,
|
| 975 |
-
context_text: str
|
| 976 |
) -> Tuple[float, str]:
|
| 977 |
def safe_map_sim(s):
|
| 978 |
# map potentially [-1,1] cosine-like to [0,1], clamp
|
|
@@ -1112,13 +1117,13 @@ class RAGMCQWithDifficulty(RAGMCQ):
|
|
| 1112 |
|
| 1113 |
if not chunk_text.strip():
|
| 1114 |
continue
|
| 1115 |
-
|
| 1116 |
|
| 1117 |
# ask generator
|
| 1118 |
try:
|
| 1119 |
-
structured_context = structure_context_for_llm(
|
| 1120 |
-
mcq_block =
|
| 1121 |
-
source_text=
|
| 1122 |
)
|
| 1123 |
except Exception as e:
|
| 1124 |
# skip this chunk if generator fails
|
|
@@ -1170,9 +1175,10 @@ class RAGMCQWithDifficulty(RAGMCQ):
|
|
| 1170 |
|
| 1171 |
# call generator for 1 question (or small batch) with the retrieved context
|
| 1172 |
try:
|
| 1173 |
-
# request 1 question at a time to keep diversity
|
| 1174 |
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1175 |
-
mcq_block = new_generate_mcqs_from_text(
|
|
|
|
|
|
|
| 1176 |
except Exception as e:
|
| 1177 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 1178 |
continue
|
|
@@ -1264,9 +1270,12 @@ class RAGMCQWithDifficulty(RAGMCQ):
|
|
| 1264 |
for i, txt in enumerate(texts):
|
| 1265 |
if not txt.strip():
|
| 1266 |
continue
|
| 1267 |
-
to_gen = questions_per_chunk
|
| 1268 |
try:
|
| 1269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1270 |
except Exception as e:
|
| 1271 |
print(f"Generator failed on chunk (index {i}): {e}")
|
| 1272 |
continue
|
|
@@ -1298,10 +1307,8 @@ class RAGMCQWithDifficulty(RAGMCQ):
|
|
| 1298 |
seed_sent = (stripped[:200] if stripped else "[no text available]")
|
| 1299 |
query = f"Create questions about: {seed_sent}"
|
| 1300 |
|
| 1301 |
-
|
| 1302 |
# retrieve top_k chunks from the same file (restricted by filename filter)
|
| 1303 |
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
|
| 1304 |
-
print('retrieved qdrant', retrieved)
|
| 1305 |
context_parts = []
|
| 1306 |
for payload, score in retrieved:
|
| 1307 |
# payload should contain page & chunk_id and text
|
|
@@ -1313,9 +1320,11 @@ class RAGMCQWithDifficulty(RAGMCQ):
|
|
| 1313 |
|
| 1314 |
# q generation
|
| 1315 |
try:
|
| 1316 |
-
# Difficulty pipeline: easy, mid, difficult
|
| 1317 |
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1318 |
-
mcq_block = new_generate_mcqs_from_text(
|
|
|
|
|
|
|
|
|
|
| 1319 |
except Exception as e:
|
| 1320 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 1321 |
continue
|
|
@@ -1342,18 +1351,18 @@ class RAGMCQWithDifficulty(RAGMCQ):
|
|
| 1342 |
|
| 1343 |
#? change estimate
|
| 1344 |
diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
|
| 1345 |
-
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
|
| 1346 |
)
|
| 1347 |
|
| 1348 |
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
|
| 1349 |
|
| 1350 |
# CHECK n generation: if number of request mcqs < default generation number e.g. 5 - 3 = 2 < 3 then only genearate 2 mcqs
|
| 1351 |
if n_questions - qcount < questions_per_chunk:
|
| 1352 |
-
|
| 1353 |
|
| 1354 |
qcount += 1 # count number of question
|
| 1355 |
-
print('qcount:', qcount)
|
| 1356 |
-
print('questions_per_chunk:', questions_per_chunk)
|
| 1357 |
|
| 1358 |
output[str(qcount)] = mcq_block[item]
|
| 1359 |
if qcount >= n_questions:
|
|
@@ -1371,7 +1380,7 @@ class RAGMCQWithDifficulty(RAGMCQ):
|
|
| 1371 |
q_text: str,
|
| 1372 |
options: Dict[str, str],
|
| 1373 |
correct_text: str,
|
| 1374 |
-
context_text: str
|
| 1375 |
concepts_used: Dict = {}
|
| 1376 |
) -> Tuple[float, str]:
|
| 1377 |
def safe_map_sim(s):
|
|
@@ -1468,7 +1477,7 @@ class RAGMCQWithDifficulty(RAGMCQ):
|
|
| 1468 |
score += 0.35 * float(distractor_penalty)
|
| 1469 |
score += 0.20 * float(mean_sim)
|
| 1470 |
score += 0.22 * float(amb_flag)
|
| 1471 |
-
score += 0.
|
| 1472 |
score -= 0.20 * float(gap)
|
| 1473 |
|
| 1474 |
# clamp
|
|
@@ -1493,4 +1502,4 @@ class RAGMCQWithDifficulty(RAGMCQ):
|
|
| 1493 |
else:
|
| 1494 |
label = "khó"
|
| 1495 |
|
| 1496 |
-
return score, label, components
|
|
|
|
| 32 |
except Exception:
|
| 33 |
_HAS_FAISS = False
|
| 34 |
|
| 35 |
+
from utils import generate_mcqs_from_text, structure_context_for_llm, new_generate_mcqs_from_text
|
| 36 |
|
| 37 |
from huggingface_hub import login
|
| 38 |
login(token=os.environ['HF_MODEL_TOKEN'])
|
|
|
|
| 202 |
|
| 203 |
if not chunk_text.strip():
|
| 204 |
continue
|
|
|
|
| 205 |
|
| 206 |
# ask generator
|
| 207 |
try:
|
| 208 |
+
structured_context = structure_context_for_llm(chunk_text, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
|
| 209 |
mcq_block = generate_mcqs_from_text(
|
| 210 |
+
structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
|
| 211 |
)
|
| 212 |
except Exception as e:
|
| 213 |
# skip this chunk if generator fails
|
|
|
|
| 259 |
# call generator for 1 question (or small batch) with the retrieved context
|
| 260 |
try:
|
| 261 |
# request 1 question at a time to keep diversity
|
| 262 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
|
| 263 |
mcq_block = generate_mcqs_from_text(
|
| 264 |
+
structured_context, n=1, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
|
| 265 |
)
|
| 266 |
except Exception as e:
|
| 267 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
|
|
|
| 684 |
"text": txt,
|
| 685 |
"source_id": source_id,
|
| 686 |
}
|
| 687 |
+
points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload)) # pyright: ignore[reportPossiblyUnboundVariable]
|
| 688 |
|
| 689 |
# upsert in batches
|
| 690 |
if len(points) >= batch_size:
|
|
|
|
| 888 |
for i, txt in enumerate(texts):
|
| 889 |
if not txt.strip():
|
| 890 |
continue
|
| 891 |
+
|
| 892 |
try:
|
| 893 |
+
structured_context = structure_context_for_llm(txt, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
|
| 894 |
+
mcq_block = generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler)
|
| 895 |
except Exception as e:
|
| 896 |
print(f"Generator failed on chunk (index {i}): {e}")
|
| 897 |
continue
|
|
|
|
| 935 |
context = "\n\n".join(context_parts)
|
| 936 |
|
| 937 |
try:
|
| 938 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
|
| 939 |
+
mcq_block = generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler)
|
| 940 |
except Exception as e:
|
| 941 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 942 |
continue
|
|
|
|
| 970 |
else:
|
| 971 |
raise ValueError("mode must be 'per_chunk' or 'rag'.")
|
| 972 |
|
| 973 |
+
|
| 974 |
+
|
| 975 |
def _estimate_difficulty_for_generation(
|
| 976 |
self,
|
| 977 |
q_text: str,
|
| 978 |
options: Dict[str, str],
|
| 979 |
correct_text: str,
|
| 980 |
+
context_text: str,
|
| 981 |
) -> Tuple[float, str]:
|
| 982 |
def safe_map_sim(s):
|
| 983 |
# map potentially [-1,1] cosine-like to [0,1], clamp
|
|
|
|
| 1117 |
|
| 1118 |
if not chunk_text.strip():
|
| 1119 |
continue
|
| 1120 |
+
|
| 1121 |
|
| 1122 |
# ask generator
|
| 1123 |
try:
|
| 1124 |
+
structured_context = structure_context_for_llm(chunk_text, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
|
| 1125 |
+
mcq_block = new_generate_mcqs_from_text(
|
| 1126 |
+
source_text=structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler
|
| 1127 |
)
|
| 1128 |
except Exception as e:
|
| 1129 |
# skip this chunk if generator fails
|
|
|
|
| 1175 |
|
| 1176 |
# call generator for 1 question (or small batch) with the retrieved context
|
| 1177 |
try:
|
|
|
|
| 1178 |
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1179 |
+
mcq_block = new_generate_mcqs_from_text(
|
| 1180 |
+
source_text=structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler
|
| 1181 |
+
)
|
| 1182 |
except Exception as e:
|
| 1183 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 1184 |
continue
|
|
|
|
| 1270 |
for i, txt in enumerate(texts):
|
| 1271 |
if not txt.strip():
|
| 1272 |
continue
|
|
|
|
| 1273 |
try:
|
| 1274 |
+
structured_context = structure_context_for_llm(txt, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1275 |
+
mcq_block = new_generate_mcqs_from_text(
|
| 1276 |
+
source_text=structured_context, n=questions_per_chunk, model=self.generation_model,
|
| 1277 |
+
temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler
|
| 1278 |
+
)
|
| 1279 |
except Exception as e:
|
| 1280 |
print(f"Generator failed on chunk (index {i}): {e}")
|
| 1281 |
continue
|
|
|
|
| 1307 |
seed_sent = (stripped[:200] if stripped else "[no text available]")
|
| 1308 |
query = f"Create questions about: {seed_sent}"
|
| 1309 |
|
|
|
|
| 1310 |
# retrieve top_k chunks from the same file (restricted by filename filter)
|
| 1311 |
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
|
|
|
|
| 1312 |
context_parts = []
|
| 1313 |
for payload, score in retrieved:
|
| 1314 |
# payload should contain page & chunk_id and text
|
|
|
|
| 1320 |
|
| 1321 |
# q generation
|
| 1322 |
try:
|
|
|
|
| 1323 |
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
|
| 1324 |
+
mcq_block = new_generate_mcqs_from_text(
|
| 1325 |
+
source_text=structured_context, n=questions_per_chunk, model=self.generation_model,
|
| 1326 |
+
temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler
|
| 1327 |
+
)
|
| 1328 |
except Exception as e:
|
| 1329 |
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 1330 |
continue
|
|
|
|
| 1351 |
|
| 1352 |
#? change estimate
|
| 1353 |
diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
|
| 1354 |
+
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts # type: ignore
|
| 1355 |
)
|
| 1356 |
|
| 1357 |
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
|
| 1358 |
|
| 1359 |
# CHECK n generation: if number of request mcqs < default generation number e.g. 5 - 3 = 2 < 3 then only genearate 2 mcqs
|
| 1360 |
if n_questions - qcount < questions_per_chunk:
|
| 1361 |
+
questions_per_chunk = n_questions - qcount
|
| 1362 |
|
| 1363 |
qcount += 1 # count number of question
|
| 1364 |
+
# print('qcount:', qcount)
|
| 1365 |
+
# print('questions_per_chunk:', questions_per_chunk)
|
| 1366 |
|
| 1367 |
output[str(qcount)] = mcq_block[item]
|
| 1368 |
if qcount >= n_questions:
|
|
|
|
| 1380 |
q_text: str,
|
| 1381 |
options: Dict[str, str],
|
| 1382 |
correct_text: str,
|
| 1383 |
+
context_text: str,
|
| 1384 |
concepts_used: Dict = {}
|
| 1385 |
) -> Tuple[float, str]:
|
| 1386 |
def safe_map_sim(s):
|
|
|
|
| 1477 |
score += 0.35 * float(distractor_penalty)
|
| 1478 |
score += 0.20 * float(mean_sim)
|
| 1479 |
score += 0.22 * float(amb_flag)
|
| 1480 |
+
score += 0.08 * float(question_len_norm)
|
| 1481 |
score -= 0.20 * float(gap)
|
| 1482 |
|
| 1483 |
# clamp
|
|
|
|
| 1502 |
else:
|
| 1503 |
label = "khó"
|
| 1504 |
|
| 1505 |
+
return score, label, components # type: ignore
|
utils.py
CHANGED
|
@@ -99,7 +99,7 @@ def structure_context_for_llm(
|
|
| 99 |
"content": (
|
| 100 |
"Bạn là một trợ lý hữu ích chuyên xử lý và cấu trúc văn bản để phục vụ mô hình ngôn ngữ (LLM). Trả lời bằng Tiếng Việt\n"
|
| 101 |
"Nhiệm vụ của bạn là:\n"
|
| 102 |
-
"- Nếu văn bản dài trên 500 từ chia văn bản thành
|
| 103 |
"- Mỗi chunk phải **tự chứa đủ thông tin** (self-contained) để LLM có thể hiểu độc lập.\n"
|
| 104 |
"- Xác định **chủ đề chính (topic)** của mỗi chunk và dùng nó làm KEY trong JSON.\n"
|
| 105 |
"- Trong mỗi topic, tổ chức thông tin thành cấu trúc rõ ràng gồm các trường:\n"
|
|
@@ -140,7 +140,7 @@ def structure_context_for_llm(
|
|
| 140 |
|
| 141 |
|
| 142 |
def new_generate_mcqs_from_text(
|
| 143 |
-
source_text:
|
| 144 |
n: int = 3,
|
| 145 |
model: str = "openai/openai/gpt-oss-120b",
|
| 146 |
temperature: float = 0.2,
|
|
@@ -151,60 +151,60 @@ def new_generate_mcqs_from_text(
|
|
| 151 |
|
| 152 |
|
| 153 |
expected_concepts = {
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
}
|
| 158 |
if isinstance(expected_concepts[target_difficulty], tuple):
|
| 159 |
-
|
| 160 |
-
|
| 161 |
else:
|
| 162 |
-
|
| 163 |
|
| 164 |
|
| 165 |
difficulty_prompts = {
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
}
|
| 187 |
|
| 188 |
difficult_criteria = difficulty_prompts[target_difficulty] # "easy", "medium", "hard"
|
| 189 |
print(concept_range)
|
| 190 |
system_message = {
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
)
|
| 209 |
}
|
| 210 |
|
|
|
|
| 99 |
"content": (
|
| 100 |
"Bạn là một trợ lý hữu ích chuyên xử lý và cấu trúc văn bản để phục vụ mô hình ngôn ngữ (LLM). Trả lời bằng Tiếng Việt\n"
|
| 101 |
"Nhiệm vụ của bạn là:\n"
|
| 102 |
+
"- Nếu văn bản dài trên 500 từ chia văn bản thành các chunk có ý nghĩa rõ ràng.\n"
|
| 103 |
"- Mỗi chunk phải **tự chứa đủ thông tin** (self-contained) để LLM có thể hiểu độc lập.\n"
|
| 104 |
"- Xác định **chủ đề chính (topic)** của mỗi chunk và dùng nó làm KEY trong JSON.\n"
|
| 105 |
"- Trong mỗi topic, tổ chức thông tin thành cấu trúc rõ ràng gồm các trường:\n"
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
def new_generate_mcqs_from_text(
|
| 143 |
+
source_text: Dict,
|
| 144 |
n: int = 3,
|
| 145 |
model: str = "openai/openai/gpt-oss-120b",
|
| 146 |
temperature: float = 0.2,
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
expected_concepts = {
|
| 154 |
+
"easy": 1,
|
| 155 |
+
"medium": 2,
|
| 156 |
+
"hard": (3, 4)
|
| 157 |
}
|
| 158 |
if isinstance(expected_concepts[target_difficulty], tuple):
|
| 159 |
+
min_concepts, max_concepts = expected_concepts[target_difficulty]
|
| 160 |
+
concept_range = f"{min_concepts}-{max_concepts}"
|
| 161 |
else:
|
| 162 |
+
concept_range = expected_concepts[target_difficulty]
|
| 163 |
|
| 164 |
|
| 165 |
difficulty_prompts = {
|
| 166 |
+
"easy": (
|
| 167 |
+
"- Câu hỏi DỄ: kiểm tra duy nhất 1 khái niệm chính cơ bản dễ hiểu, định nghĩa, hoặc công thức đơn giản."
|
| 168 |
+
"- Đáp án có thể tìm thấy trực tiếp trong văn bản."
|
| 169 |
+
"- Ngữ cảnh đủ để hiểu khái niệm chính."
|
| 170 |
+
"- Distractors khác biệt rõ ràng, dễ loại bỏ."
|
| 171 |
+
"- Độ dài câu hỏi ngắn gọn không quá 10-20 từ hoặc ít hơn 120 ký tự, tập trung vào một ý duy nhất.\n"
|
| 172 |
+
),
|
| 173 |
+
"medium": (
|
| 174 |
+
"- Câu hỏi TRUNG BÌNH kiểm tra khái niệm chính trong văn bản"
|
| 175 |
+
"- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
|
| 176 |
+
"- Các Distractors không quá giống nhau."
|
| 177 |
+
"- Độ dài câu hỏi vừa phải khoảng 23–30 từ hoặc khoảng 150 - 180 ký tự, có thêm chi tiết phụ để suy luận.\n"
|
| 178 |
+
),
|
| 179 |
+
"hard": (
|
| 180 |
+
"- Câu hỏi KHÓ kiểm tra thông tin được phân tích/tổng hợp"
|
| 181 |
+
"- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
|
| 182 |
+
"- Ít nhất 2 distractors gần giống đáp án đúng, độ tương đồng cao. "
|
| 183 |
+
f"- Đáp án yêu cầu học sinh suy luận hoặc áp dụng công thức vào ví dụ nếu có."
|
| 184 |
+
"- Độ dài câu hỏi dài hơn 35 từ hoặc hơn 200 ký tự.\n \n"
|
| 185 |
+
)
|
| 186 |
}
|
| 187 |
|
| 188 |
difficult_criteria = difficulty_prompts[target_difficulty] # "easy", "medium", "hard"
|
| 189 |
print(concept_range)
|
| 190 |
system_message = {
|
| 191 |
+
"role": "system",
|
| 192 |
+
"content": (
|
| 193 |
+
"Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm (MCQ). Luôn trả lời bằng tiếng việt"
|
| 194 |
+
f"Đảm bảo chỉ tạo sinh câu trắc nghiệm có độ khó sau {difficult_criteria}"
|
| 195 |
+
f"Quan trọng: Mỗi câu hỏi chỉ sử dụng chính xác {concept_range} khái niệm chính (mỗi khái niệm chính có 1 danh sách khái niệm phụ) từ văn bản nguồn. "
|
| 196 |
+
"Mỗi câu hỏi và đáp án phải dựa trên thông tin từ văn bản nguồn. Không được đưa kiến thức ngoài vào."
|
| 197 |
+
"Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không kèm giải thích hay trường thêm:\n\n"
|
| 198 |
+
"{\n"
|
| 199 |
+
' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"...", "khái niệm sử dụng": {"khái niệm chính":["khái niệm phụ", "..."], "..."]}},\n'
|
| 200 |
+
' "2": { ... }\n'
|
| 201 |
+
"}\n\n"
|
| 202 |
+
"Lưu ý:\n"
|
| 203 |
+
f"- Tạo đúng {n} mục, đánh số từ 1 tới {n}.\n"
|
| 204 |
+
"- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
|
| 205 |
+
"- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'options'.\n"
|
| 206 |
+
"- Toàn bộ thông tin cần thiết để trả lời phải nằm trong chính câu hỏi, không tham chiếu lại văn bản nguồn."
|
| 207 |
+
f"- Sử dụng chính xác {concept_range} khái niệm chính"
|
| 208 |
)
|
| 209 |
}
|
| 210 |
|