Ratnesh-dev commited on
Commit
cf57473
·
1 Parent(s): fdfe940

Add Infer Speaker And Transcript Cleanup using OpenAI GPT 5

Browse files
cleanup_transcript_openai.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import json
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+
9
+ def _log(message: str) -> None:
10
+ print(f"[cleanup] {message}", flush=True)
11
+
12
+
13
+ def _load_json(path: Path) -> dict[str, Any]:
14
+ with path.open("r", encoding="utf-8") as f:
15
+ return json.load(f)
16
+
17
+
18
+ def _save_json(path: Path, payload: Any) -> None:
19
+ path.parent.mkdir(parents=True, exist_ok=True)
20
+ with path.open("w", encoding="utf-8") as f:
21
+ json.dump(payload, f, indent=2, ensure_ascii=False)
22
+
23
+
24
+ def _extract_json_object(text: str) -> dict[str, Any]:
25
+ text = text.strip()
26
+ if not text:
27
+ raise ValueError("Model returned empty text.")
28
+
29
+ try:
30
+ parsed = json.loads(text)
31
+ if isinstance(parsed, dict):
32
+ return parsed
33
+ except Exception:
34
+ pass
35
+
36
+ start = text.find("{")
37
+ while start >= 0:
38
+ depth = 0
39
+ for idx in range(start, len(text)):
40
+ ch = text[idx]
41
+ if ch == "{":
42
+ depth += 1
43
+ elif ch == "}":
44
+ depth -= 1
45
+ if depth == 0:
46
+ candidate = text[start : idx + 1]
47
+ try:
48
+ parsed = json.loads(candidate)
49
+ if isinstance(parsed, dict):
50
+ return parsed
51
+ except Exception:
52
+ break
53
+ start = text.find("{", start + 1)
54
+
55
+ raise ValueError("Could not parse a JSON object from model output.")
56
+
57
+
58
+ def _response_to_dict(response: Any) -> dict[str, Any]:
59
+ if hasattr(response, "model_dump") and callable(response.model_dump):
60
+ return response.model_dump()
61
+ if hasattr(response, "to_dict") and callable(response.to_dict):
62
+ return response.to_dict()
63
+ return {"raw_response": str(response)}
64
+
65
+
66
+ def _response_text(response: Any) -> str:
67
+ output_text = getattr(response, "output_text", None)
68
+ if isinstance(output_text, str) and output_text.strip():
69
+ return output_text
70
+
71
+ data = _response_to_dict(response)
72
+ if isinstance(data, dict):
73
+ for key in ("output_text", "text"):
74
+ val = data.get(key)
75
+ if isinstance(val, str) and val.strip():
76
+ return val
77
+ return ""
78
+
79
+
80
+ def _usage_from_response_dict(payload: dict[str, Any]) -> dict[str, int | None]:
81
+ usage = payload.get("usage")
82
+ if not isinstance(usage, dict):
83
+ return {
84
+ "input_tokens": None,
85
+ "output_tokens": None,
86
+ "total_tokens": None,
87
+ "cached_input_tokens": None,
88
+ "reasoning_tokens": None,
89
+ }
90
+
91
+ input_details = usage.get("input_tokens_details", {})
92
+ output_details = usage.get("output_tokens_details", {})
93
+ return {
94
+ "input_tokens": usage.get("input_tokens"),
95
+ "output_tokens": usage.get("output_tokens"),
96
+ "total_tokens": usage.get("total_tokens"),
97
+ "cached_input_tokens": input_details.get("cached_tokens") if isinstance(input_details, dict) else None,
98
+ "reasoning_tokens": output_details.get("reasoning_tokens") if isinstance(output_details, dict) else None,
99
+ }
100
+
101
+
102
+ def _sum_usage(
103
+ first: dict[str, int | None],
104
+ second: dict[str, int | None],
105
+ ) -> dict[str, int | None]:
106
+ def _sum_key(key: str) -> int | None:
107
+ a = first.get(key)
108
+ b = second.get(key)
109
+ if isinstance(a, int) and isinstance(b, int):
110
+ return a + b
111
+ if isinstance(a, int):
112
+ return a
113
+ if isinstance(b, int):
114
+ return b
115
+ return None
116
+
117
+ total = _sum_key("total_tokens")
118
+ input_tokens = _sum_key("input_tokens")
119
+ output_tokens = _sum_key("output_tokens")
120
+ if total is None and isinstance(input_tokens, int) and isinstance(output_tokens, int):
121
+ total = input_tokens + output_tokens
122
+
123
+ return {
124
+ "input_tokens": input_tokens,
125
+ "output_tokens": output_tokens,
126
+ "total_tokens": total,
127
+ "cached_input_tokens": _sum_key("cached_input_tokens"),
128
+ "reasoning_tokens": _sum_key("reasoning_tokens"),
129
+ }
130
+
131
+
132
+ def _parse_executive_names(
133
+ *,
134
+ names_csv: str | None,
135
+ ) -> list[str]:
136
+ out: list[str] = []
137
+
138
+ if names_csv:
139
+ for item in names_csv.split(","):
140
+ name = item.strip().strip('"').strip("'")
141
+ if name:
142
+ out.append(name)
143
+
144
+ # Preserve order while removing duplicates.
145
+ seen = set()
146
+ deduped: list[str] = []
147
+ for name in out:
148
+ key = name.lower()
149
+ if key in seen:
150
+ continue
151
+ seen.add(key)
152
+ deduped.append(name)
153
+ return deduped
154
+
155
+
156
+ def _build_intro_payload(turns: list[dict[str, Any]], intro_turn_limit: int) -> list[dict[str, Any]]:
157
+ sampled = turns[: max(1, intro_turn_limit)]
158
+ payload: list[dict[str, Any]] = []
159
+ for idx, turn in enumerate(sampled):
160
+ payload.append(
161
+ {
162
+ "turn_index": idx,
163
+ "speaker": turn.get("speaker"),
164
+ "start": turn.get("start"),
165
+ "end": turn.get("end"),
166
+ "text": turn.get("text"),
167
+ }
168
+ )
169
+ return payload
170
+
171
+
172
+ def _extract_qna_announcements(turns: list[dict[str, Any]], max_items: int = 200) -> list[dict[str, Any]]:
173
+ announcements: list[dict[str, Any]] = []
174
+ for idx, turn in enumerate(turns):
175
+ text = str(turn.get("text", "")).strip()
176
+ if not text:
177
+ continue
178
+ lowered = text.lower()
179
+ if "line of" in lowered and ("please go ahead" in lowered or "question" in lowered):
180
+ announcements.append(
181
+ {
182
+ "turn_index": idx,
183
+ "speaker": turn.get("speaker"),
184
+ "text": text,
185
+ }
186
+ )
187
+ if len(announcements) >= max_items:
188
+ break
189
+ return announcements
190
+
191
+
192
+ def _extract_response_id(response: Any, response_dict: dict[str, Any]) -> str | None:
193
+ rid = getattr(response, "id", None)
194
+ if isinstance(rid, str) and rid:
195
+ return rid
196
+ candidate = response_dict.get("id")
197
+ if isinstance(candidate, str) and candidate:
198
+ return candidate
199
+ return None
200
+
201
+
202
+ def run_cleanup_pipeline(
203
+ *,
204
+ input_file: Path,
205
+ api_key: str,
206
+ model: str,
207
+ output_dir: Path,
208
+ intro_turn_limit: int,
209
+ executive_names_csv: str | None,
210
+ ) -> dict[str, Any]:
211
+ try:
212
+ from openai import OpenAI
213
+ except ImportError as exc:
214
+ raise RuntimeError(
215
+ "Missing dependency: openai. Install with `pip install openai`."
216
+ ) from exc
217
+
218
+ _log("Loading transcript JSON...")
219
+ transcript_json = _load_json(input_file)
220
+ turns = transcript_json.get("turns")
221
+ if not isinstance(turns, list) or not turns:
222
+ raise ValueError("Input JSON must contain a non-empty `turns` list.")
223
+
224
+ _log("Parsing executive names input...")
225
+ executive_names = _parse_executive_names(
226
+ names_csv=executive_names_csv,
227
+ )
228
+ intro_turns_payload = _build_intro_payload(turns, intro_turn_limit=intro_turn_limit)
229
+ qna_announcements = _extract_qna_announcements(turns)
230
+
231
+ run_dir = output_dir / datetime.now().strftime("%Y%m%d_%H%M%S")
232
+ run_dir.mkdir(parents=True, exist_ok=True)
233
+ executive_names_out_path = run_dir / "executive_names.json"
234
+ _save_json(executive_names_out_path, {"names": executive_names})
235
+ _log(f"Run directory: {run_dir}")
236
+ _log(f"Saved executive names file: {executive_names_out_path}")
237
+
238
+ client = OpenAI(api_key=api_key)
239
+
240
+ speaker_map_system = (
241
+ "You are a transcript entity-resolution assistant. "
242
+ "Return strict JSON only, no markdown. "
243
+ "Infer speaker identities from transcript context."
244
+ )
245
+ speaker_map_user = json.dumps(
246
+ {
247
+ "task": "Infer speaker mapping from transcript context (intro + Q&A announcements).",
248
+ "rules": [
249
+ "Use explicit or near-explicit intro context ('I now hand over to ...', self-intros, operator intros).",
250
+ "Label any conference host/queue-management voice as exactly 'Operator' when they do call control.",
251
+ "Do not map Operator to an executive name.",
252
+ "Do not guess beyond evidence.",
253
+ "Prefer names from `executive_names` when they match context.",
254
+ "In Q&A, infer non-executive participant names from operator announcements such as 'line of <name> from <firm>', even if absent in executive list.",
255
+ "Keep unknown speakers as null names if evidence is weak.",
256
+ ],
257
+ "output_schema": {
258
+ "speaker_mapping": [
259
+ {
260
+ "speaker_label": "SPEAKER_XX",
261
+ "inferred_name": "string or null",
262
+ "confidence": "number 0..1",
263
+ "evidence_turn_indexes": ["int"],
264
+ "reason": "short string",
265
+ }
266
+ ],
267
+ "notes": ["string"],
268
+ },
269
+ "executive_names": executive_names,
270
+ "intro_turns": intro_turns_payload,
271
+ "qna_announcements": qna_announcements,
272
+ "transcript_turns": turns,
273
+ },
274
+ ensure_ascii=False,
275
+ )
276
+
277
+ _log("OpenAI call 1/2: inferring speaker mapping...")
278
+ speaker_map_response = client.responses.create(
279
+ model=model,
280
+ input=[
281
+ {"role": "system", "content": speaker_map_system},
282
+ {"role": "user", "content": speaker_map_user},
283
+ ],
284
+ )
285
+ speaker_map_raw = _response_to_dict(speaker_map_response)
286
+ first_response_id = _extract_response_id(speaker_map_response, speaker_map_raw)
287
+ speaker_map_usage = _usage_from_response_dict(speaker_map_raw)
288
+ speaker_map_text = _response_text(speaker_map_response)
289
+ speaker_map_json = _extract_json_object(speaker_map_text)
290
+
291
+ speaker_map_path = run_dir / "speaker_mapping.json"
292
+ speaker_map_raw_path = run_dir / "speaker_mapping_raw_response.json"
293
+ _save_json(speaker_map_path, speaker_map_json)
294
+ _save_json(speaker_map_raw_path, speaker_map_raw)
295
+
296
+ cleanup_system = (
297
+ "You are a transcript cleanup and diarization refinement assistant. "
298
+ "Return strict JSON only, no markdown."
299
+ )
300
+ cleanup_payload_base = {
301
+ "task": "Clean transcript and produce final speaker-attributed turns.",
302
+ "rules": [
303
+ "Correct likely misspellings and improve punctuation/casing.",
304
+ "Remove false starts and repeated filler where safe, but keep meaning.",
305
+ "Standardize executive names to the canonical forms in `executive_names` where applicable.",
306
+ "Use `speaker_mapping` from call 1, but keep unknown labels if unsupported.",
307
+ "Label the conference host/control speaker as exactly 'Operator' when they are handling queue/instructions.",
308
+ "In Q&A, infer names not present in `executive_names` from context and operator announcements.",
309
+ "If a very short mid-sentence speaker switch is likely diarization noise, merge/reassign using sentence continuity.",
310
+ "Preserve turn order and timing progression.",
311
+ "Output speaker labels as inferred names when confidence is sufficient; otherwise keep SPEAKER_XX.",
312
+ "Do not invent facts not present in transcript context.",
313
+ ],
314
+ "output_schema": {
315
+ "speaker_mapping_final": [
316
+ {
317
+ "source_label": "SPEAKER_XX",
318
+ "final_label": "Name or SPEAKER_XX",
319
+ "confidence": "number 0..1",
320
+ "reason": "short string",
321
+ }
322
+ ],
323
+ "turns": [
324
+ {
325
+ "speaker": "Name or SPEAKER_XX",
326
+ "start": "float",
327
+ "end": "float",
328
+ "text": "cleaned text",
329
+ }
330
+ ],
331
+ "summary": {
332
+ "turn_count": "int",
333
+ "speaker_count": "int",
334
+ "notes": ["string"],
335
+ },
336
+ },
337
+ "executive_names": executive_names,
338
+ "speaker_mapping": speaker_map_json.get("speaker_mapping", []),
339
+ }
340
+ cleanup_payload_with_turns = dict(cleanup_payload_base)
341
+ cleanup_payload_with_turns["transcript_turns"] = turns
342
+ cleanup_payload_context_only = dict(cleanup_payload_base)
343
+ cleanup_payload_context_only["context_hint"] = (
344
+ "Use the transcript context from the previous response. "
345
+ "Do not request retransmission."
346
+ )
347
+
348
+ _log("OpenAI call 2/2: cleaning transcript and refining speaker labels...")
349
+ cleanup_response = None
350
+ used_context_chaining = False
351
+ if first_response_id:
352
+ _log("Using previous_response_id context chaining for call 2.")
353
+ try:
354
+ cleanup_response = client.responses.create(
355
+ model=model,
356
+ previous_response_id=first_response_id,
357
+ input=[
358
+ {"role": "system", "content": cleanup_system},
359
+ {"role": "user", "content": json.dumps(cleanup_payload_context_only, ensure_ascii=False)},
360
+ ],
361
+ )
362
+ used_context_chaining = True
363
+ except TypeError:
364
+ _log("Client does not support previous_response_id; falling back to explicit transcript payload.")
365
+ except Exception as exc:
366
+ _log(f"Context-chained call failed ({exc}); falling back to explicit transcript payload.")
367
+
368
+ if cleanup_response is None:
369
+ cleanup_response = client.responses.create(
370
+ model=model,
371
+ input=[
372
+ {"role": "system", "content": cleanup_system},
373
+ {"role": "user", "content": json.dumps(cleanup_payload_with_turns, ensure_ascii=False)},
374
+ ],
375
+ )
376
+ cleanup_raw = _response_to_dict(cleanup_response)
377
+ cleanup_usage = _usage_from_response_dict(cleanup_raw)
378
+ cleanup_text = _response_text(cleanup_response)
379
+ cleaned_json = _extract_json_object(cleanup_text)
380
+ token_usage = {
381
+ "speaker_mapping_call": speaker_map_usage,
382
+ "cleanup_call": cleanup_usage,
383
+ "combined": _sum_usage(speaker_map_usage, cleanup_usage),
384
+ }
385
+
386
+ cleaned_json["inputs"] = {
387
+ "source_file": str(input_file),
388
+ "speaker_mapping_file": str(speaker_map_path),
389
+ "context_chaining_used_for_cleanup": used_context_chaining,
390
+ }
391
+ cleaned_json["openai_token_usage"] = token_usage
392
+
393
+ cleaned_path = run_dir / "cleaned_transcript.json"
394
+ cleaned_raw_path = run_dir / "cleanup_raw_response.json"
395
+ cleaned_text_path = run_dir / "cleaned_transcript.txt"
396
+
397
+ _save_json(cleaned_path, cleaned_json)
398
+ _save_json(cleaned_raw_path, cleanup_raw)
399
+
400
+ output_turns = cleaned_json.get("turns", [])
401
+ lines: list[str] = []
402
+ if isinstance(output_turns, list):
403
+ for turn in output_turns:
404
+ if not isinstance(turn, dict):
405
+ continue
406
+ speaker = str(turn.get("speaker", "SPEAKER_XX"))
407
+ text = str(turn.get("text", "")).strip()
408
+ if text:
409
+ lines.append(f"{speaker}: {text}")
410
+ cleaned_text_path.write_text("\n".join(lines), encoding="utf-8")
411
+ _log("Saved cleaned transcript outputs.")
412
+
413
+ run_summary = {
414
+ "run_dir": str(run_dir),
415
+ "input_file": str(input_file),
416
+ "model": model,
417
+ "speaker_mapping_file": str(speaker_map_path),
418
+ "speaker_mapping_raw_file": str(speaker_map_raw_path),
419
+ "cleaned_transcript_file": str(cleaned_path),
420
+ "cleaned_transcript_raw_file": str(cleaned_raw_path),
421
+ "cleaned_text_file": str(cleaned_text_path),
422
+ "intro_turn_limit": intro_turn_limit,
423
+ "executive_names_file": str(executive_names_out_path),
424
+ "context_chaining_used_for_cleanup": used_context_chaining,
425
+ "openai_token_usage": token_usage,
426
+ }
427
+ _save_json(run_dir / "run_summary.json", run_summary)
428
+ _log("Completed.")
429
+ return run_summary
430
+
431
+
432
+ def main() -> None:
433
+ parser = argparse.ArgumentParser(
434
+ description=(
435
+ "Run two OpenAI calls over a merged transcript JSON: "
436
+ "(1) speaker mapping inference, (2) cleaned/re-labeled transcript."
437
+ )
438
+ )
439
+ parser.add_argument("--input-file", required=True, help="Path to merged transcript JSON.")
440
+ parser.add_argument("--api-key", required=True, help="OpenAI API key.")
441
+ parser.add_argument("--model", default="gpt-5", help="OpenAI model ID (default: gpt-5).")
442
+ parser.add_argument(
443
+ "--intro-turn-limit",
444
+ type=int,
445
+ default=80,
446
+ help="Number of initial turns to use for speaker-introduction inference.",
447
+ )
448
+ parser.add_argument(
449
+ "--executive-names-csv",
450
+ default=None,
451
+ help='Comma-separated executive names, e.g. "Name A,Name B,Name C".',
452
+ )
453
+ parser.add_argument(
454
+ "--output-dir",
455
+ default="benchmark_outputs/cleanup_openai",
456
+ help="Directory to store outputs.",
457
+ )
458
+
459
+ args = parser.parse_args()
460
+ summary = run_cleanup_pipeline(
461
+ input_file=Path(args.input_file),
462
+ api_key=args.api_key,
463
+ model=args.model,
464
+ output_dir=Path(args.output_dir),
465
+ intro_turn_limit=args.intro_turn_limit,
466
+ executive_names_csv=args.executive_names_csv,
467
+ )
468
+ print(json.dumps(summary, indent=2))
469
+
470
+
471
+ if __name__ == "__main__":
472
+ main()
merge_parakeet_pyannote.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import json
4
+ from dataclasses import dataclass
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ @dataclass
11
+ class Word:
12
+ text: str
13
+ start: float
14
+ end: float
15
+
16
+
17
+ @dataclass
18
+ class Segment:
19
+ speaker: str
20
+ start: float
21
+ end: float
22
+
23
+
24
+ def _load_json(path: str) -> dict[str, Any]:
25
+ with open(path, "r", encoding="utf-8") as f:
26
+ return json.load(f)
27
+
28
+
29
+ def _unwrap_result(payload: dict[str, Any], model_hint: str | None = None) -> dict[str, Any]:
30
+ if isinstance(payload.get("results"), list):
31
+ results = payload["results"]
32
+ preferred = None
33
+ if model_hint:
34
+ for item in results:
35
+ if str(item.get("model")) == model_hint and isinstance(item.get("result"), dict):
36
+ preferred = item["result"]
37
+ break
38
+ if preferred is not None:
39
+ return preferred
40
+ for item in results:
41
+ if item.get("status") == "ok" and isinstance(item.get("result"), dict):
42
+ return item["result"]
43
+ if isinstance(payload.get("result"), dict):
44
+ return payload["result"]
45
+ return payload
46
+
47
+
48
+ def _as_float(value: Any) -> float | None:
49
+ try:
50
+ out = float(value)
51
+ except Exception:
52
+ return None
53
+ if out != out: # NaN
54
+ return None
55
+ return out
56
+
57
+
58
+ def _extract_parakeet_words(payload: dict[str, Any]) -> list[Word]:
59
+ result = _unwrap_result(payload, model_hint="NVIDIA Parakeet v3")
60
+ raw_output = result.get("raw_output", {}) if isinstance(result, dict) else {}
61
+ output = raw_output.get("output", {}) if isinstance(raw_output, dict) else {}
62
+ timestamp = output.get("timestamp", {}) if isinstance(output, dict) else {}
63
+ word_items = timestamp.get("word", []) if isinstance(timestamp, dict) else []
64
+
65
+ words: list[Word] = []
66
+ for item in word_items if isinstance(word_items, list) else []:
67
+ if not isinstance(item, dict):
68
+ continue
69
+ text = str(item.get("word", "")).strip()
70
+ start = _as_float(item.get("start"))
71
+ end = _as_float(item.get("end"))
72
+ if not text or start is None or end is None:
73
+ continue
74
+ if end < start:
75
+ continue
76
+ words.append(Word(text=text, start=start, end=end))
77
+
78
+ words.sort(key=lambda w: (w.start, w.end))
79
+ return words
80
+
81
+
82
+ def _extract_pyannote_segments(payload: dict[str, Any], diarization_key: str) -> list[Segment]:
83
+ result = _unwrap_result(payload, model_hint="pyannote/speaker-diarization-community-1")
84
+ raw_output = result.get("raw_output", {}) if isinstance(result, dict) else {}
85
+
86
+ stitched = raw_output.get("stitched", {}) if isinstance(raw_output, dict) else {}
87
+ seg_items = []
88
+ if isinstance(stitched, dict):
89
+ seg_items = stitched.get(diarization_key, [])
90
+
91
+ if not seg_items and isinstance(raw_output, dict):
92
+ # Fallback for direct chunk output shape.
93
+ seg_items = (
94
+ raw_output.get(diarization_key, {}).get("segments", [])
95
+ if isinstance(raw_output.get(diarization_key), dict)
96
+ else []
97
+ )
98
+
99
+ segments: list[Segment] = []
100
+ for item in seg_items if isinstance(seg_items, list) else []:
101
+ if not isinstance(item, dict):
102
+ continue
103
+ speaker = str(item.get("speaker", "")).strip() or "SPEAKER_XX"
104
+ start = _as_float(item.get("start"))
105
+ end = _as_float(item.get("end"))
106
+ if start is None or end is None:
107
+ continue
108
+ if end < start:
109
+ continue
110
+ segments.append(Segment(speaker=speaker, start=start, end=end))
111
+
112
+ segments.sort(key=lambda s: (s.start, s.end))
113
+ return segments
114
+
115
+
116
+ def _segment_distance_to_time(seg: Segment, t: float) -> float:
117
+ if seg.start <= t <= seg.end:
118
+ return 0.0
119
+ if t < seg.start:
120
+ return seg.start - t
121
+ return t - seg.end
122
+
123
+
124
+ def _assign_words_to_segments(words: list[Word], segments: list[Segment]) -> list[list[Word]]:
125
+ assigned: list[list[Word]] = [[] for _ in segments]
126
+ if not words or not segments:
127
+ return assigned
128
+
129
+ seg_idx = 0
130
+ n = len(segments)
131
+
132
+ for w in words:
133
+ mid = (w.start + w.end) / 2.0
134
+
135
+ while seg_idx + 1 < n and segments[seg_idx].end <= mid:
136
+ seg_idx += 1
137
+
138
+ candidates = {seg_idx}
139
+ if seg_idx - 1 >= 0:
140
+ candidates.add(seg_idx - 1)
141
+ if seg_idx + 1 < n:
142
+ candidates.add(seg_idx + 1)
143
+
144
+ best_idx = min(candidates, key=lambda i: _segment_distance_to_time(segments[i], mid))
145
+ assigned[best_idx].append(w)
146
+
147
+ return assigned
148
+
149
+
150
+ def _join_words(words: list[Word]) -> str:
151
+ if not words:
152
+ return ""
153
+ out = words[0].text
154
+ for w in words[1:]:
155
+ if w.text and w.text[0] in ",.!?;:)]}":
156
+ out += w.text
157
+ else:
158
+ out += " " + w.text
159
+ return out.strip()
160
+
161
+
162
+ def merge_parakeet_with_pyannote(
163
+ parakeet_json: dict[str, Any],
164
+ pyannote_json: dict[str, Any],
165
+ diarization_key: str = "exclusive_speaker_diarization",
166
+ ) -> dict[str, Any]:
167
+ words = _extract_parakeet_words(parakeet_json)
168
+ segments = _extract_pyannote_segments(pyannote_json, diarization_key=diarization_key)
169
+
170
+ if not words:
171
+ raise ValueError("No Parakeet word-level timestamps found.")
172
+ if not segments:
173
+ raise ValueError(f"No Pyannote segments found for key '{diarization_key}'.")
174
+
175
+ words_by_segment = _assign_words_to_segments(words, segments)
176
+
177
+ turns: list[dict[str, Any]] = []
178
+ for seg, seg_words in zip(segments, words_by_segment):
179
+ if not seg_words:
180
+ continue
181
+ text = _join_words(seg_words)
182
+ if not text:
183
+ continue
184
+
185
+ first_word_start = seg_words[0].start
186
+ last_word_end = seg_words[-1].end
187
+ start = min(seg.start, first_word_start)
188
+ end = max(seg.end, last_word_end)
189
+
190
+ if turns and turns[-1]["speaker"] == seg.speaker:
191
+ turns[-1]["end"] = round(end, 4)
192
+ turns[-1]["text"] = (turns[-1]["text"] + " " + text).strip()
193
+ else:
194
+ turns.append(
195
+ {
196
+ "speaker": seg.speaker,
197
+ "start": round(start, 4),
198
+ "end": round(end, 4),
199
+ "text": text,
200
+ }
201
+ )
202
+
203
+ assigned_word_count = sum(len(seg_words) for seg_words in words_by_segment)
204
+ transcript_lines = [f'{t["speaker"]}: {t["text"]}' for t in turns]
205
+
206
+ return {
207
+ "summary": {
208
+ "diarization_key_used": diarization_key,
209
+ "parakeet_word_count": len(words),
210
+ "pyannote_segment_count": len(segments),
211
+ "turn_count": len(turns),
212
+ "assigned_word_count": assigned_word_count,
213
+ "unassigned_word_count": len(words) - assigned_word_count,
214
+ },
215
+ "turns": turns,
216
+ "transcript_text": "\n".join(transcript_lines),
217
+ }
218
+
219
+
220
+ def main() -> None:
221
+ parser = argparse.ArgumentParser(
222
+ description=(
223
+ "Merge Parakeet word-level transcript with Pyannote exclusive diarization. "
224
+ "Pyannote segments are the leader for speaker attribution."
225
+ )
226
+ )
227
+ parser.add_argument("--parakeet", required=True, help="Path to Parakeet JSON output file")
228
+ parser.add_argument("--pyannote", required=True, help="Path to Pyannote JSON output file")
229
+ parser.add_argument(
230
+ "--output",
231
+ default=None,
232
+ help="Output JSON path. Defaults to benchmark_outputs/merged_transcript_<timestamp>.json",
233
+ )
234
+ args = parser.parse_args()
235
+
236
+ parakeet_payload = _load_json(args.parakeet)
237
+ pyannote_payload = _load_json(args.pyannote)
238
+
239
+ merged_exclusive = merge_parakeet_with_pyannote(
240
+ parakeet_payload,
241
+ pyannote_payload,
242
+ diarization_key="exclusive_speaker_diarization",
243
+ )
244
+ merged_exclusive["inputs"] = {"parakeet_file": str(args.parakeet), "pyannote_file": str(args.pyannote)}
245
+
246
+ merged_standard = merge_parakeet_with_pyannote(
247
+ parakeet_payload,
248
+ pyannote_payload,
249
+ diarization_key="speaker_diarization",
250
+ )
251
+ merged_standard["inputs"] = {"parakeet_file": str(args.parakeet), "pyannote_file": str(args.pyannote)}
252
+
253
+ if args.output:
254
+ output_path = Path(args.output)
255
+ else:
256
+ output_path = Path("benchmark_outputs") / f"merged_transcript_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
257
+ output_path_standard = output_path.with_name(f"{output_path.stem}_speaker_diarization{output_path.suffix}")
258
+
259
+ output_path.parent.mkdir(parents=True, exist_ok=True)
260
+ output_path.write_text(json.dumps(merged_exclusive, indent=2, ensure_ascii=False), encoding="utf-8")
261
+ output_path_standard.write_text(json.dumps(merged_standard, indent=2, ensure_ascii=False), encoding="utf-8")
262
+
263
+ print(
264
+ json.dumps(
265
+ {
266
+ "output_file_exclusive_speaker_diarization": str(output_path),
267
+ "summary_exclusive_speaker_diarization": merged_exclusive["summary"],
268
+ "output_file_speaker_diarization": str(output_path_standard),
269
+ "summary_speaker_diarization": merged_standard["summary"],
270
+ },
271
+ indent=2,
272
+ )
273
+ )
274
+
275
+
276
+ if __name__ == "__main__":
277
+ main()