namberino commited on
Commit
dfa5afb
·
1 Parent(s): 073c79b

Initial commit

Browse files
.dockerignore ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ *.sqlite3
6
+ .env
7
+ .env.*
8
+ .git
9
+ .gitignore
10
+ .wheelhouse
11
+ wheels
12
+ dist
13
+ build
14
+ .vscode
15
+ .idea
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
.github/workflows/hf_deploy.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to HuggingFace
2
+ on:
3
+ push:
4
+ branches: [ main ]
5
+
6
+ workflow_dispatch:
7
+
8
+ jobs:
9
+ sync-to-hub:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v3
13
+ with:
14
+ fetch-depth: 0
15
+ lfs: true
16
+ - name: Push to hub
17
+ env:
18
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
19
+ run: git push https://namberino:$HF_TOKEN@huggingface.co/spaces/namberino/mcq-generator main
.github/workflows/image_scan.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Scan docker image for security issues
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+
7
+ pull_request:
8
+
9
+ workflow_dispatch:
10
+
11
+ jobs:
12
+ trivy:
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+
17
+ - name: Build image
18
+ run: docker build -t mcq-gen:ci .
19
+
20
+ - name: Run Trivy scan
21
+ uses: aquasecurity/trivy-action@0.32.0
22
+ with:
23
+ image-ref: mcq-gen:ci
24
+ format: 'table'
25
+ exit-code: '1'
.github/workflows/lint_test.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint, Typecheck, Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+
7
+ pull_request:
8
+
9
+ workflow_dispatch:
10
+
11
+ jobs:
12
+ lint-and-test:
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+ with:
17
+ fetch-depth: 0
18
+
19
+ - name: Set up Python
20
+ uses: actions/setup-python@v4
21
+ with:
22
+ python-version: "3.11"
23
+
24
+ - name: Install dependencies
25
+ run: |
26
+ pip install ruff
27
+
28
+ - name: Run ruff (lint)
29
+ run: |
30
+ ruff check .
.github/workflows/sast.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: SAST
2
+
3
+ on:
4
+ push:
5
+ pull_request:
6
+ workflow_dispatch:
7
+
8
+ jobs:
9
+ sast:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v4
13
+
14
+ - name: Install tools
15
+ run: |
16
+ python -m pip install --upgrade pip
17
+ pip install semgrep bandit
18
+
19
+ - name: Run semgrep
20
+ run: semgrep --config auto --output semgrep-results.txt || true
21
+
22
+ - name: Run bandit
23
+ run: bandit -r . -f json -o bandit-results.json || true
24
+
25
+ - name: Upload SARIF/artifacts
26
+ uses: actions/upload-artifact@v4
27
+ with:
28
+ name: security-reports
29
+ path: |
30
+ semgrep-results.txt
31
+ bandit-results.json
.github/workflows/security_scan.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Scan for security issues
2
+ on:
3
+ push:
4
+ pull_request:
5
+ workflow_dispatch:
6
+
7
+ jobs:
8
+ gitleaks-scan:
9
+ runs-on: ubuntu-latest
10
+ name: Scan for secrets and sensitive information
11
+ steps:
12
+ - name: Checkout repo
13
+ uses: actions/checkout@v4
14
+ with:
15
+ fetch-depth: 0
16
+
17
+ - name: Run gitleaks
18
+ uses: gitleaks/gitleaks-action@v2
19
+ env:
20
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
21
+ # GITLEAKS_CONFIG: .gitleaks.toml
22
+ GITLEAKS_ENABLE_UPLOAD_ARTIFACT: true
23
+ GITLEAKS_ENABLE_SUMMARY: true
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .vscode
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # set HF cache to /tmp for writable FS on Spaces
4
+ ENV HF_HOME=/tmp/huggingface
5
+ ENV TOKENIZERS_PARALLELISM=false
6
+
7
+ # install system packages needed by some python libs
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ build-essential \
10
+ git \
11
+ wget \
12
+ libsndfile1 \
13
+ libgl1 \
14
+ libglib2.0-0 \
15
+ poppler-utils \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ WORKDIR /app
19
+
20
+ # copy requirements and install
21
+ COPY requirements.txt /app/requirements.txt
22
+ RUN pip install --upgrade pip
23
+ # try to be robust to wheels/build issues
24
+ # RUN pip wheel --no-cache-dir --wheel-dir=/wheels -r /app/requirements.txt || true
25
+ RUN pip install --no-cache-dir -r /app/requirements.txt
26
+
27
+ # copy app code
28
+ COPY . /app
29
+
30
+ EXPOSE 7860
31
+
32
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ from typing import Optional
5
+
6
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+
10
+ # Import the user's RAGMCQ implementation
11
+ from generator import RAGMCQ
12
+
13
+ app = FastAPI(title="RAG MCQ Generator API")
14
+
15
+ # allow cross-origin requests (adjust in production)
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # global rag instance
25
+ rag: Optional[RAGMCQ] = None
26
+
27
+ class GenerateResponse(BaseModel):
28
+ mcqs: dict
29
+ validation: Optional[dict] = None
30
+
31
+ class ListResponse(BaseModel):
32
+ files: list
33
+
34
+ @app.on_event("startup")
35
+ def startup_event():
36
+ global rag
37
+
38
+ # instantiate the heavy object once
39
+ rag = RAGMCQ(
40
+ qdrant_url=os.environ['QDRANT_URL'],
41
+ qdrant_api_key=os.environ['QDRANT_API_KEY']
42
+ )
43
+ print("RAGMCQ instance created on startup.")
44
+
45
+ @app.get("/health")
46
+ def health():
47
+ return {"status": "ok", "ready": rag is not None}
48
+
49
+ def _save_upload_to_temp(upload: UploadFile) -> str:
50
+ suffix = ".pdf"
51
+ fd, path = tempfile.mkstemp(suffix=suffix)
52
+ os.close(fd)
53
+ with open(path, "wb") as out_file:
54
+ shutil.copyfileobj(upload.file, out_file)
55
+ return path
56
+
57
+ @app.get("/list_collection_files", response_model=ListResponse)
58
+ async def list_collection_files_endpoint(
59
+ collection_name: str = "programming"
60
+ ):
61
+ global rag
62
+ if rag is None:
63
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
64
+
65
+ files = rag.list_files_in_collection(collection_name)
66
+
67
+ return {"files": files}
68
+
69
+ @app.post("/generate_saved", response_model=GenerateResponse)
70
+ async def generate_saved_endpoint(
71
+ n_questions: int = Form(10),
72
+ qdrant_filename: str = Form("default_filename"),
73
+ collection_name: str = Form("programming"),
74
+ mode: str = Form("rag"),
75
+ questions_per_chunk: int = Form(3),
76
+ top_k: int = Form(3),
77
+ temperature: float = Form(0.2),
78
+ validate: bool = Form(False),
79
+ use_model_verification: bool = Form(False)
80
+ ):
81
+ global rag
82
+ if rag is None:
83
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
84
+
85
+ try:
86
+ mcqs = rag.generate_from_qdrant(
87
+ filename=qdrant_filename,
88
+ collection=collection_name,
89
+ n_questions=n_questions,
90
+ mode=mode,
91
+ questions_per_chunk=questions_per_chunk,
92
+ top_k=top_k,
93
+ temperature=temperature
94
+ )
95
+ except Exception as e:
96
+ raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
97
+
98
+ validation_report = None
99
+
100
+ if validate:
101
+ try:
102
+ # validate_mcqs expects keys as strings and the normalized content
103
+ validation_report = rag.validate_mcqs(mcqs, top_k=top_k, use_model_verification=use_model_verification)
104
+ except Exception as e:
105
+ # don't fail the whole request for a validation error — return generator output and note the error
106
+ validation_report = {"error": f"Validation failed: {e}"}
107
+
108
+ return {"mcqs": mcqs, "validation": validation_report}
109
+
110
+ @app.post("/generate", response_model=GenerateResponse)
111
+ async def generate_endpoint(
112
+ background_tasks: BackgroundTasks,
113
+ file: UploadFile = File(...),
114
+ n_questions: int = Form(10),
115
+ qdrant_filename: str = Form("default_filename"),
116
+ collection_name: str = Form("programming"),
117
+ mode: str = Form("rag"),
118
+ questions_per_page: int = Form(3),
119
+ top_k: int = Form(3),
120
+ temperature: float = Form(0.2),
121
+ validate: bool = Form(False),
122
+ use_model_verification: bool = Form(False)
123
+ ):
124
+ global rag
125
+ if rag is None:
126
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
127
+
128
+ # basic file validation
129
+ if not file.filename.lower().endswith(".pdf"):
130
+ raise HTTPException(status_code=400, detail="Only PDF files are supported.")
131
+
132
+ # save uploaded file to a temp location
133
+ tmp_path = _save_upload_to_temp(file)
134
+
135
+ # ensure file removed afterward
136
+ def _cleanup(path: str):
137
+ try:
138
+ os.remove(path)
139
+ except Exception:
140
+ pass
141
+
142
+ background_tasks.add_task(_cleanup, tmp_path)
143
+
144
+ # save pdf
145
+ try:
146
+ rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=True)
147
+ except Exception as e:
148
+ raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
149
+
150
+ # generate
151
+ try:
152
+ mcqs = rag.generate_from_pdf(
153
+ tmp_path,
154
+ n_questions=n_questions,
155
+ mode=mode,
156
+ questions_per_page=questions_per_page,
157
+ top_k=top_k,
158
+ temperature=temperature,
159
+ )
160
+ except Exception as e:
161
+ raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
162
+
163
+ validation_report = None
164
+
165
+ if validate:
166
+ try:
167
+ # rag.build_index_from_pdf(tmp_path)
168
+ # validate_mcqs expects keys as strings and the normalized content
169
+ validation_report = rag.validate_mcqs(mcqs, top_k=top_k, use_model_verification=use_model_verification)
170
+ except Exception as e:
171
+ # don't fail the whole request for a validation error — return generator output and note the error
172
+ validation_report = {"error": f"Validation failed: {e}"}
173
+
174
+ return {"mcqs": mcqs, "validation": validation_report}
175
+
176
+
177
+ if __name__ == "__main__":
178
+ import uvicorn
179
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, log_level="info")
app/app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ from typing import Optional
5
+
6
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+
10
+ # Import the user's RAGMCQ implementation
11
+ from generator import RAGMCQ
12
+
13
+ app = FastAPI(title="RAG MCQ Generator API")
14
+
15
+ # allow cross-origin requests (adjust in production)
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ # global rag instance
25
+ rag: Optional[RAGMCQ] = None
26
+
27
+ class GenerateResponse(BaseModel):
28
+ mcqs: dict
29
+ validation: Optional[dict] = None
30
+
31
+ class ListResponse(BaseModel):
32
+ files: list
33
+
34
+ @app.on_event("startup")
35
+ def startup_event():
36
+ global rag
37
+
38
+ # instantiate the heavy object once
39
+ rag = RAGMCQ()
40
+ print("RAGMCQ instance created on startup.")
41
+
42
+ @app.get("/health")
43
+ def health():
44
+ return {"status": "ok", "ready": rag is not None}
45
+
46
+ def _save_upload_to_temp(upload: UploadFile) -> str:
47
+ suffix = ".pdf"
48
+ fd, path = tempfile.mkstemp(suffix=suffix)
49
+ os.close(fd)
50
+ with open(path, "wb") as out_file:
51
+ shutil.copyfileobj(upload.file, out_file)
52
+ return path
53
+
54
+ @app.get("/list_collection_files", response_model=ListResponse)
55
+ async def list_collection_files_endpoint(
56
+ collection_name: str = "programming"
57
+ ):
58
+ global rag
59
+ if rag is None:
60
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
61
+
62
+ files = rag.list_files_in_collection(collection_name)
63
+
64
+ return {"files": files}
65
+
66
+ @app.post("/generate_saved", response_model=GenerateResponse)
67
+ async def generate_saved_endpoint(
68
+ n_questions: int = Form(10),
69
+ qdrant_filename: str = Form("default_filename"),
70
+ collection_name: str = Form("programming"),
71
+ mode: str = Form("rag"),
72
+ questions_per_chunk: int = Form(3),
73
+ top_k: int = Form(3),
74
+ temperature: float = Form(0.2),
75
+ validate: bool = Form(False),
76
+ use_model_verification: bool = Form(False)
77
+ ):
78
+ global rag
79
+ if rag is None:
80
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
81
+
82
+ try:
83
+ mcqs = rag.generate_from_qdrant(
84
+ filename=qdrant_filename,
85
+ collection=collection_name,
86
+ n_questions=n_questions,
87
+ mode=mode,
88
+ questions_per_chunk=questions_per_chunk,
89
+ top_k=top_k,
90
+ temperature=temperature
91
+ )
92
+ except Exception as e:
93
+ raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
94
+
95
+ validation_report = None
96
+
97
+ if validate:
98
+ try:
99
+ # validate_mcqs expects keys as strings and the normalized content
100
+ validation_report = rag.validate_mcqs(mcqs, top_k=top_k, use_model_verification=use_model_verification)
101
+ except Exception as e:
102
+ # don't fail the whole request for a validation error — return generator output and note the error
103
+ validation_report = {"error": f"Validation failed: {e}"}
104
+
105
+ return {"mcqs": mcqs, "validation": validation_report}
106
+
107
+ @app.post("/generate", response_model=GenerateResponse)
108
+ async def generate_endpoint(
109
+ background_tasks: BackgroundTasks,
110
+ file: UploadFile = File(...),
111
+ n_questions: int = Form(10),
112
+ qdrant_filename: str = Form("default_filename"),
113
+ collection_name: str = Form("programming"),
114
+ mode: str = Form("rag"),
115
+ questions_per_page: int = Form(3),
116
+ top_k: int = Form(3),
117
+ temperature: float = Form(0.2),
118
+ validate: bool = Form(False),
119
+ use_model_verification: bool = Form(False)
120
+ ):
121
+ global rag
122
+ if rag is None:
123
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
124
+
125
+ # basic file validation
126
+ if not file.filename.lower().endswith(".pdf"):
127
+ raise HTTPException(status_code=400, detail="Only PDF files are supported.")
128
+
129
+ # save uploaded file to a temp location
130
+ tmp_path = _save_upload_to_temp(file)
131
+
132
+ # ensure file removed afterward
133
+ def _cleanup(path: str):
134
+ try:
135
+ os.remove(path)
136
+ except Exception:
137
+ pass
138
+
139
+ background_tasks.add_task(_cleanup, tmp_path)
140
+
141
+ # save pdf
142
+ try:
143
+ rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=True)
144
+ except Exception as e:
145
+ raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
146
+
147
+ # generate
148
+ try:
149
+ mcqs = rag.generate_from_pdf(
150
+ tmp_path,
151
+ n_questions=n_questions,
152
+ mode=mode,
153
+ questions_per_page=questions_per_page,
154
+ top_k=top_k,
155
+ temperature=temperature,
156
+ )
157
+ except Exception as e:
158
+ raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
159
+
160
+ validation_report = None
161
+
162
+ if validate:
163
+ try:
164
+ # rag.build_index_from_pdf(tmp_path)
165
+ # validate_mcqs expects keys as strings and the normalized content
166
+ validation_report = rag.validate_mcqs(mcqs, top_k=top_k, use_model_verification=use_model_verification)
167
+ except Exception as e:
168
+ # don't fail the whole request for a validation error — return generator output and note the error
169
+ validation_report = {"error": f"Validation failed: {e}"}
170
+
171
+ return {"mcqs": mcqs, "validation": validation_report}
172
+
173
+
174
+ if __name__ == "__main__":
175
+ import uvicorn
176
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, log_level="info")
app/generator.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import random
3
+ import numpy as np
4
+ from typing import List, Tuple, Dict, Any, Optional
5
+ from sentence_transformers import SentenceTransformer
6
+ from uuid import uuid4
7
+ import pymupdf4llm
8
+ import pymupdf as fitz
9
+
10
+ try:
11
+ from qdrant_client import QdrantClient
12
+ from qdrant_client.http.models import (
13
+ PointStruct,
14
+ Filter,
15
+ FieldCondition,
16
+ MatchValue,
17
+ Distance,
18
+ VectorParams,
19
+ )
20
+ from qdrant_client.http import models as rest
21
+ _HAS_QDRANT = True
22
+ except Exception:
23
+ _HAS_QDRANT = False
24
+
25
+ try:
26
+ import faiss
27
+ _HAS_FAISS = True
28
+ except Exception:
29
+ _HAS_FAISS = False
30
+
31
+ from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json
32
+
33
+ class RAGMCQ:
34
+ def __init__(
35
+ self,
36
+ embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
37
+ hf_model: str = "gpt-oss-120b",
38
+ qdrant_url: str = None,
39
+ qdrant_api_key: str = None,
40
+ qdrant_prefer_grpc: bool = False,
41
+ ):
42
+ self.embedder = SentenceTransformer(embedder_model)
43
+ self.hf_model = hf_model
44
+ self.embeddings = None # np.array of shape (N, D)
45
+ self.texts = [] # list of chunk texts
46
+ self.metadata = [] # list of dicts (page, chunk_id, char_range)
47
+ self.index = None
48
+ self.dim = self.embedder.get_sentence_embedding_dimension()
49
+
50
+ self.qdrant = None
51
+ self.qdrant_url = qdrant_url
52
+ self.qdrant_api_key = qdrant_api_key
53
+ self.qdrant_prefer_grpc = qdrant_prefer_grpc
54
+ if qdrant_url:
55
+ self.connect_qdrant(qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
56
+
57
+ def extract_pages(
58
+ self,
59
+ pdf_path: str,
60
+ *,
61
+ pages: Optional[List[int]] = None,
62
+ ignore_images: bool = False,
63
+ dpi: int = 150
64
+ ) -> List[str]:
65
+ doc = fitz.open(pdf_path)
66
+ try:
67
+ # request page-wise output (page_chunks=True -> list[dict] per page)
68
+ page_dicts = pymupdf4llm.to_markdown(
69
+ doc,
70
+ pages=pages,
71
+ ignore_images=ignore_images,
72
+ dpi=dpi,
73
+ page_chunks=True,
74
+ )
75
+
76
+ # to_markdown(..., page_chunks=True) returns a list of dicts, each has key "text" (markdown)
77
+ pages_md: List[str] = []
78
+ for p in page_dicts:
79
+ txt = p.get("text", "") or ""
80
+ pages_md.append(txt.strip())
81
+
82
+ return pages_md
83
+ finally:
84
+ doc.close()
85
+
86
+ def chunk_text(self, text: str, max_chars: int = 1200) -> List[str]:
87
+ text = text.strip()
88
+ if not text:
89
+ return []
90
+ if len(text) <= max_chars:
91
+ return [text]
92
+
93
+ # split by sentence-like boundaries
94
+ sentences = re.split(r'(?<=[\.\?\!])\s+', text)
95
+ chunks = []
96
+ cur = ""
97
+ for s in sentences:
98
+ if len(cur) + len(s) + 1 <= max_chars:
99
+ cur += (" " if cur else "") + s
100
+ else:
101
+ if cur:
102
+ chunks.append(cur)
103
+ cur = s
104
+ if cur:
105
+ chunks.append(cur)
106
+
107
+ # if still too long, hard-split
108
+ final = []
109
+ for c in chunks:
110
+ if len(c) <= max_chars:
111
+ final.append(c)
112
+ else:
113
+ for i in range(0, len(c), max_chars):
114
+ final.append(c[i:i+max_chars])
115
+ return final
116
+
117
+ def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
118
+ pages = self.extract_pages(pdf_path)
119
+ self.texts = []
120
+ self.metadata = []
121
+
122
+ for p_idx, page_text in enumerate(pages, start=1):
123
+ chunks = self.chunk_text(page_text or "", max_chars=max_chars)
124
+ for cid, ch in enumerate(chunks, start=1):
125
+ self.texts.append(ch)
126
+ self.metadata.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
127
+
128
+ if not self.texts:
129
+ raise RuntimeError("No text extracted from PDF.")
130
+
131
+ # compute embeddings
132
+ emb = self.embedder.encode(self.texts, convert_to_numpy=True, show_progress_bar=True)
133
+ self.embeddings = emb.astype("float32")
134
+ self._build_faiss_index()
135
+
136
+ def _build_faiss_index(self):
137
+ if _HAS_FAISS:
138
+ d = self.embeddings.shape[1]
139
+ index = faiss.IndexFlatIP(d) # inner product -> cosine if vectors normalized
140
+ faiss.normalize_L2(self.embeddings)
141
+ index.add(self.embeddings)
142
+ self.index = index
143
+ else:
144
+ # store normalized embeddings and use brute-force numpy
145
+ norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-10
146
+ self.embeddings = self.embeddings / norms
147
+ self.index = None
148
+
149
+ def _retrieve(self, query: str, top_k: int = 3) -> List[Tuple[int, float]]:
150
+ q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")
151
+
152
+ if _HAS_FAISS:
153
+ faiss.normalize_L2(q_emb)
154
+ D_list, I_list = self.index.search(q_emb, top_k)
155
+ # D are inner products; return list of (idx, score)
156
+ return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
157
+ else:
158
+ qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
159
+ sims = (self.embeddings @ qn.T).squeeze(axis=1)
160
+ idxs = np.argsort(-sims)[:top_k]
161
+ return [(int(i), float(sims[i])) for i in idxs]
162
+
163
+ def generate_from_pdf(
164
+ self,
165
+ pdf_path: str,
166
+ n_questions: int = 10,
167
+ mode: str = "rag", # per_page or rag
168
+ questions_per_page: int = 3, # for per_page mode
169
+ top_k: int = 3, # chunks to retrieve for each question in rag mode
170
+ temperature: float = 0.2,
171
+ ) -> Dict[str, Any]:
172
+ # build index
173
+ self.build_index_from_pdf(pdf_path)
174
+
175
+ output: Dict[str, Any] = {}
176
+ qcount = 0
177
+
178
+ if mode == "per_page":
179
+ # iterate pages -> chunks
180
+ for idx, meta in enumerate(self.metadata):
181
+ chunk_text = self.texts[idx]
182
+
183
+ if not chunk_text.strip():
184
+ continue
185
+ to_gen = questions_per_page
186
+
187
+ # ask generator
188
+ try:
189
+ mcq_block = generate_mcqs_from_text(
190
+ chunk_text, n=to_gen, model=self.hf_model, temperature=temperature
191
+ )
192
+ except Exception as e:
193
+ # skip this chunk if generator fails
194
+ print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
195
+ continue
196
+
197
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
198
+ qcount += 1
199
+ output[str(qcount)] = mcq_block[item]
200
+ if qcount >= n_questions:
201
+ return output
202
+
203
+ return output
204
+
205
+ elif mode == "rag":
206
+ # strategy: create a few natural short queries by sampling sentences or using chunk summaries.
207
+ # create queries by sampling chunk text sentences.
208
+ # stop when n_questions reached or max_attempts exceeded.
209
+ attempts = 0
210
+ max_attempts = n_questions * 4
211
+
212
+ while qcount < n_questions and attempts < max_attempts:
213
+ attempts += 1
214
+ # create a seed query: pick a random chunk, pick a sentence from it
215
+ seed_idx = random.randrange(len(self.texts))
216
+ chunk = self.texts[seed_idx]
217
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
218
+ seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
219
+ query = f"Create questions about: {seed_sent}"
220
+
221
+ # retrieve top_k chunks
222
+ retrieved = self._retrieve(query, top_k=top_k)
223
+ context_parts = []
224
+ for ridx, score in retrieved:
225
+ md = self.metadata[ridx]
226
+ context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
227
+ context = "\n\n".join(context_parts)
228
+
229
+ # call generator for 1 question (or small batch) with the retrieved context
230
+ try:
231
+ # request 1 question at a time to keep diversity
232
+ mcq_block = generate_mcqs_from_text(
233
+ context, n=1, model=self.hf_model, temperature=temperature
234
+ )
235
+ except Exception as e:
236
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
237
+ continue
238
+
239
+ # append result(s)
240
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
241
+ qcount += 1
242
+ output[str(qcount)] = mcq_block[item]
243
+ if qcount >= n_questions:
244
+ return output
245
+
246
+ return output
247
+ else:
248
+ raise ValueError("mode must be 'per_page' or 'rag'.")
249
+
250
+ def validate_mcqs(
251
+ self,
252
+ mcqs: Dict[str, Any],
253
+ top_k: int = 4,
254
+ similarity_threshold: float = 0.5,
255
+ evidence_score_cutoff: float = 0.5,
256
+ use_model_verification: bool = True,
257
+ model_verification_temperature: float = 0.0,
258
+ ) -> Dict[str, Any]:
259
+ if self.embeddings is None or not self.texts:
260
+ raise RuntimeError("Index/embeddings not built. Run build_index_from_pdf() first.")
261
+
262
+ report: Dict[str, Any] = {}
263
+
264
+ # helper: semantic similarity search on statement -> returns list of (idx, score)
265
+ def semantic_search(statement: str, k: int = top_k):
266
+ q_emb = self.embedder.encode([statement], convert_to_numpy=True).astype("float32")
267
+
268
+ if _HAS_FAISS:
269
+ faiss.normalize_L2(q_emb)
270
+ D_list, I_list = self.index.search(q_emb, k)
271
+ # D are inner products; return list of (idx, score)
272
+ return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
273
+ else:
274
+ qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
275
+ sims = (self.embeddings @ qn.T).squeeze(axis=1)
276
+ idxs = np.argsort(-sims)[:k]
277
+ return [(int(i), float(sims[i])) for i in idxs]
278
+
279
+ # helper: verify with model (strict JSON in response)
280
+ def _verify_with_model(question_text: str, options: Dict[str, str], correct_text: str, context_text: str):
281
+ system = {
282
+ "role": "system",
283
+ "content": (
284
+ "Bạn là một trợ lý đánh giá tính thực chứng của câu hỏi trắc nghiệm dựa trên đoạn văn được cung cấp. "
285
+ "Hãy trả lời DUY NHẤT bằng JSON hợp lệ (không có văn bản khác) theo schema:\n\n"
286
+ "{\n"
287
+ ' "supported": true/false, # câu trả lời đúng có được nội dung chứng thực không\n'
288
+ ' "confidence": 0.0-1.0, # mức độ tự tin (số)\n'
289
+ ' "evidence": "cụm văn bản ngắn làm bằng chứng hoặc trích dẫn",\n'
290
+ ' "reason": "ngắn gọn, vì sao supported hoặc không"\n'
291
+ "}\n\n"
292
+ "Luôn dựa chỉ trên nội dung trong trường 'Context' dưới đây. Nếu nội dung không chứa bằng chứng, trả về supported: false."
293
+ )
294
+ }
295
+ user = {
296
+ "role": "user",
297
+ "content": (
298
+ "Câu hỏi:\n" + question_text + "\n\n"
299
+ "Lựa chọn:\n" + "\n".join([f"{k}: {v}" for k, v in options.items()]) + "\n\n"
300
+ "Đáp án:\n" + correct_text + "\n\n"
301
+ "Context:\n" + context_text + "\n\n"
302
+ "Hãy trả lời như yêu cầu."
303
+ )
304
+ }
305
+
306
+ raw = _post_chat([system, user], model=self.hf_model, temperature=model_verification_temperature)
307
+
308
+ # parse JSON object in response
309
+ try:
310
+ parsed = _safe_extract_json(raw)
311
+ except Exception as e:
312
+ return {"error": f"Model verification failed to return JSON: {e}", "raw": raw}
313
+ return parsed
314
+
315
+ # iterate MCQs
316
+ for qid, item in mcqs.items():
317
+ q_text = item.get("câu hỏi", "").strip()
318
+ options = item.get("lựa chọn", {})
319
+ correct_text = item.get("đáp án", "").strip()
320
+
321
+ # form a short declarative statement to embed: "Question: ... Answer: <correct>"
322
+ statement = f"{q_text} Answer: {correct_text}"
323
+
324
+ retrieved = semantic_search(statement, k=top_k)
325
+ evidence_list = []
326
+ max_sim = 0.0
327
+ for idx, score in retrieved:
328
+ if score >= evidence_score_cutoff:
329
+ evidence_list.append({
330
+ "idx": idx,
331
+ "page": self.metadata[idx].get("page", None),
332
+ "score": float(score),
333
+ "text": (self.texts[idx][:1000] + ("..." if len(self.texts[idx]) > 1000 else "")),
334
+ })
335
+
336
+ if score > max_sim:
337
+ max_sim = float(score)
338
+
339
+ supported_by_embeddings = max_sim >= similarity_threshold
340
+
341
+ model_verdict = None
342
+ if use_model_verification:
343
+ # build a context string from top retrieved chunks (regardless of cutoff)
344
+ context_parts = []
345
+ for ridx, sc in retrieved:
346
+ md = self.metadata[ridx]
347
+ context_parts.append(f"[page {md.get('page')}] {self.texts[ridx]}")
348
+ context_text = "\n\n".join(context_parts)
349
+
350
+ try:
351
+ parsed = _verify_with_model(q_text, options, correct_text, context_text)
352
+ model_verdict = parsed
353
+ except Exception as e:
354
+ model_verdict = {"error": f"verification exception: {e}"}
355
+
356
+ report[qid] = {
357
+ "supported_by_embeddings": bool(supported_by_embeddings),
358
+ "max_similarity": float(max_sim),
359
+ "evidence": evidence_list,
360
+ "model_verdict": model_verdict,
361
+ }
362
+
363
+ return report
364
+
365
+ def connect_qdrant(self, url: str, api_key: str = None, prefer_grpc: bool = False):
366
+ if not _HAS_QDRANT:
367
+ raise RuntimeError("qdrant-client is not installed. Install with `pip install qdrant-client`.")
368
+ self.qdrant_url = url
369
+ self.qdrant_api_key = api_key
370
+ self.qdrant_prefer_grpc = prefer_grpc
371
+ # Create client
372
+ self.qdrant = QdrantClient(url=url, api_key=api_key, prefer_grpc=prefer_grpc)
373
+
374
+ def _ensure_collection(self, collection_name: str):
375
+ if self.qdrant is None:
376
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
377
+ try:
378
+ # get_collection will raise if not present
379
+ _ = self.qdrant.get_collection(collection_name)
380
+ except Exception:
381
+ # create collection with vector size = self.dim
382
+ vect_params = VectorParams(size=self.dim, distance=Distance.COSINE)
383
+ self.qdrant.recreate_collection(collection_name=collection_name, vectors_config=vect_params)
384
+ # recreate_collection ensures a clean collection; if you prefer to avoid wiping use create_collection instead.
385
+
386
+ def save_pdf_to_qdrant(
387
+ self,
388
+ pdf_path: str,
389
+ filename: str,
390
+ collection: str,
391
+ max_chars: int = 1200,
392
+ batch_size: int = 64,
393
+ overwrite: bool = False,
394
+ ):
395
+ if self.qdrant is None:
396
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
397
+
398
+ # extract pages and chunks (re-using your existing helpers)
399
+ pages = self.extract_pages(pdf_path)
400
+ all_chunks = []
401
+ all_meta = []
402
+ for p_idx, page_text in enumerate(pages, start=1):
403
+ chunks = self.chunk_text(page_text or "", max_chars=max_chars)
404
+ for cid, ch in enumerate(chunks, start=1):
405
+ all_chunks.append(ch)
406
+ all_meta.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
407
+
408
+ if not all_chunks:
409
+ raise RuntimeError("No text extracted from PDF.")
410
+
411
+ # ensure collection exists
412
+ self._ensure_collection(collection)
413
+
414
+ # optional: delete previous points for this filename if overwrite
415
+ if overwrite:
416
+ # delete by filter: filename == filename
417
+ flt = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
418
+ try:
419
+ # qdrant-client delete uses delete(
420
+ self.qdrant.delete(collection_name=collection, filter=flt)
421
+ except Exception:
422
+ # ignore if deletion fails
423
+ pass
424
+
425
+ # compute embeddings in batches
426
+ embeddings = self.embedder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
427
+ embeddings = embeddings.astype("float32")
428
+
429
+ # prepare points
430
+ points = []
431
+ for i, (emb, md, txt) in enumerate(zip(embeddings, all_meta, all_chunks)):
432
+ pid = str(uuid4())
433
+ source_id = f"{filename}__p{md['page']}__c{md['chunk_id']}"
434
+ payload = {
435
+ "filename": filename,
436
+ "page": md["page"],
437
+ "chunk_id": md["chunk_id"],
438
+ "length": md["length"],
439
+ "text": txt,
440
+ "source_id": source_id,
441
+ }
442
+ points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload))
443
+
444
+ # upsert in batches
445
+ if len(points) >= batch_size:
446
+ self.qdrant.upsert(collection_name=collection, points=points)
447
+ points = []
448
+
449
+ # upsert remaining
450
+ if points:
451
+ self.qdrant.upsert(collection_name=collection, points=points)
452
+
453
+ try:
454
+ self.qdrant.create_payload_index(
455
+ collection_name=collection,
456
+ field_name="filename",
457
+ field_schema=rest.PayloadSchemaType.KEYWORD
458
+ )
459
+ except Exception as e:
460
+ print(f"Index creation skipped or failed: {e}")
461
+
462
+ return {"status": "ok", "uploaded_chunks": len(all_chunks), "collection": collection, "filename": filename}
463
+
464
+ def list_files_in_collection(
465
+ self,
466
+ collection: str,
467
+ payload_field: str = "filename",
468
+ batch_size: int = 500,
469
+ ) -> List[str]:
470
+ if self.qdrant is None:
471
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
472
+
473
+ # ensure collection exists
474
+ try:
475
+ if not self.qdrant.collection_exists(collection):
476
+ raise RuntimeError(f"Collection '{collection}' does not exist.")
477
+ except Exception:
478
+ # collection_exists may raise if server unreachable
479
+ raise
480
+
481
+ filenames = set()
482
+ offset = None
483
+
484
+ while True:
485
+ # scroll returns (points, next_offset)
486
+ pts, next_offset = self.qdrant.scroll(
487
+ collection_name=collection,
488
+ limit=batch_size,
489
+ offset=offset,
490
+ with_payload=[payload_field],
491
+ with_vectors=False,
492
+ )
493
+
494
+ if not pts:
495
+ break
496
+
497
+ for p in pts:
498
+ # p may be a dict-like or an object with .payload
499
+ payload = None
500
+ if hasattr(p, "payload"):
501
+ payload = p.payload
502
+ elif isinstance(p, dict):
503
+ # older/newer variants might use nested structures: try common keys
504
+ payload = p.get("payload") or p.get("payload", None) or p
505
+ else:
506
+ # best-effort fallback: convert to dict if possible
507
+ try:
508
+ payload = dict(p)
509
+ except Exception:
510
+ payload = None
511
+
512
+ if not payload:
513
+ continue
514
+
515
+ # extract candidate value(s)
516
+ val = None
517
+ if isinstance(payload, dict):
518
+ val = payload.get(payload_field)
519
+ else:
520
+ # Some payload representations store fields differently; try attribute access
521
+ val = getattr(payload, payload_field, None)
522
+
523
+ # If value is list-like, iterate, else add single
524
+ if isinstance(val, (list, tuple, set)):
525
+ for v in val:
526
+ if v is not None:
527
+ filenames.add(str(v))
528
+ elif val is not None:
529
+ filenames.add(str(val))
530
+
531
+ # stop if no more pages
532
+ if not next_offset:
533
+ break
534
+ offset = next_offset
535
+
536
+ return sorted(filenames)
537
+
538
+ def list_chunks_for_filename(self, collection: str, filename: str, batch: int = 256) -> List[Dict[str, Any]]:
539
+ if self.qdrant is None:
540
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
541
+
542
+ results = []
543
+ offset = None
544
+ while True:
545
+ # scroll returns (points, next_offset)
546
+ points, next_offset = self.qdrant.scroll(
547
+ collection_name=collection,
548
+ scroll_filter=Filter(
549
+ must=[
550
+ FieldCondition(key="filename", match=MatchValue(value=filename))
551
+ ]
552
+ ),
553
+ limit=batch,
554
+ offset=offset,
555
+ with_payload=True,
556
+ with_vectors=False,
557
+ )
558
+ # points are objects (Record / ScoredPoint-like); get id and payload
559
+ for p in points:
560
+ # p.payload is a dict, p.id is point id
561
+ results.append({"point_id": p.id, "payload": p.payload})
562
+ if not next_offset:
563
+ break
564
+ offset = next_offset
565
+ return results
566
+
567
+ def _retrieve_qdrant(self, query: str, collection: str, filename: str = None, top_k: int = 3) -> List[Tuple[Dict[str, Any], float]]:
568
+ if self.qdrant is None:
569
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
570
+
571
+ q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")[0].tolist()
572
+ q_filter = None
573
+ if filename:
574
+ q_filter = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
575
+
576
+ search_res = self.qdrant.search(
577
+ collection_name=collection,
578
+ query_vector=q_emb,
579
+ query_filter=q_filter,
580
+ limit=top_k,
581
+ with_payload=True,
582
+ with_vectors=False,
583
+ )
584
+
585
+ out = []
586
+ for hit in search_res:
587
+ # hit.payload is the stored payload, hit.score is similarity
588
+ out.append((hit.payload, float(getattr(hit, "score", 0.0))))
589
+ return out
590
+
591
+ def generate_from_qdrant(
592
+ self,
593
+ filename: str,
594
+ collection: str,
595
+ n_questions: int = 10,
596
+ mode: str = "rag", # 'per_chunk' or 'rag'
597
+ questions_per_chunk: int = 3, # used for 'per_chunk'
598
+ top_k: int = 3, # retrieval size used in RAG
599
+ temperature: float = 0.2,
600
+ ) -> Dict[str, Any]:
601
+ if self.qdrant is None:
602
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
603
+
604
+ # get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
605
+ file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
606
+ if not file_points:
607
+ raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
608
+
609
+ # create a local list of texts & metadata for sampling
610
+ texts = []
611
+ metas = []
612
+ for p in file_points:
613
+ payload = p.get("payload", {})
614
+ text = payload.get("text", "")
615
+ texts.append(text)
616
+ metas.append(payload)
617
+
618
+ self.texts = texts
619
+ self.metadata = metas
620
+ embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
621
+ if embeddings is None or len(embeddings) == 0:
622
+ self.embeddings = None
623
+ self.index = None
624
+ else:
625
+ self.embeddings = embeddings.astype("float32")
626
+
627
+ # update dim in case embedder changed unexpectedly
628
+ self.dim = int(self.embeddings.shape[1])
629
+
630
+ # build index
631
+ self._build_faiss_index()
632
+
633
+ output = {}
634
+ qcount = 0
635
+
636
+ if mode == "per_chunk":
637
+ # iterate all chunks (in payload order) and request questions_per_chunk from each
638
+ for i, txt in enumerate(texts):
639
+ if not txt.strip():
640
+ continue
641
+ to_gen = questions_per_chunk
642
+ try:
643
+ mcq_block = generate_mcqs_from_text(txt, n=to_gen, model=self.hf_model, temperature=temperature)
644
+ except Exception as e:
645
+ print(f"Generator failed on chunk (index {i}): {e}")
646
+ continue
647
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
648
+ qcount += 1
649
+ output[str(qcount)] = mcq_block[item]
650
+ if qcount >= n_questions:
651
+ return output
652
+ return output
653
+
654
+ elif mode == "rag":
655
+ attempts = 0
656
+ max_attempts = n_questions * 4
657
+ while qcount < n_questions and attempts < max_attempts:
658
+ attempts += 1
659
+ # sample a seed sentence from a random chunk of this file
660
+ seed_idx = random.randrange(len(texts))
661
+ chunk = texts[seed_idx]
662
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
663
+ seed_sent = None
664
+ for s in sents:
665
+ if len(s.strip()) > 20:
666
+ seed_sent = s
667
+ break
668
+ if not seed_sent:
669
+ seed_sent = chunk[:200]
670
+ query = f"Create questions about: {seed_sent}"
671
+
672
+ # retrieve top_k chunks from the same file (restricted by filename filter)
673
+ retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
674
+ context_parts = []
675
+ for payload, score in retrieved:
676
+ # payload should contain page & chunk_id and text
677
+ page = payload.get("page", "?")
678
+ ctxt = payload.get("text", "")
679
+ context_parts.append(f"[page {page}] {ctxt}")
680
+ context = "\n\n".join(context_parts)
681
+
682
+ try:
683
+ mcq_block = generate_mcqs_from_text(context, n=1, model=self.hf_model, temperature=temperature)
684
+ except Exception as e:
685
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
686
+ continue
687
+
688
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
689
+ qcount += 1
690
+ output[str(qcount)] = mcq_block[item]
691
+ if qcount >= n_questions:
692
+ return output
693
+ return output
694
+ else:
695
+ raise ValueError("mode must be 'per_chunk' or 'rag'.")
app/output.json ADDED
The diff for this file is too large to render. See raw diff
 
