from typing import List, Dict import logging import json import re import math from json_repair import repair_json from pydantic import parse_obj_as from collections import defaultdict from config import get_settings from routes.schemas.Exam_Models import * from stores.llm.LLMProviderFactory import LLMProviderFactory from generation.AssistantRagGenerator import ProviderLLMWrapper from generation.prompts import ExamPromptBuilder from indexing.indexingController import IndexingController class ExamService: MAX_CHUNK_CHARS = 2000 MAX_TOTAL_CONTEXT = 8000 MAX_SCORE = 40 PASS_THRESHOLD = int(MAX_SCORE * 0.8) MAX_GENERATION_ATTEMPTS = 3 def __init__(self): self.logger = logging.getLogger(__name__) self._models_initialized = False self.settings=get_settings() self._init_models() self.prompts=ExamPromptBuilder() self.controller = IndexingController() self.store = self.controller.vector_store self.BATCH_SIZE=10 def _init_models(self): if self._models_initialized: return factory = LLMProviderFactory(self.settings) self.generator = factory.create(self.settings.GENERATION_BACKEND) self.generator.set_generation_model(self.settings.GENERATION_MODEL_ID) self.embedding_provider = factory.create(self.settings.EMBEDDING_BACKEND) self.embedding_provider.set_embedding_model( self.settings.EMBEDDING_MODEL_ID, self.settings.EMBEDDING_MODEL_SIZE ) self.llm = ProviderLLMWrapper(provider=self.generator) self._models_initialized = True def _extract_json(self, text: str) -> dict: """ Extract the first valid JSON object from LLM output. Attempts to repair malformed JSON using `repair_json`. """ match = re.search(r"\{.*\}", text, re.DOTALL) if not match: self.logger.error("No JSON found in LLM response:\n%s", text) raise ValueError("LLM returned no JSON") json_str = match.group(0) # Try to load directly try: return json.loads(json_str) except json.JSONDecodeError: self.logger.warning("Invalid JSON extracted, attempting repair...") try: repaired_str = repair_json(json_str) return json.loads(repaired_str) except Exception as e: self.logger.error("Failed to repair JSON:\n%s\nError: %s", json_str, e) raise def normalize_exam_dict(self, data: dict): # Normalize difficulty enum if "difficulty" in data: diff = data["difficulty"] if isinstance(diff, str): if "." in diff: diff = diff.split(".")[-1] data["difficulty"] = diff.lower() # Normalize questions questions = data.get("questions") if not isinstance(questions, list): return data normalized_questions = [] for q in questions: if not isinstance(q, dict): continue q.pop("id", None) q.pop("question_id", None) q.pop("points", None) # normalize type q_type = q.get("type") if isinstance(q_type, str): q_type = q_type.lower().strip() if q_type == "truefalse": q_type = "true_false" q["type"] = q_type # normalize question text if "question" in q: q["question"] = str(q["question"]).strip() # MCQ normalization if q_type == "mcq": options = q.get("options") # dict -> list if isinstance(options, dict): options = list(options.values()) # string -> split into options elif isinstance(options, str): parts = re.split(r"[A-D]\)|\n|\r", options) options = [ p.strip(" .-") for p in parts if p.strip() ] # ensure list[str] if isinstance(options, list): options = [str(o).strip() for o in options] else: options = [] q["options"] = options # normalize correct answer correct = q.get("correct_answer") if correct is not None: correct = str(correct).strip() q["correct_answer"] = correct # ensure correct answer exists in options if correct not in q["options"]: q["options"].append(correct) # ensure explanation exists q.setdefault("explanation", "") # True/False normalization elif q_type == "true_false": ans = q.get("correct_answer") if isinstance(ans, str): ans = ans.lower() if ans in ["true", "t", "1", "yes"]: ans = True elif ans in ["false", "f", "0", "no"]: ans = False q["correct_answer"] = ans q.setdefault("explanation", "") # Short Answer normalization elif q_type == "short_answer": if "answer" in q: q["answer"] = str(q["answer"]).strip() q.setdefault("explanation", "") # Essay normalization elif q_type == "essay": if "expected_keywords" in q: keywords = q.pop("expected_keywords") if isinstance(keywords, list): q["answer_guidelines"] = ", ".join(keywords) else: q["answer_guidelines"] = str(keywords) q.setdefault("answer_guidelines", "") # Code question normalization elif q_type == "code": if "solution" in q: q["solution"] = str(q["solution"]) q.setdefault("starter_code", None) q.setdefault("explanation", "") normalized_questions.append(q) data["questions"] = normalized_questions return data def generate_exam(self, request: ExamGenerationRequest, context: str, llm, batch_size: int) -> List[QuestionUnion]: """ Generate a batch of questions from the LLM, ensuring valid QuestionUnion objects.Repairs incomplete MCQs automatically. """ # Prepare the prompt for the batch batch_request = request.model_copy() batch_request.total_questions = batch_size prompt = self.prompts.build_exam_generation_prompt(batch_request, context) raw_text = llm._call(prompt) if not raw_text: raise RuntimeError("LLM generation failed") cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip() try: exam_dict = self._extract_json(cleaned) exam_dict = self.normalize_exam_dict(exam_dict) questions = exam_dict.get("questions") or [] questions = questions[:batch_size] # Repair incomplete MCQs or missing fields repaired_questions = [] for q in questions: if not isinstance(q, dict): continue # skip invalid entries q_type = q.get("type") if q_type == "mcq": if not q.get("options"): self.logger.warning(f"Skipping MCQ with no options: {q}") continue if not q.get("correct_answer"): q["correct_answer"] = q["options"][0] # safe placeholder repaired_questions.append(q) # Convert to Pydantic QuestionUnion objects questions = parse_obj_as(List[QuestionUnion], repaired_questions) self.logger.info( "Batch requested=%d | received=%d | kept=%d", batch_size, len(exam_dict.get("questions", [])), len(questions), ) except json.JSONDecodeError: self.logger.error("Invalid JSON from LLM:\n%s", raw_text) raise return questions def evaluate_exam(self, request: ExamGenerationRequest, exam: ExamResponse, llm): prompt = self.prompts.build_exam_evaluation_prompt(request, exam) raw_text = llm._call(prompt) if not raw_text: raise RuntimeError("Evaluation generation failed") cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip() try: evaluation_dict = self._extract_json(cleaned) except json.JSONDecodeError: self.logger.error("Invalid evaluation JSON:\n%s", raw_text) raise return EvaluationResult.model_validate(evaluation_dict) def split_chunks_by_topic_batches(self, exam_chunks, num_batches): self.logger.info(f"Topics retrieved: {list(exam_chunks.keys())}") self.logger.info(f"Number of batches: {num_batches}") batches = [[] for _ in range(num_batches)] for topic, chunks in exam_chunks.items(): total_chunks = len(chunks) self.logger.info(f"Topic '{topic}' -> {total_chunks} chunks distributed across batches") for idx, chunk in enumerate(chunks): batch_index = idx % num_batches batches[batch_index].append(chunk) # Log batch composition for i, batch in enumerate(batches): topic_counter = defaultdict(int) for chunk in batch: topic = chunk.get("metadata", {}).get("topic", "unknown") topic_counter[topic] += 1 self.logger.info(f"Batch {i+1} contains {len(batch)} chunks -> {dict(topic_counter)}") return batches def exam_task(self, request_dict: dict) -> ExamResponse: """ Generate a full exam using batching, safety break, and validated QuestionUnion questions.Each batch receives a portion of the retrieved chunks. """ request = ExamGenerationRequest.model_validate(request_dict) # Prepare context from knowledge store topics_with_embeddings = self.prepare_topics_with_embeddings(request.topics) exam_chunks = self.store.retrieve_for_exam(topics_with_embeddings,request.username,request.course,request.references) # Determine number of batches num_batches = math.ceil(request.total_questions / self.BATCH_SIZE) self.logger.info(f"Raw exam_chunks structure: {type(exam_chunks)}") for k, v in exam_chunks.items(): self.logger.info(f"Topic={k} | type={type(v)} | len={len(v) if hasattr(v,'__len__') else 'NA'}") chunk_batches = self.split_chunks_by_topic_batches(exam_chunks,num_batches) feedback_context = "" best_exam = None best_score = 0 for attempt in range(self.MAX_GENERATION_ATTEMPTS): self.logger.info(f"Generating exam attempt {attempt+1}") remaining_distribution: Dict[QuestionType, int] = dict(request.question_types_distribution) all_questions: List[QuestionUnion] = [] batch_index = 0 # Batch generation loop while len(all_questions) < request.total_questions: remaining = request.total_questions - len(all_questions) batch_size = min(self.BATCH_SIZE, remaining) # Determine batch distribution batch_distribution: Dict[QuestionType, int] = {} slots_left = batch_size for qtype, count in remaining_distribution.items(): if slots_left <= 0: break take = min(count, slots_left) if take > 0: batch_distribution[qtype] = take slots_left -= take if not batch_distribution: break batch_request = request.model_copy() batch_request.total_questions = sum(batch_distribution.values()) batch_request.question_types_distribution = batch_distribution # Select chunk subset for this batch chunk_subset = chunk_batches[batch_index % len(chunk_batches)] self.logger.info(f"\n===== BATCH {batch_index+1} CHUNKS =====") for i, chunk in enumerate(chunk_subset): meta = chunk.get("metadata", {}) topic = meta.get("topic", "unknown") page = meta.get("page", "NA") # Try common text keys text = chunk.get("text") or chunk.get("content") or chunk.get("page_content") or "" preview = text[:200].replace("\n", " ") self.logger.info( f"Chunk {i+1} | Topic={topic} | Page={page} | Preview={preview}" ) self.logger.info("=====================================\n") batch_index += 1 batch_context = self.build_exam_context(chunk_subset) if feedback_context: batch_context += f"\n\nEvaluator Feedback:\n{feedback_context}" # Generate questions batch_questions = self.generate_exam(batch_request,batch_context,self.llm,batch_request.total_questions) # Filter generated questions for q in batch_questions: if remaining_distribution.get(q.type, 0) > 0: all_questions.append(q) remaining_distribution[q.type] -= 1 if len(all_questions) >= request.total_questions: break # Build final exam exam_dict = { "exam_id": request.exam_id, "difficulty": request.difficulty, "total_questions": request.total_questions, "expected_distribution": request.question_types_distribution, "questions": all_questions[:request.total_questions], } try: exam = ExamResponse.model_validate(exam_dict) except Exception as e: self.logger.error(f"Exam validation failed: {e}") raise evaluation = self.evaluate_exam(request, exam, self.llm) self.logger.info(f"Evaluation score: {evaluation.overall_score}") if evaluation.overall_score > best_score: best_score = evaluation.overall_score best_exam = exam if evaluation.overall_score >= self.PASS_THRESHOLD: break feedback_context = evaluation.feedback if best_exam is None: raise RuntimeError("Exam generation failed after retries") return best_exam def build_exam_context(self, exam_chunks) -> str: """ Accepts either: 1) {topic: [chunks]} 2) [chunks] """ # Normalize structure if isinstance(exam_chunks, list): topic_chunks = defaultdict(list) for c in exam_chunks: topic = c.get("metadata", {}).get("topic", "Unknown") topic_chunks[topic].append(c) exam_chunks = topic_chunks context_parts = [] total_length = 0 for topic, chunks in exam_chunks.items(): topic_header = f"\n### Topic: {topic}\n" if total_length + len(topic_header) > self.MAX_TOTAL_CONTEXT: break context_parts.append(topic_header) total_length += len(topic_header) for c in chunks: text = c.get("payload", {}).get("text", "") source = c.get("metadata", {}).get("source", "") bookmark = c.get("metadata", {}).get("bookmark_path", "") if not isinstance(text, str): continue if len(text) > self.MAX_CHUNK_CHARS: text = text[:self.MAX_CHUNK_CHARS] formatted_chunk = (f"[Source: {source} | Bookmark: {bookmark}]\n{text}\n") if total_length + len(formatted_chunk) > self.MAX_TOTAL_CONTEXT: break context_parts.append(formatted_chunk) total_length += len(formatted_chunk) return "\n".join(context_parts) def prepare_topics_with_embeddings(self, topics: List[str]): results = [] for topic in topics: try: embedding = self.embedding_provider.embed_text(topic) results.append((topic, embedding)) except Exception as e: self.logger.warning(f"Embedding failed for topic '{topic}': {e}") self.logger.info(f"Prepared {len(results)} topic embeddings") return results