namberino commited on
Commit
a8b213d
·
1 Parent(s): d3f7049

Literally everything

Browse files
Files changed (6) hide show
  1. Dockerfile +32 -0
  2. README.md +8 -3
  3. app.py +259 -0
  4. generator.py +1125 -0
  5. requirements.txt +10 -0
  6. utils.py +408 -0
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # set HF cache to /tmp for writable FS on Spaces
4
+ ENV HF_HOME=/tmp/huggingface
5
+ ENV TOKENIZERS_PARALLELISM=false
6
+
7
+ # install system packages needed by some python libs
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ build-essential \
10
+ git \
11
+ wget \
12
+ libsndfile1 \
13
+ libgl1 \
14
+ libglib2.0-0 \
15
+ poppler-utils \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ WORKDIR /app
19
+
20
+ # copy requirements and install
21
+ COPY requirements.txt /app/requirements.txt
22
+ RUN pip install --upgrade pip
23
+ # try to be robust to wheels/build issues
24
+ # RUN pip wheel --no-cache-dir --wheel-dir=/wheels -r /app/requirements.txt || true
25
+ RUN pip install --no-cache-dir -r /app/requirements.txt
26
+
27
+ # copy app code
28
+ COPY . /app
29
+
30
+ EXPOSE 7860
31
+
32
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,15 @@
1
  ---
2
- title: Test Generator
3
- emoji: 👀
4
- colorFrom: blue
5
  colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
1
  ---
2
+ title: Mcq Generator
3
+ emoji: 📚
4
+ colorFrom: purple
5
  colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
