| | from transformers import AutoModelForSpeechSeq2Seq, AutoTokenizer |
| | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
| | from llama_cpp import Llama |
| | import torch |
| | import soundfile as sf |
| | import io |
| | import os |
| | from pydantic import BaseModel |
| | from fastapi import FastAPI, File, UploadFile, Response |
| | app = FastAPI() |
| |
|
| | |
| | if os.path.exists("./models/tts_model"): |
| | tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("./models/tts_model") |
| | tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model") |
| | else: |
| | tts_model = AutoModelForSpeechSeq2Seq.from_pretrained("facebook/tts_transformer-en-ljspeech") |
| | tts_tokenizer = AutoTokenizer.from_pretrained("facebook/tts_transformer-en-ljspeech") |
| |
|
| | |
| | if os.path.exists("./models/sst_model"): |
| | sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model") |
| | sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model") |
| | else: |
| | sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
| | sst_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
| |
|
| | |
| | if os.path.exists("./models/llama.gguf"): |
| | llm = Llama("./models/llama.gguf") |
| | else: |
| | raise FileNotFoundError("Please upload llama.gguf to models/ directory") |
| |
|
| | |
| | class TTSRequest(BaseModel): |
| | text: str |
| |
|
| | class LLMRequest(BaseModel): |
| | prompt: str |
| |
|
| | @app.post("/tts") |
| | async def tts_endpoint(request: TTSRequest): |
| | text = request.text |
| | inputs = tts_tokenizer(text, return_tensors="pt") |
| | with torch.no_grad(): |
| | audio = tts_model.generate(**inputs) |
| | audio = audio.squeeze().cpu().numpy() |
| | buffer = io.BytesIO() |
| | sf.write(buffer, audio, 22050, format="WAV") |
| | buffer.seek(0) |
| | return Response(content=buffer.getvalue(), media_type="audio/wav") |
| |
|
| | |
| | @app.post("/sst") |
| | async def sst_endpoint(file: UploadFile = File(...)): |
| | audio_bytes = await file.read() |
| | audio, sr = sf.read(io.BytesIO(audio_bytes)) |
| | inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt") |
| | with torch.no_grad(): |
| | logits = sst_model(inputs.input_values).logits |
| | predicted_ids = torch.argmax(logits, dim=-1) |
| | transcription = sst_processor.batch_decode(predicted_ids)[0] |
| | return {"text": transcription} |
| |
|
| | @app.post("/llm") |
| | async def llm_endpoint(request: LLMRequest): |
| | prompt = request.prompt |
| | output = llm(prompt, max_tokens=50) |
| | return {"text": output["choices"][0]["text"]} |