Spaces:
Sleeping
Sleeping
| """ | |
| Quiz generator using RAG context from ingested documents. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| import requests | |
| from src.ingestion.vectorstore import ChromaAdapter | |
| load_dotenv() | |
| SUPPORTED_QUIZ_LLM_PROVIDERS = {"openai", "groq", "ollama"} | |
| DEFAULT_QUIZ_MODELS = { | |
| "openai": "gpt-4o-mini", | |
| "groq": "llama-3.1-8b-instant", | |
| "ollama": "qwen2.5:3b", | |
| } | |
| QUIZ_SYSTEM_PROMPT = ( | |
| "You are an expert quiz generator. Create clear, educational, and well-structured " | |
| "multiple-choice questions. Return valid JSON only with a top-level 'questions' array." | |
| ) | |
| class QuizGenerator: | |
| def __init__( | |
| self, | |
| api_key: Optional[str] = None, | |
| model: Optional[str] = None, | |
| llm_provider: Optional[str] = None, | |
| ): | |
| """ | |
| Initialize quiz generator. | |
| Args: | |
| api_key: OpenAI API key (defaults to OPENAI_API_KEY from .env) | |
| model: LLM model to use (defaults to LLM_MODEL from .env) | |
| llm_provider: Quiz LLM provider (openai, groq, ollama) | |
| """ | |
| self.llm_provider = (llm_provider or os.getenv("QUIZ_LLM_PROVIDER", "openai")).strip().lower() | |
| if self.llm_provider not in SUPPORTED_QUIZ_LLM_PROVIDERS: | |
| raise ValueError( | |
| f"Unsupported QUIZ_LLM_PROVIDER='{self.llm_provider}'. " | |
| f"Choose from: {sorted(SUPPORTED_QUIZ_LLM_PROVIDERS)}" | |
| ) | |
| self.model = self._resolve_model_name(model) | |
| self._openai_client: OpenAI | None = None | |
| self._groq_client = None | |
| self._ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://127.0.0.1:11434").rstrip("/") | |
| if self.llm_provider == "openai": | |
| self.api_key = api_key or os.getenv("OPENAI_API_KEY") | |
| self._openai_client = OpenAI(api_key=self.api_key) | |
| elif self.llm_provider == "groq": | |
| from groq import Groq | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if not groq_api_key: | |
| raise ValueError("GROQ_API_KEY is required when QUIZ_LLM_PROVIDER=groq") | |
| self._groq_client = Groq(api_key=groq_api_key) | |
| else: | |
| self.api_key = None | |
| # Default settings from .env | |
| self.default_num_questions = int(os.getenv("DEFAULT_QUIZ_QUESTIONS", "5")) | |
| self.default_difficulty = os.getenv("DEFAULT_QUIZ_DIFFICULTY", "medium") | |
| def _resolve_model_name(self, explicit_model: Optional[str]) -> str: | |
| if explicit_model and explicit_model.strip(): | |
| return explicit_model.strip() | |
| configured = os.getenv("QUIZ_LLM_MODEL", "").strip() | |
| if configured: | |
| return configured | |
| legacy = os.getenv("LLM_MODEL", "").strip() | |
| if legacy: | |
| return legacy | |
| return DEFAULT_QUIZ_MODELS.get(self.llm_provider, "gpt-4o-mini") | |
| def generate_quiz( | |
| self, | |
| user_id: str, | |
| notebook_id: str, | |
| num_questions: Optional[int] = None, | |
| difficulty: Optional[str] = None, | |
| topic_focus: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate quiz questions from notebook content. | |
| Args: | |
| user_id: User identifier | |
| notebook_id: Notebook to generate quiz from | |
| num_questions: Number of questions (defaults to DEFAULT_QUIZ_QUESTIONS) | |
| difficulty: "easy", "medium", or "hard" (defaults to DEFAULT_QUIZ_DIFFICULTY) | |
| topic_focus: Optional specific topic to focus on | |
| Returns: | |
| Dict with questions, answers, and metadata | |
| """ | |
| num_questions = num_questions or self.default_num_questions | |
| difficulty = difficulty or self.default_difficulty | |
| print(f"🎯 Generating {num_questions} {difficulty} quiz questions...") | |
| # 1. Retrieve context from vector store | |
| context = self._get_quiz_context(user_id, notebook_id, topic_focus) | |
| if not context: | |
| return { | |
| "error": "No content found in notebook. Please ingest documents first.", | |
| "questions": [], | |
| "metadata": {}, | |
| } | |
| # 2. Generate quiz using LLM | |
| print("🤖 Generating questions with LLM...") | |
| quiz_data = self._generate_with_llm(context, num_questions, difficulty) | |
| questions = quiz_data.get("questions", []) if isinstance(quiz_data, dict) else [] | |
| if not questions: | |
| return { | |
| "error": "Failed to generate quiz questions from notebook context.", | |
| "questions": [], | |
| "metadata": { | |
| "notebook_id": notebook_id, | |
| "num_questions": num_questions, | |
| "difficulty": difficulty, | |
| "topic_focus": topic_focus, | |
| "llm_provider": self.llm_provider, | |
| "llm_model": self.model, | |
| "generated_at": datetime.now(timezone.utc).isoformat(), | |
| }, | |
| } | |
| # 3. Format and return | |
| return { | |
| "questions": questions, | |
| "metadata": { | |
| "notebook_id": notebook_id, | |
| "num_questions": num_questions, | |
| "difficulty": difficulty, | |
| "topic_focus": topic_focus, | |
| "llm_provider": self.llm_provider, | |
| "llm_model": self.model, | |
| "generated_at": datetime.now(timezone.utc).isoformat(), | |
| }, | |
| } | |
| def _get_quiz_context( | |
| self, | |
| user_id: str, | |
| notebook_id: str, | |
| topic_focus: Optional[str] = None, | |
| ) -> str: | |
| """Retrieve relevant context from notebook.""" | |
| data_base = os.getenv("STORAGE_BASE_DIR", "data") | |
| chroma_dir = str( | |
| Path(data_base) / "users" / user_id / "notebooks" / notebook_id / "chroma" | |
| ) | |
| if not Path(chroma_dir).exists(): | |
| print(f"⚠️ Chroma directory not found: {chroma_dir}") | |
| return "" | |
| store = ChromaAdapter(persist_directory=chroma_dir) | |
| # Get diverse chunks for quiz generation | |
| if topic_focus: | |
| sample_queries = [topic_focus] | |
| else: | |
| sample_queries = [ | |
| "main concepts and definitions", | |
| "key principles and theories", | |
| "important facts and details", | |
| "examples and applications", | |
| "processes and mechanisms", | |
| ] | |
| all_chunks: List[str] = [] | |
| for query in sample_queries: | |
| try: | |
| results = store.query(user_id, notebook_id, query, top_k=3) | |
| for _, _, chunk_data in results: | |
| all_chunks.append(chunk_data["document"]) | |
| except Exception as e: | |
| print(f"⚠️ Error querying: {e}") | |
| continue | |
| if not all_chunks: | |
| return "" | |
| # Deduplicate and combine | |
| unique_chunks = list(set(all_chunks)) | |
| context = "\n\n".join(unique_chunks[:10]) # Use top 10 unique chunks | |
| print(f"✓ Retrieved {len(unique_chunks)} unique chunks ({len(context)} chars)") | |
| return context | |
| def _generate_with_llm( | |
| self, | |
| context: str, | |
| num_questions: int, | |
| difficulty: str, | |
| ) -> Dict[str, Any]: | |
| """Generate quiz using LLM.""" | |
| prompt = self._build_quiz_prompt(context, num_questions, difficulty) | |
| try: | |
| raw_response = self._generate_quiz_json(prompt) | |
| payload = self._extract_json_object(raw_response) | |
| questions = payload.get("questions") if isinstance(payload, dict) else None | |
| if not isinstance(questions, list): | |
| return {"questions": []} | |
| return { | |
| "questions": self._normalize_questions( | |
| questions=questions, | |
| expected_count=num_questions, | |
| difficulty=difficulty, | |
| ) | |
| } | |
| except Exception as e: | |
| print(f"❌ Error generating quiz: {e}") | |
| return {"questions": []} | |
| def _generate_quiz_json(self, prompt: str) -> str: | |
| if self.llm_provider == "openai": | |
| assert self._openai_client is not None | |
| response = self._openai_client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": QUIZ_SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.7, | |
| response_format={"type": "json_object"}, | |
| ) | |
| return str(response.choices[0].message.content or "") | |
| if self.llm_provider == "groq": | |
| assert self._groq_client is not None | |
| response = self._groq_client.chat.completions.create( | |
| model=self.model, | |
| messages=[ | |
| {"role": "system", "content": QUIZ_SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=0.7, | |
| ) | |
| return str(response.choices[0].message.content or "") | |
| payload = { | |
| "model": self.model, | |
| "system": QUIZ_SYSTEM_PROMPT, | |
| "prompt": prompt, | |
| "stream": False, | |
| "options": {"temperature": 0.7}, | |
| } | |
| response = requests.post( | |
| f"{self._ollama_base_url}/api/generate", | |
| json=payload, | |
| timeout=120, | |
| ) | |
| response.raise_for_status() | |
| body = response.json() | |
| return str(body.get("response", "")) | |
| def _extract_json_object(self, raw_response: str) -> Dict[str, Any]: | |
| content = str(raw_response or "").strip() | |
| if not content: | |
| return {} | |
| if content.startswith("```"): | |
| content = re.sub(r"^```(?:json)?\s*", "", content) | |
| content = re.sub(r"\s*```$", "", content) | |
| try: | |
| parsed = json.loads(content) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except json.JSONDecodeError: | |
| pass | |
| match = re.search(r"\{.*\}", content, re.DOTALL) | |
| if match: | |
| try: | |
| parsed = json.loads(match.group(0)) | |
| if isinstance(parsed, dict): | |
| return parsed | |
| except json.JSONDecodeError: | |
| return {} | |
| return {} | |
| def _normalize_questions( | |
| self, | |
| questions: List[Any], | |
| expected_count: int, | |
| difficulty: str, | |
| ) -> List[Dict[str, Any]]: | |
| normalized: List[Dict[str, Any]] = [] | |
| for item in questions: | |
| if not isinstance(item, dict): | |
| continue | |
| prompt = str(item.get("question", "")).strip() | |
| if not prompt: | |
| continue | |
| options = self._normalize_options(item.get("options")) | |
| if not options: | |
| continue | |
| answer = self._normalize_answer_letter(str(item.get("correct_answer", "")).strip()) | |
| explanation = str(item.get("explanation", "")).strip() | |
| topic = str(item.get("topic", "")).strip() | |
| normalized.append( | |
| { | |
| "id": len(normalized) + 1, | |
| "question": prompt, | |
| "options": options, | |
| "correct_answer": answer or "N/A", | |
| "explanation": explanation, | |
| "difficulty": difficulty, | |
| "topic": topic, | |
| } | |
| ) | |
| if len(normalized) >= expected_count: | |
| break | |
| return normalized | |
| def _normalize_options(self, options_raw: Any) -> List[str]: | |
| """ | |
| Normalize options into 2-6 labeled choices: A) ... B) ... etc. | |
| Accepts list, dict, or multiline string. | |
| """ | |
| parsed: List[str] = [] | |
| if isinstance(options_raw, list): | |
| parsed = [str(opt).strip() for opt in options_raw if str(opt).strip()] | |
| elif isinstance(options_raw, dict): | |
| for key in sorted(options_raw.keys()): | |
| value = str(options_raw.get(key, "")).strip() | |
| if not value: | |
| continue | |
| parsed.append(f"{str(key).strip().upper()}) {value}") | |
| elif isinstance(options_raw, str): | |
| lines = [line.strip() for line in options_raw.replace("\r", "\n").split("\n")] | |
| parsed = [line for line in lines if line] | |
| if not parsed: | |
| return [] | |
| cleaned: List[str] = [] | |
| letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
| for idx, text in enumerate(parsed): | |
| body = self._option_text_only(text) | |
| if not body: | |
| continue | |
| label = letters[idx] if idx < len(letters) else str(idx + 1) | |
| cleaned.append(f"{label}) {body}") | |
| # Keep a reasonable range and ensure common quiz shape is preserved. | |
| return cleaned[:6] | |
| def _option_text_only(self, value: str) -> str: | |
| text = str(value or "").strip() | |
| # Strip prefixes like "A)", "A.", "(A)", "1)", "1." | |
| text = re.sub(r"^\(?[A-Z0-9]\)?[\.\):\-]\s*", "", text, flags=re.IGNORECASE) | |
| return text.strip() | |
| def _normalize_answer_letter(self, value: str) -> str: | |
| match = re.search(r"[A-Z]", value.upper()) | |
| return match.group(0) if match else "" | |
| def _build_quiz_prompt(self, context: str, num_questions: int, difficulty: str) -> str: | |
| """Build the quiz generation prompt.""" | |
| difficulty_guidelines = { | |
| "easy": "Focus on basic recall and understanding. Questions should test fundamental concepts.", | |
| "medium": "Mix recall with application. Questions should require understanding and basic analysis.", | |
| "hard": "Focus on analysis, synthesis, and application. Questions should require deep understanding.", | |
| } | |
| guideline = difficulty_guidelines.get(difficulty, difficulty_guidelines["medium"]) | |
| return f""" | |
| Based on the following content, generate {num_questions} multiple-choice quiz questions at {difficulty} difficulty level. | |
| Content: | |
| {context} | |
| Difficulty Guidelines: | |
| {guideline} | |
| Generate questions in this exact JSON format: | |
| {{ | |
| "questions": [ | |
| {{ | |
| "id": 1, | |
| "question": "Clear question text here?", | |
| "options": [ | |
| "A) First option", | |
| "B) Second option", | |
| "C) Third option", | |
| "D) Fourth option" | |
| ], | |
| "correct_answer": "A", | |
| "explanation": "Clear explanation of why this answer is correct and why others are wrong.", | |
| "difficulty": "{difficulty}", | |
| "topic": "Main topic this question covers" | |
| }} | |
| ] | |
| }} | |
| Requirements: | |
| - Generate exactly {num_questions} questions | |
| - All options must be plausible and relevant | |
| - Correct answer must be clearly identifiable | |
| - Explanations should be educational and thorough | |
| - Questions should test {difficulty}-level understanding | |
| - Vary question types (definition, application, analysis) | |
| - Ensure questions are based solely on the provided content | |
| """ | |
| def format_quiz_markdown(self, quiz_data: Dict[str, Any], title: str | None = None) -> str: | |
| """Render quiz questions and answer key as Markdown.""" | |
| resolved_title = title or "Quiz" | |
| questions = quiz_data.get("questions", []) | |
| metadata = quiz_data.get("metadata", {}) if isinstance(quiz_data.get("metadata"), dict) else {} | |
| difficulty = metadata.get("difficulty") | |
| lines: list[str] = [f"# {resolved_title}", ""] | |
| if difficulty: | |
| lines.append(f"Difficulty: **{difficulty}**") | |
| lines.append("") | |
| lines.append("## Questions") | |
| lines.append("") | |
| for idx, question in enumerate(questions, 1): | |
| prompt = str(question.get("question", "")).strip() | |
| lines.append(f"### {idx}. {prompt or 'Question'}") | |
| options = question.get("options", []) | |
| for option in options if isinstance(options, list) else []: | |
| option_text = str(option).strip() | |
| if option_text: | |
| lines.append(f"- {option_text}") | |
| lines.append("") | |
| lines.append("## Answer Key") | |
| lines.append("") | |
| for idx, question in enumerate(questions, 1): | |
| correct = str(question.get("correct_answer", "")).strip() or "N/A" | |
| explanation = str(question.get("explanation", "")).strip() | |
| lines.append(f"{idx}. **{correct}**") | |
| if explanation: | |
| lines.append(f" - {explanation}") | |
| lines.append("") | |
| return "\n".join(lines) | |
| def save_quiz(self, quiz_markdown: str, user_id: str, notebook_id: str) -> str: | |
| """Save generated quiz Markdown to file.""" | |
| storage_base = os.getenv("STORAGE_BASE_DIR", "data") | |
| quiz_dir = Path(storage_base) / "users" / user_id / "notebooks" / notebook_id / "artifacts" / "quizzes" | |
| quiz_dir.mkdir(parents=True, exist_ok=True) | |
| timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") | |
| filename = f"quiz_{timestamp}.md" | |
| filepath = quiz_dir / filename | |
| filepath.write_text(quiz_markdown, encoding="utf-8") | |
| print(f"✓ Quiz saved to: {filepath}") | |
| return str(filepath) | |
| # === CLI for testing === | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Generate quiz from notebook") | |
| parser.add_argument("--user", required=True, help="User ID") | |
| parser.add_argument("--notebook", required=True, help="Notebook ID") | |
| parser.add_argument("--num-questions", type=int, help="Number of questions") | |
| parser.add_argument("--difficulty", choices=["easy", "medium", "hard"], help="Difficulty level") | |
| parser.add_argument("--topic", help="Focus on specific topic") | |
| parser.add_argument("--save", action="store_true", help="Save quiz to file") | |
| args = parser.parse_args() | |
| generator = QuizGenerator() | |
| quiz = generator.generate_quiz( | |
| args.user, | |
| args.notebook, | |
| args.num_questions, | |
| args.difficulty, | |
| args.topic, | |
| ) | |
| if "error" in quiz: | |
| print(f"\n❌ {quiz['error']}") | |
| else: | |
| print(f"\n✓ Generated {len(quiz['questions'])} questions") | |
| print(json.dumps(quiz, indent=2)) | |
| if args.save: | |
| markdown = generator.format_quiz_markdown(quiz, title="Quiz") | |
| generator.save_quiz(markdown, args.user, args.notebook) | |