test-generator / generator.py
Sukmadi's picture
update 2 difficulty endpoint and 2 original endpoint
0d5d779
import re
import random
import fitz
import string
import numpy as np
import os
from typing import List, Optional, Tuple, Dict, Any
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import pipeline
from uuid import uuid4
import pymupdf4llm
from typing_extensions import override
try:
from qdrant_client import QdrantClient
from qdrant_client.http.models import (
PointStruct,
Filter,
FieldCondition,
MatchValue,
Distance,
VectorParams,
)
from qdrant_client.http import models as rest
_HAS_QDRANT = True
except Exception:
_HAS_QDRANT = False
try:
import faiss
_HAS_FAISS = True
except Exception:
_HAS_FAISS = False
from utils import generate_mcqs_from_text, structure_context_for_llm, new_generate_mcqs_from_text
from huggingface_hub import login
login(token=os.environ['HF_MODEL_TOKEN'])
class RAGMCQ:
def __init__(
self,
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
generation_model: str = "openai/gpt-oss-120b",
qdrant_url: str = os.environ.get('QDRANT_URL') or "",
qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "",
qdrant_prefer_grpc: bool = False,
):
self.embedder = SentenceTransformer(embedder_model)
self.generation_model = generation_model
self.qa_pipeline = pipeline("question-answering", model="nguyenvulebinh/vi-mrc-base", tokenizer="nguyenvulebinh/vi-mrc-base")
self.cross_entail = CrossEncoder("itdainb/PhoRanker")
self.embeddings = None # np.array of shape (N, D)
self.texts = [] # list of chunk texts
self.metadata = [] # list of dicts (page, chunk_id, char_range)
self.index = None
self.dim = self.embedder.get_sentence_embedding_dimension()
self.qdrant = None
self.qdrant_url = qdrant_url
self.qdrant_api_key = qdrant_api_key
self.qdrant_prefer_grpc = qdrant_prefer_grpc
if qdrant_url:
self.connect_qdrant(qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
def extract_pages(
self,
pdf_path: str,
*,
pages: Optional[List[int]] = None,
ignore_images: bool = False,
dpi: int = 150
) -> List[str]:
doc = fitz.open(pdf_path)
try:
# request page-wise output (page_chunks=True -> list[dict] per page)
page_dicts = pymupdf4llm.to_markdown(
doc,
pages=pages,
ignore_images=ignore_images,
dpi=dpi,
page_chunks=True,
)
# to_markdown(..., page_chunks=True) returns a list of dicts, each has key "text" (markdown)
pages_md: List[str] = []
for p in page_dicts:
txt = p.get("text", "") or ""
pages_md.append(txt.strip())
return pages_md
finally:
doc.close()
def chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 100) -> List[str]:
text = text.strip()
if not text:
return []
if len(text) <= max_chars:
return [text]
# split by sentence-like boundaries
sentences = re.split(r'(?<=[\.\?\!])\s+', text)
chunks = []
cur = ""
for s in sentences:
if len(cur) + len(s) + 1 <= max_chars:
cur += (" " if cur else "") + s
else:
if cur:
chunks.append(cur)
cur = (cur[-overlap:] + " " + s) if overlap > 0 else s
if cur:
chunks.append(cur)
# if still too long, hard-split
final = []
for c in chunks:
if len(c) <= max_chars:
final.append(c)
else:
for i in range(0, len(c), max_chars):
final.append(c[i:i+max_chars])
return final
def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200):
pages = self.extract_pages(pdf_path)
self.texts = []
self.metadata = []
for p_idx, page_text in enumerate(pages, start=1):
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
for cid, ch in enumerate(chunks, start=1):
self.texts.append(ch)
self.metadata.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
if not self.texts:
raise RuntimeError("No text extracted from PDF.")
# save_to_local('test/text_chunks.md', content=self.texts)
# compute embeddings
emb = self.embedder.encode(self.texts, convert_to_numpy=True, show_progress_bar=True)
self.embeddings = emb.astype("float32")
self._build_faiss_index()
def _build_faiss_index(self, ef_construction=200, M=32):
if _HAS_FAISS:
d = self.embeddings.shape[1]
index = faiss.IndexHNSWFlat(d, M)
faiss.normalize_L2(self.embeddings)
index.add(self.embeddings)
index.hnsw.efConstruction = ef_construction
self.index = index
else:
# store normalized embeddings and use brute-force numpy
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-10
self.embeddings = self.embeddings / norms
self.index = None
def _retrieve(self, query: str, top_k: int = 3) -> List[Tuple[int, float]]:
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")
if _HAS_FAISS:
faiss.normalize_L2(q_emb)
D_list, I_list = self.index.search(q_emb, top_k)
# D are inner products; return list of (idx, score)
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
else:
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
sims = (self.embeddings @ qn.T).squeeze(axis=1)
idxs = np.argsort(-sims)[:top_k]
return [(int(i), float(sims[i])) for i in idxs]
def generate_from_pdf(
self,
pdf_path: str,
n_questions: int = 10,
mode: str = "rag", # per_page or rag
questions_per_page: int = 3, # for per_page mode
top_k: int = 3, # chunks to retrieve for each question in rag mode
temperature: float = 0.2,
enable_fiddler: bool = False,
) -> Dict[str, Any]:
# build index
self.build_index_from_pdf(pdf_path)
output: Dict[str, Any] = {}
qcount = 0
if mode == "per_page":
# iterate pages -> chunks
for idx, meta in enumerate(self.metadata):
chunk_text = self.texts[idx]
if not chunk_text.strip():
continue
# ask generator
try:
structured_context = structure_context_for_llm(chunk_text, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
mcq_block = generate_mcqs_from_text(
structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
)
except Exception as e:
# skip this chunk if generator fails
print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
continue
if "error" in list(mcq_block.keys()):
return output
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
elif mode == "rag":
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
# create queries by sampling chunk text sentences.
# stop when n_questions reached or max_attempts exceeded.
attempts = 0
max_attempts = n_questions * 4
while qcount < n_questions and attempts < max_attempts:
attempts += 1
# create a seed query: pick a random chunk, pick a sentence from it
seed_idx = random.randrange(len(self.texts))
chunk = self.texts[seed_idx]
#? investigate better Chunking Strategy
#with open("chunks.txt", "a", encoding="utf-8") as f:
#f.write(chunk + "\n")
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
query = f"Create questions about: {seed_sent}"
# retrieve top_k chunks
retrieved = self._retrieve(query, top_k=top_k)
context_parts = []
for ridx, score in retrieved:
md = self.metadata[ridx]
context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
context = "\n\n".join(context_parts)
# save_to_local('test/context.md', content=context)
# call generator for 1 question (or small batch) with the retrieved context
try:
# request 1 question at a time to keep diversity
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
mcq_block = generate_mcqs_from_text(
structured_context, n=1, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler
)
except Exception as e:
print(f"Generator failed during RAG attempt {attempts}: {e}")
continue
if "error" in list(mcq_block.keys()):
return output
# append result(s)
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
payload = mcq_block[item]
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
if isinstance(options, list):
options = {str(i+1): o for i, o in enumerate(options)}
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
correct_text = ""
if isinstance(correct_key, str) and correct_key.strip() in options:
correct_text = options[correct_key.strip()]
else:
correct_text = payload.get("correct_text") or correct_key or ""
diff_score, diff_label = self._estimate_difficulty_for_generation(
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context
)
payload["difficulty"] = {"score": diff_score, "label": diff_label}
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
else:
raise ValueError("mode must be 'per_page' or 'rag'.")
def validate_mcqs(
self,
mcqs: Dict[str, Any],
top_k: int = 4,
similarity_threshold: float = 0.5,
evidence_score_cutoff: float = 0.5,
use_cross_encoder: bool = True,
use_qa: bool = True,
auto_accept_threshold: float = 0.7,
review_threshold: float = 0.5,
distractor_too_similar: float = 0.8,
distractor_too_different: float = 0.15,
model_verification_temperature: float = 0.0,
) -> Dict[str, Any]:
"""
Upgraded validation pipeline:
- embedding retrieval (self.index / self.embeddings)
- cross-encoder entailment scoring (optional)
- extractive QA consistency check (optional)
- distractor similarity and type checks
- aggregate into quality_score and triage_action
Returns a dict keyed by qid with detailed info and triage decision.
"""
cross_entail = None
qa_pipeline = None
if use_cross_encoder:
try:
cross_entail = self.cross_entail
except Exception as e:
cross_entail = None
if use_qa:
try:
qa_pipeline = self.qa_pipeline
except Exception:
qa_pipeline = None
# --- helpers ---
def _norm_text(s: str) -> str:
if s is None:
return ""
s = s.strip().lower()
# remove punctuation
s = s.translate(str.maketrans("", "", string.punctuation))
# collapse whitespace
s = " ".join(s.split())
return s
def _semantic_search(statement: str, k: int = top_k):
# returns list of (idx, score) using current embeddings/index
q_emb = self.embedder.encode([statement], convert_to_numpy=True).astype("float32")
if _HAS_FAISS and getattr(self, "index", None) is not None:
try:
faiss.normalize_L2(q_emb)
D_list, I_list = self.index.search(q_emb, k)
return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1]
except Exception:
pass
# fallback to brute force
qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10)
sims = (self.embeddings @ qn.T).squeeze(axis=1)
idxs = np.argsort(-sims)[:k]
return [(int(i), float(sims[i])) for i in idxs]
def _compose_context_from_retrieved(retrieved):
parts = []
for ridx, score in retrieved:
md = self.metadata[ridx] if ridx < len(self.metadata) else {}
page = md.get("page", "?")
text = self.texts[ridx]
parts.append(f"[page {page}] {text}")
return "\n\n".join(parts)
def _compute_option_embeddings(options_map: Dict[str, str]):
# returns dict key->embedding
keys = list(options_map.keys())
texts = [options_map[k] for k in keys]
embs = self.embedder.encode(texts, convert_to_numpy=True)
return dict(zip(keys, embs))
def _cosine(a, b):
a = np.asarray(a, dtype=float)
b = np.asarray(b, dtype=float)
denom = (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12)
return float(np.dot(a, b) / denom)
# --- main loop ---
report = {}
for qid, item in mcqs.items():
# support both Vietnamese keys and English keys
q_text = (item.get("câu hỏi") or item.get("question") or item.get("q") or item.get("stem") or "").strip()
options = item.get("lựa chọn") or item.get("options") or item.get("choices") or {}
# options may be dict mapping letters to text, or list: normalize to dict
if isinstance(options, list):
options = {str(i+1): o for i, o in enumerate(options)}
# correct answer may be a key (like "A") or the text; try both
correct_key = item.get("đáp án") or item.get("answer") or item.get("correct") or item.get("ans")
correct_text = ""
if isinstance(correct_key, str) and correct_key.strip() in options:
correct_text = options[correct_key.strip()]
else:
# maybe the answer is full text
if isinstance(correct_key, str):
correct_text = correct_key.strip()
else:
# fallback to 'correct_text' field
correct_text = item.get("correct_text") or item.get("đáp án_text") or ""
# default empty guard
options = {k: str(v) for k, v in options.items()}
correct_text = str(correct_text)
# prepare statement for retrieval
statement = f"{q_text} Answer: {correct_text}"
retrieved = _semantic_search(statement, k=top_k)
# build context from top retrieved
context_parts = []
for ridx, score in retrieved:
md = self.metadata[ridx] if ridx < len(self.metadata) else {}
context_parts.append({"idx": ridx, "score": float(score), "page": md.get("page", None), "text": self.texts[ridx]})
context_text = "\n\n".join([f"[page {p['page']}] {p['text']}" for p in context_parts])
# Evidence list (embedding-based)
evidence_list = []
max_sim = 0.0
for r in context_parts:
if r["score"] >= evidence_score_cutoff:
snippet = r["text"]
evidence_list.append({
"idx": r["idx"],
"page": r["page"],
"score": r["score"],
"text": (snippet[:1000] + ("..." if len(snippet) > 1000 else "")),
})
if r["score"] > max_sim:
max_sim = float(r["score"])
supported_by_embeddings = max_sim >= similarity_threshold
# Cross-encoder entailment scores for each option
entailment_scores = {}
correct_entail = 0.0
try:
if cross_entail is not None and context_text.strip():
# prepare list of (premise, hypothesis)
pairs = []
opt_keys = list(options.keys())
for k in opt_keys:
hyp = f"{q_text} Answer: {options[k]}"
pairs.append((context_text, hyp))
scores = cross_entail.predict(pairs) # returns list of floats
# normalize scores to 0-1 if needed (cross-encoder may return arbitrary positive)
# do a min-max normalization across the returned scores
# but avoid division by zero
min_s = float(min(scores)) if len(scores) else 0.0
max_s = float(max(scores)) if len(scores) else 1.0
denom = max_s - min_s if max_s - min_s > 1e-6 else 1.0
for k, raw in zip(opt_keys, scores):
scaled = (raw - min_s) / denom
entailment_scores[k] = float(scaled)
# find correct key if available
# if `correct_text` exactly matches one of options, find that key
matched_key = None
for k, v in options.items():
if _norm_text(v) == _norm_text(correct_text):
matched_key = k
break
if matched_key:
correct_entail = entailment_scores.get(matched_key, 0.0)
else:
# fallback: treat 'correct_text' as a separate hypothesis
hyp = f"{q_text} Answer: {correct_text}"
raw = cross_entail.predict([(context_text, hyp)])[0]
# scale relative to min/max used above
correct_entail = float((raw - min_s) / denom)
else:
entailment_scores = {}
correct_entail = 0.0
except Exception as e:
entailment_scores = {}
correct_entail = 0.0
def embed_cosine_sim(a, b):
emb = self.embedder.encode([a, b], convert_to_numpy=True, normalize_embeddings=True)
return float(np.dot(emb[0], emb[1]))
# QA consistency
qa_answer = None
qa_score = 0.0
qa_agrees = False
if qa_pipeline is not None and context_text.strip():
try:
qa_res = qa_pipeline(question=q_text, context=context_text)
# some QA pipelines return list of answers or dict
if isinstance(qa_res, list) and len(qa_res) > 0:
top = qa_res[0]
qa_answer = top.get("answer") if isinstance(top, dict) else str(top)
# qa_score = float(top.get("score", 0.0) if isinstance(top, dict) else 0.0)
elif isinstance(qa_res, dict):
qa_answer = qa_res.get("answer", "")
qa_score = float(qa_res.get("score", 0.0))
else:
qa_answer = str(qa_res)
qa_score = 0.0
qa_score = embed_cosine_sim(qa_answer, correct_text)
qa_agrees = (qa_score >= 0.5)
except Exception:
qa_answer = None
qa_score = 0.0
qa_agrees = False
try:
opt_embs = _compute_option_embeddings({**options, "__CORRECT__": correct_text})
correct_emb = opt_embs.pop("__CORRECT__")
distractor_similarities = {}
for k, emb in opt_embs.items():
distractor_similarities[k] = float(_cosine(correct_emb, emb))
except Exception:
distractor_similarities = {k: None for k in options.keys()}
# distractor flags
distractor_penalty = 0.0
distractor_flags = []
for k, sim in distractor_similarities.items():
if sim is None or sim >= 0.999999 or (sim >= -0.01 and sim <= 0):
continue
if sim >= distractor_too_similar:
distractor_flags.append({"key": k, "reason": "too_similar", "similarity": sim})
distractor_penalty += 0.25
elif sim <= distractor_too_different:
distractor_flags.append({"key": k, "reason": "too_different", "similarity": sim})
distractor_penalty += 0.15
# clamp penalty
distractor_penalty = min(distractor_penalty, 1.0)
# Ambiguity detection: how many options have entailment >= threshold
ambiguous = False
ambiguous_options = []
if entailment_scores:
# count options whose entailment >= max(correct_entail * 0.9, 0.6)
amb_thresh = max(correct_entail * 0.9, 0.6)
for k, sc in entailment_scores.items():
if sc >= amb_thresh and (options.get(k, "") != correct_text):
ambiguous_options.append({"key": k, "score": sc, "text": options[k]})
ambiguous = len(ambiguous_options) > 0
# Compose aggregated quality score
# Components:
# - embedding_support: normalized max_sim (0..1)
# - entailment: correct_entail (0..1)
# - qa_agree: boolean -> 1 or 0 times qa_score
# - distractor_penalty: subtracted
emb_support_norm = max_sim # embedding similarity typically already 0..1 (inner product normalized)
entail_component = float(correct_entail)
qa_component = float(qa_score) if qa_agrees else 0.0
# weighted sum
quality_score = (
0.40 * emb_support_norm +
0.35 * entail_component +
0.20 * qa_component -
0.05 * distractor_penalty
)
# clamp to 0..1
quality_score = max(0.0, min(1.0, quality_score))
# triage decision
triage_action = "reject"
if quality_score >= auto_accept_threshold and not ambiguous:
triage_action = "pass"
elif quality_score >= review_threshold:
triage_action = "review"
else:
triage_action = "reject"
# compile flags/reasons
flag_reasons = []
if not supported_by_embeddings:
flag_reasons.append("no_strong_embedding_evidence")
if entailment_scores and correct_entail < 0.6:
flag_reasons.append("low_entailment_score_for_correct")
if qa_pipeline is not None and qa_score > 0.6 and not qa_agrees:
flag_reasons.append("qa_contradiction")
if ambiguous:
flag_reasons.append("ambiguous_options_supported")
if distractor_flags:
flag_reasons.append({"distractor_issues": distractor_flags})
# assemble per-question report
report[qid] = {
"supported_by_embeddings": bool(supported_by_embeddings),
"max_similarity": float(max_sim),
"evidence": evidence_list,
"entailment_scores": entailment_scores,
"correct_entailment": float(correct_entail),
"qa_answer": qa_answer,
"qa_score": float(qa_score),
"qa_agrees": bool(qa_agrees),
"distractor_similarities": distractor_similarities,
"distractor_flags": distractor_flags,
"distractor_penalty": float(distractor_penalty),
"ambiguous_options": ambiguous_options,
"quality_score": float(quality_score),
"triage_action": triage_action,
"flag_reasons": flag_reasons,
}
return report
def connect_qdrant(self, url: str, api_key: str = None, prefer_grpc: bool = False):
if not _HAS_QDRANT:
raise RuntimeError("qdrant-client is not installed. Install with `pip install qdrant-client`.")
self.qdrant_url = url
self.qdrant_api_key = api_key
self.qdrant_prefer_grpc = prefer_grpc
# Create client
self.qdrant = QdrantClient(url=url, api_key=api_key, prefer_grpc=prefer_grpc)
def _ensure_collection(self, collection_name: str):
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
try:
# get_collection will raise if not present
_ = self.qdrant.get_collection(collection_name)
except Exception:
# create collection with vector size = self.dim
vect_params = VectorParams(size=self.dim, distance=Distance.COSINE)
self.qdrant.recreate_collection(collection_name=collection_name, vectors_config=vect_params)
# recreate_collection ensures a clean collection; if you prefer to avoid wiping use create_collection instead.
def save_pdf_to_qdrant(
self,
pdf_path: str,
filename: str,
collection: str,
max_chars: int = 1200,
batch_size: int = 64,
overwrite: bool = False,
):
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
# extract pages and chunks (re-using your existing helpers)
pages = self.extract_pages(pdf_path)
all_chunks = []
all_meta = []
for p_idx, page_text in enumerate(pages, start=1):
chunks = self.chunk_text(page_text or "", max_chars=max_chars)
for cid, ch in enumerate(chunks, start=1):
all_chunks.append(ch)
all_meta.append({"page": p_idx, "chunk_id": cid, "length": len(ch)})
if not all_chunks:
raise RuntimeError("No tSext extracted from PDF.")
# ensure collection exists
self._ensure_collection(collection)
# optional: delete previous points for this filename if overwrite
if overwrite:
# delete by filter: filename == filename
flt = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
try:
# qdrant-client delete uses delete(
self.qdrant.delete(collection_name=collection, filter=flt)
except Exception:
# ignore if deletion fails
pass
# compute embeddings in batches
embeddings = self.embedder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
embeddings = embeddings.astype("float32")
# prepare points
points = []
for i, (emb, md, txt) in enumerate(zip(embeddings, all_meta, all_chunks)):
pid = str(uuid4())
source_id = f"{filename}__p{md['page']}__c{md['chunk_id']}"
payload = {
"filename": filename,
"page": md["page"],
"chunk_id": md["chunk_id"],
"length": md["length"],
"text": txt,
"source_id": source_id,
}
points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload)) # pyright: ignore[reportPossiblyUnboundVariable]
# upsert in batches
if len(points) >= batch_size:
self.qdrant.upsert(collection_name=collection, points=points)
points = []
# upsert remaining
if points:
self.qdrant.upsert(collection_name=collection, points=points)
try:
self.qdrant.create_payload_index(
collection_name=collection,
field_name="filename",
field_schema=rest.PayloadSchemaType.KEYWORD
)
except Exception as e:
print(f"Index creation skipped or failed: {e}")
return {"status": "ok", "uploaded_chunks": len(all_chunks), "collection": collection, "filename": filename}
def list_files_in_collection(
self,
collection: str,
payload_field: str = "filename",
batch_size: int = 500,
) -> List[str]:
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
# ensure collection exists
try:
if not self.qdrant.collection_exists(collection):
raise RuntimeError(f"Collection '{collection}' does not exist.")
except Exception:
# collection_exists may raise if server unreachable
raise
filenames = set()
offset = None
while True:
# scroll returns (points, next_offset)
pts, next_offset = self.qdrant.scroll(
collection_name=collection,
limit=batch_size,
offset=offset,
with_payload=[payload_field],
with_vectors=False,
)
if not pts:
break
for p in pts:
# p may be a dict-like or an object with .payload
payload = None
if hasattr(p, "payload"):
payload = p.payload
elif isinstance(p, dict):
# older/newer variants might use nested structures: try common keys
payload = p.get("payload") or p.get("payload", None) or p
else:
# best-effort fallback: convert to dict if possible
try:
payload = dict(p)
except Exception:
payload = None
if not payload:
continue
# extract candidate value(s)
val = None
if isinstance(payload, dict):
val = payload.get(payload_field)
else:
# Some payload representations store fields differently; try attribute access
val = getattr(payload, payload_field, None)
# If value is list-like, iterate, else add single
if isinstance(val, (list, tuple, set)):
for v in val:
if v is not None:
filenames.add(str(v))
elif val is not None:
filenames.add(str(val))
# stop if no more pages
if not next_offset:
break
offset = next_offset
return sorted(filenames)
def list_chunks_for_filename(self, collection: str, filename: str, batch: int = 256) -> List[Dict[str, Any]]:
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
results = []
offset = None
while True:
# scroll returns (points, next_offset)
points, next_offset = self.qdrant.scroll(
collection_name=collection,
scroll_filter=Filter(
must=[
FieldCondition(key="filename", match=MatchValue(value=filename))
]
),
limit=batch,
offset=offset,
with_payload=True,
with_vectors=False,
)
# points are objects (Record / ScoredPoint-like); get id and payload
for p in points:
# p.payload is a dict, p.id is point id
results.append({"point_id": p.id, "payload": p.payload})
if not next_offset:
break
offset = next_offset
return results
def _retrieve_qdrant(self, query: str, collection: str, filename: str = None, top_k: int = 3) -> List[Tuple[Dict[str, Any], float]]:
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")[0].tolist()
q_filter = None
if filename:
q_filter = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))])
search_res = self.qdrant.search(
collection_name=collection,
query_vector=q_emb,
query_filter=q_filter,
limit=top_k,
with_payload=True,
with_vectors=False,
)
out = []
for hit in search_res:
# hit.payload is the stored payload, hit.score is similarity
out.append((hit.payload, float(getattr(hit, "score", 0.0))))
return out
def generate_from_qdrant(
self,
filename: str,
collection: str,
n_questions: int = 10,
mode: str = "rag", # 'per_chunk' or 'rag'
questions_per_chunk: int = 3, # used for 'per_chunk'
top_k: int = 3, # retrieval size used in RAG
temperature: float = 0.2,
enable_fiddler: bool = False,
) -> Dict[str, Any]:
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
# get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
if not file_points:
raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
# create a local list of texts & metadata for sampling
texts = []
metas = []
for p in file_points:
payload = p.get("payload", {})
text = payload.get("text", "")
texts.append(text)
metas.append(payload)
self.texts = texts
self.metadata = metas
embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
if embeddings is None or len(embeddings) == 0:
self.embeddings = None
self.index = None
else:
self.embeddings = embeddings.astype("float32")
# update dim in case embedder changed unexpectedly
self.dim = int(self.embeddings.shape[1])
# build index
self._build_faiss_index()
output = {}
qcount = 0
if mode == "per_chunk":
# iterate all chunks (in payload order) and request questions_per_chunk from each
for i, txt in enumerate(texts):
if not txt.strip():
continue
try:
structured_context = structure_context_for_llm(txt, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
mcq_block = generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler)
except Exception as e:
print(f"Generator failed on chunk (index {i}): {e}")
continue
if "error" in list(mcq_block.keys()):
return output
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
elif mode == "rag":
attempts = 0
max_attempts = n_questions * 4
while qcount < n_questions and attempts < max_attempts:
attempts += 1
# create a seed query: pick a random chunk, pick a sentence from it
seed_idx = random.randrange(len(self.texts))
chunk = self.texts[seed_idx]
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
candidate = [s for s in sents if len(s.strip()) > 20]
if candidate:
seed_sent = random.choice(candidate)
else:
stripped = chunk.strip()
seed_sent = (stripped[:200] if stripped else "[no text available]")
query = f"Create questions about: {seed_sent}"
# retrieve top_k chunks from the same file (restricted by filename filter)
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
context_parts = []
for payload, score in retrieved:
# payload should contain page & chunk_id and text
page = payload.get("page", "?")
ctxt = payload.get("text", "")
context_parts.append(f"[page {page}] {ctxt}")
context = "\n\n".join(context_parts)
try:
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
mcq_block = generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler)
except Exception as e:
print(f"Generator failed during RAG attempt {attempts}: {e}")
continue
if "error" in list(mcq_block.keys()):
return output
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
payload = mcq_block[item]
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
if isinstance(options, list):
options = {str(i+1): o for i, o in enumerate(options)}
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
correct_text = ""
if isinstance(correct_key, str) and correct_key.strip() in options:
correct_text = options[correct_key.strip()]
else:
correct_text = payload.get("correct_text") or correct_key or ""
diff_score, diff_label = self._estimate_difficulty_for_generation(
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context
)
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
else:
raise ValueError("mode must be 'per_chunk' or 'rag'.")
def _estimate_difficulty_for_generation(
self,
q_text: str,
options: Dict[str, str],
correct_text: str,
context_text: str,
) -> Tuple[float, str]:
def safe_map_sim(s):
# map potentially [-1,1] cosine-like to [0,1], clamp
try:
s = float(s)
except Exception:
return 0.0
mapped = (s + 1.0) / 2.0
return max(0.0, min(1.0, mapped))
# embedding support
emb_support = 0.0
try:
stmt = (q_text or "").strip()
if correct_text:
stmt = f"{stmt} Answer: {correct_text}"
# use internal retrieve but map returned score
res = []
try:
res = self._retrieve(stmt, top_k=1)
except Exception:
res = []
if res:
raw_score = float(res[0][1])
emb_support = safe_map_sim(raw_score)
else:
emb_support = 0.0
except Exception:
emb_support = 0.0
# distractor sims
mean_sim = 0.0
distractor_penalty = 0.0
amb_flag = 0.0
try:
keys = list(options.keys())
texts = [options[k] for k in keys]
if correct_text is None:
correct_text = ""
all_texts = [correct_text] + texts
embs = self.embedder.encode(all_texts, convert_to_numpy=True)
embs = np.asarray(embs, dtype=float)
norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12
embs = embs / norms
corr = embs[0]
opts = embs[1:]
if opts.size == 0:
mean_sim = 0.0
distractor_penalty = 0.0
gap = 0.0
else:
sims = (opts @ corr).tolist() # [-1,1]
sims_mapped = [safe_map_sim(s) for s in sims] # [0,1]
mean_sim = float(sum(sims_mapped) / len(sims_mapped))
# gap between best distractor and second best (higher gap -> easier)
sorted_s = sorted(sims_mapped, reverse=True)
top = sorted_s[0]
second = sorted_s[1] if len(sorted_s) > 1 else 0.0
gap = top - second
# penalties: if distractors are extremely close to correct -> higher penalty
too_close_count = sum(1 for s in sims_mapped if s >= 0.85)
too_far_count = sum(1 for s in sims_mapped if s <= 0.15)
distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped))))
amb_flag = 1.0 if top >= 0.9 else 0.0
except Exception:
mean_sim = 0.0
distractor_penalty = 0.0
amb_flag = 0.0
gap = 0.0
# stem length normalized
qlen = len((q_text or "").strip())
qlen_norm = min(1.0, qlen / 300.0)
# combine signals using safer semantics:
# higher emb_support -> easier (so we subtract a term)
# higher distractor_penalty -> harder (add)
# better gap -> easier (subtract)
# compute score (higher -> harder)
score = 0
score += 0.35 * float(distractor_penalty)
score += 0.20 * float(mean_sim)
score += 0.22 * float(amb_flag)
score += 0.05 * float(qlen_norm)
score -= 0.20 * float(gap)
# clamp
score = max(0.0, min(1.0, float(score)))
# label
if score <= 0.33:
label = "dễ"
elif score <= 0.66 and score > 0.33:
label = "trung bình"
else:
label = "khó"
return score, label
class RAGMCQWithDifficulty(RAGMCQ):
def __init__(
self,
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
generation_model: str = "openai/gpt-oss-120b",
qdrant_url: str = os.environ.get('QDRANT_URL') or "",
qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "",
qdrant_prefer_grpc: bool = False,
):
super().__init__(embedder_model, generation_model, qdrant_url, qdrant_api_key, qdrant_prefer_grpc)
@override
def generate_from_pdf(
self,
pdf_path: str,
n_questions: int = 10,
mode: str = "rag", # per_page or rag
questions_per_page: int = 3, # for per_page mode
top_k: int = 3, # chunks to retrieve for each question in rag mode
temperature: float = 0.2,
enable_fiddler: bool = False,
target_difficulty: str = 'easy' # easy, mid, difficult
) -> Dict[str, Any]:
# build index
self.build_index_from_pdf(pdf_path)
output: Dict[str, Any] = {}
qcount = 0
if mode == "per_page":
# iterate pages -> chunks
for idx, meta in enumerate(self.metadata):
chunk_text = self.texts[idx]
if not chunk_text.strip():
continue
# ask generator
try:
structured_context = structure_context_for_llm(chunk_text, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler)
mcq_block = new_generate_mcqs_from_text(
source_text=structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler
)
except Exception as e:
# skip this chunk if generator fails
print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}")
continue
if "error" in list(mcq_block.keys()):
return output
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
# pdf gene
elif mode == "rag":
# strategy: create a few natural short queries by sampling sentences or using chunk summaries.
# create queries by sampling chunk text sentences.
# stop when n_questions reached or max_attempts exceeded.
attempts = 0
max_attempts = n_questions * 4
while qcount < n_questions and attempts < max_attempts:
attempts += 1
# create a seed query: pick a random chunk, pick a sentence from it
seed_idx = random.randrange(len(self.texts))
chunk = self.texts[seed_idx]
#? investigate better Chunking Strategy
#with open("chunks.txt", "a", encoding="utf-8") as f:
#f.write(chunk + "\n")
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
query = f"Create questions about: {seed_sent}"
# retrieve top_k chunks
retrieved = self._retrieve(query, top_k=top_k)
context_parts = []
for ridx, score in retrieved:
md = self.metadata[ridx]
context_parts.append(f"[page {md['page']}] {self.texts[ridx]}")
context = "\n\n".join(context_parts)
# save_to_local('test/context.md', content=context)
# call generator for 1 question (or small batch) with the retrieved context
try:
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
mcq_block = new_generate_mcqs_from_text(
source_text=structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler
)
except Exception as e:
print(f"Generator failed during RAG attempt {attempts}: {e}")
continue
if "error" in list(mcq_block.keys()):
return output
# append result(s)
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
payload = mcq_block[item]
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
if isinstance(options, list):
options = {str(i+1): o for i, o in enumerate(options)}
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
correct_text = ""
if isinstance(correct_key, str) and correct_key.strip() in options:
correct_text = options[correct_key.strip()]
else:
correct_text = payload.get("correct_text") or correct_key or ""
diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts
)
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
else:
raise ValueError("mode must be 'per_page' or 'rag'.")
@override
def generate_from_qdrant(
self,
filename: str,
collection: str,
n_questions: int = 10,
mode: str = "rag", # 'per_chunk' or 'rag'
questions_per_chunk: int = 3, # used for 'per_chunk'
top_k: int = 3, # retrieval size used in RAG
temperature: float = 0.2,
enable_fiddler: bool = False,
target_difficulty: str = 'easy',
) -> Dict[str, Any]:
if self.qdrant is None:
raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.")
# get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.)
file_points = self.list_chunks_for_filename(collection=collection, filename=filename)
if not file_points:
raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.")
# create a local list of texts & metadata for sampling
texts = []
metas = []
for p in file_points:
payload = p.get("payload", {})
text = payload.get("text", "")
texts.append(text)
metas.append(payload)
self.texts = texts
self.metadata = metas
embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True)
if embeddings is None or len(embeddings) == 0:
self.embeddings = None
self.index = None
else:
self.embeddings = embeddings.astype("float32")
# update dim in case embedder changed unexpectedly
self.dim = int(self.embeddings.shape[1])
# build index
self._build_faiss_index()
output = {}
qcount = 0
if mode == "per_chunk":
# iterate all chunks (in payload order) and request questions_per_chunk from each
for i, txt in enumerate(texts):
if not txt.strip():
continue
try:
structured_context = structure_context_for_llm(txt, model=self.generation_model, temperature=0.2, enable_fiddler=False)
mcq_block = new_generate_mcqs_from_text(
source_text=structured_context, n=questions_per_chunk, model=self.generation_model,
temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler
)
except Exception as e:
print(f"Generator failed on chunk (index {i}): {e}")
continue
if "error" in list(mcq_block.keys()):
return output
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
qcount += 1
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
return output
elif mode == "rag":
attempts = 0
max_attempts = n_questions * 4
while qcount < n_questions and attempts < max_attempts:
attempts += 1
# create a seed query: pick a random chunk, pick a sentence from it
seed_idx = random.randrange(len(self.texts))
chunk = self.texts[seed_idx]
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
candidate = [s for s in sents if len(s.strip()) > 20]
if candidate:
seed_sent = random.choice(candidate)
else:
stripped = chunk.strip()
seed_sent = (stripped[:200] if stripped else "[no text available]")
query = f"Create questions about: {seed_sent}"
# retrieve top_k chunks from the same file (restricted by filename filter)
retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k)
context_parts = []
for payload, score in retrieved:
# payload should contain page & chunk_id and text
page = payload.get("page", "?")
ctxt = payload.get("text", "")
context_parts.append(f"[page {page}] {ctxt}")
context = "\n\n".join(context_parts)
# q generation
try:
structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False)
mcq_block = new_generate_mcqs_from_text(
source_text=structured_context, n=questions_per_chunk, model=self.generation_model,
temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler
)
except Exception as e:
print(f"Generator failed during RAG attempt {attempts}: {e}")
continue
if "error" in list(mcq_block.keys()):
return output
for item in sorted(mcq_block.keys(), key=lambda x: int(x)):
payload = mcq_block[item]
q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip()
options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {}
if isinstance(options, list):
options = {str(i+1): o for i, o in enumerate(options)}
correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None
concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None
correct_text = ""
if isinstance(correct_key, str) and correct_key.strip() in options:
correct_text = options[correct_key.strip()]
else:
correct_text = payload.get("correct_text") or correct_key or ""
#? change estimate
diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore
q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts # type: ignore
)
payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label}
# CHECK n generation: if number of request mcqs < default generation number e.g. 5 - 3 = 2 < 3 then only genearate 2 mcqs
if n_questions - qcount < questions_per_chunk:
questions_per_chunk = n_questions - qcount
qcount += 1 # count number of question
# print('qcount:', qcount)
# print('questions_per_chunk:', questions_per_chunk)
output[str(qcount)] = mcq_block[item]
if qcount >= n_questions:
return output
if output is not None:
print("output available")
return output
else:
raise ValueError("mode must be 'per_chunk' or 'rag'.")
@override
def _estimate_difficulty_for_generation(
self,
q_text: str,
options: Dict[str, str],
correct_text: str,
context_text: str,
concepts_used: Dict = {}
) -> Tuple[float, str]:
def safe_map_sim(s):
# map potentially [-1,1] cosine-like to [0,1], clamp
try:
s = float(s)
except Exception:
return 0.0
mapped = (s + 1.0) / 2.0
return max(0.0, min(1.0, mapped))
# embedding support
emb_support = 0.0
try:
stmt = (q_text or "").strip()
if correct_text:
stmt = f"{stmt} Answer: {correct_text}"
# use internal retrieve but map returned score
res = []
try:
res = self._retrieve(stmt, top_k=1)
except Exception:
res = []
if res:
raw_score = float(res[0][1])
emb_support = safe_map_sim(raw_score)
else:
emb_support = 0.0
except Exception:
emb_support = 0.0
# distractor sims
mean_sim = 0.0
distractor_penalty = 0.0
amb_flag = 0.0
try:
keys = list(options.keys())
texts = [options[k] for k in keys]
if correct_text is None:
correct_text = ""
all_texts = [correct_text] + texts
embs = self.embedder.encode(all_texts, convert_to_numpy=True)
embs = np.asarray(embs, dtype=float)
norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12
embs = embs / norms
corr = embs[0]
opts = embs[1:]
if opts.size == 0:
mean_sim = 0.0
distractor_penalty = 0.0
gap = 0.0
else:
sims = (opts @ corr).tolist() # [-1,1]
sims_mapped = [safe_map_sim(s) for s in sims] # [0,1]
mean_sim = float(sum(sims_mapped) / len(sims_mapped))
# gap between best distractor and second best (higher gap -> easier)
sorted_s = sorted(sims_mapped, reverse=True)
top = sorted_s[0]
second = sorted_s[1] if len(sorted_s) > 1 else 0.0
gap = top - second
# penalties: if distractors are extremely close to correct -> higher penalty
too_close_count = sum(1 for s in sims_mapped if s >= 0.85)
too_far_count = sum(1 for s in sims_mapped if s <= 0.15)
distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped))))
amb_flag = 1.0 if top >= 0.8 else 0.0
except Exception:
mean_sim = 0.0
distractor_penalty = 0.0
amb_flag = 0.0
gap = 0.0
# question length normalized
question_len = len((q_text or "").strip())
question_len_norm = min(1.0, question_len / 300.0)
# count number of concept from string
concepts_num = len(concepts_used.keys())
if concepts_num < 2:
concepts_penalty = 0
else:
concepts_penalty = concepts_num
# combine signals using safer semantics:
# higher emb_support -> easier (so we subtract a term)
# higher distractor_penalty -> harder (add)
# better gap -> easier (subtract)
# compute score (higher -> harder)
score = 0
score += 0.35 * float(distractor_penalty)
score += 0.20 * float(mean_sim)
score += 0.22 * float(amb_flag)
score += 0.08 * float(question_len_norm)
score -= 0.20 * float(gap)
# clamp
score = max(0.0, min(1.0, float(score)))
components = {
"base": 0.3,
"distractor_penalty": 0.35 * float(distractor_penalty),
"mean_sim": 0.15 * float(mean_sim),
"amb_flag": 0.05 * float(amb_flag),
"concepts_num": 0.1 * float(concepts_num),
"gap": -0.12 * float(gap),
"question_len_norm": 0.05 * float(question_len_norm),
"emb_support": -0.45 * float(emb_support),
"total_score": score,
}
# label
if score <= 0.56:
label = "dễ"
elif score <= 0.755 and score > 0.56:
label = "trung bình"
else:
label = "khó"
return score, label, components # type: ignore