podify / research /graph.py
jayaspjacob
Add OpenAudio emotion/tone cues to generated scripts
9756234
Raw
History Blame Contribute Delete
7.61 kB
"""LangGraph research graph: plan -> research -> outline -> write.
Produces a speaker-tagged podcast script from a topic, grounded in live DuckDuckGo
search results. Kept deliberately linear and lightweight so it runs fast on a CPU Space.
"""
from __future__ import annotations
import json
import re
from typing import List, Tuple, TypedDict
from langgraph.graph import StateGraph, START, END
from .llm import complete
from .search import web_search, SearchResult
class ResearchState(TypedDict, total=False):
topic: str
style: str
duration_min: int
num_speakers: int
speaker_names: List[str]
queries: List[str]
findings: str
sources: List[str]
outline: str
script: str
# --------------------------------------------------------------------------- nodes
def plan_node(state: ResearchState) -> dict:
topic = state["topic"]
raw = complete(
system=(
"You are a research planner. Given a podcast topic, produce 3-6 focused web "
"search queries that together cover the key angles. Respond ONLY with a JSON "
'array of strings, e.g. ["query one", "query two"].'
),
user=f"Topic: {topic}",
temperature=0.4,
max_tokens=400,
)
queries = _parse_json_list(raw) or [topic]
return {"queries": queries[:6]}
def research_node(state: ResearchState) -> dict:
blocks: List[str] = []
sources: List[str] = []
for q in state.get("queries", []):
results: List[SearchResult] = web_search(q, max_results=4)
if not results:
continue
blocks.append(f"### Query: {q}\n" + "\n".join(r.as_markdown() for r in results))
sources.extend(r.url for r in results)
findings = "\n\n".join(blocks) if blocks else "(No web results were available.)"
# De-duplicate sources, preserve order.
seen, uniq = set(), []
for u in sources:
if u not in seen:
seen.add(u)
uniq.append(u)
return {"findings": findings, "sources": uniq}
def outline_node(state: ResearchState) -> dict:
outline = complete(
system=(
"You are a podcast producer. Using the research findings, write a tight "
"outline (intro, 3-5 segments, outro) for the podcast. Use markdown bullets."
),
user=(
f"Topic: {state['topic']}\n"
f"Style: {state.get('style', 'conversational')}\n"
f"Target length: ~{state.get('duration_min', 5)} minutes\n\n"
f"Research findings:\n{state.get('findings', '')}"
),
temperature=0.6,
max_tokens=800,
)
return {"outline": outline}
def write_node(state: ResearchState) -> dict:
speakers = state.get("speaker_names") or _default_speakers(state.get("num_speakers", 2))
speaker_list = ", ".join(speakers)
fmt = "\n".join(f"{s}: <what they say>" for s in speakers)
script = complete(
system=(
"You are a professional podcast scriptwriter. Write a natural, engaging, "
"factually-grounded podcast script based on the outline and findings.\n"
f"Speakers: {speaker_list}.\n"
"Format STRICTLY as one line per turn, prefixed with the speaker name and a "
f"colon, like:\n{fmt}\n"
"Make the delivery feel human by adding OpenAudio emotion/tone cues IN "
"PARENTHESES, inline, right before the words they color (or at the very start "
"of a turn). Use ONLY these cues: (excited) (curious) (surprised) (amused) "
"(interested) (confident) (empathetic) (joyful) (serious) (sarcastic) "
"(thoughtful) (laughing) (chuckling) (sighing) (whispering) (soft tone) "
"(in a hurry tone). Use them sparingly — about one every few lines, only where "
"it genuinely fits the moment. Do NOT invent other cues and do NOT use square "
"brackets. There is no pause or emphasis marker: convey pauses and emphasis with "
"natural punctuation (commas, em-dashes —, ellipses …).\n"
"Apart from these inline parenthetical cues, output only spoken dialogue — no "
"markdown, headings, or stand-alone stage directions. Keep each line to a few "
"sentences. Open with a hook and close with a sign-off."
),
user=(
f"Topic: {state['topic']}\n"
f"Style: {state.get('style', 'conversational')}\n"
f"Target length: ~{state.get('duration_min', 5)} minutes\n\n"
f"Outline:\n{state.get('outline', '')}\n\n"
f"Findings:\n{state.get('findings', '')}"
),
temperature=0.8,
max_tokens=3000,
)
return {"script": script.strip()}
# --------------------------------------------------------------------------- helpers
def _parse_json_list(text: str) -> List[str]:
match = re.search(r"\[.*\]", text, re.DOTALL)
if not match:
return [line.strip("-* ").strip() for line in text.splitlines() if line.strip()]
try:
data = json.loads(match.group(0))
return [str(x).strip() for x in data if str(x).strip()]
except json.JSONDecodeError:
return []
def _default_speakers(n: int) -> List[str]:
names = ["Host", "Guest", "Co-host", "Expert"]
if n <= 1:
return ["Narrator"]
return names[:n]
def parse_script(script: str) -> List[Tuple[str, str]]:
"""Turn a 'Speaker: text' transcript into [(speaker, text), ...]."""
lines: List[Tuple[str, str]] = []
pattern = re.compile(r"^\s*([\w .'-]{1,30}?)\s*:\s*(.+)$")
for raw in script.splitlines():
raw = raw.strip()
if not raw:
continue
m = pattern.match(raw)
if m:
lines.append((m.group(1).strip(), m.group(2).strip()))
elif lines: # continuation of previous speaker's line
spk, txt = lines[-1]
lines[-1] = (spk, f"{txt} {raw}")
return lines
# --------------------------------------------------------------------------- graph
def build_graph():
g = StateGraph(ResearchState)
g.add_node("plan", plan_node)
g.add_node("research", research_node)
g.add_node("outline", outline_node)
g.add_node("write", write_node)
g.add_edge(START, "plan")
g.add_edge("plan", "research")
g.add_edge("research", "outline")
g.add_edge("outline", "write")
g.add_edge("write", END)
return g.compile()
_GRAPH = None
def generate_script(
topic: str,
*,
style: str = "conversational",
duration_min: int = 5,
num_speakers: int = 2,
speaker_names: List[str] | None = None,
) -> dict:
"""Run the full research graph and return the final state."""
global _GRAPH
if _GRAPH is None:
_GRAPH = build_graph()
speakers = speaker_names or _default_speakers(num_speakers)
result = _GRAPH.invoke(
{
"topic": topic,
"style": style,
"duration_min": duration_min,
"num_speakers": num_speakers,
"speaker_names": speakers,
}
)
return result
if __name__ == "__main__": # manual smoke test (needs HF_TOKEN)
import sys
t = sys.argv[1] if len(sys.argv) > 1 else "The history and future of electric cars"
out = generate_script(t, duration_min=3)
print("\n=== SCRIPT ===\n")
print(out["script"])
print("\n=== SOURCES ===\n")
print("\n".join(out.get("sources", [])))
print("\n=== PARSED LINES ===\n")
for spk, txt in parse_script(out["script"]):
print(f"[{spk}] {txt[:80]}")