Spaces:
Configuration error
Configuration error
File size: 2,096 Bytes
b2a0ae9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | from fastapi import FastAPI, UploadFile, File
from pydantic import BaseModel
import torch
import torchaudio
from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM
from TTS.api import TTS
import uvicorn
import tempfile
app = FastAPI()
# Load ASR (Whisper small)
asr_processor = AutoProcessor.from_pretrained("openai/whisper-small")
asr_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small").to("cpu")
# Load LLM (Flan-T5 small)
llm_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
llm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to("cpu")
# Load TTS (Facebook MMS)
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts") # lightweight multilingual
class LLMInput(BaseModel):
prompt: str
@app.post("/asr/")
async def transcribe(file: UploadFile = File(...)):
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(await file.read())
waveform, rate = torchaudio.load(tmp.name)
if rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000)
waveform = resampler(waveform)
inputs = asr_processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
predicted_ids = asr_model.generate(inputs["input_features"])
transcription = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return {"transcription": transcription}
@app.post("/llm/")
async def generate(input: LLMInput):
input_ids = llm_tokenizer.encode(input.prompt, return_tensors="pt")
output_ids = llm_model.generate(input_ids, max_length=100)
response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
return {"response": response}
@app.post("/tts/")
async def synthesize(input: LLMInput):
path = tts.tts_to_file(text=input.prompt, file_path="output.wav")
return {"audio_path": path}
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|