11
+
12
+ ---
13
+ TODO:
14
+ + Apply Cohen's Kappa to measure rate of aggreement between human and AI.
15
+ + Improve function transparency by adding Documents
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ from typing import List, Optional, Union
5
+
6
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+
10
+ # Import the user's RAGMCQ implementation
11
+ from generator import RAGMCQ
12
+ from utils import log_pipeline
13
+
14
+ app = FastAPI(title="RAG MCQ Generator API")
15
+
16
+ # allow cross-origin requests (adjust in production)
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ # global rag instance
26
+ rag: Optional[RAGMCQ] = None
27
+
28
+ class GenerateResponse(BaseModel):
29
+ mcqs: dict
30
+ validation: Optional[dict] = None
31
+
32
+ class ListResponse(BaseModel):
33
+ files: list
34
+
35
+ @app.on_event("startup")
36
+ def startup_event():
37
+ global rag
38
+
39
+ # instantiate the heavy object once
40
+ rag = RAGMCQ()
41
+ print("RAGMCQ instance created on startup.")
42
+
43
+ @app.get("/health")
44
+ def health():
45
+ return {"status": "ok", "ready": rag is not None}
46
+
47
+ def _save_upload_to_temp(upload: UploadFile) -> str:
48
+ suffix = ".pdf"
49
+ fd, path = tempfile.mkstemp(suffix=suffix)
50
+ os.close(fd)
51
+ with open(path, "wb") as out_file:
52
+ shutil.copyfileobj(upload.file, out_file)
53
+ return path
54
+
55
+
56
+ @app.get("/list_collection_files", response_model=ListResponse)
57
+ async def list_collection_files_endpoint(
58
+ collection_name: str = "programming"
59
+ ):
60
+ global rag
61
+ if rag is None:
62
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
63
+
64
+ files = rag.list_files_in_collection(collection_name)
65
+
66
+ return {"files": files}
67
+
68
+
69
+ @app.post("/upload_multiple_files", response_model=ListResponse)
70
+ async def upload_multiple_files(
71
+ background_tasks: BackgroundTasks,
72
+ files: List[UploadFile] = File(...), # get multiple files
73
+ collection_name: str = Form("programming"),
74
+ overwrite: bool = Form(True),
75
+ qdrant_filename_prefix: Optional[str] = Form(None),
76
+ ):
77
+ """
78
+ Upload multiple PDF files and save their chunks to Qdrant.
79
+ - files: one or more PDF files (multipart/form-data, repeated 'files' fields)
80
+ - collection_name: Qdrant collection to save into
81
+ - overwrite: if true, existing points for each filename will be removed
82
+ - qdrant_filename_prefix: optional prefix; if provided each file will be saved under "<prefix>_<original_filename>"
83
+ """
84
+ global rag
85
+ if rag is None:
86
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
87
+
88
+ saved_files = []
89
+
90
+ def _cleanup(path: str):
91
+ try:
92
+ os.remove(path)
93
+ except Exception:
94
+ pass
95
+
96
+ for idx, upload in enumerate(files):
97
+ if isinstance(upload, str):
98
+ continue
99
+
100
+ if not upload.filename:
101
+ raise HTTPException(status_code=400, detail=f"Uploaded file #{idx+1} missing filename.")
102
+
103
+ if not upload.filename.lower().endswith(".pdf"):
104
+ raise HTTPException(status_code=400, detail=f"Only PDF files supported: {upload.filename}, error at file number: {idx}")
105
+
106
+ tmp_path = _save_upload_to_temp(upload)
107
+ background_tasks.add_task(_cleanup, tmp_path)
108
+
109
+ # decide filename to use in Qdrant payload
110
+ qdrant_filename = str(
111
+ f"{qdrant_filename_prefix}_{upload.filename}" if qdrant_filename_prefix else upload.filename
112
+ )
113
+
114
+ try:
115
+ rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=overwrite)
116
+ saved_files.append(qdrant_filename)
117
+ except Exception as e:
118
+ # collect failure info rather than aborting all uploads
119
+ saved_files.append({"filename": upload.filename, "error": str(e)})
120
+
121
+ return {"files": saved_files}
122
+
123
+
124
+
125
+ @app.post("/generate_saved", response_model=GenerateResponse)
126
+ async def generate_saved_endpoint(
127
+ n_easy_questions: int = Form(3),
128
+ n_medium_questions: int = Form(5),
129
+ n_hard_questions: int = Form(2),
130
+ qdrant_filename: str = Form("default_filename"),
131
+ collection_name: str = Form("programming"),
132
+ mode: str = Form("rag"),
133
+ questions_per_chunk: int = Form(3),
134
+ top_k: int = Form(3),
135
+ temperature: float = Form(0.2),
136
+ validate_mcqs: bool = Form(False),
137
+ enable_fiddler: bool = Form(False),
138
+ ):
139
+ global rag
140
+ if rag is None:
141
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
142
+
143
+ difficulty_counts = {
144
+ "easy": n_easy_questions,
145
+ "medium": n_medium_questions,
146
+ "hard": n_hard_questions
147
+ }
148
+
149
+ all_mcqs = {}
150
+
151
+ for difficulty, n_questions in difficulty_counts.items():
152
+ try:
153
+ mcqs = rag.generate_from_qdrant(
154
+ filename=qdrant_filename,
155
+ collection=collection_name,
156
+ n_questions=n_questions,
157
+ mode=mode,
158
+ questions_per_chunk=questions_per_chunk,
159
+ top_k=top_k,
160
+ temperature=temperature,
161
+ enable_fiddler=enable_fiddler,
162
+ target_difficulty=difficulty,
163
+ )
164
+ all_mcqs.update(mcqs)
165
+
166
+ except Exception as e:
167
+ raise HTTPException(status_code=500, detail=f"Generation from saved file failed: {e}")
168
+
169
+ validation_report = None
170
+
171
+ if validate_mcqs:
172
+ try:
173
+ # validate_mcqs expects keys as strings and the normalized content
174
+ validation_report = rag.validate_mcqs(all_mcqs, top_k=top_k)
175
+ except Exception as e:
176
+ # don't fail the whole request for a validation error — return generator output and note the error
177
+ validation_report = {"error": f"Validation failed: {e}"}
178
+
179
+ # log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
180
+
181
+ return {"mcqs": all_mcqs, "validation": validation_report}
182
+
183
+
184
+
185
+
186
+ @app.post("/generate", response_model=GenerateResponse)
187
+ async def generate_endpoint(
188
+ background_tasks: BackgroundTasks,
189
+ file: UploadFile = File(...),
190
+ n_questions: int = Form(10),
191
+ qdrant_filename: str = Form("default_filename"),
192
+ collection_name: str = Form("programming"),
193
+ mode: str = Form("rag"),
194
+ questions_per_page: int = Form(3),
195
+ top_k: int = Form(3),
196
+ temperature: float = Form(0.2),
197
+ validate_mcqs: bool = Form(False),
198
+ enable_fiddler: bool = Form(False)
199
+ ):
200
+ global rag
201
+ if rag is None:
202
+ raise HTTPException(status_code=503, detail="RAGMCQ not ready on server.")
203
+
204
+ # basic file validation
205
+ if not file.filename.lower().endswith(".pdf"):
206
+ raise HTTPException(status_code=400, detail="Only PDF files are supported.")
207
+
208
+ # save uploaded file to a temp location
209
+ tmp_path = _save_upload_to_temp(file)
210
+
211
+ # ensure file removed afterward
212
+ def _cleanup(path: str):
213
+ try:
214
+ os.remove(path)
215
+ except Exception:
216
+ pass
217
+
218
+ background_tasks.add_task(_cleanup, tmp_path)
219
+
220
+ # save pdf
221
+ try:
222
+ rag.save_pdf_to_qdrant(tmp_path, filename=qdrant_filename, collection=collection_name, overwrite=True)
223
+ except Exception as e:
224
+ raise HTTPException(status_code=500, detail=f"Could not save file to Qdrant Cloud: {e}")
225
+
226
+ # generate
227
+ try:
228
+ mcqs = rag.generate_from_pdf(
229
+ tmp_path,
230
+ n_questions=n_questions,
231
+ mode=mode,
232
+ questions_per_page=questions_per_page,
233
+ top_k=top_k,
234
+ temperature=temperature,
235
+ enable_fiddler=enable_fiddler
236
+ )
237
+ except Exception as e:
238
+ raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
239
+
240
+ validation_report = None
241
+
242
+ if validate_mcqs:
243
+ try:
244
+ # rag.build_index_from_pdf(tmp_path)
245
+ # validate_mcqs expects keys as strings and the normalized content
246
+ validation_report = rag.validate_mcqs(mcqs, top_k=top_k)
247
+ except Exception as e:
248
+ # don't fail the whole request for a validation error — return generator output and note the error
249
+ validation_report = {"error": f"Validation failed: {e}"}
250
+
251
+
252
+ # log_pipeline('test/mcq_output.json', content={"mcqs": mcqs, "validation": validation_report})
253
+
254
+ return {"mcqs": mcqs, "validation": validation_report}
255
+
256
+
257
+ if __name__ == "__main__":
258
+ import uvicorn
259
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, log_level="info")
generator.py ADDED
@@ -0,0 +1,1125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import random
3
+ import fitz
4
+ import string
5
+ import numpy as np
6
+ import os
7
+ from typing import List, Optional, Tuple, Dict, Any
8
+ from sentence_transformers import SentenceTransformer, CrossEncoder
9
+ from transformers import pipeline
10
+ from uuid import uuid4
11
+ import pymupdf4llm
12
+
13
+ try:
14
+ from qdrant_client import QdrantClient
15
+ from qdrant_client.http.models import (
16
+ PointStruct,
17
+ Filter,
18
+ FieldCondition,
19
+ MatchValue,
20
+ Distance,
21
+ VectorParams,
22
+ )
23
+ from qdrant_client.http import models as rest
24
+ _HAS_QDRANT = True
25
+ except Exception:
26
+ _HAS_QDRANT = False
27
+
28
+ try:
29
+ import faiss
30
+ _HAS_FAISS = True
31
+ except Exception:
32
+ _HAS_FAISS = False
33
+
34
+ from utils import generate_mcqs_from_text, new_generate_mcqs_from_text, structure_context_for_llm
35
+
36
+ from huggingface_hub import login
37
+ login(token=os.environ['HF_MODEL_TOKEN'])
38
+
39
+ class RAGMCQ:
40
+ def __init__(
41
+ self,
42
+ embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
43
+ generation_model: str = "openai/gpt-oss-120b",
44
+ qdrant_url: str = os.environ.get('QDRANT_URL') or "",
45
+ qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "",
46
+ qdrant_prefer_grpc: bool = False,
47
+ ):
48
+ self.embedder = SentenceTransformer(embedder_model)
49
+ self.generation_model = generation_model
50
+ self.qa_pipeline = pipeline("question-answering", model="nguyenvulebinh/vi-mrc-base", tokenizer="nguyenvulebinh/vi-mrc-base")
51
+ self.cross_entail = CrossEncoder("itdainb/PhoRanker")
52
+ self.embeddings = None # np.array of shape (N, D)
53
+ self.texts = [] # list of chunk texts
54
+ self.metadata = [] # list of dicts (page, chunk_id, char_range)
55
+ self.index = None
56
+ self.dim = self.embedder.get_sentence_embedding_dimension()
57
+
58
+ self.qdrant = None
59
+ self.qdrant_url = qdrant_url
60
+ self.qdrant_api_key = qdrant_api_key
61
+ self.qdrant_prefer_grpc = qdrant_prefer_grpc
62
+
63
+ if qdrant_url:
64
+ self.connect_qdrant(qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
65
+
66
+ def extract_pages(
67
+ self,
68
+ pdf_path: str,
69
+ *,
70
+ pages: Optional[List[int]] = None,
71
+ ignore_images: bool = False,
72
+ dpi: int = 150
73
+ ) -> List[str]:
74
+ doc = fitz.open(pdf_path)
75
+ try:
76
+ # request page-wise output (page_chunks=True -> list[dict] per page)
77
+ page_dicts = pymupdf4llm.to_markdown(
78
+ doc,
79
+ pages=pages,
80
+ ignore_images=ignore_images,
81
+ dpi=dpi,
82
+ page_chunks=True,
83
+ )
84
+
85
+ # to_markdown(..., page_chunks=True) returns a list of dicts, each has key "text" (markdown)
86
+ pages_md: List[str] = []
87
+ for p in page_dicts:
88
+ txt = p.get("text", "") or ""
89
+ pages_md.append(txt.strip())
90
+
91
+ return pages_md
92
+ finally:
93
+ doc.close()
94
+
95
+ def chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 100) -> List[str]:
96
+ text = text.strip()
97
+ if not text:
98
+ return []
99
+
100
+ if len(text) <= max_chars:
101
+ return [text]
102
+
103
+ # split by sentence-like boundaries
104
+ sentences = re.split(r'(?<=[\.\?\!])\s+', text)
105
+ chunks = []
106
+ cur = ""
107
+
108
+ for s in sentences:
109
+ if len(cur) + len(s) + 1 <= max_chars:
110
+ cur += (" " if cur else "") + s
111
+ else:
112
+ if cur:
113
+ chunks.append(cur)
114
+
115
+ cur = (cur[-overlap:] + " " + s) if overlap > 0 else s
116
+
117
+ if cur:
118
+ chunks.append(cur)
119
+
120
+ # if still too long, hard-split
121
+ final = []
122
+ for c in chunks:
123
+ if len(c) <= max_chars:
124
+ final.append(c)
125
+ else:
126
+ for i in range(0, len(c), max_chars):
127
+ final.append(c[i:i+max_chars])
128
+
129
+ return final
130
+
131
+
132
+ def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
133
+ pages = self.extract_pages(pdf_path)
134
+
135
+ self.texts = []
136
+ self.metadata = []
137
+
138
+ for p_idx, page_text in enumerate(pages, start=1):
139
+ chunks = self.chunk_text(page_text or "", max_chars=max_chars)
140
+ for cid, ch in enumerate(chunks, start=1):
141
+ self.texts.append(ch)
142
+ self.metadata.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
143
+
144
+ if not self.texts:
145
+ raise RuntimeError("No text extracted from PDF.")
146
+
147
+ # save_to_local('test/text_chunks.md', content=self.texts)
148
+
149
+ # compute embeddings
150
+ emb = self.embedder.encode(self.texts, convert_to_numpy=True, show_progress_bar=True)
151
+ self.embeddings = emb.astype("float32")
152
+ self._build_faiss_index()
153
+
154
+ def _build_faiss_index(self, ef_construction=200, M=32):
155
+ if _HAS_FAISS:
156
+ d = self.embeddings.shape[1]
157
+ index = faiss.IndexHNSWFlat(d, M)
158
+ faiss.normalize_L2(self.embeddings)
159
+ index.add(self.embeddings)
160
+ index.hnsw.efConstruction = ef_construction
161
+ self.index = index
162
+ else:
163
+ # store normalized embeddings and use brute-force numpy
164
+ norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-10
165
+ self.embeddings = self.embeddings / norms
166
+ self.index = None
167
+
168
+ def _retrieve(self, query: str, top_k: int = 3) -> List[Tuple[int, float]]:
169
+ q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")
170
+
171
+ if _HAS_FAISS:
172
+ faiss.normalize_L2(q_emb)
173
+ D_list, I_list = self.index.search(q_emb, top_k)
174
+ # D are inner products; return list of (idx, score)
175
+ return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
176
+ else:
177
+ qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
178
+ sims = (self.embeddings @ qn.T).squeeze(axis=1)
179
+ idxs = np.argsort(-sims)[:top_k]
180
+ return [(int(i), float(sims[i])) for i in idxs]
181
+
182
+ def generate_from_pdf(
183
+ self,
184
+ pdf_path: str,
185
+ n_questions: int = 10,
186
+ mode: str = "rag", # per_page or rag
187
+ questions_per_page: int = 3, # for per_page mode
188
+ top_k: int = 3, # chunks to retrieve for each question in rag mode
189
+ temperature: float = 0.2,
190
+ enable_fiddler: bool = False,
191
+ target_difficulty: str = 'easy' # easy, mid, difficult
192
+ ) -> Dict[str, Any]:
193
+ # build index
194
+ self.build_index_from_pdf(pdf_path)
195
+
196
+ output: Dict[str, Any] = {}
197
+ qcount = 0
198
+
199
+ if mode == "per_page":
200
+ # iterate pages -> chunks
201
+ for idx, meta in enumerate(self.metadata):
202
+ chunk_text = self.texts[idx]
203
+
204
+ if not chunk_text.strip():
205
+ continue
206
+ to_gen = questions_per_page
207
+
208
+ # ask generator
209
+ try:
210
+ structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False, target_difficulty=target_difficulty)
211
+ mcq_block = generate_mcqs_from_text(
212
+ source_text=chunk_text, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
213
+ )
214
+ except Exception as e:
215
+ # skip this chunk if generator fails
216
+ print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
217
+ continue
218
+
219
+ if "error" in list(mcq_block.keys()):
220
+ return output
221
+
222
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
223
+ qcount += 1
224
+ output[str(qcount)] = mcq_block[item]
225
+ if qcount >= n_questions:
226
+ return output
227
+
228
+ return output
229
+
230
+ # pdf gene
231
+ elif mode == "rag":
232
+ # strategy: create a few natural short queries by sampling sentences or using chunk summaries.
233
+ # create queries by sampling chunk text sentences.
234
+ # stop when n_questions reached or max_attempts exceeded.
235
+ attempts = 0
236
+ max_attempts = n_questions * 4
237
+
238
+ while qcount < n_questions and attempts < max_attempts:
239
+ attempts += 1
240
+ # create a seed query: pick a random chunk, pick a sentence from it
241
+ seed_idx = random.randrange(len(self.texts))
242
+ chunk = self.texts[seed_idx]
243
+
244
+ #? investigate better Chunking Strategy
245
+ #with open("chunks.txt", "a", encoding="utf-8") as f:
246
+ #f.write(chunk + "\n")
247
+
248
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
249
+ seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
250
+ query = f"Create questions about: {seed_sent}"
251
+
252
+ # retrieve top_k chunks
253
+ retrieved = self._retrieve(query, top_k=top_k)
254
+ context_parts = []
255
+ for ridx, score in retrieved:
256
+ md = self.metadata[ridx]
257
+ context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
258
+ context = "\n\n".join(context_parts)
259
+
260
+ # save_to_local('test/context.md', content=context)
261
+
262
+ # call generator for 1 question (or small batch) with the retrieved context
263
+ try:
264
+ # request 1 question at a time to keep diversity
265
+ structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False, target_difficulty=target_difficulty)
266
+ mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
267
+ except Exception as e:
268
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
269
+ continue
270
+
271
+ if "error" in list(mcq_block.keys()):
272
+ return output
273
+
274
+ # append result(s)
275
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
276
+ payload = mcq_block[item]
277
+ q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
278
+ options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
279
+ if isinstance(options, list):
280
+ options = {str(i+1): o for i, o in enumerate(options)}
281
+ correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
282
+ correct_text = ""
283
+ if isinstance(correct_key, str) and correct_key.strip() in options:
284
+ correct_text = options[correct_key.strip()]
285
+ else:
286
+ correct_text = payload.get("correct_text") or correct_key or ""
287
+
288
+ diff_score, diff_label = self._estimate_difficulty_for_generation(
289
+ q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context
290
+ )
291
+ payload["difficulty"] = {"score": diff_score, "label": diff_label}
292
+
293
+ qcount += 1
294
+ output[str(qcount)] = mcq_block[item]
295
+ if qcount >= n_questions:
296
+ return output
297
+
298
+ return output
299
+ else:
300
+ raise ValueError("mode must be 'per_page' or 'rag'.")
301
+
302
+ def validate_mcqs(
303
+ self,
304
+ mcqs: Dict[str, Any],
305
+ top_k: int = 4,
306
+ similarity_threshold: float = 0.5,
307
+ evidence_score_cutoff: float = 0.5,
308
+ use_cross_encoder: bool = True,
309
+ use_qa: bool = True,
310
+ auto_accept_threshold: float = 0.7,
311
+ review_threshold: float = 0.5,
312
+ distractor_too_similar: float = 0.8,
313
+ distractor_too_different: float = 0.15,
314
+ model_verification_temperature: float = 0.0,
315
+ ) -> Dict[str, Any]:
316
+ """
317
+ Upgraded validation pipeline:
318
+ - embedding retrieval (self.index / self.embeddings)
319
+ - cross-encoder entailment scoring (optional)
320
+ - extractive QA consistency check (optional)
321
+ - distractor similarity and type checks
322
+ - aggregate into quality_score and triage_action
323
+
324
+ Returns a dict keyed by qid with detailed info and triage decision.
325
+ """
326
+ cross_entail = None
327
+ qa_pipeline = None
328
+ if use_cross_encoder:
329
+ try:
330
+ cross_entail = self.cross_entail
331
+ except Exception as e:
332
+ cross_entail = None
333
+ if use_qa:
334
+ try:
335
+ qa_pipeline = self.qa_pipeline
336
+ except Exception:
337
+ qa_pipeline = None
338
+
339
+ # --- helpers ---
340
+ def _norm_text(s: str) -> str:
341
+ if s is None:
342
+ return ""
343
+ s = s.strip().lower()
344
+ # remove punctuation
345
+ s = s.translate(str.maketrans("", "", string.punctuation))
346
+ # collapse whitespace
347
+ s = " ".join(s.split())
348
+ return s
349
+
350
+ def _semantic_search(statement: str, k: int = top_k):
351
+ # returns list of (idx, score) using current embeddings/index
352
+ q_emb = self.embedder.encode([statement], convert_to_numpy=True).astype("float32")
353
+ if _HAS_FAISS and getattr(self, "index", None) is not None:
354
+ try:
355
+ faiss.normalize_L2(q_emb)
356
+ D_list, I_list = self.index.search(q_emb, k)
357
+ return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
358
+ except Exception:
359
+ pass
360
+ # fallback to brute force
361
+ qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
362
+ sims = (self.embeddings @ qn.T).squeeze(axis=1)
363
+ idxs = np.argsort(-sims)[:k]
364
+ return [(int(i), float(sims[i])) for i in idxs]
365
+
366
+ def _compose_context_from_retrieved(retrieved):
367
+ parts = []
368
+ for ridx, score in retrieved:
369
+ md = self.metadata[ridx] if ridx < len(self.metadata) else {}
370
+ page = md.get("page", "?")
371
+ text = self.texts[ridx]
372
+ parts.append(f"[page {page}] {text}")
373
+ return "\n\n".join(parts)
374
+
375
+ def _compute_option_embeddings(options_map: Dict[str, str]):
376
+ # returns dict key->embedding
377
+ keys = list(options_map.keys())
378
+ texts = [options_map[k] for k in keys]
379
+ embs = self.embedder.encode(texts, convert_to_numpy=True)
380
+ return dict(zip(keys, embs))
381
+
382
+ def _cosine(a, b):
383
+ a = np.asarray(a, dtype=float)
384
+ b = np.asarray(b, dtype=float)
385
+ denom = (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12)
386
+ return float(np.dot(a, b) / denom)
387
+
388
+ # --- main loop ---
389
+ report = {}
390
+ for qid, item in mcqs.items():
391
+ # support both Vietnamese keys and English keys
392
+ q_text = (item.get("câu hỏi") or item.get("question") or item.get("q") or item.get("stem") or "").strip()
393
+ options = item.get("lựa chọn") or item.get("options") or item.get("choices") or {}
394
+ # options may be dict mapping letters to text, or list: normalize to dict
395
+ if isinstance(options, list):
396
+ options = {str(i+1): o for i, o in enumerate(options)}
397
+ # correct answer may be a key (like "A") or the text; try both
398
+ correct_key = item.get("đáp án") or item.get("answer") or item.get("correct") or item.get("ans")
399
+ correct_text = ""
400
+ if isinstance(correct_key, str) and correct_key.strip() in options:
401
+ correct_text = options[correct_key.strip()]
402
+ else:
403
+ # maybe the answer is full text
404
+ if isinstance(correct_key, str):
405
+ correct_text = correct_key.strip()
406
+ else:
407
+ # fallback to 'correct_text' field
408
+ correct_text = item.get("correct_text") or item.get("đáp án_text") or ""
409
+
410
+ # default empty guard
411
+ options = {k: str(v) for k, v in options.items()}
412
+ correct_text = str(correct_text)
413
+
414
+ # prepare statement for retrieval
415
+ statement = f"{q_text} Answer: {correct_text}"
416
+ retrieved = _semantic_search(statement, k=top_k)
417
+ # build context from top retrieved
418
+ context_parts = []
419
+ for ridx, score in retrieved:
420
+ md = self.metadata[ridx] if ridx < len(self.metadata) else {}
421
+ context_parts.append({"idx": ridx, "score": float(score), "page": md.get("page", None), "text": self.texts[ridx]})
422
+ context_text = "\n\n".join([f"[page {p['page']}] {p['text']}" for p in context_parts])
423
+
424
+ # Evidence list (embedding-based)
425
+ evidence_list = []
426
+ max_sim = 0.0
427
+ for r in context_parts:
428
+ if r["score"] >= evidence_score_cutoff:
429
+ snippet = r["text"]
430
+ evidence_list.append({
431
+ "idx": r["idx"],
432
+ "page": r["page"],
433
+ "score": r["score"],
434
+ "text": (snippet[:1000] + ("..." if len(snippet) > 1000 else "")),
435
+ })
436
+ if r["score"] > max_sim:
437
+ max_sim = float(r["score"])
438
+ supported_by_embeddings = max_sim >= similarity_threshold
439
+
440
+ # Cross-encoder entailment scores for each option
441
+ entailment_scores = {}
442
+ correct_entail = 0.0
443
+ try:
444
+ if cross_entail is not None and context_text.strip():
445
+ # prepare list of (premise, hypothesis)
446
+ pairs = []
447
+ opt_keys = list(options.keys())
448
+ for k in opt_keys:
449
+ hyp = f"{q_text} Answer: {options[k]}"
450
+ pairs.append((context_text, hyp))
451
+ scores = cross_entail.predict(pairs) # returns list of floats
452
+ # normalize scores to 0-1 if needed (cross-encoder may return arbitrary positive)
453
+ # do a min-max normalization across the returned scores
454
+ # but avoid division by zero
455
+ min_s = float(min(scores)) if len(scores) else 0.0
456
+ max_s = float(max(scores)) if len(scores) else 1.0
457
+ denom = max_s - min_s if max_s - min_s > 1e-6 else 1.0
458
+ for k, raw in zip(opt_keys, scores):
459
+ scaled = (raw - min_s) / denom
460
+ entailment_scores[k] = float(scaled)
461
+ # find correct key if available
462
+ # if `correct_text` exactly matches one of options, find that key
463
+ matched_key = None
464
+ for k, v in options.items():
465
+ if _norm_text(v) == _norm_text(correct_text):
466
+ matched_key = k
467
+ break
468
+ if matched_key:
469
+ correct_entail = entailment_scores.get(matched_key, 0.0)
470
+ else:
471
+ # fallback: treat 'correct_text' as a separate hypothesis
472
+ hyp = f"{q_text} Answer: {correct_text}"
473
+ raw = cross_entail.predict([(context_text, hyp)])[0]
474
+ # scale relative to min/max used above
475
+ correct_entail = float((raw - min_s) / denom)
476
+ else:
477
+ entailment_scores = {}
478
+ correct_entail = 0.0
479
+ except Exception as e:
480
+ entailment_scores = {}
481
+ correct_entail = 0.0
482
+
483
+ def embed_cosine_sim(a, b):
484
+ emb = self.embedder.encode([a, b], convert_to_numpy=True, normalize_embeddings=True)
485
+ return float(np.dot(emb[0], emb[1]))
486
+
487
+ # QA consistency
488
+ qa_answer = None
489
+ qa_score = 0.0
490
+ qa_agrees = False
491
+ if qa_pipeline is not None and context_text.strip():
492
+ try:
493
+ qa_res = qa_pipeline(question=q_text, context=context_text)
494
+ # some QA pipelines return list of answers or dict
495
+ if isinstance(qa_res, list) and len(qa_res) > 0:
496
+ top = qa_res[0]
497
+ qa_answer = top.get("answer") if isinstance(top, dict) else str(top)
498
+ # qa_score = float(top.get("score", 0.0) if isinstance(top, dict) else 0.0)
499
+ elif isinstance(qa_res, dict):
500
+ qa_answer = qa_res.get("answer", "")
501
+ qa_score = float(qa_res.get("score", 0.0))
502
+ else:
503
+ qa_answer = str(qa_res)
504
+ qa_score = 0.0
505
+ qa_score = embed_cosine_sim(qa_answer, correct_text)
506
+ qa_agrees = (qa_score >= 0.5)
507
+ except Exception:
508
+ qa_answer = None
509
+ qa_score = 0.0
510
+ qa_agrees = False
511
+
512
+ try:
513
+ opt_embs = _compute_option_embeddings({**options, "__CORRECT__": correct_text})
514
+ correct_emb = opt_embs.pop("__CORRECT__")
515
+ distractor_similarities = {}
516
+ for k, emb in opt_embs.items():
517
+ distractor_similarities[k] = float(_cosine(correct_emb, emb))
518
+ except Exception:
519
+ distractor_similarities = {k: None for k in options.keys()}
520
+
521
+ # distractor flags
522
+ distractor_penalty = 0.0
523
+ distractor_flags = []
524
+ for k, sim in distractor_similarities.items():
525
+ if sim is None or sim >= 0.999999 or (sim >= -0.01 and sim <= 0):
526
+ continue
527
+ if sim >= distractor_too_similar:
528
+ distractor_flags.append({"key": k, "reason": "too_similar", "similarity": sim})
529
+ distractor_penalty += 0.25
530
+ elif sim <= distractor_too_different:
531
+ distractor_flags.append({"key": k, "reason": "too_different", "similarity": sim})
532
+ distractor_penalty += 0.15
533
+ # clamp penalty
534
+ distractor_penalty = min(distractor_penalty, 1.0)
535
+
536
+ # Ambiguity detection: how many options have entailment >= threshold
537
+ ambiguous = False
538
+ ambiguous_options = []
539
+ if entailment_scores:
540
+ # count options whose entailment >= max(correct_entail * 0.9, 0.6)
541
+ amb_thresh = max(correct_entail * 0.9, 0.6)
542
+ for k, sc in entailment_scores.items():
543
+ if sc >= amb_thresh and (options.get(k, "") != correct_text):
544
+ ambiguous_options.append({"key": k, "score": sc, "text": options[k]})
545
+ ambiguous = len(ambiguous_options) > 0
546
+
547
+ # Compose aggregated quality score
548
+ # Components:
549
+ # - embedding_support: normalized max_sim (0..1)
550
+ # - entailment: correct_entail (0..1)
551
+ # - qa_agree: boolean -> 1 or 0 times qa_score
552
+ # - distractor_penalty: subtracted
553
+ emb_support_norm = max_sim # embedding similarity typically already 0..1 (inner product normalized)
554
+ entail_component = float(correct_entail)
555
+ qa_component = float(qa_score) if qa_agrees else 0.0
556
+
557
+ # weighted sum
558
+ quality_score = (
559
+ 0.40 * emb_support_norm +
560
+ 0.35 * entail_component +
561
+ 0.20 * qa_component -
562
+ 0.05 * distractor_penalty
563
+ )
564
+ # clamp to 0..1
565
+ quality_score = max(0.0, min(1.0, quality_score))
566
+
567
+ # triage decision
568
+ triage_action = "reject"
569
+ if quality_score >= auto_accept_threshold and not ambiguous:
570
+ triage_action = "pass"
571
+ elif quality_score >= review_threshold:
572
+ triage_action = "review"
573
+ else:
574
+ triage_action = "reject"
575
+
576
+ # compile flags/reasons
577
+ flag_reasons = []
578
+ if not supported_by_embeddings:
579
+ flag_reasons.append("no_strong_embedding_evidence")
580
+ if entailment_scores and correct_entail < 0.6:
581
+ flag_reasons.append("low_entailment_score_for_correct")
582
+ if qa_pipeline is not None and qa_score > 0.6 and not qa_agrees:
583
+ flag_reasons.append("qa_contradiction")
584
+ if ambiguous:
585
+ flag_reasons.append("ambiguous_options_supported")
586
+ if distractor_flags:
587
+ flag_reasons.append({"distractor_issues": distractor_flags})
588
+
589
+ # assemble per-question report
590
+ report[qid] = {
591
+ "supported_by_embeddings": bool(supported_by_embeddings),
592
+ "max_similarity": float(max_sim),
593
+ "evidence": evidence_list,
594
+ "entailment_scores": entailment_scores,
595
+ "correct_entailment": float(correct_entail),
596
+ "qa_answer": qa_answer,
597
+ "qa_score": float(qa_score),
598
+ "qa_agrees": bool(qa_agrees),
599
+ "distractor_similarities": distractor_similarities,
600
+ "distractor_flags": distractor_flags,
601
+ "distractor_penalty": float(distractor_penalty),
602
+ "ambiguous_options": ambiguous_options,
603
+ "quality_score": float(quality_score),
604
+ "triage_action": triage_action,
605
+ "flag_reasons": flag_reasons,
606
+ }
607
+
608
+ return report
609
+
610
+ def connect_qdrant(self, url: str, api_key: str = None, prefer_grpc: bool = False):
611
+ if not _HAS_QDRANT:
612
+ raise RuntimeError("qdrant-client is not installed. Install with `pip install qdrant-client`.")
613
+ self.qdrant_url = url
614
+ self.qdrant_api_key = api_key
615
+ self.qdrant_prefer_grpc = prefer_grpc
616
+ # Create client
617
+ self.qdrant = QdrantClient(url=url, api_key=api_key, prefer_grpc=prefer_grpc)
618
+
619
+ def _ensure_collection(self, collection_name: str):
620
+ if self.qdrant is None:
621
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
622
+ try:
623
+ # get_collection will raise if not present
624
+ _ = self.qdrant.get_collection(collection_name)
625
+ except Exception:
626
+ # create collection with vector size = self.dim
627
+ vect_params = VectorParams(size=self.dim, distance=Distance.COSINE)
628
+ self.qdrant.recreate_collection(collection_name=collection_name, vectors_config=vect_params)
629
+ # recreate_collection ensures a clean collection; if you prefer to avoid wiping use create_collection instead.
630
+
631
+ def save_pdf_to_qdrant(
632
+ self,
633
+ pdf_path: str,
634
+ filename: str,
635
+ collection: str,
636
+ max_chars: int = 1200,
637
+ batch_size: int = 64,
638
+ overwrite: bool = False,
639
+ ):
640
+ if self.qdrant is None:
641
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
642
+
643
+ # extract pages and chunks (re-using your existing helpers)
644
+ pages = self.extract_pages(pdf_path)
645
+
646
+ all_chunks = []
647
+ all_meta = []
648
+ for p_idx, page_text in enumerate(pages, start=1):
649
+ chunks = self.chunk_text(page_text or "", max_chars=max_chars)
650
+ for cid, ch in enumerate(chunks, start=1):
651
+ all_chunks.append(ch)
652
+ all_meta.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
653
+
654
+ if not all_chunks:
655
+ raise RuntimeError("No tSext extracted from PDF.")
656
+
657
+ # ensure collection exists
658
+ self._ensure_collection(collection)
659
+
660
+ # optional: delete previous points for this filename if overwrite
661
+ if overwrite:
662
+ # delete by filter: filename == filename
663
+ flt = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
664
+ try:
665
+ # qdrant-client delete uses delete(
666
+ self.qdrant.delete(collection_name=collection, filter=flt)
667
+ except Exception:
668
+ # ignore if deletion fails
669
+ pass
670
+
671
+ # compute embeddings in batches
672
+ embeddings = self.embedder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
673
+ embeddings = embeddings.astype("float32")
674
+
675
+ # prepare points
676
+ points = []
677
+ for i, (emb, md, txt) in enumerate(zip(embeddings, all_meta, all_chunks)):
678
+ pid = str(uuid4())
679
+ source_id = f"{filename}__p{md['page']}__c{md['chunk_id']}"
680
+ payload = {
681
+ "filename": filename,
682
+ "page": md["page"],
683
+ "chunk_id": md["chunk_id"],
684
+ "length": md["length"],
685
+ "text": txt,
686
+ "source_id": source_id,
687
+ }
688
+ points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload))
689
+
690
+ # upsert in batches
691
+ if len(points) >= batch_size:
692
+ self.qdrant.upsert(collection_name=collection, points=points)
693
+ points = []
694
+
695
+ # upsert remaining
696
+ if points:
697
+ self.qdrant.upsert(collection_name=collection, points=points)
698
+
699
+ try:
700
+ self.qdrant.create_payload_index(
701
+ collection_name=collection,
702
+ field_name="filename",
703
+ field_schema=rest.PayloadSchemaType.KEYWORD
704
+ )
705
+ except Exception as e:
706
+ print(f"Index creation skipped or failed: {e}")
707
+
708
+ return {"status": "ok", "uploaded_chunks": len(all_chunks), "collection": collection, "filename": filename}
709
+
710
+
711
+ def list_files_in_collection(
712
+ self,
713
+ collection: str,
714
+ payload_field: str = "filename",
715
+ batch_size: int = 500,
716
+ ) -> List[str]:
717
+ if self.qdrant is None:
718
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
719
+
720
+ # ensure collection exists
721
+ try:
722
+ if not self.qdrant.collection_exists(collection):
723
+ raise RuntimeError(f"Collection '{collection}' does not exist.")
724
+ except Exception:
725
+ # collection_exists may raise if server unreachable
726
+ raise
727
+
728
+ filenames = set()
729
+ offset = None
730
+
731
+ while True:
732
+ # scroll returns (points, next_offset)
733
+ pts, next_offset = self.qdrant.scroll(
734
+ collection_name=collection,
735
+ limit=batch_size,
736
+ offset=offset,
737
+ with_payload=[payload_field],
738
+ with_vectors=False,
739
+ )
740
+
741
+ if not pts:
742
+ break
743
+
744
+ for p in pts:
745
+ # p may be a dict-like or an object with .payload
746
+ payload = None
747
+ if hasattr(p, "payload"):
748
+ payload = p.payload
749
+ elif isinstance(p, dict):
750
+ # older/newer variants might use nested structures: try common keys
751
+ payload = p.get("payload") or p.get("payload", None) or p
752
+ else:
753
+ # best-effort fallback: convert to dict if possible
754
+ try:
755
+ payload = dict(p)
756
+ except Exception:
757
+ payload = None
758
+
759
+ if not payload:
760
+ continue
761
+
762
+ # extract candidate value(s)
763
+ val = None
764
+ if isinstance(payload, dict):
765
+ val = payload.get(payload_field)
766
+ else:
767
+ # Some payload representations store fields differently; try attribute access
768
+ val = getattr(payload, payload_field, None)
769
+
770
+ # If value is list-like, iterate, else add single
771
+ if isinstance(val, (list, tuple, set)):
772
+ for v in val:
773
+ if v is not None:
774
+ filenames.add(str(v))
775
+ elif val is not None:
776
+ filenames.add(str(val))
777
+
778
+ # stop if no more pages
779
+ if not next_offset:
780
+ break
781
+ offset = next_offset
782
+
783
+ return sorted(filenames)
784
+
785
+
786
+ def list_chunks_for_filename(self, collection: str, filename: str, batch: int = 256) -> List[Dict[str, Any]]:
787
+ if self.qdrant is None:
788
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
789
+
790
+ results = []
791
+ offset = None
792
+ while True:
793
+ # scroll returns (points, next_offset)
794
+ points, next_offset = self.qdrant.scroll(
795
+ collection_name=collection,
796
+ scroll_filter=Filter(
797
+ must=[
798
+ FieldCondition(key="filename", match=MatchValue(value=filename))
799
+ ]
800
+ ),
801
+ limit=batch,
802
+ offset=offset,
803
+ with_payload=True,
804
+ with_vectors=False,
805
+ )
806
+ # points are objects (Record / ScoredPoint-like); get id and payload
807
+ for p in points:
808
+ # p.payload is a dict, p.id is point id
809
+ results.append({"point_id": p.id, "payload": p.payload})
810
+ if not next_offset:
811
+ break
812
+ offset = next_offset
813
+ return results
814
+
815
+
816
+ def _retrieve_qdrant(self, query: str, collection: str, filename: str = None, top_k: int = 3) -> List[Tuple[Dict[str, Any], float]]:
817
+ if self.qdrant is None:
818
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
819
+
820
+ q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")[0].tolist()
821
+ q_filter = None
822
+ if filename:
823
+ q_filter = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
824
+
825
+ search_res = self.qdrant.search(
826
+ collection_name=collection,
827
+ query_vector=q_emb,
828
+ query_filter=q_filter,
829
+ limit=top_k,
830
+ with_payload=True,
831
+ with_vectors=False,
832
+ )
833
+
834
+ out = []
835
+ for hit in search_res:
836
+ # hit.payload is the stored payload, hit.score is similarity
837
+ out.append((hit.payload, float(getattr(hit, "score", 0.0))))
838
+ return out
839
+
840
+
841
+ def generate_from_qdrant(
842
+ self,
843
+ filename: str,
844
+ collection: str,
845
+ n_questions: int = 10,
846
+ mode: str = "rag", # 'per_chunk' or 'rag'
847
+ questions_per_chunk: int = 3, # used for 'per_chunk'
848
+ top_k: int = 3, # retrieval size used in RAG
849
+ temperature: float = 0.2,
850
+ enable_fiddler: bool = False,
851
+ target_difficulty: str = 'easy',
852
+
853
+ ) -> Dict[str, Any]:
854
+ if self.qdrant is None:
855
+ raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
856
+
857
+ # get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
858
+ file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
859
+ if not file_points:
860
+ raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
861
+
862
+ # create a local list of texts & metadata for sampling
863
+ texts = []
864
+ metas = []
865
+ for p in file_points:
866
+ payload = p.get("payload", {})
867
+ text = payload.get("text", "")
868
+ texts.append(text)
869
+ metas.append(payload)
870
+
871
+ self.texts = texts
872
+ self.metadata = metas
873
+ embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
874
+ if embeddings is None or len(embeddings) == 0:
875
+ self.embeddings = None
876
+ self.index = None
877
+ else:
878
+ self.embeddings = embeddings.astype("float32")
879
+
880
+ # update dim in case embedder changed unexpectedly
881
+ self.dim = int(self.embeddings.shape[1])
882
+
883
+ # build index
884
+ self._build_faiss_index()
885
+
886
+ output = {}
887
+ qcount = 0
888
+
889
+ if mode == "per_chunk":
890
+ # iterate all chunks (in payload order) and request questions_per_chunk from each
891
+ for i, txt in enumerate(texts):
892
+ if not txt.strip():
893
+ continue
894
+ to_gen = questions_per_chunk
895
+ try:
896
+ mcq_block = new_generate_mcqs_from_text(txt, n=to_gen, model=self.generation_model, temperature=temperature, enable_fiddler=False)
897
+ except Exception as e:
898
+ print(f"Generator failed on chunk (index {i}): {e}")
899
+ continue
900
+
901
+ if "error" in list(mcq_block.keys()):
902
+ return output
903
+
904
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
905
+ qcount += 1
906
+ output[str(qcount)] = mcq_block[item]
907
+ if qcount >= n_questions:
908
+ return output
909
+ return output
910
+
911
+ elif mode == "rag":
912
+ attempts = 0
913
+ max_attempts = n_questions * 4
914
+ while qcount < n_questions and attempts < max_attempts:
915
+ attempts += 1
916
+ # create a seed query: pick a random chunk, pick a sentence from it
917
+ seed_idx = random.randrange(len(self.texts))
918
+ chunk = self.texts[seed_idx]
919
+ sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
920
+ candidate = [s for s in sents if len(s.strip()) > 20]
921
+ if candidate:
922
+ seed_sent = random.choice(candidate)
923
+ else:
924
+ stripped = chunk.strip()
925
+ seed_sent = (stripped[:200] if stripped else "[no text available]")
926
+ query = f"Create questions about: {seed_sent}"
927
+
928
+
929
+ # retrieve top_k chunks from the same file (restricted by filename filter)
930
+ retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
931
+ print('retrieved qdrant', retrieved)
932
+ context_parts = []
933
+ for payload, score in retrieved:
934
+ # payload should contain page & chunk_id and text
935
+ page = payload.get("page", "?")
936
+ ctxt = payload.get("text", "")
937
+ context_parts.append(f"[page {page}] {ctxt}")
938
+ context = "\n\n".join(context_parts)
939
+
940
+
941
+ # q generation
942
+ try:
943
+ # Difficulty pipeline: easy, mid, difficult
944
+ structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False, target_difficulty=target_difficulty)
945
+ mcq_block = new_generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=False, target_difficulty=target_difficulty)
946
+ except Exception as e:
947
+ print(f"Generator failed during RAG attempt {attempts}: {e}")
948
+ continue
949
+
950
+ if "error" in list(mcq_block.keys()):
951
+ return output
952
+
953
+ for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
954
+ payload = mcq_block[item]
955
+ q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
956
+ options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
957
+
958
+ if isinstance(options, list):
959
+ options = {str(i+1): o for i, o in enumerate(options)}
960
+
961
+ correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
962
+ concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
963
+
964
+ correct_text = ""
965
+ if isinstance(correct_key, str) and correct_key.strip() in options:
966
+ correct_text = options[correct_key.strip()]
967
+ else:
968
+ correct_text = payload.get("correct_text") or correct_key or ""
969
+
970
+ #? change estimate
971
+ diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
972
+ q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used = concepts
973
+ )
974
+
975
+ payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
976
+
977
+ # CHECK n generation: if number of request mcqs < default generation number e.g. 5 - 3 = 2 < 3 then only genearate 2 mcqs
978
+ if n_questions - qcount < questions_per_chunk:
979
+ questions_per_chunk = n_questions - qcount
980
+
981
+ qcount += 1 # count number of question
982
+ print('qcount:', qcount)
983
+ print('questions_per_chunk:', questions_per_chunk)
984
+
985
+ output[str(qcount)] = mcq_block[item]
986
+ if qcount >= n_questions:
987
+ return output
988
+
989
+ if output is not None:
990
+ print("output available")
991
+ return output
992
+ else:
993
+ raise ValueError("mode must be 'per_chunk' or 'rag'.")
994
+
995
+
996
+ def _estimate_difficulty_for_generation(
997
+ self,
998
+ q_text: str,
999
+ options: Dict[str, str],
1000
+ correct_text: str,
1001
+ context_text: str = "",
1002
+ concepts_used: Dict = {}
1003
+ ) -> Tuple[float, str]:
1004
+ def safe_map_sim(s):
1005
+ # map potentially [-1,1] cosine-like to [0,1], clamp
1006
+ try:
1007
+ s = float(s)
1008
+ except Exception:
1009
+ return 0.0
1010
+ mapped = (s + 1.0) / 2.0
1011
+ return max(0.0, min(1.0, mapped))
1012
+
1013
+ # embedding support
1014
+ emb_support = 0.0
1015
+ try:
1016
+ stmt = (q_text or "").strip()
1017
+ if correct_text:
1018
+ stmt = f"{stmt} Answer: {correct_text}"
1019
+
1020
+ # use internal retrieve but map returned score
1021
+ res = []
1022
+ try:
1023
+ res = self._retrieve(stmt, top_k=1)
1024
+ except Exception:
1025
+ res = []
1026
+
1027
+ if res:
1028
+ raw_score = float(res[0][1])
1029
+ emb_support = safe_map_sim(raw_score)
1030
+ else:
1031
+ emb_support = 0.0
1032
+ except Exception:
1033
+ emb_support = 0.0
1034
+
1035
+ # distractor sims
1036
+ mean_sim = 0.0
1037
+ distractor_penalty = 0.0
1038
+ amb_flag = 0.0
1039
+ try:
1040
+ keys = list(options.keys())
1041
+ texts = [options[k] for k in keys]
1042
+ if correct_text is None:
1043
+ correct_text = ""
1044
+
1045
+ all_texts = [correct_text] + texts
1046
+ embs = self.embedder.encode(all_texts, convert_to_numpy=True)
1047
+ embs = np.asarray(embs, dtype=float)
1048
+ norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12
1049
+ embs = embs / norms
1050
+ corr = embs[0]
1051
+ opts = embs[1:]
1052
+
1053
+ if opts.size == 0:
1054
+ mean_sim = 0.0
1055
+ distractor_penalty = 0.0
1056
+ gap = 0.0
1057
+ else:
1058
+ sims = (opts @ corr).tolist() # [-1,1]
1059
+ sims_mapped = [safe_map_sim(s) for s in sims] # [0,1]
1060
+ mean_sim = float(sum(sims_mapped) / len(sims_mapped))
1061
+ # gap between best distractor and second best (higher gap -> easier)
1062
+ sorted_s = sorted(sims_mapped, reverse=True)
1063
+ top = sorted_s[0]
1064
+ second = sorted_s[1] if len(sorted_s) > 1 else 0.0
1065
+ gap = top - second
1066
+ # penalties: if distractors are extremely close to correct -> higher penalty
1067
+ too_close_count = sum(1 for s in sims_mapped if s >= 0.85)
1068
+ too_far_count = sum(1 for s in sims_mapped if s <= 0.15)
1069
+ distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped))))
1070
+ amb_flag = 1.0 if top >= 0.8 else 0.0
1071
+ except Exception:
1072
+ mean_sim = 0.0
1073
+ distractor_penalty = 0.0
1074
+ amb_flag = 0.0
1075
+ gap = 0.0
1076
+
1077
+ # question length normalized
1078
+ question_len = len((q_text or "").strip())
1079
+ question_len_norm = min(1.0, question_len / 300.0)
1080
+
1081
+ # count number of concept from string
1082
+ concepts_num = len(concepts_used.keys())
1083
+ if concepts_num < 2:
1084
+ concepts_penalty = 0
1085
+ else:
1086
+ concepts_penalty = concepts_num
1087
+
1088
+ # combine signals using safer semantics:
1089
+ # higher emb_support -> easier (so we subtract a term)
1090
+ # higher distractor_penalty -> harder (add)
1091
+ # better gap -> easier (subtract)
1092
+ # compute score (higher -> harder)
1093
+
1094
+ score = 0.3 # more toward easy
1095
+ score += 0.35 * float(distractor_penalty) # stronger penalty for similar distractors
1096
+ score += 0.2 * float(mean_sim) # emphasizes average distractor similarity (harder if distractors are close, per "khó" criteria)
1097
+ score += 0.12 * float(amb_flag) # penalty if the best distractors hard to distinguish
1098
+ score += 0.1 * float(concepts_penalty) # boost difficuty if more concept used in a question
1099
+ score -= 0.15 * float(gap) # less emphasis on "dễ" if gap is large but other factors are hard
1100
+ score += 0.05 * float(question_len_norm)
1101
+ score -= 0.45 * float(emb_support) # easy ques is obvious while hard question get penalty because they often get rephrase from the original concet -> harder for embedding suppport to be meaningful.
1102
+
1103
+ # clamp
1104
+ score = max(0.0, min(1.0, float(score)))
1105
+ components = {
1106
+ "base": 0.3,
1107
+ "distractor_penalty": 0.35 * float(distractor_penalty),
1108
+ "mean_sim": 0.15 * float(mean_sim),
1109
+ "amb_flag": 0.05 * float(amb_flag),
1110
+ "concepts_num": 0.1 * float(concepts_num),
1111
+ "gap": -0.12 * float(gap),
1112
+ "question_len_norm": 0.05 * float(question_len_norm),
1113
+ "emb_support": -0.45 * float(emb_support),
1114
+ "total_score": score,
1115
+ }
1116
+
1117
+ # label
1118
+ if score <= 0.56:
1119
+ label = "dễ"
1120
+ elif score <= 0.755 and score > 0.56:
1121
+ label = "trung bình"
1122
+ else:
1123
+ label = "khó"
1124
+
1125
+ return score, label, components # type: ignore
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ boto3
2
+ faiss-cpu
3
+ transformers
4
+ sentence-transformers
5
+ fastapi[standard]
6
+ uvicorn
7
+ qdrant-client
8
+ pymupdf4llm
9
+ uuid
10
+ huggingface_hub
utils.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ from typing import Dict, Any
4
+ import requests
5
+ import os
6
+ import numpy as np
7
+ import uuid
8
+ import datetime
9
+ import pathlib
10
+ import time
11
+
12
+ #TODO: allow to choose different provider later + dynamic routing when token expired
13
+ API_URL = "https://openrouter.ai/api/v1/chat/completions"
14
+ CEREBRAS_API_KEY = os.environ['OPENROUTER_KEY']
15
+
16
+ HEADERS = {"Authorization": f"Bearer {CEREBRAS_API_KEY}", "Content-Type": "application/json"}
17
+ JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
18
+
19
+ INPUT_TOKEN_COUNT = np.array([], dtype=int)
20
+ OUTPUT_TOKEN_COUNT = np.array([], dtype=int)
21
+ TOTAL_TOKEN_COUNT = np.array([], dtype=int)
22
+ TOTAL_TOKEN_COUNT_EACH_GENERATION = np.array([])
23
+ TIME_INFOs = {}
24
+
25
+
26
+ FIDDLER_GUARDRAILS_TOKEN = os.environ['FIDDLER_TOKEN']
27
+ SAFETY_GUARDRAILS_URL = "https://guardrails.cloud.fiddler.ai/v3/guardrails/ftl-safety"
28
+ GUARDRAILS_HEADERS = {
29
+ 'Content-Type': 'application/json',
30
+ 'Authorization': f'Bearer {FIDDLER_GUARDRAILS_TOKEN}',
31
+ }
32
+
33
+ def get_safety_response(text, sleep_seconds: float = 0.5):
34
+ time.sleep(sleep_seconds) # rate limited
35
+ response = requests.post(
36
+ SAFETY_GUARDRAILS_URL,
37
+ headers=GUARDRAILS_HEADERS,
38
+ json={'data': {'input': text}},
39
+ )
40
+ response.raise_for_status()
41
+ response_dict = response.json()
42
+ return response_dict
43
+
44
+ def text_safety_check(text: str, sleep_seconds: float = 0.5):
45
+ confs = get_safety_response(text, sleep_seconds)
46
+ max_conf = max(confs.values())
47
+ max_category = list(confs.keys())[list(confs.values()).index(max_conf)]
48
+ return max_conf, max_category
49
+
50
+ def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
51
+ if model == 'openai/gpt-oss-120b': # OpenRouter's version of Cerebras-GPT
52
+ payload = {"model": model, "messages": messages, "temperature": temperature, "provider": {"only": ["Cerebras", "together", "baseten", "deepinfra/fp4"]}}
53
+ else: # default to Cerebras
54
+ payload = {"model": model, "messages": messages, "temperature": temperature}
55
+
56
+ resp = requests.post(API_URL, headers=HEADERS, json=payload, timeout=timeout)
57
+ resp.raise_for_status()
58
+ data = resp.json()
59
+
60
+ # handle various shapes
61
+ if "choices" in data and len(data["choices"]) > 0:
62
+ # prefer message.content
63
+ ch = data["choices"][0]
64
+ if isinstance(ch, dict) and "message" in ch and "content" in ch["message"]:
65
+ return ch["message"]["content"]
66
+ if "text" in ch:
67
+ return ch["text"]
68
+ # final fallback
69
+ raise RuntimeError("Unexpected HF response shape: " + json.dumps(data)[:200])
70
+
71
+ def _safe_extract_json(text: str) -> dict:
72
+ # remove triple backticks
73
+ text = re.sub(r"```(?:json)?\n?", "", text)
74
+ m = JSON_OBJ_RE.search(text)
75
+ if not m:
76
+ raise ValueError("No JSON object found in model output.")
77
+ js = m.group(1)
78
+ # try load, fix trailing commas
79
+ try:
80
+ return json.loads(js)
81
+ except json.JSONDecodeError:
82
+ fixed = re.sub(r",\s*([}\]])", r"\1", js)
83
+ return json.loads(fixed)
84
+
85
+
86
+ def structure_context_for_llm(
87
+ source_text: str,
88
+ model: str = "openai/gpt-oss-120b",
89
+ temperature: float = 0.2,
90
+ enable_fiddler = False,
91
+ ) -> Dict[str, Any]:
92
+ """
93
+ Take a long source_text, split into N chunks, and restructure them
94
+ so each chunk is self-contained, structured, and semantically meaningful.
95
+ """
96
+
97
+ system_message = {
98
+ "role": "system",
99
+ "content": (
100
+ "Bạn là một trợ lý hữu ích chuyên xử lý và cấu trúc văn bản để phục vụ mô hình ngôn ngữ (LLM). Trả lời bằng Tiếng Việt\n"
101
+ "Nhiệm vụ của bạn là:\n"
102
+ "- Nếu văn bản dài trên 500 từ chia văn bản thành 2 đoạn (chunk) có ý nghĩa rõ ràng.\n"
103
+ "- Mỗi chunk phải **tự chứa đủ thông tin** (self-contained) để LLM có thể hiểu độc lập.\n"
104
+ "- Xác định **chủ đề chính (topic)** của mỗi chunk và dùng nó làm KEY trong JSON.\n"
105
+ "- Trong mỗi topic, tổ chức thông tin thành cấu trúc rõ ràng gồm các trường:\n"
106
+ " - 'đoạn văn': nội dung gốc đã cấu trúc đầy đủ\n"
107
+ " - 'khái niệm chính': từ điểm chứa các khái niệm chính với khái niệm phụ hỗ trợ khái niệm chính đi kèm nếu có\n"
108
+ " - 'công thức': danh sách công thức (nếu có)\n"
109
+ " - 'ví dụ': ví dụ minh họa (nếu có)\n"
110
+ " - 'tóm tắt': tóm tắt nội dung, dễ hiểu\n"
111
+ "- Giữ ngữ nghĩa liền mạch.\n"
112
+ "- Chỉ TRẢ VỀ MỘT JSON hợp lệ theo schema, không kèm văn bản khác.\n\n"
113
+
114
+ "Chỉ TRẢ VỀ duy nhất MỘT đối tượng JSON theo schema sau và không có bất kỳ văn bản nào khác:\n\n"
115
+ "{\n"
116
+ ' "Tên topic": {"đoạn văn": "nội dung đã cấu trúc của topic 1", "khái niệm chính": {"khái niệm chính 1":["khái niệm phụ", "..."],"khái niệm chính 2":["khái niệm phụ", "..."]}, "công thức": ["..."], "ví dụ": ["..."], "tóm tắt": "tóm tắt ngắn gọn"},\n'
117
+ "}\n"
118
+ )
119
+ }
120
+
121
+ user_message = {
122
+ "role": "user",
123
+ "content": (
124
+ "Hãy chia văn bản sau thành nhiều chunk theo hướng dẫn trên và xuất JSON hợp lệ.\n"
125
+ f"### Văn bản nguồn:\n{source_text}"
126
+ )
127
+ }
128
+
129
+ if enable_fiddler:
130
+ max_conf, max_cat = text_safety_check(user_message['content'])
131
+ if max_conf > 0.5:
132
+ print(f"Harmful content detected: ({max_cat} : {max_conf})")
133
+ return {}
134
+
135
+ raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
136
+ parsed = _safe_extract_json(raw)
137
+ if not isinstance(parsed, dict):
138
+ raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
139
+ return parsed
140
+
141
+
142
+ def new_generate_mcqs_from_text(
143
+ source_text: str,
144
+ n: int = 3,
145
+ model: str = "openai/openai/gpt-oss-120b",
146
+ temperature: float = 0.2,
147
+ enable_fiddler = False,
148
+ target_difficulty: str = "easy",
149
+
150
+ ) -> Dict[str, Any]:
151
+
152
+
153
+ expected_concepts = {
154
+ "easy": 1,
155
+ "medium": 2,
156
+ "hard": (3, 4)
157
+ }
158
+ if isinstance(expected_concepts[target_difficulty], tuple):
159
+ min_concepts, max_concepts = expected_concepts[target_difficulty]
160
+ concept_range = f"{min_concepts}-{max_concepts}"
161
+ else:
162
+ concept_range = expected_concepts[target_difficulty]
163
+
164
+
165
+ difficulty_prompts = {
166
+ "easy": (
167
+ "- Câu hỏi DỄ: kiểm tra duy nhất 1 khái niệm chính cơ bản dễ hiểu, định nghĩa, hoặc công thức đơn giản."
168
+ "- Đáp án có thể tìm thấy trực tiếp trong văn bản."
169
+ "- Ngữ cảnh đủ để hiểu khái niệm chính."
170
+ "- Distractors khác biệt rõ ràng, dễ loại bỏ."
171
+ "- Độ dài câu hỏi ngắn gọn không quá 10-20 từ hoặc ít hơn 120 ký tự, tập trung vào một ý duy nhất.\n"
172
+ ),
173
+ "medium": (
174
+ "- Câu hỏi TRUNG BÌNH kiểm tra khái niệm chính trong văn bản"
175
+ "- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
176
+ "- Các Distractors không quá giống nhau."
177
+ "- Độ dài câu hỏi vừa phải khoảng 23–30 từ hoặc khoảng 150 - 180 ký tự, có thêm chi tiết phụ để suy luận.\n"
178
+ ),
179
+ "hard": (
180
+ "- Câu hỏi KHÓ kiểm tra thông tin được phân tích/tổng hợp"
181
+ "- Nếu câu hỏi thuộc dạng áp dụng và suy luận thiếu dữ liệu để trả lời câu hỏi, thêm nội dung hoặc ví dụ từ văn bản nguồn."
182
+ "- Ít nhất 2 distractors gần giống đáp án đúng, độ tương đồng cao. "
183
+ f"- Đáp án yêu cầu học sinh suy luận hoặc áp dụng công thức vào ví dụ nếu có."
184
+ "- Độ dài câu hỏi dài hơn 35 từ hoặc hơn 200 ký tự.\n \n"
185
+ )
186
+ }
187
+
188
+ difficult_criteria = difficulty_prompts[target_difficulty] # "easy", "medium", "hard"
189
+ print(concept_range)
190
+ system_message = {
191
+ "role": "system",
192
+ "content": (
193
+ "Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm (MCQ). Luôn trả lời bằng tiếng việt"
194
+ f"Đảm bảo chỉ tạo sinh câu trắc nghiệm có độ khó sau {difficult_criteria}"
195
+ f"Quan trọng: Mỗi câu hỏi chỉ sử dụng chính xác {concept_range} khái niệm chính (mỗi khái niệm chính có 1 danh sách khái niệm phụ) từ văn bản nguồn. "
196
+ "Mỗi câu hỏi và đáp án phải dựa trên thông tin từ văn bản nguồn. Không được đưa kiến thức ngoài vào."
197
+ "Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không kèm giải thích hay trường thêm:\n\n"
198
+ "{\n"
199
+ ' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"...", "khái niệm sử dụng": {"khái niệm chính":["khái niệm phụ", "..."], "..."]}},\n'
200
+ ' "2": { ... }\n'
201
+ "}\n\n"
202
+ "Lưu ý:\n"
203
+ f"- Tạo đúng {n} mục, đánh số từ 1 tới {n}.\n"
204
+ "- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
205
+ "- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'options'.\n"
206
+ "- Toàn bộ thông tin cần thiết để trả lời phải nằm trong chính câu hỏi, không tham chiếu lại văn bản nguồn."
207
+ f"- Sử dụng chính xác {concept_range} khái niệm chính"
208
+ )
209
+ }
210
+
211
+ user_message = {
212
+ "role": "user",
213
+ "content": (
214
+ f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Chỉ sử dụng nội dung này làm nguồn duy nhất để xây dựng câu hỏi.\n\n"
215
+
216
+ "### Yêu cầu:\n"
217
+ "- Bám sát vào thông tin trong văn bản; không thêm kiến thức ngoài.\n"
218
+ "- Nếu văn bản thiếu chi tiết, hãy tạo phương án nhiễu (distractors) hợp lý, nhưng phải có thể biện minh từ nội dung hoặc ngữ cảnh.\n"
219
+ f"### Văn bản nguồn:\n{source_text}"
220
+ )
221
+ }
222
+
223
+
224
+ if enable_fiddler:
225
+ max_conf, max_cat = text_safety_check(user_message['content'])
226
+ if max_conf > 0.5:
227
+ print(f"Harmful content detected: ({max_cat} : {max_conf})")
228
+ return {}
229
+
230
+ raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
231
+ # print('\n\n',raw)
232
+ parsed = _safe_extract_json(raw)
233
+ # basic validation
234
+ if not isinstance(parsed, dict) or len(parsed) != n:
235
+ raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
236
+ return parsed
237
+
238
+
239
+
240
+ def generate_mcqs_from_text(
241
+ source_text: str,
242
+ n: int = 3,
243
+ model: str = "openai/gpt-oss-120b",
244
+ temperature: float = 0.2,
245
+ enable_fiddler: bool = False,
246
+ ) -> Dict[str, Any]:
247
+ system_message = {
248
+ "role": "system",
249
+ "content": (
250
+ "Bạn là một trợ lý hữu ích chuyên tạo câu hỏi trắc nghiệm. "
251
+ "Chỉ TRẢ VỀ duy nhất một đối tượng JSON theo đúng schema sau và không có bất kỳ văn bản nào khác:\n\n"
252
+ "{\n"
253
+ ' "1": { "câu hỏi": "...", "lựa chọn": {"a":"...","b":"...","c":"...","d":"..."}, "đáp án":"..."},\n'
254
+ ' "2": { ... }\n'
255
+ "}\n\n"
256
+ "Lưu ý:\n"
257
+ f"- Tạo đúng {n} mục, đánh số từ 1 tới {n}.\n"
258
+ "- Khóa 'lựa chọn' phải có các phím a, b, c, d.\n"
259
+ "- 'đáp án' phải là toàn văn đáp án đúng (không phải ký tự chữ cái), và giá trị này phải khớp chính xác với một trong các giá trị trong 'lựa chọn'.\n"
260
+ "- Không kèm giải thích hay trường thêm.\n"
261
+ "- Các phương án sai (distractors) phải hợp lý và không lặp lại."
262
+ )
263
+ }
264
+ user_message = {
265
+ "role": "user",
266
+ "content": (
267
+ f"Hãy tạo {n} câu hỏi trắc nghiệm từ nội dung dưới đây. Dùng nội dung này làm nguồn duy nhất để trả lời."
268
+ "Nếu nội dung quá ít để tạo câu hỏi chính xác, hãy tạo các phương án hợp lý nhưng có thể biện minh được.\n\n"
269
+ f"Nội dung:\n\n{source_text}"
270
+ )
271
+ }
272
+
273
+ if enable_fiddler:
274
+ max_conf, max_cat = text_safety_check(user_message['content'])
275
+ if max_conf > 0.5:
276
+ print(f"Harmful content detected: ({max_cat} : {max_conf})")
277
+ return {"error": "Harmful content detected", f"{max_cat}": f"{str(max_conf)}"}
278
+
279
+ raw = _post_chat([system_message, user_message], model=model, temperature=temperature)
280
+ parsed = _safe_extract_json(raw)
281
+
282
+ # validate structure and length
283
+ if not isinstance(parsed, dict) or len(parsed) != n:
284
+ raise ValueError(f"Generator returned invalid structure. Raw:\n{raw}")
285
+ return parsed
286
+
287
+
288
+ # helpers to read/reset token counts
289
+ def get_token_count_record():
290
+ global TOTAL_TOKEN_COUNT_EACH_GENERATION
291
+ TOTAL_TOKEN_COUNT_EACH_GENERATION = np.append(TOTAL_TOKEN_COUNT_EACH_GENERATION, np.sum(TOTAL_TOKEN_COUNT))
292
+
293
+ token_record = {
294
+ 'INPUT_token_count': np.sum(INPUT_TOKEN_COUNT),
295
+ 'OUTPUT_token_count': np.sum(OUTPUT_TOKEN_COUNT),
296
+ 'AVG_INPUT_token_count': np.average(INPUT_TOKEN_COUNT),
297
+ 'AVG_OUTPUT_token_count': np.average(OUTPUT_TOKEN_COUNT),
298
+ 'TOTAL_token_count': TOTAL_TOKEN_COUNT,
299
+ 'TOTAL_token_count_PER_GENERATION - ': TOTAL_TOKEN_COUNT_EACH_GENERATION,
300
+ 'AVG_TOTAL_token_count_PER_GENERATION': [np.average(TOTAL_TOKEN_COUNT_EACH_GENERATION), len(TOTAL_TOKEN_COUNT_EACH_GENERATION)],
301
+ }
302
+
303
+ return token_record
304
+
305
+
306
+ def reset_token_count(reset_all=None):
307
+ """Call in app.py. For Reset Token Count after 1 Generation Session"""
308
+ global INPUT_TOKEN_COUNT, OUTPUT_TOKEN_COUNT, TOTAL_TOKEN_COUNT, TOTAL_TOKEN_COUNT_EACH_GENERATION
309
+
310
+ INPUT_TOKEN_COUNT = np.array([])
311
+ OUTPUT_TOKEN_COUNT = np.array([])
312
+ TOTAL_TOKEN_COUNT = np.array([])
313
+
314
+ if reset_all:
315
+ TOTAL_TOKEN_COUNT_EACH_GENERATION = np.array([])
316
+
317
+
318
+ def update_token_count(token_usage):
319
+ """Update Token Count for each generation
320
+ "usage": {
321
+ "prompt_tokens": 1209,
322
+ "completion_tokens": 313,
323
+ "total_tokens": 1522,
324
+ "prompt_tokens_details": {
325
+ "cached_tokens": 0
326
+ }
327
+ """
328
+ global INPUT_TOKEN_COUNT, OUTPUT_TOKEN_COUNT, TOTAL_TOKEN_COUNT # get value from global
329
+ prompt_tokens = token_usage['prompt_tokens'] # INPUT token
330
+ completion_tokens = token_usage['completion_tokens'] # OUTPUT token
331
+ total_tokens = token_usage['total_tokens'] # TOTAL token
332
+
333
+ INPUT_TOKEN_COUNT = np.append(INPUT_TOKEN_COUNT, prompt_tokens)
334
+ OUTPUT_TOKEN_COUNT = np.append(OUTPUT_TOKEN_COUNT, completion_tokens)
335
+ TOTAL_TOKEN_COUNT = np.append(TOTAL_TOKEN_COUNT, total_tokens)
336
+
337
+ # print("Input Token Increase:", INPUT_TOKEN_COUNT)
338
+ # print("Output Token Increase:", OUTPUT_TOKEN_COUNT)
339
+
340
+
341
+ def save_logs(record: dict, log_path:str = "logs/generation_log.jsonl"):
342
+ """
343
+ Append log to log_path
344
+ record: dict with keys you want to store (e.g. filename, input/output token_count, collection, etc..)
345
+ """
346
+ # create file if not exist
347
+ p = pathlib.Path(log_path)
348
+ p.parent.mkdir(parents=True, exist_ok=True)
349
+
350
+ # add id/timestampt if missing
351
+ record.setdefault('id', str(uuid.uuid4()))
352
+ record.setdefault('timestamp_utc', datetime.datetime.now(datetime.timezone.utc).isoformat() + "Z") # get current time at timezone
353
+
354
+ # append as 1 json file for each generation
355
+ with open(p, "a", encoding='utf-8') as f:
356
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
357
+
358
+
359
+ def update_time_info(time_info):
360
+ """
361
+ "time_info": {
362
+ "queue_time": 0.000600429,
363
+ "prompt_time": 0.052739054,
364
+ "completion_time": 0.15692187,
365
+ "total_time": 0.2117476463317871,
366
+ "created": 1755599458
367
+ }
368
+ """
369
+ time_info['created'] = time_info
370
+ time_info['created'].pop('created')
371
+
372
+
373
+ def get_time_info():
374
+ global TIME_INFOs
375
+ return TIME_INFOs
376
+ # token_record = {
377
+ # 'completion_time': np.sum(INPUT_TOKEN_COUNT),
378
+ # 'total_time': np.sum(OUTPUT_TOKEN_COUNT),
379
+ # }
380
+
381
+ def log_pipeline(path, content):
382
+ print("Save result to test/mcq_output.json")
383
+ #save_to_local(path=path, content=content)
384
+ token_record = get_token_count_record()
385
+
386
+ print("Token Record:")
387
+ for record, value in token_record.items():
388
+ print(f'{record}:{value}', '\n')
389
+
390
+ reset_token_count()
391
+
392
+ def save_to_local(path, content):
393
+ """
394
+ path = 'test/raw_data.json'
395
+ path = 'test/mcq_output.json'
396
+ path = 'test/extract_output.md'
397
+
398
+ """
399
+ p = pathlib.Path(path)
400
+ p.parent.mkdir(parents=True, exist_ok=True) # create folder if missing
401
+ p.touch(exist_ok=True) # create file if missing
402
+
403
+ if path.lower().endswith('.json'):
404
+ with open(path, 'w', encoding='utf-8') as f:
405
+ f.write(json.dumps(content, ensure_ascii=False, indent=2))
406
+ else:
407
+ with open(path, 'w', encoding='utf-8') as f:
408
+ f.write(f'{content}') # md, txt