Clone_Lm / backend /artifacts.py
skumar54's picture
NotebookLM clone: Gradio app, backend, Gemini artifacts
9c9ce67
"""
Artifact generation: report, quiz, podcast (transcript + mp3).
Uses Gemini API only for text; grounded in retrieved chunks with [1], [2] citations.
Citations block is built from chunk metadata (never trust model-made bibliographies).
"""
import time
from typing import Any, Dict, List, Optional, Set
from backend.config import (
ARTIFACT_PODCAST,
ARTIFACT_QUIZ,
ARTIFACT_REPORT,
TOP_K,
)
from backend.gemini_client import (
build_citations_block,
generate_with_gemini,
is_gemini_error_response,
parse_citation_numbers,
)
from backend.retriever import retrieve
from backend.storage import (
artifacts_index_path,
podcasts_dir,
quizzes_dir,
reports_dir,
)
from backend.utils import new_uuid, read_json, write_json, logger
from backend import tts as tts_module
# Artifact-specific retrieval queries (used to fetch context for Gemini)
REPORT_QUERY = "Summarize and explain the most important concepts"
QUIZ_QUERY = "Create assessment questions covering the key concepts"
PODCAST_QUERY = "Create an engaging conversational explanation of key concepts"
# Max tokens per artifact type (Gemini)
MAX_TOKENS_REPORT = 2000
MAX_TOKENS_QUIZ = 2500
MAX_TOKENS_PODCAST = 3000
def _now_iso() -> str:
from datetime import datetime
return datetime.utcnow().isoformat() + "Z"
def _retrieve_chunks(
username: str,
notebook_id: str,
query: str,
extra_instruction: str,
strategy: str = "similarity",
) -> List[dict]:
"""Retrieve chunks (only from enabled sources). Query + extra_instruction used for retrieval."""
combined = f"{query}. {extra_instruction}".strip() if extra_instruction else query
chunks, _ = retrieve(username, notebook_id, combined, top_k=TOP_K * 2, strategy=strategy)
return chunks
def _build_numbered_context(chunks: List[dict]) -> str:
"""Build context as [1] <chunk> (source_name, page/slide/url), [2] ..., for Gemini."""
parts = []
for i, c in enumerate(chunks, start=1):
meta = c.get("metadata", {}) or {}
name = meta.get("source_name", "Source")
loc = meta.get("page_or_slide", "")
if loc == "web":
loc_label = "url"
elif loc:
loc_label = f"page {loc}"
else:
loc_label = "β€”"
doc = (c.get("document") or "").strip()
parts.append(f"[{i}] {doc}\n(Source: {name}, {loc_label})")
return "\n\n---\n\n".join(parts)
def _append_citations_to_markdown(body: str, used_numbers: Set[int], chunks: List[dict]) -> str:
"""Append Citations section to markdown; we map [n] from chunk metadata (never trust model bib)."""
block = build_citations_block(used_numbers, chunks)
if not block:
return body.rstrip()
return body.rstrip() + "\n\n" + block
def generate_report(
username: str,
notebook_id: str,
extra_instruction: str = "",
strategy: str = "similarity",
) -> Dict[str, Any]:
"""Generate report via Gemini. Context-only; citations [1], [2] from retrieved chunks."""
chunks = _retrieve_chunks(username, notebook_id, REPORT_QUERY, extra_instruction, strategy)
if not chunks:
return {"error": "No sources to generate report from. Add and enable sources in this notebook."}
numbered_ctx = _build_numbered_context(chunks)
system = (
"You write clear, structured reports in Markdown. You may ONLY use the provided numbered context. "
"Do not invent facts. If the context is insufficient to answer a point, say so. "
"Use citations by referencing the numbered chunks: [1], [2], etc. at the end of the sentence or paragraph. "
"Output must include: a short executive summary, main sections with headings and bullet points, and key takeaways. "
"Use only the given numbers that correspond to the context chunks."
)
user = f"Context (cite using [1], [2], ...):\n\n{numbered_ctx}\n\nWrite a structured report in Markdown with executive summary, headings, bullet points, and key takeaways. Cite sources with [n]."
if extra_instruction:
user += f"\n\nAdditional instruction: {extra_instruction}"
t0 = time.perf_counter()
raw = generate_with_gemini(system, user, max_output_tokens=MAX_TOKENS_REPORT)
generation_time = time.perf_counter() - t0
if is_gemini_error_response(raw):
return {"error": raw}
used = parse_citation_numbers(raw)
report_md = _append_citations_to_markdown(raw, used, chunks)
artifact_id = new_uuid()
reports_dir_path = reports_dir(username, notebook_id)
filename = f"report_{artifact_id[:8]}.md"
path = reports_dir_path / filename
path.write_text(report_md, encoding="utf-8")
entry = {
"id": artifact_id,
"type": ARTIFACT_REPORT,
"filename": filename,
"created_at": _now_iso(),
"prompt": extra_instruction,
"retrieval_strategy": strategy,
}
_append_artifact_index(username, notebook_id, entry)
return {"path": str(path), "filename": filename, "content": report_md, "entry": entry, "generation_time": generation_time}
def generate_quiz(
username: str,
notebook_id: str,
extra_instruction: str = "",
strategy: str = "similarity",
) -> Dict[str, Any]:
"""Generate quiz via Gemini. 10–15 questions mixed MCQ + short answer; ANSWER KEY section."""
chunks = _retrieve_chunks(username, notebook_id, QUIZ_QUERY, extra_instruction, strategy)
if not chunks:
return {"error": "No sources to generate quiz from. Add and enable sources in this notebook."}
numbered_ctx = _build_numbered_context(chunks)
system = (
"You create quizzes in Markdown. You may ONLY use the provided numbered context. "
"Do not invent facts. If context is insufficient, say so. "
"Cite sources with [1], [2], etc. when a question or answer comes from a chunk. "
"Include 10–15 questions: a mix of multiple choice (MCQ) and short answer. "
"At the end, include a section: ## Answer Key with answers for all questions. "
"Use only the given numbers that correspond to the context chunks."
)
user = f"Context (cite using [1], [2], ...):\n\n{numbered_ctx}\n\nCreate a quiz in Markdown: 10–15 questions (MCQ + short answer), then ## Answer Key."
if extra_instruction:
user += f"\n\nAdditional instruction: {extra_instruction}"
t0 = time.perf_counter()
raw = generate_with_gemini(system, user, max_output_tokens=MAX_TOKENS_QUIZ)
generation_time = time.perf_counter() - t0
if is_gemini_error_response(raw):
return {"error": raw}
used = parse_citation_numbers(raw)
quiz_md = _append_citations_to_markdown(raw, used, chunks)
artifact_id = new_uuid()
quizzes_dir_path = quizzes_dir(username, notebook_id)
filename = f"quiz_{artifact_id[:8]}.md"
path = quizzes_dir_path / filename
path.write_text(quiz_md, encoding="utf-8")
entry = {
"id": artifact_id,
"type": ARTIFACT_QUIZ,
"filename": filename,
"created_at": _now_iso(),
"prompt": extra_instruction,
"retrieval_strategy": strategy,
}
_append_artifact_index(username, notebook_id, entry)
return {"path": str(path), "filename": filename, "content": quiz_md, "entry": entry, "generation_time": generation_time}
def generate_podcast(
username: str,
notebook_id: str,
extra_instruction: str = "",
strategy: str = "similarity",
) -> Dict[str, Any]:
"""Generate podcast transcript via Gemini (2 speakers, 4–8 min of text); then TTS for .mp3."""
chunks = _retrieve_chunks(username, notebook_id, PODCAST_QUERY, extra_instruction, strategy)
if not chunks:
return {"error": "No sources to generate podcast from. Add and enable sources in this notebook."}
numbered_ctx = _build_numbered_context(chunks)
system = (
"You write a 2-speaker podcast transcript. You may ONLY use the provided numbered context. "
"Do not invent facts. Use 'Speaker A:' and 'Speaker B:' before each line. "
"Natural dialogue; 4–8 minutes worth of text when read aloud (roughly 600–1200 words). "
"Include occasional citations like [3] when referring to a source. "
"Use only the given numbers that correspond to the context chunks."
)
user = f"Context (cite using [1], [2], ...):\n\n{numbered_ctx}\n\nWrite a 2-speaker podcast transcript (Speaker A:, Speaker B:). 4–8 minutes of dialogue, grounded in the context, with occasional [n] citations."
if extra_instruction:
user += f"\n\nAdditional instruction: {extra_instruction}"
t0 = time.perf_counter()
raw = generate_with_gemini(system, user, max_output_tokens=MAX_TOKENS_PODCAST)
generation_time = time.perf_counter() - t0
if is_gemini_error_response(raw):
return {"error": raw}
used = parse_citation_numbers(raw)
transcript_md = _append_citations_to_markdown(raw, used, chunks)
artifact_id = new_uuid()
podcasts_dir_path = podcasts_dir(username, notebook_id)
transcript_filename = f"transcript_{artifact_id[:8]}.md"
transcript_path = podcasts_dir_path / transcript_filename
transcript_path.write_text(transcript_md, encoding="utf-8")
mp3_filename = f"podcast_{artifact_id[:8]}.mp3"
mp3_path = podcasts_dir_path / mp3_filename
success = tts_module.text_to_speech(transcript_md[:5000], mp3_path, lang="en")
if not success:
logger.warning("TTS failed for podcast %s", artifact_id)
entry = {
"id": artifact_id,
"type": ARTIFACT_PODCAST,
"filename": mp3_filename,
"transcript_filename": transcript_filename,
"created_at": _now_iso(),
"prompt": extra_instruction,
"retrieval_strategy": strategy,
}
_append_artifact_index(username, notebook_id, entry)
return {
"path": str(mp3_path),
"filename": mp3_filename,
"transcript_path": str(transcript_path),
"transcript_content": transcript_md,
"entry": entry,
"audio_ok": success,
"generation_time": generation_time,
}
def _append_artifact_index(username: str, notebook_id: str, entry: Dict[str, Any]) -> None:
path = artifacts_index_path(username, notebook_id)
data = read_json(path, default={"artifacts": []})
data.setdefault("artifacts", []).append(entry)
write_json(path, data)
def list_artifacts(username: str, notebook_id: str) -> List[Dict[str, Any]]:
path = artifacts_index_path(username, notebook_id)
data = read_json(path, default={"artifacts": []})
return list(data.get("artifacts", [])[::-1])
def get_report_content(username: str, notebook_id: str, filename: str) -> str:
path = reports_dir(username, notebook_id) / filename
if not path.exists():
return ""
return path.read_text(encoding="utf-8")
def get_quiz_content(username: str, notebook_id: str, filename: str) -> str:
path = quizzes_dir(username, notebook_id) / filename
if not path.exists():
return ""
return path.read_text(encoding="utf-8")
def get_podcast_transcript(username: str, notebook_id: str, transcript_filename: str) -> str:
path = podcasts_dir(username, notebook_id) / transcript_filename
if not path.exists():
return ""
return path.read_text(encoding="utf-8")
def get_podcast_audio_path(username: str, notebook_id: str, mp3_filename: str) -> Optional[str]:
path = podcasts_dir(username, notebook_id) / mp3_filename
if not path.exists():
return None
return str(path)