| """Grounded learning features: summarization, quiz, and flashcard generation.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
|
|
| from loguru import logger |
| from pydantic import ValidationError |
|
|
| from src.config import settings |
| from src.llm import invoke_llm |
| from src.rag import fetch_all_chunks, format_citations, render_prompt, retrieve |
| from src.schemas import Flashcard, FlashcardSet, QuizItem, QuizSet, RetrievedChunk, Summary |
|
|
| SUMMARY_SINGLE_TEMPLATE = "summary_single.jinja2" |
| SUMMARY_MAP_TEMPLATE = "summary_map.jinja2" |
| SUMMARY_REDUCE_TEMPLATE = "summary_reduce.jinja2" |
| QUIZ_TEMPLATE = "quiz.jinja2" |
| FLASHCARDS_TEMPLATE = "flashcards.jinja2" |
|
|
|
|
| def _parse_json(text: str) -> dict | list: |
| """Parse JSON object/array from model output, allowing optional markdown code fences.""" |
| cleaned = text.strip() |
| if cleaned.startswith("```"): |
| cleaned = cleaned.split("\n", 1)[-1].removesuffix("```").strip() |
|
|
| try: |
| obj = json.loads(cleaned) |
| except json.JSONDecodeError as e: |
| raise RuntimeError(f"Invalid JSON from model output: {cleaned}") from e |
|
|
| if not isinstance(obj, (dict, list)): |
| raise RuntimeError(f"Expected JSON object or array, got {type(obj).__name__}.") |
|
|
| return obj |
|
|
|
|
| def _resolve_target( |
| document: str | None, |
| query: str | None, |
| filters: dict[str, object] | None, |
| k: int | None, |
| retrieval_k: int, |
| ) -> tuple[list[RetrievedChunk], str, str | None]: |
| """Resolve input options into (chunks, scope, target_label).""" |
| effective_filters: dict[str, object] = dict(filters or {}) |
| if document: |
| effective_filters["filename"] = document |
|
|
| if query: |
| chunks = retrieve(query, k=k or retrieval_k, filters=effective_filters) |
| target: str | None = query |
| scope = "query" |
| elif effective_filters: |
| chunks = fetch_all_chunks(filters=effective_filters) |
| target = ", ".join(f"{fk}={fv}" for fk, fv in effective_filters.items()) |
| scope = "document" if document else "filter" |
| else: |
| chunks = fetch_all_chunks(filters=None) |
| target = None |
| scope = "corpus" |
|
|
| return chunks, scope, target |
|
|
|
|
| def _validate_items( |
| payload: object, |
| key: str, |
| model_class: type, |
| dedup_field: str, |
| label: str, |
| valid_markers: set[str], |
| ) -> list: |
| if not isinstance(payload, dict): |
| raise RuntimeError(f"Expected JSON object for {label}.") |
| raw_items = payload.get(key) |
| if not isinstance(raw_items, list): |
| raise RuntimeError(f"Expected '{key}' to be a list for {label}.") |
|
|
| items: list = [] |
| seen: set[str] = set() |
| for raw in raw_items: |
| if not isinstance(raw, dict): |
| continue |
| try: |
| item = model_class.model_validate(raw) |
| except ValidationError as e: |
| logger.warning("Dropping invalid {}: {}", label, e) |
| continue |
| norm = str(getattr(item, dedup_field, "")).strip().lower() |
| if not norm or norm in seen: |
| continue |
| seen.add(norm) |
| markers = [m for m in item.source_markers if m in valid_markers] |
| items.append(item.model_copy(update={"source_markers": markers})) |
|
|
| if not items: |
| raise RuntimeError(f"No valid {label} produced.") |
| return items |
|
|
|
|
| def _validate_summary_payload(payload: object) -> tuple[str, list[str]]: |
| if not isinstance(payload, dict): |
| raise RuntimeError("Expected a JSON object for summary.") |
| summary = payload.get("summary") |
| key_points = payload.get("key_points", []) |
| if not isinstance(summary, str): |
| raise RuntimeError("Summary payload missing 'summary' string.") |
| if not isinstance(key_points, list) or not all(isinstance(x, str) for x in key_points): |
| raise RuntimeError("Summary payload 'key_points' must be a list of strings.") |
| return summary.strip(), [kp.strip() for kp in key_points if kp.strip()] |
|
|
|
|
| def summarize( |
| document: str | None = None, |
| query: str | None = None, |
| filters: dict[str, object] | None = None, |
| k: int | None = None, |
| ) -> Summary: |
| """Grounded summary; uses map-reduce when chunk count exceeds batch size.""" |
| chunks, scope, target = _resolve_target( |
| document=document, |
| query=query, |
| filters=filters, |
| k=k, |
| retrieval_k=settings.summarize_retrieval_k, |
| ) |
|
|
| if not chunks: |
| raise RuntimeError("No chunks available for summarization.") |
|
|
| batch_size = settings.summarize_batch_size |
| if len(chunks) <= batch_size: |
| prompt = render_prompt(SUMMARY_SINGLE_TEMPLATE, chunks=chunks) |
| payload = _parse_json(invoke_llm(prompt)) |
| summary_text, key_points = _validate_summary_payload(payload) |
| else: |
| n_batches = (len(chunks) + batch_size - 1) // batch_size |
| partials: list[dict] = [] |
|
|
| for batch_index, start in enumerate(range(0, len(chunks), batch_size), start=1): |
| logger.info("Summarizing batch {}/{}", batch_index, n_batches) |
| batch = chunks[start : start + batch_size] |
| prompt = render_prompt(SUMMARY_MAP_TEMPLATE, chunks=batch) |
| payload = _parse_json(invoke_llm(prompt)) |
| summary_text, key_points = _validate_summary_payload(payload) |
| partials.append({"summary": summary_text, "key_points": key_points}) |
|
|
| reduce_prompt = render_prompt(SUMMARY_REDUCE_TEMPLATE, partials=partials) |
| payload = _parse_json(invoke_llm(reduce_prompt)) |
| summary_text, key_points = _validate_summary_payload(payload) |
|
|
| return Summary( |
| scope=scope, |
| target=target, |
| summary=summary_text, |
| key_points=key_points, |
| citations=format_citations(chunks), |
| ) |
|
|
|
|
| def generate_quiz( |
| document: str | None = None, |
| query: str | None = None, |
| filters: dict[str, object] | None = None, |
| count: int | None = None, |
| k: int | None = None, |
| ) -> QuizSet: |
| """Grounded multiple-choice quiz; raises RuntimeError if output is unparseable.""" |
| chunks, scope, target = _resolve_target( |
| document=document, |
| query=query, |
| filters=filters, |
| k=k, |
| retrieval_k=settings.generation_retrieval_k, |
| ) |
| if not chunks: |
| raise RuntimeError("No chunks available for quiz generation.") |
|
|
| n = count or settings.quiz_default_count |
| valid_markers = {f"S{i}" for i in range(1, len(chunks) + 1)} |
|
|
| prompt = render_prompt(QUIZ_TEMPLATE, chunks=chunks, count=n) |
| payload = _parse_json(invoke_llm(prompt)) |
| items = _validate_items(payload, "items", QuizItem, "question", "quiz items", valid_markers) |
|
|
| return QuizSet( |
| scope=scope, |
| target=target, |
| items=items, |
| citations=format_citations(chunks), |
| ) |
|
|
|
|
| def generate_flashcards( |
| document: str | None = None, |
| query: str | None = None, |
| filters: dict[str, object] | None = None, |
| count: int | None = None, |
| k: int | None = None, |
| ) -> FlashcardSet: |
| """Grounded flashcard set for spaced repetition; raises RuntimeError if output is unparseable.""" |
| chunks, scope, target = _resolve_target( |
| document=document, |
| query=query, |
| filters=filters, |
| k=k, |
| retrieval_k=settings.generation_retrieval_k, |
| ) |
| if not chunks: |
| raise RuntimeError("No chunks available for flashcard generation.") |
|
|
| n = count or settings.flashcards_default_count |
| valid_markers = {f"S{i}" for i in range(1, len(chunks) + 1)} |
|
|
| prompt = render_prompt(FLASHCARDS_TEMPLATE, chunks=chunks, count=n) |
| payload = _parse_json(invoke_llm(prompt)) |
| cards = _validate_items(payload, "cards", Flashcard, "front", "flashcards", valid_markers) |
|
|
| return FlashcardSet( |
| scope=scope, |
| target=target, |
| cards=cards, |
| citations=format_citations(chunks), |
| ) |
|
|
|
|