ndurner commited on
Commit
9f94d04
·
1 Parent(s): b604263

context-biased transcription

Browse files
demo/app.py CHANGED
@@ -9,6 +9,7 @@ from layout import CELL_CSS, cell
9
  from problem_cell import render_problem_cell
10
  from solution_cell import render_solution_cell
11
  from setup_cell import render_setup_cell
 
12
 
13
 
14
  def render_health_panel(gemini_api_key: str | None = None) -> str:
@@ -77,6 +78,8 @@ Think of this interface as a lightweight Jupyter notebook: instead of code cells
77
  queue=False,
78
  )
79
 
 
 
80
  return demo
81
 
82
 
 
9
  from problem_cell import render_problem_cell
10
  from solution_cell import render_solution_cell
11
  from setup_cell import render_setup_cell
12
+ from context_biased_transcription_cell import render_context_biased_transcription_cell
13
 
14
 
15
  def render_health_panel(gemini_api_key: str | None = None) -> str:
 
78
  queue=False,
79
  )
80
 
81
+ render_context_biased_transcription_cell(gemini_key_box)
82
+
83
  return demo
84
 
85
 
demo/context_biased_transcription_cell.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import os
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Tuple
9
+
10
+ import gradio as gr
11
+
12
+ from layout import cell
13
+ from health import GEMINI_ENV_VAR
14
+ from problem_cell import (
15
+ DEFAULT_VIDEO_URL,
16
+ SEARCH_TERM,
17
+ CORRECT_TERM,
18
+ render_status_box,
19
+ )
20
+
21
+ log = logging.getLogger(__name__)
22
+
23
+ MAX_POLL_ATTEMPTS = 3
24
+ POLL_WAIT_SECONDS = 54
25
+
26
+
27
+ def _unwrap_tool_result(result: object) -> dict:
28
+ """Adapt FastMCP CallToolResult objects into plain dicts."""
29
+ payload = getattr(result, "data", None) or getattr(result, "structured_content", None) or result
30
+ if isinstance(payload, dict):
31
+ return payload
32
+ return {
33
+ "status": "error",
34
+ "is_error": True,
35
+ "detail": f"Unexpected tool result type: {type(payload)!r}",
36
+ }
37
+
38
+
39
+ def _status(payload: dict) -> str:
40
+ return str(payload.get("status") or "").lower()
41
+
42
+
43
+ def _is_done(payload: dict) -> bool:
44
+ return _status(payload) == "done"
45
+
46
+
47
+ def _needs_poll(payload: dict) -> bool:
48
+ return _status(payload) in {"pending", "running"}
49
+
50
+
51
+ async def _poll_until_done(
52
+ client,
53
+ *,
54
+ tool_name: str,
55
+ reference: str,
56
+ wait_seconds: int,
57
+ max_attempts: int = MAX_POLL_ATTEMPTS,
58
+ ) -> dict:
59
+ """Poll the get_* MCP tools until a job finishes or attempts are exhausted."""
60
+ latest: dict = {}
61
+ for attempt in range(max_attempts):
62
+ try:
63
+ latest = _unwrap_tool_result(
64
+ await client.call_tool(
65
+ tool_name,
66
+ {"reference": reference, "wait_seconds": wait_seconds},
67
+ )
68
+ )
69
+ except Exception as exc: # pragma: no cover - defensive
70
+ return {
71
+ "status": "error",
72
+ "is_error": True,
73
+ "detail": f"Polling {tool_name} failed: {exc}",
74
+ }
75
+
76
+ if latest.get("is_error") or _is_done(latest):
77
+ return latest
78
+
79
+ if not _needs_poll(latest):
80
+ return latest
81
+
82
+ if latest:
83
+ latest.setdefault("detail", f"{tool_name} never reported completion; try again later.")
84
+ else:
85
+ latest = {
86
+ "status": "error",
87
+ "is_error": True,
88
+ "detail": f"{tool_name} did not return a response.",
89
+ }
90
+ return latest
91
+
92
+
93
+ async def _run_transcription_flow(gemini_api_key: str) -> Tuple[str, str]:
94
+ """Drive the MCP media tools to run a context-biased transcription demo."""
95
+ try:
96
+ from fastmcp import Client # type: ignore[import-untyped]
97
+ from fastmcp.client.transports import StdioTransport # type: ignore[import-untyped]
98
+ except Exception as exc: # pragma: no cover - defensive
99
+ status = render_status_box(f"fastmcp is not available in this environment: {exc}", "fail")
100
+ return status, ""
101
+
102
+ repo_root = Path(__file__).resolve().parents[1]
103
+ mcp_src = repo_root / "mcp" / "src"
104
+ existing_py_path = os.environ.get("PYTHONPATH", "")
105
+ py_path = f"{mcp_src}{os.pathsep}{existing_py_path}" if existing_py_path else str(mcp_src)
106
+
107
+ env = os.environ.copy()
108
+ env["PYTHONPATH"] = py_path
109
+ env[GEMINI_ENV_VAR] = gemini_api_key
110
+
111
+ server_entry = ["-m", "aileen3_mcp.server"]
112
+
113
+ log.warning(
114
+ "Context-biased transcription demo spawning MCP server: cmd=%s args=%s PYTHONPATH=%s cwd=%s",
115
+ sys.executable,
116
+ server_entry,
117
+ py_path,
118
+ repo_root,
119
+ )
120
+
121
+ transport = StdioTransport(
122
+ command=sys.executable,
123
+ args=server_entry,
124
+ env=env,
125
+ cwd=str(repo_root),
126
+ )
127
+
128
+ from_text = f"Using YouTube URL {DEFAULT_VIDEO_URL} as media source and its description as prior."
129
+
130
+ async with Client(transport) as client:
131
+ retrieval_start = _unwrap_tool_result(
132
+ await client.call_tool(
133
+ "start_media_retrieval",
134
+ {
135
+ "source": DEFAULT_VIDEO_URL,
136
+ "prefer_audio_only": True,
137
+ "wait_seconds": POLL_WAIT_SECONDS,
138
+ },
139
+ )
140
+ )
141
+
142
+ if retrieval_start.get("is_error"):
143
+ detail = retrieval_start.get("detail") or "Media retrieval failed."
144
+ status = render_status_box(detail, "fail")
145
+ return status, from_text
146
+
147
+ reference = retrieval_start.get("reference")
148
+ if not reference:
149
+ status = render_status_box(
150
+ "Media retrieval did not return a reference token.", "fail"
151
+ )
152
+ return status, from_text
153
+
154
+ retrieval = retrieval_start
155
+ if not _is_done(retrieval_start):
156
+ retrieval = await _poll_until_done(
157
+ client,
158
+ tool_name="get_media_retrieval_status",
159
+ reference=reference,
160
+ wait_seconds=POLL_WAIT_SECONDS,
161
+ )
162
+
163
+ if retrieval.get("is_error") or not _is_done(retrieval):
164
+ detail = retrieval.get("detail") or retrieval.get("status") or "Retrieval incomplete."
165
+ status = render_status_box(
166
+ f"Media retrieval did not complete successfully: {detail}", "fail"
167
+ )
168
+ return status, from_text
169
+
170
+ metadata = retrieval.get("metadata") or {}
171
+ description = metadata.get("description") or ""
172
+
173
+ context_text = description.strip()
174
+ if not context_text:
175
+ context_text = (
176
+ "No YouTube description was available for this video; using an empty prior instead."
177
+ )
178
+
179
+ transcription_start = _unwrap_tool_result(
180
+ await client.call_tool(
181
+ "start_media_transcription",
182
+ {
183
+ "reference": reference,
184
+ "context": context_text,
185
+ "prefer_audio_only": True,
186
+ "wait_seconds": POLL_WAIT_SECONDS,
187
+ },
188
+ )
189
+ )
190
+
191
+ if transcription_start.get("is_error"):
192
+ detail = transcription_start.get("detail") or "Transcription job failed to start."
193
+ status = render_status_box(
194
+ f"Transcription job did not complete successfully: {detail}", "fail"
195
+ )
196
+ return status, from_text
197
+
198
+ transcription = transcription_start
199
+ if not _is_done(transcription_start):
200
+ transcription = await _poll_until_done(
201
+ client,
202
+ tool_name="get_media_transcription_result",
203
+ reference=reference,
204
+ wait_seconds=POLL_WAIT_SECONDS,
205
+ )
206
+
207
+ if transcription.get("is_error") or not _is_done(transcription):
208
+ detail = transcription.get("detail") or transcription.get("status") or "Transcription incomplete."
209
+ status = render_status_box(
210
+ f"Transcription job did not complete successfully: {detail}", "fail"
211
+ )
212
+ return status, from_text
213
+
214
+ transcript_text = transcription.get("transcription") or ""
215
+ normalized = transcript_text.lower()
216
+ found_term = SEARCH_TERM.lower() in normalized
217
+
218
+ if found_term:
219
+ headline = (
220
+ f"🚨 Even with contextual priors, the transcript still contains “{SEARCH_TERM}”."
221
+ )
222
+ tone = "fail"
223
+ else:
224
+ headline = (
225
+ f"✅ With contextual priors, “{SEARCH_TERM}” does **not** appear; "
226
+ f"the model stays on {CORRECT_TERM}."
227
+ )
228
+ tone = "success"
229
+
230
+ status_html = render_status_box(headline, tone)
231
+
232
+ snippet = transcript_text.strip()
233
+ if len(snippet) > 1200:
234
+ snippet = snippet[:1200].rsplit(" ", 1)[0] + " …"
235
+
236
+ details_lines = [
237
+ from_text,
238
+ "",
239
+ f"**Search term checked**: “{SEARCH_TERM}”",
240
+ "",
241
+ "Below is a snippet of the transcription output (truncated for readability):",
242
+ "",
243
+ "```text",
244
+ snippet or "[Transcription was empty]",
245
+ "```",
246
+ ]
247
+ return status_html, "\n".join(details_lines)
248
+
249
+
250
+ def run_context_biased_transcription(gemini_api_key: str | None) -> Tuple[str, str]:
251
+ """Gradio callback entry point for the contextual transcription demo."""
252
+ key = (gemini_api_key or "").strip()
253
+ if not key:
254
+ status = render_status_box(
255
+ "Please provide a Gemini API key in the setup cell above before running this demo.",
256
+ "fail",
257
+ )
258
+ details = (
259
+ "The contextual transcription demo relies on Gemini via the Aileen MCP server. "
260
+ "Set `GEMINI_API_KEY` in the setup cell, run the health check to verify it, "
261
+ "then try this demo again."
262
+ )
263
+ return status, details
264
+
265
+ try:
266
+ return asyncio.run(_run_transcription_flow(key))
267
+ except Exception as exc: # pragma: no cover - defensive
268
+ log.warning("Context-biased transcription demo failed: %s", exc)
269
+ status = render_status_box(f"Context-biased transcription failed: {exc}", "fail")
270
+ details = (
271
+ "Something went wrong while talking to the Aileen MCP media tools. "
272
+ "Check the Space logs for more detail and ensure that ffmpeg, yt-dlp and Gemini "
273
+ "are all available."
274
+ )
275
+ return status, details
276
+
277
+
278
+ def render_context_biased_transcription_cell(gemini_key_input: gr.Textbox) -> None:
279
+ """Render the notebook-style cell for the contextual transcription demo."""
280
+ with cell("🧪 Context-biased transcription with Gemini"):
281
+ gr.Markdown(
282
+ f"""
283
+ ### 💁🏻‍♀️ Demo
284
+ This cell reuses the Smart Country Convention talk highlighted in the problem statement. The **Aileen MCP media tools** call Gemini to
285
+ transcribe a slice of the audio *while seeing the YouTube description as a prior*.
286
+
287
+ - The media is fetched via `start_media_retrieval` for the same video as above.
288
+ - The YouTube **description** from that retrieval is passed as the `context` argument to `start_media_transcription`.
289
+ - Gemini receives both the audio and this textual prior, increasing the chance that it sticks with **{CORRECT_TERM}** instead of
290
+ hallucinating **{SEARCH_TERM}**.
291
+
292
+ The goal is to observe how much a realistic prior (here: the video description) can nudge the transcription away from dramatic but wrong
293
+ tokens and toward the terminology the speaker actually uses.
294
+ """
295
+ )
296
+
297
+ gr.Textbox(
298
+ label="YouTube video URL",
299
+ value=DEFAULT_VIDEO_URL,
300
+ interactive=False,
301
+ )
302
+
303
+ run_button = gr.Button("Run context-biased transcription demo", variant="primary")
304
+ result_panel = gr.HTML(
305
+ value=render_status_box(
306
+ "👉 Click the button to retrieve the media, run a Gemini-backed transcription with priors, and check for “Notstaatsvertrag”.",
307
+ "placeholder",
308
+ )
309
+ )
310
+ result_details = gr.Markdown(visible=True)
311
+
312
+ run_button.click(
313
+ fn=run_context_biased_transcription,
314
+ inputs=[gemini_key_input],
315
+ outputs=[result_panel, result_details],
316
+ queue=False,
317
+ )
mcp/src/aileen3_mcp/media_tools.py CHANGED
@@ -988,7 +988,7 @@ def _analysis_flow(metadata: dict, priors_obj: Priors | dict) -> dict:
988
  # ---------------------------------------------------------------------------------------------------------------------
