Spaces:
Paused
Paused
| from typing import List, Dict | |
| import logging | |
| import json | |
| import re | |
| import math | |
| from json_repair import repair_json | |
| from pydantic import parse_obj_as | |
| from collections import defaultdict | |
| from config import get_settings | |
| from routes.schemas.Exam_Models import * | |
| from stores.llm.LLMProviderFactory import LLMProviderFactory | |
| from generation.AssistantRagGenerator import ProviderLLMWrapper | |
| from generation.prompts import ExamPromptBuilder | |
| from indexing.indexingController import IndexingController | |
| class ExamService: | |
| MAX_CHUNK_CHARS = 2000 | |
| MAX_TOTAL_CONTEXT = 8000 | |
| MAX_SCORE = 40 | |
| PASS_THRESHOLD = int(MAX_SCORE * 0.8) | |
| MAX_GENERATION_ATTEMPTS = 3 | |
| def __init__(self): | |
| self.logger = logging.getLogger(__name__) | |
| self._models_initialized = False | |
| self.settings=get_settings() | |
| self._init_models() | |
| self.prompts=ExamPromptBuilder() | |
| self.controller = IndexingController() | |
| self.store = self.controller.vector_store | |
| self.BATCH_SIZE=10 | |
| def _init_models(self): | |
| if self._models_initialized: | |
| return | |
| factory = LLMProviderFactory(self.settings) | |
| self.generator = factory.create(self.settings.GENERATION_BACKEND) | |
| self.generator.set_generation_model(self.settings.GENERATION_MODEL_ID) | |
| self.embedding_provider = factory.create(self.settings.EMBEDDING_BACKEND) | |
| self.embedding_provider.set_embedding_model( | |
| self.settings.EMBEDDING_MODEL_ID, | |
| self.settings.EMBEDDING_MODEL_SIZE | |
| ) | |
| self.llm = ProviderLLMWrapper(provider=self.generator) | |
| self._models_initialized = True | |
| def _extract_json(self, text: str) -> dict: | |
| """ | |
| Extract the first valid JSON object from LLM output. Attempts to repair malformed JSON using `repair_json`. | |
| """ | |
| match = re.search(r"\{.*\}", text, re.DOTALL) | |
| if not match: | |
| self.logger.error("No JSON found in LLM response:\n%s", text) | |
| raise ValueError("LLM returned no JSON") | |
| json_str = match.group(0) | |
| # Try to load directly | |
| try: | |
| return json.loads(json_str) | |
| except json.JSONDecodeError: | |
| self.logger.warning("Invalid JSON extracted, attempting repair...") | |
| try: | |
| repaired_str = repair_json(json_str) | |
| return json.loads(repaired_str) | |
| except Exception as e: | |
| self.logger.error("Failed to repair JSON:\n%s\nError: %s", json_str, e) | |
| raise | |
| def normalize_exam_dict(self, data: dict): | |
| # Normalize difficulty enum | |
| if "difficulty" in data: | |
| diff = data["difficulty"] | |
| if isinstance(diff, str): | |
| if "." in diff: | |
| diff = diff.split(".")[-1] | |
| data["difficulty"] = diff.lower() | |
| # Normalize questions | |
| questions = data.get("questions") | |
| if not isinstance(questions, list): | |
| return data | |
| normalized_questions = [] | |
| for q in questions: | |
| if not isinstance(q, dict): | |
| continue | |
| q.pop("id", None) | |
| q.pop("question_id", None) | |
| q.pop("points", None) | |
| # normalize type | |
| q_type = q.get("type") | |
| if isinstance(q_type, str): | |
| q_type = q_type.lower().strip() | |
| if q_type == "truefalse": | |
| q_type = "true_false" | |
| q["type"] = q_type | |
| # normalize question text | |
| if "question" in q: | |
| q["question"] = str(q["question"]).strip() | |
| # MCQ normalization | |
| if q_type == "mcq": | |
| options = q.get("options") | |
| # dict -> list | |
| if isinstance(options, dict): | |
| options = list(options.values()) | |
| # string -> split into options | |
| elif isinstance(options, str): | |
| parts = re.split(r"[A-D]\)|\n|\r", options) | |
| options = [ | |
| p.strip(" .-") | |
| for p in parts | |
| if p.strip() | |
| ] | |
| # ensure list[str] | |
| if isinstance(options, list): | |
| options = [str(o).strip() for o in options] | |
| else: | |
| options = [] | |
| q["options"] = options | |
| # normalize correct answer | |
| correct = q.get("correct_answer") | |
| if correct is not None: | |
| correct = str(correct).strip() | |
| q["correct_answer"] = correct | |
| # ensure correct answer exists in options | |
| if correct not in q["options"]: | |
| q["options"].append(correct) | |
| # ensure explanation exists | |
| q.setdefault("explanation", "") | |
| # True/False normalization | |
| elif q_type == "true_false": | |
| ans = q.get("correct_answer") | |
| if isinstance(ans, str): | |
| ans = ans.lower() | |
| if ans in ["true", "t", "1", "yes"]: | |
| ans = True | |
| elif ans in ["false", "f", "0", "no"]: | |
| ans = False | |
| q["correct_answer"] = ans | |
| q.setdefault("explanation", "") | |
| # Short Answer normalization | |
| elif q_type == "short_answer": | |
| if "answer" in q: | |
| q["answer"] = str(q["answer"]).strip() | |
| q.setdefault("explanation", "") | |
| # Essay normalization | |
| elif q_type == "essay": | |
| if "expected_keywords" in q: | |
| keywords = q.pop("expected_keywords") | |
| if isinstance(keywords, list): | |
| q["answer_guidelines"] = ", ".join(keywords) | |
| else: | |
| q["answer_guidelines"] = str(keywords) | |
| q.setdefault("answer_guidelines", "") | |
| # Code question normalization | |
| elif q_type == "code": | |
| if "solution" in q: | |
| q["solution"] = str(q["solution"]) | |
| q.setdefault("starter_code", None) | |
| q.setdefault("explanation", "") | |
| normalized_questions.append(q) | |
| data["questions"] = normalized_questions | |
| return data | |
| def generate_exam(self, request: ExamGenerationRequest, context: str, llm, batch_size: int) -> List[QuestionUnion]: | |
| """ | |
| Generate a batch of questions from the LLM, ensuring valid QuestionUnion objects.Repairs incomplete MCQs automatically. | |
| """ | |
| # Prepare the prompt for the batch | |
| batch_request = request.model_copy() | |
| batch_request.total_questions = batch_size | |
| prompt = self.prompts.build_exam_generation_prompt(batch_request, context) | |
| raw_text = llm._call(prompt) | |
| if not raw_text: | |
| raise RuntimeError("LLM generation failed") | |
| cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip() | |
| try: | |
| exam_dict = self._extract_json(cleaned) | |
| exam_dict = self.normalize_exam_dict(exam_dict) | |
| questions = exam_dict.get("questions") or [] | |
| questions = questions[:batch_size] | |
| # Repair incomplete MCQs or missing fields | |
| repaired_questions = [] | |
| for q in questions: | |
| if not isinstance(q, dict): | |
| continue # skip invalid entries | |
| q_type = q.get("type") | |
| if q_type == "mcq": | |
| if not q.get("options"): | |
| self.logger.warning(f"Skipping MCQ with no options: {q}") | |
| continue | |
| if not q.get("correct_answer"): | |
| q["correct_answer"] = q["options"][0] # safe placeholder | |
| repaired_questions.append(q) | |
| # Convert to Pydantic QuestionUnion objects | |
| questions = parse_obj_as(List[QuestionUnion], repaired_questions) | |
| self.logger.info( | |
| "Batch requested=%d | received=%d | kept=%d", | |
| batch_size, | |
| len(exam_dict.get("questions", [])), | |
| len(questions), | |
| ) | |
| except json.JSONDecodeError: | |
| self.logger.error("Invalid JSON from LLM:\n%s", raw_text) | |
| raise | |
| return questions | |
| def evaluate_exam(self, request: ExamGenerationRequest, exam: ExamResponse, llm): | |
| prompt = self.prompts.build_exam_evaluation_prompt(request, exam) | |
| raw_text = llm._call(prompt) | |
| if not raw_text: | |
| raise RuntimeError("Evaluation generation failed") | |
| cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip() | |
| try: | |
| evaluation_dict = self._extract_json(cleaned) | |
| except json.JSONDecodeError: | |
| self.logger.error("Invalid evaluation JSON:\n%s", raw_text) | |
| raise | |
| return EvaluationResult.model_validate(evaluation_dict) | |
| def split_chunks_by_topic_batches(self, exam_chunks, num_batches): | |
| self.logger.info(f"Topics retrieved: {list(exam_chunks.keys())}") | |
| self.logger.info(f"Number of batches: {num_batches}") | |
| batches = [[] for _ in range(num_batches)] | |
| for topic, chunks in exam_chunks.items(): | |
| total_chunks = len(chunks) | |
| self.logger.info(f"Topic '{topic}' -> {total_chunks} chunks distributed across batches") | |
| for idx, chunk in enumerate(chunks): | |
| batch_index = idx % num_batches | |
| batches[batch_index].append(chunk) | |
| # Log batch composition | |
| for i, batch in enumerate(batches): | |
| topic_counter = defaultdict(int) | |
| for chunk in batch: | |
| topic = chunk.get("metadata", {}).get("topic", "unknown") | |
| topic_counter[topic] += 1 | |
| self.logger.info(f"Batch {i+1} contains {len(batch)} chunks -> {dict(topic_counter)}") | |
| return batches | |
| def exam_task(self, request_dict: dict) -> ExamResponse: | |
| """ | |
| Generate a full exam using batching, safety break, and validated QuestionUnion questions.Each batch receives a portion of the retrieved chunks. | |
| """ | |
| request = ExamGenerationRequest.model_validate(request_dict) | |
| # Prepare context from knowledge store | |
| topics_with_embeddings = self.prepare_topics_with_embeddings(request.topics) | |
| exam_chunks = self.store.retrieve_for_exam(topics_with_embeddings,request.username,request.course,request.references) | |
| # Determine number of batches | |
| num_batches = math.ceil(request.total_questions / self.BATCH_SIZE) | |
| self.logger.info(f"Raw exam_chunks structure: {type(exam_chunks)}") | |
| for k, v in exam_chunks.items(): | |
| self.logger.info(f"Topic={k} | type={type(v)} | len={len(v) if hasattr(v,'__len__') else 'NA'}") | |
| chunk_batches = self.split_chunks_by_topic_batches(exam_chunks,num_batches) | |
| feedback_context = "" | |
| best_exam = None | |
| best_score = 0 | |
| for attempt in range(self.MAX_GENERATION_ATTEMPTS): | |
| self.logger.info(f"Generating exam attempt {attempt+1}") | |
| remaining_distribution: Dict[QuestionType, int] = dict(request.question_types_distribution) | |
| all_questions: List[QuestionUnion] = [] | |
| batch_index = 0 | |
| # Batch generation loop | |
| while len(all_questions) < request.total_questions: | |
| remaining = request.total_questions - len(all_questions) | |
| batch_size = min(self.BATCH_SIZE, remaining) | |
| # Determine batch distribution | |
| batch_distribution: Dict[QuestionType, int] = {} | |
| slots_left = batch_size | |
| for qtype, count in remaining_distribution.items(): | |
| if slots_left <= 0: | |
| break | |
| take = min(count, slots_left) | |
| if take > 0: | |
| batch_distribution[qtype] = take | |
| slots_left -= take | |
| if not batch_distribution: | |
| break | |
| batch_request = request.model_copy() | |
| batch_request.total_questions = sum(batch_distribution.values()) | |
| batch_request.question_types_distribution = batch_distribution | |
| # Select chunk subset for this batch | |
| chunk_subset = chunk_batches[batch_index % len(chunk_batches)] | |
| self.logger.info(f"\n===== BATCH {batch_index+1} CHUNKS =====") | |
| for i, chunk in enumerate(chunk_subset): | |
| meta = chunk.get("metadata", {}) | |
| topic = meta.get("topic", "unknown") | |
| page = meta.get("page", "NA") | |
| # Try common text keys | |
| text = chunk.get("text") or chunk.get("content") or chunk.get("page_content") or "" | |
| preview = text[:200].replace("\n", " ") | |
| self.logger.info( | |
| f"Chunk {i+1} | Topic={topic} | Page={page} | Preview={preview}" | |
| ) | |
| self.logger.info("=====================================\n") | |
| batch_index += 1 | |
| batch_context = self.build_exam_context(chunk_subset) | |
| if feedback_context: | |
| batch_context += f"\n\nEvaluator Feedback:\n{feedback_context}" | |
| # Generate questions | |
| batch_questions = self.generate_exam(batch_request,batch_context,self.llm,batch_request.total_questions) | |
| # Filter generated questions | |
| for q in batch_questions: | |
| if remaining_distribution.get(q.type, 0) > 0: | |
| all_questions.append(q) | |
| remaining_distribution[q.type] -= 1 | |
| if len(all_questions) >= request.total_questions: | |
| break | |
| # Build final exam | |
| exam_dict = { | |
| "exam_id": request.exam_id, | |
| "difficulty": request.difficulty, | |
| "total_questions": request.total_questions, | |
| "expected_distribution": request.question_types_distribution, | |
| "questions": all_questions[:request.total_questions], | |
| } | |
| try: | |
| exam = ExamResponse.model_validate(exam_dict) | |
| except Exception as e: | |
| self.logger.error(f"Exam validation failed: {e}") | |
| raise | |
| evaluation = self.evaluate_exam(request, exam, self.llm) | |
| self.logger.info(f"Evaluation score: {evaluation.overall_score}") | |
| if evaluation.overall_score > best_score: | |
| best_score = evaluation.overall_score | |
| best_exam = exam | |
| if evaluation.overall_score >= self.PASS_THRESHOLD: | |
| break | |
| feedback_context = evaluation.feedback | |
| if best_exam is None: | |
| raise RuntimeError("Exam generation failed after retries") | |
| return best_exam | |
| def build_exam_context(self, exam_chunks) -> str: | |
| """ | |
| Accepts either: | |
| 1) {topic: [chunks]} | |
| 2) [chunks] | |
| """ | |
| # Normalize structure | |
| if isinstance(exam_chunks, list): | |
| topic_chunks = defaultdict(list) | |
| for c in exam_chunks: | |
| topic = c.get("metadata", {}).get("topic", "Unknown") | |
| topic_chunks[topic].append(c) | |
| exam_chunks = topic_chunks | |
| context_parts = [] | |
| total_length = 0 | |
| for topic, chunks in exam_chunks.items(): | |
| topic_header = f"\n### Topic: {topic}\n" | |
| if total_length + len(topic_header) > self.MAX_TOTAL_CONTEXT: | |
| break | |
| context_parts.append(topic_header) | |
| total_length += len(topic_header) | |
| for c in chunks: | |
| text = c.get("payload", {}).get("text", "") | |
| source = c.get("metadata", {}).get("source", "") | |
| bookmark = c.get("metadata", {}).get("bookmark_path", "") | |
| if not isinstance(text, str): | |
| continue | |
| if len(text) > self.MAX_CHUNK_CHARS: | |
| text = text[:self.MAX_CHUNK_CHARS] | |
| formatted_chunk = (f"[Source: {source} | Bookmark: {bookmark}]\n{text}\n") | |
| if total_length + len(formatted_chunk) > self.MAX_TOTAL_CONTEXT: | |
| break | |
| context_parts.append(formatted_chunk) | |
| total_length += len(formatted_chunk) | |
| return "\n".join(context_parts) | |
| def prepare_topics_with_embeddings(self, topics: List[str]): | |
| results = [] | |
| for topic in topics: | |
| try: | |
| embedding = self.embedding_provider.embed_text(topic) | |
| results.append((topic, embedding)) | |
| except Exception as e: | |
| self.logger.warning(f"Embedding failed for topic '{topic}': {e}") | |
| self.logger.info(f"Prepared {len(results)} topic embeddings") | |
| return results | |