Magjot commited on
Commit
c7460b4
·
verified ·
1 Parent(s): be5ee8a

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +53 -0
  2. requirements.txt +6 -0
  3. utils.py +1 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import torchaudio
5
+ from transformers import pipeline, AutoProcessor, AutoModelForSpeechSeq2Seq, AutoTokenizer, AutoModelForSeq2SeqLM
6
+ from TTS.api import TTS
7
+ import uvicorn
8
+ import tempfile
9
+
10
+ app = FastAPI()
11
+
12
+ # Load ASR (Whisper small)
13
+ asr_processor = AutoProcessor.from_pretrained("openai/whisper-small")
14
+ asr_model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small").to("cpu")
15
+
16
+ # Load LLM (Flan-T5 small)
17
+ llm_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
18
+ llm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to("cpu")
19
+
20
+ # Load TTS (Facebook MMS)
21
+ tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts") # lightweight multilingual
22
+
23
+ class LLMInput(BaseModel):
24
+ prompt: str
25
+
26
+ @app.post("/asr/")
27
+ async def transcribe(file: UploadFile = File(...)):
28
+ with tempfile.NamedTemporaryFile(delete=False) as tmp:
29
+ tmp.write(await file.read())
30
+ waveform, rate = torchaudio.load(tmp.name)
31
+ if rate != 16000:
32
+ resampler = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000)
33
+ waveform = resampler(waveform)
34
+ inputs = asr_processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
35
+ with torch.no_grad():
36
+ predicted_ids = asr_model.generate(inputs["input_features"])
37
+ transcription = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
38
+ return {"transcription": transcription}
39
+
40
+ @app.post("/llm/")
41
+ async def generate(input: LLMInput):
42
+ input_ids = llm_tokenizer.encode(input.prompt, return_tensors="pt")
43
+ output_ids = llm_model.generate(input_ids, max_length=100)
44
+ response = llm_tokenizer.decode(output_ids[0], skip_special_tokens=True)
45
+ return {"response": response}
46
+
47
+ @app.post("/tts/")
48
+ async def synthesize(input: LLMInput):
49
+ path = tts.tts_to_file(text=input.prompt, file_path="output.wav")
50
+ return {"audio_path": path}
51
+
52
+ if __name__ == "__main__":
53
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers
4
+ torchaudio
5
+ TTS
6
+ torch
utils.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Utility functions can go here