File size: 2,096 Bytes
b2a0ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from fastapi import FastAPI, UploadFile, File
from pydantic import BaseModel
import torch
import torchaudio
from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM
from TTS.api import TTS
import uvicorn
import tempfile

app = FastAPI()

# Load ASR (Whisper small)
asr_processor = AutoProcessor.from_pretrained("openai/whisper-small")
asr_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small").to("cpu")

# Load LLM (Flan-T5 small)
llm_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
llm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to("cpu")

# Load TTS (Facebook MMS)
tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts")  # lightweight multilingual

class LLMInput(BaseModel):
    prompt: str

@app.post("/asr/")
async def transcribe(file: UploadFile = File(...)):
    with tempfile.NamedTemporaryFile(delete=False) as tmp:
        tmp.write(await file.read())
        waveform, rate = torchaudio.load(tmp.name)
        if rate != 16000:
            resampler = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000)
            waveform = resampler(waveform)
        inputs = asr_processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
        with torch.no_grad():
            predicted_ids = asr_model.generate(inputs["input_features"])
        transcription = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    return {"transcription": transcription}

@app.post("/llm/")
async def generate(input: LLMInput):
    input_ids = llm_tokenizer.encode(input.prompt, return_tensors="pt")
    output_ids = llm_model.generate(input_ids, max_length=100)
    response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return {"response": response}

@app.post("/tts/")
async def synthesize(input: LLMInput):
    path = tts.tts_to_file(text=input.prompt, file_path="output.wav")
    return {"audio_path": path}

if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860)