Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import os | |
| import re | |
| import tempfile | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Iterable | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer | |
| from transformers.generation.logits_process import LogitsProcessor, LogitsProcessorList | |
| logging.basicConfig( | |
| level=os.getenv("LOG_LEVEL", "INFO"), | |
| format="[%(asctime)s] %(levelname)s %(name)s: %(message)s", | |
| ) | |
| logger = logging.getLogger("ark_asr_space") | |
| MODEL_ID = os.getenv("ARK_ASR_MODEL_ID", "AutoArk-AI/ARK-ASR-0.6B") | |
| ASR_INSTRUCTION = os.getenv("ARK_ASR_INSTRUCTION", "Please transcribe this audio.") | |
| MAX_AUDIO_SECONDS = int(os.getenv("ARK_ASR_MAX_AUDIO_SECONDS", "30")) | |
| SAMPLING_RATE = int(os.getenv("ARK_ASR_SAMPLING_RATE", "16000")) | |
| MAX_NEW_TOKENS = int(os.getenv("ARK_ASR_MAX_NEW_TOKENS", "256")) | |
| DTYPE = os.getenv("ARK_ASR_DTYPE", "float16") | |
| ATTN_IMPL = os.getenv("ARK_ASR_ATTN_IMPL", "sdpa") | |
| ASR_BLOCK_TOKEN_ID_FROM = int(os.getenv("ARK_ASR_BLOCK_TOKEN_ID_FROM", "151670")) | |
| SPECIAL_TOKEN_PATTERN = re.compile( | |
| r"<\|(?:" | |
| r"bicodec_(?:semantic|global)_\d+|" | |
| r"(?:start|end)_(?:global_token|glm_token|semantic_token|content)" | |
| r")\|>" | |
| ) | |
| TURN_END_MARKERS = ("<|user|>", "<|assistant|>", "<|im_end|>") | |
| LEADING_NOISE_PATTERN = re.compile(r"^[\s,.;:!?-]+") | |
| CONTROL_TOKEN_PATTERN = re.compile(r"^<.*>$") | |
| class BlockTokenIdsFromLogitsProcessor(LogitsProcessor): | |
| def __init__(self, block_from_id: int | None, block_token_ids: Iterable[int] | None = None): | |
| self.block_from_id = ( | |
| None if block_from_id is None or int(block_from_id) < 0 else int(block_from_id) | |
| ) | |
| self.block_token_ids = sorted(set(int(token_id) for token_id in (block_token_ids or []))) | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| vocab_size = scores.shape[-1] | |
| if self.block_from_id is not None and self.block_from_id < vocab_size: | |
| scores[:, self.block_from_id :] = -float("inf") | |
| valid_token_ids = [token_id for token_id in self.block_token_ids if 0 <= token_id < vocab_size] | |
| if valid_token_ids: | |
| scores[:, valid_token_ids] = -float("inf") | |
| return scores | |
| class AppState: | |
| model_path: str = "" | |
| device: str = "cpu" | |
| torch_dtype: torch.dtype = torch.float32 | |
| model: Any = None | |
| processor: Any = None | |
| tokenizer: Any = None | |
| eos_token_ids: list[int] | None = None | |
| extra_block_token_ids: list[int] | None = None | |
| resolved_attn_impl: str = "" | |
| loaded_at: float = 0.0 | |
| state = AppState() | |
| load_lock = asyncio.Lock() | |
| infer_lock = asyncio.Lock() | |
| def normalize_token_ids(token_ids: Any) -> list[int]: | |
| if token_ids is None: | |
| return [] | |
| if isinstance(token_ids, (list, tuple, set)): | |
| return [int(token_id) for token_id in token_ids if token_id is not None] | |
| return [int(token_ids)] | |
| def build_eos_token_ids(tokenizer: Any) -> list[int]: | |
| eos_ids = [] | |
| eos_ids.extend(normalize_token_ids(getattr(tokenizer, "eos_token_id", None))) | |
| for marker in TURN_END_MARKERS: | |
| token_id = tokenizer.convert_tokens_to_ids(marker) | |
| if isinstance(token_id, int) and token_id >= 0: | |
| eos_ids.append(int(token_id)) | |
| return list(dict.fromkeys(eos_ids)) | |
| def build_asr_keep_token_ids(model: Any, tokenizer: Any) -> list[int]: | |
| keep_token_ids = set() | |
| keep_token_ids.update(normalize_token_ids(getattr(tokenizer, "eos_token_id", None))) | |
| keep_token_ids.update(normalize_token_ids(getattr(getattr(model, "config", None), "eos_token_id", None))) | |
| keep_token_ids.update( | |
| normalize_token_ids(getattr(getattr(model, "generation_config", None), "eos_token_id", None)) | |
| ) | |
| return sorted(keep_token_ids) | |
| def build_asr_extra_block_token_ids( | |
| tokenizer: Any, | |
| keep_token_ids: Iterable[int] | None = None, | |
| block_from_id: int | None = None, | |
| ) -> list[int]: | |
| keep = set(int(token_id) for token_id in (keep_token_ids or [])) | |
| max_control_token_id = None if block_from_id is None or int(block_from_id) < 0 else int(block_from_id) | |
| block_token_ids = { | |
| int(token_id) | |
| for token_id in getattr(tokenizer, "all_special_ids", []) | |
| if token_id is not None | |
| } | |
| added_tokens_decoder = getattr(tokenizer, "added_tokens_decoder", {}) or {} | |
| for token_id, token_meta in added_tokens_decoder.items(): | |
| token_id = int(token_id) | |
| if max_control_token_id is not None and token_id >= max_control_token_id: | |
| continue | |
| token_content = getattr(token_meta, "content", None) | |
| if token_content is None and isinstance(token_meta, dict): | |
| token_content = token_meta.get("content") | |
| if token_content and CONTROL_TOKEN_PATTERN.match(token_content): | |
| block_token_ids.add(token_id) | |
| block_token_ids.difference_update(keep) | |
| return sorted(block_token_ids) | |
| def truncate_generation_text(text: str) -> str: | |
| if not text: | |
| return "" | |
| cut = len(text) | |
| for marker in TURN_END_MARKERS: | |
| index = text.find(marker) | |
| if index != -1 and index < cut: | |
| cut = index | |
| return text[:cut].strip() | |
| def remove_special_tokens(text: str) -> str: | |
| if not text: | |
| return "" | |
| if "<|text|>" in text: | |
| text = text.split("<|text|>", 1)[1] | |
| return SPECIAL_TOKEN_PATTERN.sub("", text).strip() | |
| def normalize_prediction_text(text: str) -> str: | |
| if not text: | |
| return "" | |
| text = truncate_generation_text(text) | |
| text = remove_special_tokens(text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return LEADING_NOISE_PATTERN.sub("", text).strip() | |
| def as_dict(value: Any) -> dict[str, Any]: | |
| if isinstance(value, dict): | |
| return value | |
| if hasattr(value, "keys") and hasattr(value, "__getitem__"): | |
| return {key: value[key] for key in value.keys()} | |
| raise TypeError(f"Unexpected processor output type: {type(value)}") | |
| def resolve_torch_dtype(dtype_name: str, device: str) -> torch.dtype: | |
| if dtype_name == "auto": | |
| return torch.float16 if device == "cuda" else torch.float32 | |
| mapping = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| } | |
| if dtype_name not in mapping: | |
| raise ValueError(f"Unsupported dtype: {dtype_name}") | |
| if device != "cuda" and mapping[dtype_name] != torch.float32: | |
| return torch.float32 | |
| return mapping[dtype_name] | |
| def maybe_gpu_memory_text() -> str: | |
| if not torch.cuda.is_available(): | |
| return "GPU: not available in this runtime." | |
| index = torch.cuda.current_device() | |
| props = torch.cuda.get_device_properties(index) | |
| total_gb = props.total_memory / 1024**3 | |
| reserved_gb = torch.cuda.memory_reserved(index) / 1024**3 | |
| allocated_gb = torch.cuda.memory_allocated(index) / 1024**3 | |
| return ( | |
| f"GPU: {props.name}, total={total_gb:.1f}G, " | |
| f"reserved={reserved_gb:.1f}G, allocated={allocated_gb:.1f}G." | |
| ) | |
| def resolve_model_path() -> str: | |
| local_path = Path(MODEL_ID).expanduser() | |
| if local_path.exists(): | |
| logger.info("Using local model path: %s", local_path.resolve()) | |
| return str(local_path.resolve()) | |
| logger.info("Using Hugging Face model id: %s", MODEL_ID) | |
| return MODEL_ID | |
| def load_model(model_path: str, device: str, torch_dtype: torch.dtype, attn_impl: str): | |
| candidates = ["sdpa", "eager"] if attn_impl == "auto" else [attn_impl] | |
| if attn_impl == "flash_attention_2": | |
| candidates.extend(["sdpa", "eager"]) | |
| last_error: Exception | None = None | |
| for candidate in candidates: | |
| try: | |
| logger.info("Loading model with attn_implementation=%s", candidate) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| torch_dtype=torch_dtype, | |
| attn_implementation=candidate, | |
| ).to(device) | |
| model.eval() | |
| return model, candidate | |
| except (ImportError, RuntimeError, ValueError) as exc: | |
| if candidate != "flash_attention_2": | |
| raise | |
| logger.warning("flash_attention_2 unavailable, falling back: %s", str(exc).splitlines()[0]) | |
| last_error = exc | |
| if last_error is not None: | |
| raise last_error | |
| raise RuntimeError("Failed to load model") | |
| async def ensure_loaded() -> None: | |
| if state.model is not None: | |
| return | |
| async with load_lock: | |
| if state.model is not None: | |
| return | |
| started = time.perf_counter() | |
| state.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| state.torch_dtype = resolve_torch_dtype(DTYPE, state.device) | |
| state.model_path = resolve_model_path() | |
| logger.info( | |
| "Loading Transformers ASR stack: model=%s device=%s dtype=%s", | |
| state.model_path, | |
| state.device, | |
| state.torch_dtype, | |
| ) | |
| state.model, state.resolved_attn_impl = await asyncio.to_thread( | |
| load_model, | |
| state.model_path, | |
| state.device, | |
| state.torch_dtype, | |
| ATTN_IMPL, | |
| ) | |
| state.tokenizer = AutoTokenizer.from_pretrained( | |
| state.model_path, | |
| trust_remote_code=True, | |
| fix_mistral_regex=True, | |
| ) | |
| if state.tokenizer.pad_token_id is None: | |
| state.tokenizer.pad_token_id = state.tokenizer.eos_token_id | |
| state.tokenizer.padding_side = "left" | |
| state.processor = AutoProcessor.from_pretrained( | |
| state.model_path, | |
| trust_remote_code=True, | |
| fix_mistral_regex=True, | |
| ) | |
| if hasattr(state.processor, "tokenizer"): | |
| if state.processor.tokenizer.pad_token_id is None: | |
| state.processor.tokenizer.pad_token_id = state.tokenizer.pad_token_id | |
| state.processor.tokenizer.padding_side = "left" | |
| state.eos_token_ids = build_eos_token_ids(state.tokenizer) | |
| keep_token_ids = build_asr_keep_token_ids(state.model, state.tokenizer) | |
| state.extra_block_token_ids = build_asr_extra_block_token_ids( | |
| state.tokenizer, | |
| keep_token_ids=keep_token_ids, | |
| block_from_id=ASR_BLOCK_TOKEN_ID_FROM, | |
| ) | |
| state.loaded_at = time.time() | |
| logger.info( | |
| "Transformers ASR stack loaded in %.2fs with attn=%s", | |
| time.perf_counter() - started, | |
| state.resolved_attn_impl, | |
| ) | |
| def build_conversation(audio_path: str, begin_time: float, end_time: float) -> list[dict[str, Any]]: | |
| return [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "audio", | |
| "path": audio_path, | |
| "begin_time": begin_time, | |
| "end_time": end_time, | |
| }, | |
| {"type": "text", "text": ASR_INSTRUCTION}, | |
| ], | |
| } | |
| ] | |
| def audio_to_path(audio: str | tuple[int, Any] | None) -> tuple[str, str | None]: | |
| if audio is None: | |
| raise gr.Error("Please upload or record an audio clip first.") | |
| if isinstance(audio, str): | |
| return audio, None | |
| sample_rate, data = audio | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| tmp.close() | |
| import soundfile as sf | |
| sf.write(tmp.name, data, sample_rate) | |
| return tmp.name, tmp.name | |
| def run_transformers_generation( | |
| audio_path: str, | |
| begin_time: float, | |
| end_time: float, | |
| max_new_tokens: int, | |
| ) -> tuple[str, int]: | |
| inputs_raw = state.processor.apply_chat_template( | |
| [build_conversation(audio_path, begin_time, end_time)], | |
| return_tensors="pt", | |
| sampling_rate=SAMPLING_RATE, | |
| audio_padding="longest", | |
| add_generation_prompt=True, | |
| text_kwargs={"padding": "longest"}, | |
| audio_max_length=int(MAX_AUDIO_SECONDS * SAMPLING_RATE), | |
| ) | |
| if torch.is_tensor(inputs_raw): | |
| raise RuntimeError("ASR apply_chat_template returned Tensor-only; audio was not encoded.") | |
| inputs = as_dict(inputs_raw) | |
| if "audios" not in inputs: | |
| raise RuntimeError(f"ASR inputs missing 'audios'; processor keys={list(inputs.keys())}") | |
| if "attention_mask" not in inputs and "input_ids" in inputs and torch.is_tensor(inputs["input_ids"]): | |
| inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long) | |
| for key, value in list(inputs.items()): | |
| if not torch.is_tensor(value): | |
| continue | |
| if key == "audios": | |
| inputs[key] = value.to(device=state.device, dtype=state.torch_dtype) | |
| else: | |
| inputs[key] = value.to(state.device) | |
| generate_kwargs: dict[str, Any] = { | |
| "max_new_tokens": int(max_new_tokens), | |
| "do_sample": False, | |
| "pad_token_id": state.tokenizer.pad_token_id, | |
| } | |
| if state.eos_token_ids: | |
| generate_kwargs["eos_token_id"] = state.eos_token_ids | |
| if ASR_BLOCK_TOKEN_ID_FROM >= 0 or state.extra_block_token_ids: | |
| generate_kwargs["logits_processor"] = LogitsProcessorList( | |
| [ | |
| BlockTokenIdsFromLogitsProcessor( | |
| block_from_id=ASR_BLOCK_TOKEN_ID_FROM, | |
| block_token_ids=state.extra_block_token_ids, | |
| ) | |
| ] | |
| ) | |
| with torch.inference_mode(): | |
| outputs = state.model.generate(**inputs, **generate_kwargs) | |
| input_ids = inputs["input_ids"] | |
| generated_ids = outputs[0][len(input_ids[0].tolist()) :] | |
| prediction_raw = state.tokenizer.decode(generated_ids, skip_special_tokens=False) | |
| return normalize_prediction_text(prediction_raw), int(input_ids.shape[-1]) | |
| async def transcribe( | |
| audio: str | tuple[int, Any] | None, | |
| max_new_tokens: int, | |
| begin_time: float, | |
| end_time: float, | |
| ) -> str: | |
| started = time.perf_counter() | |
| tmp_path: str | None = None | |
| try: | |
| logger.info("Transcribe request started") | |
| await ensure_loaded() | |
| audio_path, tmp_path = audio_to_path(audio) | |
| async with infer_lock: | |
| text, prompt_tokens = await asyncio.to_thread( | |
| run_transformers_generation, | |
| audio_path, | |
| begin_time, | |
| end_time, | |
| int(max_new_tokens), | |
| ) | |
| elapsed = time.perf_counter() - started | |
| logger.info("Transcribe request finished in %.2fs", elapsed) | |
| logger.info( | |
| "Generation metadata: prompt_tokens=%s model=%s backend=transformers/%s %s", | |
| prompt_tokens, | |
| MODEL_ID, | |
| state.resolved_attn_impl, | |
| maybe_gpu_memory_text(), | |
| ) | |
| return text | |
| except gr.Error: | |
| raise | |
| except Exception as exc: | |
| logger.exception("ASR request failed") | |
| raise gr.Error(f"{exc.__class__.__name__}: {exc}") from exc | |
| finally: | |
| if tmp_path: | |
| try: | |
| os.unlink(tmp_path) | |
| except OSError: | |
| pass | |
| APP_CSS = """ | |
| .gradio-container { | |
| max-width: 1120px !important; | |
| margin: 0 auto !important; | |
| background: | |
| radial-gradient(circle at top left, rgba(27, 99, 146, 0.12), transparent 34rem), | |
| linear-gradient(180deg, #f6f8fb 0%, #eef3f7 100%); | |
| color: #172033; | |
| } | |
| .ark-header { | |
| padding: 26px 4px 20px; | |
| border-bottom: 1px solid rgba(23, 32, 51, 0.12); | |
| margin-bottom: 18px; | |
| } | |
| .ark-eyebrow { | |
| margin: 0 0 7px; | |
| color: #536579; | |
| font-size: 15px; | |
| font-weight: 700; | |
| letter-spacing: 0.04em; | |
| text-transform: uppercase; | |
| } | |
| .ark-title { | |
| margin: 0; | |
| color: #101828; | |
| font-size: 40px; | |
| line-height: 1.1; | |
| font-weight: 800; | |
| } | |
| .ark-subtitle { | |
| max-width: 920px; | |
| margin: 12px 0 0; | |
| color: #405166; | |
| font-size: 17px; | |
| line-height: 1.5; | |
| } | |
| .ark-opd { | |
| color: #0b5cad; | |
| font-weight: 800; | |
| } | |
| .ark-badges { | |
| display: flex; | |
| flex-wrap: wrap; | |
| justify-content: flex-start; | |
| gap: 8px; | |
| margin-top: 14px; | |
| } | |
| .ark-badge { | |
| display: inline-flex; | |
| align-items: center; | |
| height: 28px; | |
| border-radius: 4px; | |
| text-decoration: none !important; | |
| background: transparent; | |
| box-shadow: 0 1px 2px rgba(16, 24, 40, 0.12); | |
| } | |
| .ark-badge img { | |
| display: block; | |
| height: 28px; | |
| } | |
| .ark-panel { | |
| border: 1px solid rgba(23, 32, 51, 0.12); | |
| border-radius: 8px; | |
| background: rgba(255, 255, 255, 0.9); | |
| padding: 16px; | |
| box-shadow: 0 10px 32px rgba(16, 24, 40, 0.06); | |
| } | |
| .ark-panel textarea { | |
| font-size: 17px !important; | |
| line-height: 1.65 !important; | |
| } | |
| .ark-panel button.primary, | |
| .ark-panel button[variant="primary"] { | |
| border-radius: 8px !important; | |
| } | |
| @media (max-width: 760px) { | |
| .ark-badges { | |
| max-width: 100%; | |
| } | |
| .ark-title { | |
| font-size: 32px; | |
| } | |
| .ark-subtitle { | |
| font-size: 16px; | |
| } | |
| } | |
| """ | |
| with gr.Blocks(title="Ark ASR 0.6B", css=APP_CSS) as demo: | |
| gr.HTML( | |
| """ | |
| <header class="ark-header"> | |
| <p class="ark-eyebrow">Industrial Audio Online Policy Distillation</p> | |
| <h1 class="ark-title">Ark ASR 0.6B</h1> | |
| <p class="ark-subtitle"><span class="ark-opd">Open Audio OPD</span> brings online policy distillation to ASR, with the best overall results among the 0.6B-scale ASR models compared in the project.</p> | |
| <nav class="ark-badges" aria-label="Project links"> | |
| <a class="ark-badge" href="https://github.com/AutoArk/open-audio-opd" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/GitHub-open--audio--opd-black?style=for-the-badge&logo=github&logoColor=white" alt="GitHub open-audio-opd"></a> | |
| <a class="ark-badge" href="https://huggingface.co/AutoArk-AI/ARK-ASR-0.6B" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-ARK--ASR--0.6B-yellow?style=for-the-badge" alt="Hugging Face ARK-ASR-0.6B"></a> | |
| <a class="ark-badge" href="https://github.com/AutoArk/open-audio-opd/blob/main/paper/arxiv_ark_asr_opd/main.pdf" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Paper-PDF-b31b1b?style=for-the-badge&logo=readthedocs&logoColor=white" alt="Paper PDF"></a> | |
| <a class="ark-badge" href="https://github.com/AutoArk/open-audio-opd/blob/main/LICENSE" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/License-See%20LICENSE-blue?style=for-the-badge" alt="License"></a> | |
| </nav> | |
| </header> | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, elem_classes=["ark-panel"]): | |
| audio_input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Audio", | |
| ) | |
| with gr.Row(): | |
| begin_input = gr.Number(value=-1, label="Begin time") | |
| end_input = gr.Number(value=-1, label="End time") | |
| max_tokens_input = gr.Slider( | |
| minimum=16, | |
| maximum=512, | |
| value=MAX_NEW_TOKENS, | |
| step=16, | |
| label="Max new tokens", | |
| ) | |
| transcribe_button = gr.Button("Transcribe", variant="primary") | |
| with gr.Column(scale=1, elem_classes=["ark-panel"]): | |
| text_output = gr.Textbox(label="Transcript", lines=8) | |
| transcribe_button.click( | |
| transcribe, | |
| inputs=[audio_input, max_tokens_input, begin_input, end_input], | |
| outputs=text_output, | |
| ) | |
| if __name__ == "__main__": | |
| asyncio.run(ensure_loaded()) | |
| demo.queue(default_concurrency_limit=1).launch() | |