| | |
| |
|
| | from io import BytesIO |
| | from typing import Optional, Dict, Any, List, Set, Union, Tuple |
| |
|
| | |
| | import os |
| | import time |
| | import asyncio |
| |
|
| | |
| | from fastapi import FastAPI, File, UploadFile, HTTPException, Depends |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| | from fastapi.responses import HTMLResponse |
| | import numpy as np |
| | import torch |
| | import torchaudio |
| | from funasr import AutoModel |
| | from dotenv import load_dotenv |
| | import os |
| | import time |
| | import gradio as gr |
| |
|
| | |
| | load_dotenv() |
| |
|
| | |
| | API_TOKEN: str = os.getenv("API_TOKEN") |
| | if not API_TOKEN: |
| | raise RuntimeError("API_TOKEN environment variable is not set") |
| |
|
| | |
| | security = HTTPBearer() |
| |
|
| | app = FastAPI( |
| | title="SenseVoice API", |
| | description="Speech To Text API Service", |
| | version="1.0.0" |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | model = AutoModel( |
| | model="FunAudioLLM/SenseVoiceSmall", |
| | vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | vad_kwargs={"max_single_segment_time": 30000}, |
| | hub="hf", |
| | device="cuda" |
| | ) |
| |
|
| | emotion_dict: Dict[str, str] = { |
| | "<|HAPPY|>": "๐", |
| | "<|SAD|>": "๐", |
| | "<|ANGRY|>": "๐ก", |
| | "<|NEUTRAL|>": "", |
| | "<|FEARFUL|>": "๐ฐ", |
| | "<|DISGUSTED|>": "๐คข", |
| | "<|SURPRISED|>": "๐ฎ", |
| | } |
| |
|
| | event_dict: Dict[str, str] = { |
| | "<|BGM|>": "๐ผ", |
| | "<|Speech|>": "", |
| | "<|Applause|>": "๐", |
| | "<|Laughter|>": "๐", |
| | "<|Cry|>": "๐ญ", |
| | "<|Sneeze|>": "๐คง", |
| | "<|Breath|>": "", |
| | "<|Cough|>": "๐คง", |
| | } |
| |
|
| | emoji_dict: Dict[str, str] = { |
| | "<|nospeech|><|Event_UNK|>": "โ", |
| | "<|zh|>": "", |
| | "<|en|>": "", |
| | "<|yue|>": "", |
| | "<|ja|>": "", |
| | "<|ko|>": "", |
| | "<|nospeech|>": "", |
| | "<|HAPPY|>": "๐", |
| | "<|SAD|>": "๐", |
| | "<|ANGRY|>": "๐ก", |
| | "<|NEUTRAL|>": "", |
| | "<|BGM|>": "๐ผ", |
| | "<|Speech|>": "", |
| | "<|Applause|>": "๐", |
| | "<|Laughter|>": "๐", |
| | "<|FEARFUL|>": "๐ฐ", |
| | "<|DISGUSTED|>": "๐คข", |
| | "<|SURPRISED|>": "๐ฎ", |
| | "<|Cry|>": "๐ญ", |
| | "<|EMO_UNKNOWN|>": "", |
| | "<|Sneeze|>": "๐คง", |
| | "<|Breath|>": "", |
| | "<|Cough|>": "๐ท", |
| | "<|Sing|>": "", |
| | "<|Speech_Noise|>": "", |
| | "<|withitn|>": "", |
| | "<|woitn|>": "", |
| | "<|GBG|>": "", |
| | "<|Event_UNK|>": "", |
| | } |
| |
|
| | lang_dict: Dict[str, str] = { |
| | "<|zh|>": "<|lang|>", |
| | "<|en|>": "<|lang|>", |
| | "<|yue|>": "<|lang|>", |
| | "<|ja|>": "<|lang|>", |
| | "<|ko|>": "<|lang|>", |
| | "<|nospeech|>": "<|lang|>", |
| | } |
| |
|
| | emo_set: Set[str] = {"๐", "๐", "๐ก", "๐ฐ", "๐คข", "๐ฎ"} |
| | event_set: Set[str] = {"๐ผ", "๐", "๐", "๐ญ", "๐คง", "๐ท"} |
| |
|
| |
|
| | def format_text_with_emotion(text: str) -> str: |
| | """Format text with emotion and event markers""" |
| | token_count: Dict[str, int] = {} |
| | original_text = text |
| | for token in emoji_dict: |
| | token_count[token] = text.count(token) |
| | |
| | |
| | dominant_emotion = "<|NEUTRAL|>" |
| | for emotion in emotion_dict: |
| | if token_count[emotion] > token_count[dominant_emotion]: |
| | dominant_emotion = emotion |
| | |
| | |
| | text = original_text |
| | for event in event_dict: |
| | if token_count[event] > 0: |
| | text = event_dict[event] + text |
| | |
| | |
| | for token in emoji_dict: |
| | text = text.replace(token, emoji_dict[token]) |
| | |
| | |
| | text = text + emotion_dict[dominant_emotion] |
| |
|
| | |
| | for emoji in emo_set.union(event_set): |
| | text = text.replace(" " + emoji, emoji) |
| | text = text.replace(emoji + " ", emoji) |
| | return text.strip() |
| |
|
| |
|
| | def format_text_advanced(text: str) -> str: |
| | """Advanced text formatting with multilingual and complex token handling""" |
| | def get_emotion(text: str) -> Optional[str]: |
| | return text[-1] if text[-1] in emo_set else None |
| |
|
| | def get_event(text: str) -> Optional[str]: |
| | return text[0] if text[0] in event_set else None |
| |
|
| | |
| | text = text.replace("<|nospeech|><|Event_UNK|>", "โ") |
| | for lang in lang_dict: |
| | text = text.replace(lang, "<|lang|>") |
| | |
| | |
| | text_segments: List[str] = [format_text_with_emotion(segment).strip() for segment in text.split("<|lang|>")] |
| | formatted_text = " " + text_segments[0] |
| | current_event = get_event(formatted_text) |
| |
|
| | |
| | for i in range(1, len(text_segments)): |
| | if not text_segments[i]: |
| | continue |
| |
|
| | if get_event(text_segments[i]) == current_event and get_event(text_segments[i]) is not None: |
| | text_segments[i] = text_segments[i][1:] |
| | current_event = get_event(text_segments[i]) |
| |
|
| | if get_emotion(text_segments[i]) is not None and get_emotion(text_segments[i]) == get_emotion(formatted_text): |
| | formatted_text = formatted_text[:-1] |
| | formatted_text += text_segments[i].strip() |
| |
|
| | formatted_text = formatted_text.replace("The.", " ") |
| | return formatted_text.strip() |
| |
|
| |
|
| | async def audio_stt(audio: torch.Tensor, sample_rate: int, language: str = "auto") -> str: |
| | """Process audio tensor and perform speech-to-text conversion. |
| | |
| | Args: |
| | audio: Input audio tensor |
| | sample_rate: Audio sample rate in Hz |
| | language: Target language code (auto/zh/en/yue/ja/ko/nospeech) |
| | |
| | Returns: |
| | str: Transcribed and formatted text result |
| | """ |
| | try: |
| | |
| | if audio.dtype != torch.float32: |
| | if audio.dtype == torch.int16: |
| | audio = audio.float() / torch.iinfo(torch.int16).max |
| | elif audio.dtype == torch.int32: |
| | audio = audio.float() / torch.iinfo(torch.int32).max |
| | else: |
| | audio = audio.float() |
| | |
| | |
| | if audio.abs().max() > 1.0: |
| | audio = audio / audio.abs().max() |
| | |
| | |
| | if len(audio.shape) > 1: |
| | audio = audio.mean(dim=0) |
| | audio = audio.squeeze() |
| | |
| | |
| | if sample_rate != 16000: |
| | resampler = torchaudio.transforms.Resample( |
| | orig_freq=sample_rate, |
| | new_freq=16000 |
| | ) |
| | audio = resampler(audio.unsqueeze(0)).squeeze(0) |
| | |
| | text = model.generate( |
| | input=audio, |
| | cache={}, |
| | language=language, |
| | use_itn=True, |
| | batch_size_s=500, |
| | merge_vad=True |
| | ) |
| | |
| | |
| | result = text[0]["text"] |
| | return format_text_advanced(result) |
| | |
| | except Exception as e: |
| | raise HTTPException( |
| | status_code=500, |
| | detail=f"Audio processing failed in audio_stt: {str(e)}" |
| | ) |
| |
|
| | async def process_audio(audio_data: bytes, language: str = "auto") -> str: |
| | """Process audio data and return transcription result. |
| | |
| | Args: |
| | audio_data: Raw audio data in bytes |
| | language: Target language code |
| | |
| | Returns: |
| | str: Transcribed and formatted text |
| | |
| | Raises: |
| | HTTPException: If audio processing fails |
| | """ |
| | try: |
| | audio_buffer = BytesIO(audio_data) |
| | waveform, sample_rate = torchaudio.load( |
| | uri=audio_buffer, |
| | normalize=True, |
| | channels_first=True |
| | ) |
| | result = await audio_stt(waveform, sample_rate, language) |
| | return result |
| | |
| | except Exception as e: |
| | raise HTTPException( |
| | status_code=500, |
| | detail=f"Audio processing failed: {str(e)}" |
| | ) |
| |
|
| |
|
| | async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials: |
| | """Verify Bearer Token authentication""" |
| | if credentials.credentials != API_TOKEN: |
| | raise HTTPException( |
| | status_code=401, |
| | detail="Invalid authentication token", |
| | headers={"WWW-Authenticate": "Bearer"} |
| | ) |
| | return credentials |
| |
|
| | @app.post("/v1/audio/transcriptions") |
| | async def transcribe_audio( |
| | file: UploadFile = File(...), |
| | model: str = "FunAudioLLM/SenseVoiceSmall", |
| | language: str = "auto", |
| | token: HTTPAuthorizationCredentials = Depends(verify_token) |
| | ) -> Dict[str, Union[str, int, float]]: |
| | """Audio transcription endpoint. |
| | |
| | Args: |
| | file: Audio file (supports mp3, wav, flac, ogg, m4a) |
| | model: Model name |
| | language: Language code |
| | token: Authentication token |
| | |
| | Returns: |
| | Dict containing transcription result and metadata |
| | """ |
| | start_time = time.time() |
| | |
| | try: |
| | |
| | if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): |
| | return { |
| | "text": "", |
| | "error_code": 400, |
| | "error_msg": "ไธๆฏๆ็้ณ้ขๆ ผๅผ", |
| | "process_time": time.time() - start_time |
| | } |
| | |
| | |
| | if model != "FunAudioLLM/SenseVoiceSmall": |
| | return { |
| | "text": "", |
| | "error_code": 400, |
| | "error_msg": "ไธๆฏๆ็ๆจกๅ", |
| | "process_time": time.time() - start_time |
| | } |
| | |
| | |
| | if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: |
| | return { |
| | "text": "", |
| | "error_code": 400, |
| | "error_msg": "ไธๆฏๆ็่ฏญ่จ", |
| | "process_time": time.time() - start_time |
| | } |
| | |
| | |
| | content = await file.read() |
| | text = await process_audio(content, language) |
| | |
| | return { |
| | "text": text, |
| | "error_code": 0, |
| | "error_msg": "", |
| | "process_time": time.time() - start_time |
| | } |
| | |
| | except Exception as e: |
| | return { |
| | "text": "", |
| | "error_code": 500, |
| | "error_msg": str(e), |
| | "process_time": time.time() - start_time |
| | } |
| |
|
| |
|
| | def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: str = "auto") -> str: |
| | """Gradio interface for audio transcription""" |
| | try: |
| | if audio is None: |
| | return "Please upload an audio file" |
| | |
| | |
| | sample_rate, input_wav = audio |
| | |
| | |
| | input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max |
| | |
| | |
| | input_wav = torch.from_numpy(input_wav) |
| | result = asyncio.run(audio_stt(input_wav, sample_rate, language)) |
| | |
| | return result |
| | except Exception as e: |
| | return f"Processing failed: {str(e)}" |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=transcribe_audio_gradio, |
| | inputs=[ |
| | gr.Audio( |
| | sources=["upload", "microphone"], |
| | type="numpy", |
| | label="Upload audio or record from microphone" |
| | ), |
| | gr.Dropdown( |
| | choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], |
| | value="auto", |
| | label="Select Language" |
| | ) |
| | ], |
| | outputs=gr.Textbox(label="Recognition Result"), |
| | title="SenseVoice Speech Recognition", |
| | description="Multi-language speech transcription service supporting Chinese, English, Cantonese, Japanese, and Korean", |
| | examples=[ |
| | ["examples/zh.mp3", "zh"], |
| | ["examples/en.mp3", "en"], |
| | ] |
| | ) |
| |
|
| | |
| | app = gr.mount_gradio_app(app, demo, path="/") |
| |
|
| | |
| | @app.get("/docs", include_in_schema=False) |
| | async def custom_swagger_ui_html(): |
| | return HTMLResponse(""" |
| | <!DOCTYPE html> |
| | <html> |
| | <head> |
| | <title>SenseVoice API Documentation</title> |
| | <meta http-equiv="refresh" content="0;url=/docs/" /> |
| | </head> |
| | <body> |
| | <p>Redirecting to API documentation...</p> |
| | </body> |
| | </html> |
| | """) |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|