duongtruongbinh's picture
Init project
16fa4e7
"""Grounded learning features: summarization, quiz, and flashcard generation."""
from __future__ import annotations
import json
from loguru import logger
from pydantic import ValidationError
from src.config import settings
from src.llm import invoke_llm
from src.rag import fetch_all_chunks, format_citations, render_prompt, retrieve
from src.schemas import Flashcard, FlashcardSet, QuizItem, QuizSet, RetrievedChunk, Summary
SUMMARY_SINGLE_TEMPLATE = "summary_single.jinja2"
SUMMARY_MAP_TEMPLATE = "summary_map.jinja2"
SUMMARY_REDUCE_TEMPLATE = "summary_reduce.jinja2"
QUIZ_TEMPLATE = "quiz.jinja2"
FLASHCARDS_TEMPLATE = "flashcards.jinja2"
def _parse_json(text: str) -> dict | list:
"""Parse JSON object/array from model output, allowing optional markdown code fences."""
cleaned = text.strip()
if cleaned.startswith("```"):
cleaned = cleaned.split("\n", 1)[-1].removesuffix("```").strip()
try:
obj = json.loads(cleaned)
except json.JSONDecodeError as e:
raise RuntimeError(f"Invalid JSON from model output: {cleaned}") from e
if not isinstance(obj, (dict, list)):
raise RuntimeError(f"Expected JSON object or array, got {type(obj).__name__}.")
return obj
def _resolve_target(
document: str | None,
query: str | None,
filters: dict[str, object] | None,
k: int | None,
retrieval_k: int,
) -> tuple[list[RetrievedChunk], str, str | None]:
"""Resolve input options into (chunks, scope, target_label)."""
effective_filters: dict[str, object] = dict(filters or {})
if document:
effective_filters["filename"] = document
if query:
chunks = retrieve(query, k=k or retrieval_k, filters=effective_filters)
target: str | None = query
scope = "query"
elif effective_filters:
chunks = fetch_all_chunks(filters=effective_filters)
target = ", ".join(f"{fk}={fv}" for fk, fv in effective_filters.items())
scope = "document" if document else "filter"
else:
chunks = fetch_all_chunks(filters=None)
target = None
scope = "corpus"
return chunks, scope, target
def _validate_items(
payload: object,
key: str,
model_class: type,
dedup_field: str,
label: str,
valid_markers: set[str],
) -> list:
if not isinstance(payload, dict):
raise RuntimeError(f"Expected JSON object for {label}.")
raw_items = payload.get(key)
if not isinstance(raw_items, list):
raise RuntimeError(f"Expected '{key}' to be a list for {label}.")
items: list = []
seen: set[str] = set()
for raw in raw_items:
if not isinstance(raw, dict):
continue
try:
item = model_class.model_validate(raw)
except ValidationError as e:
logger.warning("Dropping invalid {}: {}", label, e)
continue
norm = str(getattr(item, dedup_field, "")).strip().lower()
if not norm or norm in seen:
continue
seen.add(norm)
markers = [m for m in item.source_markers if m in valid_markers]
items.append(item.model_copy(update={"source_markers": markers}))
if not items:
raise RuntimeError(f"No valid {label} produced.")
return items
def _validate_summary_payload(payload: object) -> tuple[str, list[str]]:
if not isinstance(payload, dict):
raise RuntimeError("Expected a JSON object for summary.")
summary = payload.get("summary")
key_points = payload.get("key_points", [])
if not isinstance(summary, str):
raise RuntimeError("Summary payload missing 'summary' string.")
if not isinstance(key_points, list) or not all(isinstance(x, str) for x in key_points):
raise RuntimeError("Summary payload 'key_points' must be a list of strings.")
return summary.strip(), [kp.strip() for kp in key_points if kp.strip()]
def summarize(
document: str | None = None,
query: str | None = None,
filters: dict[str, object] | None = None,
k: int | None = None,
) -> Summary:
"""Grounded summary; uses map-reduce when chunk count exceeds batch size."""
chunks, scope, target = _resolve_target(
document=document,
query=query,
filters=filters,
k=k,
retrieval_k=settings.summarize_retrieval_k,
)
if not chunks:
raise RuntimeError("No chunks available for summarization.")
batch_size = settings.summarize_batch_size
if len(chunks) <= batch_size:
prompt = render_prompt(SUMMARY_SINGLE_TEMPLATE, chunks=chunks)
payload = _parse_json(invoke_llm(prompt))
summary_text, key_points = _validate_summary_payload(payload)
else:
n_batches = (len(chunks) + batch_size - 1) // batch_size
partials: list[dict] = []
for batch_index, start in enumerate(range(0, len(chunks), batch_size), start=1):
logger.info("Summarizing batch {}/{}", batch_index, n_batches)
batch = chunks[start : start + batch_size]
prompt = render_prompt(SUMMARY_MAP_TEMPLATE, chunks=batch)
payload = _parse_json(invoke_llm(prompt))
summary_text, key_points = _validate_summary_payload(payload)
partials.append({"summary": summary_text, "key_points": key_points})
reduce_prompt = render_prompt(SUMMARY_REDUCE_TEMPLATE, partials=partials)
payload = _parse_json(invoke_llm(reduce_prompt))
summary_text, key_points = _validate_summary_payload(payload)
return Summary(
scope=scope,
target=target,
summary=summary_text,
key_points=key_points,
citations=format_citations(chunks),
)
def generate_quiz(
document: str | None = None,
query: str | None = None,
filters: dict[str, object] | None = None,
count: int | None = None,
k: int | None = None,
) -> QuizSet:
"""Grounded multiple-choice quiz; raises RuntimeError if output is unparseable."""
chunks, scope, target = _resolve_target(
document=document,
query=query,
filters=filters,
k=k,
retrieval_k=settings.generation_retrieval_k,
)
if not chunks:
raise RuntimeError("No chunks available for quiz generation.")
n = count or settings.quiz_default_count
valid_markers = {f"S{i}" for i in range(1, len(chunks) + 1)}
prompt = render_prompt(QUIZ_TEMPLATE, chunks=chunks, count=n)
payload = _parse_json(invoke_llm(prompt))
items = _validate_items(payload, "items", QuizItem, "question", "quiz items", valid_markers)
return QuizSet(
scope=scope,
target=target,
items=items,
citations=format_citations(chunks),
)
def generate_flashcards(
document: str | None = None,
query: str | None = None,
filters: dict[str, object] | None = None,
count: int | None = None,
k: int | None = None,
) -> FlashcardSet:
"""Grounded flashcard set for spaced repetition; raises RuntimeError if output is unparseable."""
chunks, scope, target = _resolve_target(
document=document,
query=query,
filters=filters,
k=k,
retrieval_k=settings.generation_retrieval_k,
)
if not chunks:
raise RuntimeError("No chunks available for flashcard generation.")
n = count or settings.flashcards_default_count
valid_markers = {f"S{i}" for i in range(1, len(chunks) + 1)}
prompt = render_prompt(FLASHCARDS_TEMPLATE, chunks=chunks, count=n)
payload = _parse_json(invoke_llm(prompt))
cards = _validate_items(payload, "cards", Flashcard, "front", "flashcards", valid_markers)
return FlashcardSet(
scope=scope,
target=target,
cards=cards,
citations=format_citations(chunks),
)