NotebookLMClone / src /artifacts /podcast_generator.py
github-actions[bot]
Sync from GitHub 214f9ed998ff8d82e81656fab8d69dcd637cd425
46e5b37
"""
Podcast generator - creates conversational audio from notebook content.
"""
from __future__ import annotations
import json
import os
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
from dotenv import load_dotenv
from openai import OpenAI
import requests
from src.ingestion.vectorstore import ChromaAdapter
from .tts_adapter import get_tts_adapter, TTSProvider
load_dotenv()
SUPPORTED_TRANSCRIPT_LLM_PROVIDERS = {"openai", "groq", "ollama"}
DEFAULT_TRANSCRIPT_MODELS = {
"openai": "gpt-4o-mini",
"groq": "llama-3.1-8b-instant",
"ollama": "qwen2.5:3b",
}
TRANSCRIPT_SYSTEM_PROMPT = (
"You are an expert podcast script writer. Create engaging, natural, educational conversations. "
"Return valid JSON only with a top-level 'segments' array."
)
class PodcastGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model: Optional[str] = None,
tts_provider: Optional[TTSProvider] = None,
llm_provider: Optional[str] = None,
):
"""
Initialize podcast generator.
Args:
api_key: OpenAI API key (defaults to OPENAI_API_KEY from .env)
model: LLM model to use (defaults to LLM_MODEL from .env)
tts_provider: TTS provider (defaults to TTS_PROVIDER from .env)
llm_provider: Transcript LLM provider (openai, groq, ollama)
"""
self.llm_provider = (llm_provider or os.getenv("TRANSCRIPT_LLM_PROVIDER", "openai")).strip().lower()
if self.llm_provider not in SUPPORTED_TRANSCRIPT_LLM_PROVIDERS:
raise ValueError(
f"Unsupported TRANSCRIPT_LLM_PROVIDER='{self.llm_provider}'. "
f"Choose from: {sorted(SUPPORTED_TRANSCRIPT_LLM_PROVIDERS)}"
)
self.model = self._resolve_model_name(model)
self._openai_client: OpenAI | None = None
self._groq_client = None
self._ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http://127.0.0.1:11434").rstrip("/")
if self.llm_provider == "openai":
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self._openai_client = OpenAI(api_key=self.api_key)
elif self.llm_provider == "groq":
from groq import Groq
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError("GROQ_API_KEY is required when TRANSCRIPT_LLM_PROVIDER=groq")
self._groq_client = Groq(api_key=groq_api_key)
else:
self.api_key = None
# TTS configuration
tts_provider = tts_provider or os.getenv("TTS_PROVIDER", "edge")
self.tts = get_tts_adapter(tts_provider)
self.tts_provider = tts_provider
self._last_tts_errors: List[str] = []
# Default settings from .env
self.default_duration = os.getenv("DEFAULT_PODCAST_DURATION", "5min")
self.host_1 = os.getenv("PODCAST_HOST_1", "Alex")
self.host_2 = os.getenv("PODCAST_HOST_2", "Jordan")
def _resolve_model_name(self, explicit_model: Optional[str]) -> str:
if explicit_model and explicit_model.strip():
return explicit_model.strip()
configured = os.getenv("TRANSCRIPT_LLM_MODEL", "").strip()
if configured:
return configured
return DEFAULT_TRANSCRIPT_MODELS.get(self.llm_provider, "gpt-4o-mini")
def generate_podcast(
self,
user_id: str,
notebook_id: str,
duration_target: Optional[str] = None,
hosts: Optional[List[str]] = None,
topic_focus: Optional[str] = None,
) -> Dict[str, Any]:
"""
Generate a podcast-style conversation from notebook content.
Args:
user_id: User identifier
notebook_id: Notebook to generate podcast from
duration_target: Target length ("5min", "10min", "15min")
hosts: List of host names (defaults to PODCAST_HOST_1/2 from .env)
topic_focus: Optional specific topic to focus on
Returns:
Dict with transcript, audio_path, and metadata
"""
duration_target = duration_target or self.default_duration
hosts = hosts or [self.host_1, self.host_2]
print(f"🎙️ Generating {duration_target} podcast with {hosts[0]} & {hosts[1]}...")
# 1. Retrieve comprehensive context
context = self._get_notebook_context(user_id, notebook_id, topic_focus)
if not context:
return {
"error": "No content found in notebook. Please ingest documents first.",
"transcript": [],
"audio_path": None,
"metadata": {},
}
# 2. Generate conversational script
print("🤖 Generating podcast script...")
script = self._generate_script(context, duration_target, hosts)
if not script:
return {
"error": "Failed to generate podcast script.",
"transcript": [],
"audio_path": None,
"metadata": {},
}
# 3. Synthesize audio segments
print(f"🎵 Synthesizing audio with {self.tts_provider}...")
self._last_tts_errors = []
audio_segments = self._synthesize_segments(script, user_id, notebook_id, hosts)
if not audio_segments:
tts_error_preview = "; ".join(self._last_tts_errors[:3]).strip()
failure_message = (
"Transcript generated but audio synthesis failed for all segments. "
"Check TTS provider credentials, quota, and configured voices."
)
if tts_error_preview:
failure_message = f"{failure_message} Provider errors: {tts_error_preview}"
return {
"error": failure_message,
"transcript": script,
"audio_path": None,
"metadata": {
"notebook_id": notebook_id,
"duration_target": duration_target,
"hosts": hosts,
"tts_provider": self.tts_provider,
"llm_provider": self.llm_provider,
"llm_model": self.model,
"num_segments": len(script),
"topic_focus": topic_focus,
"tts_errors": self._last_tts_errors[:20],
"generated_at": datetime.now(timezone.utc).isoformat(),
},
}
# 4. Combine audio
print("🔗 Combining audio segments...")
final_audio = self._combine_audio(audio_segments, user_id, notebook_id)
if not final_audio or not Path(final_audio).exists():
return {
"error": (
"Transcript generated but final audio file was not created. "
"Check ffmpeg/pydub setup and TTS output."
),
"transcript": script,
"audio_path": None,
"metadata": {
"notebook_id": notebook_id,
"duration_target": duration_target,
"hosts": hosts,
"tts_provider": self.tts_provider,
"llm_provider": self.llm_provider,
"llm_model": self.model,
"num_segments": len(script),
"topic_focus": topic_focus,
"generated_at": datetime.now(timezone.utc).isoformat(),
},
}
return {
"transcript": script,
"audio_path": final_audio,
"metadata": {
"notebook_id": notebook_id,
"duration_target": duration_target,
"hosts": hosts,
"tts_provider": self.tts_provider,
"llm_provider": self.llm_provider,
"llm_model": self.model,
"num_segments": len(script),
"topic_focus": topic_focus,
"generated_at": datetime.now(timezone.utc).isoformat(),
},
}
def _get_notebook_context(
self,
user_id: str,
notebook_id: str,
topic_focus: Optional[str] = None,
) -> str:
"""Retrieve comprehensive context from notebook."""
data_base = os.getenv("STORAGE_BASE_DIR", "data")
chroma_dir = str(
Path(data_base) / "users" / user_id / "notebooks" / notebook_id / "chroma"
)
if not Path(chroma_dir).exists():
print(f"⚠️ Chroma directory not found: {chroma_dir}")
return ""
store = ChromaAdapter(persist_directory=chroma_dir)
# Get diverse chunks for comprehensive coverage
if topic_focus:
sample_queries = [topic_focus]
else:
sample_queries = [
"main topics and concepts",
"key principles and ideas",
"important details and facts",
"conclusions and insights",
"examples and applications",
]
all_chunks: List[str] = []
for query in sample_queries:
try:
results = store.query(user_id, notebook_id, query, top_k=5)
for _, _, chunk_data in results:
all_chunks.append(chunk_data["document"])
except Exception as e:
print(f"⚠️ Error querying: {e}")
continue
if not all_chunks:
return ""
# Deduplicate and combine
unique_chunks = list(set(all_chunks))
context = "\n\n".join(unique_chunks[:15]) # Top 15 chunks
print(f"✓ Retrieved {len(unique_chunks)} unique chunks ({len(context)} chars)")
return context
def _generate_script(
self,
context: str,
duration: str,
hosts: List[str],
) -> List[Dict[str, str]]:
"""Generate conversational script using LLM."""
word_count_map = {
"5min": 750,
"10min": 1500,
"15min": 2250,
"20min": 3000,
}
target_words = word_count_map.get(duration, 750)
prompt = self._build_podcast_prompt(context, target_words, hosts)
try:
raw_response = self._generate_transcript_json(prompt)
segments = self._extract_segments(raw_response)
print(f"✓ Generated script with {len(segments)} segments")
return segments
except Exception as e:
print(f"❌ Error generating script: {e}")
return []
def _generate_transcript_json(self, prompt: str) -> str:
if self.llm_provider == "openai":
assert self._openai_client is not None
response = self._openai_client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": TRANSCRIPT_SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=0.8,
response_format={"type": "json_object"},
)
return str(response.choices[0].message.content or "")
if self.llm_provider == "groq":
assert self._groq_client is not None
response = self._groq_client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": TRANSCRIPT_SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
temperature=0.8,
)
return str(response.choices[0].message.content or "")
payload = {
"model": self.model,
"system": TRANSCRIPT_SYSTEM_PROMPT,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.8},
}
response = requests.post(
f"{self._ollama_base_url}/api/generate",
json=payload,
timeout=120,
)
response.raise_for_status()
body = response.json()
return str(body.get("response", ""))
def _extract_segments(self, raw_response: str) -> List[Dict[str, str]]:
payload = self._extract_json_object(raw_response)
segments = payload.get("segments") if isinstance(payload, dict) else None
if not isinstance(segments, list):
return []
cleaned: List[Dict[str, str]] = []
for item in segments:
if not isinstance(item, dict):
continue
speaker = str(item.get("speaker", "")).strip()
text = str(item.get("text", "")).strip()
if not speaker or not text:
continue
cleaned.append({"speaker": speaker, "text": text})
return cleaned
def _extract_json_object(self, raw_response: str) -> Dict[str, Any]:
content = str(raw_response or "").strip()
if not content:
return {}
if content.startswith("```"):
content = re.sub(r"^```(?:json)?\s*", "", content)
content = re.sub(r"\s*```$", "", content)
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", content, re.DOTALL)
if match:
try:
parsed = json.loads(match.group(0))
if isinstance(parsed, dict):
return parsed
except json.JSONDecodeError:
return {}
return {}
def _build_podcast_prompt(
self,
context: str,
target_words: int,
hosts: List[str],
) -> str:
"""Build the podcast script generation prompt."""
return f"""
Create a natural, engaging podcast conversation between {hosts[0]} and {hosts[1]} about the following content.
Content to discuss:
{context}
Requirements:
- Target length: approximately {target_words} words ({target_words // 150} minutes)
- Make it conversational, engaging, and educational
- {hosts[0]} is the curious host who asks insightful questions
- {hosts[1]} is the knowledgeable host who provides clear explanations
- Include natural reactions, follow-up questions, and transitions
- Break down complex topics into digestible explanations
- Use analogies and examples to clarify concepts
- Maintain an upbeat, friendly, and enthusiastic tone
- End with a brief summary and key takeaways
Structure:
1. Opening (introduce topic, set context)
2. Main discussion (explore key concepts)
3. Deep dive (detailed explanations with examples)
4. Closing (summary and takeaways)
Generate the script in this exact JSON format:
{{
"segments": [
{{
"speaker": "{hosts[0]}",
"text": "Welcome to today's episode! We're diving into..."
}},
{{
"speaker": "{hosts[1]}",
"text": "Thanks for having me! This is such a fascinating topic..."
}}
]
}}
IMPORTANT:
- Each segment should be 1-3 sentences (natural speaking chunks)
- Alternate between speakers naturally
- Include pauses and transitions like "That's interesting..." or "Let me explain..."
- Make it sound like a real conversation, not a lecture
"""
@staticmethod
def _is_fatal_tts_error(exc: Exception) -> bool:
"""
Detect provider/configuration errors where retrying further segments is pointless.
"""
text = " ".join(str(exc).lower().split())
fatal_markers = [
"voice_not_found",
"no compatible elevenlabs synthesis method found",
"invalid_api_key",
"unauthorized",
"authentication",
"forbidden",
"insufficient_credits",
"quota",
"status_code: 401",
"status_code: 403",
]
return any(marker in text for marker in fatal_markers)
def _synthesize_segments(
self,
script: List[Dict[str, str]],
user_id: str,
notebook_id: str,
hosts: List[str],
) -> List[str]:
"""Synthesize each script segment to audio."""
storage_base = os.getenv("STORAGE_BASE_DIR", "data")
audio_dir = Path(storage_base) / "users" / user_id / "notebooks" / notebook_id / "artifacts" / "podcasts"
audio_dir.mkdir(parents=True, exist_ok=True)
voice_maps: Dict[str, Dict[str, str]] = {
"openai": {
hosts[0]: os.getenv("TTS_OPENAI_VOICE_1", "alloy"),
hosts[1]: os.getenv("TTS_OPENAI_VOICE_2", "echo"),
},
"edge": {
hosts[0]: os.getenv("TTS_EDGE_VOICE_1", "en-US-GuyNeural"),
hosts[1]: os.getenv("TTS_EDGE_VOICE_2", "en-US-AriaNeural"),
},
"elevenlabs": {
hosts[0]: os.getenv("TTS_ELEVENLABS_VOICE_1", "Antoni"),
hosts[1]: os.getenv("TTS_ELEVENLABS_VOICE_2", "Rachel"),
},
}
voices = voice_maps.get(self.tts_provider, voice_maps["edge"])
audio_files: List[str] = []
self._last_tts_errors = []
total = len(script)
for i, segment in enumerate(script, 1):
speaker = segment["speaker"]
text = segment["text"]
voice = voices.get(speaker, list(voices.values())[0])
output_path = str(audio_dir / f"segment_{i:03d}_{speaker}.mp3")
try:
self.tts.synthesize(text, output_path, voice=voice)
audio_files.append(output_path)
print(f" ✓ Segment {i}/{total}: {speaker}")
except Exception as e:
error_detail = (
f"segment={i}/{total}, speaker={speaker}, voice={voice}, "
f"error={type(e).__name__}: {' '.join(str(e).split())}"
)
self._last_tts_errors.append(error_detail)
print(f" ⚠️ Failed {error_detail}")
if self._is_fatal_tts_error(e):
print(" ⛔ Fatal TTS configuration/provider error detected. Stopping remaining segments.")
break
continue
return audio_files
def _combine_audio(
self,
audio_segments: List[str],
user_id: str,
notebook_id: str,
) -> str:
"""Combine audio segments into single file."""
try:
from pydub import AudioSegment
except ImportError:
print("⚠️ pydub not installed. Skipping audio combination.")
print(" Install with: pip install pydub")
return audio_segments[0] if audio_segments else ""
if not audio_segments:
return ""
combined = AudioSegment.empty()
for i, segment_path in enumerate(audio_segments, 1):
try:
audio = AudioSegment.from_file(segment_path)
combined += audio
combined += AudioSegment.silent(duration=500) # 0.5s pause
except Exception as e:
print(f" ⚠️ Error processing segment {i}: {e}")
continue
storage_base = os.getenv("STORAGE_BASE_DIR", "data")
output_dir = Path(storage_base) / "users" / user_id / "notebooks" / notebook_id / "artifacts" / "podcasts"
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
final_path = str(output_dir / f"podcast_{timestamp}.mp3")
combined.export(final_path, format="mp3")
print(f"✓ Final podcast: {final_path}")
print(f" Duration: {len(combined) / 1000:.1f} seconds")
return final_path
def save_transcript(
self,
podcast_data: Dict[str, Any],
user_id: str,
notebook_id: str,
) -> str:
"""Save podcast transcript Markdown to file."""
storage_base = os.getenv("STORAGE_BASE_DIR", "data")
transcript_dir = Path(storage_base) / "users" / user_id / "notebooks" / notebook_id / "artifacts" / "podcasts"
transcript_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
filename = f"transcript_{timestamp}.md"
filepath = transcript_dir / filename
filepath.write_text(self.format_transcript_markdown(podcast_data), encoding="utf-8")
print(f"✓ Transcript saved to: {filepath}")
return str(filepath)
def format_transcript_markdown(
self,
podcast_data: Dict[str, Any],
title: str | None = None,
) -> str:
"""Render podcast transcript as Markdown."""
metadata = podcast_data.get("metadata", {}) if isinstance(podcast_data.get("metadata"), dict) else {}
transcript = podcast_data.get("transcript", [])
resolved_title = title or "Podcast Transcript"
lines: list[str] = [f"# {resolved_title}", ""]
duration = metadata.get("duration_target")
if duration:
lines.append(f"Target duration: **{duration}**")
lines.append("")
topic_focus = metadata.get("topic_focus")
if topic_focus:
lines.append(f"Topic focus: {topic_focus}")
lines.append("")
lines.append("## Conversation")
lines.append("")
for segment in transcript if isinstance(transcript, list) else []:
speaker = str(segment.get("speaker", "Speaker")).strip() or "Speaker"
text = str(segment.get("text", "")).strip()
if not text:
continue
lines.append(f"**{speaker}:** {text}")
lines.append("")
return "\n".join(lines)
# === CLI for testing ===
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate podcast from notebook")
parser.add_argument("--user", required=True, help="User ID")
parser.add_argument("--notebook", required=True, help="Notebook ID")
parser.add_argument(
"--duration",
choices=["5min", "10min", "15min", "20min"],
help="Target podcast duration",
)
parser.add_argument("--topic", help="Focus on specific topic")
parser.add_argument(
"--tts-provider",
choices=["openai", "edge", "elevenlabs"],
help="TTS provider (defaults to TTS_PROVIDER in .env)",
)
parser.add_argument(
"--llm-provider",
choices=["openai", "groq", "ollama"],
help="Transcript LLM provider (defaults to TRANSCRIPT_LLM_PROVIDER in .env)",
)
parser.add_argument("--model", help="Override transcript LLM model")
parser.add_argument("--save-transcript", action="store_true", help="Save transcript to file")
args = parser.parse_args()
generator = PodcastGenerator(
tts_provider=args.tts_provider,
llm_provider=args.llm_provider,
model=args.model,
)
result = generator.generate_podcast(
args.user,
args.notebook,
args.duration,
topic_focus=args.topic,
)
if "error" in result:
print(f"\n❌ {result['error']}")
else:
print(f"\n✓ Podcast generated!")
print(f" Audio: {result['audio_path']}")
print(f" Segments: {len(result['transcript'])}")
print(f" Provider: {result['metadata']['tts_provider']}")
if args.save_transcript:
generator.save_transcript(result, args.user, args.notebook)