Spaces:
Sleeping
Sleeping
namberino commited on
Commit ·
a8b213d
1
Parent(s): d3f7049
Literally everything
Browse files- Dockerfile +32 -0
- README.md +8 -3
- app.py +259 -0
- generator.py +1125 -0
- requirements.txt +10 -0
- utils.py +408 -0
Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# set HF cache to /tmp for writable FS on Spaces
|
| 4 |
+
ENV HF_HOME=/tmp/huggingface
|
| 5 |
+
ENV TOKENIZERS_PARALLELISM=false
|
| 6 |
+
|
| 7 |
+
# install system packages needed by some python libs
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 9 |
+
build-essential \
|
| 10 |
+
git \
|
| 11 |
+
wget \
|
| 12 |
+
libsndfile1 \
|
| 13 |
+
libgl1 \
|
| 14 |
+
libglib2.0-0 \
|
| 15 |
+
poppler-utils \
|
| 16 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
+
|
| 18 |
+
WORKDIR /app
|
| 19 |
+
|
| 20 |
+
# copy requirements and install
|
| 21 |
+
COPY requirements.txt /app/requirements.txt
|
| 22 |
+
RUN pip install --upgrade pip
|
| 23 |
+
# try to be robust to wheels/build issues
|
| 24 |
+
# RUN pip wheel --no-cache-dir --wheel-dir=/wheels -r /app/requirements.txt || true
|
| 25 |
+
RUN pip install --no-cache-dir -r /app/requirements.txt
|
| 26 |
+
|
| 27 |
+
# copy app code
|
| 28 |
+
COPY . /app
|
| 29 |
+
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
|
| 32 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,15 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: gray
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Mcq Generator
|
| 3 |
+
emoji: 📚
|
| 4 |
+
colorFrom: purple
|
| 5 |
colorTo: gray
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
TODO:
|
| 14 |
+
+ Apply Cohen's Kappa to measure rate of aggreement between human and AI.
|
| 15 |
+
+ Improve function transparency by adding Documents
|
app.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import tempfile
|
| 4 |
+
from typing import List, Optional, Union
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
# Import the user's RAGMCQ implementation
|
| 11 |
+
from generator import RAGMCQ
|
| 12 |
+
from utils import log_pipeline
|
| 13 |
+
|
| 14 |
+
app = FastAPI(title="RAG MCQ Generator API")
|
| 15 |
+
|
| 16 |
+
# allow cross-origin requests (adjust in production)
|
| 17 |
+
app.add_middleware(
|
| 18 |
+
CORSMiddleware,
|
| 19 |
+
allow_origins=["*"],
|
| 20 |
+
allow_credentials=True,
|
| 21 |
+
allow_methods=["*"],
|
| 22 |
+
allow_headers=["*"],
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# global rag instance
|
| 26 |
+
rag: Optional[RAGMCQ] = None
|
| 27 |
+
|
| 28 |
+
class GenerateResponse(BaseModel):
|
| 29 |
+
mcqs: dict
|
| 30 |
+
validation: Optional[dict] = None
|
| 31 |
+
|
| 32 |
+
class ListResponse(BaseModel):
|
| 33 |
+
files: list
|
| 34 |
+
|
| 35 |
+
@app.on_event("startup")
|
| 36 |
+
def startup_event():
|
| 37 |
+
global rag
|
| 38 |
+
|
| 39 |
+
# instantiate the heavy object once
|
| 40 |
+
rag = RAGMCQ()
|
| 41 |
+
print("RAGMCQ instance created on startup.")
|
| 42 |
+
|
| 43 |
+
@app.get("/health")
|
| 44 |
+
def health():
|
| 45 |
+
return {"status": "ok", "ready": rag is not None}
|
| 46 |
+
|
| 47 |
+
def _save_upload_to_temp(upload: UploadFile) -> str:
|
| 48 |
+
suffix = ".pdf"
|
| 49 |
+
fd, path = tempfile.mkstemp(suffix=suffix)
|
| 50 |
+
os.close(fd)
|
| 51 |
+
with open(path, "wb") as out_file:
|
| 52 |
+
shutil.copyfileobj(upload.file, out_file)
|
| 53 |
+
return path
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@app.get("/list_collection_files", response_model=ListResponse)
|
| 57 |
+
async def list_collection_files_endpoint(
|
| 58 |
+
collection_name: str = "programming"
|
| 59 |
+
):
|
| 60 |
+
global rag
|
| 61 |
+
if rag is None:
|
| 62 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 63 |
+
|
| 64 |
+
files = rag.list_files_in_collection(collection_name)
|
| 65 |
+
|
| 66 |
+
return {"files": files}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@app.post("/upload_multiple_files", response_model=ListResponse)
|
| 70 |
+
async def upload_multiple_files(
|
| 71 |
+
background_tasks: BackgroundTasks,
|
| 72 |
+
files: List[UploadFile] = File(...), # get multiple files
|
| 73 |
+
collection_name: str = Form("programming"),
|
| 74 |
+
overwrite: bool = Form(True),
|
| 75 |
+
qdrant_filename_prefix: Optional[str] = Form(None),
|
| 76 |
+
):
|
| 77 |
+
"""
|
| 78 |
+
Upload multiple PDF files and save their chunks to Qdrant.
|
| 79 |
+
- files: one or more PDF files (multipart/form-data, repeated 'files' fields)
|
| 80 |
+
- collection_name: Qdrant collection to save into
|
| 81 |
+
- overwrite: if true, existing points for each filename will be removed
|
| 82 |
+
- qdrant_filename_prefix: optional prefix; if provided each file will be saved under "<prefix>_<original_filename>"
|
| 83 |
+
"""
|
| 84 |
+
global rag
|
| 85 |
+
if rag is None:
|
| 86 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 87 |
+
|
| 88 |
+
saved_files = []
|
| 89 |
+
|
| 90 |
+
def _cleanup(path: str):
|
| 91 |
+
try:
|
| 92 |
+
os.remove(path)
|
| 93 |
+
except Exception:
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
for idx, upload in enumerate(files):
|
| 97 |
+
if isinstance(upload, str):
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
if not upload.filename:
|
| 101 |
+
raise HTTPException(status_code=400, detail=f"Uploaded file #{idx+1} missing filename.")
|
| 102 |
+
|
| 103 |
+
if not upload.filename.lower().endswith(".pdf"):
|
| 104 |
+
raise HTTPException(status_code=400, detail=f"Only PDF files supported: {upload.filename}, error at file number: {idx}")
|
| 105 |
+
|
| 106 |
+
tmp_path = _save_upload_to_temp(upload)
|
| 107 |
+
background_tasks.add_task(_cleanup, tmp_path)
|
| 108 |
+
|
| 109 |
+
# decide filename to use in Qdrant payload
|
| 110 |
+
qdrant_filename = str(
|
| 111 |
+
f"{qdrant_filename_prefix}_{upload.filename}" if qdrant_filename_prefix else upload.filename
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=overwrite)
|
| 116 |
+
saved_files.append(qdrant_filename)
|
| 117 |
+
except Exception as e:
|
| 118 |
+
# collect failure info rather than aborting all uploads
|
| 119 |
+
saved_files.append({"filename": upload.filename, "error": str(e)})
|
| 120 |
+
|
| 121 |
+
return {"files": saved_files}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@app.post("/generate_saved", response_model=GenerateResponse)
|
| 126 |
+
async def generate_saved_endpoint(
|
| 127 |
+
n_easy_questions: int = Form(3),
|
| 128 |
+
n_medium_questions: int = Form(5),
|
| 129 |
+
n_hard_questions: int = Form(2),
|
| 130 |
+
qdrant_filename: str = Form("default_filename"),
|
| 131 |
+
collection_name: str = Form("programming"),
|
| 132 |
+
mode: str = Form("rag"),
|
| 133 |
+
questions_per_chunk: int = Form(3),
|
| 134 |
+
top_k: int = Form(3),
|
| 135 |
+
temperature: float = Form(0.2),
|
| 136 |
+
validate_mcqs: bool = Form(False),
|
| 137 |
+
enable_fiddler: bool = Form(False),
|
| 138 |
+
):
|
| 139 |
+
global rag
|
| 140 |
+
if rag is None:
|
| 141 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 142 |
+
|
| 143 |
+
difficulty_counts = {
|
| 144 |
+
"easy": n_easy_questions,
|
| 145 |
+
"medium": n_medium_questions,
|
| 146 |
+
"hard": n_hard_questions
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
all_mcqs = {}
|
| 150 |
+
|
| 151 |
+
for difficulty, n_questions in difficulty_counts.items():
|
| 152 |
+
try:
|
| 153 |
+
mcqs = rag.generate_from_qdrant(
|
| 154 |
+
filename=qdrant_filename,
|
| 155 |
+
collection=collection_name,
|
| 156 |
+
n_questions=n_questions,
|
| 157 |
+
mode=mode,
|
| 158 |
+
questions_per_chunk=questions_per_chunk,
|
| 159 |
+
top_k=top_k,
|
| 160 |
+
temperature=temperature,
|
| 161 |
+
enable_fiddler=enable_fiddler,
|
| 162 |
+
target_difficulty=difficulty,
|
| 163 |
+
)
|
| 164 |
+
all_mcqs.update(mcqs)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
|
| 168 |
+
|
| 169 |
+
validation_report = None
|
| 170 |
+
|
| 171 |
+
if validate_mcqs:
|
| 172 |
+
try:
|
| 173 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 174 |
+
validation_report = rag.validate_mcqs(all_mcqs, top_k=top_k)
|
| 175 |
+
except Exception as e:
|
| 176 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 177 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 178 |
+
|
| 179 |
+
# log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
|
| 180 |
+
|
| 181 |
+
return {"mcqs": all_mcqs, "validation": validation_report}
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
@app.post("/generate", response_model=GenerateResponse)
|
| 187 |
+
async def generate_endpoint(
|
| 188 |
+
background_tasks: BackgroundTasks,
|
| 189 |
+
file: UploadFile = File(...),
|
| 190 |
+
n_questions: int = Form(10),
|
| 191 |
+
qdrant_filename: str = Form("default_filename"),
|
| 192 |
+
collection_name: str = Form("programming"),
|
| 193 |
+
mode: str = Form("rag"),
|
| 194 |
+
questions_per_page: int = Form(3),
|
| 195 |
+
top_k: int = Form(3),
|
| 196 |
+
temperature: float = Form(0.2),
|
| 197 |
+
validate_mcqs: bool = Form(False),
|
| 198 |
+
enable_fiddler: bool = Form(False)
|
| 199 |
+
):
|
| 200 |
+
global rag
|
| 201 |
+
if rag is None:
|
| 202 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 203 |
+
|
| 204 |
+
# basic file validation
|
| 205 |
+
if not file.filename.lower().endswith(".pdf"):
|
| 206 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
|
| 207 |
+
|
| 208 |
+
# save uploaded file to a temp location
|
| 209 |
+
tmp_path = _save_upload_to_temp(file)
|
| 210 |
+
|
| 211 |
+
# ensure file removed afterward
|
| 212 |
+
def _cleanup(path: str):
|
| 213 |
+
try:
|
| 214 |
+
os.remove(path)
|
| 215 |
+
except Exception:
|
| 216 |
+
pass
|
| 217 |
+
|
| 218 |
+
background_tasks.add_task(_cleanup, tmp_path)
|
| 219 |
+
|
| 220 |
+
# save pdf
|
| 221 |
+
try:
|
| 222 |
+
rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=True)
|
| 223 |
+
except Exception as e:
|
| 224 |
+
raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
|
| 225 |
+
|
| 226 |
+
# generate
|
| 227 |
+
try:
|
| 228 |
+
mcqs = rag.generate_from_pdf(
|
| 229 |
+
tmp_path,
|
| 230 |
+
n_questions=n_questions,
|
| 231 |
+
mode=mode,
|
| 232 |
+
questions_per_page=questions_per_page,
|
| 233 |
+
top_k=top_k,
|
| 234 |
+
temperature=temperature,
|
| 235 |
+
enable_fiddler=enable_fiddler
|
| 236 |
+
)
|
| 237 |
+
except Exception as e:
|
| 238 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
|
| 239 |
+
|
| 240 |
+
validation_report = None
|
| 241 |
+
|
| 242 |
+
if validate_mcqs:
|
| 243 |
+
try:
|
| 244 |
+
# rag.build_index_from_pdf(tmp_path)
|
| 245 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 246 |
+
validation_report = rag.validate_mcqs(mcqs, top_k=top_k)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 249 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
|
| 253 |
+
|
| 254 |
+
return {"mcqs": mcqs, "validation": validation_report}
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if __name__ == "__main__":
|
| 258 |
+
import uvicorn
|
| 259 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000, log_level="info")
|
generator.py
ADDED
|
@@ -0,0 +1,1125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import random
|
| 3 |
+
import fitz
|
| 4 |
+
import string
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
from typing import List, Optional, Tuple, Dict, Any
|
| 8 |
+
from sentence_transformers import SentenceTransformer, CrossEncoder
|
| 9 |
+
from transformers import pipeline
|
| 10 |
+
from uuid import uuid4
|
| 11 |
+
import pymupdf4llm
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from qdrant_client import QdrantClient
|
| 15 |
+
from qdrant_client.http.models import (
|
| 16 |
+
PointStruct,
|
| 17 |
+
Filter,
|
| 18 |
+
FieldCondition,
|
| 19 |
+
MatchValue,
|
| 20 |
+
Distance,
|
| 21 |
+
VectorParams,
|
| 22 |
+
)
|
| 23 |
+
from qdrant_client.http import models as rest
|
| 24 |
+
_HAS_QDRANT = True
|
| 25 |
+
except Exception:
|
| 26 |
+
_HAS_QDRANT = False
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import faiss
|
| 30 |
+
_HAS_FAISS = True
|
| 31 |
+
except Exception:
|
| 32 |
+
_HAS_FAISS = False
|
| 33 |
+
|
| 34 |
+
from utils import generate_mcqs_from_text, new_generate_mcqs_from_text, structure_context_for_llm
|
| 35 |
+
|
| 36 |
+
from huggingface_hub import login
|
| 37 |
+
login(token=os.environ['HF_MODEL_TOKEN'])
|
| 38 |
+
|
| 39 |
+
class RAGMCQ:
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
| 43 |
+
generation_model: str = "openai/gpt-oss-120b",
|
| 44 |
+
qdrant_url: str = os.environ.get('QDRANT_URL') or "",
|
| 45 |
+
qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "",
|
| 46 |
+
qdrant_prefer_grpc: bool = False,
|
| 47 |
+
):
|
| 48 |
+
self.embedder = SentenceTransformer(embedder_model)
|
| 49 |
+
self.generation_model = generation_model
|
| 50 |
+
self.qa_pipeline = pipeline("question-answering", model="nguyenvulebinh/vi-mrc-base", tokenizer="nguyenvulebinh/vi-mrc-base")
|
| 51 |
+
self.cross_entail = CrossEncoder("itdainb/PhoRanker")
|
| 52 |
+
self.embeddings = None # np.array of shape (N, D)
|
| 53 |
+
self.texts = [] # list of chunk texts
|
| 54 |
+
self.metadata = [] # list of dicts (page, chunk_id, char_range)
|
| 55 |
+
self.index = None
|
| 56 |
+
self.dim = self.embedder.get_sentence_embedding_dimension()
|
| 57 |
+
|
| 58 |
+
self.qdrant = None
|
| 59 |
+
self.qdrant_url = qdrant_url
|
| 60 |
+
self.qdrant_api_key = qdrant_api_key
|
| 61 |
+
self.qdrant_prefer_grpc = qdrant_prefer_grpc
|
| 62 |
+
|
| 63 |
+
if qdrant_url:
|
| 64 |
+
self.connect_qdrant(qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
|
| 65 |
+
|
| 66 |
+
def extract_pages(
|
| 67 |
+
self,
|
| 68 |
+
pdf_path: str,
|
| 69 |
+
*,
|
| 70 |
+
pages: Optional[List[int]] = None,
|
| 71 |
+
ignore_images: bool = False,
|
| 72 |
+
dpi: int = 150
|
| 73 |
+
) -> List[str]:
|
| 74 |
+
doc = fitz.open(pdf_path)
|
| 75 |
+
try:
|
| 76 |
+
# request page-wise output (page_chunks=True -> list[dict] per page)
|
| 77 |
+
page_dicts = pymupdf4llm.to_markdown(
|
| 78 |
+
doc,
|
| 79 |
+
pages=pages,
|
| 80 |
+
ignore_images=ignore_images,
|
| 81 |
+
dpi=dpi,
|
| 82 |
+
page_chunks=True,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# to_markdown(..., page_chunks=True) returns a list of dicts, each has key "text" (markdown)
|
| 86 |
+
pages_md: List[str] = []
|
| 87 |
+
for p in page_dicts:
|
| 88 |
+
txt = p.get("text", "") or ""
|
| 89 |
+
pages_md.append(txt.strip())
|
| 90 |
+
|
| 91 |
+
return pages_md
|
| 92 |
+
finally:
|
| 93 |
+
doc.close()
|
| 94 |
+
|
| 95 |
+
def chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 100) -> List[str]:
|
| 96 |
+
text = text.strip()
|
| 97 |
+
if not text:
|
| 98 |
+
return []
|
| 99 |
+
|
| 100 |
+
if len(text) <= max_chars:
|
| 101 |
+
return [text]
|
| 102 |
+
|
| 103 |
+
# split by sentence-like boundaries
|
| 104 |
+
sentences = re.split(r'(?<=[\.\?\!])\s+', text)
|
| 105 |
+
chunks = []
|
| 106 |
+
cur = ""
|
| 107 |
+
|
| 108 |
+
for s in sentences:
|
| 109 |
+
if len(cur) + len(s) + 1 <= max_chars:
|
| 110 |
+
cur += (" " if cur else "") + s
|
| 111 |
+
else:
|
| 112 |
+
if cur:
|
| 113 |
+
chunks.append(cur)
|
| 114 |
+
|
| 115 |
+
cur = (cur[-overlap:] + " " + s) if overlap > 0 else s
|
| 116 |
+
|
| 117 |
+
if cur:
|
| 118 |
+
chunks.append(cur)
|
| 119 |
+
|
| 120 |
+
# if still too long, hard-split
|
| 121 |
+
final = []
|
| 122 |
+
for c in chunks:
|
| 123 |
+
if len(c) <= max_chars:
|
| 124 |
+
final.append(c)
|
| 125 |
+
else:
|
| 126 |
+
for i in range(0, len(c), max_chars):
|
| 127 |
+
final.append(c[i:i+max_chars])
|
| 128 |
+
|
| 129 |
+
return final
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
|
| 133 |
+
pages = self.extract_pages(pdf_path)
|
| 134 |
+
|
| 135 |
+
self.texts = []
|
| 136 |
+
self.metadata = []
|
| 137 |
+
|
| 138 |
+
for p_idx, page_text in enumerate(pages, start=1):
|
| 139 |
+
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
|
| 140 |
+
for cid, ch in enumerate(chunks, start=1):
|
| 141 |
+
self.texts.append(ch)
|
| 142 |
+
self.metadata.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
|
| 143 |
+
|
| 144 |
+
if not self.texts:
|
| 145 |
+
raise RuntimeError("No text extracted from PDF.")
|
| 146 |
+
|
| 147 |
+
# save_to_local('test/text_chunks.md', content=self.texts)
|
| 148 |
+
|
| 149 |
+
# compute embeddings
|
| 150 |
+
emb = self.embedder.encode(self.texts, convert_to_numpy=True, show_progress_bar=True)
|
| 151 |
+
self.embeddings = emb.astype("float32")
|
| 152 |
+
self._build_faiss_index()
|
| 153 |
+
|
| 154 |
+
def _build_faiss_index(self, ef_construction=200, M=32):
|
| 155 |
+
if _HAS_FAISS:
|
| 156 |
+
d = self.embeddings.shape[1]
|
| 157 |
+
index = faiss.IndexHNSWFlat(d, M)
|
| 158 |
+
faiss.normalize_L2(self.embeddings)
|
| 159 |
+
index.add(self.embeddings)
|
| 160 |
+
index.hnsw.efConstruction = ef_construction
|
| 161 |
+
self.index = index
|
| 162 |
+
else:
|
| 163 |
+
# store normalized embeddings and use brute-force numpy
|
| 164 |
+
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-10
|
| 165 |
+
self.embeddings = self.embeddings / norms
|
| 166 |
+
self.index = None
|
| 167 |
+
|
| 168 |
+
def _retrieve(self, query: str, top_k: int = 3) -> List[Tuple[int, float]]:
|
| 169 |
+
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")
|
| 170 |
+
|
| 171 |
+
if _HAS_FAISS:
|
| 172 |
+
faiss.normalize_L2(q_emb)
|
| 173 |
+
D_list, I_list = self.index.search(q_emb, top_k)
|
| 174 |
+
# D are inner products; return list of (idx, score)
|
| 175 |
+
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
|
| 176 |
+
else:
|
| 177 |
+
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
|
| 178 |
+
sims = (self.embeddings @ qn.T).squeeze(axis=1)
|
| 179 |
+
idxs = np.argsort(-sims)[:top_k]
|
| 180 |
+
return [(int(i), float(sims[i])) for i in idxs]
|
| 181 |
+
|
| 182 |
+
def generate_from_pdf(
|
| 183 |
+
self,
|
| 184 |
+
pdf_path: str,
|
| 185 |
+
n_questions: int = 10,
|
| 186 |
+
mode: str = "rag", # per_page or rag
|
| 187 |
+
questions_per_page: int = 3, # for per_page mode
|
| 188 |
+
top_k: int = 3, # chunks to retrieve for each question in rag mode
|
| 189 |
+
temperature: float = 0.2,
|
| 190 |
+
enable_fiddler: bool = False,
|
| 191 |
+
target_difficulty: str = 'easy' # easy, mid, difficult
|
| 192 |
+
) -> Dict[str, Any]:
|
| 193 |
+
# build index
|
| 194 |
+
self.build_index_from_pdf(pdf_path)
|
| 195 |
+
|
| 196 |
+
output: Dict[str, Any] = {}
|
| 197 |
+
qcount = 0
|
| 198 |
+
|
| 199 |
+
if mode == "per_page":
|
| 200 |
+
# iterate pages -> chunks
|
| 201 |
+
for idx, meta in enumerate(self.metadata):
|
| 202 |
+
chunk_text = self.texts[idx]
|
| 203 |
+
|
| 204 |
+
if not chunk_text.strip():
|
| 205 |
+
continue
|
| 206 |
+
to_gen = questions_per_page
|
| 207 |
+
|
| 208 |
+
# ask generator
|
| 209 |
+
try:
|
| 210 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False, target_difficulty=target_difficulty)
|
| 211 |
+
mcq_block = generate_mcqs_from_text(
|
| 212 |
+
source_text=chunk_text, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
|
| 213 |
+
)
|
| 214 |
+
except Exception as e:
|
| 215 |
+
# skip this chunk if generator fails
|
| 216 |
+
print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
if "error" in list(mcq_block.keys()):
|
| 220 |
+
return output
|
| 221 |
+
|
| 222 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 223 |
+
qcount += 1
|
| 224 |
+
output[str(qcount)] = mcq_block[item]
|
| 225 |
+
if qcount >= n_questions:
|
| 226 |
+
return output
|
| 227 |
+
|
| 228 |
+
return output
|
| 229 |
+
|
| 230 |
+
# pdf gene
|
| 231 |
+
elif mode == "rag":
|
| 232 |
+
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
|
| 233 |
+
# create queries by sampling chunk text sentences.
|
| 234 |
+
# stop when n_questions reached or max_attempts exceeded.
|
| 235 |
+
attempts = 0
|
| 236 |
+
max_attempts = n_questions * 4
|
| 237 |
+
|
| 238 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 239 |
+
attempts += 1
|
| 240 |
+
# create a seed query: pick a random chunk, pick a sentence from it
|
| 241 |
+
seed_idx = random.randrange(len(self.texts))
|
| 242 |
+
chunk = self.texts[seed_idx]
|
| 243 |
+
|
| 244 |
+
#? investigate better Chunking Strategy
|
| 245 |
+
#with open("chunks.txt", "a", encoding="utf-8") as f:
|
| 246 |
+
#f.write(chunk + "\n")
|
| 247 |
+
|
| 248 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 249 |
+
seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
|
| 250 |
+
query = f"Create questions about: {seed_sent}"
|
| 251 |
+
|
| 252 |
+
# retrieve top_k chunks
|
| 253 |
+
retrieved = self._retrieve(query, top_k=top_k)
|
| 254 |
+
context_parts = []
|
| 255 |
+
for ridx, score in retrieved:
|
| 256 |
+
md = self.metadata[ridx]
|
| 257 |
+
context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
|
| 258 |
+
context = "\n\n".join(context_parts)
|
| 259 |
+
|
| 260 |
+
# save_to_local('test/context.md', content=context)
|
| 261 |
+
|
| 262 |
+
# call generator for 1 question (or small batch) with the retrieved context
|
| 263 |
+
try:
|
| 264 |
+
# request 1 question at a time to keep diversity
|
| 265 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False, target_difficulty=target_difficulty)
|
| 266 |
+
mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
|
| 267 |
+
except Exception as e:
|
| 268 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 269 |
+
continue
|
| 270 |
+
|
| 271 |
+
if "error" in list(mcq_block.keys()):
|
| 272 |
+
return output
|
| 273 |
+
|
| 274 |
+
# append result(s)
|
| 275 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 276 |
+
payload = mcq_block[item]
|
| 277 |
+
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
|
| 278 |
+
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
|
| 279 |
+
if isinstance(options, list):
|
| 280 |
+
options = {str(i+1): o for i, o in enumerate(options)}
|
| 281 |
+
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
|
| 282 |
+
correct_text = ""
|
| 283 |
+
if isinstance(correct_key, str) and correct_key.strip() in options:
|
| 284 |
+
correct_text = options[correct_key.strip()]
|
| 285 |
+
else:
|
| 286 |
+
correct_text = payload.get("correct_text") or correct_key or ""
|
| 287 |
+
|
| 288 |
+
diff_score, diff_label = self._estimate_difficulty_for_generation(
|
| 289 |
+
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context
|
| 290 |
+
)
|
| 291 |
+
payload["difficulty"] = {"score": diff_score, "label": diff_label}
|
| 292 |
+
|
| 293 |
+
qcount += 1
|
| 294 |
+
output[str(qcount)] = mcq_block[item]
|
| 295 |
+
if qcount >= n_questions:
|
| 296 |
+
return output
|
| 297 |
+
|
| 298 |
+
return output
|
| 299 |
+
else:
|
| 300 |
+
raise ValueError("mode must be 'per_page' or 'rag'.")
|
| 301 |
+
|
| 302 |
+
def validate_mcqs(
|
| 303 |
+
self,
|
| 304 |
+
mcqs: Dict[str, Any],
|
| 305 |
+
top_k: int = 4,
|
| 306 |
+
similarity_threshold: float = 0.5,
|
| 307 |
+
evidence_score_cutoff: float = 0.5,
|
| 308 |
+
use_cross_encoder: bool = True,
|
| 309 |
+
use_qa: bool = True,
|
| 310 |
+
auto_accept_threshold: float = 0.7,
|
| 311 |
+
review_threshold: float = 0.5,
|
| 312 |
+
distractor_too_similar: float = 0.8,
|
| 313 |
+
distractor_too_different: float = 0.15,
|
| 314 |
+
model_verification_temperature: float = 0.0,
|
| 315 |
+
) -> Dict[str, Any]:
|
| 316 |
+
"""
|
| 317 |
+
Upgraded validation pipeline:
|
| 318 |
+
- embedding retrieval (self.index / self.embeddings)
|
| 319 |
+
- cross-encoder entailment scoring (optional)
|
| 320 |
+
- extractive QA consistency check (optional)
|
| 321 |
+
- distractor similarity and type checks
|
| 322 |
+
- aggregate into quality_score and triage_action
|
| 323 |
+
|
| 324 |
+
Returns a dict keyed by qid with detailed info and triage decision.
|
| 325 |
+
"""
|
| 326 |
+
cross_entail = None
|
| 327 |
+
qa_pipeline = None
|
| 328 |
+
if use_cross_encoder:
|
| 329 |
+
try:
|
| 330 |
+
cross_entail = self.cross_entail
|
| 331 |
+
except Exception as e:
|
| 332 |
+
cross_entail = None
|
| 333 |
+
if use_qa:
|
| 334 |
+
try:
|
| 335 |
+
qa_pipeline = self.qa_pipeline
|
| 336 |
+
except Exception:
|
| 337 |
+
qa_pipeline = None
|
| 338 |
+
|
| 339 |
+
# --- helpers ---
|
| 340 |
+
def _norm_text(s: str) -> str:
|
| 341 |
+
if s is None:
|
| 342 |
+
return ""
|
| 343 |
+
s = s.strip().lower()
|
| 344 |
+
# remove punctuation
|
| 345 |
+
s = s.translate(str.maketrans("", "", string.punctuation))
|
| 346 |
+
# collapse whitespace
|
| 347 |
+
s = " ".join(s.split())
|
| 348 |
+
return s
|
| 349 |
+
|
| 350 |
+
def _semantic_search(statement: str, k: int = top_k):
|
| 351 |
+
# returns list of (idx, score) using current embeddings/index
|
| 352 |
+
q_emb = self.embedder.encode([statement], convert_to_numpy=True).astype("float32")
|
| 353 |
+
if _HAS_FAISS and getattr(self, "index", None) is not None:
|
| 354 |
+
try:
|
| 355 |
+
faiss.normalize_L2(q_emb)
|
| 356 |
+
D_list, I_list = self.index.search(q_emb, k)
|
| 357 |
+
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
|
| 358 |
+
except Exception:
|
| 359 |
+
pass
|
| 360 |
+
# fallback to brute force
|
| 361 |
+
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
|
| 362 |
+
sims = (self.embeddings @ qn.T).squeeze(axis=1)
|
| 363 |
+
idxs = np.argsort(-sims)[:k]
|
| 364 |
+
return [(int(i), float(sims[i])) for i in idxs]
|
| 365 |
+
|
| 366 |
+
def _compose_context_from_retrieved(retrieved):
|
| 367 |
+
parts = []
|
| 368 |
+
for ridx, score in retrieved:
|
| 369 |
+
md = self.metadata[ridx] if ridx < len(self.metadata) else {}
|
| 370 |
+
page = md.get("page", "?")
|
| 371 |
+
text = self.texts[ridx]
|
| 372 |
+
parts.append(f"[page {page}] {text}")
|
| 373 |
+
return "\n\n".join(parts)
|
| 374 |
+
|
| 375 |
+
def _compute_option_embeddings(options_map: Dict[str, str]):
|
| 376 |
+
# returns dict key->embedding
|
| 377 |
+
keys = list(options_map.keys())
|
| 378 |
+
texts = [options_map[k] for k in keys]
|
| 379 |
+
embs = self.embedder.encode(texts, convert_to_numpy=True)
|
| 380 |
+
return dict(zip(keys, embs))
|
| 381 |
+
|
| 382 |
+
def _cosine(a, b):
|
| 383 |
+
a = np.asarray(a, dtype=float)
|
| 384 |
+
b = np.asarray(b, dtype=float)
|
| 385 |
+
denom = (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12)
|
| 386 |
+
return float(np.dot(a, b) / denom)
|
| 387 |
+
|
| 388 |
+
# --- main loop ---
|
| 389 |
+
report = {}
|
| 390 |
+
for qid, item in mcqs.items():
|
| 391 |
+
# support both Vietnamese keys and English keys
|
| 392 |
+
q_text = (item.get("câu hỏi") or item.get("question") or item.get("q") or item.get("stem") or "").strip()
|
| 393 |
+
options = item.get("lựa chọn") or item.get("options") or item.get("choices") or {}
|
| 394 |
+
# options may be dict mapping letters to text, or list: normalize to dict
|
| 395 |
+
if isinstance(options, list):
|
| 396 |
+
options = {str(i+1): o for i, o in enumerate(options)}
|
| 397 |
+
# correct answer may be a key (like "A") or the text; try both
|
| 398 |
+
correct_key = item.get("đáp án") or item.get("answer") or item.get("correct") or item.get("ans")
|
| 399 |
+
correct_text = ""
|
| 400 |
+
if isinstance(correct_key, str) and correct_key.strip() in options:
|
| 401 |
+
correct_text = options[correct_key.strip()]
|
| 402 |
+
else:
|
| 403 |
+
# maybe the answer is full text
|
| 404 |
+
if isinstance(correct_key, str):
|
| 405 |
+
correct_text = correct_key.strip()
|
| 406 |
+
else:
|
| 407 |
+
# fallback to 'correct_text' field
|
| 408 |
+
correct_text = item.get("correct_text") or item.get("đáp án_text") or ""
|
| 409 |
+
|
| 410 |
+
# default empty guard
|
| 411 |
+
options = {k: str(v) for k, v in options.items()}
|
| 412 |
+
correct_text = str(correct_text)
|
| 413 |
+
|
| 414 |
+
# prepare statement for retrieval
|
| 415 |
+
statement = f"{q_text} Answer: {correct_text}"
|
| 416 |
+
retrieved = _semantic_search(statement, k=top_k)
|
| 417 |
+
# build context from top retrieved
|
| 418 |
+
context_parts = []
|
| 419 |
+
for ridx, score in retrieved:
|
| 420 |
+
md = self.metadata[ridx] if ridx < len(self.metadata) else {}
|
| 421 |
+
context_parts.append({"idx": ridx, "score": float(score), "page": md.get("page", None), "text": self.texts[ridx]})
|
| 422 |
+
context_text = "\n\n".join([f"[page {p['page']}] {p['text']}" for p in context_parts])
|
| 423 |
+
|
| 424 |
+
# Evidence list (embedding-based)
|
| 425 |
+
evidence_list = []
|
| 426 |
+
max_sim = 0.0
|
| 427 |
+
for r in context_parts:
|
| 428 |
+
if r["score"] >= evidence_score_cutoff:
|
| 429 |
+
snippet = r["text"]
|
| 430 |
+
evidence_list.append({
|
| 431 |
+
"idx": r["idx"],
|
| 432 |
+
"page": r["page"],
|
| 433 |
+
"score": r["score"],
|
| 434 |
+
"text": (snippet[:1000] + ("..." if len(snippet) > 1000 else "")),
|
| 435 |
+
})
|
| 436 |
+
if r["score"] > max_sim:
|
| 437 |
+
max_sim = float(r["score"])
|
| 438 |
+
supported_by_embeddings = max_sim >= similarity_threshold
|
| 439 |
+
|
| 440 |
+
# Cross-encoder entailment scores for each option
|
| 441 |
+
entailment_scores = {}
|
| 442 |
+
correct_entail = 0.0
|
| 443 |
+
try:
|
| 444 |
+
if cross_entail is not None and context_text.strip():
|
| 445 |
+
# prepare list of (premise, hypothesis)
|
| 446 |
+
pairs = []
|
| 447 |
+
opt_keys = list(options.keys())
|
| 448 |
+
for k in opt_keys:
|
| 449 |
+
hyp = f"{q_text} Answer: {options[k]}"
|
| 450 |
+
pairs.append((context_text, hyp))
|
| 451 |
+
scores = cross_entail.predict(pairs) # returns list of floats
|
| 452 |
+
# normalize scores to 0-1 if needed (cross-encoder may return arbitrary positive)
|
| 453 |
+
# do a min-max normalization across the returned scores
|
| 454 |
+
# but avoid division by zero
|
| 455 |
+
min_s = float(min(scores)) if len(scores) else 0.0
|
| 456 |
+
max_s = float(max(scores)) if len(scores) else 1.0
|
| 457 |
+
denom = max_s - min_s if max_s - min_s > 1e-6 else 1.0
|
| 458 |
+
for k, raw in zip(opt_keys, scores):
|
| 459 |
+
scaled = (raw - min_s) / denom
|
| 460 |
+
entailment_scores[k] = float(scaled)
|
| 461 |
+
# find correct key if available
|
| 462 |
+
# if `correct_text` exactly matches one of options, find that key
|
| 463 |
+
matched_key = None
|
| 464 |
+
for k, v in options.items():
|
| 465 |
+
if _norm_text(v) == _norm_text(correct_text):
|
| 466 |
+
matched_key = k
|
| 467 |
+
break
|
| 468 |
+
if matched_key:
|
| 469 |
+
correct_entail = entailment_scores.get(matched_key, 0.0)
|
| 470 |
+
else:
|
| 471 |
+
# fallback: treat 'correct_text' as a separate hypothesis
|
| 472 |
+
hyp = f"{q_text} Answer: {correct_text}"
|
| 473 |
+
raw = cross_entail.predict([(context_text, hyp)])[0]
|
| 474 |
+
# scale relative to min/max used above
|
| 475 |
+
correct_entail = float((raw - min_s) / denom)
|
| 476 |
+
else:
|
| 477 |
+
entailment_scores = {}
|
| 478 |
+
correct_entail = 0.0
|
| 479 |
+
except Exception as e:
|
| 480 |
+
entailment_scores = {}
|
| 481 |
+
correct_entail = 0.0
|
| 482 |
+
|
| 483 |
+
def embed_cosine_sim(a, b):
|
| 484 |
+
emb = self.embedder.encode([a, b], convert_to_numpy=True, normalize_embeddings=True)
|
| 485 |
+
return float(np.dot(emb[0], emb[1]))
|
| 486 |
+
|
| 487 |
+
# QA consistency
|
| 488 |
+
qa_answer = None
|
| 489 |
+
qa_score = 0.0
|
| 490 |
+
qa_agrees = False
|
| 491 |
+
if qa_pipeline is not None and context_text.strip():
|
| 492 |
+
try:
|
| 493 |
+
qa_res = qa_pipeline(question=q_text, context=context_text)
|
| 494 |
+
# some QA pipelines return list of answers or dict
|
| 495 |
+
if isinstance(qa_res, list) and len(qa_res) > 0:
|
| 496 |
+
top = qa_res[0]
|
| 497 |
+
qa_answer = top.get("answer") if isinstance(top, dict) else str(top)
|
| 498 |
+
# qa_score = float(top.get("score", 0.0) if isinstance(top, dict) else 0.0)
|
| 499 |
+
elif isinstance(qa_res, dict):
|
| 500 |
+
qa_answer = qa_res.get("answer", "")
|
| 501 |
+
qa_score = float(qa_res.get("score", 0.0))
|
| 502 |
+
else:
|
| 503 |
+
qa_answer = str(qa_res)
|
| 504 |
+
qa_score = 0.0
|
| 505 |
+
qa_score = embed_cosine_sim(qa_answer, correct_text)
|
| 506 |
+
qa_agrees = (qa_score >= 0.5)
|
| 507 |
+
except Exception:
|
| 508 |
+
qa_answer = None
|
| 509 |
+
qa_score = 0.0
|
| 510 |
+
qa_agrees = False
|
| 511 |
+
|
| 512 |
+
try:
|
| 513 |
+
opt_embs = _compute_option_embeddings({**options, "__CORRECT__": correct_text})
|
| 514 |
+
correct_emb = opt_embs.pop("__CORRECT__")
|
| 515 |
+
distractor_similarities = {}
|
| 516 |
+
for k, emb in opt_embs.items():
|
| 517 |
+
distractor_similarities[k] = float(_cosine(correct_emb, emb))
|
| 518 |
+
except Exception:
|
| 519 |
+
distractor_similarities = {k: None for k in options.keys()}
|
| 520 |
+
|
| 521 |
+
# distractor flags
|
| 522 |
+
distractor_penalty = 0.0
|
| 523 |
+
distractor_flags = []
|
| 524 |
+
for k, sim in distractor_similarities.items():
|
| 525 |
+
if sim is None or sim >= 0.999999 or (sim >= -0.01 and sim <= 0):
|
| 526 |
+
continue
|
| 527 |
+
if sim >= distractor_too_similar:
|
| 528 |
+
distractor_flags.append({"key": k, "reason": "too_similar", "similarity": sim})
|
| 529 |
+
distractor_penalty += 0.25
|
| 530 |
+
elif sim <= distractor_too_different:
|
| 531 |
+
distractor_flags.append({"key": k, "reason": "too_different", "similarity": sim})
|
| 532 |
+
distractor_penalty += 0.15
|
| 533 |
+
# clamp penalty
|
| 534 |
+
distractor_penalty = min(distractor_penalty, 1.0)
|
| 535 |
+
|
| 536 |
+
# Ambiguity detection: how many options have entailment >= threshold
|
| 537 |
+
ambiguous = False
|
| 538 |
+
ambiguous_options = []
|
| 539 |
+
if entailment_scores:
|
| 540 |
+
# count options whose entailment >= max(correct_entail * 0.9, 0.6)
|
| 541 |
+
amb_thresh = max(correct_entail * 0.9, 0.6)
|
| 542 |
+
for k, sc in entailment_scores.items():
|
| 543 |
+
if sc >= amb_thresh and (options.get(k, "") != correct_text):
|
| 544 |
+
ambiguous_options.append({"key": k, "score": sc, "text": options[k]})
|
| 545 |
+
ambiguous = len(ambiguous_options) > 0
|
| 546 |
+
|
| 547 |
+
# Compose aggregated quality score
|
| 548 |
+
# Components:
|
| 549 |
+
# - embedding_support: normalized max_sim (0..1)
|
| 550 |
+
# - entailment: correct_entail (0..1)
|
| 551 |
+
# - qa_agree: boolean -> 1 or 0 times qa_score
|
| 552 |
+
# - distractor_penalty: subtracted
|
| 553 |
+
emb_support_norm = max_sim # embedding similarity typically already 0..1 (inner product normalized)
|
| 554 |
+
entail_component = float(correct_entail)
|
| 555 |
+
qa_component = float(qa_score) if qa_agrees else 0.0
|
| 556 |
+
|
| 557 |
+
# weighted sum
|
| 558 |
+
quality_score = (
|
| 559 |
+
0.40 * emb_support_norm +
|
| 560 |
+
0.35 * entail_component +
|
| 561 |
+
0.20 * qa_component -
|
| 562 |
+
0.05 * distractor_penalty
|
| 563 |
+
)
|
| 564 |
+
# clamp to 0..1
|
| 565 |
+
quality_score = max(0.0, min(1.0, quality_score))
|
| 566 |
+
|
| 567 |
+
# triage decision
|
| 568 |
+
triage_action = "reject"
|
| 569 |
+
if quality_score >= auto_accept_threshold and not ambiguous:
|
| 570 |
+
triage_action = "pass"
|
| 571 |
+
elif quality_score >= review_threshold:
|
| 572 |
+
triage_action = "review"
|
| 573 |
+
else:
|
| 574 |
+
triage_action = "reject"
|
| 575 |
+
|
| 576 |
+
# compile flags/reasons
|
| 577 |
+
flag_reasons = []
|
| 578 |
+
if not supported_by_embeddings:
|
| 579 |
+
flag_reasons.append("no_strong_embedding_evidence")
|
| 580 |
+
if entailment_scores and correct_entail < 0.6:
|
| 581 |
+
flag_reasons.append("low_entailment_score_for_correct")
|
| 582 |
+
if qa_pipeline is not None and qa_score > 0.6 and not qa_agrees:
|
| 583 |
+
flag_reasons.append("qa_contradiction")
|
| 584 |
+
if ambiguous:
|
| 585 |
+
flag_reasons.append("ambiguous_options_supported")
|
| 586 |
+
if distractor_flags:
|
| 587 |
+
flag_reasons.append({"distractor_issues": distractor_flags})
|
| 588 |
+
|
| 589 |
+
# assemble per-question report
|
| 590 |
+
report[qid] = {
|
| 591 |
+
"supported_by_embeddings": bool(supported_by_embeddings),
|
| 592 |
+
"max_similarity": float(max_sim),
|
| 593 |
+
"evidence": evidence_list,
|
| 594 |
+
"entailment_scores": entailment_scores,
|
| 595 |
+
"correct_entailment": float(correct_entail),
|
| 596 |
+
"qa_answer": qa_answer,
|
| 597 |
+
"qa_score": float(qa_score),
|
| 598 |
+
"qa_agrees": bool(qa_agrees),
|
| 599 |
+
"distractor_similarities": distractor_similarities,
|
| 600 |
+
"distractor_flags": distractor_flags,
|
| 601 |
+
"distractor_penalty": float(distractor_penalty),
|
| 602 |
+
"ambiguous_options": ambiguous_options,
|
| 603 |
+
"quality_score": float(quality_score),
|
| 604 |
+
"triage_action": triage_action,
|
| 605 |
+
"flag_reasons": flag_reasons,
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
return report
|
| 609 |
+
|
| 610 |
+
def connect_qdrant(self, url: str, api_key: str = None, prefer_grpc: bool = False):
|
| 611 |
+
if not _HAS_QDRANT:
|
| 612 |
+
raise RuntimeError("qdrant-client is not installed. Install with `pip install qdrant-client`.")
|
| 613 |
+
self.qdrant_url = url
|
| 614 |
+
self.qdrant_api_key = api_key
|
| 615 |
+
self.qdrant_prefer_grpc = prefer_grpc
|
| 616 |
+
# Create client
|
| 617 |
+
self.qdrant = QdrantClient(url=url, api_key=api_key, prefer_grpc=prefer_grpc)
|
| 618 |
+
|
| 619 |
+
def _ensure_collection(self, collection_name: str):
|
| 620 |
+
if self.qdrant is None:
|
| 621 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 622 |
+
try:
|
| 623 |
+
# get_collection will raise if not present
|
| 624 |
+
_ = self.qdrant.get_collection(collection_name)
|
| 625 |
+
except Exception:
|
| 626 |
+
# create collection with vector size = self.dim
|
| 627 |
+
vect_params = VectorParams(size=self.dim, distance=Distance.COSINE)
|
| 628 |
+
self.qdrant.recreate_collection(collection_name=collection_name, vectors_config=vect_params)
|
| 629 |
+
# recreate_collection ensures a clean collection; if you prefer to avoid wiping use create_collection instead.
|
| 630 |
+
|
| 631 |
+
def save_pdf_to_qdrant(
|
| 632 |
+
self,
|
| 633 |
+
pdf_path: str,
|
| 634 |
+
filename: str,
|
| 635 |
+
collection: str,
|
| 636 |
+
max_chars: int = 1200,
|
| 637 |
+
batch_size: int = 64,
|
| 638 |
+
overwrite: bool = False,
|
| 639 |
+
):
|
| 640 |
+
if self.qdrant is None:
|
| 641 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 642 |
+
|
| 643 |
+
# extract pages and chunks (re-using your existing helpers)
|
| 644 |
+
pages = self.extract_pages(pdf_path)
|
| 645 |
+
|
| 646 |
+
all_chunks = []
|
| 647 |
+
all_meta = []
|
| 648 |
+
for p_idx, page_text in enumerate(pages, start=1):
|
| 649 |
+
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
|
| 650 |
+
for cid, ch in enumerate(chunks, start=1):
|
| 651 |
+
all_chunks.append(ch)
|
| 652 |
+
all_meta.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
|
| 653 |
+
|
| 654 |
+
if not all_chunks:
|
| 655 |
+
raise RuntimeError("No tSext extracted from PDF.")
|
| 656 |
+
|
| 657 |
+
# ensure collection exists
|
| 658 |
+
self._ensure_collection(collection)
|
| 659 |
+
|
| 660 |
+
# optional: delete previous points for this filename if overwrite
|
| 661 |
+
if overwrite:
|
| 662 |
+
# delete by filter: filename == filename
|
| 663 |
+
flt = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
|
| 664 |
+
try:
|
| 665 |
+
# qdrant-client delete uses delete(
|
| 666 |
+
self.qdrant.delete(collection_name=collection, filter=flt)
|
| 667 |
+
except Exception:
|
| 668 |
+
# ignore if deletion fails
|
| 669 |
+
pass
|
| 670 |
+
|
| 671 |
+
# compute embeddings in batches
|
| 672 |
+
embeddings = self.embedder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
|
| 673 |
+
embeddings = embeddings.astype("float32")
|
| 674 |
+
|
| 675 |
+
# prepare points
|
| 676 |
+
points = []
|
| 677 |
+
for i, (emb, md, txt) in enumerate(zip(embeddings, all_meta, all_chunks)):
|
| 678 |
+
pid = str(uuid4())
|
| 679 |
+
source_id = f"{filename}__p{md['page']}__c{md['chunk_id']}"
|
| 680 |
+
payload = {
|
| 681 |
+
"filename": filename,
|
| 682 |
+
"page": md["page"],
|
| 683 |
+
"chunk_id": md["chunk_id"],
|
| 684 |
+
"length": md["length"],
|
| 685 |
+
"text": txt,
|
| 686 |
+
"source_id": source_id,
|
| 687 |
+
}
|
| 688 |
+
points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload))
|
| 689 |
+
|
| 690 |
+
# upsert in batches
|
| 691 |
+
if len(points) >= batch_size:
|
| 692 |
+
self.qdrant.upsert(collection_name=collection, points=points)
|
| 693 |
+
points = []
|
| 694 |
+
|
| 695 |
+
# upsert remaining
|
| 696 |
+
if points:
|
| 697 |
+
self.qdrant.upsert(collection_name=collection, points=points)
|
| 698 |
+
|
| 699 |
+
try:
|
| 700 |
+
self.qdrant.create_payload_index(
|
| 701 |
+
collection_name=collection,
|
| 702 |
+
field_name="filename",
|
| 703 |
+
field_schema=rest.PayloadSchemaType.KEYWORD
|
| 704 |
+
)
|
| 705 |
+
except Exception as e:
|
| 706 |
+
print(f"Index creation skipped or failed: {e}")
|
| 707 |
+
|
| 708 |
+
return {"status": "ok", "uploaded_chunks": len(all_chunks), "collection": collection, "filename": filename}
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def list_files_in_collection(
|
| 712 |
+
self,
|
| 713 |
+
collection: str,
|
| 714 |
+
payload_field: str = "filename",
|
| 715 |
+
batch_size: int = 500,
|
| 716 |
+
) -> List[str]:
|
| 717 |
+
if self.qdrant is None:
|
| 718 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 719 |
+
|
| 720 |
+
# ensure collection exists
|
| 721 |
+
try:
|
| 722 |
+
if not self.qdrant.collection_exists(collection):
|
| 723 |
+
raise RuntimeError(f"Collection '{collection}' does not exist.")
|
| 724 |
+
except Exception:
|
| 725 |
+
# collection_exists may raise if server unreachable
|
| 726 |
+
raise
|
| 727 |
+
|
| 728 |
+
filenames = set()
|
| 729 |
+
offset = None
|
| 730 |
+
|
| 731 |
+
while True:
|
| 732 |
+
# scroll returns (points, next_offset)
|
| 733 |
+
pts, next_offset = self.qdrant.scroll(
|
| 734 |
+
collection_name=collection,
|
| 735 |
+
limit=batch_size,
|
| 736 |
+
offset=offset,
|
| 737 |
+
with_payload=[payload_field],
|
| 738 |
+
with_vectors=False,
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
if not pts:
|
| 742 |
+
break
|
| 743 |
+
|
| 744 |
+
for p in pts:
|
| 745 |
+
# p may be a dict-like or an object with .payload
|
| 746 |
+
payload = None
|
| 747 |
+
if hasattr(p, "payload"):
|
| 748 |
+
payload = p.payload
|
| 749 |
+
elif isinstance(p, dict):
|
| 750 |
+
# older/newer variants might use nested structures: try common keys
|
| 751 |
+
payload = p.get("payload") or p.get("payload", None) or p
|
| 752 |
+
else:
|
| 753 |
+
# best-effort fallback: convert to dict if possible
|
| 754 |
+
try:
|
| 755 |
+
payload = dict(p)
|
| 756 |
+
except Exception:
|
| 757 |
+
payload = None
|
| 758 |
+
|
| 759 |
+
if not payload:
|
| 760 |
+
continue
|
| 761 |
+
|
| 762 |
+
# extract candidate value(s)
|
| 763 |
+
val = None
|
| 764 |
+
if isinstance(payload, dict):
|
| 765 |
+
val = payload.get(payload_field)
|
| 766 |
+
else:
|
| 767 |
+
# Some payload representations store fields differently; try attribute access
|
| 768 |
+
val = getattr(payload, payload_field, None)
|
| 769 |
+
|
| 770 |
+
# If value is list-like, iterate, else add single
|
| 771 |
+
if isinstance(val, (list, tuple, set)):
|
| 772 |
+
for v in val:
|
| 773 |
+
if v is not None:
|
| 774 |
+
filenames.add(str(v))
|
| 775 |
+
elif val is not None:
|
| 776 |
+
filenames.add(str(val))
|
| 777 |
+
|
| 778 |
+
# stop if no more pages
|
| 779 |
+
if not next_offset:
|
| 780 |
+
break
|
| 781 |
+
offset = next_offset
|
| 782 |
+
|
| 783 |
+
return sorted(filenames)
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
def list_chunks_for_filename(self, collection: str, filename: str, batch: int = 256) -> List[Dict[str, Any]]:
|
| 787 |
+
if self.qdrant is None:
|
| 788 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 789 |
+
|
| 790 |
+
results = []
|
| 791 |
+
offset = None
|
| 792 |
+
while True:
|
| 793 |
+
# scroll returns (points, next_offset)
|
| 794 |
+
points, next_offset = self.qdrant.scroll(
|
| 795 |
+
collection_name=collection,
|
| 796 |
+
scroll_filter=Filter(
|
| 797 |
+
must=[
|
| 798 |
+
FieldCondition(key="filename", match=MatchValue(value=filename))
|
| 799 |
+
]
|
| 800 |
+
),
|
| 801 |
+
limit=batch,
|
| 802 |
+
offset=offset,
|
| 803 |
+
with_payload=True,
|
| 804 |
+
with_vectors=False,
|
| 805 |
+
)
|
| 806 |
+
# points are objects (Record / ScoredPoint-like); get id and payload
|
| 807 |
+
for p in points:
|
| 808 |
+
# p.payload is a dict, p.id is point id
|
| 809 |
+
results.append({"point_id": p.id, "payload": p.payload})
|
| 810 |
+
if not next_offset:
|
| 811 |
+
break
|
| 812 |
+
offset = next_offset
|
| 813 |
+
return results
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
def _retrieve_qdrant(self, query: str, collection: str, filename: str = None, top_k: int = 3) -> List[Tuple[Dict[str, Any], float]]:
|
| 817 |
+
if self.qdrant is None:
|
| 818 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 819 |
+
|
| 820 |
+
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")[0].tolist()
|
| 821 |
+
q_filter = None
|
| 822 |
+
if filename:
|
| 823 |
+
q_filter = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
|
| 824 |
+
|
| 825 |
+
search_res = self.qdrant.search(
|
| 826 |
+
collection_name=collection,
|
| 827 |
+
query_vector=q_emb,
|
| 828 |
+
query_filter=q_filter,
|
| 829 |
+
limit=top_k,
|
| 830 |
+
with_payload=True,
|
| 831 |
+
with_vectors=False,
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
out = []
|
| 835 |
+
for hit in search_res:
|
| 836 |
+
# hit.payload is the stored payload, hit.score is similarity
|
| 837 |
+
out.append((hit.payload, float(getattr(hit, "score", 0.0))))
|
| 838 |
+
return out
|
| 839 |
+
|
| 840 |
+
|
| 841 |
+
def generate_from_qdrant(
|
| 842 |
+
self,
|
| 843 |
+
filename: str,
|
| 844 |
+
collection: str,
|
| 845 |
+
n_questions: int = 10,
|
| 846 |
+
mode: str = "rag", # 'per_chunk' or 'rag'
|
| 847 |
+
questions_per_chunk: int = 3, # used for 'per_chunk'
|
| 848 |
+
top_k: int = 3, # retrieval size used in RAG
|
| 849 |
+
temperature: float = 0.2,
|
| 850 |
+
enable_fiddler: bool = False,
|
| 851 |
+
target_difficulty: str = 'easy',
|
| 852 |
+
|
| 853 |
+
) -> Dict[str, Any]:
|
| 854 |
+
if self.qdrant is None:
|
| 855 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 856 |
+
|
| 857 |
+
# get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
|
| 858 |
+
file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
|
| 859 |
+
if not file_points:
|
| 860 |
+
raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
|
| 861 |
+
|
| 862 |
+
# create a local list of texts & metadata for sampling
|
| 863 |
+
texts = []
|
| 864 |
+
metas = []
|
| 865 |
+
for p in file_points:
|
| 866 |
+
payload = p.get("payload", {})
|
| 867 |
+
text = payload.get("text", "")
|
| 868 |
+
texts.append(text)
|
| 869 |
+
metas.append(payload)
|
| 870 |
+
|
| 871 |
+
self.texts = texts
|
| 872 |
+
self.metadata = metas
|
| 873 |
+
embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
|
| 874 |
+
if embeddings is None or len(embeddings) == 0:
|
| 875 |
+
self.embeddings = None
|
| 876 |
+
self.index = None
|
| 877 |
+
else:
|
| 878 |
+
self.embeddings = embeddings.astype("float32")
|
| 879 |
+
|
| 880 |
+
# update dim in case embedder changed unexpectedly
|
| 881 |
+
self.dim = int(self.embeddings.shape[1])
|
| 882 |
+
|
| 883 |
+
# build index
|
| 884 |
+
self._build_faiss_index()
|
| 885 |
+
|
| 886 |
+
output = {}
|
| 887 |
+
qcount = 0
|
| 888 |
+
|
| 889 |
+
if mode == "per_chunk":
|
| 890 |
+
# iterate all chunks (in payload order) and request questions_per_chunk from each
|
| 891 |
+
for i, txt in enumerate(texts):
|
| 892 |
+
if not txt.strip():
|
| 893 |
+
continue
|
| 894 |
+
to_gen = questions_per_chunk
|
| 895 |
+
try:
|
| 896 |
+
mcq_block = new_generate_mcqs_from_text(txt, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=False)
|
| 897 |
+
except Exception as e:
|
| 898 |
+
print(f"Generator failed on chunk (index {i}): {e}")
|
| 899 |
+
continue
|
| 900 |
+
|
| 901 |
+
if "error" in list(mcq_block.keys()):
|
| 902 |
+
return output
|
| 903 |
+
|
| 904 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 905 |
+
qcount += 1
|
| 906 |
+
output[str(qcount)] = mcq_block[item]
|
| 907 |
+
if qcount >= n_questions:
|
| 908 |
+
return output
|
| 909 |
+
return output
|
| 910 |
+
|
| 911 |
+
elif mode == "rag":
|
| 912 |
+
attempts = 0
|
| 913 |
+
max_attempts = n_questions * 4
|
| 914 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 915 |
+
attempts += 1
|
| 916 |
+
# create a seed query: pick a random chunk, pick a sentence from it
|
| 917 |
+
seed_idx = random.randrange(len(self.texts))
|
| 918 |
+
chunk = self.texts[seed_idx]
|
| 919 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 920 |
+
candidate = [s for s in sents if len(s.strip()) > 20]
|
| 921 |
+
if candidate:
|
| 922 |
+
seed_sent = random.choice(candidate)
|
| 923 |
+
else:
|
| 924 |
+
stripped = chunk.strip()
|
| 925 |
+
seed_sent = (stripped[:200] if stripped else "[no text available]")
|
| 926 |
+
query = f"Create questions about: {seed_sent}"
|
| 927 |
+
|
| 928 |
+
|
| 929 |
+
# retrieve top_k chunks from the same file (restricted by filename filter)
|
| 930 |
+
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
|
| 931 |
+
print('retrieved qdrant', retrieved)
|
| 932 |
+
context_parts = []
|
| 933 |
+
for payload, score in retrieved:
|
| 934 |
+
# payload should contain page & chunk_id and text
|
| 935 |
+
page = payload.get("page", "?")
|
| 936 |
+
ctxt = payload.get("text", "")
|
| 937 |
+
context_parts.append(f"[page {page}] {ctxt}")
|
| 938 |
+
context = "\n\n".join(context_parts)
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
# q generation
|
| 942 |
+
try:
|
| 943 |
+
# Difficulty pipeline: easy, mid, difficult
|
| 944 |
+
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False, target_difficulty=target_difficulty)
|
| 945 |
+
mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
|
| 946 |
+
except Exception as e:
|
| 947 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 948 |
+
continue
|
| 949 |
+
|
| 950 |
+
if "error" in list(mcq_block.keys()):
|
| 951 |
+
return output
|
| 952 |
+
|
| 953 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 954 |
+
payload = mcq_block[item]
|
| 955 |
+
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
|
| 956 |
+
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
|
| 957 |
+
|
| 958 |
+
if isinstance(options, list):
|
| 959 |
+
options = {str(i+1): o for i, o in enumerate(options)}
|
| 960 |
+
|
| 961 |
+
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
|
| 962 |
+
concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
|
| 963 |
+
|
| 964 |
+
correct_text = ""
|
| 965 |
+
if isinstance(correct_key, str) and correct_key.strip() in options:
|
| 966 |
+
correct_text = options[correct_key.strip()]
|
| 967 |
+
else:
|
| 968 |
+
correct_text = payload.get("correct_text") or correct_key or ""
|
| 969 |
+
|
| 970 |
+
#? change estimate
|
| 971 |
+
diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
|
| 972 |
+
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used = concepts
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
|
| 976 |
+
|
| 977 |
+
# CHECK n generation: if number of request mcqs < default generation number e.g. 5 - 3 = 2 < 3 then only genearate 2 mcqs
|
| 978 |
+
if n_questions - qcount < questions_per_chunk:
|
| 979 |
+
questions_per_chunk = n_questions - qcount
|
| 980 |
+
|
| 981 |
+
qcount += 1 # count number of question
|
| 982 |
+
print('qcount:', qcount)
|
| 983 |
+
print('questions_per_chunk:', questions_per_chunk)
|
| 984 |
+
|
| 985 |
+
output[str(qcount)] = mcq_block[item]
|
| 986 |
+
if qcount >= n_questions:
|
| 987 |
+
return output
|
| 988 |
+
|
| 989 |
+
if output is not None:
|
| 990 |
+
print("output available")
|
| 991 |
+
return output
|
| 992 |
+
else:
|
| 993 |
+
raise ValueError("mode must be 'per_chunk' or 'rag'.")
|
| 994 |
+
|
| 995 |
+
|
| 996 |
+
def _estimate_difficulty_for_generation(
|
| 997 |
+
self,
|
| 998 |
+
q_text: str,
|
| 999 |
+
options: Dict[str, str],
|
| 1000 |
+
correct_text: str,
|
| 1001 |
+
context_text: str = "",
|
| 1002 |
+
concepts_used: Dict = {}
|
| 1003 |
+
) -> Tuple[float, str]:
|
| 1004 |
+
def safe_map_sim(s):
|
| 1005 |
+
# map potentially [-1,1] cosine-like to [0,1], clamp
|
| 1006 |
+
try:
|
| 1007 |
+
s = float(s)
|
| 1008 |
+
except Exception:
|
| 1009 |
+
return 0.0
|
| 1010 |
+
mapped = (s + 1.0) / 2.0
|
| 1011 |
+
return max(0.0, min(1.0, mapped))
|
| 1012 |
+
|
| 1013 |
+
# embedding support
|
| 1014 |
+
emb_support = 0.0
|
| 1015 |
+
try:
|
| 1016 |
+
stmt = (q_text or "").strip()
|
| 1017 |
+
if correct_text:
|
| 1018 |
+
stmt = f"{stmt} Answer: {correct_text}"
|
| 1019 |
+
|
| 1020 |
+
# use internal retrieve but map returned score
|
| 1021 |
+
res = []
|
| 1022 |
+
try:
|
| 1023 |
+
res = self._retrieve(stmt, top_k=1)
|
| 1024 |
+
except Exception:
|
| 1025 |
+
res = []
|
| 1026 |
+
|
| 1027 |
+
if res:
|
| 1028 |
+
raw_score = float(res[0][1])
|
| 1029 |
+
emb_support = safe_map_sim(raw_score)
|
| 1030 |
+
else:
|
| 1031 |
+
emb_support = 0.0
|
| 1032 |
+
except Exception:
|
| 1033 |
+
emb_support = 0.0
|
| 1034 |
+
|
| 1035 |
+
# distractor sims
|
| 1036 |
+
mean_sim = 0.0
|
| 1037 |
+
distractor_penalty = 0.0
|
| 1038 |
+
amb_flag = 0.0
|
| 1039 |
+
try:
|
| 1040 |
+
keys = list(options.keys())
|
| 1041 |
+
texts = [options[k] for k in keys]
|
| 1042 |
+
if correct_text is None:
|
| 1043 |
+
correct_text = ""
|
| 1044 |
+
|
| 1045 |
+
all_texts = [correct_text] + texts
|
| 1046 |
+
embs = self.embedder.encode(all_texts, convert_to_numpy=True)
|
| 1047 |
+
embs = np.asarray(embs, dtype=float)
|
| 1048 |
+
norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12
|
| 1049 |
+
embs = embs / norms
|
| 1050 |
+
corr = embs[0]
|
| 1051 |
+
opts = embs[1:]
|
| 1052 |
+
|
| 1053 |
+
if opts.size == 0:
|
| 1054 |
+
mean_sim = 0.0
|
| 1055 |
+
distractor_penalty = 0.0
|
| 1056 |
+
gap = 0.0
|
| 1057 |
+
else:
|
| 1058 |
+
sims = (opts @ corr).tolist() # [-1,1]
|
| 1059 |
+
sims_mapped = [safe_map_sim(s) for s in sims] # [0,1]
|
| 1060 |
+
mean_sim = float(sum(sims_mapped) / len(sims_mapped))
|
| 1061 |
+
# gap between best distractor and second best (higher gap -> easier)
|
| 1062 |
+
sorted_s = sorted(sims_mapped, reverse=True)
|
| 1063 |
+
top = sorted_s[0]
|
| 1064 |
+
second = sorted_s[1] if len(sorted_s) > 1 else 0.0
|
| 1065 |
+
gap = top - second
|
| 1066 |
+
# penalties: if distractors are extremely close to correct -> higher penalty
|
| 1067 |
+
too_close_count = sum(1 for s in sims_mapped if s >= 0.85)
|
| 1068 |
+
too_far_count = sum(1 for s in sims_mapped if s <= 0.15)
|
| 1069 |
+
distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped))))
|
| 1070 |
+
amb_flag = 1.0 if top >= 0.8 else 0.0
|
| 1071 |
+
except Exception:
|
| 1072 |
+
mean_sim = 0.0
|
| 1073 |
+
distractor_penalty = 0.0
|
| 1074 |
+
amb_flag = 0.0
|
| 1075 |
+
gap = 0.0
|
| 1076 |
+
|
| 1077 |
+
# question length normalized
|
| 1078 |
+
question_len = len((q_text or "").strip())
|
| 1079 |
+
question_len_norm = min(1.0, question_len / 300.0)
|
| 1080 |
+
|
| 1081 |
+
# count number of concept from string
|
| 1082 |
+
concepts_num = len(concepts_used.keys())
|
| 1083 |
+
if concepts_num < 2:
|
| 1084 |
+
concepts_penalty = 0
|
| 1085 |
+
else:
|
| 1086 |
+
concepts_penalty = concepts_num
|
| 1087 |
+
|
| 1088 |
+
# combine signals using safer semantics:
|
| 1089 |
+
# higher emb_support -> easier (so we subtract a term)
|
| 1090 |
+
# higher distractor_penalty -> harder (add)
|
| 1091 |
+
# better gap -> easier (subtract)
|
| 1092 |
+
# compute score (higher -> harder)
|
| 1093 |
+
|
| 1094 |
+
score = 0.3 # more toward easy
|
| 1095 |
+
score += 0.35 * float(distractor_penalty) # stronger penalty for similar distractors
|
| 1096 |
+
score += 0.2 * float(mean_sim) # emphasizes average distractor similarity (harder if distractors are close, per "khó" criteria)
|
| 1097 |
+
score += 0.12 * float(amb_flag) # penalty if the best distractors hard to distinguish
|
| 1098 |
+
score += 0.1 * float(concepts_penalty) # boost difficuty if more concept used in a question
|
| 1099 |
+
score -= 0.15 * float(gap) # less emphasis on "dễ" if gap is large but other factors are hard
|
| 1100 |
+
score += 0.05 * float(question_len_norm)
|
| 1101 |
+
score -= 0.45 * float(emb_support) # easy ques is obvious while hard question get penalty because they often get rephrase from the original concet -> harder for embedding suppport to be meaningful.
|
| 1102 |
+
|
| 1103 |
+
# clamp
|
| 1104 |
+
score = max(0.0, min(1.0, float(score)))
|
| 1105 |
+
components = {
|
| 1106 |
+
"base": 0.3,
|
| 1107 |
+
"distractor_penalty": 0.35 * float(distractor_penalty),
|
| 1108 |
+
"mean_sim": 0.15 * float(mean_sim),
|
| 1109 |
+
"amb_flag": 0.05 * float(amb_flag),
|
| 1110 |
+
"concepts_num": 0.1 * float(concepts_num),
|
| 1111 |
+
"gap": -0.12 * float(gap),
|
| 1112 |
+
"question_len_norm": 0.05 * float(question_len_norm),
|
| 1113 |
+
"emb_support": -0.45 * float(emb_support),
|
| 1114 |
+
"total_score": score,
|
| 1115 |
+
}
|
| 1116 |
+
|
| 1117 |
+
# label
|
| 1118 |
+
if score <= 0.56:
|
| 1119 |
+
label = "dễ"
|
| 1120 |
+
elif score <= 0.755 and score > 0.56:
|
| 1121 |
+
label = "trung bình"
|
| 1122 |
+
else:
|
| 1123 |
+
label = "khó"
|
| 1124 |
+
|
| 1125 |
+
return score, label, components # type: ignore
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
boto3
|
| 2 |
+
faiss-cpu
|
| 3 |
+
transformers
|
| 4 |
+
sentence-transformers
|
| 5 |
+
fastapi[standard]
|
| 6 |
+
uvicorn
|
| 7 |
+
qdrant-client
|
| 8 |
+
pymupdf4llm
|
| 9 |
+
uuid
|
| 10 |
+
huggingface_hub
|
utils.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
import requests
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
import uuid
|
| 8 |
+
import datetime
|
| 9 |
+
import pathlib
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
#TODO: allow to choose different provider later + dynamic routing when token expired
|
| 13 |
+
API_URL = "https://openrouter.ai/api/v1/chat/completions"
|
| 14 |
+
CEREBRAS_API_KEY = os.environ['OPENROUTER_KEY']
|
| 15 |
+
|
| 16 |
+
HEADERS = {"Authorization": f"Bearer {CEREBRAS_API_KEY}", "Content-Type": "application/json"}
|
| 17 |
+
JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
|
| 18 |
+
|
| 19 |
+
INPUT_TOKEN_COUNT = np.array([], dtype=int)
|
| 20 |
+
OUTPUT_TOKEN_COUNT = np.array([], dtype=int)
|
| 21 |
+
TOTAL_TOKEN_COUNT = np.array([], dtype=int)
|
| 22 |
+
TOTAL_TOKEN_COUNT_EACH_GENERATION = np.array([])
|
| 23 |
+
TIME_INFOs = {}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
FIDDLER_GUARDRAILS_TOKEN = os.environ['FIDDLER_TOKEN']
|
| 27 |
+
SAFETY_GUARDRAILS_URL = "https://guardrails.cloud.fiddler.ai/v3/guardrails/ftl-safety"
|
| 28 |
+
GUARDRAILS_HEADERS = {
|
| 29 |
+
'Content-Type': 'application/json',
|
| 30 |
+
'Authorization': f'Bearer {FIDDLER_GUARDRAILS_TOKEN}',
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def get_safety_response(text, sleep_seconds: float = 0.5):
|
| 34 |
+
time.sleep(sleep_seconds) # rate limited
|
| 35 |
+
response = requests.post(
|
| 36 |
+
SAFETY_GUARDRAILS_URL,
|
| 37 |
+
headers=GUARDRAILS_HEADERS,
|
| 38 |
+
json={'data': {'input': text}},
|
| 39 |
+
)
|
| 40 |
+
response.raise_for_status()
|
| 41 |
+
response_dict = response.json()
|
| 42 |
+
return response_dict
|
| 43 |
+
|
| 44 |
+
def text_safety_check(text: str, sleep_seconds: float = 0.5):
|
| 45 |
+
confs = get_safety_response(text, sleep_seconds)
|
| 46 |
+
max_conf = max(confs.values())
|
| 47 |
+
max_category = list(confs.keys())[list(confs.values()).index(max_conf)]
|
| 48 |
+
return max_conf, max_category
|
| 49 |
+
|
| 50 |
+
def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
|
| 51 |
+
if model == 'openai/gpt-oss-120b': # OpenRouter's version of Cerebras-GPT
|
| 52 |
+
payload = {"model": model, "messages": messages, "temperature": temperature, "provider": {"only": ["Cerebras", "together", "baseten", "deepinfra/fp4"]}}
|
| 53 |
+
else: # default to Cerebras
|
| 54 |
+
payload = {"model": model, "messages": messages, "temperature": temperature}
|
| 55 |
+
|
| 56 |
+
resp = requests.post(API_URL, headers=HEADERS, json=payload, timeout=timeout)
|
| 57 |
+
resp.raise_for_status()
|
| 58 |
+
data = resp.json()
|
| 59 |
+
|
| 60 |
+
# handle various shapes
|
| 61 |
+
if "choices" in data and len(data["choices"]) > 0:
|
| 62 |
+
# prefer message.content
|
| 63 |
+
ch = data["choices"][0]
|
| 64 |
+
if isinstance(ch, dict) and "message" in ch and "content" in ch["message"]:
|
| 65 |
+
return ch["message"]["content"]
|
| 66 |
+
if "text" in ch:
|
| 67 |
+
return ch["text"]
|
| 68 |
+
# final fallback
|
| 69 |
+
raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
|
| 70 |
+
|
| 71 |
+
def _safe_extract_json(text: str) -> dict:
|
| 72 |
+
# remove triple backticks
|
| 73 |
+
text = re.sub(r"```(?:json)?\n?", "", text)
|
| 74 |
+
m = JSON_OBJ_RE.search(text)
|
| 75 |
+
if not m:
|
| 76 |
+
raise ValueError("No JSON object found in model output.")
|
| 77 |
+
js = m.group(1)
|
| 78 |
+
# try load, fix trailing commas
|
| 79 |
+
try:
|
| 80 |
+
return json.loads(js)
|
| 81 |
+
except json.JSONDecodeError:
|
| 82 |
+
fixed = re.sub(r",\s*([}\]])", r"\1", js)
|
| 83 |
+
return json.loads(fixed)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def structure_context_for_llm(
|
| 87 |
+
source_text: str,
|
| 88 |
+
model: str = "openai/gpt-oss-120b",
|
| 89 |
+
temperature: float = 0.2,
|
| 90 |
+
enable_fiddler = False,
|
| 91 |
+
) -> Dict[str, Any]:
|
| 92 |
+
"""
|
| 93 |
+
Take a long source_text, split into N chunks, and restructure them
|
| 94 |
+
so each chunk is self-contained, structured, and semantically meaningful.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
system_message = {
|
| 98 |
+
"role": "system",
|
| 99 |
+
"content": (
|
| 100 |
+
"Bạn là một trợ lý hữu ích chuyên xử lý và cấu trúc văn bản để phục vụ mô hình ngôn ngữ (LLM). Trả lời bằng Tiếng Việt\n"
|
| 101 |
+
"Nhiệm vụ của bạn là:\n"
|
| 102 |
+
"- Nếu văn bản dài trên 500 từ chia văn bản thành 2 đoạn (chunk) có ý nghĩa rõ ràng.\n"
|
| 103 |
+
"- Mỗi chunk phải **tự chứa đủ thông tin** (self-contained) để LLM có thể hiểu độc lập.\n"
|
| 104 |
+
"- Xác định **chủ đề chính (topic)** của mỗi chunk và dùng nó làm KEY trong JSON.\n"
|
| 105 |
+
"- Trong mỗi topic, tổ chức thông tin thành cấu trúc rõ ràng gồm các trường:\n"
|
| 106 |
+
" - 'đoạn văn': nội dung gốc đã cấu trúc đầy đủ\n"
|
| 107 |
+
" - 'khái niệm chính': từ điểm chứa các khái niệm chính với khái niệm phụ hỗ trợ khái niệm chính đi kèm nếu có\n"
|
| 108 |
+
" - 'công thức': danh sách công thức (nếu có)\n"
|
| 109 |
+
" - 'ví dụ': ví dụ minh họa (nếu có)\n"
|
| 110 |
+
" - 'tóm tắt': tóm tắt nội dung, dễ hiểu\n"
|
| 111 |
+
"- Giữ ngữ nghĩa liền mạch.\n"
|
| 112 |
+
"- Chỉ TRẢ VỀ MỘT JSON hợp lệ theo schema, không kèm văn bản khác.\n\n"
|
| 113 |
+
|
| 114 |
+
"Chỉ TRẢ VỀ duy nhất MỘT đối tượng JSON theo schema sau và không có bất kỳ văn bản nào khác:\n\n"
|
| 115 |
+
"{\n"
|
| 116 |
+
' "Tên topic": {"đoạn văn": "nội dung đã cấu trúc của topic 1", "khái niệm chính": {"khái niệm chính 1":["khái niệm phụ", "..."],"khái niệm chính 2":["khái niệm phụ", "..."]}, "công thức": ["..."], "ví dụ": ["..."], "tóm tắt": "tóm tắt ngắn gọn"},\n'
|
| 117 |
+
"}\n"
|
| 118 |
+
)
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
user_message = {
|
| 122 |
+
"role": "user",
|
| 123 |
+
"content": (
|
| 124 |
+
"Hãy chia văn bản sau thành nhiều chunk theo hướng dẫn trên và xuất JSON hợp lệ.\n"
|
| 125 |
+
f"### Văn bản nguồn:\n{source_text}"
|
| 126 |
+
)
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
if enable_fiddler:
|
| 130 |
+
max_conf, max_cat = text_safety_check(user_message['content'])
|
| 131 |
+
if max_conf > 0.5:
|
| 132 |
+
print(f"Harmful content detected: ({max_cat} : {max_conf})")
|
| 133 |
+
return {}
|
| 134 |
+
|
| 135 |
+
raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
|
| 136 |
+
parsed = _safe_extract_json(raw)
|
| 137 |
+
if not isinstance(parsed, dict):
|
| 138 |
+
raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
|
| 139 |
+
return parsed
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def new_generate_mcqs_from_text(
|
| 143 |
+
source_text: str,
|
| 144 |
+
n: int = 3,
|
| 145 |
+
model: str = "openai/openai/gpt-oss-120b",
|
| 146 |
+
temperature: float = 0.2,
|
| 147 |
+
enable_fiddler = False,
|
| 148 |
+
target_difficulty: str = "easy",
|
| 149 |
+
|
| 150 |
+
) -> Dict[str, Any]:
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
expected_concepts = {
|
| 154 |
+
"easy": 1,
|
| 155 |
+
"medium": 2,
|
| 156 |
+
"hard": (3, 4)
|
| 157 |
+
}
|
| 158 |
+
if isinstance(expected_concepts[target_difficulty], tuple):
|
| 159 |
+
min_concepts, max_concepts = expected_concepts[target_difficulty]
|
| 160 |
+
concept_range = f"{min_concepts}-{max_concepts}"
|
| 161 |
+
else:
|
| 162 |
+
concept_range = expected_concepts[target_difficulty]
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
difficulty_prompts = {
|
| 166 |
+
"easy": (
|
| 167 |
+
"- Câu hỏi DỄ: kiểm tra duy nhất 1 khái niệm chính cơ bản dễ hiểu, định nghĩa, hoặc công thức đơn giản."
|
| 168 |
+
"- Đáp án có thể tìm thấy trực tiếp trong văn bản."
|
| 169 |
+
"- Ngữ cảnh đủ để hiểu khái niệm chính."
|
| 170 |
+
"- Distractors khác biệt rõ ràng, dễ loại bỏ."
|
| 171 |
+
"- Độ dài câu hỏi ngắn gọn không quá 10-20 từ hoặc ít hơn 120 ký tự, tập trung vào một ý duy nhất.\n"
|
| 172 |
+
),
|
| 173 |
+
"medium": (
|
| 174 |
+
"- Câu hỏi TRUNG BÌNH kiểm tra khái niệm chính trong văn bản"
|
| 175 |
+
"- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
|
| 176 |
+
"- Các Distractors không quá giống nhau."
|
| 177 |
+
"- Độ dài câu hỏi vừa phải khoảng 23–30 từ hoặc khoảng 150 - 180 ký tự, có thêm chi tiết phụ để suy luận.\n"
|
| 178 |
+
),
|
| 179 |
+
"hard": (
|
| 180 |
+
"- Câu hỏi KHÓ kiểm tra thông tin được phân tích/tổng hợp"
|
| 181 |
+
"- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
|
| 182 |
+
"- Ít nhất 2 distractors gần giống đáp án đúng, độ tương đồng cao. "
|
| 183 |
+
f"- Đáp án yêu cầu học sinh suy luận hoặc áp dụng công thức vào ví dụ nếu có."
|
| 184 |
+
"- Độ dài câu hỏi dài hơn 35 từ hoặc hơn 200 ký tự.\n \n"
|
| 185 |
+
)
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
difficult_criteria = difficulty_prompts[target_difficulty] # "easy", "medium", "hard"
|
| 189 |
+
print(concept_range)
|
| 190 |
+
system_message = {
|
| 191 |
+
"role": "system",
|
| 192 |
+
"content": (
|
| 193 |
+
"Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm (MCQ). Luôn trả lời bằng tiếng việt"
|
| 194 |
+
f"Đảm bảo chỉ tạo sinh câu trắc nghiệm có độ khó sau {difficult_criteria}"
|
| 195 |
+
f"Quan trọng: Mỗi câu hỏi chỉ sử dụng chính xác {concept_range} khái niệm chính (mỗi khái niệm chính có 1 danh sách khái niệm phụ) từ văn bản nguồn. "
|
| 196 |
+
"Mỗi câu hỏi và đáp án phải dựa trên thông tin từ văn bản nguồn. Không được đưa kiến thức ngoài vào."
|
| 197 |
+
"Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không kèm giải thích hay trường thêm:\n\n"
|
| 198 |
+
"{\n"
|
| 199 |
+
' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"...", "khái niệm sử dụng": {"khái niệm chính":["khái niệm phụ", "..."], "..."]}},\n'
|
| 200 |
+
' "2": { ... }\n'
|
| 201 |
+
"}\n\n"
|
| 202 |
+
"Lưu ý:\n"
|
| 203 |
+
f"- Tạo đúng {n} mục, đánh số từ 1 tới {n}.\n"
|
| 204 |
+
"- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
|
| 205 |
+
"- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'options'.\n"
|
| 206 |
+
"- Toàn bộ thông tin cần thiết để trả lời phải nằm trong chính câu hỏi, không tham chiếu lại văn bản nguồn."
|
| 207 |
+
f"- Sử dụng chính xác {concept_range} khái niệm chính"
|
| 208 |
+
)
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
user_message = {
|
| 212 |
+
"role": "user",
|
| 213 |
+
"content": (
|
| 214 |
+
f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Chỉ sử dụng nội dung này làm nguồn duy nhất để xây dựng câu hỏi.\n\n"
|
| 215 |
+
|
| 216 |
+
"### Yêu cầu:\n"
|
| 217 |
+
"- Bám sát vào thông tin trong văn bản; không thêm kiến thức ngoài.\n"
|
| 218 |
+
"- Nếu văn bản thiếu chi tiết, hãy tạo phương án nhiễu (distractors) hợp lý, nhưng phải có thể biện minh từ nội dung hoặc ngữ cảnh.\n"
|
| 219 |
+
f"### Văn bản nguồn:\n{source_text}"
|
| 220 |
+
)
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
if enable_fiddler:
|
| 225 |
+
max_conf, max_cat = text_safety_check(user_message['content'])
|
| 226 |
+
if max_conf > 0.5:
|
| 227 |
+
print(f"Harmful content detected: ({max_cat} : {max_conf})")
|
| 228 |
+
return {}
|
| 229 |
+
|
| 230 |
+
raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
|
| 231 |
+
# print('\n\n',raw)
|
| 232 |
+
parsed = _safe_extract_json(raw)
|
| 233 |
+
# basic validation
|
| 234 |
+
if not isinstance(parsed, dict) or len(parsed) != n:
|
| 235 |
+
raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
|
| 236 |
+
return parsed
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def generate_mcqs_from_text(
|
| 241 |
+
source_text: str,
|
| 242 |
+
n: int = 3,
|
| 243 |
+
model: str = "openai/gpt-oss-120b",
|
| 244 |
+
temperature: float = 0.2,
|
| 245 |
+
enable_fiddler: bool = False,
|
| 246 |
+
) -> Dict[str, Any]:
|
| 247 |
+
system_message = {
|
| 248 |
+
"role": "system",
|
| 249 |
+
"content": (
|
| 250 |
+
"Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm. "
|
| 251 |
+
"Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không có bất kỳ văn bản nào khác:\n\n"
|
| 252 |
+
"{\n"
|
| 253 |
+
' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"..."},\n'
|
| 254 |
+
' "2": { ... }\n'
|
| 255 |
+
"}\n\n"
|
| 256 |
+
"Lưu ý:\n"
|
| 257 |
+
f"- Tạo đúng {n} mục, đánh số từ 1 tới {n}.\n"
|
| 258 |
+
"- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
|
| 259 |
+
"- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'lựa chọn'.\n"
|
| 260 |
+
"- Không kèm giải thích hay trường thêm.\n"
|
| 261 |
+
"- Các phương án sai (distractors) phải hợp lý và không lặp lại."
|
| 262 |
+
)
|
| 263 |
+
}
|
| 264 |
+
user_message = {
|
| 265 |
+
"role": "user",
|
| 266 |
+
"content": (
|
| 267 |
+
f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Dùng nội dung này làm nguồn duy nhất để trả lời."
|
| 268 |
+
"Nếu nội dung quá ít để tạo câu hỏi chính xác, hãy tạo các phương án hợp lý nhưng có thể biện minh được.\n\n"
|
| 269 |
+
f"Nội dung:\n\n{source_text}"
|
| 270 |
+
)
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
if enable_fiddler:
|
| 274 |
+
max_conf, max_cat = text_safety_check(user_message['content'])
|
| 275 |
+
if max_conf > 0.5:
|
| 276 |
+
print(f"Harmful content detected: ({max_cat} : {max_conf})")
|
| 277 |
+
return {"error": "Harmful content detected", f"{max_cat}": f"{str(max_conf)}"}
|
| 278 |
+
|
| 279 |
+
raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
|
| 280 |
+
parsed = _safe_extract_json(raw)
|
| 281 |
+
|
| 282 |
+
# validate structure and length
|
| 283 |
+
if not isinstance(parsed, dict) or len(parsed) != n:
|
| 284 |
+
raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
|
| 285 |
+
return parsed
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# helpers to read/reset token counts
|
| 289 |
+
def get_token_count_record():
|
| 290 |
+
global TOTAL_TOKEN_COUNT_EACH_GENERATION
|
| 291 |
+
TOTAL_TOKEN_COUNT_EACH_GENERATION = np.append(TOTAL_TOKEN_COUNT_EACH_GENERATION, np.sum(TOTAL_TOKEN_COUNT))
|
| 292 |
+
|
| 293 |
+
token_record = {
|
| 294 |
+
'INPUT_token_count': np.sum(INPUT_TOKEN_COUNT),
|
| 295 |
+
'OUTPUT_token_count': np.sum(OUTPUT_TOKEN_COUNT),
|
| 296 |
+
'AVG_INPUT_token_count': np.average(INPUT_TOKEN_COUNT),
|
| 297 |
+
'AVG_OUTPUT_token_count': np.average(OUTPUT_TOKEN_COUNT),
|
| 298 |
+
'TOTAL_token_count': TOTAL_TOKEN_COUNT,
|
| 299 |
+
'TOTAL_token_count_PER_GENERATION - ': TOTAL_TOKEN_COUNT_EACH_GENERATION,
|
| 300 |
+
'AVG_TOTAL_token_count_PER_GENERATION': [np.average(TOTAL_TOKEN_COUNT_EACH_GENERATION), len(TOTAL_TOKEN_COUNT_EACH_GENERATION)],
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
return token_record
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def reset_token_count(reset_all=None):
|
| 307 |
+
"""Call in app.py. For Reset Token Count after 1 Generation Session"""
|
| 308 |
+
global INPUT_TOKEN_COUNT, OUTPUT_TOKEN_COUNT, TOTAL_TOKEN_COUNT, TOTAL_TOKEN_COUNT_EACH_GENERATION
|
| 309 |
+
|
| 310 |
+
INPUT_TOKEN_COUNT = np.array([])
|
| 311 |
+
OUTPUT_TOKEN_COUNT = np.array([])
|
| 312 |
+
TOTAL_TOKEN_COUNT = np.array([])
|
| 313 |
+
|
| 314 |
+
if reset_all:
|
| 315 |
+
TOTAL_TOKEN_COUNT_EACH_GENERATION = np.array([])
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def update_token_count(token_usage):
|
| 319 |
+
"""Update Token Count for each generation
|
| 320 |
+
"usage": {
|
| 321 |
+
"prompt_tokens": 1209,
|
| 322 |
+
"completion_tokens": 313,
|
| 323 |
+
"total_tokens": 1522,
|
| 324 |
+
"prompt_tokens_details": {
|
| 325 |
+
"cached_tokens": 0
|
| 326 |
+
}
|
| 327 |
+
"""
|
| 328 |
+
global INPUT_TOKEN_COUNT, OUTPUT_TOKEN_COUNT, TOTAL_TOKEN_COUNT # get value from global
|
| 329 |
+
prompt_tokens = token_usage['prompt_tokens'] # INPUT token
|
| 330 |
+
completion_tokens = token_usage['completion_tokens'] # OUTPUT token
|
| 331 |
+
total_tokens = token_usage['total_tokens'] # TOTAL token
|
| 332 |
+
|
| 333 |
+
INPUT_TOKEN_COUNT = np.append(INPUT_TOKEN_COUNT, prompt_tokens)
|
| 334 |
+
OUTPUT_TOKEN_COUNT = np.append(OUTPUT_TOKEN_COUNT, completion_tokens)
|
| 335 |
+
TOTAL_TOKEN_COUNT = np.append(TOTAL_TOKEN_COUNT, total_tokens)
|
| 336 |
+
|
| 337 |
+
# print("Input Token Increase:", INPUT_TOKEN_COUNT)
|
| 338 |
+
# print("Output Token Increase:", OUTPUT_TOKEN_COUNT)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def save_logs(record: dict, log_path:str = "logs/generation_log.jsonl"):
|
| 342 |
+
"""
|
| 343 |
+
Append log to log_path
|
| 344 |
+
record: dict with keys you want to store (e.g. filename, input/output token_count, collection, etc..)
|
| 345 |
+
"""
|
| 346 |
+
# create file if not exist
|
| 347 |
+
p = pathlib.Path(log_path)
|
| 348 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 349 |
+
|
| 350 |
+
# add id/timestampt if missing
|
| 351 |
+
record.setdefault('id', str(uuid.uuid4()))
|
| 352 |
+
record.setdefault('timestamp_utc', datetime.datetime.now(datetime.timezone.utc).isoformat() + "Z") # get current time at timezone
|
| 353 |
+
|
| 354 |
+
# append as 1 json file for each generation
|
| 355 |
+
with open(p, "a", encoding='utf-8') as f:
|
| 356 |
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def update_time_info(time_info):
|
| 360 |
+
"""
|
| 361 |
+
"time_info": {
|
| 362 |
+
"queue_time": 0.000600429,
|
| 363 |
+
"prompt_time": 0.052739054,
|
| 364 |
+
"completion_time": 0.15692187,
|
| 365 |
+
"total_time": 0.2117476463317871,
|
| 366 |
+
"created": 1755599458
|
| 367 |
+
}
|
| 368 |
+
"""
|
| 369 |
+
time_info['created'] = time_info
|
| 370 |
+
time_info['created'].pop('created')
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def get_time_info():
|
| 374 |
+
global TIME_INFOs
|
| 375 |
+
return TIME_INFOs
|
| 376 |
+
# token_record = {
|
| 377 |
+
# 'completion_time': np.sum(INPUT_TOKEN_COUNT),
|
| 378 |
+
# 'total_time': np.sum(OUTPUT_TOKEN_COUNT),
|
| 379 |
+
# }
|
| 380 |
+
|
| 381 |
+
def log_pipeline(path, content):
|
| 382 |
+
print("Save result to test/mcq_output.json")
|
| 383 |
+
#save_to_local(path=path, content=content)
|
| 384 |
+
token_record = get_token_count_record()
|
| 385 |
+
|
| 386 |
+
print("Token Record:")
|
| 387 |
+
for record, value in token_record.items():
|
| 388 |
+
print(f'{record}:{value}', '\n')
|
| 389 |
+
|
| 390 |
+
reset_token_count()
|
| 391 |
+
|
| 392 |
+
def save_to_local(path, content):
|
| 393 |
+
"""
|
| 394 |
+
path = 'test/raw_data.json'
|
| 395 |
+
path = 'test/mcq_output.json'
|
| 396 |
+
path = 'test/extract_output.md'
|
| 397 |
+
|
| 398 |
+
"""
|
| 399 |
+
p = pathlib.Path(path)
|
| 400 |
+
p.parent.mkdir(parents=True, exist_ok=True) # create folder if missing
|
| 401 |
+
p.touch(exist_ok=True) # create file if missing
|
| 402 |
+
|
| 403 |
+
if path.lower().endswith('.json'):
|
| 404 |
+
with open(path, 'w', encoding='utf-8') as f:
|
| 405 |
+
f.write(json.dumps(content, ensure_ascii=False, indent=2))
|
| 406 |
+
else:
|
| 407 |
+
with open(path, 'w', encoding='utf-8') as f:
|
| 408 |
+
f.write(f'{content}') # md, txt
|