| import json |
| import os.path |
| import tempfile |
| import sys |
| import re |
| import uuid |
| import requests |
| import librosa |
| import numpy as np |
| import torch |
| import uvicorn |
| import torchaudio |
| import base64 |
| import io |
| from argparse import ArgumentParser |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from transformers import WhisperFeatureExtractor, AutoTokenizer |
| from speech_tokenizer.modeling_whisper import WhisperVQEncoder |
|
|
| sys.path.insert(0, "./cosyvoice") |
| sys.path.insert(0, "./third_party/Matcha-TTS") |
|
|
| from speech_tokenizer.utils import extract_speech_token |
| from flow_inference import AudioDecoder |
|
|
| |
| app = FastAPI() |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| class AudioRequest(BaseModel): |
| audio_data: str |
| sample_rate: int |
|
|
| class AudioResponse(BaseModel): |
| audio_data: str |
| text_transcript: str |
|
|
| |
| DEVICE = "cuda" |
| audio_decoder = None |
| whisper_model = None |
| feature_extractor = None |
| glm_tokenizer = None |
|
|
| def initialize_models(): |
| global audio_decoder, feature_extractor, whisper_model, glm_tokenizer |
| |
| |
| glm_tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-voice-9b", trust_remote_code=True) |
| |
| |
| whisper_model = WhisperVQEncoder.from_pretrained("THUDM/glm-4-voice-tokenizer").eval().to(DEVICE) |
| feature_extractor = WhisperFeatureExtractor.from_pretrained("THUDM/glm-4-voice-tokenizer") |
| |
| |
| audio_decoder = AudioDecoder( |
| config_path="./glm-4-voice-decoder/config.yaml", |
| flow_ckpt_path="./glm-4-voice-decoder/flow.pt", |
| hift_ckpt_path="./glm-4-voice-decoder/hift.pt", |
| device=DEVICE |
| ) |
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| try: |
| initialize_models() |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Model initialization failed: {str(e)}") |
|
|
| def process_audio(audio_bytes: bytes, target_sr: int = 16000): |
| |
| audio_np = np.frombuffer(audio_bytes, dtype=np.int16) |
| |
| |
| if target_sr != 16000: |
| audio_np = librosa.resample(audio_np, orig_sr=target_sr, target_sr=16000) |
| |
| return audio_np |
|
|
| @app.post("/api/voice_chat") |
| async def voice_chat(request: AudioRequest): |
| try: |
| |
| audio_bytes = base64.b64decode(request.audio_data) |
| audio_np = process_audio(audio_bytes, request.sample_rate) |
| |
| |
| with tempfile.TemporaryDirectory() as tmp_dir: |
| tmp_path = os.path.join(tmp_dir, "audio.wav") |
| torchaudio.save(tmp_path, torch.from_numpy(audio_np).unsqueeze(0), request.sample_rate) |
| audio_tokens = extract_speech_token(whisper_model, feature_extractor, [tmp_path])[0] |
|
|
| if not audio_tokens: |
| raise HTTPException(400, "No speech detected") |
|
|
| |
| response = requests.post( |
| "http://localhost:10000/generate_stream", |
| json={ |
| "prompt": f"<|system|>Respond<|user|>{' '.join(f'<|audio_{x}|>' for x in audio_tokens)}<|assistant|>", |
| "temperature": 0.7, |
| "top_p": 0.9, |
| "max_new_tokens": 256 |
| }, |
| stream=True |
| ) |
|
|
| |
| text_tokens = [] |
| audio_tokens = [] |
| audio_offset = glm_tokenizer.convert_tokens_to_ids('<|audio_0|>') |
| |
| for chunk in response.iter_lines(): |
| token_id = json.loads(chunk)["token_id"] |
| if token_id >= audio_offset: |
| audio_tokens.append(token_id - audio_offset) |
| else: |
| text_tokens.append(token_id) |
|
|
| |
| tts_token = torch.tensor(audio_tokens, device=DEVICE).unsqueeze(0) |
| tts_speech, _ = audio_decoder.token2wav(tts_token) |
| |
| |
| buffer = io.BytesIO() |
| torchaudio.save(buffer, tts_speech.cpu(), 22050, format="wav") |
| |
| return AudioResponse( |
| audio_data=base64.b64encode(buffer.getvalue()).decode(), |
| text_transcript=glm_tokenizer.decode(text_tokens, skip_special_tokens=True) |
| ) |
| |
| except Exception as e: |
| raise HTTPException(500, str(e)) |
| |
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=8000) |