mcq-generator / app /generator.py
Sukmadi
ignore json file in /test
1cccc5d
import re
import random
import fitz
import numpy as np
import os
from typing import List, Optional, Tuple, Dict, Any
from sentence_transformers import SentenceTransformer
from uuid import uuid4
import pymupdf4llm
from utils import save_to_local
try:
from qdrant_client import QdrantClient
from qdrant_client.http.models import (
PointStruct,
Filter,
FieldCondition,
MatchValue,
Distance,
VectorParams,
)
from qdrant_client.http import models as rest
_HAS_QDRANT = True
except Exception:
_HAS_QDRANT = False
try:
import faiss
_HAS_FAISS = True
except Exception:
_HAS_FAISS = False
from utils import generate_mcqs_from_text, _post_chat, _safe_extract_json, save_to_local
class RAGMCQ:
def __init__(
self,
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
generation_model: str = "gpt-oss-120b",
qdrant_url: str = os.environ.get('QDRANT_URL') or "",
qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "",
qdrant_prefer_grpc: bool = False,
):
self.embedder = SentenceTransformer(embedder_model)
self.generation_model = generation_model
self.embeddings = None # np.array of shape (N, D)
self.texts = [] # list of chunk texts
self.metadata = [] # list of dicts (page, chunk_id, char_range)
self.index = None
self.dim = self.embedder.get_sentence_embedding_dimension()
self.qdrant = None
self.qdrant_url = qdrant_url
self.qdrant_api_key = qdrant_api_key
self.qdrant_prefer_grpc = qdrant_prefer_grpc
if qdrant_url:
self.connect_qdrant(qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
def extract_pages(
self,
pdf_path: str,
*,
pages: Optional[List[int]] = None,
ignore_images: bool = False,
dpi: int = 150
) -> List[str]:
doc = fitz.open(pdf_path)
try:
# request page-wise output (page_chunks=True -> list[dict] per page)
page_dicts = pymupdf4llm.to_markdown(
doc,
pages=pages,
ignore_images=ignore_images,
dpi=dpi,
page_chunks=True,
)
# to_markdown(..., page_chunks=True) returns a list of dicts, each has key "text" (markdown)
pages_md: List[str] = []
for p in page_dicts:
txt = p.get("text", "") or ""
pages_md.append(txt.strip())
return pages_md
finally:
doc.close()
def chunk_text(self, text: str, max_chars: int = 1200) -> List[str]:
text = text.strip()
if not text:
return []
if len(text) <= max_chars:
return [text]
# split by sentence-like boundaries
sentences = re.split(r'(?<=[\.\?\!])\s+', text)
chunks = []
cur = ""
for s in sentences:
if len(cur) + len(s) + 1 <= max_chars:
cur += (" " if cur else "") + s
else:
if cur:
chunks.append(cur)
cur = s
if cur:
chunks.append(cur)
# if still too long, hard-split
final = []
for c in chunks:
if len(c) <= max_chars:
final.append(c)
else:
for i in range(0, len(c), max_chars):
final.append(c[i:i+max_chars])
return final
def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
pages = self.extract_pages(pdf_path)
self.texts = []
self.metadata = []
for p_idx, page_text in enumerate(pages, start=1):
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
for cid, ch in enumerate(chunks, start=1):
self.texts.append(ch)
self.metadata.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
if not self.texts:
raise RuntimeError("No text extracted from PDF.")
save_to_local('../test/text_chunks.md', file_content=self.texts)
# compute embeddings
emb = self.embedder.encode(self.texts, convert_to_numpy=True, show_progress_bar=True)
self.embeddings = emb.astype("float32")
self._build_faiss_index()
def _build_faiss_index(self):
if _HAS_FAISS:
d = self.embeddings.shape[1]
index = faiss.IndexFlatIP(d) # inner product -> cosine if vectors normalized
faiss.normalize_L2(self.embeddings)
index.add(self.embeddings)
self.index = index
else:
# store normalized embeddings and use brute-force numpy
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-10
self.embeddings = self.embeddings / norms
self.index = None
def _retrieve(self, query: str, top_k: int = 3) -> List[Tuple[int, float]]:
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")
if _HAS_FAISS:
faiss.normalize_L2(q_emb)
D_list, I_list = self.index.search(q_emb, top_k)
# D are inner products; return list of (idx, score)
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
else:
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
sims = (self.embeddings @ qn.T).squeeze(axis=1)
idxs = np.argsort(-sims)[:top_k]
return [(int(i), float(sims[i])) for i in idxs]
def generate_from_pdf(
self,
pdf_path: str,
n_questions: int = 10,
mode: str = "rag", # per_page or rag
questions_per_page: int = 3, # for per_page mode
top_k: int = 3, # chunks to retrieve for each question in rag mode
temperature: float = 0.2,
) -> Dict[str, Any]:
# build index
self.build_index_from_pdf(pdf_path)
output: Dict[str, Any] = {}
qcount = 0
if mode == "per_page":
# iterate pages -> chunks
for idx, meta in enumerate(self.metadata):
chunk_text = self.texts[idx]
if not chunk_text.strip():
continue
to_gen = questions_per_page
# ask generator
try:
mcq_block = generate_mcqs_from_text(
chunk_text, n=to_gen, model=self.generation_model, temperature=temperature
)
except Exception as e:
# skip this chunk if generator fails
print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
continue
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
elif mode == "rag":
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
# create queries by sampling chunk text sentences.
# stop when n_questions reached or max_attempts exceeded.
attempts = 0
max_attempts = n_questions * 4
while qcount < n_questions and attempts < max_attempts:
attempts += 1
# create a seed query: pick a random chunk, pick a sentence from it
seed_idx = random.randrange(len(self.texts))
chunk = self.texts[seed_idx]
#? Investigate Chunking Strategy
with open("chunks.txt", "a", encoding="utf-8") as f:
f.write(chunk + "\n")
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
query = f"Create questions about: {seed_sent}"
# retrieve top_k chunks
retrieved = self._retrieve(query, top_k=top_k)
context_parts = []
for ridx, score in retrieved:
md = self.metadata[ridx]
context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
context = "\n\n".join(context_parts)
save_to_local('../test/context.md', context)
# call generator for 1 question (or small batch) with the retrieved context
try:
# request 1 question at a time to keep diversity
mcq_block = generate_mcqs_from_text(
context, n=1, model=self.generation_model, temperature=temperature
)
except Exception as e:
print(f"Generator failed during RAG attempt {attempts}: {e}")
continue
# append result(s)
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
else:
raise ValueError("mode must be 'per_page' or 'rag'.")
def validate_mcqs(
self,
mcqs: Dict[str, Any],
top_k: int = 4,
similarity_threshold: float = 0.5,
evidence_score_cutoff: float = 0.5,
use_model_verification: bool = True,
model_verification_temperature: float = 0.0,
) -> Dict[str, Any]:
if self.embeddings is None or not self.texts:
raise RuntimeError("Index/embeddings not built. Run build_index_from_pdf() first.")
report: Dict[str, Any] = {}
# helper: semantic similarity search on statement -> returns list of (idx, score)
def semantic_search(statement: str, k: int = top_k):
q_emb = self.embedder.encode([statement], convert_to_numpy=True).astype("float32")
if _HAS_FAISS:
faiss.normalize_L2(q_emb)
D_list, I_list = self.index.search(q_emb, k)
# D are inner products; return list of (idx, score)
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
else:
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
sims = (self.embeddings @ qn.T).squeeze(axis=1)
idxs = np.argsort(-sims)[:k]
return [(int(i), float(sims[i])) for i in idxs]
# helper: verify with model (strict JSON in response)
def _verify_with_model(question_text: str, options: Dict[str, str], correct_text: str, context_text: str):
system = {
"role": "system",
"content": (
"Bạn là một trợ lý đánh giá tính thực chứng của câu hỏi trắc nghiệm dựa trên đoạn văn được cung cấp. Luôn trả lời bằng Tiếng Việt"
"Hãy trả lời DUY NHẤT bằng JSON hợp lệ (không có văn bản khác) theo schema:\n\n"
"{\n"
' "supported": true/false, # câu trả lời đúng có được nội dung chứng thực không\n'
' "confidence": 0.0-1.0, # mức độ tự tin (số)\n'
' "evidence": "cụm văn bản ngắn làm bằng chứng hoặc trích dẫn",\n'
' "reason": "ngắn gọn, vì sao supported hoặc không"\n'
"}\n\n"
"Luôn dựa chỉ trên nội dung trong trường 'Context' dưới đây. Nếu nội dung không chứa bằng chứng, trả về supported: false."
)
}
user = {
"role": "user",
"content": (
"Câu hỏi:\n" + question_text + "\n\n"
"Lựa chọn:\n" + "\n".join([f"{k}: {v}" for k, v in options.items()]) + "\n\n"
"Đáp án:\n" + correct_text + "\n\n"
"Context:\n" + context_text + "\n\n"
"Hãy trả lời như yêu cầu."
)
}
raw = _post_chat([system, user], model=self.generation_model, temperature=model_verification_temperature)
# parse JSON object in response
try:
parsed = _safe_extract_json(raw)
except Exception as e:
return {"error": f"Model verification failed to return JSON: {e}", "raw": raw}
return parsed
# iterate MCQs
for qid, item in mcqs.items():
q_text = item.get("câu hỏi", "").strip()
options = item.get("lựa chọn", {})
correct_text = item.get("đáp án", "").strip()
# form a short declarative statement to embed: "Question: ... Answer: <correct>"
statement = f"{q_text} Answer: {correct_text}"
retrieved = semantic_search(statement, k=top_k)
evidence_list = []
max_sim = 0.0
for idx, score in retrieved:
if score >= evidence_score_cutoff:
evidence_list.append({
"idx": idx,
"page": self.metadata[idx].get("page", None),
"score": float(score),
"text": (self.texts[idx][:1000] + ("..." if len(self.texts[idx]) > 1000 else "")),
})
if score > max_sim:
max_sim = float(score)
supported_by_embeddings = max_sim >= similarity_threshold
model_verdict = None
if use_model_verification:
# build a context string from top retrieved chunks (regardless of cutoff)
context_parts = []
for ridx, sc in retrieved:
md = self.metadata[ridx]
context_parts.append(f"[page {md.get('page')}] {self.texts[ridx]}")
context_text = "\n\n".join(context_parts)
try:
parsed = _verify_with_model(q_text, options, correct_text, context_text)
model_verdict = parsed
except Exception as e:
model_verdict = {"error": f"verification exception: {e}"}
report[qid] = {
"supported_by_embeddings": bool(supported_by_embeddings),
"max_similarity": float(max_sim),
"evidence": evidence_list,
"model_verdict": model_verdict,
}
return report
def connect_qdrant(self, url: str, api_key: str = None, prefer_grpc: bool = False):
if not _HAS_QDRANT:
raise RuntimeError("qdrant-client is not installed. Install with `pip install qdrant-client`.")
self.qdrant_url = url
self.qdrant_api_key = api_key
self.qdrant_prefer_grpc = prefer_grpc
# Create client
self.qdrant = QdrantClient(url=url, api_key=api_key, prefer_grpc=prefer_grpc)
def _ensure_collection(self, collection_name: str):
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
try:
# get_collection will raise if not present
_ = self.qdrant.get_collection(collection_name)
except Exception:
# create collection with vector size = self.dim
vect_params = VectorParams(size=self.dim, distance=Distance.COSINE)
self.qdrant.recreate_collection(collection_name=collection_name, vectors_config=vect_params)
# recreate_collection ensures a clean collection; if you prefer to avoid wiping use create_collection instead.
def save_pdf_to_qdrant(
self,
pdf_path: str,
filename: str,
collection: str,
max_chars: int = 1200,
batch_size: int = 64,
overwrite: bool = False,
):
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
# extract pages and chunks (re-using your existing helpers)
pages = self.extract_pages(pdf_path)
all_chunks = []
all_meta = []
for p_idx, page_text in enumerate(pages, start=1):
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
for cid, ch in enumerate(chunks, start=1):
all_chunks.append(ch)
all_meta.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
if not all_chunks:
raise RuntimeError("No tSext extracted from PDF.")
# ensure collection exists
self._ensure_collection(collection)
# optional: delete previous points for this filename if overwrite
if overwrite:
# delete by filter: filename == filename
flt = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
try:
# qdrant-client delete uses delete(
self.qdrant.delete(collection_name=collection, filter=flt)
except Exception:
# ignore if deletion fails
pass
# compute embeddings in batches
embeddings = self.embedder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
embeddings = embeddings.astype("float32")
# prepare points
points = []
for i, (emb, md, txt) in enumerate(zip(embeddings, all_meta, all_chunks)):
pid = str(uuid4())
source_id = f"{filename}__p{md['page']}__c{md['chunk_id']}"
payload = {
"filename": filename,
"page": md["page"],
"chunk_id": md["chunk_id"],
"length": md["length"],
"text": txt,
"source_id": source_id,
}
points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload))
# upsert in batches
if len(points) >= batch_size:
self.qdrant.upsert(collection_name=collection, points=points)
points = []
# upsert remaining
if points:
self.qdrant.upsert(collection_name=collection, points=points)
try:
self.qdrant.create_payload_index(
collection_name=collection,
field_name="filename",
field_schema=rest.PayloadSchemaType.KEYWORD
)
except Exception as e:
print(f"Index creation skipped or failed: {e}")
return {"status": "ok", "uploaded_chunks": len(all_chunks), "collection": collection, "filename": filename}
def list_files_in_collection(
self,
collection: str,
payload_field: str = "filename",
batch_size: int = 500,
) -> List[str]:
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
# ensure collection exists
try:
if not self.qdrant.collection_exists(collection):
raise RuntimeError(f"Collection '{collection}' does not exist.")
except Exception:
# collection_exists may raise if server unreachable
raise
filenames = set()
offset = None
while True:
# scroll returns (points, next_offset)
pts, next_offset = self.qdrant.scroll(
collection_name=collection,
limit=batch_size,
offset=offset,
with_payload=[payload_field],
with_vectors=False,
)
if not pts:
break
for p in pts:
# p may be a dict-like or an object with .payload
payload = None
if hasattr(p, "payload"):
payload = p.payload
elif isinstance(p, dict):
# older/newer variants might use nested structures: try common keys
payload = p.get("payload") or p.get("payload", None) or p
else:
# best-effort fallback: convert to dict if possible
try:
payload = dict(p)
except Exception:
payload = None
if not payload:
continue
# extract candidate value(s)
val = None
if isinstance(payload, dict):
val = payload.get(payload_field)
else:
# Some payload representations store fields differently; try attribute access
val = getattr(payload, payload_field, None)
# If value is list-like, iterate, else add single
if isinstance(val, (list, tuple, set)):
for v in val:
if v is not None:
filenames.add(str(v))
elif val is not None:
filenames.add(str(val))
# stop if no more pages
if not next_offset:
break
offset = next_offset
return sorted(filenames)
def list_chunks_for_filename(self, collection: str, filename: str, batch: int = 256) -> List[Dict[str, Any]]:
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
results = []
offset = None
while True:
# scroll returns (points, next_offset)
points, next_offset = self.qdrant.scroll(
collection_name=collection,
scroll_filter=Filter(
must=[
FieldCondition(key="filename", match=MatchValue(value=filename))
]
),
limit=batch,
offset=offset,
with_payload=True,
with_vectors=False,
)
# points are objects (Record / ScoredPoint-like); get id and payload
for p in points:
# p.payload is a dict, p.id is point id
results.append({"point_id": p.id, "payload": p.payload})
if not next_offset:
break
offset = next_offset
return results
def _retrieve_qdrant(self, query: str, collection: str, filename: str = None, top_k: int = 3) -> List[Tuple[Dict[str, Any], float]]:
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")[0].tolist()
q_filter = None
if filename:
q_filter = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
search_res = self.qdrant.search(
collection_name=collection,
query_vector=q_emb,
query_filter=q_filter,
limit=top_k,
with_payload=True,
with_vectors=False,
)
out = []
for hit in search_res:
# hit.payload is the stored payload, hit.score is similarity
out.append((hit.payload, float(getattr(hit, "score", 0.0))))
return out
def generate_from_qdrant(
self,
filename: str,
collection: str,
n_questions: int = 10,
mode: str = "rag", # 'per_chunk' or 'rag'
questions_per_chunk: int = 3, # used for 'per_chunk'
top_k: int = 3, # retrieval size used in RAG
temperature: float = 0.2,
) -> Dict[str, Any]:
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
# get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
if not file_points:
raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
# create a local list of texts & metadata for sampling
texts = []
metas = []
for p in file_points:
payload = p.get("payload", {})
text = payload.get("text", "")
texts.append(text)
metas.append(payload)
self.texts = texts
self.metadata = metas
embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
if embeddings is None or len(embeddings) == 0:
self.embeddings = None
self.index = None
else:
self.embeddings = embeddings.astype("float32")
# update dim in case embedder changed unexpectedly
self.dim = int(self.embeddings.shape[1])
# build index
self._build_faiss_index()
output = {}
qcount = 0
if mode == "per_chunk":
# iterate all chunks (in payload order) and request questions_per_chunk from each
for i, txt in enumerate(texts):
if not txt.strip():
continue
to_gen = questions_per_chunk
try:
mcq_block = generate_mcqs_from_text(txt, n=to_gen, model=self.generation_model, temperature=temperature)
except Exception as e:
print(f"Generator failed on chunk (index {i}): {e}")
continue
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
elif mode == "rag":
attempts = 0
max_attempts = n_questions * 4
while qcount < n_questions and attempts < max_attempts:
attempts += 1
# sample a seed sentence from a random chunk of this file
seed_idx = random.randrange(len(texts))
chunk = texts[seed_idx]
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
seed_sent = None
for s in sents:
if len(s.strip()) > 20:
seed_sent = s
break
if not seed_sent:
seed_sent = chunk[:200]
query = f"Create questions about: {seed_sent}"
# retrieve top_k chunks from the same file (restricted by filename filter)
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
context_parts = []
for payload, score in retrieved:
# payload should contain page & chunk_id and text
page = payload.get("page", "?")
ctxt = payload.get("text", "")
context_parts.append(f"[page {page}] {ctxt}")
context = "\n\n".join(context_parts)
try:
mcq_block = generate_mcqs_from_text(context, n=1, model=self.generation_model, temperature=temperature)
except Exception as e:
print(f"Generator failed during RAG attempt {attempts}: {e}")
continue
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
else:
raise ValueError("mode must be 'per_chunk' or 'rag'.")