Spaces:
Runtime error
Runtime error
namberino
commited on
Commit
·
dfa5afb
1
Parent(s):
073c79b
Initial commit
Browse files- .dockerignore +15 -0
- .gitattributes +1 -0
- .github/workflows/hf_deploy.yml +19 -0
- .github/workflows/image_scan.yml +25 -0
- .github/workflows/lint_test.yml +30 -0
- .github/workflows/sast.yml +31 -0
- .github/workflows/security_scan.yml +23 -0
- .gitignore +1 -0
- Dockerfile +32 -0
- app.py +179 -0
- app/app.py +176 -0
- app/generator.py +695 -0
- app/output.json +0 -0
- app/software_report_template.md +261 -0
- app/utils.py +88 -0
- generator.py +696 -0
- requirements.txt +8 -0
- test/cerebras-api.py +48 -0
- test/logging.txt +111 -0
- test/mcq_output.json +163 -0
- test/output.json +0 -0
- utils.py +86 -0
.dockerignore
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
*.sqlite3
|
| 6 |
+
.env
|
| 7 |
+
.env.*
|
| 8 |
+
.git
|
| 9 |
+
.gitignore
|
| 10 |
+
.wheelhouse
|
| 11 |
+
wheels
|
| 12 |
+
dist
|
| 13 |
+
build
|
| 14 |
+
.vscode
|
| 15 |
+
.idea
|
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/hf_deploy.yml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Sync to HuggingFace
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
branches: [ main ]
|
| 5 |
+
|
| 6 |
+
workflow_dispatch:
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
sync-to-hub:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
steps:
|
| 12 |
+
- uses: actions/checkout@v3
|
| 13 |
+
with:
|
| 14 |
+
fetch-depth: 0
|
| 15 |
+
lfs: true
|
| 16 |
+
- name: Push to hub
|
| 17 |
+
env:
|
| 18 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 19 |
+
run: git push https://namberino:$HF_TOKEN@huggingface.co/spaces/namberino/mcq-generator main
|
.github/workflows/image_scan.yml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Scan docker image for security issues
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ main ]
|
| 6 |
+
|
| 7 |
+
pull_request:
|
| 8 |
+
|
| 9 |
+
workflow_dispatch:
|
| 10 |
+
|
| 11 |
+
jobs:
|
| 12 |
+
trivy:
|
| 13 |
+
runs-on: ubuntu-latest
|
| 14 |
+
steps:
|
| 15 |
+
- uses: actions/checkout@v4
|
| 16 |
+
|
| 17 |
+
- name: Build image
|
| 18 |
+
run: docker build -t mcq-gen:ci .
|
| 19 |
+
|
| 20 |
+
- name: Run Trivy scan
|
| 21 |
+
uses: aquasecurity/trivy-action@0.32.0
|
| 22 |
+
with:
|
| 23 |
+
image-ref: mcq-gen:ci
|
| 24 |
+
format: 'table'
|
| 25 |
+
exit-code: '1'
|
.github/workflows/lint_test.yml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Lint, Typecheck, Tests
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches: [ main ]
|
| 6 |
+
|
| 7 |
+
pull_request:
|
| 8 |
+
|
| 9 |
+
workflow_dispatch:
|
| 10 |
+
|
| 11 |
+
jobs:
|
| 12 |
+
lint-and-test:
|
| 13 |
+
runs-on: ubuntu-latest
|
| 14 |
+
steps:
|
| 15 |
+
- uses: actions/checkout@v4
|
| 16 |
+
with:
|
| 17 |
+
fetch-depth: 0
|
| 18 |
+
|
| 19 |
+
- name: Set up Python
|
| 20 |
+
uses: actions/setup-python@v4
|
| 21 |
+
with:
|
| 22 |
+
python-version: "3.11"
|
| 23 |
+
|
| 24 |
+
- name: Install dependencies
|
| 25 |
+
run: |
|
| 26 |
+
pip install ruff
|
| 27 |
+
|
| 28 |
+
- name: Run ruff (lint)
|
| 29 |
+
run: |
|
| 30 |
+
ruff check .
|
.github/workflows/sast.yml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: SAST
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
pull_request:
|
| 6 |
+
workflow_dispatch:
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
sast:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
steps:
|
| 12 |
+
- uses: actions/checkout@v4
|
| 13 |
+
|
| 14 |
+
- name: Install tools
|
| 15 |
+
run: |
|
| 16 |
+
python -m pip install --upgrade pip
|
| 17 |
+
pip install semgrep bandit
|
| 18 |
+
|
| 19 |
+
- name: Run semgrep
|
| 20 |
+
run: semgrep --config auto --output semgrep-results.txt || true
|
| 21 |
+
|
| 22 |
+
- name: Run bandit
|
| 23 |
+
run: bandit -r . -f json -o bandit-results.json || true
|
| 24 |
+
|
| 25 |
+
- name: Upload SARIF/artifacts
|
| 26 |
+
uses: actions/upload-artifact@v4
|
| 27 |
+
with:
|
| 28 |
+
name: security-reports
|
| 29 |
+
path: |
|
| 30 |
+
semgrep-results.txt
|
| 31 |
+
bandit-results.json
|
.github/workflows/security_scan.yml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Scan for security issues
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
pull_request:
|
| 5 |
+
workflow_dispatch:
|
| 6 |
+
|
| 7 |
+
jobs:
|
| 8 |
+
gitleaks-scan:
|
| 9 |
+
runs-on: ubuntu-latest
|
| 10 |
+
name: Scan for secrets and sensitive information
|
| 11 |
+
steps:
|
| 12 |
+
- name: Checkout repo
|
| 13 |
+
uses: actions/checkout@v4
|
| 14 |
+
with:
|
| 15 |
+
fetch-depth: 0
|
| 16 |
+
|
| 17 |
+
- name: Run gitleaks
|
| 18 |
+
uses: gitleaks/gitleaks-action@v2
|
| 19 |
+
env:
|
| 20 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
| 21 |
+
# GITLEAKS_CONFIG: .gitleaks.toml
|
| 22 |
+
GITLEAKS_ENABLE_UPLOAD_ARTIFACT: true
|
| 23 |
+
GITLEAKS_ENABLE_SUMMARY: true
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.vscode
|
Dockerfile
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# set HF cache to /tmp for writable FS on Spaces
|
| 4 |
+
ENV HF_HOME=/tmp/huggingface
|
| 5 |
+
ENV TOKENIZERS_PARALLELISM=false
|
| 6 |
+
|
| 7 |
+
# install system packages needed by some python libs
|
| 8 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 9 |
+
build-essential \
|
| 10 |
+
git \
|
| 11 |
+
wget \
|
| 12 |
+
libsndfile1 \
|
| 13 |
+
libgl1 \
|
| 14 |
+
libglib2.0-0 \
|
| 15 |
+
poppler-utils \
|
| 16 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
+
|
| 18 |
+
WORKDIR /app
|
| 19 |
+
|
| 20 |
+
# copy requirements and install
|
| 21 |
+
COPY requirements.txt /app/requirements.txt
|
| 22 |
+
RUN pip install --upgrade pip
|
| 23 |
+
# try to be robust to wheels/build issues
|
| 24 |
+
# RUN pip wheel --no-cache-dir --wheel-dir=/wheels -r /app/requirements.txt || true
|
| 25 |
+
RUN pip install --no-cache-dir -r /app/requirements.txt
|
| 26 |
+
|
| 27 |
+
# copy app code
|
| 28 |
+
COPY . /app
|
| 29 |
+
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
|
| 32 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import tempfile
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
# Import the user's RAGMCQ implementation
|
| 11 |
+
from generator import RAGMCQ
|
| 12 |
+
|
| 13 |
+
app = FastAPI(title="RAG MCQ Generator API")
|
| 14 |
+
|
| 15 |
+
# allow cross-origin requests (adjust in production)
|
| 16 |
+
app.add_middleware(
|
| 17 |
+
CORSMiddleware,
|
| 18 |
+
allow_origins=["*"],
|
| 19 |
+
allow_credentials=True,
|
| 20 |
+
allow_methods=["*"],
|
| 21 |
+
allow_headers=["*"],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# global rag instance
|
| 25 |
+
rag: Optional[RAGMCQ] = None
|
| 26 |
+
|
| 27 |
+
class GenerateResponse(BaseModel):
|
| 28 |
+
mcqs: dict
|
| 29 |
+
validation: Optional[dict] = None
|
| 30 |
+
|
| 31 |
+
class ListResponse(BaseModel):
|
| 32 |
+
files: list
|
| 33 |
+
|
| 34 |
+
@app.on_event("startup")
|
| 35 |
+
def startup_event():
|
| 36 |
+
global rag
|
| 37 |
+
|
| 38 |
+
# instantiate the heavy object once
|
| 39 |
+
rag = RAGMCQ(
|
| 40 |
+
qdrant_url=os.environ['QDRANT_URL'],
|
| 41 |
+
qdrant_api_key=os.environ['QDRANT_API_KEY']
|
| 42 |
+
)
|
| 43 |
+
print("RAGMCQ instance created on startup.")
|
| 44 |
+
|
| 45 |
+
@app.get("/health")
|
| 46 |
+
def health():
|
| 47 |
+
return {"status": "ok", "ready": rag is not None}
|
| 48 |
+
|
| 49 |
+
def _save_upload_to_temp(upload: UploadFile) -> str:
|
| 50 |
+
suffix = ".pdf"
|
| 51 |
+
fd, path = tempfile.mkstemp(suffix=suffix)
|
| 52 |
+
os.close(fd)
|
| 53 |
+
with open(path, "wb") as out_file:
|
| 54 |
+
shutil.copyfileobj(upload.file, out_file)
|
| 55 |
+
return path
|
| 56 |
+
|
| 57 |
+
@app.get("/list_collection_files", response_model=ListResponse)
|
| 58 |
+
async def list_collection_files_endpoint(
|
| 59 |
+
collection_name: str = "programming"
|
| 60 |
+
):
|
| 61 |
+
global rag
|
| 62 |
+
if rag is None:
|
| 63 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 64 |
+
|
| 65 |
+
files = rag.list_files_in_collection(collection_name)
|
| 66 |
+
|
| 67 |
+
return {"files": files}
|
| 68 |
+
|
| 69 |
+
@app.post("/generate_saved", response_model=GenerateResponse)
|
| 70 |
+
async def generate_saved_endpoint(
|
| 71 |
+
n_questions: int = Form(10),
|
| 72 |
+
qdrant_filename: str = Form("default_filename"),
|
| 73 |
+
collection_name: str = Form("programming"),
|
| 74 |
+
mode: str = Form("rag"),
|
| 75 |
+
questions_per_chunk: int = Form(3),
|
| 76 |
+
top_k: int = Form(3),
|
| 77 |
+
temperature: float = Form(0.2),
|
| 78 |
+
validate: bool = Form(False),
|
| 79 |
+
use_model_verification: bool = Form(False)
|
| 80 |
+
):
|
| 81 |
+
global rag
|
| 82 |
+
if rag is None:
|
| 83 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
mcqs = rag.generate_from_qdrant(
|
| 87 |
+
filename=qdrant_filename,
|
| 88 |
+
collection=collection_name,
|
| 89 |
+
n_questions=n_questions,
|
| 90 |
+
mode=mode,
|
| 91 |
+
questions_per_chunk=questions_per_chunk,
|
| 92 |
+
top_k=top_k,
|
| 93 |
+
temperature=temperature
|
| 94 |
+
)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
|
| 97 |
+
|
| 98 |
+
validation_report = None
|
| 99 |
+
|
| 100 |
+
if validate:
|
| 101 |
+
try:
|
| 102 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 103 |
+
validation_report = rag.validate_mcqs(mcqs, top_k=top_k, use_model_verification=use_model_verification)
|
| 104 |
+
except Exception as e:
|
| 105 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 106 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 107 |
+
|
| 108 |
+
return {"mcqs": mcqs, "validation": validation_report}
|
| 109 |
+
|
| 110 |
+
@app.post("/generate", response_model=GenerateResponse)
|
| 111 |
+
async def generate_endpoint(
|
| 112 |
+
background_tasks: BackgroundTasks,
|
| 113 |
+
file: UploadFile = File(...),
|
| 114 |
+
n_questions: int = Form(10),
|
| 115 |
+
qdrant_filename: str = Form("default_filename"),
|
| 116 |
+
collection_name: str = Form("programming"),
|
| 117 |
+
mode: str = Form("rag"),
|
| 118 |
+
questions_per_page: int = Form(3),
|
| 119 |
+
top_k: int = Form(3),
|
| 120 |
+
temperature: float = Form(0.2),
|
| 121 |
+
validate: bool = Form(False),
|
| 122 |
+
use_model_verification: bool = Form(False)
|
| 123 |
+
):
|
| 124 |
+
global rag
|
| 125 |
+
if rag is None:
|
| 126 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 127 |
+
|
| 128 |
+
# basic file validation
|
| 129 |
+
if not file.filename.lower().endswith(".pdf"):
|
| 130 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
|
| 131 |
+
|
| 132 |
+
# save uploaded file to a temp location
|
| 133 |
+
tmp_path = _save_upload_to_temp(file)
|
| 134 |
+
|
| 135 |
+
# ensure file removed afterward
|
| 136 |
+
def _cleanup(path: str):
|
| 137 |
+
try:
|
| 138 |
+
os.remove(path)
|
| 139 |
+
except Exception:
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
background_tasks.add_task(_cleanup, tmp_path)
|
| 143 |
+
|
| 144 |
+
# save pdf
|
| 145 |
+
try:
|
| 146 |
+
rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=True)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
|
| 149 |
+
|
| 150 |
+
# generate
|
| 151 |
+
try:
|
| 152 |
+
mcqs = rag.generate_from_pdf(
|
| 153 |
+
tmp_path,
|
| 154 |
+
n_questions=n_questions,
|
| 155 |
+
mode=mode,
|
| 156 |
+
questions_per_page=questions_per_page,
|
| 157 |
+
top_k=top_k,
|
| 158 |
+
temperature=temperature,
|
| 159 |
+
)
|
| 160 |
+
except Exception as e:
|
| 161 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
|
| 162 |
+
|
| 163 |
+
validation_report = None
|
| 164 |
+
|
| 165 |
+
if validate:
|
| 166 |
+
try:
|
| 167 |
+
# rag.build_index_from_pdf(tmp_path)
|
| 168 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 169 |
+
validation_report = rag.validate_mcqs(mcqs, top_k=top_k, use_model_verification=use_model_verification)
|
| 170 |
+
except Exception as e:
|
| 171 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 172 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 173 |
+
|
| 174 |
+
return {"mcqs": mcqs, "validation": validation_report}
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
import uvicorn
|
| 179 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000, log_level="info")
|
app/app.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import tempfile
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
# Import the user's RAGMCQ implementation
|
| 11 |
+
from generator import RAGMCQ
|
| 12 |
+
|
| 13 |
+
app = FastAPI(title="RAG MCQ Generator API")
|
| 14 |
+
|
| 15 |
+
# allow cross-origin requests (adjust in production)
|
| 16 |
+
app.add_middleware(
|
| 17 |
+
CORSMiddleware,
|
| 18 |
+
allow_origins=["*"],
|
| 19 |
+
allow_credentials=True,
|
| 20 |
+
allow_methods=["*"],
|
| 21 |
+
allow_headers=["*"],
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# global rag instance
|
| 25 |
+
rag: Optional[RAGMCQ] = None
|
| 26 |
+
|
| 27 |
+
class GenerateResponse(BaseModel):
|
| 28 |
+
mcqs: dict
|
| 29 |
+
validation: Optional[dict] = None
|
| 30 |
+
|
| 31 |
+
class ListResponse(BaseModel):
|
| 32 |
+
files: list
|
| 33 |
+
|
| 34 |
+
@app.on_event("startup")
|
| 35 |
+
def startup_event():
|
| 36 |
+
global rag
|
| 37 |
+
|
| 38 |
+
# instantiate the heavy object once
|
| 39 |
+
rag = RAGMCQ()
|
| 40 |
+
print("RAGMCQ instance created on startup.")
|
| 41 |
+
|
| 42 |
+
@app.get("/health")
|
| 43 |
+
def health():
|
| 44 |
+
return {"status": "ok", "ready": rag is not None}
|
| 45 |
+
|
| 46 |
+
def _save_upload_to_temp(upload: UploadFile) -> str:
|
| 47 |
+
suffix = ".pdf"
|
| 48 |
+
fd, path = tempfile.mkstemp(suffix=suffix)
|
| 49 |
+
os.close(fd)
|
| 50 |
+
with open(path, "wb") as out_file:
|
| 51 |
+
shutil.copyfileobj(upload.file, out_file)
|
| 52 |
+
return path
|
| 53 |
+
|
| 54 |
+
@app.get("/list_collection_files", response_model=ListResponse)
|
| 55 |
+
async def list_collection_files_endpoint(
|
| 56 |
+
collection_name: str = "programming"
|
| 57 |
+
):
|
| 58 |
+
global rag
|
| 59 |
+
if rag is None:
|
| 60 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 61 |
+
|
| 62 |
+
files = rag.list_files_in_collection(collection_name)
|
| 63 |
+
|
| 64 |
+
return {"files": files}
|
| 65 |
+
|
| 66 |
+
@app.post("/generate_saved", response_model=GenerateResponse)
|
| 67 |
+
async def generate_saved_endpoint(
|
| 68 |
+
n_questions: int = Form(10),
|
| 69 |
+
qdrant_filename: str = Form("default_filename"),
|
| 70 |
+
collection_name: str = Form("programming"),
|
| 71 |
+
mode: str = Form("rag"),
|
| 72 |
+
questions_per_chunk: int = Form(3),
|
| 73 |
+
top_k: int = Form(3),
|
| 74 |
+
temperature: float = Form(0.2),
|
| 75 |
+
validate: bool = Form(False),
|
| 76 |
+
use_model_verification: bool = Form(False)
|
| 77 |
+
):
|
| 78 |
+
global rag
|
| 79 |
+
if rag is None:
|
| 80 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
mcqs = rag.generate_from_qdrant(
|
| 84 |
+
filename=qdrant_filename,
|
| 85 |
+
collection=collection_name,
|
| 86 |
+
n_questions=n_questions,
|
| 87 |
+
mode=mode,
|
| 88 |
+
questions_per_chunk=questions_per_chunk,
|
| 89 |
+
top_k=top_k,
|
| 90 |
+
temperature=temperature
|
| 91 |
+
)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
|
| 94 |
+
|
| 95 |
+
validation_report = None
|
| 96 |
+
|
| 97 |
+
if validate:
|
| 98 |
+
try:
|
| 99 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 100 |
+
validation_report = rag.validate_mcqs(mcqs, top_k=top_k, use_model_verification=use_model_verification)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 103 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 104 |
+
|
| 105 |
+
return {"mcqs": mcqs, "validation": validation_report}
|
| 106 |
+
|
| 107 |
+
@app.post("/generate", response_model=GenerateResponse)
|
| 108 |
+
async def generate_endpoint(
|
| 109 |
+
background_tasks: BackgroundTasks,
|
| 110 |
+
file: UploadFile = File(...),
|
| 111 |
+
n_questions: int = Form(10),
|
| 112 |
+
qdrant_filename: str = Form("default_filename"),
|
| 113 |
+
collection_name: str = Form("programming"),
|
| 114 |
+
mode: str = Form("rag"),
|
| 115 |
+
questions_per_page: int = Form(3),
|
| 116 |
+
top_k: int = Form(3),
|
| 117 |
+
temperature: float = Form(0.2),
|
| 118 |
+
validate: bool = Form(False),
|
| 119 |
+
use_model_verification: bool = Form(False)
|
| 120 |
+
):
|
| 121 |
+
global rag
|
| 122 |
+
if rag is None:
|
| 123 |
+
raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
|
| 124 |
+
|
| 125 |
+
# basic file validation
|
| 126 |
+
if not file.filename.lower().endswith(".pdf"):
|
| 127 |
+
raise HTTPException(status_code=400, detail="Only PDF files are supported.")
|
| 128 |
+
|
| 129 |
+
# save uploaded file to a temp location
|
| 130 |
+
tmp_path = _save_upload_to_temp(file)
|
| 131 |
+
|
| 132 |
+
# ensure file removed afterward
|
| 133 |
+
def _cleanup(path: str):
|
| 134 |
+
try:
|
| 135 |
+
os.remove(path)
|
| 136 |
+
except Exception:
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
background_tasks.add_task(_cleanup, tmp_path)
|
| 140 |
+
|
| 141 |
+
# save pdf
|
| 142 |
+
try:
|
| 143 |
+
rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=True)
|
| 144 |
+
except Exception as e:
|
| 145 |
+
raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
|
| 146 |
+
|
| 147 |
+
# generate
|
| 148 |
+
try:
|
| 149 |
+
mcqs = rag.generate_from_pdf(
|
| 150 |
+
tmp_path,
|
| 151 |
+
n_questions=n_questions,
|
| 152 |
+
mode=mode,
|
| 153 |
+
questions_per_page=questions_per_page,
|
| 154 |
+
top_k=top_k,
|
| 155 |
+
temperature=temperature,
|
| 156 |
+
)
|
| 157 |
+
except Exception as e:
|
| 158 |
+
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
|
| 159 |
+
|
| 160 |
+
validation_report = None
|
| 161 |
+
|
| 162 |
+
if validate:
|
| 163 |
+
try:
|
| 164 |
+
# rag.build_index_from_pdf(tmp_path)
|
| 165 |
+
# validate_mcqs expects keys as strings and the normalized content
|
| 166 |
+
validation_report = rag.validate_mcqs(mcqs, top_k=top_k, use_model_verification=use_model_verification)
|
| 167 |
+
except Exception as e:
|
| 168 |
+
# don't fail the whole request for a validation error — return generator output and note the error
|
| 169 |
+
validation_report = {"error": f"Validation failed: {e}"}
|
| 170 |
+
|
| 171 |
+
return {"mcqs": mcqs, "validation": validation_report}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
import uvicorn
|
| 176 |
+
uvicorn.run("app:app", host="0.0.0.0", port=8000, log_level="info")
|
app/generator.py
ADDED
|
@@ -0,0 +1,695 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List, Tuple, Dict, Any, Optional
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from uuid import uuid4
|
| 7 |
+
import pymupdf4llm
|
| 8 |
+
import pymupdf as fitz
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from qdrant_client import QdrantClient
|
| 12 |
+
from qdrant_client.http.models import (
|
| 13 |
+
PointStruct,
|
| 14 |
+
Filter,
|
| 15 |
+
FieldCondition,
|
| 16 |
+
MatchValue,
|
| 17 |
+
Distance,
|
| 18 |
+
VectorParams,
|
| 19 |
+
)
|
| 20 |
+
from qdrant_client.http import models as rest
|
| 21 |
+
_HAS_QDRANT = True
|
| 22 |
+
except Exception:
|
| 23 |
+
_HAS_QDRANT = False
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import faiss
|
| 27 |
+
_HAS_FAISS = True
|
| 28 |
+
except Exception:
|
| 29 |
+
_HAS_FAISS = False
|
| 30 |
+
|
| 31 |
+
from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json
|
| 32 |
+
|
| 33 |
+
class RAGMCQ:
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
| 37 |
+
hf_model: str = "gpt-oss-120b",
|
| 38 |
+
qdrant_url: str = None,
|
| 39 |
+
qdrant_api_key: str = None,
|
| 40 |
+
qdrant_prefer_grpc: bool = False,
|
| 41 |
+
):
|
| 42 |
+
self.embedder = SentenceTransformer(embedder_model)
|
| 43 |
+
self.hf_model = hf_model
|
| 44 |
+
self.embeddings = None # np.array of shape (N, D)
|
| 45 |
+
self.texts = [] # list of chunk texts
|
| 46 |
+
self.metadata = [] # list of dicts (page, chunk_id, char_range)
|
| 47 |
+
self.index = None
|
| 48 |
+
self.dim = self.embedder.get_sentence_embedding_dimension()
|
| 49 |
+
|
| 50 |
+
self.qdrant = None
|
| 51 |
+
self.qdrant_url = qdrant_url
|
| 52 |
+
self.qdrant_api_key = qdrant_api_key
|
| 53 |
+
self.qdrant_prefer_grpc = qdrant_prefer_grpc
|
| 54 |
+
if qdrant_url:
|
| 55 |
+
self.connect_qdrant(qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
|
| 56 |
+
|
| 57 |
+
def extract_pages(
|
| 58 |
+
self,
|
| 59 |
+
pdf_path: str,
|
| 60 |
+
*,
|
| 61 |
+
pages: Optional[List[int]] = None,
|
| 62 |
+
ignore_images: bool = False,
|
| 63 |
+
dpi: int = 150
|
| 64 |
+
) -> List[str]:
|
| 65 |
+
doc = fitz.open(pdf_path)
|
| 66 |
+
try:
|
| 67 |
+
# request page-wise output (page_chunks=True -> list[dict] per page)
|
| 68 |
+
page_dicts = pymupdf4llm.to_markdown(
|
| 69 |
+
doc,
|
| 70 |
+
pages=pages,
|
| 71 |
+
ignore_images=ignore_images,
|
| 72 |
+
dpi=dpi,
|
| 73 |
+
page_chunks=True,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# to_markdown(..., page_chunks=True) returns a list of dicts, each has key "text" (markdown)
|
| 77 |
+
pages_md: List[str] = []
|
| 78 |
+
for p in page_dicts:
|
| 79 |
+
txt = p.get("text", "") or ""
|
| 80 |
+
pages_md.append(txt.strip())
|
| 81 |
+
|
| 82 |
+
return pages_md
|
| 83 |
+
finally:
|
| 84 |
+
doc.close()
|
| 85 |
+
|
| 86 |
+
def chunk_text(self, text: str, max_chars: int = 1200) -> List[str]:
|
| 87 |
+
text = text.strip()
|
| 88 |
+
if not text:
|
| 89 |
+
return []
|
| 90 |
+
if len(text) <= max_chars:
|
| 91 |
+
return [text]
|
| 92 |
+
|
| 93 |
+
# split by sentence-like boundaries
|
| 94 |
+
sentences = re.split(r'(?<=[\.\?\!])\s+', text)
|
| 95 |
+
chunks = []
|
| 96 |
+
cur = ""
|
| 97 |
+
for s in sentences:
|
| 98 |
+
if len(cur) + len(s) + 1 <= max_chars:
|
| 99 |
+
cur += (" " if cur else "") + s
|
| 100 |
+
else:
|
| 101 |
+
if cur:
|
| 102 |
+
chunks.append(cur)
|
| 103 |
+
cur = s
|
| 104 |
+
if cur:
|
| 105 |
+
chunks.append(cur)
|
| 106 |
+
|
| 107 |
+
# if still too long, hard-split
|
| 108 |
+
final = []
|
| 109 |
+
for c in chunks:
|
| 110 |
+
if len(c) <= max_chars:
|
| 111 |
+
final.append(c)
|
| 112 |
+
else:
|
| 113 |
+
for i in range(0, len(c), max_chars):
|
| 114 |
+
final.append(c[i:i+max_chars])
|
| 115 |
+
return final
|
| 116 |
+
|
| 117 |
+
def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
|
| 118 |
+
pages = self.extract_pages(pdf_path)
|
| 119 |
+
self.texts = []
|
| 120 |
+
self.metadata = []
|
| 121 |
+
|
| 122 |
+
for p_idx, page_text in enumerate(pages, start=1):
|
| 123 |
+
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
|
| 124 |
+
for cid, ch in enumerate(chunks, start=1):
|
| 125 |
+
self.texts.append(ch)
|
| 126 |
+
self.metadata.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
|
| 127 |
+
|
| 128 |
+
if not self.texts:
|
| 129 |
+
raise RuntimeError("No text extracted from PDF.")
|
| 130 |
+
|
| 131 |
+
# compute embeddings
|
| 132 |
+
emb = self.embedder.encode(self.texts, convert_to_numpy=True, show_progress_bar=True)
|
| 133 |
+
self.embeddings = emb.astype("float32")
|
| 134 |
+
self._build_faiss_index()
|
| 135 |
+
|
| 136 |
+
def _build_faiss_index(self):
|
| 137 |
+
if _HAS_FAISS:
|
| 138 |
+
d = self.embeddings.shape[1]
|
| 139 |
+
index = faiss.IndexFlatIP(d) # inner product -> cosine if vectors normalized
|
| 140 |
+
faiss.normalize_L2(self.embeddings)
|
| 141 |
+
index.add(self.embeddings)
|
| 142 |
+
self.index = index
|
| 143 |
+
else:
|
| 144 |
+
# store normalized embeddings and use brute-force numpy
|
| 145 |
+
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-10
|
| 146 |
+
self.embeddings = self.embeddings / norms
|
| 147 |
+
self.index = None
|
| 148 |
+
|
| 149 |
+
def _retrieve(self, query: str, top_k: int = 3) -> List[Tuple[int, float]]:
|
| 150 |
+
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")
|
| 151 |
+
|
| 152 |
+
if _HAS_FAISS:
|
| 153 |
+
faiss.normalize_L2(q_emb)
|
| 154 |
+
D_list, I_list = self.index.search(q_emb, top_k)
|
| 155 |
+
# D are inner products; return list of (idx, score)
|
| 156 |
+
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
|
| 157 |
+
else:
|
| 158 |
+
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
|
| 159 |
+
sims = (self.embeddings @ qn.T).squeeze(axis=1)
|
| 160 |
+
idxs = np.argsort(-sims)[:top_k]
|
| 161 |
+
return [(int(i), float(sims[i])) for i in idxs]
|
| 162 |
+
|
| 163 |
+
def generate_from_pdf(
|
| 164 |
+
self,
|
| 165 |
+
pdf_path: str,
|
| 166 |
+
n_questions: int = 10,
|
| 167 |
+
mode: str = "rag", # per_page or rag
|
| 168 |
+
questions_per_page: int = 3, # for per_page mode
|
| 169 |
+
top_k: int = 3, # chunks to retrieve for each question in rag mode
|
| 170 |
+
temperature: float = 0.2,
|
| 171 |
+
) -> Dict[str, Any]:
|
| 172 |
+
# build index
|
| 173 |
+
self.build_index_from_pdf(pdf_path)
|
| 174 |
+
|
| 175 |
+
output: Dict[str, Any] = {}
|
| 176 |
+
qcount = 0
|
| 177 |
+
|
| 178 |
+
if mode == "per_page":
|
| 179 |
+
# iterate pages -> chunks
|
| 180 |
+
for idx, meta in enumerate(self.metadata):
|
| 181 |
+
chunk_text = self.texts[idx]
|
| 182 |
+
|
| 183 |
+
if not chunk_text.strip():
|
| 184 |
+
continue
|
| 185 |
+
to_gen = questions_per_page
|
| 186 |
+
|
| 187 |
+
# ask generator
|
| 188 |
+
try:
|
| 189 |
+
mcq_block = generate_mcqs_from_text(
|
| 190 |
+
chunk_text, n=to_gen, model=self.hf_model, temperature=temperature
|
| 191 |
+
)
|
| 192 |
+
except Exception as e:
|
| 193 |
+
# skip this chunk if generator fails
|
| 194 |
+
print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
|
| 195 |
+
continue
|
| 196 |
+
|
| 197 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 198 |
+
qcount += 1
|
| 199 |
+
output[str(qcount)] = mcq_block[item]
|
| 200 |
+
if qcount >= n_questions:
|
| 201 |
+
return output
|
| 202 |
+
|
| 203 |
+
return output
|
| 204 |
+
|
| 205 |
+
elif mode == "rag":
|
| 206 |
+
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
|
| 207 |
+
# create queries by sampling chunk text sentences.
|
| 208 |
+
# stop when n_questions reached or max_attempts exceeded.
|
| 209 |
+
attempts = 0
|
| 210 |
+
max_attempts = n_questions * 4
|
| 211 |
+
|
| 212 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 213 |
+
attempts += 1
|
| 214 |
+
# create a seed query: pick a random chunk, pick a sentence from it
|
| 215 |
+
seed_idx = random.randrange(len(self.texts))
|
| 216 |
+
chunk = self.texts[seed_idx]
|
| 217 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 218 |
+
seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
|
| 219 |
+
query = f"Create questions about: {seed_sent}"
|
| 220 |
+
|
| 221 |
+
# retrieve top_k chunks
|
| 222 |
+
retrieved = self._retrieve(query, top_k=top_k)
|
| 223 |
+
context_parts = []
|
| 224 |
+
for ridx, score in retrieved:
|
| 225 |
+
md = self.metadata[ridx]
|
| 226 |
+
context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
|
| 227 |
+
context = "\n\n".join(context_parts)
|
| 228 |
+
|
| 229 |
+
# call generator for 1 question (or small batch) with the retrieved context
|
| 230 |
+
try:
|
| 231 |
+
# request 1 question at a time to keep diversity
|
| 232 |
+
mcq_block = generate_mcqs_from_text(
|
| 233 |
+
context, n=1, model=self.hf_model, temperature=temperature
|
| 234 |
+
)
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
# append result(s)
|
| 240 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 241 |
+
qcount += 1
|
| 242 |
+
output[str(qcount)] = mcq_block[item]
|
| 243 |
+
if qcount >= n_questions:
|
| 244 |
+
return output
|
| 245 |
+
|
| 246 |
+
return output
|
| 247 |
+
else:
|
| 248 |
+
raise ValueError("mode must be 'per_page' or 'rag'.")
|
| 249 |
+
|
| 250 |
+
def validate_mcqs(
|
| 251 |
+
self,
|
| 252 |
+
mcqs: Dict[str, Any],
|
| 253 |
+
top_k: int = 4,
|
| 254 |
+
similarity_threshold: float = 0.5,
|
| 255 |
+
evidence_score_cutoff: float = 0.5,
|
| 256 |
+
use_model_verification: bool = True,
|
| 257 |
+
model_verification_temperature: float = 0.0,
|
| 258 |
+
) -> Dict[str, Any]:
|
| 259 |
+
if self.embeddings is None or not self.texts:
|
| 260 |
+
raise RuntimeError("Index/embeddings not built. Run build_index_from_pdf() first.")
|
| 261 |
+
|
| 262 |
+
report: Dict[str, Any] = {}
|
| 263 |
+
|
| 264 |
+
# helper: semantic similarity search on statement -> returns list of (idx, score)
|
| 265 |
+
def semantic_search(statement: str, k: int = top_k):
|
| 266 |
+
q_emb = self.embedder.encode([statement], convert_to_numpy=True).astype("float32")
|
| 267 |
+
|
| 268 |
+
if _HAS_FAISS:
|
| 269 |
+
faiss.normalize_L2(q_emb)
|
| 270 |
+
D_list, I_list = self.index.search(q_emb, k)
|
| 271 |
+
# D are inner products; return list of (idx, score)
|
| 272 |
+
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
|
| 273 |
+
else:
|
| 274 |
+
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
|
| 275 |
+
sims = (self.embeddings @ qn.T).squeeze(axis=1)
|
| 276 |
+
idxs = np.argsort(-sims)[:k]
|
| 277 |
+
return [(int(i), float(sims[i])) for i in idxs]
|
| 278 |
+
|
| 279 |
+
# helper: verify with model (strict JSON in response)
|
| 280 |
+
def _verify_with_model(question_text: str, options: Dict[str, str], correct_text: str, context_text: str):
|
| 281 |
+
system = {
|
| 282 |
+
"role": "system",
|
| 283 |
+
"content": (
|
| 284 |
+
"Bạn là một trợ lý đánh giá tính thực chứng của câu hỏi trắc nghiệm dựa trên đoạn văn được cung cấp. "
|
| 285 |
+
"Hãy trả lời DUY NHẤT bằng JSON hợp lệ (không có văn bản khác) theo schema:\n\n"
|
| 286 |
+
"{\n"
|
| 287 |
+
' "supported": true/false, # câu trả lời đúng có được nội dung chứng thực không\n'
|
| 288 |
+
' "confidence": 0.0-1.0, # mức độ tự tin (số)\n'
|
| 289 |
+
' "evidence": "cụm văn bản ngắn làm bằng chứng hoặc trích dẫn",\n'
|
| 290 |
+
' "reason": "ngắn gọn, vì sao supported hoặc không"\n'
|
| 291 |
+
"}\n\n"
|
| 292 |
+
"Luôn dựa chỉ trên nội dung trong trường 'Context' dưới đây. Nếu nội dung không chứa bằng chứng, trả về supported: false."
|
| 293 |
+
)
|
| 294 |
+
}
|
| 295 |
+
user = {
|
| 296 |
+
"role": "user",
|
| 297 |
+
"content": (
|
| 298 |
+
"Câu hỏi:\n" + question_text + "\n\n"
|
| 299 |
+
"Lựa chọn:\n" + "\n".join([f"{k}: {v}" for k, v in options.items()]) + "\n\n"
|
| 300 |
+
"Đáp án:\n" + correct_text + "\n\n"
|
| 301 |
+
"Context:\n" + context_text + "\n\n"
|
| 302 |
+
"Hãy trả lời như yêu cầu."
|
| 303 |
+
)
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
raw = _post_chat([system, user], model=self.hf_model, temperature=model_verification_temperature)
|
| 307 |
+
|
| 308 |
+
# parse JSON object in response
|
| 309 |
+
try:
|
| 310 |
+
parsed = _safe_extract_json(raw)
|
| 311 |
+
except Exception as e:
|
| 312 |
+
return {"error": f"Model verification failed to return JSON: {e}", "raw": raw}
|
| 313 |
+
return parsed
|
| 314 |
+
|
| 315 |
+
# iterate MCQs
|
| 316 |
+
for qid, item in mcqs.items():
|
| 317 |
+
q_text = item.get("câu hỏi", "").strip()
|
| 318 |
+
options = item.get("lựa chọn", {})
|
| 319 |
+
correct_text = item.get("đáp án", "").strip()
|
| 320 |
+
|
| 321 |
+
# form a short declarative statement to embed: "Question: ... Answer: <correct>"
|
| 322 |
+
statement = f"{q_text} Answer: {correct_text}"
|
| 323 |
+
|
| 324 |
+
retrieved = semantic_search(statement, k=top_k)
|
| 325 |
+
evidence_list = []
|
| 326 |
+
max_sim = 0.0
|
| 327 |
+
for idx, score in retrieved:
|
| 328 |
+
if score >= evidence_score_cutoff:
|
| 329 |
+
evidence_list.append({
|
| 330 |
+
"idx": idx,
|
| 331 |
+
"page": self.metadata[idx].get("page", None),
|
| 332 |
+
"score": float(score),
|
| 333 |
+
"text": (self.texts[idx][:1000] + ("..." if len(self.texts[idx]) > 1000 else "")),
|
| 334 |
+
})
|
| 335 |
+
|
| 336 |
+
if score > max_sim:
|
| 337 |
+
max_sim = float(score)
|
| 338 |
+
|
| 339 |
+
supported_by_embeddings = max_sim >= similarity_threshold
|
| 340 |
+
|
| 341 |
+
model_verdict = None
|
| 342 |
+
if use_model_verification:
|
| 343 |
+
# build a context string from top retrieved chunks (regardless of cutoff)
|
| 344 |
+
context_parts = []
|
| 345 |
+
for ridx, sc in retrieved:
|
| 346 |
+
md = self.metadata[ridx]
|
| 347 |
+
context_parts.append(f"[page {md.get('page')}] {self.texts[ridx]}")
|
| 348 |
+
context_text = "\n\n".join(context_parts)
|
| 349 |
+
|
| 350 |
+
try:
|
| 351 |
+
parsed = _verify_with_model(q_text, options, correct_text, context_text)
|
| 352 |
+
model_verdict = parsed
|
| 353 |
+
except Exception as e:
|
| 354 |
+
model_verdict = {"error": f"verification exception: {e}"}
|
| 355 |
+
|
| 356 |
+
report[qid] = {
|
| 357 |
+
"supported_by_embeddings": bool(supported_by_embeddings),
|
| 358 |
+
"max_similarity": float(max_sim),
|
| 359 |
+
"evidence": evidence_list,
|
| 360 |
+
"model_verdict": model_verdict,
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
return report
|
| 364 |
+
|
| 365 |
+
def connect_qdrant(self, url: str, api_key: str = None, prefer_grpc: bool = False):
|
| 366 |
+
if not _HAS_QDRANT:
|
| 367 |
+
raise RuntimeError("qdrant-client is not installed. Install with `pip install qdrant-client`.")
|
| 368 |
+
self.qdrant_url = url
|
| 369 |
+
self.qdrant_api_key = api_key
|
| 370 |
+
self.qdrant_prefer_grpc = prefer_grpc
|
| 371 |
+
# Create client
|
| 372 |
+
self.qdrant = QdrantClient(url=url, api_key=api_key, prefer_grpc=prefer_grpc)
|
| 373 |
+
|
| 374 |
+
def _ensure_collection(self, collection_name: str):
|
| 375 |
+
if self.qdrant is None:
|
| 376 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 377 |
+
try:
|
| 378 |
+
# get_collection will raise if not present
|
| 379 |
+
_ = self.qdrant.get_collection(collection_name)
|
| 380 |
+
except Exception:
|
| 381 |
+
# create collection with vector size = self.dim
|
| 382 |
+
vect_params = VectorParams(size=self.dim, distance=Distance.COSINE)
|
| 383 |
+
self.qdrant.recreate_collection(collection_name=collection_name, vectors_config=vect_params)
|
| 384 |
+
# recreate_collection ensures a clean collection; if you prefer to avoid wiping use create_collection instead.
|
| 385 |
+
|
| 386 |
+
def save_pdf_to_qdrant(
|
| 387 |
+
self,
|
| 388 |
+
pdf_path: str,
|
| 389 |
+
filename: str,
|
| 390 |
+
collection: str,
|
| 391 |
+
max_chars: int = 1200,
|
| 392 |
+
batch_size: int = 64,
|
| 393 |
+
overwrite: bool = False,
|
| 394 |
+
):
|
| 395 |
+
if self.qdrant is None:
|
| 396 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 397 |
+
|
| 398 |
+
# extract pages and chunks (re-using your existing helpers)
|
| 399 |
+
pages = self.extract_pages(pdf_path)
|
| 400 |
+
all_chunks = []
|
| 401 |
+
all_meta = []
|
| 402 |
+
for p_idx, page_text in enumerate(pages, start=1):
|
| 403 |
+
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
|
| 404 |
+
for cid, ch in enumerate(chunks, start=1):
|
| 405 |
+
all_chunks.append(ch)
|
| 406 |
+
all_meta.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
|
| 407 |
+
|
| 408 |
+
if not all_chunks:
|
| 409 |
+
raise RuntimeError("No text extracted from PDF.")
|
| 410 |
+
|
| 411 |
+
# ensure collection exists
|
| 412 |
+
self._ensure_collection(collection)
|
| 413 |
+
|
| 414 |
+
# optional: delete previous points for this filename if overwrite
|
| 415 |
+
if overwrite:
|
| 416 |
+
# delete by filter: filename == filename
|
| 417 |
+
flt = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
|
| 418 |
+
try:
|
| 419 |
+
# qdrant-client delete uses delete(
|
| 420 |
+
self.qdrant.delete(collection_name=collection, filter=flt)
|
| 421 |
+
except Exception:
|
| 422 |
+
# ignore if deletion fails
|
| 423 |
+
pass
|
| 424 |
+
|
| 425 |
+
# compute embeddings in batches
|
| 426 |
+
embeddings = self.embedder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
|
| 427 |
+
embeddings = embeddings.astype("float32")
|
| 428 |
+
|
| 429 |
+
# prepare points
|
| 430 |
+
points = []
|
| 431 |
+
for i, (emb, md, txt) in enumerate(zip(embeddings, all_meta, all_chunks)):
|
| 432 |
+
pid = str(uuid4())
|
| 433 |
+
source_id = f"{filename}__p{md['page']}__c{md['chunk_id']}"
|
| 434 |
+
payload = {
|
| 435 |
+
"filename": filename,
|
| 436 |
+
"page": md["page"],
|
| 437 |
+
"chunk_id": md["chunk_id"],
|
| 438 |
+
"length": md["length"],
|
| 439 |
+
"text": txt,
|
| 440 |
+
"source_id": source_id,
|
| 441 |
+
}
|
| 442 |
+
points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload))
|
| 443 |
+
|
| 444 |
+
# upsert in batches
|
| 445 |
+
if len(points) >= batch_size:
|
| 446 |
+
self.qdrant.upsert(collection_name=collection, points=points)
|
| 447 |
+
points = []
|
| 448 |
+
|
| 449 |
+
# upsert remaining
|
| 450 |
+
if points:
|
| 451 |
+
self.qdrant.upsert(collection_name=collection, points=points)
|
| 452 |
+
|
| 453 |
+
try:
|
| 454 |
+
self.qdrant.create_payload_index(
|
| 455 |
+
collection_name=collection,
|
| 456 |
+
field_name="filename",
|
| 457 |
+
field_schema=rest.PayloadSchemaType.KEYWORD
|
| 458 |
+
)
|
| 459 |
+
except Exception as e:
|
| 460 |
+
print(f"Index creation skipped or failed: {e}")
|
| 461 |
+
|
| 462 |
+
return {"status": "ok", "uploaded_chunks": len(all_chunks), "collection": collection, "filename": filename}
|
| 463 |
+
|
| 464 |
+
def list_files_in_collection(
|
| 465 |
+
self,
|
| 466 |
+
collection: str,
|
| 467 |
+
payload_field: str = "filename",
|
| 468 |
+
batch_size: int = 500,
|
| 469 |
+
) -> List[str]:
|
| 470 |
+
if self.qdrant is None:
|
| 471 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 472 |
+
|
| 473 |
+
# ensure collection exists
|
| 474 |
+
try:
|
| 475 |
+
if not self.qdrant.collection_exists(collection):
|
| 476 |
+
raise RuntimeError(f"Collection '{collection}' does not exist.")
|
| 477 |
+
except Exception:
|
| 478 |
+
# collection_exists may raise if server unreachable
|
| 479 |
+
raise
|
| 480 |
+
|
| 481 |
+
filenames = set()
|
| 482 |
+
offset = None
|
| 483 |
+
|
| 484 |
+
while True:
|
| 485 |
+
# scroll returns (points, next_offset)
|
| 486 |
+
pts, next_offset = self.qdrant.scroll(
|
| 487 |
+
collection_name=collection,
|
| 488 |
+
limit=batch_size,
|
| 489 |
+
offset=offset,
|
| 490 |
+
with_payload=[payload_field],
|
| 491 |
+
with_vectors=False,
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
if not pts:
|
| 495 |
+
break
|
| 496 |
+
|
| 497 |
+
for p in pts:
|
| 498 |
+
# p may be a dict-like or an object with .payload
|
| 499 |
+
payload = None
|
| 500 |
+
if hasattr(p, "payload"):
|
| 501 |
+
payload = p.payload
|
| 502 |
+
elif isinstance(p, dict):
|
| 503 |
+
# older/newer variants might use nested structures: try common keys
|
| 504 |
+
payload = p.get("payload") or p.get("payload", None) or p
|
| 505 |
+
else:
|
| 506 |
+
# best-effort fallback: convert to dict if possible
|
| 507 |
+
try:
|
| 508 |
+
payload = dict(p)
|
| 509 |
+
except Exception:
|
| 510 |
+
payload = None
|
| 511 |
+
|
| 512 |
+
if not payload:
|
| 513 |
+
continue
|
| 514 |
+
|
| 515 |
+
# extract candidate value(s)
|
| 516 |
+
val = None
|
| 517 |
+
if isinstance(payload, dict):
|
| 518 |
+
val = payload.get(payload_field)
|
| 519 |
+
else:
|
| 520 |
+
# Some payload representations store fields differently; try attribute access
|
| 521 |
+
val = getattr(payload, payload_field, None)
|
| 522 |
+
|
| 523 |
+
# If value is list-like, iterate, else add single
|
| 524 |
+
if isinstance(val, (list, tuple, set)):
|
| 525 |
+
for v in val:
|
| 526 |
+
if v is not None:
|
| 527 |
+
filenames.add(str(v))
|
| 528 |
+
elif val is not None:
|
| 529 |
+
filenames.add(str(val))
|
| 530 |
+
|
| 531 |
+
# stop if no more pages
|
| 532 |
+
if not next_offset:
|
| 533 |
+
break
|
| 534 |
+
offset = next_offset
|
| 535 |
+
|
| 536 |
+
return sorted(filenames)
|
| 537 |
+
|
| 538 |
+
def list_chunks_for_filename(self, collection: str, filename: str, batch: int = 256) -> List[Dict[str, Any]]:
|
| 539 |
+
if self.qdrant is None:
|
| 540 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 541 |
+
|
| 542 |
+
results = []
|
| 543 |
+
offset = None
|
| 544 |
+
while True:
|
| 545 |
+
# scroll returns (points, next_offset)
|
| 546 |
+
points, next_offset = self.qdrant.scroll(
|
| 547 |
+
collection_name=collection,
|
| 548 |
+
scroll_filter=Filter(
|
| 549 |
+
must=[
|
| 550 |
+
FieldCondition(key="filename", match=MatchValue(value=filename))
|
| 551 |
+
]
|
| 552 |
+
),
|
| 553 |
+
limit=batch,
|
| 554 |
+
offset=offset,
|
| 555 |
+
with_payload=True,
|
| 556 |
+
with_vectors=False,
|
| 557 |
+
)
|
| 558 |
+
# points are objects (Record / ScoredPoint-like); get id and payload
|
| 559 |
+
for p in points:
|
| 560 |
+
# p.payload is a dict, p.id is point id
|
| 561 |
+
results.append({"point_id": p.id, "payload": p.payload})
|
| 562 |
+
if not next_offset:
|
| 563 |
+
break
|
| 564 |
+
offset = next_offset
|
| 565 |
+
return results
|
| 566 |
+
|
| 567 |
+
def _retrieve_qdrant(self, query: str, collection: str, filename: str = None, top_k: int = 3) -> List[Tuple[Dict[str, Any], float]]:
|
| 568 |
+
if self.qdrant is None:
|
| 569 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 570 |
+
|
| 571 |
+
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")[0].tolist()
|
| 572 |
+
q_filter = None
|
| 573 |
+
if filename:
|
| 574 |
+
q_filter = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
|
| 575 |
+
|
| 576 |
+
search_res = self.qdrant.search(
|
| 577 |
+
collection_name=collection,
|
| 578 |
+
query_vector=q_emb,
|
| 579 |
+
query_filter=q_filter,
|
| 580 |
+
limit=top_k,
|
| 581 |
+
with_payload=True,
|
| 582 |
+
with_vectors=False,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
out = []
|
| 586 |
+
for hit in search_res:
|
| 587 |
+
# hit.payload is the stored payload, hit.score is similarity
|
| 588 |
+
out.append((hit.payload, float(getattr(hit, "score", 0.0))))
|
| 589 |
+
return out
|
| 590 |
+
|
| 591 |
+
def generate_from_qdrant(
|
| 592 |
+
self,
|
| 593 |
+
filename: str,
|
| 594 |
+
collection: str,
|
| 595 |
+
n_questions: int = 10,
|
| 596 |
+
mode: str = "rag", # 'per_chunk' or 'rag'
|
| 597 |
+
questions_per_chunk: int = 3, # used for 'per_chunk'
|
| 598 |
+
top_k: int = 3, # retrieval size used in RAG
|
| 599 |
+
temperature: float = 0.2,
|
| 600 |
+
) -> Dict[str, Any]:
|
| 601 |
+
if self.qdrant is None:
|
| 602 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 603 |
+
|
| 604 |
+
# get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
|
| 605 |
+
file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
|
| 606 |
+
if not file_points:
|
| 607 |
+
raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
|
| 608 |
+
|
| 609 |
+
# create a local list of texts & metadata for sampling
|
| 610 |
+
texts = []
|
| 611 |
+
metas = []
|
| 612 |
+
for p in file_points:
|
| 613 |
+
payload = p.get("payload", {})
|
| 614 |
+
text = payload.get("text", "")
|
| 615 |
+
texts.append(text)
|
| 616 |
+
metas.append(payload)
|
| 617 |
+
|
| 618 |
+
self.texts = texts
|
| 619 |
+
self.metadata = metas
|
| 620 |
+
embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
|
| 621 |
+
if embeddings is None or len(embeddings) == 0:
|
| 622 |
+
self.embeddings = None
|
| 623 |
+
self.index = None
|
| 624 |
+
else:
|
| 625 |
+
self.embeddings = embeddings.astype("float32")
|
| 626 |
+
|
| 627 |
+
# update dim in case embedder changed unexpectedly
|
| 628 |
+
self.dim = int(self.embeddings.shape[1])
|
| 629 |
+
|
| 630 |
+
# build index
|
| 631 |
+
self._build_faiss_index()
|
| 632 |
+
|
| 633 |
+
output = {}
|
| 634 |
+
qcount = 0
|
| 635 |
+
|
| 636 |
+
if mode == "per_chunk":
|
| 637 |
+
# iterate all chunks (in payload order) and request questions_per_chunk from each
|
| 638 |
+
for i, txt in enumerate(texts):
|
| 639 |
+
if not txt.strip():
|
| 640 |
+
continue
|
| 641 |
+
to_gen = questions_per_chunk
|
| 642 |
+
try:
|
| 643 |
+
mcq_block = generate_mcqs_from_text(txt, n=to_gen, model=self.hf_model, temperature=temperature)
|
| 644 |
+
except Exception as e:
|
| 645 |
+
print(f"Generator failed on chunk (index {i}): {e}")
|
| 646 |
+
continue
|
| 647 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 648 |
+
qcount += 1
|
| 649 |
+
output[str(qcount)] = mcq_block[item]
|
| 650 |
+
if qcount >= n_questions:
|
| 651 |
+
return output
|
| 652 |
+
return output
|
| 653 |
+
|
| 654 |
+
elif mode == "rag":
|
| 655 |
+
attempts = 0
|
| 656 |
+
max_attempts = n_questions * 4
|
| 657 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 658 |
+
attempts += 1
|
| 659 |
+
# sample a seed sentence from a random chunk of this file
|
| 660 |
+
seed_idx = random.randrange(len(texts))
|
| 661 |
+
chunk = texts[seed_idx]
|
| 662 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 663 |
+
seed_sent = None
|
| 664 |
+
for s in sents:
|
| 665 |
+
if len(s.strip()) > 20:
|
| 666 |
+
seed_sent = s
|
| 667 |
+
break
|
| 668 |
+
if not seed_sent:
|
| 669 |
+
seed_sent = chunk[:200]
|
| 670 |
+
query = f"Create questions about: {seed_sent}"
|
| 671 |
+
|
| 672 |
+
# retrieve top_k chunks from the same file (restricted by filename filter)
|
| 673 |
+
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
|
| 674 |
+
context_parts = []
|
| 675 |
+
for payload, score in retrieved:
|
| 676 |
+
# payload should contain page & chunk_id and text
|
| 677 |
+
page = payload.get("page", "?")
|
| 678 |
+
ctxt = payload.get("text", "")
|
| 679 |
+
context_parts.append(f"[page {page}] {ctxt}")
|
| 680 |
+
context = "\n\n".join(context_parts)
|
| 681 |
+
|
| 682 |
+
try:
|
| 683 |
+
mcq_block = generate_mcqs_from_text(context, n=1, model=self.hf_model, temperature=temperature)
|
| 684 |
+
except Exception as e:
|
| 685 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 686 |
+
continue
|
| 687 |
+
|
| 688 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 689 |
+
qcount += 1
|
| 690 |
+
output[str(qcount)] = mcq_block[item]
|
| 691 |
+
if qcount >= n_questions:
|
| 692 |
+
return output
|
| 693 |
+
return output
|
| 694 |
+
else:
|
| 695 |
+
raise ValueError("mode must be 'per_chunk' or 'rag'.")
|
app/output.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app/software_report_template.md
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Software Report: RAG-based MCQ Generation System
|
| 2 |
+
|
| 3 |
+
## 1. Overview / Abstract
|
| 4 |
+
The project provides an API service that ingests a PDF document and automatically generates multiple–choice questions (MCQs) using a Retrieval-Augmented Generation (RAG) pipeline. It exposes a FastAPI endpoint (`/generate`) that orchestrates: PDF text extraction → chunking → embedding + indexing → (mode-dependent) context selection → MCQ generation via an LLM (Together AI chat completion) → optional semantic + model-based validation.
|
| 5 |
+
|
| 6 |
+
Core components:
|
| 7 |
+
- Controller (FastAPI endpoints) – handles HTTP, file upload, response shaping.
|
| 8 |
+
- Use Case (RAGMCQ class) – encapsulates business logic: indexing, retrieval, generation, validation.
|
| 9 |
+
- Repositories / Data Stores – implicit: in‑memory lists of chunks, embeddings, optional FAISS index.
|
| 10 |
+
|
| 11 |
+
## 2. High-Level Workflow Diagram
|
| 12 |
+
### Mermaid Activity Diagram
|
| 13 |
+
```mermaid
|
| 14 |
+
flowchart LR
|
| 15 |
+
A[Client Uploads PDF -> /generate] --> B{Mode?}
|
| 16 |
+
B -->|rag| R1[Extract & Chunk PDF]
|
| 17 |
+
B -->|per_page| R1
|
| 18 |
+
R1 --> R2[SentenceTransformer Embeddings]
|
| 19 |
+
R2 --> R3{FAISS Available?}
|
| 20 |
+
R3 -->|Yes| R4[Build FAISS Index]
|
| 21 |
+
R3 -->|No| R5[Normalize Embeddings (NumPy)]
|
| 22 |
+
R4 --> R6[Question Generation Loop]
|
| 23 |
+
R5 --> R6
|
| 24 |
+
R6 -->|rag: sample queries + retrieve top-k| R7[Assemble Context]
|
| 25 |
+
R6 -->|per_page: iterate chunks| R7
|
| 26 |
+
R7 --> G1[Prompt LLM (JSON MCQs)]
|
| 27 |
+
G1 --> P1[Parse & Validate JSON shape]
|
| 28 |
+
P1 --> C{Need more?}
|
| 29 |
+
C -->|Yes| R6
|
| 30 |
+
C -->|No| V{Validation requested?}
|
| 31 |
+
V -->|Yes| V1[Semantic Evidence Search + (Optional) Model Verification]
|
| 32 |
+
V -->|No| OUT[Return MCQs]
|
| 33 |
+
V1 --> OUT
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
### Alternative PlantUML Activity (Optional)
|
| 37 |
+
```plantuml
|
| 38 |
+
@startuml
|
| 39 |
+
start
|
| 40 |
+
:Upload PDF (multipart form);
|
| 41 |
+
:Select params (mode, n_questions,...);
|
| 42 |
+
:Extract pages via pdfplumber;
|
| 43 |
+
:Chunk text (sentence pack <= max_chars);
|
| 44 |
+
:Embed chunks (SentenceTransformer);
|
| 45 |
+
if (FAISS installed?) then (yes)
|
| 46 |
+
:Build FAISS IndexFlatIP + L2 normalize;
|
| 47 |
+
else (no)
|
| 48 |
+
:Keep normalized NumPy embeddings;
|
| 49 |
+
endif
|
| 50 |
+
repeat
|
| 51 |
+
if (mode == per_page) then (per_page)
|
| 52 |
+
:Take next chunk;
|
| 53 |
+
else (rag)
|
| 54 |
+
:Sample seed sentence;
|
| 55 |
+
:Encode query & retrieve top-k chunks;
|
| 56 |
+
endif
|
| 57 |
+
:Assemble context;
|
| 58 |
+
:Call Together AI chat completion (prompt -> JSON);
|
| 59 |
+
:Parse JSON + accumulate MCQs;
|
| 60 |
+
repeat while (Need more questions?) is (yes)
|
| 61 |
+
end repeat
|
| 62 |
+
if (validate?) then (yes)
|
| 63 |
+
:For each Q -> build statement;
|
| 64 |
+
:Similarity search top_k evidence;
|
| 65 |
+
if (Insufficient sim & model verify on) then (yes)
|
| 66 |
+
:Call model for verification JSON;
|
| 67 |
+
endif
|
| 68 |
+
:Build validation report;
|
| 69 |
+
endif
|
| 70 |
+
:Return response JSON;
|
| 71 |
+
stop
|
| 72 |
+
@enduml
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## 3. Repository–Controller–Use Case Abstraction
|
| 76 |
+
| Layer | Responsibility | In This Project |
|
| 77 |
+
|-------|---------------|-----------------|
|
| 78 |
+
| Controller | HTTP I/O, request validation, mapping domain results to API schema | `app.py` endpoints (`/health`, `/generate`) |
|
| 79 |
+
| Use Case | Orchestrates domain flow, independent of HTTP details | `RAGMCQ` methods: `build_index_from_pdf`, `generate_from_pdf`, `validate_mcqs` |
|
| 80 |
+
| Repository (implicit) | Data persistence / retrieval | In-memory: `texts`, `metadata`, `embeddings`, `FAISS index` (no external DB) |
|
| 81 |
+
|
| 82 |
+
Data Flow (simplified):
|
| 83 |
+
Client → Controller(`/generate`) → UseCase(`generate_from_pdf`) → (Extract + Chunk + Embed + Index + Retrieve + Generate) → Controller (normalize/optional validation) → Response
|
| 84 |
+
|
| 85 |
+
## 4. Detailed Pipeline Explanation
|
| 86 |
+
### 4.1 PDF Text Extraction & Chunking
|
| 87 |
+
- File saved to a temp path, then `pdfplumber` loads each page.
|
| 88 |
+
- `extract_pages()` returns list of raw page strings.
|
| 89 |
+
- `chunk_text()` packs sentences (regex split on punctuation boundaries) into segments up to `max_chars` (default 1200). If a sentence overflows, the existing chunk is flushed. Residual oversize chunks are hard-split.
|
| 90 |
+
- Metadata collected: page number, chunk id, length.
|
| 91 |
+
|
| 92 |
+
### 4.2 Embedding Generation
|
| 93 |
+
- Model: `sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2` loaded via `SentenceTransformer`.
|
| 94 |
+
- Batched encoding of all chunks → NumPy array (float32).
|
| 95 |
+
- If FAISS installed: L2 normalize embeddings, create `IndexFlatIP` (inner product ~ cosine after normalization), add embeddings.
|
| 96 |
+
- Else: manually store normalized embeddings for brute-force cosine similarity with matrix multiply.
|
| 97 |
+
|
| 98 |
+
### 4.3 Retrieval Strategy
|
| 99 |
+
Two modes:
|
| 100 |
+
1. `per_page`: Sequentially process each chunk; each call to the LLM asks for `questions_per_page` new MCQs until target `n_questions` reached.
|
| 101 |
+
2. `rag`: Loop builds a synthetic query by sampling a random chunk and a sentence. Retrieval:
|
| 102 |
+
- Encode query → similarity search (FAISS or NumPy).
|
| 103 |
+
- Take top-k chunk texts; join them with page tags as context.
|
| 104 |
+
- Request 1 question per iteration (promotes diversity). Up to `max_attempts = n_questions * 4`.
|
| 105 |
+
|
| 106 |
+
Similarity Metric: Inner product on normalized vectors (equivalent to cosine). Sorting by descending similarity.
|
| 107 |
+
|
| 108 |
+
### 4.4 Question Generation Prompt Template
|
| 109 |
+
Implemented in `generate_mcqs_from_text` (utils):
|
| 110 |
+
- System message (Vietnamese) forcing strict JSON schema:
|
| 111 |
+
```json
|
| 112 |
+
{
|
| 113 |
+
"1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"..."},
|
| 114 |
+
"2": { ... }
|
| 115 |
+
}
|
| 116 |
+
```
|
| 117 |
+
- Constraints: exactly `n` entries; answer must be full text identical to one option; no explanations.
|
| 118 |
+
- User message: instructs generation from provided source text only.
|
| 119 |
+
- Post-processing: Regex extracts first JSON object; attempts `json.loads`; fallback removes trailing commas.
|
| 120 |
+
|
| 121 |
+
### 4.5 Validation (Optional)
|
| 122 |
+
For each MCQ (after normalization in controller):
|
| 123 |
+
1. Construct statement: `Question + Answer`.
|
| 124 |
+
2. Embed query → retrieve top_k evidence chunks.
|
| 125 |
+
3. Mark `supported_by_embeddings` if max similarity ≥ threshold.
|
| 126 |
+
4. If not supported and model verification enabled, call verification LLM prompt (also JSON-only) to assess `supported`, `confidence`, `evidence`, `reason`.
|
| 127 |
+
|
| 128 |
+
### 4.6 Together AI Integration
|
| 129 |
+
- Endpoint: `https://api.together.xyz/v1/chat/completions`.
|
| 130 |
+
- Authorization header uses `TOGETHER_KEY` environment variable.
|
| 131 |
+
- Payload: `{ model, messages, temperature }`.
|
| 132 |
+
- Response Handling: support both OpenAI-like `choices[0].message.content` and fallback `choices[0].text`.
|
| 133 |
+
|
| 134 |
+
## 5. API Endpoints
|
| 135 |
+
### 5.1 Health Check
|
| 136 |
+
GET `/health`
|
| 137 |
+
Response:
|
| 138 |
+
```json
|
| 139 |
+
{ "status": "ok", "ready": true }
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### 5.2 Generate MCQs
|
| 143 |
+
POST `/generate` (multipart/form-data)
|
| 144 |
+
Fields:
|
| 145 |
+
- `file` (PDF) – required
|
| 146 |
+
- `n_questions` (int, default 10)
|
| 147 |
+
- `mode` ("rag" | "per_page", default "rag")
|
| 148 |
+
- `questions_per_page` (int, default 3) – used only in per_page mode
|
| 149 |
+
- `top_k` (int, default 3) – retrieval depth (rag & validation)
|
| 150 |
+
- `temperature` (float, default 0.2)
|
| 151 |
+
- `validate` (bool, default false)
|
| 152 |
+
- `debug` (bool) – if truthy writes `output.json` locally
|
| 153 |
+
|
| 154 |
+
Example Request (curl, PowerShell style quoting simplified):
|
| 155 |
+
```bash
|
| 156 |
+
curl -X POST http://localhost:8000/generate ^
|
| 157 |
+
-F "file=@sample.pdf" ^
|
| 158 |
+
-F "n_questions=5" ^
|
| 159 |
+
-F "mode=rag" ^
|
| 160 |
+
-F "top_k=3" ^
|
| 161 |
+
-F "validate=true"
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
Success Response (validation on, abbreviated):
|
| 165 |
+
```json
|
| 166 |
+
{
|
| 167 |
+
"mcqs": {
|
| 168 |
+
"1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án": "..."},
|
| 169 |
+
"2": { "câu hỏi": "...", "lựa chọn": { ... }, "đáp án": "..." }
|
| 170 |
+
},
|
| 171 |
+
"validation": {
|
| 172 |
+
"1": {
|
| 173 |
+
"supported_by_embeddings": true,
|
| 174 |
+
"max_similarity": 0.83,
|
| 175 |
+
"evidence": [ { "page": 2, "score": 0.81, "text": "Excerpt..." } ],
|
| 176 |
+
"model_verdict": null
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
Error Examples:
|
| 183 |
+
- 400: non-PDF upload
|
| 184 |
+
- 500: generation pipeline error (e.g., empty PDF or model failure)
|
| 185 |
+
- 503: service not initialized
|
| 186 |
+
|
| 187 |
+
## 6. Data Structures & Types (Conceptual)
|
| 188 |
+
- Chunk: `{ text: str, page: int, chunk_id: int, length: int }`
|
| 189 |
+
- MCQ (generated raw): `{ "câu hỏi": str, "lựa chọn": {"a": str, ...}, "đáp án": str }`
|
| 190 |
+
- Normalized MCQ (API shaping): `{ mcq: str, options: { .. }, correct: str }`
|
| 191 |
+
- Validation Entry: `{ supported_by_embeddings: bool, max_similarity: float, evidence: [ {page, score, text}... ], model_verdict?: {...} }`
|
| 192 |
+
|
| 193 |
+
## 7. Configuration Points
|
| 194 |
+
| Parameter | Location | Purpose |
|
| 195 |
+
|-----------|----------|---------|
|
| 196 |
+
| `embedder_model` | `RAGMCQ.__init__` | Pretrained SentenceTransformer model name |
|
| 197 |
+
| `hf_model` | `RAGMCQ.__init__` | LLM model name for generation/verification |
|
| 198 |
+
| `top_k` | API form field & internal methods | Retrieval depth |
|
| 199 |
+
| `temperature` | API form field | Creativity vs determinism |
|
| 200 |
+
| `questions_per_page` | API form field | Batch size per chunk in per_page mode |
|
| 201 |
+
|
| 202 |
+
## 8. Simple Code Improvements (Quick Wins)
|
| 203 |
+
Below are low-risk refactors to make the code cleaner and more maintainable:
|
| 204 |
+
|
| 205 |
+
1. Environment Variable Safety:
|
| 206 |
+
```python
|
| 207 |
+
def _require_env(name: str) -> str:
|
| 208 |
+
val = os.getenv(name)
|
| 209 |
+
if not val:
|
| 210 |
+
raise RuntimeError(f"Missing required environment variable: {name}")
|
| 211 |
+
return val
|
| 212 |
+
TOGETHER_KEY = _require_env("TOGETHER_KEY")
|
| 213 |
+
```
|
| 214 |
+
2. Remove Unused Constant: `API_URL` in `utils.py` is unused (can delete to avoid confusion).
|
| 215 |
+
3. Unify Header Construction: Replace separate `HEADERS` / `TOGETHER_HEADERS` with a single function `auth_headers(provider)` that returns the correct dict.
|
| 216 |
+
4. Add Dataclass for MCQ:
|
| 217 |
+
```python
|
| 218 |
+
from dataclasses import dataclass
|
| 219 |
+
@dataclass
|
| 220 |
+
class MCQ: question: str; options: Dict[str,str]; answer: str
|
| 221 |
+
```
|
| 222 |
+
Helps type clarity in validation.
|
| 223 |
+
5. Extract Prompt Templates: Store system/user template strings as module-level constants to avoid duplication and ease future edits.
|
| 224 |
+
6. Fail-Fast on Empty PDF: Early check after extraction to return a user-friendly error message rather than a generic 500 later.
|
| 225 |
+
7. Replace Random Query Sampling Magic Numbers: Expose `max_attempts_factor` as a parameter (currently `n_questions * 4`).
|
| 226 |
+
8. Vector Normalization Consistency: Always keep an unnormalized copy if future scoring types are needed; currently normalization overwrites original when FAISS absent.
|
| 227 |
+
9. Logging Standardization: Replace scattered `print()` with Python `logging` module (configurable levels; avoids polluting stdout in production).
|
| 228 |
+
10. Validation Normalization: Move `_normalize_mcqs` from `app.py` into `RAGMCQ` (keeps domain logic together; controller stays thin).
|
| 229 |
+
11. Error Message Specificity: On generation failure wrap exceptions with context (page/chunk), but avoid leaking internal stack to clients; log full internally.
|
| 230 |
+
12. Dependency Pinning: Specify versions in `requirements.txt` for reproducibility (e.g., `sentence-transformers==2.2.2`).
|
| 231 |
+
13. Add `/models` Endpoint (Optional): Expose available embedder & generation models for UI introspection.
|
| 232 |
+
14. Add Basic Tests: e.g., a test for `chunk_text` (ensures boundaries) and JSON parsing fallback.
|
| 233 |
+
15. Reusable Retrieval: Expose a public `retrieve(query, top_k)` method to support future features (like user-specified queries) without duplicating private logic.
|
| 234 |
+
|
| 235 |
+
## 9. Potential Medium-Term Enhancements
|
| 236 |
+
| Area | Improvement |
|
| 237 |
+
|------|-------------|
|
| 238 |
+
| Prompt Robustness | Add JSON schema validation (e.g., `jsonschema`) & auto-regeneration for malformed outputs |
|
| 239 |
+
| Performance | Embed asynchronously / stream generation if backend supports it |
|
| 240 |
+
| Multi-Provider | Abstract provider strategy for HuggingFace, Together, OpenAI with pluggable client classes |
|
| 241 |
+
| Caching | Cache embeddings per PDF hash to avoid reprocessing identical documents |
|
| 242 |
+
| Analytics | Track generation latency, validation pass rate, average similarity in structured logs |
|
| 243 |
+
| i18n | Parameterize language; currently prompts in Vietnamese only |
|
| 244 |
+
|
| 245 |
+
## 10. Security & Operational Notes
|
| 246 |
+
- Ensure `TOGETHER_KEY` is not committed; rely on environment variables / secret managers.
|
| 247 |
+
- Limit PDF size and number of pages to prevent excessive memory or token usage.
|
| 248 |
+
- Consider sanitizing extracted text (remove personally identifiable info) before sending to LLM if sensitive documents are used.
|
| 249 |
+
- Add request timeout & retry logic for the LLM API (current single call may raise immediately).
|
| 250 |
+
|
| 251 |
+
## 11. Quick Start (Local)
|
| 252 |
+
1. Set API key: `setx TOGETHER_KEY "your_api_key"` (then restart shell).
|
| 253 |
+
2. Install dependencies: `pip install -r requirements.txt`.
|
| 254 |
+
3. Run API: `uvicorn app:app --reload`.
|
| 255 |
+
4. POST a PDF to `/generate`.
|
| 256 |
+
|
| 257 |
+
## 12. Summary
|
| 258 |
+
The system cleanly separates HTTP handling from the core RAG pipeline. Text is chunked at sentence boundaries, embedded, indexed (FAISS if available), and retrieved to assemble focused contexts that guide a JSON-constrained MCQ generation prompt. Optional validation uses embedding similarity and secondary model verification to flag unsupported questions. Suggested refactors improve safety, clarity, extensibility, and readiness for multi-provider expansion.
|
| 259 |
+
|
| 260 |
+
---
|
| 261 |
+
This report delivers architectural insight, workflow diagrams, detailed pipeline mechanics, API contract, and actionable improvement ideas for rapid comprehension and iteration.
|
app/utils.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
import requests
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
#TODO: allow to choose different provider later + dynamic routing when token expired
|
| 8 |
+
API_URL = "https://api.cerebras.ai/v1/chat/completions"
|
| 9 |
+
CEREBRAS_API_KEY = os.environ['CEREBRAS_API_KEY']
|
| 10 |
+
|
| 11 |
+
HEADERS = {"Authorization": f"Bearer {CEREBRAS_API_KEY}"}
|
| 12 |
+
JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
|
| 13 |
+
|
| 14 |
+
def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
|
| 15 |
+
payload = {"model": model, "messages": messages, "temperature": temperature}
|
| 16 |
+
resp = requests.post(API_URL, headers=HEADERS, json=payload, timeout=timeout)
|
| 17 |
+
resp.raise_for_status()
|
| 18 |
+
data = resp.json()
|
| 19 |
+
|
| 20 |
+
# handle various shapes
|
| 21 |
+
if "choices" in data and len(data["choices"]) > 0:
|
| 22 |
+
# prefer message.content
|
| 23 |
+
ch = data["choices"][0]
|
| 24 |
+
|
| 25 |
+
if isinstance(ch, dict) and "message" in ch and "content" in ch["message"]:
|
| 26 |
+
return ch["message"]["content"]
|
| 27 |
+
|
| 28 |
+
if "text" in ch:
|
| 29 |
+
return ch["text"]
|
| 30 |
+
|
| 31 |
+
# final fallback
|
| 32 |
+
raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
|
| 33 |
+
|
| 34 |
+
def _safe_extract_json(text: str) -> dict:
|
| 35 |
+
# remove triple backticks
|
| 36 |
+
text = re.sub(r"```(?:json)?\n?", "", text)
|
| 37 |
+
m = JSON_OBJ_RE.search(text)
|
| 38 |
+
|
| 39 |
+
if not m:
|
| 40 |
+
raise ValueError("No JSON object found in model output.")
|
| 41 |
+
js = m.group(1)
|
| 42 |
+
|
| 43 |
+
# try load, fix trailing commas
|
| 44 |
+
try:
|
| 45 |
+
return json.loads(js)
|
| 46 |
+
except json.JSONDecodeError:
|
| 47 |
+
fixed = re.sub(r",\s*([}\]])", r"\1", js)
|
| 48 |
+
return json.loads(fixed)
|
| 49 |
+
|
| 50 |
+
def generate_mcqs_from_text(
|
| 51 |
+
source_text: str,
|
| 52 |
+
n: int = 3,
|
| 53 |
+
model: str = "gpt-oss-120b",
|
| 54 |
+
temperature: float = 0.2,
|
| 55 |
+
) -> Dict[str, Any]:
|
| 56 |
+
system_message = {
|
| 57 |
+
"role": "system",
|
| 58 |
+
"content": (
|
| 59 |
+
"Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm. "
|
| 60 |
+
"Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không có bất kỳ văn bản nào khác:\n\n"
|
| 61 |
+
"{\n"
|
| 62 |
+
' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"..."},\n'
|
| 63 |
+
' "2": { ... }\n'
|
| 64 |
+
"}\n\n"
|
| 65 |
+
"Lưu ý:\n"
|
| 66 |
+
f"- Tạo đúng {n} mục, đánh YOUR_API_KEYsố từ 1 tới {n}.\n"
|
| 67 |
+
"- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
|
| 68 |
+
"- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'lựa chọn'.\n"
|
| 69 |
+
"- Không kèm giải thích hay trường thêm.\n"
|
| 70 |
+
"- Các phương án sai (distractors) phải hợp lý và không lặp lại."
|
| 71 |
+
)
|
| 72 |
+
}
|
| 73 |
+
user_message = {
|
| 74 |
+
"role": "user",
|
| 75 |
+
"content": (
|
| 76 |
+
f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Dùng nội dung này làm nguồn duy nhất để trả lời."
|
| 77 |
+
"Nếu nội dung quá ít để tạo câu hỏi chính xác, hãy tạo các phương án hợp lý nhưng có thể biện minh được.\n\n"
|
| 78 |
+
f"Nội dung:\n\n{source_text}"
|
| 79 |
+
)
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
|
| 83 |
+
parsed = _safe_extract_json(raw)
|
| 84 |
+
|
| 85 |
+
# validate structure and length
|
| 86 |
+
if not isinstance(parsed, dict) or len(parsed) != n:
|
| 87 |
+
raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
|
| 88 |
+
return parsed
|
generator.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import List, Tuple, Dict, Any, Optional
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
from uuid import uuid4
|
| 7 |
+
import pymupdf4llm
|
| 8 |
+
import pymupdf as fitz
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from qdrant_client import QdrantClient
|
| 12 |
+
from qdrant_client.http.models import (
|
| 13 |
+
PointStruct,
|
| 14 |
+
Filter,
|
| 15 |
+
FieldCondition,
|
| 16 |
+
MatchValue,
|
| 17 |
+
Distance,
|
| 18 |
+
VectorParams,
|
| 19 |
+
)
|
| 20 |
+
from qdrant_client.http import models as rest
|
| 21 |
+
_HAS_QDRANT = True
|
| 22 |
+
except Exception:
|
| 23 |
+
_HAS_QDRANT = False
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
import faiss
|
| 27 |
+
_HAS_FAISS = True
|
| 28 |
+
except Exception:
|
| 29 |
+
_HAS_FAISS = False
|
| 30 |
+
|
| 31 |
+
from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json
|
| 32 |
+
|
| 33 |
+
class RAGMCQ:
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
| 37 |
+
hf_model: str = "openai/gpt-oss-120b:cerebras",
|
| 38 |
+
qdrant_url: str = None,
|
| 39 |
+
qdrant_api_key: str = None,
|
| 40 |
+
qdrant_prefer_grpc: bool = False,
|
| 41 |
+
):
|
| 42 |
+
self.embedder = SentenceTransformer(embedder_model)
|
| 43 |
+
self.hf_model = hf_model
|
| 44 |
+
self.embeddings = None # np.array of shape (N, D)
|
| 45 |
+
self.texts = [] # list of chunk texts
|
| 46 |
+
self.metadata = [] # list of dicts (page, chunk_id, char_range)
|
| 47 |
+
self.index = None
|
| 48 |
+
self.dim = self.embedder.get_sentence_embedding_dimension()
|
| 49 |
+
|
| 50 |
+
self.qdrant = None
|
| 51 |
+
self.qdrant_url = qdrant_url
|
| 52 |
+
self.qdrant_api_key = qdrant_api_key
|
| 53 |
+
self.qdrant_prefer_grpc = qdrant_prefer_grpc
|
| 54 |
+
if qdrant_url:
|
| 55 |
+
self.connect_qdrant(qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
|
| 56 |
+
|
| 57 |
+
def extract_pages(
|
| 58 |
+
self,
|
| 59 |
+
pdf_path: str,
|
| 60 |
+
*,
|
| 61 |
+
pages: Optional[List[int]] = None,
|
| 62 |
+
ignore_images: bool = False,
|
| 63 |
+
dpi: int = 150
|
| 64 |
+
) -> List[str]:
|
| 65 |
+
doc = fitz.open(pdf_path)
|
| 66 |
+
try:
|
| 67 |
+
# request page-wise output (page_chunks=True -> list[dict] per page)
|
| 68 |
+
page_dicts = pymupdf4llm.to_markdown(
|
| 69 |
+
doc,
|
| 70 |
+
pages=pages,
|
| 71 |
+
ignore_images=ignore_images,
|
| 72 |
+
dpi=dpi,
|
| 73 |
+
page_chunks=True,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# to_markdown(..., page_chunks=True) returns a list of dicts, each has key "text" (markdown)
|
| 77 |
+
pages_md: List[str] = []
|
| 78 |
+
for p in page_dicts:
|
| 79 |
+
txt = p.get("text", "") or ""
|
| 80 |
+
pages_md.append(txt.strip())
|
| 81 |
+
|
| 82 |
+
return pages_md
|
| 83 |
+
finally:
|
| 84 |
+
doc.close()
|
| 85 |
+
|
| 86 |
+
def chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 100) -> List[str]:
|
| 87 |
+
text = text.strip()
|
| 88 |
+
if not text:
|
| 89 |
+
return []
|
| 90 |
+
if len(text) <= max_chars:
|
| 91 |
+
return [text]
|
| 92 |
+
|
| 93 |
+
# split by sentence-like boundaries
|
| 94 |
+
sentences = re.split(r'(?<=[\.\?\!])\s+', text)
|
| 95 |
+
chunks = []
|
| 96 |
+
cur = ""
|
| 97 |
+
for s in sentences:
|
| 98 |
+
if len(cur) + len(s) + 1 <= max_chars:
|
| 99 |
+
cur += (" " if cur else "") + s
|
| 100 |
+
else:
|
| 101 |
+
if cur:
|
| 102 |
+
chunks.append(cur)
|
| 103 |
+
cur = (cur[-overlap:] + " " + s) if overlap > 0 else s
|
| 104 |
+
if cur:
|
| 105 |
+
chunks.append(cur)
|
| 106 |
+
|
| 107 |
+
# if still too long, hard-split
|
| 108 |
+
final = []
|
| 109 |
+
for c in chunks:
|
| 110 |
+
if len(c) <= max_chars:
|
| 111 |
+
final.append(c)
|
| 112 |
+
else:
|
| 113 |
+
for i in range(0, len(c), max_chars):
|
| 114 |
+
final.append(c[i:i+max_chars])
|
| 115 |
+
return final
|
| 116 |
+
|
| 117 |
+
def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
|
| 118 |
+
pages = self.extract_pages(pdf_path)
|
| 119 |
+
self.texts = []
|
| 120 |
+
self.metadata = []
|
| 121 |
+
|
| 122 |
+
for p_idx, page_text in enumerate(pages, start=1):
|
| 123 |
+
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
|
| 124 |
+
for cid, ch in enumerate(chunks, start=1):
|
| 125 |
+
self.texts.append(ch)
|
| 126 |
+
self.metadata.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
|
| 127 |
+
|
| 128 |
+
if not self.texts:
|
| 129 |
+
raise RuntimeError("No text extracted from PDF.")
|
| 130 |
+
|
| 131 |
+
# compute embeddings
|
| 132 |
+
emb = self.embedder.encode(self.texts, convert_to_numpy=True, show_progress_bar=True)
|
| 133 |
+
self.embeddings = emb.astype("float32")
|
| 134 |
+
self._build_faiss_index()
|
| 135 |
+
|
| 136 |
+
def _build_faiss_index(self, ef_construction=200, M=32):
|
| 137 |
+
if _HAS_FAISS:
|
| 138 |
+
d = self.embeddings.shape[1]
|
| 139 |
+
index = faiss.IndexHNSWFlat(d, M)
|
| 140 |
+
faiss.normalize_L2(self.embeddings)
|
| 141 |
+
index.add(self.embeddings)
|
| 142 |
+
index.hnsw.efConstruction = ef_construction
|
| 143 |
+
self.index = index
|
| 144 |
+
else:
|
| 145 |
+
# store normalized embeddings and use brute-force numpy
|
| 146 |
+
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-10
|
| 147 |
+
self.embeddings = self.embeddings / norms
|
| 148 |
+
self.index = None
|
| 149 |
+
|
| 150 |
+
def _retrieve(self, query: str, top_k: int = 3) -> List[Tuple[int, float]]:
|
| 151 |
+
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")
|
| 152 |
+
|
| 153 |
+
if _HAS_FAISS:
|
| 154 |
+
faiss.normalize_L2(q_emb)
|
| 155 |
+
D_list, I_list = self.index.search(q_emb, top_k)
|
| 156 |
+
# D are inner products; return list of (idx, score)
|
| 157 |
+
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
|
| 158 |
+
else:
|
| 159 |
+
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
|
| 160 |
+
sims = (self.embeddings @ qn.T).squeeze(axis=1)
|
| 161 |
+
idxs = np.argsort(-sims)[:top_k]
|
| 162 |
+
return [(int(i), float(sims[i])) for i in idxs]
|
| 163 |
+
|
| 164 |
+
def generate_from_pdf(
|
| 165 |
+
self,
|
| 166 |
+
pdf_path: str,
|
| 167 |
+
n_questions: int = 10,
|
| 168 |
+
mode: str = "rag", # per_page or rag
|
| 169 |
+
questions_per_page: int = 3, # for per_page mode
|
| 170 |
+
top_k: int = 3, # chunks to retrieve for each question in rag mode
|
| 171 |
+
temperature: float = 0.2,
|
| 172 |
+
) -> Dict[str, Any]:
|
| 173 |
+
# build index
|
| 174 |
+
self.build_index_from_pdf(pdf_path)
|
| 175 |
+
|
| 176 |
+
output: Dict[str, Any] = {}
|
| 177 |
+
qcount = 0
|
| 178 |
+
|
| 179 |
+
if mode == "per_page":
|
| 180 |
+
# iterate pages -> chunks
|
| 181 |
+
for idx, meta in enumerate(self.metadata):
|
| 182 |
+
chunk_text = self.texts[idx]
|
| 183 |
+
|
| 184 |
+
if not chunk_text.strip():
|
| 185 |
+
continue
|
| 186 |
+
to_gen = questions_per_page
|
| 187 |
+
|
| 188 |
+
# ask generator
|
| 189 |
+
try:
|
| 190 |
+
mcq_block = generate_mcqs_from_text(
|
| 191 |
+
chunk_text, n=to_gen, model=self.hf_model, temperature=temperature
|
| 192 |
+
)
|
| 193 |
+
except Exception as e:
|
| 194 |
+
# skip this chunk if generator fails
|
| 195 |
+
print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 199 |
+
qcount += 1
|
| 200 |
+
output[str(qcount)] = mcq_block[item]
|
| 201 |
+
if qcount >= n_questions:
|
| 202 |
+
return output
|
| 203 |
+
|
| 204 |
+
return output
|
| 205 |
+
|
| 206 |
+
elif mode == "rag":
|
| 207 |
+
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
|
| 208 |
+
# create queries by sampling chunk text sentences.
|
| 209 |
+
# stop when n_questions reached or max_attempts exceeded.
|
| 210 |
+
attempts = 0
|
| 211 |
+
max_attempts = n_questions * 4
|
| 212 |
+
|
| 213 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 214 |
+
attempts += 1
|
| 215 |
+
# create a seed query: pick a random chunk, pick a sentence from it
|
| 216 |
+
seed_idx = random.randrange(len(self.texts))
|
| 217 |
+
chunk = self.texts[seed_idx]
|
| 218 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 219 |
+
seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
|
| 220 |
+
query = f"Create questions about: {seed_sent}"
|
| 221 |
+
|
| 222 |
+
# retrieve top_k chunks
|
| 223 |
+
retrieved = self._retrieve(query, top_k=top_k)
|
| 224 |
+
context_parts = []
|
| 225 |
+
for ridx, score in retrieved:
|
| 226 |
+
md = self.metadata[ridx]
|
| 227 |
+
context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
|
| 228 |
+
context = "\n\n".join(context_parts)
|
| 229 |
+
|
| 230 |
+
# call generator for 1 question (or small batch) with the retrieved context
|
| 231 |
+
try:
|
| 232 |
+
# request 1 question at a time to keep diversity
|
| 233 |
+
mcq_block = generate_mcqs_from_text(
|
| 234 |
+
context, n=1, model=self.hf_model, temperature=temperature
|
| 235 |
+
)
|
| 236 |
+
except Exception as e:
|
| 237 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
# append result(s)
|
| 241 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 242 |
+
qcount += 1
|
| 243 |
+
output[str(qcount)] = mcq_block[item]
|
| 244 |
+
if qcount >= n_questions:
|
| 245 |
+
return output
|
| 246 |
+
|
| 247 |
+
return output
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError("mode must be 'per_page' or 'rag'.")
|
| 250 |
+
|
| 251 |
+
def validate_mcqs(
|
| 252 |
+
self,
|
| 253 |
+
mcqs: Dict[str, Any],
|
| 254 |
+
top_k: int = 4,
|
| 255 |
+
similarity_threshold: float = 0.5,
|
| 256 |
+
evidence_score_cutoff: float = 0.5,
|
| 257 |
+
use_model_verification: bool = True,
|
| 258 |
+
model_verification_temperature: float = 0.0,
|
| 259 |
+
) -> Dict[str, Any]:
|
| 260 |
+
if self.embeddings is None or not self.texts:
|
| 261 |
+
raise RuntimeError("Index/embeddings not built. Run build_index_from_pdf() first.")
|
| 262 |
+
|
| 263 |
+
report: Dict[str, Any] = {}
|
| 264 |
+
|
| 265 |
+
# helper: semantic similarity search on statement -> returns list of (idx, score)
|
| 266 |
+
def semantic_search(statement: str, k: int = top_k):
|
| 267 |
+
q_emb = self.embedder.encode([statement], convert_to_numpy=True).astype("float32")
|
| 268 |
+
|
| 269 |
+
if _HAS_FAISS:
|
| 270 |
+
faiss.normalize_L2(q_emb)
|
| 271 |
+
D_list, I_list = self.index.search(q_emb, k)
|
| 272 |
+
# D are inner products; return list of (idx, score)
|
| 273 |
+
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
|
| 274 |
+
else:
|
| 275 |
+
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
|
| 276 |
+
sims = (self.embeddings @ qn.T).squeeze(axis=1)
|
| 277 |
+
idxs = np.argsort(-sims)[:k]
|
| 278 |
+
return [(int(i), float(sims[i])) for i in idxs]
|
| 279 |
+
|
| 280 |
+
# helper: verify with model (strict JSON in response)
|
| 281 |
+
def _verify_with_model(question_text: str, options: Dict[str, str], correct_text: str, context_text: str):
|
| 282 |
+
system = {
|
| 283 |
+
"role": "system",
|
| 284 |
+
"content": (
|
| 285 |
+
"Bạn là một trợ lý đánh giá tính thực chứng của câu hỏi trắc nghiệm dựa trên đoạn văn được cung cấp. "
|
| 286 |
+
"Hãy trả lời DUY NHẤT bằng JSON hợp lệ (không có văn bản khác) theo schema:\n\n"
|
| 287 |
+
"{\n"
|
| 288 |
+
' "supported": true/false, # câu trả lời đúng có được nội dung chứng thực không\n'
|
| 289 |
+
' "confidence": 0.0-1.0, # mức độ tự tin (số)\n'
|
| 290 |
+
' "evidence": "cụm văn bản ngắn làm bằng chứng hoặc trích dẫn",\n'
|
| 291 |
+
' "reason": "ngắn gọn, vì sao supported hoặc không"\n'
|
| 292 |
+
"}\n\n"
|
| 293 |
+
"Luôn dựa chỉ trên nội dung trong trường 'Context' dưới đây. Nếu nội dung không chứa bằng chứng, trả về supported: false."
|
| 294 |
+
)
|
| 295 |
+
}
|
| 296 |
+
user = {
|
| 297 |
+
"role": "user",
|
| 298 |
+
"content": (
|
| 299 |
+
"Câu hỏi:\n" + question_text + "\n\n"
|
| 300 |
+
"Lựa chọn:\n" + "\n".join([f"{k}: {v}" for k, v in options.items()]) + "\n\n"
|
| 301 |
+
"Đáp án:\n" + correct_text + "\n\n"
|
| 302 |
+
"Context:\n" + context_text + "\n\n"
|
| 303 |
+
"Hãy trả lời như yêu cầu."
|
| 304 |
+
)
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
raw = _post_chat([system, user], model=self.hf_model, temperature=model_verification_temperature)
|
| 308 |
+
|
| 309 |
+
# parse JSON object in response
|
| 310 |
+
try:
|
| 311 |
+
parsed = _safe_extract_json(raw)
|
| 312 |
+
except Exception as e:
|
| 313 |
+
return {"error": f"Model verification failed to return JSON: {e}", "raw": raw}
|
| 314 |
+
return parsed
|
| 315 |
+
|
| 316 |
+
# iterate MCQs
|
| 317 |
+
for qid, item in mcqs.items():
|
| 318 |
+
q_text = item.get("câu hỏi", "").strip()
|
| 319 |
+
options = item.get("lựa chọn", {})
|
| 320 |
+
correct_text = item.get("đáp án", "").strip()
|
| 321 |
+
|
| 322 |
+
# form a short declarative statement to embed: "Question: ... Answer: <correct>"
|
| 323 |
+
statement = f"{q_text} Answer: {correct_text}"
|
| 324 |
+
|
| 325 |
+
retrieved = semantic_search(statement, k=top_k)
|
| 326 |
+
evidence_list = []
|
| 327 |
+
max_sim = 0.0
|
| 328 |
+
for idx, score in retrieved:
|
| 329 |
+
if score >= evidence_score_cutoff:
|
| 330 |
+
evidence_list.append({
|
| 331 |
+
"idx": idx,
|
| 332 |
+
"page": self.metadata[idx].get("page", None),
|
| 333 |
+
"score": float(score),
|
| 334 |
+
"text": (self.texts[idx][:1000] + ("..." if len(self.texts[idx]) > 1000 else "")),
|
| 335 |
+
})
|
| 336 |
+
|
| 337 |
+
if score > max_sim:
|
| 338 |
+
max_sim = float(score)
|
| 339 |
+
|
| 340 |
+
supported_by_embeddings = max_sim >= similarity_threshold
|
| 341 |
+
|
| 342 |
+
model_verdict = None
|
| 343 |
+
if use_model_verification:
|
| 344 |
+
# build a context string from top retrieved chunks (regardless of cutoff)
|
| 345 |
+
context_parts = []
|
| 346 |
+
for ridx, sc in retrieved:
|
| 347 |
+
md = self.metadata[ridx]
|
| 348 |
+
context_parts.append(f"[page {md.get('page')}] {self.texts[ridx]}")
|
| 349 |
+
context_text = "\n\n".join(context_parts)
|
| 350 |
+
|
| 351 |
+
try:
|
| 352 |
+
parsed = _verify_with_model(q_text, options, correct_text, context_text)
|
| 353 |
+
model_verdict = parsed
|
| 354 |
+
except Exception as e:
|
| 355 |
+
model_verdict = {"error": f"verification exception: {e}"}
|
| 356 |
+
|
| 357 |
+
report[qid] = {
|
| 358 |
+
"supported_by_embeddings": bool(supported_by_embeddings),
|
| 359 |
+
"max_similarity": float(max_sim),
|
| 360 |
+
"evidence": evidence_list,
|
| 361 |
+
"model_verdict": model_verdict,
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
return report
|
| 365 |
+
|
| 366 |
+
def connect_qdrant(self, url: str, api_key: str = None, prefer_grpc: bool = False):
|
| 367 |
+
if not _HAS_QDRANT:
|
| 368 |
+
raise RuntimeError("qdrant-client is not installed. Install with `pip install qdrant-client`.")
|
| 369 |
+
self.qdrant_url = url
|
| 370 |
+
self.qdrant_api_key = api_key
|
| 371 |
+
self.qdrant_prefer_grpc = prefer_grpc
|
| 372 |
+
# Create client
|
| 373 |
+
self.qdrant = QdrantClient(url=url, api_key=api_key, prefer_grpc=prefer_grpc)
|
| 374 |
+
|
| 375 |
+
def _ensure_collection(self, collection_name: str):
|
| 376 |
+
if self.qdrant is None:
|
| 377 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 378 |
+
try:
|
| 379 |
+
# get_collection will raise if not present
|
| 380 |
+
_ = self.qdrant.get_collection(collection_name)
|
| 381 |
+
except Exception:
|
| 382 |
+
# create collection with vector size = self.dim
|
| 383 |
+
vect_params = VectorParams(size=self.dim, distance=Distance.COSINE)
|
| 384 |
+
self.qdrant.recreate_collection(collection_name=collection_name, vectors_config=vect_params)
|
| 385 |
+
# recreate_collection ensures a clean collection; if you prefer to avoid wiping use create_collection instead.
|
| 386 |
+
|
| 387 |
+
def save_pdf_to_qdrant(
|
| 388 |
+
self,
|
| 389 |
+
pdf_path: str,
|
| 390 |
+
filename: str,
|
| 391 |
+
collection: str,
|
| 392 |
+
max_chars: int = 1200,
|
| 393 |
+
batch_size: int = 64,
|
| 394 |
+
overwrite: bool = False,
|
| 395 |
+
):
|
| 396 |
+
if self.qdrant is None:
|
| 397 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 398 |
+
|
| 399 |
+
# extract pages and chunks (re-using your existing helpers)
|
| 400 |
+
pages = self.extract_pages(pdf_path)
|
| 401 |
+
all_chunks = []
|
| 402 |
+
all_meta = []
|
| 403 |
+
for p_idx, page_text in enumerate(pages, start=1):
|
| 404 |
+
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
|
| 405 |
+
for cid, ch in enumerate(chunks, start=1):
|
| 406 |
+
all_chunks.append(ch)
|
| 407 |
+
all_meta.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
|
| 408 |
+
|
| 409 |
+
if not all_chunks:
|
| 410 |
+
raise RuntimeError("No text extracted from PDF.")
|
| 411 |
+
|
| 412 |
+
# ensure collection exists
|
| 413 |
+
self._ensure_collection(collection)
|
| 414 |
+
|
| 415 |
+
# optional: delete previous points for this filename if overwrite
|
| 416 |
+
if overwrite:
|
| 417 |
+
# delete by filter: filename == filename
|
| 418 |
+
flt = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
|
| 419 |
+
try:
|
| 420 |
+
# qdrant-client delete uses delete(
|
| 421 |
+
self.qdrant.delete(collection_name=collection, filter=flt)
|
| 422 |
+
except Exception:
|
| 423 |
+
# ignore if deletion fails
|
| 424 |
+
pass
|
| 425 |
+
|
| 426 |
+
# compute embeddings in batches
|
| 427 |
+
embeddings = self.embedder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
|
| 428 |
+
embeddings = embeddings.astype("float32")
|
| 429 |
+
|
| 430 |
+
# prepare points
|
| 431 |
+
points = []
|
| 432 |
+
for i, (emb, md, txt) in enumerate(zip(embeddings, all_meta, all_chunks)):
|
| 433 |
+
pid = str(uuid4())
|
| 434 |
+
source_id = f"{filename}__p{md['page']}__c{md['chunk_id']}"
|
| 435 |
+
payload = {
|
| 436 |
+
"filename": filename,
|
| 437 |
+
"page": md["page"],
|
| 438 |
+
"chunk_id": md["chunk_id"],
|
| 439 |
+
"length": md["length"],
|
| 440 |
+
"text": txt,
|
| 441 |
+
"source_id": source_id,
|
| 442 |
+
}
|
| 443 |
+
points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload))
|
| 444 |
+
|
| 445 |
+
# upsert in batches
|
| 446 |
+
if len(points) >= batch_size:
|
| 447 |
+
self.qdrant.upsert(collection_name=collection, points=points)
|
| 448 |
+
points = []
|
| 449 |
+
|
| 450 |
+
# upsert remaining
|
| 451 |
+
if points:
|
| 452 |
+
self.qdrant.upsert(collection_name=collection, points=points)
|
| 453 |
+
|
| 454 |
+
try:
|
| 455 |
+
self.qdrant.create_payload_index(
|
| 456 |
+
collection_name=collection,
|
| 457 |
+
field_name="filename",
|
| 458 |
+
field_schema=rest.PayloadSchemaType.KEYWORD
|
| 459 |
+
)
|
| 460 |
+
except Exception as e:
|
| 461 |
+
print(f"Index creation skipped or failed: {e}")
|
| 462 |
+
|
| 463 |
+
return {"status": "ok", "uploaded_chunks": len(all_chunks), "collection": collection, "filename": filename}
|
| 464 |
+
|
| 465 |
+
def list_files_in_collection(
|
| 466 |
+
self,
|
| 467 |
+
collection: str,
|
| 468 |
+
payload_field: str = "filename",
|
| 469 |
+
batch_size: int = 500,
|
| 470 |
+
) -> List[str]:
|
| 471 |
+
if self.qdrant is None:
|
| 472 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 473 |
+
|
| 474 |
+
# ensure collection exists
|
| 475 |
+
try:
|
| 476 |
+
if not self.qdrant.collection_exists(collection):
|
| 477 |
+
raise RuntimeError(f"Collection '{collection}' does not exist.")
|
| 478 |
+
except Exception:
|
| 479 |
+
# collection_exists may raise if server unreachable
|
| 480 |
+
raise
|
| 481 |
+
|
| 482 |
+
filenames = set()
|
| 483 |
+
offset = None
|
| 484 |
+
|
| 485 |
+
while True:
|
| 486 |
+
# scroll returns (points, next_offset)
|
| 487 |
+
pts, next_offset = self.qdrant.scroll(
|
| 488 |
+
collection_name=collection,
|
| 489 |
+
limit=batch_size,
|
| 490 |
+
offset=offset,
|
| 491 |
+
with_payload=[payload_field],
|
| 492 |
+
with_vectors=False,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if not pts:
|
| 496 |
+
break
|
| 497 |
+
|
| 498 |
+
for p in pts:
|
| 499 |
+
# p may be a dict-like or an object with .payload
|
| 500 |
+
payload = None
|
| 501 |
+
if hasattr(p, "payload"):
|
| 502 |
+
payload = p.payload
|
| 503 |
+
elif isinstance(p, dict):
|
| 504 |
+
# older/newer variants might use nested structures: try common keys
|
| 505 |
+
payload = p.get("payload") or p.get("payload", None) or p
|
| 506 |
+
else:
|
| 507 |
+
# best-effort fallback: convert to dict if possible
|
| 508 |
+
try:
|
| 509 |
+
payload = dict(p)
|
| 510 |
+
except Exception:
|
| 511 |
+
payload = None
|
| 512 |
+
|
| 513 |
+
if not payload:
|
| 514 |
+
continue
|
| 515 |
+
|
| 516 |
+
# extract candidate value(s)
|
| 517 |
+
val = None
|
| 518 |
+
if isinstance(payload, dict):
|
| 519 |
+
val = payload.get(payload_field)
|
| 520 |
+
else:
|
| 521 |
+
# Some payload representations store fields differently; try attribute access
|
| 522 |
+
val = getattr(payload, payload_field, None)
|
| 523 |
+
|
| 524 |
+
# If value is list-like, iterate, else add single
|
| 525 |
+
if isinstance(val, (list, tuple, set)):
|
| 526 |
+
for v in val:
|
| 527 |
+
if v is not None:
|
| 528 |
+
filenames.add(str(v))
|
| 529 |
+
elif val is not None:
|
| 530 |
+
filenames.add(str(val))
|
| 531 |
+
|
| 532 |
+
# stop if no more pages
|
| 533 |
+
if not next_offset:
|
| 534 |
+
break
|
| 535 |
+
offset = next_offset
|
| 536 |
+
|
| 537 |
+
return sorted(filenames)
|
| 538 |
+
|
| 539 |
+
def list_chunks_for_filename(self, collection: str, filename: str, batch: int = 256) -> List[Dict[str, Any]]:
|
| 540 |
+
if self.qdrant is None:
|
| 541 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 542 |
+
|
| 543 |
+
results = []
|
| 544 |
+
offset = None
|
| 545 |
+
while True:
|
| 546 |
+
# scroll returns (points, next_offset)
|
| 547 |
+
points, next_offset = self.qdrant.scroll(
|
| 548 |
+
collection_name=collection,
|
| 549 |
+
scroll_filter=Filter(
|
| 550 |
+
must=[
|
| 551 |
+
FieldCondition(key="filename", match=MatchValue(value=filename))
|
| 552 |
+
]
|
| 553 |
+
),
|
| 554 |
+
limit=batch,
|
| 555 |
+
offset=offset,
|
| 556 |
+
with_payload=True,
|
| 557 |
+
with_vectors=False,
|
| 558 |
+
)
|
| 559 |
+
# points are objects (Record / ScoredPoint-like); get id and payload
|
| 560 |
+
for p in points:
|
| 561 |
+
# p.payload is a dict, p.id is point id
|
| 562 |
+
results.append({"point_id": p.id, "payload": p.payload})
|
| 563 |
+
if not next_offset:
|
| 564 |
+
break
|
| 565 |
+
offset = next_offset
|
| 566 |
+
return results
|
| 567 |
+
|
| 568 |
+
def _retrieve_qdrant(self, query: str, collection: str, filename: str = None, top_k: int = 3) -> List[Tuple[Dict[str, Any], float]]:
|
| 569 |
+
if self.qdrant is None:
|
| 570 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 571 |
+
|
| 572 |
+
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")[0].tolist()
|
| 573 |
+
q_filter = None
|
| 574 |
+
if filename:
|
| 575 |
+
q_filter = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
|
| 576 |
+
|
| 577 |
+
search_res = self.qdrant.search(
|
| 578 |
+
collection_name=collection,
|
| 579 |
+
query_vector=q_emb,
|
| 580 |
+
query_filter=q_filter,
|
| 581 |
+
limit=top_k,
|
| 582 |
+
with_payload=True,
|
| 583 |
+
with_vectors=False,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
out = []
|
| 587 |
+
for hit in search_res:
|
| 588 |
+
# hit.payload is the stored payload, hit.score is similarity
|
| 589 |
+
out.append((hit.payload, float(getattr(hit, "score", 0.0))))
|
| 590 |
+
return out
|
| 591 |
+
|
| 592 |
+
def generate_from_qdrant(
|
| 593 |
+
self,
|
| 594 |
+
filename: str,
|
| 595 |
+
collection: str,
|
| 596 |
+
n_questions: int = 10,
|
| 597 |
+
mode: str = "rag", # 'per_chunk' or 'rag'
|
| 598 |
+
questions_per_chunk: int = 3, # used for 'per_chunk'
|
| 599 |
+
top_k: int = 3, # retrieval size used in RAG
|
| 600 |
+
temperature: float = 0.2,
|
| 601 |
+
) -> Dict[str, Any]:
|
| 602 |
+
if self.qdrant is None:
|
| 603 |
+
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
|
| 604 |
+
|
| 605 |
+
# get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
|
| 606 |
+
file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
|
| 607 |
+
if not file_points:
|
| 608 |
+
raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
|
| 609 |
+
|
| 610 |
+
# create a local list of texts & metadata for sampling
|
| 611 |
+
texts = []
|
| 612 |
+
metas = []
|
| 613 |
+
for p in file_points:
|
| 614 |
+
payload = p.get("payload", {})
|
| 615 |
+
text = payload.get("text", "")
|
| 616 |
+
texts.append(text)
|
| 617 |
+
metas.append(payload)
|
| 618 |
+
|
| 619 |
+
self.texts = texts
|
| 620 |
+
self.metadata = metas
|
| 621 |
+
embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
|
| 622 |
+
if embeddings is None or len(embeddings) == 0:
|
| 623 |
+
self.embeddings = None
|
| 624 |
+
self.index = None
|
| 625 |
+
else:
|
| 626 |
+
self.embeddings = embeddings.astype("float32")
|
| 627 |
+
|
| 628 |
+
# update dim in case embedder changed unexpectedly
|
| 629 |
+
self.dim = int(self.embeddings.shape[1])
|
| 630 |
+
|
| 631 |
+
# build index
|
| 632 |
+
self._build_faiss_index()
|
| 633 |
+
|
| 634 |
+
output = {}
|
| 635 |
+
qcount = 0
|
| 636 |
+
|
| 637 |
+
if mode == "per_chunk":
|
| 638 |
+
# iterate all chunks (in payload order) and request questions_per_chunk from each
|
| 639 |
+
for i, txt in enumerate(texts):
|
| 640 |
+
if not txt.strip():
|
| 641 |
+
continue
|
| 642 |
+
to_gen = questions_per_chunk
|
| 643 |
+
try:
|
| 644 |
+
mcq_block = generate_mcqs_from_text(txt, n=to_gen, model=self.hf_model, temperature=temperature)
|
| 645 |
+
except Exception as e:
|
| 646 |
+
print(f"Generator failed on chunk (index {i}): {e}")
|
| 647 |
+
continue
|
| 648 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 649 |
+
qcount += 1
|
| 650 |
+
output[str(qcount)] = mcq_block[item]
|
| 651 |
+
if qcount >= n_questions:
|
| 652 |
+
return output
|
| 653 |
+
return output
|
| 654 |
+
|
| 655 |
+
elif mode == "rag":
|
| 656 |
+
attempts = 0
|
| 657 |
+
max_attempts = n_questions * 4
|
| 658 |
+
while qcount < n_questions and attempts < max_attempts:
|
| 659 |
+
attempts += 1
|
| 660 |
+
# sample a seed sentence from a random chunk of this file
|
| 661 |
+
seed_idx = random.randrange(len(texts))
|
| 662 |
+
chunk = texts[seed_idx]
|
| 663 |
+
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 664 |
+
seed_sent = None
|
| 665 |
+
for s in sents:
|
| 666 |
+
if len(s.strip()) > 20:
|
| 667 |
+
seed_sent = s
|
| 668 |
+
break
|
| 669 |
+
if not seed_sent:
|
| 670 |
+
seed_sent = chunk[:200]
|
| 671 |
+
query = f"Create questions about: {seed_sent}"
|
| 672 |
+
|
| 673 |
+
# retrieve top_k chunks from the same file (restricted by filename filter)
|
| 674 |
+
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
|
| 675 |
+
context_parts = []
|
| 676 |
+
for payload, score in retrieved:
|
| 677 |
+
# payload should contain page & chunk_id and text
|
| 678 |
+
page = payload.get("page", "?")
|
| 679 |
+
ctxt = payload.get("text", "")
|
| 680 |
+
context_parts.append(f"[page {page}] {ctxt}")
|
| 681 |
+
context = "\n\n".join(context_parts)
|
| 682 |
+
|
| 683 |
+
try:
|
| 684 |
+
mcq_block = generate_mcqs_from_text(context, n=1, model=self.hf_model, temperature=temperature)
|
| 685 |
+
except Exception as e:
|
| 686 |
+
print(f"Generator failed during RAG attempt {attempts}: {e}")
|
| 687 |
+
continue
|
| 688 |
+
|
| 689 |
+
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
|
| 690 |
+
qcount += 1
|
| 691 |
+
output[str(qcount)] = mcq_block[item]
|
| 692 |
+
if qcount >= n_questions:
|
| 693 |
+
return output
|
| 694 |
+
return output
|
| 695 |
+
else:
|
| 696 |
+
raise ValueError("mode must be 'per_chunk' or 'rag'.")
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
boto3
|
| 2 |
+
pdfplumber
|
| 3 |
+
faiss-cpu
|
| 4 |
+
sentence-transformers
|
| 5 |
+
fastapi[standard]
|
| 6 |
+
uvicorn[standard]
|
| 7 |
+
qdrant-client
|
| 8 |
+
pymupdf4llm
|
test/cerebras-api.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import os
|
| 2 |
+
# from cerebras.cloud.sdk import Cerebras
|
| 3 |
+
import tiktoken
|
| 4 |
+
|
| 5 |
+
# client = Cerebras(
|
| 6 |
+
# # This is the default and can be omitted
|
| 7 |
+
# api_key=os.environ.get("CEREBRAS_API_KEY")
|
| 8 |
+
# )
|
| 9 |
+
|
| 10 |
+
# stream = client.chat.completions.create(
|
| 11 |
+
# messages=[
|
| 12 |
+
# {
|
| 13 |
+
# "role": "system",
|
| 14 |
+
# "content": ""
|
| 15 |
+
# }
|
| 16 |
+
# ],
|
| 17 |
+
# model="gpt-oss-120b",
|
| 18 |
+
# stream=True,
|
| 19 |
+
# max_completion_tokens=65536,
|
| 20 |
+
# temperature=1,
|
| 21 |
+
# top_p=1
|
| 22 |
+
# )
|
| 23 |
+
import numpy as np
|
| 24 |
+
|
| 25 |
+
INPUT_TOKEN_COUNT = np.array([], dtype=int)
|
| 26 |
+
OUTPUT_TOKEN_COUNT = np.array([], dtype=int)
|
| 27 |
+
|
| 28 |
+
# for chunk in stream:
|
| 29 |
+
# print(chunk.choices[0].delta.content or "", end="")
|
| 30 |
+
with open('../test/mcq_output.json', 'r', encoding='utf-8') as f:
|
| 31 |
+
text = f.read()
|
| 32 |
+
|
| 33 |
+
def count_tokens(text: str, model_name='gpt-oss-120b', encoding_name='cl100k_base') -> int:
|
| 34 |
+
"""Look up model encoding; fallback to encoding_name if model not known."""
|
| 35 |
+
try:
|
| 36 |
+
# encoding_for_model can raise if model is unknown to tiktoken
|
| 37 |
+
enc = tiktoken.encoding_for_model(model_name)
|
| 38 |
+
except Exception:
|
| 39 |
+
enc = None
|
| 40 |
+
|
| 41 |
+
if enc is None:
|
| 42 |
+
enc = tiktoken.get_encoding(encoding_name)
|
| 43 |
+
|
| 44 |
+
return len(enc.encode(text))
|
| 45 |
+
|
| 46 |
+
c = count_tokens(text)
|
| 47 |
+
INPUT_TOKEN_COUNT = np.append(INPUT_TOKEN_COUNT, c)
|
| 48 |
+
print(INPUT_TOKEN_COUNT)
|
test/logging.txt
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
**********************
|
| 2 |
+
Windows PowerShell transcript start
|
| 3 |
+
Start time: 20250815161919
|
| 4 |
+
Username: MACBOOKM5\boboi
|
| 5 |
+
RunAs User: MACBOOKM5\boboi
|
| 6 |
+
Configuration Name:
|
| 7 |
+
Machine: MACBOOKM5 (Microsoft Windows NT 10.0.22631.0)
|
| 8 |
+
Host Application: C:\WINDOWS\System32\WindowsPowerShell\v1.0\powershell.exe
|
| 9 |
+
Process ID: 20936
|
| 10 |
+
PSVersion: 5.1.22621.5697
|
| 11 |
+
PSEdition: Desktop
|
| 12 |
+
PSCompatibleVersions: 1.0, 2.0, 3.0, 4.0, 5.0, 5.1.22621.5697
|
| 13 |
+
BuildVersion: 10.0.22621.5697
|
| 14 |
+
CLRVersion: 4.0.30319.42000
|
| 15 |
+
WSManStackVersion: 3.0
|
| 16 |
+
PSRemotingProtocolVersion: 2.3
|
| 17 |
+
SerializationVersion: 1.1.0.1
|
| 18 |
+
**********************
|
| 19 |
+
Transcript started, output file is D:\graduation_project\mcq-generator\test\logging.txt
|
| 20 |
+
PS D:\graduation_project\mcq-generator\app>
|
| 21 |
+
(rag-api) uvicorn app:app --reload
|
| 22 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\base_events.py", li
|
| 23 |
+
ne 608, in run_forever
|
| 24 |
+
self._run_once()
|
| 25 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\base_events.py", li
|
| 26 |
+
ne 1936, in _run_once
|
| 27 |
+
handle._run()
|
| 28 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\events.py", line 84
|
| 29 |
+
, in _run
|
| 30 |
+
self._context.run(self._callback, *self._args)
|
| 31 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\uvicorn\serve
|
| 32 |
+
r.py", line 70, in serve
|
| 33 |
+
with self.capture_signals():
|
| 34 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\contextlib.py", line 144, i
|
| 35 |
+
n __exit__
|
| 36 |
+
next(self.gen)
|
| 37 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\uvicorn\serve
|
| 38 |
+
r.py", line 331, in capture_signals
|
| 39 |
+
signal.raise_signal(captured_signal)
|
| 40 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\runners.py", line 1
|
| 41 |
+
57, in _on_sigint
|
| 42 |
+
raise KeyboardInterrupt()
|
| 43 |
+
KeyboardInterrupt
|
| 44 |
+
|
| 45 |
+
During handling of the above exception, another exception occurred:
|
| 46 |
+
|
| 47 |
+
Traceback (most recent call last):
|
| 48 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\starlette\rou
|
| 49 |
+
ting.py", line 701, in lifespan
|
| 50 |
+
await receive()
|
| 51 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\uvicorn\lifes
|
| 52 |
+
pan\on.py", line 137, in receive
|
| 53 |
+
return await self.receive_queue.get()
|
| 54 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 55 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\queues.py", line 15
|
| 56 |
+
8, in get
|
| 57 |
+
await getter
|
| 58 |
+
asyncio.exceptions.CancelledError
|
| 59 |
+
|
| 60 |
+
INFO: Stopping reloader process [21928]
|
| 61 |
+
(rag-api) TerminatingError(): "The pipeline has been stopped."
|
| 62 |
+
>> TerminatingError(): "The pipeline has been stopped."
|
| 63 |
+
PS D:\graduation_project\mcq-generator\app>
|
| 64 |
+
(rag-api) uvicorn app:app --reload
|
| 65 |
+
warnings.warn(
|
| 66 |
+
INFO: Started server process [20356]
|
| 67 |
+
INFO: Waiting for application startup.
|
| 68 |
+
RAGMCQ instance created on startup.
|
| 69 |
+
INFO: Application startup complete.
|
| 70 |
+
ERROR: Traceback (most recent call last):
|
| 71 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\runners.py", line 1
|
| 72 |
+
18, in run
|
| 73 |
+
return self._loop.run_until_complete(task)
|
| 74 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 75 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\base_events.py", li
|
| 76 |
+
ne 654, in run_until_complete
|
| 77 |
+
return future.result()
|
| 78 |
+
^^^^^^^^^^^^^^^
|
| 79 |
+
asyncio.exceptions.CancelledError
|
| 80 |
+
|
| 81 |
+
During handling of the above exception, another exception occurred:
|
| 82 |
+
|
| 83 |
+
Traceback (most recent call last):
|
| 84 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\runners.py", line 1
|
| 85 |
+
90, in run
|
| 86 |
+
return runner.run(main)
|
| 87 |
+
^^^^^^^^^^^^^^^^
|
| 88 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\runners.py", line 1
|
| 89 |
+
23, in run
|
| 90 |
+
raise KeyboardInterrupt()
|
| 91 |
+
KeyboardInterrupt
|
| 92 |
+
|
| 93 |
+
During handling of the above exception, another exception occurred:
|
| 94 |
+
|
| 95 |
+
Traceback (most recent call last):
|
| 96 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\starlette\rou
|
| 97 |
+
ting.py", line 701, in lifespan
|
| 98 |
+
await receive()
|
| 99 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\uvicorn\lifes
|
| 100 |
+
pan\on.py", line 137, in receive
|
| 101 |
+
return await self.receive_queue.get()
|
| 102 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 103 |
+
File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\queues.py", line 15
|
| 104 |
+
8, in get
|
| 105 |
+
await getter
|
| 106 |
+
asyncio.exceptions.CancelledError
|
| 107 |
+
|
| 108 |
+
INFO: Stopping reloader process [1968]
|
| 109 |
+
(rag-api) TerminatingError(): "The pipeline has been stopped."
|
| 110 |
+
>> TerminatingError(): "The pipeline has been stopped."
|
| 111 |
+
PS D:\graduation_project\mcq-generator\app>
|
test/mcq_output.json
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mcqs": {
|
| 3 |
+
"1": {
|
| 4 |
+
"câu hỏi": "Trong lớp Str_OutputParser, biểu thức chính quy nào được sử dụng để trích xuất câu trả lời từ chuỗi phản hồi?",
|
| 5 |
+
"lựa chọn": {
|
| 6 |
+
"a": "r\"Answer:\\s*(.*)\"",
|
| 7 |
+
"b": "r\"Respuesta:\\s*(.*)\"",
|
| 8 |
+
"c": "r\"Answer :\\s*(.*)\"",
|
| 9 |
+
"d": "r\"Result:\\s*(.*)\""
|
| 10 |
+
},
|
| 11 |
+
"đáp án": "r\"Answer :\\s*(.*)\""
|
| 12 |
+
},
|
| 13 |
+
"2": {
|
| 14 |
+
"câu hỏi": "Trong dự án RAG, file nào được dùng để khai báo các hàm load file PDF?",
|
| 15 |
+
"lựa chọn": {
|
| 16 |
+
"a": "src/rag/main.py",
|
| 17 |
+
"b": "src/rag/file_loader.py",
|
| 18 |
+
"c": "src/rag/offline_rag.py",
|
| 19 |
+
"d": "src/rag/utils.py"
|
| 20 |
+
},
|
| 21 |
+
"đáp án": "src/rag/file_loader.py"
|
| 22 |
+
},
|
| 23 |
+
"3": {
|
| 24 |
+
"câu hỏi": "Trong file src/rag/vectorstore.py, lớp nào được đặt làm giá trị mặc định cho vector database?",
|
| 25 |
+
"lựa chọn": {
|
| 26 |
+
"a": "FAISS",
|
| 27 |
+
"b": "Chroma",
|
| 28 |
+
"c": "Pinecone",
|
| 29 |
+
"d": "Milvus"
|
| 30 |
+
},
|
| 31 |
+
"đáp án": "Chroma"
|
| 32 |
+
},
|
| 33 |
+
"4": {
|
| 34 |
+
"câu hỏi": "Trong đoạn mã được trích dẫn, tham số nào được sử dụng cho kiểu lượng tử (quantization type) trong cấu hình BitsAndBytesConfig?",
|
| 35 |
+
"lựa chọn": {
|
| 36 |
+
"a": "nf4",
|
| 37 |
+
"b": "int8",
|
| 38 |
+
"c": "fp16",
|
| 39 |
+
"d": "int4"
|
| 40 |
+
},
|
| 41 |
+
"đáp án": "nf4"
|
| 42 |
+
},
|
| 43 |
+
"5": {
|
| 44 |
+
"câu hỏi": "Theo mô tả trong nội dung, bước nào liên quan đến việc tạo cơ sở dữ liệu vector bằng mô hình embedding?",
|
| 45 |
+
"lựa chọn": {
|
| 46 |
+
"a": "Tách danh sách các bài báo khoa học thành các văn bản nhỏ.",
|
| 47 |
+
"b": "Xây dựng một cơ sở dữ liệu vector từ các văn bản nhỏ bằng mô hình embedding.",
|
| 48 |
+
"c": "Truy vấn các mẫu văn bản có liên quan đến câu hỏi đầu vào để làm ngữ cảnh.",
|
| 49 |
+
"d": "Đưa câu prompt (câu hỏi và ngữ cảnh) vào mô hình để nhận câu trả lời."
|
| 50 |
+
},
|
| 51 |
+
"đáp án": "Xây dựng một cơ sở dữ liệu vector từ các văn bản nhỏ bằng mô hình embedding."
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"validation": {
|
| 55 |
+
"1": {
|
| 56 |
+
"supported_by_embeddings": true,
|
| 57 |
+
"max_similarity": 0.5152225494384766,
|
| 58 |
+
"evidence": [
|
| 59 |
+
{
|
| 60 |
+
"idx": 26,
|
| 61 |
+
"page": 15,
|
| 62 |
+
"score": 0.5152225494384766,
|
| 63 |
+
"text": "Ý nghĩa của phương thức `from_template()` trong class PromptTemplate là? ( _a_ ) Đểkhởi tạo prompt template từmột file. ( _b_ ) Đểkhởi tạo prompt template từmột string. ( _c_ ) Đểkhởi tạo prompt template từmột danh sách các tin nhắn. ( _d_ ) Đểkhởi tạo prompt template từmột prompt template có sẵn. 15"
|
| 64 |
+
}
|
| 65 |
+
],
|
| 66 |
+
"model_verdict": {
|
| 67 |
+
"supported": false,
|
| 68 |
+
"confidence": 0.9,
|
| 69 |
+
"evidence": "",
|
| 70 |
+
"reason": "Context không chứa thông tin về lớp Str_OutputParser hay biểu thức chính quy được sử dụng, vì vậy không thể chứng thực đáp án được đưa ra."
|
| 71 |
+
}
|
| 72 |
+
},
|
| 73 |
+
"2": {
|
| 74 |
+
"supported_by_embeddings": true,
|
| 75 |
+
"max_similarity": 0.694902777671814,
|
| 76 |
+
"evidence": [
|
| 77 |
+
{
|
| 78 |
+
"idx": 4,
|
| 79 |
+
"page": 4,
|
| 80 |
+
"score": 0.694902777671814,
|
| 81 |
+
"text": "**AI VIETNAM (AIO2024)** **aivietnam.edu.vn**\n\n\n_ **src/rag/:** Thư mục dùng đểlưu trữcác code liên quan đến xây dựng RAG, bao gồm:\n\n\n1. **src/rag/file_loader.py:** File code dùng đểkhai báo các hàm load file pdf (vì tài\nliệu của chúng ta thu thập thuộc file pdf). 2. **src/rag/main.py:** File code dùng đểkhai báo hàm khởi tạo chains. 3. **src/rag/offline_rag.py:** File code dùng đểkhai báo PromptTemplate. 4. **src/rag/utils.py:** File code dùng đểkhai báo hàm tách câu trảlời từmodel. 5. **src/rag/vectorstore.py:** File code dùng đểkhai báo hàm khởi tạo hệcơ sởdữliệu\n\nvector. _ **src/app.py:** File code dùng đểkhởi tạo API. _ **requirements.txt:** File code dùng đểkhai báo các thư viện cần thiết đểsửdụng source\ncode. ## II.2. Cập nhật file requirements.txt\n\n\nĐểbắt đầu, chúng ta sẽliệt kê các gói thư viện cần thiết đểchạy được chương trình này."
|
| 82 |
+
},
|
| 83 |
+
{
|
| 84 |
+
"idx": 28,
|
| 85 |
+
"page": 16,
|
| 86 |
+
"score": 0.5763600468635559,
|
| 87 |
+
"text": "document_loaders` `import` `PyPDFLoader`\n\n\n2\n\n\n3 `pdf_loader = PyPDFLoader(url, extract_images =` `True` `)`\n\n\n4\n\n\n5 `docs = pdf_loader.load ()`\n\n\nTham số `extract_images` tại dòng code 3 có chức năng gì? ( _a_ ) Trảvềtất cảảnh từfile pdf. ( _b_ ) Bỏqua ảnh, chỉload text. ( _c_ ) Phân tích ảnh thành vector. ( _d_ ) Chuyển đổi ảnh trong file pdf thành text. 16"
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"idx": 16,
|
| 91 |
+
"page": 9,
|
| 92 |
+
"score": 0.5420067310333252,
|
| 93 |
+
"text": "**AI VIETNAM (AIO2024)** **aivietnam.edu.vn**\n\n\n86 `return` `self.load(files, workers=workers)`\n\n## II.6. Cập nhật file src/rag/vectorstore.py\n\n\nTại file này, ta định nghĩa một class đểkhởi tạo hệcơ sởdữliệu vector. Trong project này, chúng\nta sẽsửdụng Chroma. Vềviệc tìm kiếm tài liệu tương đồng, ta sửdụng FAISS. Như vậy, nội\ndung của file như sau:\n\n\nHình 4: Minh họa việc sửdụng vector database Chroma đểtruy vấn các tài liệu có liên quan\n[làm context trong prompt. Ảnh: Link.](https://heidloff.net/article/retrieval-augmented-generation-chroma-langchain/)\n\n\n1 `from` `typing` `import` `Union`\n\n2 `from` `langchain_chroma` `import` `Chroma`\n\n3 `from` `langchain_community .vectorstores` `import` `FAISS`\n\n4 `from` `langchain_community .embeddings` `import` `HuggingFaceEmbeddings`\n\n\n5\n\n\n6 `class` `VectorDB:`\n\n\n7 `def` `__init__(self,`\n\n\n8 `documents = None,`\n\n9 `vector_db: Union[Chroma, FAISS] = Chroma,`\n\n10 `embedding = HuggingFaceEmbeddings (),`\n\n11 `) -> None` `:`\n\n\n12\n\n\n13 `self.vector_db ..."
|
| 94 |
+
}
|
| 95 |
+
],
|
| 96 |
+
"model_verdict": {
|
| 97 |
+
"supported": true,
|
| 98 |
+
"confidence": 0.99,
|
| 99 |
+
"evidence": "src/rag/file_loader.py: File code dùng để khai báo các hàm load file pdf",
|
| 100 |
+
"reason": "Context explicitly states that src/rag/file_loader.py declares functions for loading PDF files, matching the answer."
|
| 101 |
+
}
|
| 102 |
+
},
|
| 103 |
+
"3": {
|
| 104 |
+
"supported_by_embeddings": true,
|
| 105 |
+
"max_similarity": 0.579485297203064,
|
| 106 |
+
"evidence": [
|
| 107 |
+
{
|
| 108 |
+
"idx": 20,
|
| 109 |
+
"page": 11,
|
| 110 |
+
"score": 0.579485297203064,
|
| 111 |
+
"text": "Cập nhật file src/rag/main.py\n\n\nTại file này, ta khởi tạo toàn bộcác instance của các class, các hàm mà ta đã khai báo trước đó\nvà kết nối chúng vào trong một hàm duy nhất gọi là `build_rag_chain()` :\n\n\n1 `from` `pydantic` `import` `BaseModel, Field`\n\n\n2\n\n\n3 `from src.rag.file_loader` `import` `Loader`\n\n4 `from src.rag.vectorstore` `import` `VectorDB`\n\n5 `from src.rag.offline_rag` `import` `Offline_RAG`\n\n\n6\n\n\n7 `class` `InputQA(BaseModel):`\n\n8 `question: str = Field (..., title=` `\"Question to ask the model\"` `)`\n\n\n9\n\n\n10 `class` `OutputQA(BaseModel):`\n\n11 `answer: str = Field (..., title=` `\"Answer` `from the model\"` `)`\n\n\n12\n\n\n13 `def` `build_rag_chain (llm, data_dir, data_type):`\n\n14 `doc_loaded = Loader(file_type=data_type).load_dir(data_dir, workers=2)`\n\n15 `retriever = VectorDB(documents = doc_loaded).get_retriever ()`\n\n16 `rag_chain = Offline_RAG(llm).get_chain(retriever)`\n\n\n17\n\n\n18 `return` `rag_chain`\n\n\n11"
|
| 112 |
+
},
|
| 113 |
+
{
|
| 114 |
+
"idx": 16,
|
| 115 |
+
"page": 9,
|
| 116 |
+
"score": 0.5778905749320984,
|
| 117 |
+
"text": "**AI VIETNAM (AIO2024)** **aivietnam.edu.vn**\n\n\n86 `return` `self.load(files, workers=workers)`\n\n## II.6. Cập nhật file src/rag/vectorstore.py\n\n\nTại file này, ta định nghĩa một class đểkhởi tạo hệcơ sởdữliệu vector. Trong project này, chúng\nta sẽsửdụng Chroma. Vềviệc tìm kiếm tài liệu tương đồng, ta sửdụng FAISS. Như vậy, nội\ndung của file như sau:\n\n\nHình 4: Minh họa việc sửdụng vector database Chroma đểtruy vấn các tài liệu có liên quan\n[làm context trong prompt. Ảnh: Link.](https://heidloff.net/article/retrieval-augmented-generation-chroma-langchain/)\n\n\n1 `from` `typing` `import` `Union`\n\n2 `from` `langchain_chroma` `import` `Chroma`\n\n3 `from` `langchain_community .vectorstores` `import` `FAISS`\n\n4 `from` `langchain_community .embeddings` `import` `HuggingFaceEmbeddings`\n\n\n5\n\n\n6 `class` `VectorDB:`\n\n\n7 `def` `__init__(self,`\n\n\n8 `documents = None,`\n\n9 `vector_db: Union[Chroma, FAISS] = Chroma,`\n\n10 `embedding = HuggingFaceEmbeddings (),`\n\n11 `) -> None` `:`\n\n\n12\n\n\n13 `self.vector_db ..."
|
| 118 |
+
}
|
| 119 |
+
],
|
| 120 |
+
"model_verdict": {
|
| 121 |
+
"supported": true,
|
| 122 |
+
"confidence": 1.0,
|
| 123 |
+
"evidence": "vector_db: Union[Chroma, FAISS] = Chroma",
|
| 124 |
+
"reason": "Mặc định của tham số vector_db trong class VectorDB được đặt là Chroma"
|
| 125 |
+
}
|
| 126 |
+
},
|
| 127 |
+
"4": {
|
| 128 |
+
"supported_by_embeddings": false,
|
| 129 |
+
"max_similarity": 0.43995893001556396,
|
| 130 |
+
"evidence": [],
|
| 131 |
+
"model_verdict": {
|
| 132 |
+
"supported": false,
|
| 133 |
+
"confidence": 0.95,
|
| 134 |
+
"evidence": "",
|
| 135 |
+
"reason": "Trong nội dung Context không có bất kỳ đoạn nào đề cập đến BitsAndBytesConfig hay tham số kiểu lượng tử, vì vậy không thể chứng thực đáp án nf4."
|
| 136 |
+
}
|
| 137 |
+
},
|
| 138 |
+
"5": {
|
| 139 |
+
"supported_by_embeddings": true,
|
| 140 |
+
"max_similarity": 0.6268875598907471,
|
| 141 |
+
"evidence": [
|
| 142 |
+
{
|
| 143 |
+
"idx": 1,
|
| 144 |
+
"page": 2,
|
| 145 |
+
"score": 0.6268875598907471,
|
| 146 |
+
"text": "**AI VIETNAM (AIO2024)** **aivietnam.edu.vn**\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nHình 2: Tổng quan vềpipeline của project.\n\n\n**Theo đó:**\n\n\n1. Từdanh sách các bài báo khoa học, ta tách thành các văn bản nhỏ. Từđó, xây dựng một\nhệcơ sởdữliệu vector với một embedding model.\n\n\n2. Bên cạnh câu hỏi đầu vào (question), ta truy vấn các mẫu văn bản có liên quan đến đến\ncâu hỏi, dùng làm ngữcảnh (context) trong câu prompt. Đây là nguồn thông tin mà LLMs\ncó thểdựa vào đểtrảlời câu hỏi.\n\n\n3. Đưa câu prompt vào mô hình (question và context) đểnhận câu trảlời từmô hình.\n\n\n2"
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"idx": 30,
|
| 150 |
+
"page": 17,
|
| 151 |
+
"score": 0.5708718299865723,
|
| 152 |
+
"text": "split_documents (pdf_pages)`\n\n\n18\n\n\n19 _`# Embedding`_ _`model`_\n\n20 `embedding_model = HuggingFaceEmbeddings ()`\n\n\n21\n\n\n22 _`# vector`_ _`store`_\n\n\n23 `chroma_db = Chroma.from_documents(docs, embedding= embedding_model )`\n\n\nNhiệm vụcủa `embedding_model` là gì? ( _a_ ) Dùng biến đổi chuỗi đầu vào thành các vector cho cơ sởdữliệu vector. ( _b_ ) Dùng đểlập chỉmục cho cơ sởdữliệu. ( _c_ ) Dùng đểtìm kiếm tài liệu. ( _d_ ) Dùng đểtính toán độtương đồng. 17"
|
| 153 |
+
}
|
| 154 |
+
],
|
| 155 |
+
"model_verdict": {
|
| 156 |
+
"supported": true,
|
| 157 |
+
"confidence": 0.99,
|
| 158 |
+
"evidence": "1. Từ danh sách các bài báo khoa học, ta tách thành các văn bản nhỏ. Từ đó, xây dựng một hẹcơ sở dữ liệu vector với một embedding model.",
|
| 159 |
+
"reason": "Context explicitly states that after splitting documents, a vector database is built using an embedding model, matching the chosen answer."
|
| 160 |
+
}
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
}
|
test/output.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
from typing import Dict, Any
|
| 4 |
+
import requests
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
API_URL = "https://router.huggingface.co/v1/chat/completions"
|
| 8 |
+
HF_KEY = os.environ['HF_API_KEY']
|
| 9 |
+
HEADERS = {"Authorization": f"Bearer {HF_KEY}"}
|
| 10 |
+
JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
|
| 11 |
+
|
| 12 |
+
def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
|
| 13 |
+
payload = {"model": model, "messages": messages, "temperature": temperature}
|
| 14 |
+
resp = requests.post(API_URL, headers=HEADERS, json=payload, timeout=timeout)
|
| 15 |
+
resp.raise_for_status()
|
| 16 |
+
data = resp.json()
|
| 17 |
+
|
| 18 |
+
# handle various shapes
|
| 19 |
+
if "choices" in data and len(data["choices"]) > 0:
|
| 20 |
+
# prefer message.content
|
| 21 |
+
ch = data["choices"][0]
|
| 22 |
+
|
| 23 |
+
if isinstance(ch, dict) and "message" in ch and "content" in ch["message"]:
|
| 24 |
+
return ch["message"]["content"]
|
| 25 |
+
|
| 26 |
+
if "text" in ch:
|
| 27 |
+
return ch["text"]
|
| 28 |
+
|
| 29 |
+
# final fallback
|
| 30 |
+
raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
|
| 31 |
+
|
| 32 |
+
def _safe_extract_json(text: str) -> dict:
|
| 33 |
+
# remove triple backticks
|
| 34 |
+
text = re.sub(r"```(?:json)?\n?", "", text)
|
| 35 |
+
m = JSON_OBJ_RE.search(text)
|
| 36 |
+
|
| 37 |
+
if not m:
|
| 38 |
+
raise ValueError("No JSON object found in model output.")
|
| 39 |
+
js = m.group(1)
|
| 40 |
+
|
| 41 |
+
# try load, fix trailing commas
|
| 42 |
+
try:
|
| 43 |
+
return json.loads(js)
|
| 44 |
+
except json.JSONDecodeError:
|
| 45 |
+
fixed = re.sub(r",\s*([}\]])", r"\1", js)
|
| 46 |
+
return json.loads(fixed)
|
| 47 |
+
|
| 48 |
+
def generate_mcqs_from_text(
|
| 49 |
+
source_text: str,
|
| 50 |
+
n: int = 3,
|
| 51 |
+
model: str = "openai/gpt-oss-120b:cerebras",
|
| 52 |
+
temperature: float = 0.2,
|
| 53 |
+
) -> Dict[str, Any]:
|
| 54 |
+
system_message = {
|
| 55 |
+
"role": "system",
|
| 56 |
+
"content": (
|
| 57 |
+
"Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm. "
|
| 58 |
+
"Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không có bất kỳ văn bản nào khác:\n\n"
|
| 59 |
+
"{\n"
|
| 60 |
+
' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"..."},\n'
|
| 61 |
+
' "2": { ... }\n'
|
| 62 |
+
"}\n\n"
|
| 63 |
+
"Lưu ý:\n"
|
| 64 |
+
f"- Tạo đúng {n} mục, đánh số từ 1 tới {n}.\n"
|
| 65 |
+
"- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
|
| 66 |
+
"- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'lựa chọn'.\n"
|
| 67 |
+
"- Không kèm giải thích hay trường thêm.\n"
|
| 68 |
+
"- Các phương án sai (distractors) phải hợp lý và không lặp lại."
|
| 69 |
+
)
|
| 70 |
+
}
|
| 71 |
+
user_message = {
|
| 72 |
+
"role": "user",
|
| 73 |
+
"content": (
|
| 74 |
+
f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Dùng nội dung này làm nguồn duy nhất để trả lời."
|
| 75 |
+
"Nếu nội dung quá ít để tạo câu hỏi chính xác, hãy tạo các phương án hợp lý nhưng có thể biện minh được.\n\n"
|
| 76 |
+
f"Nội dung:\n\n{source_text}"
|
| 77 |
+
)
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
|
| 81 |
+
parsed = _safe_extract_json(raw)
|
| 82 |
+
|
| 83 |
+
# validate structure and length
|
| 84 |
+
if not isinstance(parsed, dict) or len(parsed) != n:
|
| 85 |
+
raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
|
| 86 |
+
return parsed
|