| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import Dict, Iterable, List, Optional, Tuple |
|
|
| import soundfile as sf |
| from fastapi import HTTPException |
|
|
| from src.asr import transcribe_file |
| from src.diarization import ( |
| get_diarization_stats, |
| init_speaker_embedding_extractor, |
| merge_consecutive_utterances, |
| merge_transcription_with_diarization, |
| perform_speaker_diarization_on_utterances, |
| ) |
| from src.utils import sensevoice_models |
|
|
| from ..core.config import get_settings |
| from ..models.transcription import DiarizationOptions, TranscriptionRequest |
|
|
| settings = get_settings() |
|
|
|
|
| def _serialize_utterance(utt: Tuple[float, float, str], speaker: Optional[int] = None) -> Dict[str, object]: |
| start, end, text = utt |
| payload: Dict[str, object] = { |
| "start": round(float(start), 3), |
| "end": round(float(end), 3), |
| "text": text, |
| } |
| if speaker is not None: |
| payload["speaker"] = int(speaker) |
| return payload |
|
|
|
|
| def _prepare_model_name(options: TranscriptionRequest) -> str: |
| if options.backend == "sensevoice": |
| |
| return sensevoice_models.get(options.model_name, options.model_name) |
| return options.model_name |
|
|
|
|
| def iter_transcription_events( |
| audio_path: Path, |
| audio_url: str, |
| options: TranscriptionRequest, |
| ) -> Iterable[Dict[str, object]]: |
| model_name = _prepare_model_name(options) |
|
|
| try: |
| generator = transcribe_file( |
| audio_path=str(audio_path), |
| vad_threshold=options.vad_threshold, |
| model_name=model_name, |
| backend=options.backend, |
| language=options.language, |
| textnorm=options.textnorm, |
| ) |
|
|
| yield { |
| "type": "ready", |
| "audioUrl": audio_url, |
| "backend": options.backend, |
| "model": model_name, |
| } |
|
|
| yield { |
| "type": "status", |
| "message": "Transcribing audio...", |
| } |
|
|
| final_utterances: List[Tuple[float, float, str]] = [] |
|
|
| for current_utterance, all_utterances, progress in generator: |
| if current_utterance: |
| start, end, text = current_utterance |
| yield { |
| "type": "utterance", |
| "utterance": _serialize_utterance((start, end, text)), |
| "index": len(all_utterances) - 1, |
| "progress": round(progress, 1), |
| } |
| final_utterances = list(all_utterances) |
|
|
| |
| diarization_payload = None |
| if options.diarization.enable: |
| yield { |
| "type": "status", |
| "message": "Performing speaker diarization...", |
| } |
| diarization_gen = _run_diarization(audio_path, final_utterances, options.diarization) |
| for event in diarization_gen: |
| if event["type"] == "progress": |
| yield event |
| elif event["type"] == "result": |
| diarization_payload = event["payload"] |
| break |
|
|
| transcript_text = "\n".join([utt[2] for utt in final_utterances]) |
|
|
| yield { |
| "type": "complete", |
| "utterances": [_serialize_utterance(utt) for utt in final_utterances], |
| "transcript": transcript_text, |
| "diarization": diarization_payload, |
| } |
|
|
| except Exception as exc: |
| raise HTTPException(status_code=500, detail=f"Transcription failed: {exc}") |
|
|
|
|
| def _run_diarization( |
| audio_path: Path, |
| utterances: List[Tuple[float, float, str]], |
| options: DiarizationOptions, |
| ): |
| if not utterances: |
| yield {"type": "result", "payload": None} |
| return |
|
|
| extractor_result = init_speaker_embedding_extractor( |
| cluster_threshold=options.cluster_threshold, |
| num_speakers=options.num_speakers, |
| ) |
| if not extractor_result: |
| yield {"type": "result", "payload": None} |
| return |
|
|
| embedding_extractor, config_dict = extractor_result |
|
|
| audio, sample_rate = sf.read(str(audio_path), dtype="float32") |
| if audio.ndim > 1: |
| audio = audio.mean(axis=1) |
|
|
| if sample_rate != 16000: |
| |
| from scipy.signal import resample |
|
|
| target_num_samples = int(len(audio) * 16000 / sample_rate) |
| audio = resample(audio, target_num_samples) |
| sample_rate = 16000 |
|
|
| diarization_gen = perform_speaker_diarization_on_utterances( |
| audio=audio, |
| sample_rate=sample_rate, |
| utterances=utterances, |
| embedding_extractor=embedding_extractor, |
| config_dict=config_dict, |
| progress_callback=None, |
| ) |
|
|
| diarization_segments = None |
| try: |
| while True: |
| item = next(diarization_gen) |
| if isinstance(item, float): |
| yield {"type": "progress", "stage": "diarization", "progress": round(item * 100, 1)} |
| else: |
| diarization_segments = item |
| break |
| except StopIteration as e: |
| diarization_segments = e.value |
|
|
| if not diarization_segments: |
| yield {"type": "result", "payload": None} |
| return |
|
|
| merged = merge_transcription_with_diarization(utterances, diarization_segments) |
| merged = merge_consecutive_utterances(merged, max_gap=1.0) |
| stats = get_diarization_stats(merged) |
|
|
| yield {"type": "result", "payload": { |
| "utterances": [ |
| _serialize_utterance((start, end, text), speaker) |
| for start, end, text, speaker in merged |
| ], |
| "stats": stats, |
| }} |
|
|