| | |
| |
|
| | from io import BytesIO |
| | from typing import Optional, Dict, Any, List, Set, Union, Tuple |
| | import os |
| | import time |
| |
|
| | |
| | from fastapi import FastAPI, File, UploadFile, HTTPException, Depends |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| | from fastapi.responses import HTMLResponse |
| | import numpy as np |
| | import torch |
| | import torchaudio |
| | from funasr import AutoModel |
| | from dotenv import load_dotenv |
| | import os |
| | import time |
| | import gradio as gr |
| |
|
| | |
| | load_dotenv() |
| |
|
| | |
| | API_TOKEN: str = os.getenv("API_TOKEN") |
| | if not API_TOKEN: |
| | raise RuntimeError("API_TOKEN environment variable is not set") |
| |
|
| | |
| | security = HTTPBearer() |
| |
|
| | app = FastAPI( |
| | title="SenseVoice API", |
| | description="语音识别 API 服务", |
| | version="1.0.0" |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | model = AutoModel( |
| | model="FunAudioLLM/SenseVoiceSmall", |
| | vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | vad_kwargs={"max_single_segment_time": 30000}, |
| | hub="hf", |
| | device="cuda" |
| | ) |
| |
|
| | |
| | emotion_dict: Dict[str, str] = { |
| | "<|HAPPY|>": "😊", |
| | "<|SAD|>": "😔", |
| | "<|ANGRY|>": "😡", |
| | "<|NEUTRAL|>": "", |
| | "<|FEARFUL|>": "😰", |
| | "<|DISGUSTED|>": "🤢", |
| | "<|SURPRISED|>": "😮", |
| | } |
| |
|
| | event_dict: Dict[str, str] = { |
| | "<|BGM|>": "🎼", |
| | "<|Speech|>": "", |
| | "<|Applause|>": "👏", |
| | "<|Laughter|>": "😀", |
| | "<|Cry|>": "😭", |
| | "<|Sneeze|>": "🤧", |
| | "<|Breath|>": "", |
| | "<|Cough|>": "🤧", |
| | } |
| |
|
| | emoji_dict: Dict[str, str] = { |
| | "<|nospeech|><|Event_UNK|>": "❓", |
| | "<|zh|>": "", |
| | "<|en|>": "", |
| | "<|yue|>": "", |
| | "<|ja|>": "", |
| | "<|ko|>": "", |
| | "<|nospeech|>": "", |
| | "<|HAPPY|>": "😊", |
| | "<|SAD|>": "😔", |
| | "<|ANGRY|>": "😡", |
| | "<|NEUTRAL|>": "", |
| | "<|BGM|>": "🎼", |
| | "<|Speech|>": "", |
| | "<|Applause|>": "👏", |
| | "<|Laughter|>": "😀", |
| | "<|FEARFUL|>": "😰", |
| | "<|DISGUSTED|>": "🤢", |
| | "<|SURPRISED|>": "😮", |
| | "<|Cry|>": "😭", |
| | "<|EMO_UNKNOWN|>": "", |
| | "<|Sneeze|>": "🤧", |
| | "<|Breath|>": "", |
| | "<|Cough|>": "😷", |
| | "<|Sing|>": "", |
| | "<|Speech_Noise|>": "", |
| | "<|withitn|>": "", |
| | "<|woitn|>": "", |
| | "<|GBG|>": "", |
| | "<|Event_UNK|>": "", |
| | } |
| |
|
| | lang_dict: Dict[str, str] = { |
| | "<|zh|>": "<|lang|>", |
| | "<|en|>": "<|lang|>", |
| | "<|yue|>": "<|lang|>", |
| | "<|ja|>": "<|lang|>", |
| | "<|ko|>": "<|lang|>", |
| | "<|nospeech|>": "<|lang|>", |
| | } |
| |
|
| | emo_set: Set[str] = {"😊", "😔", "😡", "😰", "🤢", "😮"} |
| | event_set: Set[str] = {"🎼", "👏", "😀", "😭", "🤧", "😷"} |
| |
|
| |
|
| | def format_text_basic(text: str) -> str: |
| | """Replace special tokens with corresponding emojis""" |
| | for token in emoji_dict: |
| | text = text.replace(token, emoji_dict[token]) |
| | return text |
| |
|
| |
|
| | def format_text_with_emotion(text: str) -> str: |
| | """Format text with emotion and event markers""" |
| | token_count: Dict[str, int] = {} |
| | original_text = text |
| | for token in emoji_dict: |
| | token_count[token] = text.count(token) |
| | |
| | |
| | dominant_emotion = "<|NEUTRAL|>" |
| | for emotion in emotion_dict: |
| | if token_count[emotion] > token_count[dominant_emotion]: |
| | dominant_emotion = emotion |
| | |
| | |
| | text = original_text |
| | for event in event_dict: |
| | if token_count[event] > 0: |
| | text = event_dict[event] + text |
| | |
| | |
| | for token in emoji_dict: |
| | text = text.replace(token, emoji_dict[token]) |
| | |
| | |
| | text = text + emotion_dict[dominant_emotion] |
| |
|
| | |
| | for emoji in emo_set.union(event_set): |
| | text = text.replace(" " + emoji, emoji) |
| | text = text.replace(emoji + " ", emoji) |
| | return text.strip() |
| |
|
| |
|
| | def format_text_advanced(text: str) -> str: |
| | """Advanced text formatting with multilingual and complex token handling""" |
| | def get_emotion(text: str) -> Optional[str]: |
| | return text[-1] if text[-1] in emo_set else None |
| |
|
| | def get_event(text: str) -> Optional[str]: |
| | return text[0] if text[0] in event_set else None |
| |
|
| | |
| | text = text.replace("<|nospeech|><|Event_UNK|>", "❓") |
| | for lang in lang_dict: |
| | text = text.replace(lang, "<|lang|>") |
| | |
| | |
| | text_segments: List[str] = [format_text_with_emotion(segment).strip() for segment in text.split("<|lang|>")] |
| | formatted_text = " " + text_segments[0] |
| | current_event = get_event(formatted_text) |
| |
|
| | |
| | for i in range(1, len(text_segments)): |
| | if not text_segments[i]: |
| | continue |
| |
|
| | if get_event(text_segments[i]) == current_event and get_event(text_segments[i]) is not None: |
| | text_segments[i] = text_segments[i][1:] |
| | current_event = get_event(text_segments[i]) |
| |
|
| | if get_emotion(text_segments[i]) is not None and get_emotion(text_segments[i]) == get_emotion(formatted_text): |
| | formatted_text = formatted_text[:-1] |
| | formatted_text += text_segments[i].strip() |
| |
|
| | formatted_text = formatted_text.replace("The.", " ") |
| | return formatted_text.strip() |
| |
|
| |
|
| | async def audio_stt(audio: np.ndarray, sample_rate: int, language: str = "auto") -> str: |
| | |
| | input_wav = audio.astype(np.float32) / np.iinfo(np.int16).max |
| | |
| | if len(input_wav.shape) > 1: |
| | input_wav = input_wav.mean(-1) |
| | |
| | resampler = torchaudio.transforms.Resample(sample_rate, 16000) |
| | input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32) |
| | input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy() |
| | |
| | text = model.generate( |
| | input=input_wav, |
| | cache={}, |
| | language=language, |
| | use_itn=True, |
| | batch_size_s=500, |
| | merge_vad=True |
| | ) |
| | |
| | result = text[0]["text"] |
| | result = format_text_advanced(result) |
| | return result |
| |
|
| | async def process_audio(audio_data: bytes, language: str = "auto") -> str: |
| | """Process audio data and return transcription result""" |
| | try: |
| | |
| | audio_buffer = BytesIO(audio_data) |
| | waveform, sample_rate = torchaudio.load(audio_buffer) |
| | |
| | result = audio_stt(waveform, sample_rate, language) |
| | |
| | return result |
| | |
| | except Exception as e: |
| | import traceback |
| | traceback.print_exc() |
| | traceback.print_stack() |
| | raise HTTPException(status_code=500, detail=f"Audio processing failed: {str(e)}") |
| |
|
| |
|
| | async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)) -> HTTPAuthorizationCredentials: |
| | """Verify Bearer Token authentication""" |
| | if credentials.credentials != API_TOKEN: |
| | raise HTTPException( |
| | status_code=401, |
| | detail="Invalid authentication token", |
| | headers={"WWW-Authenticate": "Bearer"} |
| | ) |
| | return credentials |
| |
|
| | @app.post("/v1/audio/transcriptions") |
| | async def transcribe_audio( |
| | file: UploadFile = File(...), |
| | model: Optional[str] = "FunAudioLLM/SenseVoiceSmall", |
| | language: Optional[str] = "auto", |
| | token: HTTPAuthorizationCredentials = Depends(verify_token) |
| | ) -> Dict[str, Union[str, int, float]]: |
| | """Audio transcription endpoint |
| | |
| | Args: |
| | file: Audio file (supports common audio formats) |
| | model: Model name, currently only supports FunAudioLLM/SenseVoiceSmall |
| | language: Language code, supports auto/zh/en/yue/ja/ko/nospeech |
| | |
| | Returns: |
| | Dict[str, Union[str, int, float]]: { |
| | "text": "Transcription result", |
| | "error_code": 0, |
| | "error_msg": "", |
| | "process_time": 1.234 # Processing time in seconds |
| | } |
| | """ |
| | start_time = time.time() |
| | |
| | try: |
| | |
| | if not file.filename.lower().endswith((".mp3", ".wav", ".flac", ".ogg", ".m4a")): |
| | return { |
| | "text": "", |
| | "error_code": 400, |
| | "error_msg": "Unsupported audio format", |
| | "process_time": time.time() - start_time |
| | } |
| | |
| | |
| | if model != "FunAudioLLM/SenseVoiceSmall": |
| | return { |
| | "text": "", |
| | "error_code": 400, |
| | "error_msg": "Unsupported model", |
| | "process_time": time.time() - start_time |
| | } |
| | |
| | |
| | if language not in ["auto", "zh", "en", "yue", "ja", "ko", "nospeech"]: |
| | return { |
| | "text": "", |
| | "error_code": 400, |
| | "error_msg": "Unsupported language", |
| | "process_time": time.time() - start_time |
| | } |
| | |
| | |
| | content = await file.read() |
| | text = await process_audio(content, language) |
| | |
| | return { |
| | "text": text, |
| | "error_code": 0, |
| | "error_msg": "", |
| | "process_time": time.time() - start_time |
| | } |
| | |
| | except Exception as e: |
| | return { |
| | "text": "", |
| | "error_code": 500, |
| | "error_msg": str(e), |
| | "process_time": time.time() - start_time |
| | } |
| |
|
| |
|
| | def transcribe_audio_gradio(audio: Optional[Tuple[int, np.ndarray]], language: str = "auto") -> str: |
| | """Gradio interface for audio transcription""" |
| | try: |
| | if audio is None: |
| | return "Please upload an audio file" |
| | |
| | |
| | sample_rate, input_wav = audio |
| | |
| | |
| | input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max |
| | |
| | |
| | if len(input_wav.shape) > 1: |
| | input_wav = input_wav.mean(-1) |
| | |
| | |
| | if sample_rate != 16000: |
| | resampler = torchaudio.transforms.Resample(sample_rate, 16000) |
| | input_wav_tensor = torch.from_numpy(input_wav).to(torch.float32) |
| | input_wav = resampler(input_wav_tensor[None, :])[0, :].numpy() |
| | |
| | |
| | text = model.generate( |
| | input=input_wav, |
| | cache={}, |
| | language=language, |
| | use_itn=True, |
| | batch_size_s=500, |
| | merge_vad=True |
| | ) |
| | |
| | |
| | result = text[0]["text"] |
| | result = format_text_advanced(result) |
| | |
| | return result |
| | except Exception as e: |
| | return f"Processing failed: {str(e)}" |
| |
|
| | |
| | demo = gr.Interface( |
| | fn=transcribe_audio_gradio, |
| | inputs=[ |
| | gr.Audio( |
| | sources=["upload", "microphone"], |
| | type="numpy", |
| | label="Upload audio or record from microphone" |
| | ), |
| | gr.Dropdown( |
| | choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], |
| | value="auto", |
| | label="Select Language" |
| | ) |
| | ], |
| | outputs=gr.Textbox(label="Recognition Result"), |
| | title="SenseVoice Speech Recognition", |
| | description="Multi-language speech transcription service supporting Chinese, English, Cantonese, Japanese, and Korean", |
| | examples=[ |
| | ["examples/zh.mp3", "zh"], |
| | ["examples/en.mp3", "en"], |
| | ] |
| | ) |
| |
|
| | |
| | app = gr.mount_gradio_app(app, demo, path="/") |
| |
|
| | |
| | @app.get("/docs", include_in_schema=False) |
| | async def custom_swagger_ui_html(): |
| | return HTMLResponse(""" |
| | <!DOCTYPE html> |
| | <html> |
| | <head> |
| | <title>SenseVoice API Documentation</title> |
| | <meta http-equiv="refresh" content="0;url=/docs/" /> |
| | </head> |
| | <body> |
| | <p>Redirecting to API documentation...</p> |
| | </body> |
| | </html> |
| | """) |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|