EXAM_RAG_API / generation /ExamRagGenerator.py
MinaNasser's picture
1st
1bc3f18
from typing import List, Dict
import logging
import json
import re
import math
from json_repair import repair_json
from pydantic import parse_obj_as
from collections import defaultdict
from config import get_settings
from routes.schemas.Exam_Models import *
from stores.llm.LLMProviderFactory import LLMProviderFactory
from generation.AssistantRagGenerator import ProviderLLMWrapper
from generation.prompts import ExamPromptBuilder
from indexing.indexingController import IndexingController
class ExamService:
MAX_CHUNK_CHARS = 2000
MAX_TOTAL_CONTEXT = 8000
MAX_SCORE = 40
PASS_THRESHOLD = int(MAX_SCORE * 0.8)
MAX_GENERATION_ATTEMPTS = 3
def __init__(self):
self.logger = logging.getLogger(__name__)
self._models_initialized = False
self.settings=get_settings()
self._init_models()
self.prompts=ExamPromptBuilder()
self.controller = IndexingController()
self.store = self.controller.vector_store
self.BATCH_SIZE=10
def _init_models(self):
if self._models_initialized:
return
factory = LLMProviderFactory(self.settings)
self.generator = factory.create(self.settings.GENERATION_BACKEND)
self.generator.set_generation_model(self.settings.GENERATION_MODEL_ID)
self.embedding_provider = factory.create(self.settings.EMBEDDING_BACKEND)
self.embedding_provider.set_embedding_model(
self.settings.EMBEDDING_MODEL_ID,
self.settings.EMBEDDING_MODEL_SIZE
)
self.llm = ProviderLLMWrapper(provider=self.generator)
self._models_initialized = True
def _extract_json(self, text: str) -> dict:
"""
Extract the first valid JSON object from LLM output. Attempts to repair malformed JSON using `repair_json`.
"""
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
self.logger.error("No JSON found in LLM response:\n%s", text)
raise ValueError("LLM returned no JSON")
json_str = match.group(0)
# Try to load directly
try:
return json.loads(json_str)
except json.JSONDecodeError:
self.logger.warning("Invalid JSON extracted, attempting repair...")
try:
repaired_str = repair_json(json_str)
return json.loads(repaired_str)
except Exception as e:
self.logger.error("Failed to repair JSON:\n%s\nError: %s", json_str, e)
raise
def normalize_exam_dict(self, data: dict):
# Normalize difficulty enum
if "difficulty" in data:
diff = data["difficulty"]
if isinstance(diff, str):
if "." in diff:
diff = diff.split(".")[-1]
data["difficulty"] = diff.lower()
# Normalize questions
questions = data.get("questions")
if not isinstance(questions, list):
return data
normalized_questions = []
for q in questions:
if not isinstance(q, dict):
continue
q.pop("id", None)
q.pop("question_id", None)
q.pop("points", None)
# normalize type
q_type = q.get("type")
if isinstance(q_type, str):
q_type = q_type.lower().strip()
if q_type == "truefalse":
q_type = "true_false"
q["type"] = q_type
# normalize question text
if "question" in q:
q["question"] = str(q["question"]).strip()
# MCQ normalization
if q_type == "mcq":
options = q.get("options")
# dict -> list
if isinstance(options, dict):
options = list(options.values())
# string -> split into options
elif isinstance(options, str):
parts = re.split(r"[A-D]\)|\n|\r", options)
options = [
p.strip(" .-")
for p in parts
if p.strip()
]
# ensure list[str]
if isinstance(options, list):
options = [str(o).strip() for o in options]
else:
options = []
q["options"] = options
# normalize correct answer
correct = q.get("correct_answer")
if correct is not None:
correct = str(correct).strip()
q["correct_answer"] = correct
# ensure correct answer exists in options
if correct not in q["options"]:
q["options"].append(correct)
# ensure explanation exists
q.setdefault("explanation", "")
# True/False normalization
elif q_type == "true_false":
ans = q.get("correct_answer")
if isinstance(ans, str):
ans = ans.lower()
if ans in ["true", "t", "1", "yes"]:
ans = True
elif ans in ["false", "f", "0", "no"]:
ans = False
q["correct_answer"] = ans
q.setdefault("explanation", "")
# Short Answer normalization
elif q_type == "short_answer":
if "answer" in q:
q["answer"] = str(q["answer"]).strip()
q.setdefault("explanation", "")
# Essay normalization
elif q_type == "essay":
if "expected_keywords" in q:
keywords = q.pop("expected_keywords")
if isinstance(keywords, list):
q["answer_guidelines"] = ", ".join(keywords)
else:
q["answer_guidelines"] = str(keywords)
q.setdefault("answer_guidelines", "")
# Code question normalization
elif q_type == "code":
if "solution" in q:
q["solution"] = str(q["solution"])
q.setdefault("starter_code", None)
q.setdefault("explanation", "")
normalized_questions.append(q)
data["questions"] = normalized_questions
return data
def generate_exam(self, request: ExamGenerationRequest, context: str, llm, batch_size: int) -> List[QuestionUnion]:
"""
Generate a batch of questions from the LLM, ensuring valid QuestionUnion objects.Repairs incomplete MCQs automatically.
"""
# Prepare the prompt for the batch
batch_request = request.model_copy()
batch_request.total_questions = batch_size
prompt = self.prompts.build_exam_generation_prompt(batch_request, context)
raw_text = llm._call(prompt)
if not raw_text:
raise RuntimeError("LLM generation failed")
cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip()
try:
exam_dict = self._extract_json(cleaned)
exam_dict = self.normalize_exam_dict(exam_dict)
questions = exam_dict.get("questions") or []
questions = questions[:batch_size]
# Repair incomplete MCQs or missing fields
repaired_questions = []
for q in questions:
if not isinstance(q, dict):
continue # skip invalid entries
q_type = q.get("type")
if q_type == "mcq":
if not q.get("options"):
self.logger.warning(f"Skipping MCQ with no options: {q}")
continue
if not q.get("correct_answer"):
q["correct_answer"] = q["options"][0] # safe placeholder
repaired_questions.append(q)
# Convert to Pydantic QuestionUnion objects
questions = parse_obj_as(List[QuestionUnion], repaired_questions)
self.logger.info(
"Batch requested=%d | received=%d | kept=%d",
batch_size,
len(exam_dict.get("questions", [])),
len(questions),
)
except json.JSONDecodeError:
self.logger.error("Invalid JSON from LLM:\n%s", raw_text)
raise
return questions
def evaluate_exam(self, request: ExamGenerationRequest, exam: ExamResponse, llm):
prompt = self.prompts.build_exam_evaluation_prompt(request, exam)
raw_text = llm._call(prompt)
if not raw_text:
raise RuntimeError("Evaluation generation failed")
cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip()
try:
evaluation_dict = self._extract_json(cleaned)
except json.JSONDecodeError:
self.logger.error("Invalid evaluation JSON:\n%s", raw_text)
raise
return EvaluationResult.model_validate(evaluation_dict)
def split_chunks_by_topic_batches(self, exam_chunks, num_batches):
self.logger.info(f"Topics retrieved: {list(exam_chunks.keys())}")
self.logger.info(f"Number of batches: {num_batches}")
batches = [[] for _ in range(num_batches)]
for topic, chunks in exam_chunks.items():
total_chunks = len(chunks)
self.logger.info(f"Topic '{topic}' -> {total_chunks} chunks distributed across batches")
for idx, chunk in enumerate(chunks):
batch_index = idx % num_batches
batches[batch_index].append(chunk)
# Log batch composition
for i, batch in enumerate(batches):
topic_counter = defaultdict(int)
for chunk in batch:
topic = chunk.get("metadata", {}).get("topic", "unknown")
topic_counter[topic] += 1
self.logger.info(f"Batch {i+1} contains {len(batch)} chunks -> {dict(topic_counter)}")
return batches
def exam_task(self, request_dict: dict) -> ExamResponse:
"""
Generate a full exam using batching, safety break, and validated QuestionUnion questions.Each batch receives a portion of the retrieved chunks.
"""
request = ExamGenerationRequest.model_validate(request_dict)
# Prepare context from knowledge store
topics_with_embeddings = self.prepare_topics_with_embeddings(request.topics)
exam_chunks = self.store.retrieve_for_exam(topics_with_embeddings,request.username,request.course,request.references)
# Determine number of batches
num_batches = math.ceil(request.total_questions / self.BATCH_SIZE)
self.logger.info(f"Raw exam_chunks structure: {type(exam_chunks)}")
for k, v in exam_chunks.items():
self.logger.info(f"Topic={k} | type={type(v)} | len={len(v) if hasattr(v,'__len__') else 'NA'}")
chunk_batches = self.split_chunks_by_topic_batches(exam_chunks,num_batches)
feedback_context = ""
best_exam = None
best_score = 0
for attempt in range(self.MAX_GENERATION_ATTEMPTS):
self.logger.info(f"Generating exam attempt {attempt+1}")
remaining_distribution: Dict[QuestionType, int] = dict(request.question_types_distribution)
all_questions: List[QuestionUnion] = []
batch_index = 0
# Batch generation loop
while len(all_questions) < request.total_questions:
remaining = request.total_questions - len(all_questions)
batch_size = min(self.BATCH_SIZE, remaining)
# Determine batch distribution
batch_distribution: Dict[QuestionType, int] = {}
slots_left = batch_size
for qtype, count in remaining_distribution.items():
if slots_left <= 0:
break
take = min(count, slots_left)
if take > 0:
batch_distribution[qtype] = take
slots_left -= take
if not batch_distribution:
break
batch_request = request.model_copy()
batch_request.total_questions = sum(batch_distribution.values())
batch_request.question_types_distribution = batch_distribution
# Select chunk subset for this batch
chunk_subset = chunk_batches[batch_index % len(chunk_batches)]
self.logger.info(f"\n===== BATCH {batch_index+1} CHUNKS =====")
for i, chunk in enumerate(chunk_subset):
meta = chunk.get("metadata", {})
topic = meta.get("topic", "unknown")
page = meta.get("page", "NA")
# Try common text keys
text = chunk.get("text") or chunk.get("content") or chunk.get("page_content") or ""
preview = text[:200].replace("\n", " ")
self.logger.info(
f"Chunk {i+1} | Topic={topic} | Page={page} | Preview={preview}"
)
self.logger.info("=====================================\n")
batch_index += 1
batch_context = self.build_exam_context(chunk_subset)
if feedback_context:
batch_context += f"\n\nEvaluator Feedback:\n{feedback_context}"
# Generate questions
batch_questions = self.generate_exam(batch_request,batch_context,self.llm,batch_request.total_questions)
# Filter generated questions
for q in batch_questions:
if remaining_distribution.get(q.type, 0) > 0:
all_questions.append(q)
remaining_distribution[q.type] -= 1
if len(all_questions) >= request.total_questions:
break
# Build final exam
exam_dict = {
"exam_id": request.exam_id,
"difficulty": request.difficulty,
"total_questions": request.total_questions,
"expected_distribution": request.question_types_distribution,
"questions": all_questions[:request.total_questions],
}
try:
exam = ExamResponse.model_validate(exam_dict)
except Exception as e:
self.logger.error(f"Exam validation failed: {e}")
raise
evaluation = self.evaluate_exam(request, exam, self.llm)
self.logger.info(f"Evaluation score: {evaluation.overall_score}")
if evaluation.overall_score > best_score:
best_score = evaluation.overall_score
best_exam = exam
if evaluation.overall_score >= self.PASS_THRESHOLD:
break
feedback_context = evaluation.feedback
if best_exam is None:
raise RuntimeError("Exam generation failed after retries")
return best_exam
def build_exam_context(self, exam_chunks) -> str:
"""
Accepts either:
1) {topic: [chunks]}
2) [chunks]
"""
# Normalize structure
if isinstance(exam_chunks, list):
topic_chunks = defaultdict(list)
for c in exam_chunks:
topic = c.get("metadata", {}).get("topic", "Unknown")
topic_chunks[topic].append(c)
exam_chunks = topic_chunks
context_parts = []
total_length = 0
for topic, chunks in exam_chunks.items():
topic_header = f"\n### Topic: {topic}\n"
if total_length + len(topic_header) > self.MAX_TOTAL_CONTEXT:
break
context_parts.append(topic_header)
total_length += len(topic_header)
for c in chunks:
text = c.get("payload", {}).get("text", "")
source = c.get("metadata", {}).get("source", "")
bookmark = c.get("metadata", {}).get("bookmark_path", "")
if not isinstance(text, str):
continue
if len(text) > self.MAX_CHUNK_CHARS:
text = text[:self.MAX_CHUNK_CHARS]
formatted_chunk = (f"[Source: {source} | Bookmark: {bookmark}]\n{text}\n")
if total_length + len(formatted_chunk) > self.MAX_TOTAL_CONTEXT:
break
context_parts.append(formatted_chunk)
total_length += len(formatted_chunk)
return "\n".join(context_parts)
def prepare_topics_with_embeddings(self, topics: List[str]):
results = []
for topic in topics:
try:
embedding = self.embedding_provider.embed_text(topic)
results.append((topic, embedding))
except Exception as e:
self.logger.warning(f"Embedding failed for topic '{topic}': {e}")
self.logger.info(f"Prepared {len(results)} topic embeddings")
return results