989
 
990
 
991
- def _transcription_flow(metadata: dict, context: str) -> str:
992
  reference = metadata["reference"]
993
  video_path = Path(metadata["download_path"])
994
  audio_path = _ensure_audio_sidecar(video_path, reference)
@@ -998,7 +998,9 @@ def _transcription_flow(metadata: dict, context: str) -> str:
998
  priors.media_context = _media_context_from_metadata(metadata)
999
  priors_text = priors.as_prompt_text()
1000
 
1001
- slides = _load_or_extract_slides(metadata)
 
 
1002
 
1003
  client = _build_gemini_client()
1004
  uploaded_slides = _upload_slides_to_gemini(client, slides, reference)
@@ -1451,6 +1453,7 @@ def register_media_tools(app: FastMCP) -> None:
1451
  ctx: Context,
1452
  reference: str,
1453
  context: str = "",
 
1454
  wait_seconds: int = 55,
1455
  ) -> dict:
1456
  """
@@ -1463,6 +1466,8 @@ def register_media_tools(app: FastMCP) -> None:
1463
  Parameters:
1464
  - reference: Token from `start_media_retrieval` pointing at the downloaded media blob.
1465
  - context: Free-form grounding text that improves names, jargon, or expected topics.
 
 
1466
  - wait_seconds: Time to wait for the background job. Set to 0 to always return immediately.
1467
 
1468
  Note:
@@ -1480,6 +1485,8 @@ def register_media_tools(app: FastMCP) -> None:
1480
 
1481
  if context is not None and not isinstance(context, str):
1482
  return _error("context must be a string", reference)
 
 
1483
 
1484
  context_text = str(context or "")
1485
  return await _start_media_processing_job(
@@ -1489,7 +1496,7 @@ def register_media_tools(app: FastMCP) -> None:
1489
  result_field="transcription",
1490
  cache_path_fn=_transcription_json_path,
1491
  flow_callable=_transcription_flow,
1492
- flow_args=(metadata, context_text),
1493
  )
1494
 
1495
  @app.tool()
 
988
  # ---------------------------------------------------------------------------------------------------------------------
989
 
990
 
991
+ def _transcription_flow(metadata: dict, context: str, prefer_audio_only: bool) -> str:
992
  reference = metadata["reference"]
993
  video_path = Path(metadata["download_path"])
994
  audio_path = _ensure_audio_sidecar(video_path, reference)
 
998
  priors.media_context = _media_context_from_metadata(metadata)
999
  priors_text = priors.as_prompt_text()
1000
 
1001
+ slides: list[dict] = []
1002
+ if not prefer_audio_only:
1003
+ slides = _load_or_extract_slides(metadata)
1004
 
1005
  client = _build_gemini_client()
1006
  uploaded_slides = _upload_slides_to_gemini(client, slides, reference)
 
1453
  ctx: Context,
1454
  reference: str,
1455
  context: str = "",
1456
+ prefer_audio_only: bool = False,
1457
  wait_seconds: int = 55,
1458
  ) -> dict:
1459
  """
 
1466
  Parameters:
1467
  - reference: Token from `start_media_retrieval` pointing at the downloaded media blob.
1468
  - context: Free-form grounding text that improves names, jargon, or expected topics.
1469
+ - prefer_audio_only: If true, run transcription using only the audio track and ignore visual slide context.
1470
+ This avoids slide extraction and upload for cheaper, audio-only runs. Defaults to False.
1471
  - wait_seconds: Time to wait for the background job. Set to 0 to always return immediately.
1472
 
1473
  Note:
 
1485
 
1486
  if context is not None and not isinstance(context, str):
1487
  return _error("context must be a string", reference)
1488
+ if not isinstance(prefer_audio_only, bool):
1489
+ return _error("prefer_audio_only must be a boolean", reference)
1490
 
1491
  context_text = str(context or "")
1492
  return await _start_media_processing_job(
 
1496
  result_field="transcription",
1497
  cache_path_fn=_transcription_json_path,
1498
  flow_callable=_transcription_flow,
1499
+ flow_args=(metadata, context_text, prefer_audio_only),
1500
  )
1501
 
1502
  @app.tool()