Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import gc | |
| import os | |
| import re | |
| import tempfile | |
| import time | |
| import unicodedata | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from faster_whisper import BatchedInferencePipeline, WhisperModel | |
| from jiwer import cer, wer | |
| from transformers import AutoModel | |
| WHISPER_MODEL_ID = "Sh1man/whisper-large-v3-russian-ties-podlodka-v1.2-ct" | |
| GIGAAM_MODEL_ID = "ai-sage/GigaAM-v3" | |
| GIGAAM_REVISION = "e2e_rnnt" | |
| TARGET_SAMPLE_RATE = 16_000 | |
| WHISPER_BEAM_SIZE = 5 | |
| WHISPER_BATCH_SIZE = 8 if torch.cuda.is_available() else 4 | |
| WHISPER_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| WHISPER_COMPUTE_TYPE = "float16" if torch.cuda.is_available() else "int8" | |
| MODEL_LABELS = { | |
| "whisper": "Sh1man Whisper Large V3 CT", | |
| "gigaam": "GigaAM v3 e2e RNNT", | |
| } | |
| MODEL_STATE: dict[str, Any] = {"name": None, "instance": None} | |
| def cleanup_loaded_model() -> None: | |
| loaded = MODEL_STATE.get("instance") | |
| MODEL_STATE["name"] = None | |
| MODEL_STATE["instance"] = None | |
| if loaded is not None: | |
| del loaded | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def get_model(model_name: str) -> Any: | |
| if MODEL_STATE["name"] == model_name and MODEL_STATE["instance"] is not None: | |
| return MODEL_STATE["instance"] | |
| cleanup_loaded_model() | |
| if model_name == "whisper": | |
| whisper_model = WhisperModel( | |
| WHISPER_MODEL_ID, | |
| device=WHISPER_DEVICE, | |
| compute_type=WHISPER_COMPUTE_TYPE, | |
| ) | |
| model = BatchedInferencePipeline(model=whisper_model) | |
| elif model_name == "gigaam": | |
| model = AutoModel.from_pretrained( | |
| GIGAAM_MODEL_ID, | |
| revision=GIGAAM_REVISION, | |
| trust_remote_code=True, | |
| ) | |
| if hasattr(model, "eval"): | |
| model.eval() | |
| if torch.cuda.is_available() and hasattr(model, "to"): | |
| model = model.to("cuda") | |
| else: | |
| raise ValueError(f"Unsupported model name: {model_name}") | |
| MODEL_STATE["name"] = model_name | |
| MODEL_STATE["instance"] = model | |
| return model | |
| def collapse_spaces(text: str) -> str: | |
| return " ".join(text.split()) | |
| def normalize_for_metrics(text: str, enabled: bool) -> str: | |
| text = unicodedata.normalize("NFKC", text.strip()) | |
| if not enabled: | |
| return collapse_spaces(text) | |
| text = text.lower().replace("ё", "е") | |
| text = re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE) | |
| text = text.replace("_", " ") | |
| return collapse_spaces(text) | |
| def extract_text(result: Any) -> str: | |
| if isinstance(result, str): | |
| return result | |
| if isinstance(result, dict): | |
| for key in ("text", "transcription", "prediction"): | |
| value = result.get(key) | |
| if isinstance(value, str): | |
| return value | |
| if "chunks" in result and isinstance(result["chunks"], list): | |
| return " ".join( | |
| extract_text(chunk) for chunk in result["chunks"] if chunk is not None | |
| ).strip() | |
| if isinstance(result, list): | |
| return " ".join(extract_text(item) for item in result if item is not None).strip() | |
| return str(result) | |
| def prepare_audio_file(audio_path: str) -> tuple[tempfile.TemporaryDirectory, str, float]: | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| if sample_rate != TARGET_SAMPLE_RATE: | |
| waveform = torchaudio.functional.resample(waveform, sample_rate, TARGET_SAMPLE_RATE) | |
| duration_seconds = waveform.shape[1] / TARGET_SAMPLE_RATE | |
| temp_dir = tempfile.TemporaryDirectory() | |
| prepared_audio_path = Path(temp_dir.name) / "prepared_audio.wav" | |
| torchaudio.save(str(prepared_audio_path), waveform, TARGET_SAMPLE_RATE) | |
| return temp_dir, str(prepared_audio_path), duration_seconds | |
| def transcribe_with_whisper(prepared_audio_path: str) -> tuple[str, str]: | |
| transcriber = get_model("whisper") | |
| segments, _ = transcriber.transcribe( | |
| prepared_audio_path, | |
| batch_size=WHISPER_BATCH_SIZE, | |
| beam_size=WHISPER_BEAM_SIZE, | |
| language="ru", | |
| word_timestamps=False, | |
| ) | |
| transcription = collapse_spaces(" ".join(segment.text for segment in segments if segment.text)) | |
| mode_note = ( | |
| "Whisper использовал `faster-whisper` + `BatchedInferencePipeline` " | |
| f"с VAD по умолчанию, `beam_size={WHISPER_BEAM_SIZE}`, " | |
| f"`batch_size={WHISPER_BATCH_SIZE}`, `compute_type={WHISPER_COMPUTE_TYPE}`." | |
| ) | |
| return transcription, mode_note | |
| def format_boundary(boundary: Any) -> str: | |
| if not isinstance(boundary, (tuple, list)) or len(boundary) != 2: | |
| return "" | |
| start, end = boundary | |
| return f"[{start:.2f}-{end:.2f}]" | |
| def extract_longform_text(result: Any) -> str: | |
| if not isinstance(result, list): | |
| return collapse_spaces(extract_text(result)) | |
| parts: list[str] = [] | |
| for segment in result: | |
| if isinstance(segment, dict): | |
| segment_text = extract_text(segment) | |
| else: | |
| segment_text = extract_text(segment) | |
| if segment_text: | |
| parts.append(collapse_spaces(segment_text)) | |
| return collapse_spaces(" ".join(parts)) | |
| def transcribe_with_gigaam(audio_path: str) -> tuple[str, int]: | |
| if not os.getenv("HF_TOKEN"): | |
| raise ValueError( | |
| "Для GigaAM longform нужен секрет HF_TOKEN с доступом к " | |
| "'pyannote/segmentation-3.0'. Добавь его в Settings -> Variables and secrets." | |
| ) | |
| transcriber = get_model("gigaam") | |
| with torch.inference_mode(): | |
| result = transcriber.transcribe_longform(audio_path) | |
| return extract_longform_text(result), len(result) if isinstance(result, list) else 0 | |
| def load_reference_text(reference_text: str, reference_file: str | None) -> str: | |
| if reference_text.strip(): | |
| return reference_text.strip() | |
| if reference_file: | |
| for encoding in ("utf-8", "utf-8-sig", "cp1251"): | |
| try: | |
| return Path(reference_file).read_text(encoding=encoding).strip() | |
| except UnicodeDecodeError: | |
| continue | |
| raise ValueError("Не удалось прочитать эталонный текстовый файл.") | |
| return "" | |
| def format_metric(value: float | None) -> str: | |
| if value is None: | |
| return "n/a" | |
| return f"{value:.4f}" | |
| def benchmark_audio( | |
| audio_path: str | None, | |
| reference_text: str, | |
| reference_file: str | None, | |
| selected_models: list[str], | |
| normalize_metrics: bool, | |
| ) -> tuple[list[list[Any]], str, str, str]: | |
| if not audio_path: | |
| raise gr.Error("Загрузи аудиофайл для транскрибации.") | |
| if not selected_models: | |
| raise gr.Error("Выбери хотя бы одну модель.") | |
| reference = load_reference_text(reference_text, reference_file) | |
| normalized_reference = normalize_for_metrics(reference, normalize_metrics) if reference else "" | |
| temporary_dir: tempfile.TemporaryDirectory | None = None | |
| try: | |
| temporary_dir, prepared_audio_path, duration_seconds = prepare_audio_file(audio_path) | |
| whisper_text = "Модель не запускалась." | |
| gigaam_text = "Модель не запускалась." | |
| rows: list[list[Any]] = [] | |
| whisper_mode_note: str | None = None | |
| gigaam_segment_count: int | None = None | |
| for model_name in selected_models: | |
| started_at = time.perf_counter() | |
| if model_name == "whisper": | |
| transcription, whisper_mode_note = transcribe_with_whisper(prepared_audio_path) | |
| whisper_text = transcription or "Пустой результат." | |
| elif model_name == "gigaam": | |
| transcription, gigaam_segment_count = transcribe_with_gigaam(prepared_audio_path) | |
| gigaam_text = transcription or "Пустой результат." | |
| else: | |
| continue | |
| elapsed = time.perf_counter() - started_at | |
| current_wer: float | None = None | |
| current_cer: float | None = None | |
| if normalized_reference: | |
| normalized_prediction = normalize_for_metrics(transcription, normalize_metrics) | |
| current_wer = wer(normalized_reference, normalized_prediction) | |
| current_cer = cer(normalized_reference, normalized_prediction) | |
| rows.append( | |
| [ | |
| MODEL_LABELS[model_name], | |
| format_metric(current_wer), | |
| format_metric(current_cer), | |
| round(elapsed, 2), | |
| ] | |
| ) | |
| summary_lines = [ | |
| f"- Длительность аудио: `{duration_seconds:.1f}` сек.", | |
| ] | |
| if whisper_mode_note is not None: | |
| summary_lines.append(f"- {whisper_mode_note}") | |
| if gigaam_segment_count is not None: | |
| summary_lines.append( | |
| f"- GigaAM использовал встроенный `transcribe_longform` и собрал `{gigaam_segment_count}` сегментов через VAD." | |
| ) | |
| if reference: | |
| normalization_note = "с нормализацией" if normalize_metrics else "без нормализации" | |
| summary_lines.append(f"- `WER` и `CER` посчитаны {normalization_note}.") | |
| else: | |
| summary_lines.append("- Эталонный текст не задан, метрики пропущены.") | |
| return rows, whisper_text, gigaam_text, "\n".join(summary_lines) | |
| except Exception as error: | |
| raise gr.Error(f"Ошибка обработки: {error}") from error | |
| finally: | |
| if temporary_dir is not None: | |
| temporary_dir.cleanup() | |
| with gr.Blocks(title="Russian ASR Benchmark Space") as demo: | |
| gr.Markdown( | |
| """ | |
| # Russian ASR Benchmark | |
| Сравнение двух ASR-моделей: | |
| - `Sh1man/whisper-large-v3-russian-ties-podlodka-v1.2-ct` | |
| - `ai-sage/GigaAM-v3` c revision `e2e_rnnt` | |
| Загрузи аудио, вставь эталонный текст или приложи `.txt`, и Space посчитает `WER` / `CER` для каждой модели. | |
| Для `GigaAM` используется встроенный `transcribe_longform`. Для него нужен `HF_TOKEN` | |
| в секретах Space с доступом к `pyannote/segmentation-3.0`. | |
| """ | |
| ) | |
| with gr.Row(): | |
| audio_input = gr.Audio( | |
| label="Аудиофайл", | |
| type="filepath", | |
| sources=["upload", "microphone"], | |
| ) | |
| with gr.Column(): | |
| reference_input = gr.Textbox( | |
| label="Эталонный текст", | |
| placeholder="Вставь правильную расшифровку сюда", | |
| lines=10, | |
| ) | |
| reference_file_input = gr.File( | |
| label="Или загрузи эталонный текст (.txt)", | |
| file_types=[".txt"], | |
| type="filepath", | |
| ) | |
| with gr.Row(): | |
| model_selector = gr.CheckboxGroup( | |
| label="Модели для запуска", | |
| choices=[ | |
| ("Sh1man Whisper Large V3 CT", "whisper"), | |
| ("GigaAM v3 e2e RNNT", "gigaam"), | |
| ], | |
| value=["whisper", "gigaam"], | |
| ) | |
| normalize_checkbox = gr.Checkbox( | |
| label="Нормализовать текст перед подсчётом метрик", | |
| value=True, | |
| info="Приводит текст к нижнему регистру, схлопывает пробелы и убирает пунктуацию.", | |
| ) | |
| run_button = gr.Button("Транскрибировать и посчитать метрики", variant="primary") | |
| results_table = gr.Dataframe( | |
| headers=["Модель", "WER", "CER", "Время (сек)"], | |
| datatype=["str", "str", "str", "number"], | |
| label="Результаты сравнения", | |
| ) | |
| status_output = gr.Markdown("Статус появится после запуска.") | |
| with gr.Row(): | |
| whisper_output = gr.Textbox( | |
| label="Транскрипт: Sh1man Whisper Large V3 CT", | |
| lines=12, | |
| ) | |
| gigaam_output = gr.Textbox( | |
| label="Транскрипт: GigaAM v3 e2e RNNT", | |
| lines=12, | |
| ) | |
| run_button.click( | |
| fn=benchmark_audio, | |
| inputs=[ | |
| audio_input, | |
| reference_input, | |
| reference_file_input, | |
| model_selector, | |
| normalize_checkbox, | |
| ], | |
| outputs=[ | |
| results_table, | |
| whisper_output, | |
| gigaam_output, | |
| status_output, | |
| ], | |
| ) | |
| gr.Markdown( | |
| """ | |
| Первая инференс-сессия может идти заметно дольше из-за скачивания весов. | |
| `Whisper` здесь настроен как `faster-whisper` на CTranslate2 через `BatchedInferencePipeline` | |
| с VAD по умолчанию и `beam_size=5`. `GigaAM` использует встроенный longform-режим через | |
| `transcribe_longform` и VAD из `pyannote/segmentation-3.0`. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(default_concurrency_limit=1).launch() | |