from fastapi import FastAPI, UploadFile, File from pydantic import BaseModel import torch import torchaudio from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM from TTS.api import TTS import uvicorn import tempfile app = FastAPI() # Load ASR (Whisper small) asr_processor = AutoProcessor.from_pretrained("openai/whisper-small") asr_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small").to("cpu") # Load LLM (Flan-T5 small) llm_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") llm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to("cpu") # Load TTS (Facebook MMS) tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts") # lightweight multilingual class LLMInput(BaseModel): prompt: str @app.post("/asr/") async def transcribe(file: UploadFile = File(...)): with tempfile.NamedTemporaryFile(delete=False) as tmp: tmp.write(await file.read()) waveform, rate = torchaudio.load(tmp.name) if rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000) waveform = resampler(waveform) inputs = asr_processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt") with torch.no_grad(): predicted_ids = asr_model.generate(inputs["input_features"]) transcription = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] return {"transcription": transcription} @app.post("/llm/") async def generate(input: LLMInput): input_ids = llm_tokenizer.encode(input.prompt, return_tensors="pt") output_ids = llm_model.generate(input_ids, max_length=100) response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) return {"response": response} @app.post("/tts/") async def synthesize(input: LLMInput): path = tts.tts_to_file(text=input.prompt, file_path="output.wav") return {"audio_path": path} if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=7860)