File size: 11,681 Bytes
9f94d04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c163b8
 
5ed6a9a
 
9f94d04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c163b8
 
 
 
 
 
 
9f94d04
 
 
 
 
 
 
0c163b8
 
 
9f94d04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
from __future__ import annotations

import asyncio
import logging
import os
import sys
from pathlib import Path
from typing import Tuple

import gradio as gr

from layout import cell
from health import GEMINI_ENV_VAR
from problem_cell import (
    DEFAULT_VIDEO_URL,
    SEARCH_TERM,
    CORRECT_TERM,
    render_status_box,
)

log = logging.getLogger(__name__)

# Context-biased transcription can take a bit longer; use more generous
# polling defaults here than in the other cells.
MAX_POLL_ATTEMPTS = 20
POLL_WAIT_SECONDS = 58


def _unwrap_tool_result(result: object) -> dict:
    """Adapt FastMCP CallToolResult objects into plain dicts."""
    payload = getattr(result, "data", None) or getattr(result, "structured_content", None) or result
    if isinstance(payload, dict):
        return payload
    return {
        "status": "error",
        "is_error": True,
        "detail": f"Unexpected tool result type: {type(payload)!r}",
    }


def _status(payload: dict) -> str:
    return str(payload.get("status") or "").lower()


def _is_done(payload: dict) -> bool:
    return _status(payload) == "done"


def _needs_poll(payload: dict) -> bool:
    return _status(payload) in {"pending", "running"}


async def _poll_until_done(
    client,
    *,
    tool_name: str,
    reference: str,
    wait_seconds: int,
    max_attempts: int = MAX_POLL_ATTEMPTS,
) -> dict:
    """Poll the get_* MCP tools until a job finishes or attempts are exhausted."""
    latest: dict = {}
    for attempt in range(max_attempts):
        try:
            latest = _unwrap_tool_result(
                await client.call_tool(
                    tool_name,
                    {"reference": reference, "wait_seconds": wait_seconds},
                )
            )
        except Exception as exc:  # pragma: no cover - defensive
            return {
                "status": "error",
                "is_error": True,
                "detail": f"Polling {tool_name} failed: {exc}",
            }

        if latest.get("is_error") or _is_done(latest):
            return latest

        if not _needs_poll(latest):
            return latest

    if latest:
        latest.setdefault("detail", f"{tool_name} never reported completion; try again later.")
    else:
        latest = {
            "status": "error",
            "is_error": True,
            "detail": f"{tool_name} did not return a response.",
        }
    return latest


async def _run_transcription_flow(gemini_api_key: str) -> Tuple[str, str]:
    """Drive the MCP media tools to run a context-biased transcription demo.

    This mirrors a typical client-side flow:
    - retrieve media via `start_media_retrieval`,
    - derive a textual prior from the YouTube description, and
    - call `start_media_transcription` with that prior as context.
    """
    try:
        from fastmcp import Client  # type: ignore[import-untyped]
        from fastmcp.client.transports import StdioTransport  # type: ignore[import-untyped]
    except Exception as exc:  # pragma: no cover - defensive
        status = render_status_box(f"fastmcp is not available in this environment: {exc}", "fail")
        return status, ""

    # As in the other cells we spawn the MCP server as a subprocess and
    # point PYTHONPATH at `mcp/src` so that editable installs are not
    # required to run the demo locally.
    repo_root = Path(__file__).resolve().parents[1]
    mcp_src = repo_root / "mcp" / "src"
    existing_py_path = os.environ.get("PYTHONPATH", "")
    py_path = f"{mcp_src}{os.pathsep}{existing_py_path}" if existing_py_path else str(mcp_src)

    env = os.environ.copy()
    env["PYTHONPATH"] = py_path
    env[GEMINI_ENV_VAR] = gemini_api_key

    server_entry = ["-m", "aileen3_mcp.server"]

    log.warning(
        "Context-biased transcription demo spawning MCP server: cmd=%s args=%s PYTHONPATH=%s cwd=%s",
        sys.executable,
        server_entry,
        py_path,
        repo_root,
    )

    transport = StdioTransport(
        command=sys.executable,
        args=server_entry,
        env=env,
        cwd=str(repo_root),
    )

    from_text = f"Using YouTube URL {DEFAULT_VIDEO_URL} as media source and its description as prior."

    async with Client(transport) as client:
        retrieval_start = _unwrap_tool_result(
            await client.call_tool(
                "start_media_retrieval",
                {
                    "source": DEFAULT_VIDEO_URL,
                    "prefer_audio_only": True,
                    "wait_seconds": POLL_WAIT_SECONDS,
                },
            )
        )

        if retrieval_start.get("is_error"):
            detail = retrieval_start.get("detail") or "Media retrieval failed."
            status = render_status_box(detail, "fail")
            return status, from_text

        reference = retrieval_start.get("reference")
        if not reference:
            status = render_status_box(
                "Media retrieval did not return a reference token.", "fail"
            )
            return status, from_text

        retrieval = retrieval_start
        if not _is_done(retrieval_start):
            retrieval = await _poll_until_done(
                client,
                tool_name="get_media_retrieval_status",
                reference=reference,
                wait_seconds=POLL_WAIT_SECONDS,
            )

        if retrieval.get("is_error") or not _is_done(retrieval):
            detail = retrieval.get("detail") or retrieval.get("status") or "Retrieval incomplete."
            status = render_status_box(
                f"Media retrieval did not complete successfully: {detail}", "fail"
            )
            return status, from_text

        metadata = retrieval.get("metadata") or {}
        description = metadata.get("description") or ""

        context_text = description.strip()
        if not context_text:
            context_text = (
                "No YouTube description was available for this video; using an empty prior instead."
            )

        transcription_start = _unwrap_tool_result(
            await client.call_tool(
                "start_media_transcription",
                {
                    "reference": reference,
                    "context": context_text,
                    "prefer_audio_only": True,
                    "wait_seconds": POLL_WAIT_SECONDS,
                },
            )
        )

        if transcription_start.get("is_error"):
            detail = transcription_start.get("detail") or "Transcription job failed to start."
            status = render_status_box(
                f"Transcription job did not complete successfully: {detail}", "fail"
            )
            return status, from_text

        transcription = transcription_start
        if not _is_done(transcription_start):
            transcription = await _poll_until_done(
                client,
                tool_name="get_media_transcription_result",
                reference=reference,
                wait_seconds=POLL_WAIT_SECONDS,
            )

        if transcription.get("is_error") or not _is_done(transcription):
            detail = transcription.get("detail") or transcription.get("status") or "Transcription incomplete."
            status = render_status_box(
                f"Transcription job did not complete successfully: {detail}", "fail"
            )
            return status, from_text

        transcript_text = transcription.get("transcription") or ""
        normalized = transcript_text.lower()
        found_term = SEARCH_TERM.lower() in normalized

        if found_term:
            headline = (
                f"🚨 Even with contextual priors, the transcript still contains β€œ{SEARCH_TERM}”."
            )
            tone = "fail"
        else:
            headline = (
                f"βœ… With contextual priors, β€œ{SEARCH_TERM}” does **not** appear; "
                f"the model stays on {CORRECT_TERM}."
            )
            tone = "success"

        status_html = render_status_box(headline, tone)

        snippet = transcript_text.strip()
        if len(snippet) > 1200:
            snippet = snippet[:1200].rsplit(" ", 1)[0] + " …"

        details_lines = [
            from_text,
            "",
            f"**Search term checked**: β€œ{SEARCH_TERM}”",
            "",
            "Below is a snippet of the transcription output (truncated for readability):",
            "",
            "```text",
            snippet or "[Transcription was empty]",
            "```",
        ]
        return status_html, "\n".join(details_lines)


