Spaces:
Configuration error
Configuration error
| from fastapi import FastAPI, UploadFile, File | |
| from pydantic import BaseModel | |
| import torch | |
| import torchaudio | |
| from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM | |
| from TTS.api import TTS | |
| import uvicorn | |
| import tempfile | |
| app = FastAPI() | |
| # Load ASR (Whisper small) | |
| asr_processor = AutoProcessor.from_pretrained("openai/whisper-small") | |
| asr_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small").to("cpu") | |
| # Load LLM (Flan-T5 small) | |
| llm_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") | |
| llm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to("cpu") | |
| # Load TTS (Facebook MMS) | |
| tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts") # lightweight multilingual | |
| class LLMInput(BaseModel): | |
| prompt: str | |
| async def transcribe(file: UploadFile = File(...)): | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp.write(await file.read()) | |
| waveform, rate = torchaudio.load(tmp.name) | |
| if rate != 16000: | |
| resampler = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000) | |
| waveform = resampler(waveform) | |
| inputs = asr_processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| predicted_ids = asr_model.generate(inputs["input_features"]) | |
| transcription = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| return {"transcription": transcription} | |
| async def generate(input: LLMInput): | |
| input_ids = llm_tokenizer.encode(input.prompt, return_tensors="pt") | |
| output_ids = llm_model.generate(input_ids, max_length=100) | |
| response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| return {"response": response} | |
| async def synthesize(input: LLMInput): | |
| path = tts.tts_to_file(text=input.prompt, file_path="output.wav") | |
| return {"audio_path": path} | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) | |