aileen3-core / demo /context_biased_transcription_cell.py
ndurner's picture
add comments
0c163b8
from __future__ import annotations
import asyncio
import logging
import os
import sys
from pathlib import Path
from typing import Tuple
import gradio as gr
from layout import cell
from health import GEMINI_ENV_VAR
from problem_cell import (
DEFAULT_VIDEO_URL,
SEARCH_TERM,
CORRECT_TERM,
render_status_box,
)
log = logging.getLogger(__name__)
# Context-biased transcription can take a bit longer; use more generous
# polling defaults here than in the other cells.
MAX_POLL_ATTEMPTS = 20
POLL_WAIT_SECONDS = 58
def _unwrap_tool_result(result: object) -> dict:
"""Adapt FastMCP CallToolResult objects into plain dicts."""
payload = getattr(result, "data", None) or getattr(result, "structured_content", None) or result
if isinstance(payload, dict):
return payload
return {
"status": "error",
"is_error": True,
"detail": f"Unexpected tool result type: {type(payload)!r}",
}
def _status(payload: dict) -> str:
return str(payload.get("status") or "").lower()
def _is_done(payload: dict) -> bool:
return _status(payload) == "done"
def _needs_poll(payload: dict) -> bool:
return _status(payload) in {"pending", "running"}
async def _poll_until_done(
client,
*,
tool_name: str,
reference: str,
wait_seconds: int,
max_attempts: int = MAX_POLL_ATTEMPTS,
) -> dict:
"""Poll the get_* MCP tools until a job finishes or attempts are exhausted."""
latest: dict = {}
for attempt in range(max_attempts):
try:
latest = _unwrap_tool_result(
await client.call_tool(
tool_name,
{"reference": reference, "wait_seconds": wait_seconds},
)
)
except Exception as exc: # pragma: no cover - defensive
return {
"status": "error",
"is_error": True,
"detail": f"Polling {tool_name} failed: {exc}",
}
if latest.get("is_error") or _is_done(latest):
return latest
if not _needs_poll(latest):
return latest
if latest:
latest.setdefault("detail", f"{tool_name} never reported completion; try again later.")
else:
latest = {
"status": "error",
"is_error": True,
"detail": f"{tool_name} did not return a response.",
}
return latest
async def _run_transcription_flow(gemini_api_key: str) -> Tuple[str, str]:
"""Drive the MCP media tools to run a context-biased transcription demo.
This mirrors a typical client-side flow:
- retrieve media via `start_media_retrieval`,
- derive a textual prior from the YouTube description, and
- call `start_media_transcription` with that prior as context.
"""
try:
from fastmcp import Client # type: ignore[import-untyped]
from fastmcp.client.transports import StdioTransport # type: ignore[import-untyped]
except Exception as exc: # pragma: no cover - defensive
status = render_status_box(f"fastmcp is not available in this environment: {exc}", "fail")
return status, ""
# As in the other cells we spawn the MCP server as a subprocess and
# point PYTHONPATH at `mcp/src` so that editable installs are not
# required to run the demo locally.
repo_root = Path(__file__).resolve().parents[1]
mcp_src = repo_root / "mcp" / "src"
existing_py_path = os.environ.get("PYTHONPATH", "")
py_path = f"{mcp_src}{os.pathsep}{existing_py_path}" if existing_py_path else str(mcp_src)
env = os.environ.copy()
env["PYTHONPATH"] = py_path
env[GEMINI_ENV_VAR] = gemini_api_key
server_entry = ["-m", "aileen3_mcp.server"]
log.warning(
"Context-biased transcription demo spawning MCP server: cmd=%s args=%s PYTHONPATH=%s cwd=%s",
sys.executable,
server_entry,
py_path,
repo_root,
)
transport = StdioTransport(
command=sys.executable,
args=server_entry,
env=env,
cwd=str(repo_root),
)
from_text = f"Using YouTube URL {DEFAULT_VIDEO_URL} as media source and its description as prior."
async with Client(transport) as client:
retrieval_start = _unwrap_tool_result(
await client.call_tool(
"start_media_retrieval",
{
"source": DEFAULT_VIDEO_URL,
"prefer_audio_only": True,
"wait_seconds": POLL_WAIT_SECONDS,
},
)
)
if retrieval_start.get("is_error"):
detail = retrieval_start.get("detail") or "Media retrieval failed."
status = render_status_box(detail, "fail")
return status, from_text
reference = retrieval_start.get("reference")
if not reference:
status = render_status_box(
"Media retrieval did not return a reference token.", "fail"
)
return status, from_text
retrieval = retrieval_start
if not _is_done(retrieval_start):
retrieval = await _poll_until_done(
client,
tool_name="get_media_retrieval_status",
reference=reference,
wait_seconds=POLL_WAIT_SECONDS,
)
if retrieval.get("is_error") or not _is_done(retrieval):
detail = retrieval.get("detail") or retrieval.get("status") or "Retrieval incomplete."
status = render_status_box(
f"Media retrieval did not complete successfully: {detail}", "fail"
)
return status, from_text
metadata = retrieval.get("metadata") or {}
description = metadata.get("description") or ""
context_text = description.strip()
if not context_text:
context_text = (
"No YouTube description was available for this video; using an empty prior instead."
)
transcription_start = _unwrap_tool_result(
await client.call_tool(
"start_media_transcription",
{
"reference": reference,
"context": context_text,
"prefer_audio_only": True,
"wait_seconds": POLL_WAIT_SECONDS,
},
)
)
if transcription_start.get("is_error"):
detail = transcription_start.get("detail") or "Transcription job failed to start."
status = render_status_box(
f"Transcription job did not complete successfully: {detail}", "fail"
)
return status, from_text
transcription = transcription_start
if not _is_done(transcription_start):
transcription = await _poll_until_done(
client,
tool_name="get_media_transcription_result",
reference=reference,
wait_seconds=POLL_WAIT_SECONDS,
)
if transcription.get("is_error") or not _is_done(transcription):
detail = transcription.get("detail") or transcription.get("status") or "Transcription incomplete."
status = render_status_box(
f"Transcription job did not complete successfully: {detail}", "fail"
)
return status, from_text
transcript_text = transcription.get("transcription") or ""
normalized = transcript_text.lower()
found_term = SEARCH_TERM.lower() in normalized
if found_term:
headline = (
f"🚨 Even with contextual priors, the transcript still contains β€œ{SEARCH_TERM}”."
)
tone = "fail"
else:
headline = (
f"βœ… With contextual priors, β€œ{SEARCH_TERM}” does **not** appear; "
f"the model stays on {CORRECT_TERM}."
)
tone = "success"
status_html = render_status_box(headline, tone)
snippet = transcript_text.strip()
if len(snippet) > 1200:
snippet = snippet[:1200].rsplit(" ", 1)[0] + " …"
details_lines = [
from_text,
"",
f"**Search term checked**: β€œ{SEARCH_TERM}”",
"",
"Below is a snippet of the transcription output (truncated for readability):",
"",
"```text",
snippet or "[Transcription was empty]",
"```",
]
return status_html, "\n".join(details_lines)
def run_context_biased_transcription(gemini_api_key: str | None) -> Tuple[str, str]:
"""Gradio callback entry point for the contextual transcription demo."""
key = (gemini_api_key or "").strip()
if not key:
status = render_status_box(
"Please provide a Gemini API key in the setup cell above before running this demo.",
"fail",
)
details = (
"The contextual transcription demo relies on Gemini via the Aileen MCP server. "
"Set `GEMINI_API_KEY` in the setup cell, run the health check to verify it, "
"then try this demo again."
)
return status, details
try:
return asyncio.run(_run_transcription_flow(key))
except Exception as exc: # pragma: no cover - defensive
log.warning("Context-biased transcription demo failed: %s", exc)
status = render_status_box(f"Context-biased transcription failed: {exc}", "fail")
details = (
"Something went wrong while talking to the Aileen MCP media tools. "
"Check the Space logs for more detail and ensure that ffmpeg, yt-dlp and Gemini "
"are all available."
)
return status, details
def render_context_biased_transcription_cell(gemini_key_input: gr.Textbox) -> None:
"""Render the notebook-style cell for the contextual transcription demo."""
with cell("πŸ§ͺ Context-biased transcription with Gemini"):
gr.Markdown(
f"""
### πŸ’πŸ»β€β™€οΈ Demo
This cell reuses the Smart Country Convention talk highlighted in the problem statement. The **Aileen MCP media tools** call Gemini to
transcribe a slice of the audio *while seeing the YouTube description as a prior*.
- The media is fetched via `start_media_retrieval` for the same video as above.
- The YouTube **description** from that retrieval is passed as the `context` argument to `start_media_transcription`.
- Gemini receives both the audio and this textual prior, increasing the chance that it sticks with **{CORRECT_TERM}** instead of
hallucinating **{SEARCH_TERM}**.
The goal is to observe how much a realistic prior (here: the video description) can nudge the transcription away from dramatic but wrong
tokens and toward the terminology the speaker actually uses.
"""
)
gr.Textbox(
label="YouTube video URL",
value=DEFAULT_VIDEO_URL,
interactive=False,
)
run_button = gr.Button("Run context-biased transcription demo", variant="primary")
result_panel = gr.HTML(
value=render_status_box(
"πŸ‘‰ Click the button to retrieve the media, run a Gemini-backed transcription with priors, and check for β€œNotstaatsvertrag”.",
"placeholder",
)
)
result_details = gr.Markdown(visible=True)
run_button.click(
fn=run_context_biased_transcription,
inputs=[gemini_key_input],
outputs=[result_panel, result_details],
queue=False,
)