def run_context_biased_transcription(gemini_api_key: str | None) -> Tuple[str, str]:
    """Gradio callback entry point for the contextual transcription demo."""
    key = (gemini_api_key or "").strip()
    if not key:
        status = render_status_box(
            "Please provide a Gemini API key in the setup cell above before running this demo.",
            "fail",
        )
        details = (
            "The contextual transcription demo relies on Gemini via the Aileen MCP server. "
            "Set `GEMINI_API_KEY` in the setup cell, run the health check to verify it, "
            "then try this demo again."
        )
        return status, details

    try:
        return asyncio.run(_run_transcription_flow(key))
    except Exception as exc:  # pragma: no cover - defensive
        log.warning("Context-biased transcription demo failed: %s", exc)
        status = render_status_box(f"Context-biased transcription failed: {exc}", "fail")
        details = (
            "Something went wrong while talking to the Aileen MCP media tools. "
            "Check the Space logs for more detail and ensure that ffmpeg, yt-dlp and Gemini "
            "are all available."
        )
        return status, details


def render_context_biased_transcription_cell(gemini_key_input: gr.Textbox) -> None:
    """Render the notebook-style cell for the contextual transcription demo."""
    with cell("πŸ§ͺ Context-biased transcription with Gemini"):
        gr.Markdown(
            f"""
### πŸ’πŸ»β€β™€οΈ Demo
This cell reuses the Smart Country Convention talk highlighted in the problem statement. The **Aileen MCP media tools** call Gemini to
transcribe a slice of the audio *while seeing the YouTube description as a prior*.

- The media is fetched via `start_media_retrieval` for the same video as above.
- The YouTube **description** from that retrieval is passed as the `context` argument to `start_media_transcription`.
- Gemini receives both the audio and this textual prior, increasing the chance that it sticks with **{CORRECT_TERM}** instead of
  hallucinating **{SEARCH_TERM}**.

The goal is to observe how much a realistic prior (here: the video description) can nudge the transcription away from dramatic but wrong
tokens and toward the terminology the speaker actually uses.
            """
        )

        gr.Textbox(
            label="YouTube video URL",
            value=DEFAULT_VIDEO_URL,
            interactive=False,
        )

        run_button = gr.Button("Run context-biased transcription demo", variant="primary")
        result_panel = gr.HTML(
            value=render_status_box(
                "πŸ‘‰ Click the button to retrieve the media, run a Gemini-backed transcription with priors, and check for β€œNotstaatsvertrag”.",
                "placeholder",
            )
        )
        result_details = gr.Markdown(visible=True)

        run_button.click(
            fn=run_context_biased_transcription,
            inputs=[gemini_key_input],
            outputs=[result_panel, result_details],
            queue=False,
        )