import re import random import fitz import string import numpy as np import os from typing import List, Optional, Tuple, Dict, Any from sentence_transformers import SentenceTransformer, CrossEncoder from transformers import pipeline from uuid import uuid4 import pymupdf4llm from typing_extensions import override try: from qdrant_client import QdrantClient from qdrant_client.http.models import ( PointStruct, Filter, FieldCondition, MatchValue, Distance, VectorParams, ) from qdrant_client.http import models as rest _HAS_QDRANT = True except Exception: _HAS_QDRANT = False try: import faiss _HAS_FAISS = True except Exception: _HAS_FAISS = False from utils import generate_mcqs_from_text, structure_context_for_llm, new_generate_mcqs_from_text from huggingface_hub import login login(token=os.environ['HF_MODEL_TOKEN']) class RAGMCQ: def __init__( self, embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", generation_model: str = "openai/gpt-oss-120b", qdrant_url: str = os.environ.get('QDRANT_URL') or "", qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "", qdrant_prefer_grpc: bool = False, ): self.embedder = SentenceTransformer(embedder_model) self.generation_model = generation_model self.qa_pipeline = pipeline("question-answering", model="nguyenvulebinh/vi-mrc-base", tokenizer="nguyenvulebinh/vi-mrc-base") self.cross_entail = CrossEncoder("itdainb/PhoRanker") self.embeddings = None # np.array of shape (N, D) self.texts = [] # list of chunk texts self.metadata = [] # list of dicts (page, chunk_id, char_range) self.index = None self.dim = self.embedder.get_sentence_embedding_dimension() self.qdrant = None self.qdrant_url = qdrant_url self.qdrant_api_key = qdrant_api_key self.qdrant_prefer_grpc = qdrant_prefer_grpc if qdrant_url: self.connect_qdrant(qdrant_url, qdrant_api_key, qdrant_prefer_grpc) def extract_pages( self, pdf_path: str, *, pages: Optional[List[int]] = None, ignore_images: bool = False, dpi: int = 150 ) -> List[str]: doc = fitz.open(pdf_path) try: # request page-wise output (page_chunks=True -> list[dict] per page) page_dicts = pymupdf4llm.to_markdown( doc, pages=pages, ignore_images=ignore_images, dpi=dpi, page_chunks=True, ) # to_markdown(..., page_chunks=True) returns a list of dicts, each has key "text" (markdown) pages_md: List[str] = [] for p in page_dicts: txt = p.get("text", "") or "" pages_md.append(txt.strip()) return pages_md finally: doc.close() def chunk_text(self, text: str, max_chars: int = 1200, overlap: int = 100) -> List[str]: text = text.strip() if not text: return [] if len(text) <= max_chars: return [text] # split by sentence-like boundaries sentences = re.split(r'(?<=[\.\?\!])\s+', text) chunks = [] cur = "" for s in sentences: if len(cur) + len(s) + 1 <= max_chars: cur += (" " if cur else "") + s else: if cur: chunks.append(cur) cur = (cur[-overlap:] + " " + s) if overlap > 0 else s if cur: chunks.append(cur) # if still too long, hard-split final = [] for c in chunks: if len(c) <= max_chars: final.append(c) else: for i in range(0, len(c), max_chars): final.append(c[i:i+max_chars]) return final def build_index_from_pdf(self, pdf_path: str, max_chars: int = 1200): pages = self.extract_pages(pdf_path) self.texts = [] self.metadata = [] for p_idx, page_text in enumerate(pages, start=1): chunks = self.chunk_text(page_text or "", max_chars=max_chars) for cid, ch in enumerate(chunks, start=1): self.texts.append(ch) self.metadata.append({"page": p_idx, "chunk_id": cid, "length": len(ch)}) if not self.texts: raise RuntimeError("No text extracted from PDF.") # save_to_local('test/text_chunks.md', content=self.texts) # compute embeddings emb = self.embedder.encode(self.texts, convert_to_numpy=True, show_progress_bar=True) self.embeddings = emb.astype("float32") self._build_faiss_index() def _build_faiss_index(self, ef_construction=200, M=32): if _HAS_FAISS: d = self.embeddings.shape[1] index = faiss.IndexHNSWFlat(d, M) faiss.normalize_L2(self.embeddings) index.add(self.embeddings) index.hnsw.efConstruction = ef_construction self.index = index else: # store normalized embeddings and use brute-force numpy norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-10 self.embeddings = self.embeddings / norms self.index = None def _retrieve(self, query: str, top_k: int = 3) -> List[Tuple[int, float]]: q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32") if _HAS_FAISS: faiss.normalize_L2(q_emb) D_list, I_list = self.index.search(q_emb, top_k) # D are inner products; return list of (idx, score) return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1] else: qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10) sims = (self.embeddings @ qn.T).squeeze(axis=1) idxs = np.argsort(-sims)[:top_k] return [(int(i), float(sims[i])) for i in idxs] def generate_from_pdf( self, pdf_path: str, n_questions: int = 10, mode: str = "rag", # per_page or rag questions_per_page: int = 3, # for per_page mode top_k: int = 3, # chunks to retrieve for each question in rag mode temperature: float = 0.2, enable_fiddler: bool = False, ) -> Dict[str, Any]: # build index self.build_index_from_pdf(pdf_path) output: Dict[str, Any] = {} qcount = 0 if mode == "per_page": # iterate pages -> chunks for idx, meta in enumerate(self.metadata): chunk_text = self.texts[idx] if not chunk_text.strip(): continue # ask generator try: structured_context = structure_context_for_llm(chunk_text, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler) mcq_block = generate_mcqs_from_text( structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler ) except Exception as e: # skip this chunk if generator fails print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}") continue if "error" in list(mcq_block.keys()): return output for item in sorted(mcq_block.keys(), key=lambda x: int(x)): qcount += 1 output[str(qcount)] = mcq_block[item] if qcount >= n_questions: return output return output elif mode == "rag": # strategy: create a few natural short queries by sampling sentences or using chunk summaries. # create queries by sampling chunk text sentences. # stop when n_questions reached or max_attempts exceeded. attempts = 0 max_attempts = n_questions * 4 while qcount < n_questions and attempts < max_attempts: attempts += 1 # create a seed query: pick a random chunk, pick a sentence from it seed_idx = random.randrange(len(self.texts)) chunk = self.texts[seed_idx] #? investigate better Chunking Strategy #with open("chunks.txt", "a", encoding="utf-8") as f: #f.write(chunk + "\n") sents = re.split(r'(?<=[\.\?\!])\s+', chunk) seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200] query = f"Create questions about: {seed_sent}" # retrieve top_k chunks retrieved = self._retrieve(query, top_k=top_k) context_parts = [] for ridx, score in retrieved: md = self.metadata[ridx] context_parts.append(f"[page {md['page']}] {self.texts[ridx]}") context = "\n\n".join(context_parts) # save_to_local('test/context.md', content=context) # call generator for 1 question (or small batch) with the retrieved context try: # request 1 question at a time to keep diversity structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler) mcq_block = generate_mcqs_from_text( structured_context, n=1, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler ) except Exception as e: print(f"Generator failed during RAG attempt {attempts}: {e}") continue if "error" in list(mcq_block.keys()): return output # append result(s) for item in sorted(mcq_block.keys(), key=lambda x: int(x)): payload = mcq_block[item] q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip() options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {} if isinstance(options, list): options = {str(i+1): o for i, o in enumerate(options)} correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None correct_text = "" if isinstance(correct_key, str) and correct_key.strip() in options: correct_text = options[correct_key.strip()] else: correct_text = payload.get("correct_text") or correct_key or "" diff_score, diff_label = self._estimate_difficulty_for_generation( q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context ) payload["difficulty"] = {"score": diff_score, "label": diff_label} qcount += 1 output[str(qcount)] = mcq_block[item] if qcount >= n_questions: return output return output else: raise ValueError("mode must be 'per_page' or 'rag'.") def validate_mcqs( self, mcqs: Dict[str, Any], top_k: int = 4, similarity_threshold: float = 0.5, evidence_score_cutoff: float = 0.5, use_cross_encoder: bool = True, use_qa: bool = True, auto_accept_threshold: float = 0.7, review_threshold: float = 0.5, distractor_too_similar: float = 0.8, distractor_too_different: float = 0.15, model_verification_temperature: float = 0.0, ) -> Dict[str, Any]: """ Upgraded validation pipeline: - embedding retrieval (self.index / self.embeddings) - cross-encoder entailment scoring (optional) - extractive QA consistency check (optional) - distractor similarity and type checks - aggregate into quality_score and triage_action Returns a dict keyed by qid with detailed info and triage decision. """ cross_entail = None qa_pipeline = None if use_cross_encoder: try: cross_entail = self.cross_entail except Exception as e: cross_entail = None if use_qa: try: qa_pipeline = self.qa_pipeline except Exception: qa_pipeline = None # --- helpers --- def _norm_text(s: str) -> str: if s is None: return "" s = s.strip().lower() # remove punctuation s = s.translate(str.maketrans("", "", string.punctuation)) # collapse whitespace s = " ".join(s.split()) return s def _semantic_search(statement: str, k: int = top_k): # returns list of (idx, score) using current embeddings/index q_emb = self.embedder.encode([statement], convert_to_numpy=True).astype("float32") if _HAS_FAISS and getattr(self, "index", None) is not None: try: faiss.normalize_L2(q_emb) D_list, I_list = self.index.search(q_emb, k) return [(int(i), float(d)) for i, d in zip(I_list[0], D_list[0]) if i != -1] except Exception: pass # fallback to brute force qn = q_emb / (np.linalg.norm(q_emb, axis=1, keepdims=True) + 1e-10) sims = (self.embeddings @ qn.T).squeeze(axis=1) idxs = np.argsort(-sims)[:k] return [(int(i), float(sims[i])) for i in idxs] def _compose_context_from_retrieved(retrieved): parts = [] for ridx, score in retrieved: md = self.metadata[ridx] if ridx < len(self.metadata) else {} page = md.get("page", "?") text = self.texts[ridx] parts.append(f"[page {page}] {text}") return "\n\n".join(parts) def _compute_option_embeddings(options_map: Dict[str, str]): # returns dict key->embedding keys = list(options_map.keys()) texts = [options_map[k] for k in keys] embs = self.embedder.encode(texts, convert_to_numpy=True) return dict(zip(keys, embs)) def _cosine(a, b): a = np.asarray(a, dtype=float) b = np.asarray(b, dtype=float) denom = (np.linalg.norm(a) * np.linalg.norm(b) + 1e-12) return float(np.dot(a, b) / denom) # --- main loop --- report = {} for qid, item in mcqs.items(): # support both Vietnamese keys and English keys q_text = (item.get("câu hỏi") or item.get("question") or item.get("q") or item.get("stem") or "").strip() options = item.get("lựa chọn") or item.get("options") or item.get("choices") or {} # options may be dict mapping letters to text, or list: normalize to dict if isinstance(options, list): options = {str(i+1): o for i, o in enumerate(options)} # correct answer may be a key (like "A") or the text; try both correct_key = item.get("đáp án") or item.get("answer") or item.get("correct") or item.get("ans") correct_text = "" if isinstance(correct_key, str) and correct_key.strip() in options: correct_text = options[correct_key.strip()] else: # maybe the answer is full text if isinstance(correct_key, str): correct_text = correct_key.strip() else: # fallback to 'correct_text' field correct_text = item.get("correct_text") or item.get("đáp án_text") or "" # default empty guard options = {k: str(v) for k, v in options.items()} correct_text = str(correct_text) # prepare statement for retrieval statement = f"{q_text} Answer: {correct_text}" retrieved = _semantic_search(statement, k=top_k) # build context from top retrieved context_parts = [] for ridx, score in retrieved: md = self.metadata[ridx] if ridx < len(self.metadata) else {} context_parts.append({"idx": ridx, "score": float(score), "page": md.get("page", None), "text": self.texts[ridx]}) context_text = "\n\n".join([f"[page {p['page']}] {p['text']}" for p in context_parts]) # Evidence list (embedding-based) evidence_list = [] max_sim = 0.0 for r in context_parts: if r["score"] >= evidence_score_cutoff: snippet = r["text"] evidence_list.append({ "idx": r["idx"], "page": r["page"], "score": r["score"], "text": (snippet[:1000] + ("..." if len(snippet) > 1000 else "")), }) if r["score"] > max_sim: max_sim = float(r["score"]) supported_by_embeddings = max_sim >= similarity_threshold # Cross-encoder entailment scores for each option entailment_scores = {} correct_entail = 0.0 try: if cross_entail is not None and context_text.strip(): # prepare list of (premise, hypothesis) pairs = [] opt_keys = list(options.keys()) for k in opt_keys: hyp = f"{q_text} Answer: {options[k]}" pairs.append((context_text, hyp)) scores = cross_entail.predict(pairs) # returns list of floats # normalize scores to 0-1 if needed (cross-encoder may return arbitrary positive) # do a min-max normalization across the returned scores # but avoid division by zero min_s = float(min(scores)) if len(scores) else 0.0 max_s = float(max(scores)) if len(scores) else 1.0 denom = max_s - min_s if max_s - min_s > 1e-6 else 1.0 for k, raw in zip(opt_keys, scores): scaled = (raw - min_s) / denom entailment_scores[k] = float(scaled) # find correct key if available # if `correct_text` exactly matches one of options, find that key matched_key = None for k, v in options.items(): if _norm_text(v) == _norm_text(correct_text): matched_key = k break if matched_key: correct_entail = entailment_scores.get(matched_key, 0.0) else: # fallback: treat 'correct_text' as a separate hypothesis hyp = f"{q_text} Answer: {correct_text}" raw = cross_entail.predict([(context_text, hyp)])[0] # scale relative to min/max used above correct_entail = float((raw - min_s) / denom) else: entailment_scores = {} correct_entail = 0.0 except Exception as e: entailment_scores = {} correct_entail = 0.0 def embed_cosine_sim(a, b): emb = self.embedder.encode([a, b], convert_to_numpy=True, normalize_embeddings=True) return float(np.dot(emb[0], emb[1])) # QA consistency qa_answer = None qa_score = 0.0 qa_agrees = False if qa_pipeline is not None and context_text.strip(): try: qa_res = qa_pipeline(question=q_text, context=context_text) # some QA pipelines return list of answers or dict if isinstance(qa_res, list) and len(qa_res) > 0: top = qa_res[0] qa_answer = top.get("answer") if isinstance(top, dict) else str(top) # qa_score = float(top.get("score", 0.0) if isinstance(top, dict) else 0.0) elif isinstance(qa_res, dict): qa_answer = qa_res.get("answer", "") qa_score = float(qa_res.get("score", 0.0)) else: qa_answer = str(qa_res) qa_score = 0.0 qa_score = embed_cosine_sim(qa_answer, correct_text) qa_agrees = (qa_score >= 0.5) except Exception: qa_answer = None qa_score = 0.0 qa_agrees = False try: opt_embs = _compute_option_embeddings({**options, "__CORRECT__": correct_text}) correct_emb = opt_embs.pop("__CORRECT__") distractor_similarities = {} for k, emb in opt_embs.items(): distractor_similarities[k] = float(_cosine(correct_emb, emb)) except Exception: distractor_similarities = {k: None for k in options.keys()} # distractor flags distractor_penalty = 0.0 distractor_flags = [] for k, sim in distractor_similarities.items(): if sim is None or sim >= 0.999999 or (sim >= -0.01 and sim <= 0): continue if sim >= distractor_too_similar: distractor_flags.append({"key": k, "reason": "too_similar", "similarity": sim}) distractor_penalty += 0.25 elif sim <= distractor_too_different: distractor_flags.append({"key": k, "reason": "too_different", "similarity": sim}) distractor_penalty += 0.15 # clamp penalty distractor_penalty = min(distractor_penalty, 1.0) # Ambiguity detection: how many options have entailment >= threshold ambiguous = False ambiguous_options = [] if entailment_scores: # count options whose entailment >= max(correct_entail * 0.9, 0.6) amb_thresh = max(correct_entail * 0.9, 0.6) for k, sc in entailment_scores.items(): if sc >= amb_thresh and (options.get(k, "") != correct_text): ambiguous_options.append({"key": k, "score": sc, "text": options[k]}) ambiguous = len(ambiguous_options) > 0 # Compose aggregated quality score # Components: # - embedding_support: normalized max_sim (0..1) # - entailment: correct_entail (0..1) # - qa_agree: boolean -> 1 or 0 times qa_score # - distractor_penalty: subtracted emb_support_norm = max_sim # embedding similarity typically already 0..1 (inner product normalized) entail_component = float(correct_entail) qa_component = float(qa_score) if qa_agrees else 0.0 # weighted sum quality_score = ( 0.40 * emb_support_norm + 0.35 * entail_component + 0.20 * qa_component - 0.05 * distractor_penalty ) # clamp to 0..1 quality_score = max(0.0, min(1.0, quality_score)) # triage decision triage_action = "reject" if quality_score >= auto_accept_threshold and not ambiguous: triage_action = "pass" elif quality_score >= review_threshold: triage_action = "review" else: triage_action = "reject" # compile flags/reasons flag_reasons = [] if not supported_by_embeddings: flag_reasons.append("no_strong_embedding_evidence") if entailment_scores and correct_entail < 0.6: flag_reasons.append("low_entailment_score_for_correct") if qa_pipeline is not None and qa_score > 0.6 and not qa_agrees: flag_reasons.append("qa_contradiction") if ambiguous: flag_reasons.append("ambiguous_options_supported") if distractor_flags: flag_reasons.append({"distractor_issues": distractor_flags}) # assemble per-question report report[qid] = { "supported_by_embeddings": bool(supported_by_embeddings), "max_similarity": float(max_sim), "evidence": evidence_list, "entailment_scores": entailment_scores, "correct_entailment": float(correct_entail), "qa_answer": qa_answer, "qa_score": float(qa_score), "qa_agrees": bool(qa_agrees), "distractor_similarities": distractor_similarities, "distractor_flags": distractor_flags, "distractor_penalty": float(distractor_penalty), "ambiguous_options": ambiguous_options, "quality_score": float(quality_score), "triage_action": triage_action, "flag_reasons": flag_reasons, } return report def connect_qdrant(self, url: str, api_key: str = None, prefer_grpc: bool = False): if not _HAS_QDRANT: raise RuntimeError("qdrant-client is not installed. Install with `pip install qdrant-client`.") self.qdrant_url = url self.qdrant_api_key = api_key self.qdrant_prefer_grpc = prefer_grpc # Create client self.qdrant = QdrantClient(url=url, api_key=api_key, prefer_grpc=prefer_grpc) def _ensure_collection(self, collection_name: str): if self.qdrant is None: raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.") try: # get_collection will raise if not present _ = self.qdrant.get_collection(collection_name) except Exception: # create collection with vector size = self.dim vect_params = VectorParams(size=self.dim, distance=Distance.COSINE) self.qdrant.recreate_collection(collection_name=collection_name, vectors_config=vect_params) # recreate_collection ensures a clean collection; if you prefer to avoid wiping use create_collection instead. def save_pdf_to_qdrant( self, pdf_path: str, filename: str, collection: str, max_chars: int = 1200, batch_size: int = 64, overwrite: bool = False, ): if self.qdrant is None: raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.") # extract pages and chunks (re-using your existing helpers) pages = self.extract_pages(pdf_path) all_chunks = [] all_meta = [] for p_idx, page_text in enumerate(pages, start=1): chunks = self.chunk_text(page_text or "", max_chars=max_chars) for cid, ch in enumerate(chunks, start=1): all_chunks.append(ch) all_meta.append({"page": p_idx, "chunk_id": cid, "length": len(ch)}) if not all_chunks: raise RuntimeError("No tSext extracted from PDF.") # ensure collection exists self._ensure_collection(collection) # optional: delete previous points for this filename if overwrite if overwrite: # delete by filter: filename == filename flt = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))]) try: # qdrant-client delete uses delete( self.qdrant.delete(collection_name=collection, filter=flt) except Exception: # ignore if deletion fails pass # compute embeddings in batches embeddings = self.embedder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True) embeddings = embeddings.astype("float32") # prepare points points = [] for i, (emb, md, txt) in enumerate(zip(embeddings, all_meta, all_chunks)): pid = str(uuid4()) source_id = f"{filename}__p{md['page']}__c{md['chunk_id']}" payload = { "filename": filename, "page": md["page"], "chunk_id": md["chunk_id"], "length": md["length"], "text": txt, "source_id": source_id, } points.append(PointStruct(id=pid, vector=emb.tolist(), payload=payload)) # pyright: ignore[reportPossiblyUnboundVariable] # upsert in batches if len(points) >= batch_size: self.qdrant.upsert(collection_name=collection, points=points) points = [] # upsert remaining if points: self.qdrant.upsert(collection_name=collection, points=points) try: self.qdrant.create_payload_index( collection_name=collection, field_name="filename", field_schema=rest.PayloadSchemaType.KEYWORD ) except Exception as e: print(f"Index creation skipped or failed: {e}") return {"status": "ok", "uploaded_chunks": len(all_chunks), "collection": collection, "filename": filename} def list_files_in_collection( self, collection: str, payload_field: str = "filename", batch_size: int = 500, ) -> List[str]: if self.qdrant is None: raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.") # ensure collection exists try: if not self.qdrant.collection_exists(collection): raise RuntimeError(f"Collection '{collection}' does not exist.") except Exception: # collection_exists may raise if server unreachable raise filenames = set() offset = None while True: # scroll returns (points, next_offset) pts, next_offset = self.qdrant.scroll( collection_name=collection, limit=batch_size, offset=offset, with_payload=[payload_field], with_vectors=False, ) if not pts: break for p in pts: # p may be a dict-like or an object with .payload payload = None if hasattr(p, "payload"): payload = p.payload elif isinstance(p, dict): # older/newer variants might use nested structures: try common keys payload = p.get("payload") or p.get("payload", None) or p else: # best-effort fallback: convert to dict if possible try: payload = dict(p) except Exception: payload = None if not payload: continue # extract candidate value(s) val = None if isinstance(payload, dict): val = payload.get(payload_field) else: # Some payload representations store fields differently; try attribute access val = getattr(payload, payload_field, None) # If value is list-like, iterate, else add single if isinstance(val, (list, tuple, set)): for v in val: if v is not None: filenames.add(str(v)) elif val is not None: filenames.add(str(val)) # stop if no more pages if not next_offset: break offset = next_offset return sorted(filenames) def list_chunks_for_filename(self, collection: str, filename: str, batch: int = 256) -> List[Dict[str, Any]]: if self.qdrant is None: raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.") results = [] offset = None while True: # scroll returns (points, next_offset) points, next_offset = self.qdrant.scroll( collection_name=collection, scroll_filter=Filter( must=[ FieldCondition(key="filename", match=MatchValue(value=filename)) ] ), limit=batch, offset=offset, with_payload=True, with_vectors=False, ) # points are objects (Record / ScoredPoint-like); get id and payload for p in points: # p.payload is a dict, p.id is point id results.append({"point_id": p.id, "payload": p.payload}) if not next_offset: break offset = next_offset return results def _retrieve_qdrant(self, query: str, collection: str, filename: str = None, top_k: int = 3) -> List[Tuple[Dict[str, Any], float]]: if self.qdrant is None: raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.") q_emb = self.embedder.encode([query], convert_to_numpy=True).astype("float32")[0].tolist() q_filter = None if filename: q_filter = Filter(must=[FieldCondition(key="filename", match=MatchValue(value=filename))]) search_res = self.qdrant.search( collection_name=collection, query_vector=q_emb, query_filter=q_filter, limit=top_k, with_payload=True, with_vectors=False, ) out = [] for hit in search_res: # hit.payload is the stored payload, hit.score is similarity out.append((hit.payload, float(getattr(hit, "score", 0.0)))) return out def generate_from_qdrant( self, filename: str, collection: str, n_questions: int = 10, mode: str = "rag", # 'per_chunk' or 'rag' questions_per_chunk: int = 3, # used for 'per_chunk' top_k: int = 3, # retrieval size used in RAG temperature: float = 0.2, enable_fiddler: bool = False, ) -> Dict[str, Any]: if self.qdrant is None: raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.") # get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.) file_points = self.list_chunks_for_filename(collection=collection, filename=filename) if not file_points: raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.") # create a local list of texts & metadata for sampling texts = [] metas = [] for p in file_points: payload = p.get("payload", {}) text = payload.get("text", "") texts.append(text) metas.append(payload) self.texts = texts self.metadata = metas embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True) if embeddings is None or len(embeddings) == 0: self.embeddings = None self.index = None else: self.embeddings = embeddings.astype("float32") # update dim in case embedder changed unexpectedly self.dim = int(self.embeddings.shape[1]) # build index self._build_faiss_index() output = {} qcount = 0 if mode == "per_chunk": # iterate all chunks (in payload order) and request questions_per_chunk from each for i, txt in enumerate(texts): if not txt.strip(): continue try: structured_context = structure_context_for_llm(txt, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler) mcq_block = generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler) except Exception as e: print(f"Generator failed on chunk (index {i}): {e}") continue if "error" in list(mcq_block.keys()): return output for item in sorted(mcq_block.keys(), key=lambda x: int(x)): qcount += 1 output[str(qcount)] = mcq_block[item] if qcount >= n_questions: return output return output elif mode == "rag": attempts = 0 max_attempts = n_questions * 4 while qcount < n_questions and attempts < max_attempts: attempts += 1 # create a seed query: pick a random chunk, pick a sentence from it seed_idx = random.randrange(len(self.texts)) chunk = self.texts[seed_idx] sents = re.split(r'(?<=[\.\?\!])\s+', chunk) candidate = [s for s in sents if len(s.strip()) > 20] if candidate: seed_sent = random.choice(candidate) else: stripped = chunk.strip() seed_sent = (stripped[:200] if stripped else "[no text available]") query = f"Create questions about: {seed_sent}" # retrieve top_k chunks from the same file (restricted by filename filter) retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k) context_parts = [] for payload, score in retrieved: # payload should contain page & chunk_id and text page = payload.get("page", "?") ctxt = payload.get("text", "") context_parts.append(f"[page {page}] {ctxt}") context = "\n\n".join(context_parts) try: structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler) mcq_block = generate_mcqs_from_text(structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, enable_fiddler=enable_fiddler) except Exception as e: print(f"Generator failed during RAG attempt {attempts}: {e}") continue if "error" in list(mcq_block.keys()): return output for item in sorted(mcq_block.keys(), key=lambda x: int(x)): payload = mcq_block[item] q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip() options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {} if isinstance(options, list): options = {str(i+1): o for i, o in enumerate(options)} correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None correct_text = "" if isinstance(correct_key, str) and correct_key.strip() in options: correct_text = options[correct_key.strip()] else: correct_text = payload.get("correct_text") or correct_key or "" diff_score, diff_label = self._estimate_difficulty_for_generation( q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=context ) payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label} qcount += 1 output[str(qcount)] = mcq_block[item] if qcount >= n_questions: return output return output else: raise ValueError("mode must be 'per_chunk' or 'rag'.") def _estimate_difficulty_for_generation( self, q_text: str, options: Dict[str, str], correct_text: str, context_text: str, ) -> Tuple[float, str]: def safe_map_sim(s): # map potentially [-1,1] cosine-like to [0,1], clamp try: s = float(s) except Exception: return 0.0 mapped = (s + 1.0) / 2.0 return max(0.0, min(1.0, mapped)) # embedding support emb_support = 0.0 try: stmt = (q_text or "").strip() if correct_text: stmt = f"{stmt} Answer: {correct_text}" # use internal retrieve but map returned score res = [] try: res = self._retrieve(stmt, top_k=1) except Exception: res = [] if res: raw_score = float(res[0][1]) emb_support = safe_map_sim(raw_score) else: emb_support = 0.0 except Exception: emb_support = 0.0 # distractor sims mean_sim = 0.0 distractor_penalty = 0.0 amb_flag = 0.0 try: keys = list(options.keys()) texts = [options[k] for k in keys] if correct_text is None: correct_text = "" all_texts = [correct_text] + texts embs = self.embedder.encode(all_texts, convert_to_numpy=True) embs = np.asarray(embs, dtype=float) norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12 embs = embs / norms corr = embs[0] opts = embs[1:] if opts.size == 0: mean_sim = 0.0 distractor_penalty = 0.0 gap = 0.0 else: sims = (opts @ corr).tolist() # [-1,1] sims_mapped = [safe_map_sim(s) for s in sims] # [0,1] mean_sim = float(sum(sims_mapped) / len(sims_mapped)) # gap between best distractor and second best (higher gap -> easier) sorted_s = sorted(sims_mapped, reverse=True) top = sorted_s[0] second = sorted_s[1] if len(sorted_s) > 1 else 0.0 gap = top - second # penalties: if distractors are extremely close to correct -> higher penalty too_close_count = sum(1 for s in sims_mapped if s >= 0.85) too_far_count = sum(1 for s in sims_mapped if s <= 0.15) distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped)))) amb_flag = 1.0 if top >= 0.9 else 0.0 except Exception: mean_sim = 0.0 distractor_penalty = 0.0 amb_flag = 0.0 gap = 0.0 # stem length normalized qlen = len((q_text or "").strip()) qlen_norm = min(1.0, qlen / 300.0) # combine signals using safer semantics: # higher emb_support -> easier (so we subtract a term) # higher distractor_penalty -> harder (add) # better gap -> easier (subtract) # compute score (higher -> harder) score = 0 score += 0.35 * float(distractor_penalty) score += 0.20 * float(mean_sim) score += 0.22 * float(amb_flag) score += 0.05 * float(qlen_norm) score -= 0.20 * float(gap) # clamp score = max(0.0, min(1.0, float(score))) # label if score <= 0.33: label = "dễ" elif score <= 0.66 and score > 0.33: label = "trung bình" else: label = "khó" return score, label class RAGMCQWithDifficulty(RAGMCQ): def __init__( self, embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", generation_model: str = "openai/gpt-oss-120b", qdrant_url: str = os.environ.get('QDRANT_URL') or "", qdrant_api_key: str = os.environ.get('QDRANT_API_KEY') or "", qdrant_prefer_grpc: bool = False, ): super().__init__(embedder_model, generation_model, qdrant_url, qdrant_api_key, qdrant_prefer_grpc) @override def generate_from_pdf( self, pdf_path: str, n_questions: int = 10, mode: str = "rag", # per_page or rag questions_per_page: int = 3, # for per_page mode top_k: int = 3, # chunks to retrieve for each question in rag mode temperature: float = 0.2, enable_fiddler: bool = False, target_difficulty: str = 'easy' # easy, mid, difficult ) -> Dict[str, Any]: # build index self.build_index_from_pdf(pdf_path) output: Dict[str, Any] = {} qcount = 0 if mode == "per_page": # iterate pages -> chunks for idx, meta in enumerate(self.metadata): chunk_text = self.texts[idx] if not chunk_text.strip(): continue # ask generator try: structured_context = structure_context_for_llm(chunk_text, model=self.generation_model, temperature=0.2, enable_fiddler=enable_fiddler) mcq_block = new_generate_mcqs_from_text( source_text=structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler ) except Exception as e: # skip this chunk if generator fails print(f"Generator failed on page {meta['page']} chunk {meta['chunk_id']}: {e}") continue if "error" in list(mcq_block.keys()): return output for item in sorted(mcq_block.keys(), key=lambda x: int(x)): qcount += 1 output[str(qcount)] = mcq_block[item] if qcount >= n_questions: return output return output # pdf gene elif mode == "rag": # strategy: create a few natural short queries by sampling sentences or using chunk summaries. # create queries by sampling chunk text sentences. # stop when n_questions reached or max_attempts exceeded. attempts = 0 max_attempts = n_questions * 4 while qcount < n_questions and attempts < max_attempts: attempts += 1 # create a seed query: pick a random chunk, pick a sentence from it seed_idx = random.randrange(len(self.texts)) chunk = self.texts[seed_idx] #? investigate better Chunking Strategy #with open("chunks.txt", "a", encoding="utf-8") as f: #f.write(chunk + "\n") sents = re.split(r'(?<=[\.\?\!])\s+', chunk) seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200] query = f"Create questions about: {seed_sent}" # retrieve top_k chunks retrieved = self._retrieve(query, top_k=top_k) context_parts = [] for ridx, score in retrieved: md = self.metadata[ridx] context_parts.append(f"[page {md['page']}] {self.texts[ridx]}") context = "\n\n".join(context_parts) # save_to_local('test/context.md', content=context) # call generator for 1 question (or small batch) with the retrieved context try: structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False) mcq_block = new_generate_mcqs_from_text( source_text=structured_context, n=questions_per_page, model=self.generation_model, temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler ) except Exception as e: print(f"Generator failed during RAG attempt {attempts}: {e}") continue if "error" in list(mcq_block.keys()): return output # append result(s) for item in sorted(mcq_block.keys(), key=lambda x: int(x)): payload = mcq_block[item] q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip() options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {} if isinstance(options, list): options = {str(i+1): o for i, o in enumerate(options)} correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None correct_text = "" if isinstance(correct_key, str) and correct_key.strip() in options: correct_text = options[correct_key.strip()] else: correct_text = payload.get("correct_text") or correct_key or "" diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts ) payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label} qcount += 1 output[str(qcount)] = mcq_block[item] if qcount >= n_questions: return output return output else: raise ValueError("mode must be 'per_page' or 'rag'.") @override def generate_from_qdrant( self, filename: str, collection: str, n_questions: int = 10, mode: str = "rag", # 'per_chunk' or 'rag' questions_per_chunk: int = 3, # used for 'per_chunk' top_k: int = 3, # retrieval size used in RAG temperature: float = 0.2, enable_fiddler: bool = False, target_difficulty: str = 'easy', ) -> Dict[str, Any]: if self.qdrant is None: raise RuntimeError("Qdrant client not connected. Call connect_qdrant(...) first.") # get all chunks for this filename (payload should contain 'text', 'page', 'chunk_id', etc.) file_points = self.list_chunks_for_filename(collection=collection, filename=filename) if not file_points: raise RuntimeError(f"No chunks found for filename={filename} in collection={collection}.") # create a local list of texts & metadata for sampling texts = [] metas = [] for p in file_points: payload = p.get("payload", {}) text = payload.get("text", "") texts.append(text) metas.append(payload) self.texts = texts self.metadata = metas embeddings = self.embedder.encode(texts, convert_to_numpy=True, show_progress_bar=True) if embeddings is None or len(embeddings) == 0: self.embeddings = None self.index = None else: self.embeddings = embeddings.astype("float32") # update dim in case embedder changed unexpectedly self.dim = int(self.embeddings.shape[1]) # build index self._build_faiss_index() output = {} qcount = 0 if mode == "per_chunk": # iterate all chunks (in payload order) and request questions_per_chunk from each for i, txt in enumerate(texts): if not txt.strip(): continue try: structured_context = structure_context_for_llm(txt, model=self.generation_model, temperature=0.2, enable_fiddler=False) mcq_block = new_generate_mcqs_from_text( source_text=structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler ) except Exception as e: print(f"Generator failed on chunk (index {i}): {e}") continue if "error" in list(mcq_block.keys()): return output for item in sorted(mcq_block.keys(), key=lambda x: int(x)): qcount += 1 output[str(qcount)] = mcq_block[item] if qcount >= n_questions: return output return output elif mode == "rag": attempts = 0 max_attempts = n_questions * 4 while qcount < n_questions and attempts < max_attempts: attempts += 1 # create a seed query: pick a random chunk, pick a sentence from it seed_idx = random.randrange(len(self.texts)) chunk = self.texts[seed_idx] sents = re.split(r'(?<=[\.\?\!])\s+', chunk) candidate = [s for s in sents if len(s.strip()) > 20] if candidate: seed_sent = random.choice(candidate) else: stripped = chunk.strip() seed_sent = (stripped[:200] if stripped else "[no text available]") query = f"Create questions about: {seed_sent}" # retrieve top_k chunks from the same file (restricted by filename filter) retrieved = self._retrieve_qdrant(query=query, collection=collection, filename=filename, top_k=top_k) context_parts = [] for payload, score in retrieved: # payload should contain page & chunk_id and text page = payload.get("page", "?") ctxt = payload.get("text", "") context_parts.append(f"[page {page}] {ctxt}") context = "\n\n".join(context_parts) # q generation try: structured_context = structure_context_for_llm(context, model=self.generation_model, temperature=0.2, enable_fiddler=False) mcq_block = new_generate_mcqs_from_text( source_text=structured_context, n=questions_per_chunk, model=self.generation_model, temperature=temperature, target_difficulty=target_difficulty ,enable_fiddler=enable_fiddler ) except Exception as e: print(f"Generator failed during RAG attempt {attempts}: {e}") continue if "error" in list(mcq_block.keys()): return output for item in sorted(mcq_block.keys(), key=lambda x: int(x)): payload = mcq_block[item] q_text = (payload.get("câu hỏi") or payload.get("question") or payload.get("stem") or "").strip() options = payload.get("lựa chọn") or payload.get("options") or payload.get("choices") or {} if isinstance(options, list): options = {str(i+1): o for i, o in enumerate(options)} correct_key = payload.get("đáp án") or payload.get("answer") or payload.get("correct") or None concepts = payload.get("khái niệm sử dụng") or payload.get("concepts") or payload.get("concepts used") or None correct_text = "" if isinstance(correct_key, str) and correct_key.strip() in options: correct_text = options[correct_key.strip()] else: correct_text = payload.get("correct_text") or correct_key or "" #? change estimate diff_score, diff_label, components = self._estimate_difficulty_for_generation( # type: ignore q_text=q_text, options={k: str(v) for k,v in options.items()}, correct_text=str(correct_text), context_text=structured_context, concepts_used=concepts # type: ignore ) payload["độ khó"] = {"điểm": diff_score, "mức độ": diff_label} # CHECK n generation: if number of request mcqs < default generation number e.g. 5 - 3 = 2 < 3 then only genearate 2 mcqs if n_questions - qcount < questions_per_chunk: questions_per_chunk = n_questions - qcount qcount += 1 # count number of question # print('qcount:', qcount) # print('questions_per_chunk:', questions_per_chunk) output[str(qcount)] = mcq_block[item] if qcount >= n_questions: return output if output is not None: print("output available") return output else: raise ValueError("mode must be 'per_chunk' or 'rag'.") @override def _estimate_difficulty_for_generation( self, q_text: str, options: Dict[str, str], correct_text: str, context_text: str, concepts_used: Dict = {} ) -> Tuple[float, str]: def safe_map_sim(s): # map potentially [-1,1] cosine-like to [0,1], clamp try: s = float(s) except Exception: return 0.0 mapped = (s + 1.0) / 2.0 return max(0.0, min(1.0, mapped)) # embedding support emb_support = 0.0 try: stmt = (q_text or "").strip() if correct_text: stmt = f"{stmt} Answer: {correct_text}" # use internal retrieve but map returned score res = [] try: res = self._retrieve(stmt, top_k=1) except Exception: res = [] if res: raw_score = float(res[0][1]) emb_support = safe_map_sim(raw_score) else: emb_support = 0.0 except Exception: emb_support = 0.0 # distractor sims mean_sim = 0.0 distractor_penalty = 0.0 amb_flag = 0.0 try: keys = list(options.keys()) texts = [options[k] for k in keys] if correct_text is None: correct_text = "" all_texts = [correct_text] + texts embs = self.embedder.encode(all_texts, convert_to_numpy=True) embs = np.asarray(embs, dtype=float) norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-12 embs = embs / norms corr = embs[0] opts = embs[1:] if opts.size == 0: mean_sim = 0.0 distractor_penalty = 0.0 gap = 0.0 else: sims = (opts @ corr).tolist() # [-1,1] sims_mapped = [safe_map_sim(s) for s in sims] # [0,1] mean_sim = float(sum(sims_mapped) / len(sims_mapped)) # gap between best distractor and second best (higher gap -> easier) sorted_s = sorted(sims_mapped, reverse=True) top = sorted_s[0] second = sorted_s[1] if len(sorted_s) > 1 else 0.0 gap = top - second # penalties: if distractors are extremely close to correct -> higher penalty too_close_count = sum(1 for s in sims_mapped if s >= 0.85) too_far_count = sum(1 for s in sims_mapped if s <= 0.15) distractor_penalty = min(1.0, 0.5 * mean_sim + 0.2 * (too_close_count / max(1, len(sims_mapped))) - 0.2 * (too_far_count / max(1, len(sims_mapped)))) amb_flag = 1.0 if top >= 0.8 else 0.0 except Exception: mean_sim = 0.0 distractor_penalty = 0.0 amb_flag = 0.0 gap = 0.0 # question length normalized question_len = len((q_text or "").strip()) question_len_norm = min(1.0, question_len / 300.0) # count number of concept from string concepts_num = len(concepts_used.keys()) if concepts_num < 2: concepts_penalty = 0 else: concepts_penalty = concepts_num # combine signals using safer semantics: # higher emb_support -> easier (so we subtract a term) # higher distractor_penalty -> harder (add) # better gap -> easier (subtract) # compute score (higher -> harder) score = 0 score += 0.35 * float(distractor_penalty) score += 0.20 * float(mean_sim) score += 0.22 * float(amb_flag) score += 0.08 * float(question_len_norm) score -= 0.20 * float(gap) # clamp score = max(0.0, min(1.0, float(score))) components = { "base": 0.3, "distractor_penalty": 0.35 * float(distractor_penalty), "mean_sim": 0.15 * float(mean_sim), "amb_flag": 0.05 * float(amb_flag), "concepts_num": 0.1 * float(concepts_num), "gap": -0.12 * float(gap), "question_len_norm": 0.05 * float(question_len_norm), "emb_support": -0.45 * float(emb_support), "total_score": score, } # label if score <= 0.56: label = "dễ" elif score <= 0.755 and score > 0.56: label = "trung bình" else: label = "khó" return score, label, components # type: ignore