alidw commited on
Commit
df6024a
·
verified ·
1 Parent(s): b442b40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -1,18 +1,27 @@
1
- import torch
2
  from fastapi import FastAPI
3
  from fastapi.responses import StreamingResponse
4
- from chatterbox import ChatterboxMultilingualTTS # حسب التسمية في الريبو
 
 
 
5
 
6
  app = FastAPI()
7
 
8
- model = ChatterboxMultilingualTTS.from_pretrained("ResembleAI/chatterbox")
 
 
9
 
10
  @app.post("/tts")
11
- async def tts_endpoint(text: str, language: str = "ar"):
12
- # يمكنك ضبط باراميترات اللغة/الصوت حسب الدوكمنتشِن
13
- audio = model.tts(text, language=language)
14
- # audio: numpy array or torch tensor
15
- # حوّله إلى wav/ogg stream
16
- def iterfile():
17
- yield audio.tobytes()
18
- return StreamingResponse(iterfile(), media_type="audio/wav")
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from fastapi.responses import StreamingResponse
3
+ from transformers import AutoTokenizer, VitsModel
4
+ import torch
5
+ import soundfile as sf
6
+ import io
7
 
8
  app = FastAPI()
9
 
10
+ # تحميل موديل MMS-TTS عربي
11
+ tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-ara")
12
+ model = VitsModel.from_pretrained("facebook/mms-tts-ara")
13
 
14
  @app.post("/tts")
15
+ async def tts_endpoint(text: str):
16
+ inputs = tokenizer(text, return_tensors="pt")
17
+
18
+ with torch.no_grad():
19
+ output = model(**inputs).waveform
20
+
21
+ audio = output.squeeze().cpu().numpy()
22
+
23
+ buffer = io.BytesIO()
24
+ sf.write(buffer, audio, 16000, format="WAV")
25
+ buffer.seek(0)
26
+
27
+ return StreamingResponse(buffer, media_type="audio/wav")