app/software_report_template.md ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Software Report: RAG-based MCQ Generation System
2
+
3
+ ## 1. Overview / Abstract
4
+ The project provides an API service that ingests a PDF document and automatically generates multiple–choice questions (MCQs) using a Retrieval-Augmented Generation (RAG) pipeline. It exposes a FastAPI endpoint (`/generate`) that orchestrates: PDF text extraction → chunking → embedding + indexing → (mode-dependent) context selection → MCQ generation via an LLM (Together AI chat completion) → optional semantic + model-based validation.
5
+
6
+ Core components:
7
+ - Controller (FastAPI endpoints) – handles HTTP, file upload, response shaping.
8
+ - Use Case (RAGMCQ class) – encapsulates business logic: indexing, retrieval, generation, validation.
9
+ - Repositories / Data Stores – implicit: in‑memory lists of chunks, embeddings, optional FAISS index.
10
+
11
+ ## 2. High-Level Workflow Diagram
12
+ ### Mermaid Activity Diagram
13
+ ```mermaid
14
+ flowchart LR
15
+ A[Client Uploads PDF -> /generate] --> B{Mode?}
16
+ B -->|rag| R1[Extract & Chunk PDF]
17
+ B -->|per_page| R1
18
+ R1 --> R2[SentenceTransformer Embeddings]
19
+ R2 --> R3{FAISS Available?}
20
+ R3 -->|Yes| R4[Build FAISS Index]
21
+ R3 -->|No| R5[Normalize Embeddings (NumPy)]
22
+ R4 --> R6[Question Generation Loop]
23
+ R5 --> R6
24
+ R6 -->|rag: sample queries + retrieve top-k| R7[Assemble Context]
25
+ R6 -->|per_page: iterate chunks| R7
26
+ R7 --> G1[Prompt LLM (JSON MCQs)]
27
+ G1 --> P1[Parse & Validate JSON shape]
28
+ P1 --> C{Need more?}
29
+ C -->|Yes| R6
30
+ C -->|No| V{Validation requested?}
31
+ V -->|Yes| V1[Semantic Evidence Search + (Optional) Model Verification]
32
+ V -->|No| OUT[Return MCQs]
33
+ V1 --> OUT
34
+ ```
35
+
36
+ ### Alternative PlantUML Activity (Optional)
37
+ ```plantuml
38
+ @startuml
39
+ start
40
+ :Upload PDF (multipart form);
41
+ :Select params (mode, n_questions,...);
42
+ :Extract pages via pdfplumber;
43
+ :Chunk text (sentence pack <= max_chars);
44
+ :Embed chunks (SentenceTransformer);
45
+ if (FAISS installed?) then (yes)
46
+ :Build FAISS IndexFlatIP + L2 normalize;
47
+ else (no)
48
+ :Keep normalized NumPy embeddings;
49
+ endif
50
+ repeat
51
+ if (mode == per_page) then (per_page)
52
+ :Take next chunk;
53
+ else (rag)
54
+ :Sample seed sentence;
55
+ :Encode query & retrieve top-k chunks;
56
+ endif
57
+ :Assemble context;
58
+ :Call Together AI chat completion (prompt -> JSON);
59
+ :Parse JSON + accumulate MCQs;
60
+ repeat while (Need more questions?) is (yes)
61
+ end repeat
62
+ if (validate?) then (yes)
63
+ :For each Q -> build statement;
64
+ :Similarity search top_k evidence;
65
+ if (Insufficient sim & model verify on) then (yes)
66
+ :Call model for verification JSON;
67
+ endif
68
+ :Build validation report;
69
+ endif
70
+ :Return response JSON;
71
+ stop
72
+ @enduml
73
+ ```
74
+
75
+ ## 3. Repository–Controller–Use Case Abstraction
76
+ | Layer | Responsibility | In This Project |
77
+ |-------|---------------|-----------------|
78
+ | Controller | HTTP I/O, request validation, mapping domain results to API schema | `app.py` endpoints (`/health`, `/generate`) |
79
+ | Use Case | Orchestrates domain flow, independent of HTTP details | `RAGMCQ` methods: `build_index_from_pdf`, `generate_from_pdf`, `validate_mcqs` |
80
+ | Repository (implicit) | Data persistence / retrieval | In-memory: `texts`, `metadata`, `embeddings`, `FAISS index` (no external DB) |
81
+
82
+ Data Flow (simplified):
83
+ Client → Controller(`/generate`) → UseCase(`generate_from_pdf`) → (Extract + Chunk + Embed + Index + Retrieve + Generate) → Controller (normalize/optional validation) → Response
84
+
85
+ ## 4. Detailed Pipeline Explanation
86
+ ### 4.1 PDF Text Extraction & Chunking
87
+ - File saved to a temp path, then `pdfplumber` loads each page.
88
+ - `extract_pages()` returns list of raw page strings.
89
+ - `chunk_text()` packs sentences (regex split on punctuation boundaries) into segments up to `max_chars` (default 1200). If a sentence overflows, the existing chunk is flushed. Residual oversize chunks are hard-split.
90
+ - Metadata collected: page number, chunk id, length.
91
+
92
+ ### 4.2 Embedding Generation
93
+ - Model: `sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2` loaded via `SentenceTransformer`.
94
+ - Batched encoding of all chunks → NumPy array (float32).
95
+ - If FAISS installed: L2 normalize embeddings, create `IndexFlatIP` (inner product ~ cosine after normalization), add embeddings.
96
+ - Else: manually store normalized embeddings for brute-force cosine similarity with matrix multiply.
97
+
98
+ ### 4.3 Retrieval Strategy
99
+ Two modes:
100
+ 1. `per_page`: Sequentially process each chunk; each call to the LLM asks for `questions_per_page` new MCQs until target `n_questions` reached.
101
+ 2. `rag`: Loop builds a synthetic query by sampling a random chunk and a sentence. Retrieval:
102
+ - Encode query → similarity search (FAISS or NumPy).
103
+ - Take top-k chunk texts; join them with page tags as context.
104
+ - Request 1 question per iteration (promotes diversity). Up to `max_attempts = n_questions * 4`.
105
+
106
+ Similarity Metric: Inner product on normalized vectors (equivalent to cosine). Sorting by descending similarity.
107
+
108
+ ### 4.4 Question Generation Prompt Template
109
+ Implemented in `generate_mcqs_from_text` (utils):
110
+ - System message (Vietnamese) forcing strict JSON schema:
111
+ ```json
112
+ {
113
+ "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"..."},
114
+ "2": { ... }
115
+ }
116
+ ```
117
+ - Constraints: exactly `n` entries; answer must be full text identical to one option; no explanations.
118
+ - User message: instructs generation from provided source text only.
119
+ - Post-processing: Regex extracts first JSON object; attempts `json.loads`; fallback removes trailing commas.
120
+
121
+ ### 4.5 Validation (Optional)
122
+ For each MCQ (after normalization in controller):
123
+ 1. Construct statement: `Question + Answer`.
124
+ 2. Embed query → retrieve top_k evidence chunks.
125
+ 3. Mark `supported_by_embeddings` if max similarity ≥ threshold.
126
+ 4. If not supported and model verification enabled, call verification LLM prompt (also JSON-only) to assess `supported`, `confidence`, `evidence`, `reason`.
127
+
128
+ ### 4.6 Together AI Integration
129
+ - Endpoint: `https://api.together.xyz/v1/chat/completions`.
130
+ - Authorization header uses `TOGETHER_KEY` environment variable.
131
+ - Payload: `{ model, messages, temperature }`.
132
+ - Response Handling: support both OpenAI-like `choices[0].message.content` and fallback `choices[0].text`.
133
+
134
+ ## 5. API Endpoints
135
+ ### 5.1 Health Check
136
+ GET `/health`
137
+ Response:
138
+ ```json
139
+ { "status": "ok", "ready": true }
140
+ ```
141
+
142
+ ### 5.2 Generate MCQs
143
+ POST `/generate` (multipart/form-data)
144
+ Fields:
145
+ - `file` (PDF) – required
146
+ - `n_questions` (int, default 10)
147
+ - `mode` ("rag" | "per_page", default "rag")
148
+ - `questions_per_page` (int, default 3) – used only in per_page mode
149
+ - `top_k` (int, default 3) – retrieval depth (rag & validation)
150
+ - `temperature` (float, default 0.2)
151
+ - `validate` (bool, default false)
152
+ - `debug` (bool) – if truthy writes `output.json` locally
153
+
154
+ Example Request (curl, PowerShell style quoting simplified):
155
+ ```bash
156
+ curl -X POST http://localhost:8000/generate ^
157
+ -F "file=@sample.pdf" ^
158
+ -F "n_questions=5" ^
159
+ -F "mode=rag" ^
160
+ -F "top_k=3" ^
161
+ -F "validate=true"
162
+ ```
163
+
164
+ Success Response (validation on, abbreviated):
165
+ ```json
166
+ {
167
+ "mcqs": {
168
+ "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án": "..."},
169
+ "2": { "câu hỏi": "...", "lựa chọn": { ... }, "đáp án": "..." }
170
+ },
171
+ "validation": {
172
+ "1": {
173
+ "supported_by_embeddings": true,
174
+ "max_similarity": 0.83,
175
+ "evidence": [ { "page": 2, "score": 0.81, "text": "Excerpt..." } ],
176
+ "model_verdict": null
177
+ }
178
+ }
179
+ }
180
+ ```
181
+
182
+ Error Examples:
183
+ - 400: non-PDF upload
184
+ - 500: generation pipeline error (e.g., empty PDF or model failure)
185
+ - 503: service not initialized
186
+
187
+ ## 6. Data Structures & Types (Conceptual)
188
+ - Chunk: `{ text: str, page: int, chunk_id: int, length: int }`
189
+ - MCQ (generated raw): `{ "câu hỏi": str, "lựa chọn": {"a": str, ...}, "đáp án": str }`
190
+ - Normalized MCQ (API shaping): `{ mcq: str, options: { .. }, correct: str }`
191
+ - Validation Entry: `{ supported_by_embeddings: bool, max_similarity: float, evidence: [ {page, score, text}... ], model_verdict?: {...} }`
192
+
193
+ ## 7. Configuration Points
194
+ | Parameter | Location | Purpose |
195
+ |-----------|----------|---------|
196
+ | `embedder_model` | `RAGMCQ.__init__` | Pretrained SentenceTransformer model name |
197
+ | `hf_model` | `RAGMCQ.__init__` | LLM model name for generation/verification |
198
+ | `top_k` | API form field & internal methods | Retrieval depth |
199
+ | `temperature` | API form field | Creativity vs determinism |
200
+ | `questions_per_page` | API form field | Batch size per chunk in per_page mode |
201
+
202
+ ## 8. Simple Code Improvements (Quick Wins)
203
+ Below are low-risk refactors to make the code cleaner and more maintainable:
204
+
205
+ 1. Environment Variable Safety:
206
+ ```python
207
+ def _require_env(name: str) -> str:
208
+ val = os.getenv(name)
209
+ if not val:
210
+ raise RuntimeError(f"Missing required environment variable: {name}")
211
+ return val
212
+ TOGETHER_KEY = _require_env("TOGETHER_KEY")
213
+ ```
214
+ 2. Remove Unused Constant: `API_URL` in `utils.py` is unused (can delete to avoid confusion).
215
+ 3. Unify Header Construction: Replace separate `HEADERS` / `TOGETHER_HEADERS` with a single function `auth_headers(provider)` that returns the correct dict.
216
+ 4. Add Dataclass for MCQ:
217
+ ```python
218
+ from dataclasses import dataclass
219
+ @dataclass
220
+ class MCQ: question: str; options: Dict[str,str]; answer: str
221
+ ```
222
+ Helps type clarity in validation.
223
+ 5. Extract Prompt Templates: Store system/user template strings as module-level constants to avoid duplication and ease future edits.
224
+ 6. Fail-Fast on Empty PDF: Early check after extraction to return a user-friendly error message rather than a generic 500 later.
225
+ 7. Replace Random Query Sampling Magic Numbers: Expose `max_attempts_factor` as a parameter (currently `n_questions * 4`).
226
+ 8. Vector Normalization Consistency: Always keep an unnormalized copy if future scoring types are needed; currently normalization overwrites original when FAISS absent.
227
+ 9. Logging Standardization: Replace scattered `print()` with Python `logging` module (configurable levels; avoids polluting stdout in production).
228
+ 10. Validation Normalization: Move `_normalize_mcqs` from `app.py` into `RAGMCQ` (keeps domain logic together; controller stays thin).
229
+ 11. Error Message Specificity: On generation failure wrap exceptions with context (page/chunk), but avoid leaking internal stack to clients; log full internally.
230
+ 12. Dependency Pinning: Specify versions in `requirements.txt` for reproducibility (e.g., `sentence-transformers==2.2.2`).
231
+ 13. Add `/models` Endpoint (Optional): Expose available embedder & generation models for UI introspection.
232
+ 14. Add Basic Tests: e.g., a test for `chunk_text` (ensures boundaries) and JSON parsing fallback.
233
+ 15. Reusable Retrieval: Expose a public `retrieve(query, top_k)` method to support future features (like user-specified queries) without duplicating private logic.
234
+
235
+ ## 9. Potential Medium-Term Enhancements
236
+ | Area | Improvement |
237
+ |------|-------------|
238
+ | Prompt Robustness | Add JSON schema validation (e.g., `jsonschema`) & auto-regeneration for malformed outputs |
239
+ | Performance | Embed asynchronously / stream generation if backend supports it |
240
+ | Multi-Provider | Abstract provider strategy for HuggingFace, Together, OpenAI with pluggable client classes |
241
+ | Caching | Cache embeddings per PDF hash to avoid reprocessing identical documents |
242
+ | Analytics | Track generation latency, validation pass rate, average similarity in structured logs |
243
+ | i18n | Parameterize language; currently prompts in Vietnamese only |
244
+
245
+ ## 10. Security & Operational Notes
246
+ - Ensure `TOGETHER_KEY` is not committed; rely on environment variables / secret managers.
247
+ - Limit PDF size and number of pages to prevent excessive memory or token usage.
248
+ - Consider sanitizing extracted text (remove personally identifiable info) before sending to LLM if sensitive documents are used.
249
+ - Add request timeout & retry logic for the LLM API (current single call may raise immediately).
250
+
251
+ ## 11. Quick Start (Local)
252
+ 1. Set API key: `setx TOGETHER_KEY "your_api_key"` (then restart shell).
253
+ 2. Install dependencies: `pip install -r requirements.txt`.
254
+ 3. Run API: `uvicorn app:app --reload`.
255
+ 4. POST a PDF to `/generate`.
256
+
257
+ ## 12. Summary
258
+ The system cleanly separates HTTP handling from the core RAG pipeline. Text is chunked at sentence boundaries, embedded, indexed (FAISS if available), and retrieved to assemble focused contexts that guide a JSON-constrained MCQ generation prompt. Optional validation uses embedding similarity and secondary model verification to flag unsupported questions. Suggested refactors improve safety, clarity, extensibility, and readiness for multi-provider expansion.
259
+
260
+ ---
261
+ This report delivers architectural insight, workflow diagrams, detailed pipeline mechanics, API contract, and actionable improvement ideas for rapid comprehension and iteration.
app/utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ from typing import Dict, Any
4
+ import requests
5
+ import os
6
+
7
+ #TODO: allow to choose different provider later + dynamic routing when token expired
8
+ API_URL = "https://api.cerebras.ai/v1/chat/completions"
9
+ CEREBRAS_API_KEY = os.environ['CEREBRAS_API_KEY']
10
+
11
+ HEADERS = {"Authorization": f"Bearer {CEREBRAS_API_KEY}"}
12
+ JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
13
+
14
+ def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
15
+ payload = {"model": model, "messages": messages, "temperature": temperature}
16
+ resp = requests.post(API_URL, headers=HEADERS, json=payload, timeout=timeout)
17
+ resp.raise_for_status()
18
+ data = resp.json()
19
+
20
+ # handle various shapes
21
+ if "choices" in data and len(data["choices"]) > 0:
22
+ # prefer message.content
23
+ ch = data["choices"][0]
24
+
25
+ if isinstance(ch, dict) and "message" in ch and "content" in ch["message"]:
26
+ return ch["message"]["content"]
27
+
28
+ if "text" in ch:
29
+ return ch["text"]
30
+
31
+ # final fallback
32
+ raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
33
+
34
+ def _safe_extract_json(text: str) -> dict:
35
+ # remove triple backticks
36
+ text = re.sub(r"```(?:json)?\n?", "", text)
37
+ m = JSON_OBJ_RE.search(text)
38
+
39
+ if not m:
40
+ raise ValueError("No JSON object found in model output.")
41
+ js = m.group(1)
42
+
43
+ # try load, fix trailing commas
44
+ try:
45
+ return json.loads(js)
46
+ except json.JSONDecodeError:
47
+ fixed = re.sub(r",\s*([}\]])", r"\1", js)
48
+ return json.loads(fixed)
49
+
50
+ def generate_mcqs_from_text(
51
+ source_text: str,
52
+ n: int = 3,
53
+ model: str = "gpt-oss-120b",
54
+ temperature: float = 0.2,
55
+ ) -> Dict[str, Any]:
56
+ system_message = {
57
+ "role": "system",
58
+ "content": (
59
+ "Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm. "
60
+ "Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không có bất kỳ văn bản nào khác:\n\n"
61
+ "{\n"
62
+ ' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"..."},\n'
63
+ ' "2": { ... }\n'
64
+ "}\n\n"
65
+ "Lưu ý:\n"
66
+ f"- Tạo đúng {n} mục, đánh YOUR_API_KEYsố từ 1 tới {n}.\n"
67
+ "- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
68
+ "- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'lựa chọn'.\n"
69
+ "- Không kèm giải thích hay trường thêm.\n"
70
+ "- Các phương án sai (distractors) phải hợp lý và không lặp lại."
71
+ )
72
+ }
73
+ user_message = {
74
+ "role": "user",
75
+ "content": (
76
+ f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Dùng nội dung này làm nguồn duy nhất để trả lời."
77
+ "Nếu nội dung quá ít để tạo câu hỏi chính xác, hãy tạo các phương án hợp lý nhưng có thể biện minh được.\n\n"
78
+ f"Nội dung:\n\n{source_text}"
79
+ )
80
+ }
81
+
82
+ raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
83
+ parsed = _safe_extract_json(raw)
84
+
85
+ # validate structure and length
86
+ if not isinstance(parsed, dict) or len(parsed) != n:
87
+ raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
88
+ return parsed
generator.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import random
3
+ import numpy as np
4
+ from typing import List, Tuple, Dict, Any, Optional
5
+ from sentence_transformers import SentenceTransformer
6
+ from uuid import uuid4
7
+ import pymupdf4llm
8
+ import pymupdf as fitz
9
+
10
+ try:
11
+ from qdrant_client import QdrantClient
12
+ from qdrant_client.http.models import (
13
+ PointStruct,
14
+ Filter,
15
+ FieldCondition,
16
+ MatchValue,
17
+ Distance,
18
+ VectorParams,
19
+ )
20
+ from qdrant_client.http import models as rest
21
+ _HAS_QDRANT = True
22
+ except Exception:
23
+ _HAS_QDRANT = False
24
+
25
+ try:
26
+ import faiss
27
+ _HAS_FAISS = True
28
+ except Exception:
29
+ _HAS_FAISS = False
30
+
31
+ from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json
32
+
33
+ class RAGMCQ:
34
+ def __init__(
35
+ self,
36
+ embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
37
+ hf_model: str = "openai/gpt-oss-120b:cerebras",
38
+ qdrant_url: str = None,
39
+ qdrant_api_key: str = None,
40
+ qdrant_prefer_grpc: bool = False,
41
+ ):
42
+ self.embedder = SentenceTransformer(embedder_model)
43
+ self.hf_model = hf_model
44
+ self.embeddings = None # np.array of shape (N, D)
45
+ self.texts = [] # list of chunk texts
46
+ self.metadata = [] # list of dicts (page, chunk_id, char_range)
47
+ self.index = None
48
+ self.dim = self.embedder.get_sentence_embedding_dimension()
49
+
50
+ self.qdrant = None
51
+ self.qdrant_url = qdrant_url
52
+ self.qdrant_api_key = qdrant_api_key
53
+ self.qdrant_prefer_grpc = qdrant_prefer_grpc
54
+ if qdrant_url:
55
+ self.connect_qdrant(qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
56
+
57
+ def extract_pages(
58
+ self,
59
+ pdf_path: str,
60
+ *,
61
+ pages: Optional[List[int]] = None,
62
+ ignore_images: bool = False,
63
+ dpi: int = 150
64
+ ) -> List[str]:
65
+ doc = fitz.open(pdf_path)
66
+ try:
67
+ # request page-wise output (page_chunks=True -> list[dict] per page)
68
+ page_dicts = pymupdf4llm.to_markdown(
69
+ doc,
70
+ pages=pages,
71
+ ignore_images=ignore_images,
72
+ dpi=dpi,
73
+ page_chunks=True,
74
+ )
75
+
76
+ # to_markdown(..., page_chunks=True) returns a list of dicts, each has key "text" (markdown)
77
+ pages_md: List[str] = []
78
+ for p in page_dicts:
79
+ txt = p.get("text", "") or ""
80
+ pages_md.append(txt.strip())
81
+
82
+ return pages_md
83
+ finally:
84
+ doc.close()
85
+
86
+ def chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 100) -> List[str]:
87
+ text = text.strip()
88
+ if not text:
89
+ return []
90
+ if len(text) <= max_chars:
91
+ return [text]
92
+
93
+ # split by sentence-like boundaries
94
+ sentences = re.split(r'(?<=[\.\?\!])\s+', text)
95
+ chunks = []
96
+ cur = ""
97
+ for s in sentences:
98
+ if len(cur) + len(s) + 1 <= max_chars:
99
+ cur += (" " if cur else "") + s
100
+ else:
101
+ if cur:
102
+ chunks.append(cur)
103
+ cur = (cur[-overlap:] + " " + s) if overlap > 0 else s
104
+ if cur:
105
+ chunks.append(cur)
106
+
107
+ # if still too long, hard-split
108
+ final = []
109
+ for c in chunks:
110
+ if len(c) <= max_chars:
111
+ final.append(c)
112
+ else:
113
+ for i in range(0, len(c), max_chars):
114
+ final.append(c[i:i+max_chars])
115
+ return final
116
+
117
+ def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
118
+ pages = self.extract_pages(pdf_path)
119
+ self.texts = []
120
+ self.metadata = []
121
+
122
+ for p_idx, page_text in enumerate(pages, start=1):
123
+ chunks = self.chunk_text(page_text or "", max_chars=max_chars)
124
+ for cid, ch in enumerate(chunks, start=1):
125
+ self.texts.append(ch)
126
+ self.metadata.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
127
+
128
+ if not self.texts:
129
+ raise RuntimeError("No text extracted from PDF.")
130
+
131
+ # compute embeddings
132
+ emb = self.embedder.encode(self.texts, convert_to_numpy=True, show_progress_bar=True)
133
+ self.embeddings = emb.astype("float32")
134
+ self._build_faiss_index()
135
+
136
+ def _build_faiss_index(self, ef_construction=200, M=32):
137
+ if _HAS_FAISS:
138
+ d = self.embeddings.shape[1]
139
+ index = faiss.IndexHNSWFlat(d, M)
140
+ faiss.normalize_L2(self.embeddings)
141
+ index.add(self.embeddings)
142
+ index.hnsw.efConstruction = ef_construction
143
+ self.index = index
144
+ else:
145
+ # store normalized embeddings and use brute-force numpy
146
+ norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-10
147
+ self.embeddings = self.embeddings / norms
148
+ self.index = None
149
+
150
+ def _retrieve(self, query: str, top_k: int = 3) -> List[Tuple[int, float]]:
151
+ q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")
152
+
153
+ if _HAS_FAISS:
154
+ faiss.normalize_L2(q_emb)
155
+ D_list, I_list = self.index.search(q_emb, top_k)
156
+ # D are inner products; return list of (idx, score)
157
+ return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
158
+ else:
159
+ qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
160
+ sims = (self.embeddings @ qn.T).squeeze(axis=1)
161
+ idxs = np.argsort(-sims)[:top_k]
162
+ return [(int(i), float(sims[i])) for i in idxs]
163
+
164
+ def generate_from_pdf(
165
+ self,
166
+ pdf_path: str,
167
+ n_questions: int = 10,
168
+ mode: str = "rag", # per_page or rag
169
+ questions_per_page: int = 3, # for per_page mode
170
+ top_k: int = 3, # chunks to retrieve for each question in rag mode
171
+ temperature: float = 0.2,
172
+ ) -> Dict[str, Any]:
173
+ # build index
174
+ self.build_index_from_pdf(pdf_path)
175
+
176
+ output: Dict[str, Any] = {}
177
+ qcount = 0
178
+
179
+ if mode == "per_page":
180
+ # iterate pages -> chunks
181
+ for idx, meta in enumerate(self.metadata):
182
+ chunk_text = self.texts[idx]
183
+
184
+ if not chunk_text.strip():
185
+ continue
186
+ to_gen = questions_per_page
187
+
188
+ # ask generator
189
+ try:
190
+ mcq_block = generate_mcqs_from_text(
191
+ chunk_text, n=to_gen, model=self.hf_model, temperature=temperature
192
+ )
193
+ except Exception as e:
194
+ # skip this chunk if generator fails
195
+ print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
196
+ continue
197
+
198
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
199
+ qcount += 1
200
+ output[str(qcount)] = mcq_block[item]
201
+ if qcount >= n_questions:
202
+ return output
203
+
204
+ return output
205
+
206
+ elif mode == "rag":
207
+ # strategy: create a few natural short queries by sampling sentences or using chunk summaries.
208
+ # create queries by sampling chunk text sentences.
209
+ # stop when n_questions reached or max_attempts exceeded.
210
+ attempts = 0
211
+ max_attempts = n_questions * 4
212
+
213
+ while qcount < n_questions and attempts < max_attempts:
214
+ attempts += 1
215
+ # create a seed query: pick a random chunk, pick a sentence from it
216
+ seed_idx = random.randrange(len(self.texts))
217
+ chunk = self.texts[seed_idx]
218
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
219
+ seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
220
+ query = f"Create questions about: {seed_sent}"
221
+
222
+ # retrieve top_k chunks
223
+ retrieved = self._retrieve(query, top_k=top_k)
224
+ context_parts = []
225
+ for ridx, score in retrieved:
226
+ md = self.metadata[ridx]
227
+ context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
228
+ context = "\n\n".join(context_parts)
229
+
230
+ # call generator for 1 question (or small batch) with the retrieved context
231
+ try:
232
+ # request 1 question at a time to keep diversity
233
+ mcq_block = generate_mcqs_from_text(
234
+ context, n=1, model=self.hf_model, temperature=temperature
235
+ )
236
+ except Exception as e:
237
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
238
+ continue
239
+
240
+ # append result(s)
241
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
242
+ qcount += 1
243
+ output[str(qcount)] = mcq_block[item]
244
+ if qcount >= n_questions:
245
+ return output
246
+
247
+ return output
248
+ else:
249
+ raise ValueError("mode must be 'per_page' or 'rag'.")
250
+
251
+ def validate_mcqs(
252
+ self,
253
+ mcqs: Dict[str, Any],
254
+ top_k: int = 4,
255
+ similarity_threshold: float = 0.5,
256
+ evidence_score_cutoff: float = 0.5,
257
+ use_model_verification: bool = True,
258
+ model_verification_temperature: float = 0.0,
259
+ ) -> Dict[str, Any]:
260
+ if self.embeddings is None or not self.texts:
261
+ raise RuntimeError("Index/embeddings not built. Run build_index_from_pdf() first.")
262
+
263
+ report: Dict[str, Any] = {}
264
+
265
+ # helper: semantic similarity search on statement -> returns list of (idx, score)
266
+ def semantic_search(statement: str, k: int = top_k):
267
+ q_emb = self.embedder.encode([statement], convert_to_numpy=True).astype("float32")
268
+
269
+ if _HAS_FAISS:
270
+ faiss.normalize_L2(q_emb)
271
+ D_list, I_list = self.index.search(q_emb, k)
272
+ # D are inner products; return list of (idx, score)
273
+ return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
274
+ else:
275
+ qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
276
+ sims = (self.embeddings @ qn.T).squeeze(axis=1)
277
+ idxs = np.argsort(-sims)[:k]
278
+ return [(int(i), float(sims[i])) for i in idxs]
279
+
280
+ # helper: verify with model (strict JSON in response)
281
+ def _verify_with_model(question_text: str, options: Dict[str, str], correct_text: str, context_text: str):
282
+ system = {
283
+ "role": "system",
284
+ "content": (
285
+ "Bạn là một trợ lý đánh giá tính thực chứng của câu hỏi trắc nghiệm dựa trên đoạn văn được cung cấp. "
286
+ "Hãy trả lời DUY NHẤT bằng JSON hợp lệ (không có văn bản khác) theo schema:\n\n"
287
+ "{\n"
288
+ ' "supported": true/false, # câu trả lời đúng có được nội dung chứng thực không\n'
289
+ ' "confidence": 0.0-1.0, # mức độ tự tin (số)\n'
290
+ ' "evidence": "cụm văn bản ngắn làm bằng chứng hoặc trích dẫn",\n'
291
+ ' "reason": "ngắn gọn, vì sao supported hoặc không"\n'
292
+ "}\n\n"
293
+ "Luôn dựa chỉ trên nội dung trong trường 'Context' dưới đây. Nếu nội dung không chứa bằng chứng, trả về supported: false."
294
+ )
295
+ }
296
+ user = {
297
+ "role": "user",
298
+ "content": (
299
+ "Câu hỏi:\n" + question_text + "\n\n"
300
+ "Lựa chọn:\n" + "\n".join([f"{k}: {v}" for k, v in options.items()]) + "\n\n"
301
+ "Đáp án:\n" + correct_text + "\n\n"
302
+ "Context:\n" + context_text + "\n\n"
303
+ "Hãy trả lời như yêu cầu."
304
+ )
305
+ }
306
+
307
+ raw = _post_chat([system, user], model=self.hf_model, temperature=model_verification_temperature)
308
+
309
+ # parse JSON object in response
310
+ try:
311
+ parsed = _safe_extract_json(raw)
312
+ except Exception as e:
313
+ return {"error": f"Model verification failed to return JSON: {e}", "raw": raw}
314
+ return parsed
315
+
316
+ # iterate MCQs
317
+ for qid, item in mcqs.items():
318
+ q_text = item.get("câu hỏi", "").strip()
319
+ options = item.get("lựa chọn", {})
320
+ correct_text = item.get("đáp án", "").strip()
321
+
322
+ # form a short declarative statement to embed: "Question: ... Answer: <correct>"
323
+ statement = f"{q_text} Answer: {correct_text}"
324
+
325
+ retrieved = semantic_search(statement, k=top_k)
326
+ evidence_list = []
327
+ max_sim = 0.0
328
+ for idx, score in retrieved:
329
+ if score >= evidence_score_cutoff:
330
+ evidence_list.append({
331
+ "idx": idx,
332
+ "page": self.metadata[idx].get("page", None),
333
+ "score": float(score),
334
+ "text": (self.texts[idx][:1000] + ("..." if len(self.texts[idx]) > 1000 else "")),
335
+ })
336
+
337
+ if score > max_sim:
338
+ max_sim = float(score)
339
+
340
+ supported_by_embeddings = max_sim >= similarity_threshold
341
+
342
+ model_verdict = None
343
+ if use_model_verification:
344
+ # build a context string from top retrieved chunks (regardless of cutoff)
345
+ context_parts = []
346
+ for ridx, sc in retrieved:
347
+ md = self.metadata[ridx]
348
+ context_parts.append(f"[page {md.get('page')}] {self.texts[ridx]}")
349
+ context_text = "\n\n".join(context_parts)
350
+
351
+ try:
352
+ parsed = _verify_with_model(q_text, options, correct_text, context_text)
353
+ model_verdict = parsed
354
+ except Exception as e:
355
+ model_verdict = {"error": f"verification exception: {e}"}
356
+
357
+ report[qid] = {
358
+ "supported_by_embeddings": bool(supported_by_embeddings),
359
+ "max_similarity": float(max_sim),
360
+ "evidence": evidence_list,
361
+ "model_verdict": model_verdict,
362
+ }
363
+
364
+ return report
365
+
366
+ def connect_qdrant(self, url: str, api_key: str = None, prefer_grpc: bool = False):
367
+ if not _HAS_QDRANT:
368
+ raise RuntimeError("qdrant-client is not installed. Install with `pip install qdrant-client`.")
369
+ self.qdrant_url = url
370
+ self.qdrant_api_key = api_key
371
+ self.qdrant_prefer_grpc = prefer_grpc
372
+ # Create client
373
+ self.qdrant = QdrantClient(url=url, api_key=api_key, prefer_grpc=prefer_grpc)
374
+
375
+ def _ensure_collection(self, collection_name: str):
376
+ if self.qdrant is None:
377
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
378
+ try:
379
+ # get_collection will raise if not present
380
+ _ = self.qdrant.get_collection(collection_name)
381
+ except Exception:
382
+ # create collection with vector size = self.dim
383
+ vect_params = VectorParams(size=self.dim, distance=Distance.COSINE)
384
+ self.qdrant.recreate_collection(collection_name=collection_name, vectors_config=vect_params)
385
+ # recreate_collection ensures a clean collection; if you prefer to avoid wiping use create_collection instead.
386
+
387
+ def save_pdf_to_qdrant(
388
+ self,
389
+ pdf_path: str,
390
+ filename: str,
391
+ collection: str,
392
+ max_chars: int = 1200,
393
+ batch_size: int = 64,
394
+ overwrite: bool = False,
395
+ ):
396
+ if self.qdrant is None:
397
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
398
+
399
+ # extract pages and chunks (re-using your existing helpers)
400
+ pages = self.extract_pages(pdf_path)
401
+ all_chunks = []
402
+ all_meta = []
403
+ for p_idx, page_text in enumerate(pages, start=1):
404
+ chunks = self.chunk_text(page_text or "", max_chars=max_chars)
405
+ for cid, ch in enumerate(chunks, start=1):
406
+ all_chunks.append(ch)
407
+ all_meta.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
408
+
409
+ if not all_chunks:
410
+ raise RuntimeError("No text extracted from PDF.")
411
+
412
+ # ensure collection exists
413
+ self._ensure_collection(collection)
414
+
415
+ # optional: delete previous points for this filename if overwrite
416
+ if overwrite:
417
+ # delete by filter: filename == filename
418
+ flt = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
419
+ try:
420
+ # qdrant-client delete uses delete(
421
+ self.qdrant.delete(collection_name=collection, filter=flt)
422
+ except Exception:
423
+ # ignore if deletion fails
424
+ pass
425
+
426
+ # compute embeddings in batches
427
+ embeddings = self.embedder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
428
+ embeddings = embeddings.astype("float32")
429
+
430
+ # prepare points
431
+ points = []
432
+ for i, (emb, md, txt) in enumerate(zip(embeddings, all_meta, all_chunks)):
433
+ pid = str(uuid4())
434
+ source_id = f"{filename}__p{md['page']}__c{md['chunk_id']}"
435
+ payload = {
436
+ "filename": filename,
437
+ "page": md["page"],
438
+ "chunk_id": md["chunk_id"],
439
+ "length": md["length"],
440
+ "text": txt,
441
+ "source_id": source_id,
442
+ }
443
+ points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload))
444
+
445
+ # upsert in batches
446
+ if len(points) >= batch_size:
447
+ self.qdrant.upsert(collection_name=collection, points=points)
448
+ points = []
449
+
450
+ # upsert remaining
451
+ if points:
452
+ self.qdrant.upsert(collection_name=collection, points=points)
453
+
454
+ try:
455
+ self.qdrant.create_payload_index(
456
+ collection_name=collection,
457
+ field_name="filename",
458
+ field_schema=rest.PayloadSchemaType.KEYWORD
459
+ )
460
+ except Exception as e:
461
+ print(f"Index creation skipped or failed: {e}")
462
+
463
+ return {"status": "ok", "uploaded_chunks": len(all_chunks), "collection": collection, "filename": filename}
464
+
465
+ def list_files_in_collection(
466
+ self,
467
+ collection: str,
468
+ payload_field: str = "filename",
469
+ batch_size: int = 500,
470
+ ) -> List[str]:
471
+ if self.qdrant is None:
472
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
473
+
474
+ # ensure collection exists
475
+ try:
476
+ if not self.qdrant.collection_exists(collection):
477
+ raise RuntimeError(f"Collection '{collection}' does not exist.")
478
+ except Exception:
479
+ # collection_exists may raise if server unreachable
480
+ raise
481
+
482
+ filenames = set()
483
+ offset = None
484
+
485
+ while True:
486
+ # scroll returns (points, next_offset)
487
+ pts, next_offset = self.qdrant.scroll(
488
+ collection_name=collection,
489
+ limit=batch_size,
490
+ offset=offset,
491
+ with_payload=[payload_field],
492
+ with_vectors=False,
493
+ )
494
+
495
+ if not pts:
496
+ break
497
+
498
+ for p in pts:
499
+ # p may be a dict-like or an object with .payload
500
+ payload = None
501
+ if hasattr(p, "payload"):
502
+ payload = p.payload
503
+ elif isinstance(p, dict):
504
+ # older/newer variants might use nested structures: try common keys
505
+ payload = p.get("payload") or p.get("payload", None) or p
506
+ else:
507
+ # best-effort fallback: convert to dict if possible
508
+ try:
509
+ payload = dict(p)
510
+ except Exception:
511
+ payload = None
512
+
513
+ if not payload:
514
+ continue
515
+
516
+ # extract candidate value(s)
517
+ val = None
518
+ if isinstance(payload, dict):
519
+ val = payload.get(payload_field)
520
+ else:
521
+ # Some payload representations store fields differently; try attribute access
522
+ val = getattr(payload, payload_field, None)
523
+
524
+ # If value is list-like, iterate, else add single
525
+ if isinstance(val, (list, tuple, set)):
526
+ for v in val:
527
+ if v is not None:
528
+ filenames.add(str(v))
529
+ elif val is not None:
530
+ filenames.add(str(val))
531
+
532
+ # stop if no more pages
533
+ if not next_offset:
534
+ break
535
+ offset = next_offset
536
+
537
+ return sorted(filenames)
538
+
539
+ def list_chunks_for_filename(self, collection: str, filename: str, batch: int = 256) -> List[Dict[str, Any]]:
540
+ if self.qdrant is None:
541
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
542
+
543
+ results = []
544
+ offset = None
545
+ while True:
546
+ # scroll returns (points, next_offset)
547
+ points, next_offset = self.qdrant.scroll(
548
+ collection_name=collection,
549
+ scroll_filter=Filter(
550
+ must=[
551
+ FieldCondition(key="filename", match=MatchValue(value=filename))
552
+ ]
553
+ ),
554
+ limit=batch,
555
+ offset=offset,
556
+ with_payload=True,
557
+ with_vectors=False,
558
+ )
559
+ # points are objects (Record / ScoredPoint-like); get id and payload
560
+ for p in points:
561
+ # p.payload is a dict, p.id is point id
562
+ results.append({"point_id": p.id, "payload": p.payload})
563
+ if not next_offset:
564
+ break
565
+ offset = next_offset
566
+ return results
567
+
568
+ def _retrieve_qdrant(self, query: str, collection: str, filename: str = None, top_k: int = 3) -> List[Tuple[Dict[str, Any], float]]:
569
+ if self.qdrant is None:
570
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
571
+
572
+ q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")[0].tolist()
573
+ q_filter = None
574
+ if filename:
575
+ q_filter = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
576
+
577
+ search_res = self.qdrant.search(
578
+ collection_name=collection,
579
+ query_vector=q_emb,
580
+ query_filter=q_filter,
581
+ limit=top_k,
582
+ with_payload=True,
583
+ with_vectors=False,
584
+ )
585
+
586
+ out = []
587
+ for hit in search_res:
588
+ # hit.payload is the stored payload, hit.score is similarity
589
+ out.append((hit.payload, float(getattr(hit, "score", 0.0))))
590
+ return out
591
+
592
+ def generate_from_qdrant(
593
+ self,
594
+ filename: str,
595
+ collection: str,
596
+ n_questions: int = 10,
597
+ mode: str = "rag", # 'per_chunk' or 'rag'
598
+ questions_per_chunk: int = 3, # used for 'per_chunk'
599
+ top_k: int = 3, # retrieval size used in RAG
600
+ temperature: float = 0.2,
601
+ ) -> Dict[str, Any]:
602
+ if self.qdrant is None:
603
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
604
+
605
+ # get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
606
+ file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
607
+ if not file_points:
608
+ raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
609
+
610
+ # create a local list of texts & metadata for sampling
611
+ texts = []
612
+ metas = []
613
+ for p in file_points:
614
+ payload = p.get("payload", {})
615
+ text = payload.get("text", "")
616
+ texts.append(text)
617
+ metas.append(payload)
618
+
619
+ self.texts = texts
620
+ self.metadata = metas
621
+ embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
622
+ if embeddings is None or len(embeddings) == 0:
623
+ self.embeddings = None
624
+ self.index = None
625
+ else:
626
+ self.embeddings = embeddings.astype("float32")
627
+
628
+ # update dim in case embedder changed unexpectedly
629
+ self.dim = int(self.embeddings.shape[1])
630
+
631
+ # build index
632
+ self._build_faiss_index()
633
+
634
+ output = {}
635
+ qcount = 0
636
+
637
+ if mode == "per_chunk":
638
+ # iterate all chunks (in payload order) and request questions_per_chunk from each
639
+ for i, txt in enumerate(texts):
640
+ if not txt.strip():
641
+ continue
642
+ to_gen = questions_per_chunk
643
+ try:
644
+ mcq_block = generate_mcqs_from_text(txt, n=to_gen, model=self.hf_model, temperature=temperature)
645
+ except Exception as e:
646
+ print(f"Generator failed on chunk (index {i}): {e}")
647
+ continue
648
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
649
+ qcount += 1
650
+ output[str(qcount)] = mcq_block[item]
651
+ if qcount >= n_questions:
652
+ return output
653
+ return output
654
+
655
+ elif mode == "rag":
656
+ attempts = 0
657
+ max_attempts = n_questions * 4
658
+ while qcount < n_questions and attempts < max_attempts:
659
+ attempts += 1
660
+ # sample a seed sentence from a random chunk of this file
661
+ seed_idx = random.randrange(len(texts))
662
+ chunk = texts[seed_idx]
663
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
664
+ seed_sent = None
665
+ for s in sents:
666
+ if len(s.strip()) > 20:
667
+ seed_sent = s
668
+ break
669
+ if not seed_sent:
670
+ seed_sent = chunk[:200]
671
+ query = f"Create questions about: {seed_sent}"
672
+
673
+ # retrieve top_k chunks from the same file (restricted by filename filter)
674
+ retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
675
+ context_parts = []
676
+ for payload, score in retrieved:
677
+ # payload should contain page & chunk_id and text
678
+ page = payload.get("page", "?")
679
+ ctxt = payload.get("text", "")
680
+ context_parts.append(f"[page {page}] {ctxt}")
681
+ context = "\n\n".join(context_parts)
682
+
683
+ try:
684
+ mcq_block = generate_mcqs_from_text(context, n=1, model=self.hf_model, temperature=temperature)
685
+ except Exception as e:
686
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
687
+ continue
688
+
689
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
690
+ qcount += 1
691
+ output[str(qcount)] = mcq_block[item]
692
+ if qcount >= n_questions:
693
+ return output
694
+ return output
695
+ else:
696
+ raise ValueError("mode must be 'per_chunk' or 'rag'.")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ boto3
2
+ pdfplumber
3
+ faiss-cpu
4
+ sentence-transformers
5
+ fastapi[standard]
6
+ uvicorn[standard]
7
+ qdrant-client
8
+ pymupdf4llm
test/cerebras-api.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # from cerebras.cloud.sdk import Cerebras
3
+ import tiktoken
4
+
5
+ # client = Cerebras(
6
+ # # This is the default and can be omitted
7
+ # api_key=os.environ.get("CEREBRAS_API_KEY")
8
+ # )
9
+
10
+ # stream = client.chat.completions.create(
11
+ # messages=[
12
+ # {
13
+ # "role": "system",
14
+ # "content": ""
15
+ # }
16
+ # ],
17
+ # model="gpt-oss-120b",
18
+ # stream=True,
19
+ # max_completion_tokens=65536,
20
+ # temperature=1,
21
+ # top_p=1
22
+ # )
23
+ import numpy as np
24
+
25
+ INPUT_TOKEN_COUNT = np.array([], dtype=int)
26
+ OUTPUT_TOKEN_COUNT = np.array([], dtype=int)
27
+
28
+ # for chunk in stream:
29
+ # print(chunk.choices[0].delta.content or "", end="")
30
+ with open('../test/mcq_output.json', 'r', encoding='utf-8') as f:
31
+ text = f.read()
32
+
33
+ def count_tokens(text: str, model_name='gpt-oss-120b', encoding_name='cl100k_base') -> int:
34
+ """Look up model encoding; fallback to encoding_name if model not known."""
35
+ try:
36
+ # encoding_for_model can raise if model is unknown to tiktoken
37
+ enc = tiktoken.encoding_for_model(model_name)
38
+ except Exception:
39
+ enc = None
40
+
41
+ if enc is None:
42
+ enc = tiktoken.get_encoding(encoding_name)
43
+
44
+ return len(enc.encode(text))
45
+
46
+ c = count_tokens(text)
47
+ INPUT_TOKEN_COUNT = np.append(INPUT_TOKEN_COUNT, c)
48
+ print(INPUT_TOKEN_COUNT)
test/logging.txt ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **********************
2
+ Windows PowerShell transcript start
3
+ Start time: 20250815161919
4
+ Username: MACBOOKM5\boboi
5
+ RunAs User: MACBOOKM5\boboi
6
+ Configuration Name:
7
+ Machine: MACBOOKM5 (Microsoft Windows NT 10.0.22631.0)
8
+ Host Application: C:\WINDOWS\System32\WindowsPowerShell\v1.0\powershell.exe
9
+ Process ID: 20936
10
+ PSVersion: 5.1.22621.5697
11
+ PSEdition: Desktop
12
+ PSCompatibleVersions: 1.0, 2.0, 3.0, 4.0, 5.0, 5.1.22621.5697
13
+ BuildVersion: 10.0.22621.5697
14
+ CLRVersion: 4.0.30319.42000
15
+ WSManStackVersion: 3.0
16
+ PSRemotingProtocolVersion: 2.3
17
+ SerializationVersion: 1.1.0.1
18
+ **********************
19
+ Transcript started, output file is D:\graduation_project\mcq-generator\test\logging.txt
20
+ PS D:\graduation_project\mcq-generator\app>
21
+ (rag-api) uvicorn app:app --reload
22
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\base_events.py", li
23
+ ne 608, in run_forever
24
+ self._run_once()
25
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\base_events.py", li
26
+ ne 1936, in _run_once
27
+ handle._run()
28
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\events.py", line 84
29
+ , in _run
30
+ self._context.run(self._callback, *self._args)
31
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\uvicorn\serve
32
+ r.py", line 70, in serve
33
+ with self.capture_signals():
34
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\contextlib.py", line 144, i
35
+ n __exit__
36
+ next(self.gen)
37
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\uvicorn\serve
38
+ r.py", line 331, in capture_signals
39
+ signal.raise_signal(captured_signal)
40
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\runners.py", line 1
41
+ 57, in _on_sigint
42
+ raise KeyboardInterrupt()
43
+ KeyboardInterrupt
44
+
45
+ During handling of the above exception, another exception occurred:
46
+
47
+ Traceback (most recent call last):
48
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\starlette\rou
49
+ ting.py", line 701, in lifespan
50
+ await receive()
51
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\uvicorn\lifes
52
+ pan\on.py", line 137, in receive
53
+ return await self.receive_queue.get()
54
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\queues.py", line 15
56
+ 8, in get
57
+ await getter
58
+ asyncio.exceptions.CancelledError
59
+
60
+ INFO: Stopping reloader process [21928]
61
+ (rag-api) TerminatingError(): "The pipeline has been stopped."
62
+ >> TerminatingError(): "The pipeline has been stopped."
63
+ PS D:\graduation_project\mcq-generator\app>
64
+ (rag-api) uvicorn app:app --reload
65
+ warnings.warn(
66
+ INFO: Started server process [20356]
67
+ INFO: Waiting for application startup.
68
+ RAGMCQ instance created on startup.
69
+ INFO: Application startup complete.
70
+ ERROR: Traceback (most recent call last):
71
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\runners.py", line 1
72
+ 18, in run
73
+ return self._loop.run_until_complete(task)
74
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
75
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\base_events.py", li
76
+ ne 654, in run_until_complete
77
+ return future.result()
78
+ ^^^^^^^^^^^^^^^
79
+ asyncio.exceptions.CancelledError
80
+
81
+ During handling of the above exception, another exception occurred:
82
+
83
+ Traceback (most recent call last):
84
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\runners.py", line 1
85
+ 90, in run
86
+ return runner.run(main)
87
+ ^^^^^^^^^^^^^^^^
88
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\runners.py", line 1
89
+ 23, in run
90
+ raise KeyboardInterrupt()
91
+ KeyboardInterrupt
92
+
93
+ During handling of the above exception, another exception occurred:
94
+
95
+ Traceback (most recent call last):
96
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\starlette\rou
97
+ ting.py", line 701, in lifespan
98
+ await receive()
99
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\site-packages\uvicorn\lifes
100
+ pan\on.py", line 137, in receive
101
+ return await self.receive_queue.get()
102
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
103
+ File "D:\CODE\IDE\Anaconda\envs\rag-api\Lib\asyncio\queues.py", line 15
104
+ 8, in get
105
+ await getter
106
+ asyncio.exceptions.CancelledError
107
+
108
+ INFO: Stopping reloader process [1968]
109
+ (rag-api) TerminatingError(): "The pipeline has been stopped."
110
+ >> TerminatingError(): "The pipeline has been stopped."
111
+ PS D:\graduation_project\mcq-generator\app>
test/mcq_output.json ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mcqs": {
3
+ "1": {
4
+ "câu hỏi": "Trong lớp Str_OutputParser, biểu thức chính quy nào được sử dụng để trích xuất câu trả lời từ chuỗi phản hồi?",
5
+ "lựa chọn": {
6
+ "a": "r\"Answer:\\s*(.*)\"",
7
+ "b": "r\"Respuesta:\\s*(.*)\"",
8
+ "c": "r\"Answer :\\s*(.*)\"",
9
+ "d": "r\"Result:\\s*(.*)\""
10
+ },
11
+ "đáp án": "r\"Answer :\\s*(.*)\""
12
+ },
13
+ "2": {
14
+ "câu hỏi": "Trong dự án RAG, file nào được dùng để khai báo các hàm load file PDF?",
15
+ "lựa chọn": {
16
+ "a": "src/rag/main.py",
17
+ "b": "src/rag/file_loader.py",
18
+ "c": "src/rag/offline_rag.py",
19
+ "d": "src/rag/utils.py"
20
+ },
21
+ "đáp án": "src/rag/file_loader.py"
22
+ },
23
+ "3": {
24
+ "câu hỏi": "Trong file src/rag/vectorstore.py, lớp nào được đặt làm giá trị mặc định cho vector database?",
25
+ "lựa chọn": {
26
+ "a": "FAISS",
27
+ "b": "Chroma",
28
+ "c": "Pinecone",
29
+ "d": "Milvus"
30
+ },
31
+ "đáp án": "Chroma"
32
+ },
33
+ "4": {
34
+ "câu hỏi": "Trong đoạn mã được trích dẫn, tham số nào được sử dụng cho kiểu lượng tử (quantization type) trong cấu hình BitsAndBytesConfig?",
35
+ "lựa chọn": {
36
+ "a": "nf4",
37
+ "b": "int8",
38
+ "c": "fp16",
39
+ "d": "int4"
40
+ },
41
+ "đáp án": "nf4"
42
+ },
43
+ "5": {
44
+ "câu hỏi": "Theo mô tả trong nội dung, bước nào liên quan đến việc tạo cơ sở dữ liệu vector bằng mô hình embedding?",
45
+ "lựa chọn": {
46
+ "a": "Tách danh sách các bài báo khoa học thành các văn bản nhỏ.",
47
+ "b": "Xây dựng một cơ sở dữ liệu vector từ các văn bản nhỏ bằng mô hình embedding.",
48
+ "c": "Truy vấn các mẫu văn bản có liên quan đến câu hỏi đầu vào để làm ngữ cảnh.",
49
+ "d": "Đưa câu prompt (câu hỏi và ngữ cảnh) vào mô hình để nhận câu trả lời."
50
+ },
51
+ "đáp án": "Xây dựng một cơ sở dữ liệu vector từ các văn bản nhỏ bằng mô hình embedding."
52
+ }
53
+ },
54
+ "validation": {
55
+ "1": {
56
+ "supported_by_embeddings": true,
57
+ "max_similarity": 0.5152225494384766,
58
+ "evidence": [
59
+ {
60
+ "idx": 26,
61
+ "page": 15,
62
+ "score": 0.5152225494384766,
63
+ "text": "Ý nghĩa của phương thức `from_template()` trong class PromptTemplate là? ( _a_ ) Đểkhởi tạo prompt template từmột file. ( _b_ ) Đểkhởi tạo prompt template từmột string. ( _c_ ) Đểkhởi tạo prompt template từmột danh sách các tin nhắn. ( _d_ ) Đểkhởi tạo prompt template từmột prompt template có sẵn. 15"
64
+ }
65
+ ],
66
+ "model_verdict": {
67
+ "supported": false,
68
+ "confidence": 0.9,
69
+ "evidence": "",
70
+ "reason": "Context không chứa thông tin về lớp Str_OutputParser hay biểu thức chính quy được sử dụng, vì vậy không thể chứng thực đáp án được đưa ra."
71
+ }
72
+ },
73
+ "2": {
74
+ "supported_by_embeddings": true,
75
+ "max_similarity": 0.694902777671814,
76
+ "evidence": [
77
+ {
78
+ "idx": 4,
79
+ "page": 4,
80
+ "score": 0.694902777671814,
81
+ "text": "**AI VIETNAM (AIO2024)** **aivietnam.edu.vn**\n\n\n_ **src/rag/:** Thư mục dùng đểlưu trữcác code liên quan đến xây dựng RAG, bao gồm:\n\n\n1. **src/rag/file_loader.py:** File code dùng đểkhai báo các hàm load file pdf (vì tài\nliệu của chúng ta thu thập thuộc file pdf). 2. **src/rag/main.py:** File code dùng đểkhai báo hàm khởi tạo chains. 3. **src/rag/offline_rag.py:** File code dùng đểkhai báo PromptTemplate. 4. **src/rag/utils.py:** File code dùng đểkhai báo hàm tách câu trảlời từmodel. 5. **src/rag/vectorstore.py:** File code dùng đểkhai báo hàm khởi tạo hệcơ sởdữliệu\n\nvector. _ **src/app.py:** File code dùng đểkhởi tạo API. _ **requirements.txt:** File code dùng đểkhai báo các thư viện cần thiết đểsửdụng source\ncode. ## II.2. Cập nhật file requirements.txt\n\n\nĐểbắt đầu, chúng ta sẽliệt kê các gói thư viện cần thiết đểchạy được chương trình này."
82
+ },
83
+ {
84
+ "idx": 28,
85
+ "page": 16,
86
+ "score": 0.5763600468635559,
87
+ "text": "document_loaders` `import` `PyPDFLoader`\n\n\n2\n\n\n3 `pdf_loader = PyPDFLoader(url, extract_images =` `True` `)`\n\n\n4\n\n\n5 `docs = pdf_loader.load ()`\n\n\nTham số `extract_images` tại dòng code 3 có chức năng gì? ( _a_ ) Trảvềtất cảảnh từfile pdf. ( _b_ ) Bỏqua ảnh, chỉload text. ( _c_ ) Phân tích ảnh thành vector. ( _d_ ) Chuyển đổi ảnh trong file pdf thành text. 16"
88
+ },
89
+ {
90
+ "idx": 16,
91
+ "page": 9,
92
+ "score": 0.5420067310333252,
93
+ "text": "**AI VIETNAM (AIO2024)** **aivietnam.edu.vn**\n\n\n86 `return` `self.load(files, workers=workers)`\n\n## II.6. Cập nhật file src/rag/vectorstore.py\n\n\nTại file này, ta định nghĩa một class đểkhởi tạo hệcơ sởdữliệu vector. Trong project này, chúng\nta sẽsửdụng Chroma. Vềviệc tìm kiếm tài liệu tương đồng, ta sửdụng FAISS. Như vậy, nội\ndung của file như sau:\n\n\nHình 4: Minh họa việc sửdụng vector database Chroma đểtruy vấn các tài liệu có liên quan\n[làm context trong prompt. Ảnh: Link.](https://heidloff.net/article/retrieval-augmented-generation-chroma-langchain/)\n\n\n1 `from` `typing` `import` `Union`\n\n2 `from` `langchain_chroma` `import` `Chroma`\n\n3 `from` `langchain_community .vectorstores` `import` `FAISS`\n\n4 `from` `langchain_community .embeddings` `import` `HuggingFaceEmbeddings`\n\n\n5\n\n\n6 `class` `VectorDB:`\n\n\n7 `def` `__init__(self,`\n\n\n8 `documents = None,`\n\n9 `vector_db: Union[Chroma, FAISS] = Chroma,`\n\n10 `embedding = HuggingFaceEmbeddings (),`\n\n11 `) -> None` `:`\n\n\n12\n\n\n13 `self.vector_db ..."
94
+ }
95
+ ],
96
+ "model_verdict": {
97
+ "supported": true,
98
+ "confidence": 0.99,
99
+ "evidence": "src/rag/file_loader.py: File code dùng để khai báo các hàm load file pdf",
100
+ "reason": "Context explicitly states that src/rag/file_loader.py declares functions for loading PDF files, matching the answer."
101
+ }
102
+ },
103
+ "3": {
104
+ "supported_by_embeddings": true,
105
+ "max_similarity": 0.579485297203064,
106
+ "evidence": [
107
+ {
108
+ "idx": 20,
109
+ "page": 11,
110
+ "score": 0.579485297203064,
111
+ "text": "Cập nhật file src/rag/main.py\n\n\nTại file này, ta khởi tạo toàn bộcác instance của các class, các hàm mà ta đã khai báo trước đó\nvà kết nối chúng vào trong một hàm duy nhất gọi là `build_rag_chain()` :\n\n\n1 `from` `pydantic` `import` `BaseModel, Field`\n\n\n2\n\n\n3 `from src.rag.file_loader` `import` `Loader`\n\n4 `from src.rag.vectorstore` `import` `VectorDB`\n\n5 `from src.rag.offline_rag` `import` `Offline_RAG`\n\n\n6\n\n\n7 `class` `InputQA(BaseModel):`\n\n8 `question: str = Field (..., title=` `\"Question to ask the model\"` `)`\n\n\n9\n\n\n10 `class` `OutputQA(BaseModel):`\n\n11 `answer: str = Field (..., title=` `\"Answer` `from the model\"` `)`\n\n\n12\n\n\n13 `def` `build_rag_chain (llm, data_dir, data_type):`\n\n14 `doc_loaded = Loader(file_type=data_type).load_dir(data_dir, workers=2)`\n\n15 `retriever = VectorDB(documents = doc_loaded).get_retriever ()`\n\n16 `rag_chain = Offline_RAG(llm).get_chain(retriever)`\n\n\n17\n\n\n18 `return` `rag_chain`\n\n\n11"
112
+ },
113
+ {
114
+ "idx": 16,
115
+ "page": 9,
116
+ "score": 0.5778905749320984,
117
+ "text": "**AI VIETNAM (AIO2024)** **aivietnam.edu.vn**\n\n\n86 `return` `self.load(files, workers=workers)`\n\n## II.6. Cập nhật file src/rag/vectorstore.py\n\n\nTại file này, ta định nghĩa một class đểkhởi tạo hệcơ sởdữliệu vector. Trong project này, chúng\nta sẽsửdụng Chroma. Vềviệc tìm kiếm tài liệu tương đồng, ta sửdụng FAISS. Như vậy, nội\ndung của file như sau:\n\n\nHình 4: Minh họa việc sửdụng vector database Chroma đểtruy vấn các tài liệu có liên quan\n[làm context trong prompt. Ảnh: Link.](https://heidloff.net/article/retrieval-augmented-generation-chroma-langchain/)\n\n\n1 `from` `typing` `import` `Union`\n\n2 `from` `langchain_chroma` `import` `Chroma`\n\n3 `from` `langchain_community .vectorstores` `import` `FAISS`\n\n4 `from` `langchain_community .embeddings` `import` `HuggingFaceEmbeddings`\n\n\n5\n\n\n6 `class` `VectorDB:`\n\n\n7 `def` `__init__(self,`\n\n\n8 `documents = None,`\n\n9 `vector_db: Union[Chroma, FAISS] = Chroma,`\n\n10 `embedding = HuggingFaceEmbeddings (),`\n\n11 `) -> None` `:`\n\n\n12\n\n\n13 `self.vector_db ..."
118
+ }
119
+ ],
120
+ "model_verdict": {
121
+ "supported": true,
122
+ "confidence": 1.0,
123
+ "evidence": "vector_db: Union[Chroma, FAISS] = Chroma",
124
+ "reason": "Mặc định của tham số vector_db trong class VectorDB được đặt là Chroma"
125
+ }
126
+ },
127
+ "4": {
128
+ "supported_by_embeddings": false,
129
+ "max_similarity": 0.43995893001556396,
130
+ "evidence": [],
131
+ "model_verdict": {
132
+ "supported": false,
133
+ "confidence": 0.95,
134
+ "evidence": "",
135
+ "reason": "Trong nội dung Context không có bất kỳ đoạn nào đề cập đến BitsAndBytesConfig hay tham số kiểu lượng tử, vì vậy không thể chứng thực đáp án nf4."
136
+ }
137
+ },
138
+ "5": {
139
+ "supported_by_embeddings": true,
140
+ "max_similarity": 0.6268875598907471,
141
+ "evidence": [
142
+ {
143
+ "idx": 1,
144
+ "page": 2,
145
+ "score": 0.6268875598907471,
146
+ "text": "**AI VIETNAM (AIO2024)** **aivietnam.edu.vn**\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\nHình 2: Tổng quan vềpipeline của project.\n\n\n**Theo đó:**\n\n\n1. Từdanh sách các bài báo khoa học, ta tách thành các văn bản nhỏ. Từđó, xây dựng một\nhệcơ sởdữliệu vector với một embedding model.\n\n\n2. Bên cạnh câu hỏi đầu vào (question), ta truy vấn các mẫu văn bản có liên quan đến đến\ncâu hỏi, dùng làm ngữcảnh (context) trong câu prompt. Đây là nguồn thông tin mà LLMs\ncó thểdựa vào đểtrảlời câu hỏi.\n\n\n3. Đưa câu prompt vào mô hình (question và context) đểnhận câu trảlời từmô hình.\n\n\n2"
147
+ },
148
+ {
149
+ "idx": 30,
150
+ "page": 17,
151
+ "score": 0.5708718299865723,
152
+ "text": "split_documents (pdf_pages)`\n\n\n18\n\n\n19 _`# Embedding`_ _`model`_\n\n20 `embedding_model = HuggingFaceEmbeddings ()`\n\n\n21\n\n\n22 _`# vector`_ _`store`_\n\n\n23 `chroma_db = Chroma.from_documents(docs, embedding= embedding_model )`\n\n\nNhiệm vụcủa `embedding_model` là gì? ( _a_ ) Dùng biến đổi chuỗi đầu vào thành các vector cho cơ sởdữliệu vector. ( _b_ ) Dùng đểlập chỉmục cho cơ sởdữliệu. ( _c_ ) Dùng đểtìm kiếm tài liệu. ( _d_ ) Dùng đểtính toán độtương đồng. 17"
153
+ }
154
+ ],
155
+ "model_verdict": {
156
+ "supported": true,
157
+ "confidence": 0.99,
158
+ "evidence": "1. Từ danh sách các bài báo khoa học, ta tách thành các văn bản nhỏ. Từ đó, xây dựng một hẹcơ sở dữ liệu vector với một embedding model.",
159
+ "reason": "Context explicitly states that after splitting documents, a vector database is built using an embedding model, matching the chosen answer."
160
+ }
161
+ }
162
+ }
163
+ }
test/output.json ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ from typing import Dict, Any
4
+ import requests
5
+ import os
6
+
7
+ API_URL = "https://router.huggingface.co/v1/chat/completions"
8
+ HF_KEY = os.environ['HF_API_KEY']
9
+ HEADERS = {"Authorization": f"Bearer {HF_KEY}"}
10
+ JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
11
+
12
+ def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
13
+ payload = {"model": model, "messages": messages, "temperature": temperature}
14
+ resp = requests.post(API_URL, headers=HEADERS, json=payload, timeout=timeout)
15
+ resp.raise_for_status()
16
+ data = resp.json()
17
+
18
+ # handle various shapes
19
+ if "choices" in data and len(data["choices"]) > 0:
20
+ # prefer message.content
21
+ ch = data["choices"][0]
22
+
23
+ if isinstance(ch, dict) and "message" in ch and "content" in ch["message"]:
24
+ return ch["message"]["content"]
25
+
26
+ if "text" in ch:
27
+ return ch["text"]
28
+
29
+ # final fallback
30
+ raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
31
+
32
+ def _safe_extract_json(text: str) -> dict:
33
+ # remove triple backticks
34
+ text = re.sub(r"```(?:json)?\n?", "", text)
35
+ m = JSON_OBJ_RE.search(text)
36
+
37
+ if not m:
38
+ raise ValueError("No JSON object found in model output.")
39
+ js = m.group(1)
40
+
41
+ # try load, fix trailing commas
42
+ try:
43
+ return json.loads(js)
44
+ except json.JSONDecodeError:
45
+ fixed = re.sub(r",\s*([}\]])", r"\1", js)
46
+ return json.loads(fixed)
47
+
48
+ def generate_mcqs_from_text(
49
+ source_text: str,
50
+ n: int = 3,
51
+ model: str = "openai/gpt-oss-120b:cerebras",
52
+ temperature: float = 0.2,
53
+ ) -> Dict[str, Any]:
54
+ system_message = {
55
+ "role": "system",
56
+ "content": (
57
+ "Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm. "
58
+ "Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không có bất kỳ văn bản nào khác:\n\n"
59
+ "{\n"
60
+ ' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"..."},\n'
61
+ ' "2": { ... }\n'
62
+ "}\n\n"
63
+ "Lưu ý:\n"
64
+ f"- Tạo đúng {n} mục, đánh số từ 1 tới {n}.\n"
65
+ "- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
66
+ "- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'lựa chọn'.\n"
67
+ "- Không kèm giải thích hay trường thêm.\n"
68
+ "- Các phương án sai (distractors) phải hợp lý và không lặp lại."
69
+ )
70
+ }
71
+ user_message = {
72
+ "role": "user",
73
+ "content": (
74
+ f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Dùng nội dung này làm nguồn duy nhất để trả lời."
75
+ "Nếu nội dung quá ít để tạo câu hỏi chính xác, hãy tạo các phương án hợp lý nhưng có thể biện minh được.\n\n"
76
+ f"Nội dung:\n\n{source_text}"
77
+ )
78
+ }
79
+
80
+ raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
81
+ parsed = _safe_extract_json(raw)
82
+
83
+ # validate structure and length
84
+ if not isinstance(parsed, dict) or len(parsed) != n:
85
+ raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
86
+ return parsed