NotebookLMClone / src /artifacts /report_generator.py
github-actions[bot]
Sync from GitHub 214f9ed998ff8d82e81656fab8d69dcd637cd425
46e5b37
"""
Report generator using RAG context from ingested notebook content.
"""
from __future__ import annotations
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from dotenv import load_dotenv
from openai import OpenAI
import requests
from src.ingestion.vectorstore import ChromaAdapter
load_dotenv()
SUPPORTED_REPORT_LLM_PROVIDERS = {"openai", "groq", "ollama"}
DEFAULT_REPORT_MODELS = {
"openai": "gpt-4o-mini",
"groq": "llama-3.1-8b-instant",
"ollama": "qwen2.5:3b",
}
REPORT_SYSTEM_PROMPT = (
"You write high quality reports grounded only in provided source context. "
"Do not invent facts."
)
class ReportGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model: Optional[str] = None,
llm_provider: Optional[str] = None,
):
provider_default = (
llm_provider
or os.getenv("REPORT_LLM_PROVIDER", "").strip()
or os.getenv("QUIZ_LLM_PROVIDER", "").strip()
or os.getenv("TRANSCRIPT_LLM_PROVIDER", "").strip()
or "openai"
)
self.llm_provider = provider_default.strip().lower()
if self.llm_provider not in SUPPORTED_REPORT_LLM_PROVIDERS:
raise ValueError(
f"Unsupported REPORT_LLM_PROVIDER='{self.llm_provider}'. "
f"Choose from: {sorted(SUPPORTED_REPORT_LLM_PROVIDERS)}"
)
self.model = self._resolve_model_name(model)
self._openai_client: OpenAI | None = None
self._groq_client = None
self._ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://127.0.0.1:11434").rstrip("/")
if self.llm_provider == "openai":
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self._openai_client = OpenAI(api_key=self.api_key)
elif self.llm_provider == "groq":
from groq import Groq
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError("GROQ_API_KEY is required when REPORT_LLM_PROVIDER=groq")
self._groq_client = Groq(api_key=groq_api_key)
else:
self.api_key = None
def _resolve_model_name(self, explicit_model: Optional[str]) -> str:
if explicit_model and explicit_model.strip():
return explicit_model.strip()
configured = os.getenv("REPORT_LLM_MODEL", "").strip()
if configured:
return configured
legacy = os.getenv("LLM_MODEL", "").strip()
if legacy:
return legacy
return DEFAULT_REPORT_MODELS.get(self.llm_provider, "gpt-4o-mini")
def generate_report(
self,
user_id: str,
notebook_id: str,
detail_level: str = "medium",
topic_focus: str | None = None,
) -> dict[str, str]:
context = self._get_report_context(user_id, notebook_id, topic_focus)
if not context:
return {"error": "No content found in notebook. Please ingest documents first."}
report_markdown = self._generate_markdown(context=context, detail_level=detail_level, topic_focus=topic_focus)
if not report_markdown.strip():
return {"error": "Failed to generate report content."}
return {
"content": report_markdown,
"detail_level": detail_level,
"llm_provider": self.llm_provider,
"llm_model": self.model,
}
def _get_report_context(self, user_id: str, notebook_id: str, topic_focus: str | None) -> str:
storage_base = os.getenv("STORAGE_BASE_DIR", "data")
chroma_dir = Path(storage_base) / "users" / user_id / "notebooks" / notebook_id / "chroma"
if not chroma_dir.exists():
return ""
store = ChromaAdapter(persist_directory=str(chroma_dir))
if topic_focus:
queries = [topic_focus]
else:
queries = [
"main ideas and summary",
"key evidence and facts",
"conclusions and action items",
]
chunks: list[str] = []
for query in queries:
try:
results = store.query(user_id, notebook_id, query, top_k=6)
except Exception:
continue
for _, _, chunk_data in results:
text = str(chunk_data.get("document", "")).strip()
if text:
chunks.append(text)
if not chunks:
return ""
unique_chunks = list(dict.fromkeys(chunks))
return "\n\n".join(unique_chunks[:14])
def _generate_markdown(self, context: str, detail_level: str, topic_focus: str | None) -> str:
target_words = {
"short": 400,
"medium": 800,
"long": 1400,
}.get(detail_level, 800)
focus_line = f"Focus area: {topic_focus}" if topic_focus else "Focus area: broad summary of notebook content"
prompt = f"""
Write a polished Markdown report from the notebook context below.
{focus_line}
Target length: around {target_words} words.
Notebook context:
{context}
Requirements:
- Use Markdown headings and concise sections.
- Include: Executive Summary, Key Insights, Evidence/Examples, Risks or Open Questions, Next Steps.
- Stay faithful to provided context. Do not fabricate unsupported claims.
- Keep tone professional and clear.
- Return Markdown only (no code fences).
"""
try:
return self._generate_report_content(prompt)
except Exception:
return ""
def _generate_report_content(self, prompt: str) -> str:
if self.llm_provider == "openai":
assert self._openai_client is not None
response = self._openai_client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": REPORT_SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=0.4,
)
return str(response.choices[0].message.content or "").strip()
if self.llm_provider == "groq":
assert self._groq_client is not None
response = self._groq_client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": REPORT_SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=0.4,
)
return str(response.choices[0].message.content or "").strip()
payload = {
"model": self.model,
"system": REPORT_SYSTEM_PROMPT,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.4},
}
response = requests.post(
f"{self._ollama_base_url}/api/generate",
json=payload,
timeout=120,
)
response.raise_for_status()
body = response.json()
return str(body.get("response", "")).strip()
def save_report(self, markdown_text: str, user_id: str, notebook_id: str) -> str:
storage_base = os.getenv("STORAGE_BASE_DIR", "data")
report_dir = Path(storage_base) / "users" / user_id / "notebooks" / notebook_id / "artifacts" / "reports"
report_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
path = report_dir / f"report_{timestamp}.md"
path.write_text(markdown_text, encoding="utf-8")
return str(path)