Magjot commited on
Commit
b2a0ae9
·
verified ·
1 Parent(s): f8b47b5

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +16 -0
  2. app.py +53 -0
  3. requirements.txt +6 -0
  4. utils.py +1 -0
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ASR + LLM + TTS REST API
2
+
3
+ This is a simple RESTful API built with FastAPI.
4
+
5
+ ## Endpoints
6
+
7
+ - `/asr/`: Upload a `.wav` file and get the transcription.
8
+ - `/llm/`: Send a prompt and get generated text.
9
+ - `/tts/`: Send text and get back audio path (WAV file).
10
+
11
+ ## Run
12
+
13
+ ```bash
14
+ pip install -r requirements.txt
15
+ python app.py
16
+ ```
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import torchaudio
5
+ from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM
6
+ from TTS.api import TTS
7
+ import uvicorn
8
+ import tempfile
9
+
10
+ app = FastAPI()
11
+
12
+ # Load ASR (Whisper small)
13
+ asr_processor = AutoProcessor.from_pretrained("openai/whisper-small")
14
+ asr_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small").to("cpu")
15
+
16
+ # Load LLM (Flan-T5 small)
17
+ llm_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
18
+ llm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to("cpu")
19
+
20
+ # Load TTS (Facebook MMS)
21
+ tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts") # lightweight multilingual
22
+
23
+ class LLMInput(BaseModel):
24
+ prompt: str
25
+
26
+ @app.post("/asr/")
27
+ async def transcribe(file: UploadFile = File(...)):
28
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
29
+ tmp.write(await file.read())
30
+ waveform, rate = torchaudio.load(tmp.name)
31
+ if rate != 16000:
32
+ resampler = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000)
33
+ waveform = resampler(waveform)
34
+ inputs = asr_processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
35
+ with torch.no_grad():
36
+ predicted_ids = asr_model.generate(inputs["input_features"])
37
+ transcription = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
38
+ return {"transcription": transcription}
39
+
40
+ @app.post("/llm/")
41
+ async def generate(input: LLMInput):
42
+ input_ids = llm_tokenizer.encode(input.prompt, return_tensors="pt")
43
+ output_ids = llm_model.generate(input_ids, max_length=100)
44
+ response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
45
+ return {"response": response}
46
+
47
+ @app.post("/tts/")
48
+ async def synthesize(input: LLMInput):
49
+ path = tts.tts_to_file(text=input.prompt, file_path="output.wav")
50
+ return {"audio_path": path}
51
+
52
+ if __name__ == "__main__":
53
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers
4
+ torchaudio
5
+ TTS
6
+ torch
utils.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Utility functions can go here